sglang 0.5.3rc2__py3-none-any.whl → 0.5.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +47 -28
- sglang/bench_one_batch_server.py +41 -25
- sglang/bench_serving.py +330 -156
- sglang/check_env.py +1 -1
- sglang/compile_deep_gemm.py +6 -2
- sglang/global_config.py +1 -25
- sglang/lang/api.py +6 -0
- sglang/lang/interpreter.py +1 -0
- sglang/lang/ir.py +13 -0
- sglang/launch_server.py +8 -15
- sglang/profiler.py +18 -1
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +4 -6
- sglang/srt/checkpoint_engine/checkpoint_engine_worker.py +142 -0
- sglang/srt/compilation/backend.py +437 -0
- sglang/srt/compilation/compilation_config.py +20 -0
- sglang/srt/compilation/compilation_counter.py +47 -0
- sglang/srt/compilation/compile.py +210 -0
- sglang/srt/compilation/compiler_interface.py +503 -0
- sglang/srt/compilation/cuda_piecewise_backend.py +228 -0
- sglang/srt/compilation/fix_functionalization.py +134 -0
- sglang/srt/compilation/fx_utils.py +83 -0
- sglang/srt/compilation/inductor_pass.py +140 -0
- sglang/srt/compilation/pass_manager.py +66 -0
- sglang/srt/compilation/piecewise_context_manager.py +40 -0
- sglang/srt/compilation/weak_ref_tensor_jit.py +16 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseek_ocr.py +262 -0
- sglang/srt/configs/deepseekvl2.py +194 -96
- sglang/srt/configs/dots_vlm.py +2 -7
- sglang/srt/configs/falcon_h1.py +13 -64
- sglang/srt/configs/load_config.py +25 -2
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +134 -23
- sglang/srt/configs/modelopt_config.py +30 -0
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/olmo3.py +105 -0
- sglang/srt/configs/points_v15_chat.py +29 -0
- sglang/srt/configs/qwen3_next.py +11 -47
- sglang/srt/configs/qwen3_omni.py +613 -0
- sglang/srt/configs/qwen3_vl.py +0 -10
- sglang/srt/connector/remote_instance.py +1 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -1
- sglang/srt/constrained/llguidance_backend.py +5 -0
- sglang/srt/constrained/outlines_backend.py +1 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +9 -6
- sglang/srt/constrained/utils.py +12 -0
- sglang/srt/constrained/xgrammar_backend.py +20 -11
- sglang/srt/disaggregation/ascend/transfer_engine.py +1 -1
- sglang/srt/disaggregation/base/conn.py +17 -4
- sglang/srt/disaggregation/common/conn.py +4 -2
- sglang/srt/disaggregation/decode.py +123 -31
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +11 -3
- sglang/srt/disaggregation/mooncake/conn.py +157 -19
- sglang/srt/disaggregation/nixl/conn.py +69 -24
- sglang/srt/disaggregation/prefill.py +96 -270
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +4 -4
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +6 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +2 -2
- sglang/srt/distributed/device_communicators/pynccl.py +24 -12
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +2 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +1 -1
- sglang/srt/distributed/naive_distributed.py +5 -4
- sglang/srt/distributed/parallel_state.py +70 -19
- sglang/srt/elastic_ep/elastic_ep.py +74 -0
- sglang/srt/entrypoints/context.py +3 -2
- sglang/srt/entrypoints/engine.py +66 -66
- sglang/srt/entrypoints/grpc_server.py +431 -234
- sglang/srt/entrypoints/harmony_utils.py +2 -2
- sglang/srt/entrypoints/http_server.py +120 -8
- sglang/srt/entrypoints/http_server_engine.py +1 -7
- sglang/srt/entrypoints/openai/protocol.py +225 -37
- sglang/srt/entrypoints/openai/serving_base.py +49 -2
- sglang/srt/entrypoints/openai/serving_chat.py +29 -74
- sglang/srt/entrypoints/openai/serving_classify.py +204 -0
- sglang/srt/entrypoints/openai/serving_completions.py +15 -1
- sglang/srt/entrypoints/openai/serving_responses.py +5 -2
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +42 -4
- sglang/srt/eplb/eplb_algorithms/__init__.py +18 -1
- sglang/srt/eplb/eplb_algorithms/deepseek.py +0 -2
- sglang/srt/eplb/eplb_algorithms/elasticity_aware.py +87 -0
- sglang/srt/eplb/expert_distribution.py +3 -4
- sglang/srt/eplb/expert_location_dispatch.py +2 -2
- sglang/srt/eplb/expert_location_updater.py +2 -2
- sglang/srt/function_call/base_format_detector.py +17 -18
- sglang/srt/function_call/function_call_parser.py +18 -14
- sglang/srt/function_call/glm4_moe_detector.py +1 -5
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/json_array_parser.py +0 -2
- sglang/srt/function_call/utils.py +2 -2
- sglang/srt/grpc/compile_proto.py +3 -3
- sglang/srt/{entrypoints → grpc}/grpc_request_manager.py +112 -52
- sglang/srt/grpc/health_servicer.py +189 -0
- sglang/srt/grpc/scheduler_launcher.py +181 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +66 -10
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +89 -1
- sglang/srt/layers/activation.py +4 -1
- sglang/srt/layers/attention/aiter_backend.py +3 -3
- sglang/srt/layers/attention/ascend_backend.py +17 -1
- sglang/srt/layers/attention/attention_registry.py +43 -23
- sglang/srt/layers/attention/base_attn_backend.py +20 -1
- sglang/srt/layers/attention/double_sparsity_backend.py +2 -2
- sglang/srt/layers/attention/fla/chunk.py +0 -1
- sglang/srt/layers/attention/fla/chunk_o.py +1 -1
- sglang/srt/layers/attention/fla/index.py +0 -2
- sglang/srt/layers/attention/fla/layernorm_gated.py +50 -32
- sglang/srt/layers/attention/fla/utils.py +0 -3
- sglang/srt/layers/attention/fla/wy_fast.py +0 -2
- sglang/srt/layers/attention/flashattention_backend.py +12 -8
- sglang/srt/layers/attention/flashinfer_backend.py +248 -21
- sglang/srt/layers/attention/flashinfer_mla_backend.py +20 -18
- sglang/srt/layers/attention/flashmla_backend.py +2 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -62
- sglang/srt/layers/attention/intel_amx_backend.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -5
- sglang/srt/layers/attention/mamba/mamba.py +189 -241
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +0 -1
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +1 -1
- sglang/srt/layers/attention/nsa/nsa_indexer.py +40 -83
- sglang/srt/layers/attention/nsa/triton_kernel.py +136 -0
- sglang/srt/layers/attention/nsa/utils.py +0 -1
- sglang/srt/layers/attention/nsa_backend.py +404 -90
- sglang/srt/layers/attention/triton_backend.py +208 -34
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +2 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +539 -44
- sglang/srt/layers/attention/trtllm_mha_backend.py +2 -2
- sglang/srt/layers/attention/trtllm_mla_backend.py +361 -30
- sglang/srt/layers/attention/utils.py +11 -7
- sglang/srt/layers/attention/vision.py +3 -3
- sglang/srt/layers/attention/xpu_backend.py +1028 -0
- sglang/srt/layers/communicator.py +11 -7
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/compile_utils.py +4 -8
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/configurer.py +4 -3
- sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/entrypoint.py +3 -3
- sglang/srt/layers/dp_attention.py +17 -0
- sglang/srt/layers/layernorm.py +45 -15
- sglang/srt/layers/linear.py +9 -1
- sglang/srt/layers/logits_processor.py +147 -17
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/cutlass_moe.py +0 -2
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +213 -21
- sglang/srt/layers/moe/ep_moe/kernels.py +35 -457
- sglang/srt/layers/moe/ep_moe/layer.py +119 -397
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +11 -3
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -70
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +18 -42
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton.py +3 -1
- sglang/srt/layers/moe/rocm_moe_utils.py +0 -1
- sglang/srt/layers/moe/router.py +51 -15
- sglang/srt/layers/moe/token_dispatcher/__init__.py +10 -0
- sglang/srt/layers/moe/token_dispatcher/base.py +1 -1
- sglang/srt/layers/moe/token_dispatcher/deepep.py +110 -97
- sglang/srt/layers/moe/token_dispatcher/mooncake.py +386 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +46 -0
- sglang/srt/layers/moe/topk.py +3 -2
- sglang/srt/layers/moe/utils.py +17 -1
- sglang/srt/layers/quantization/__init__.py +2 -53
- sglang/srt/layers/quantization/awq.py +183 -6
- sglang/srt/layers/quantization/awq_triton.py +29 -0
- sglang/srt/layers/quantization/base_config.py +20 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +7 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +20 -49
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +421 -70
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +4 -22
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +339 -0
- sglang/srt/layers/quantization/fp8.py +84 -18
- sglang/srt/layers/quantization/fp8_kernel.py +55 -10
- sglang/srt/layers/quantization/fp8_utils.py +42 -14
- sglang/srt/layers/quantization/fpgemm_fp8.py +2 -3
- sglang/srt/layers/quantization/gptq.py +0 -1
- sglang/srt/layers/quantization/int8_kernel.py +18 -2
- sglang/srt/layers/quantization/marlin_utils.py +12 -0
- sglang/srt/layers/quantization/modelopt_quant.py +125 -100
- sglang/srt/layers/quantization/mxfp4.py +5 -30
- sglang/srt/layers/quantization/petit.py +1 -1
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +3 -3
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +0 -7
- sglang/srt/layers/quantization/unquant.py +1 -4
- sglang/srt/layers/quantization/utils.py +0 -1
- sglang/srt/layers/quantization/w4afp8.py +51 -20
- sglang/srt/layers/quantization/w8a8_int8.py +30 -24
- sglang/srt/layers/radix_attention.py +59 -9
- sglang/srt/layers/rotary_embedding.py +673 -16
- sglang/srt/layers/sampler.py +36 -16
- sglang/srt/layers/sparse_pooler.py +98 -0
- sglang/srt/layers/utils.py +0 -1
- sglang/srt/layers/vocab_parallel_embedding.py +4 -1
- sglang/srt/lora/backend/triton_backend.py +0 -1
- sglang/srt/lora/eviction_policy.py +139 -0
- sglang/srt/lora/lora_manager.py +24 -9
- sglang/srt/lora/lora_registry.py +1 -1
- sglang/srt/lora/mem_pool.py +40 -16
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +1 -1
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +4 -2
- sglang/srt/managers/cache_controller.py +48 -17
- sglang/srt/managers/data_parallel_controller.py +146 -42
- sglang/srt/managers/detokenizer_manager.py +40 -13
- sglang/srt/managers/io_struct.py +66 -16
- sglang/srt/managers/mm_utils.py +20 -18
- sglang/srt/managers/multi_tokenizer_mixin.py +66 -81
- sglang/srt/managers/overlap_utils.py +96 -19
- sglang/srt/managers/schedule_batch.py +241 -511
- sglang/srt/managers/schedule_policy.py +15 -2
- sglang/srt/managers/scheduler.py +399 -499
- sglang/srt/managers/scheduler_metrics_mixin.py +55 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +317 -111
- sglang/srt/managers/scheduler_pp_mixin.py +341 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +57 -10
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +217 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +33 -14
- sglang/srt/managers/tokenizer_communicator_mixin.py +71 -55
- sglang/srt/managers/tokenizer_manager.py +378 -90
- sglang/srt/managers/tp_worker.py +212 -161
- sglang/srt/managers/utils.py +78 -2
- sglang/srt/mem_cache/allocator.py +7 -2
- sglang/srt/mem_cache/allocator_ascend.py +2 -2
- sglang/srt/mem_cache/base_prefix_cache.py +2 -2
- sglang/srt/mem_cache/chunk_cache.py +13 -2
- sglang/srt/mem_cache/common.py +480 -0
- sglang/srt/mem_cache/evict_policy.py +16 -1
- sglang/srt/mem_cache/hicache_storage.py +4 -1
- sglang/srt/mem_cache/hiradix_cache.py +16 -3
- sglang/srt/mem_cache/mamba_radix_cache.py +993 -0
- sglang/srt/mem_cache/memory_pool.py +435 -219
- sglang/srt/mem_cache/memory_pool_host.py +0 -1
- sglang/srt/mem_cache/multimodal_cache.py +0 -1
- sglang/srt/mem_cache/radix_cache.py +53 -19
- sglang/srt/mem_cache/radix_cache_cpp.py +19 -14
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +8 -2
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +1 -13
- sglang/srt/mem_cache/storage/backend_factory.py +2 -2
- sglang/srt/mem_cache/storage/eic/eic_storage.py +5 -6
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +0 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +9 -3
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +5 -3
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +101 -17
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +38 -9
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +1 -1
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +17 -2
- sglang/srt/mem_cache/swa_radix_cache.py +92 -26
- sglang/srt/metrics/collector.py +31 -0
- sglang/srt/metrics/func_timer.py +1 -1
- sglang/srt/model_executor/cuda_graph_runner.py +43 -5
- sglang/srt/model_executor/forward_batch_info.py +28 -23
- sglang/srt/model_executor/model_runner.py +379 -139
- sglang/srt/model_executor/npu_graph_runner.py +2 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +539 -0
- sglang/srt/model_loader/__init__.py +1 -1
- sglang/srt/model_loader/loader.py +424 -27
- sglang/srt/model_loader/utils.py +0 -1
- sglang/srt/model_loader/weight_utils.py +47 -28
- sglang/srt/models/apertus.py +2 -3
- sglang/srt/models/arcee.py +2 -2
- sglang/srt/models/bailing_moe.py +13 -52
- sglang/srt/models/bailing_moe_nextn.py +3 -4
- sglang/srt/models/bert.py +1 -1
- sglang/srt/models/deepseek_nextn.py +19 -3
- sglang/srt/models/deepseek_ocr.py +1516 -0
- sglang/srt/models/deepseek_v2.py +273 -98
- sglang/srt/models/dots_ocr.py +0 -2
- sglang/srt/models/dots_vlm.py +0 -1
- sglang/srt/models/dots_vlm_vit.py +1 -1
- sglang/srt/models/falcon_h1.py +13 -19
- sglang/srt/models/gemma3_mm.py +16 -0
- sglang/srt/models/gemma3n_mm.py +1 -2
- sglang/srt/models/glm4_moe.py +14 -37
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +2 -1
- sglang/srt/models/glm4v_moe.py +5 -5
- sglang/srt/models/gpt_oss.py +5 -5
- sglang/srt/models/grok.py +10 -23
- sglang/srt/models/hunyuan.py +2 -7
- sglang/srt/models/interns1.py +0 -1
- sglang/srt/models/kimi_vl.py +1 -7
- sglang/srt/models/kimi_vl_moonvit.py +3 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/llama_eagle3.py +1 -1
- sglang/srt/models/longcat_flash.py +5 -22
- sglang/srt/models/longcat_flash_nextn.py +3 -14
- sglang/srt/models/mimo.py +2 -13
- sglang/srt/models/mimo_mtp.py +1 -2
- sglang/srt/models/minicpmo.py +7 -5
- sglang/srt/models/mixtral.py +1 -4
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/mllama4.py +13 -3
- sglang/srt/models/nemotron_h.py +511 -0
- sglang/srt/models/olmo2.py +31 -4
- sglang/srt/models/opt.py +5 -5
- sglang/srt/models/phi.py +1 -1
- sglang/srt/models/phi4mm.py +1 -1
- sglang/srt/models/phimoe.py +0 -1
- sglang/srt/models/pixtral.py +0 -3
- sglang/srt/models/points_v15_chat.py +186 -0
- sglang/srt/models/qwen.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +3 -3
- sglang/srt/models/qwen2_audio.py +2 -15
- sglang/srt/models/qwen2_moe.py +15 -12
- sglang/srt/models/qwen2_vl.py +5 -2
- sglang/srt/models/qwen3_moe.py +19 -35
- sglang/srt/models/qwen3_next.py +7 -12
- sglang/srt/models/qwen3_next_mtp.py +3 -4
- sglang/srt/models/qwen3_omni_moe.py +661 -0
- sglang/srt/models/qwen3_vl.py +37 -33
- sglang/srt/models/qwen3_vl_moe.py +57 -185
- sglang/srt/models/roberta.py +55 -3
- sglang/srt/models/sarashina2_vision.py +0 -1
- sglang/srt/models/step3_vl.py +3 -5
- sglang/srt/models/utils.py +11 -1
- sglang/srt/multimodal/processors/base_processor.py +6 -2
- sglang/srt/multimodal/processors/deepseek_ocr.py +37 -0
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +0 -3
- sglang/srt/multimodal/processors/dots_vlm.py +0 -1
- sglang/srt/multimodal/processors/glm4v.py +1 -5
- sglang/srt/multimodal/processors/internvl.py +0 -2
- sglang/srt/multimodal/processors/janus_pro.py +0 -1
- sglang/srt/multimodal/processors/mllama4.py +0 -8
- sglang/srt/multimodal/processors/phi4mm.py +0 -1
- sglang/srt/multimodal/processors/points_v15_chat.py +52 -0
- sglang/srt/multimodal/processors/qwen_vl.py +75 -16
- sglang/srt/multimodal/processors/step3_vl.py +1 -1
- sglang/srt/parser/conversation.py +41 -0
- sglang/srt/parser/reasoning_parser.py +0 -1
- sglang/srt/sampling/custom_logit_processor.py +77 -2
- sglang/srt/sampling/sampling_batch_info.py +17 -22
- sglang/srt/sampling/sampling_params.py +70 -2
- sglang/srt/server_args.py +577 -73
- sglang/srt/server_args_config_parser.py +1 -1
- sglang/srt/single_batch_overlap.py +38 -28
- sglang/srt/speculative/base_spec_worker.py +34 -0
- sglang/srt/speculative/draft_utils.py +226 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +24 -7
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +23 -2
- sglang/srt/speculative/eagle_info.py +57 -18
- sglang/srt/speculative/eagle_info_v2.py +458 -0
- sglang/srt/speculative/eagle_utils.py +138 -0
- sglang/srt/speculative/eagle_worker.py +83 -280
- sglang/srt/speculative/eagle_worker_v2.py +702 -0
- sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +14 -9
- sglang/srt/speculative/ngram_worker.py +12 -11
- sglang/srt/speculative/spec_info.py +2 -0
- sglang/srt/speculative/spec_utils.py +38 -3
- sglang/srt/speculative/standalone_worker.py +4 -14
- sglang/srt/tokenizer/tiktoken_tokenizer.py +2 -2
- sglang/srt/two_batch_overlap.py +28 -14
- sglang/srt/utils/__init__.py +1 -1
- sglang/srt/{bench_utils.py → utils/bench_utils.py} +4 -2
- sglang/srt/utils/common.py +192 -47
- sglang/srt/utils/hf_transformers_utils.py +40 -17
- sglang/srt/{host_shared_memory.py → utils/host_shared_memory.py} +0 -1
- sglang/srt/{offloader.py → utils/offloader.py} +4 -4
- sglang/srt/utils/profile_merger.py +199 -0
- sglang/test/attention/test_flashattn_backend.py +1 -1
- sglang/test/attention/test_flashattn_mla_backend.py +0 -1
- sglang/test/attention/test_prefix_chunk_info.py +0 -2
- sglang/test/attention/test_trtllm_mla_backend.py +221 -53
- sglang/test/few_shot_gsm8k_engine.py +2 -4
- sglang/test/kit_matched_stop.py +157 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +41 -0
- sglang/test/runners.py +2 -0
- sglang/test/send_one.py +42 -7
- sglang/test/simple_eval_common.py +3 -0
- sglang/test/simple_eval_gpqa.py +0 -1
- sglang/test/simple_eval_humaneval.py +0 -3
- sglang/test/simple_eval_longbench_v2.py +344 -0
- sglang/test/test_block_fp8.py +1 -2
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +0 -1
- sglang/test/test_cutlass_moe.py +1 -2
- sglang/test/test_cutlass_w4a8_moe.py +10 -20
- sglang/test/test_deterministic.py +232 -99
- sglang/test/test_deterministic_utils.py +73 -0
- sglang/test/test_disaggregation_utils.py +81 -0
- sglang/test/test_marlin_moe.py +0 -1
- sglang/test/test_utils.py +85 -20
- sglang/version.py +1 -1
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/METADATA +45 -33
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/RECORD +404 -345
- sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
- sglang/srt/speculative/build_eagle_tree.py +0 -427
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/layers/{quantization/deep_gemm_wrapper → deep_gemm_wrapper}/__init__.py +0 -0
- /sglang/srt/{aio_rwlock.py → utils/aio_rwlock.py} +0 -0
- /sglang/srt/{torch_memory_saver_adapter.py → utils/torch_memory_saver_adapter.py} +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc2.dist-info → sglang-0.5.4.dist-info}/top_level.txt +0 -0
|
@@ -24,11 +24,12 @@ import threading
|
|
|
24
24
|
import time
|
|
25
25
|
from collections import defaultdict
|
|
26
26
|
from dataclasses import dataclass
|
|
27
|
-
from typing import List, Optional, Tuple, Union
|
|
27
|
+
from typing import Callable, List, Optional, Tuple, Union
|
|
28
28
|
|
|
29
29
|
import torch
|
|
30
30
|
import torch.distributed as dist
|
|
31
31
|
|
|
32
|
+
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
|
|
32
33
|
from sglang.srt.configs.device_config import DeviceConfig
|
|
33
34
|
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
|
34
35
|
from sglang.srt.configs.model_config import (
|
|
@@ -50,6 +51,7 @@ from sglang.srt.distributed import (
|
|
|
50
51
|
set_symm_mem_all_reduce,
|
|
51
52
|
)
|
|
52
53
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
|
54
|
+
from sglang.srt.elastic_ep.elastic_ep import ElasticEPStateManager
|
|
53
55
|
from sglang.srt.eplb.eplb_manager import EPLBManager
|
|
54
56
|
from sglang.srt.eplb.expert_distribution import (
|
|
55
57
|
ExpertDistributionRecorder,
|
|
@@ -63,6 +65,7 @@ from sglang.srt.eplb.expert_location import (
|
|
|
63
65
|
set_global_expert_location_metadata,
|
|
64
66
|
)
|
|
65
67
|
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
|
68
|
+
from sglang.srt.layers import deep_gemm_wrapper
|
|
66
69
|
from sglang.srt.layers.attention.attention_registry import (
|
|
67
70
|
ATTENTION_BACKENDS,
|
|
68
71
|
attn_backend_wrapper,
|
|
@@ -74,18 +77,11 @@ from sglang.srt.layers.dp_attention import (
|
|
|
74
77
|
initialize_dp_attention,
|
|
75
78
|
)
|
|
76
79
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
77
|
-
from sglang.srt.layers.quantization import
|
|
78
|
-
deep_gemm_wrapper,
|
|
79
|
-
monkey_patch_isinstance_for_vllm_base_layer,
|
|
80
|
-
)
|
|
80
|
+
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
|
81
81
|
from sglang.srt.layers.sampler import Sampler
|
|
82
82
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
|
83
83
|
from sglang.srt.lora.lora_manager import LoRAManager
|
|
84
84
|
from sglang.srt.lora.lora_registry import LoRARef
|
|
85
|
-
from sglang.srt.managers.schedule_batch import (
|
|
86
|
-
GLOBAL_SERVER_ARGS_KEYS,
|
|
87
|
-
global_server_args_dict,
|
|
88
|
-
)
|
|
89
85
|
from sglang.srt.mem_cache.allocator import (
|
|
90
86
|
BaseTokenToKVPoolAllocator,
|
|
91
87
|
PagedTokenToKVPoolAllocator,
|
|
@@ -109,6 +105,9 @@ from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
|
|
109
105
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
|
110
106
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
111
107
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
|
108
|
+
from sglang.srt.model_executor.piecewise_cuda_graph_runner import (
|
|
109
|
+
PiecewiseCudaGraphRunner,
|
|
110
|
+
)
|
|
112
111
|
from sglang.srt.model_loader import get_model
|
|
113
112
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
|
114
113
|
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
|
@@ -116,15 +115,13 @@ from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
|
|
116
115
|
)
|
|
117
116
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
|
118
117
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
|
119
|
-
from sglang.srt.offloader import (
|
|
120
|
-
create_offloader_from_server_args,
|
|
121
|
-
get_offloader,
|
|
122
|
-
set_offloader,
|
|
123
|
-
)
|
|
124
118
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
|
125
|
-
from sglang.srt.server_args import
|
|
119
|
+
from sglang.srt.server_args import (
|
|
120
|
+
ServerArgs,
|
|
121
|
+
get_global_server_args,
|
|
122
|
+
set_global_server_args_for_scheduler,
|
|
123
|
+
)
|
|
126
124
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
|
127
|
-
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
128
125
|
from sglang.srt.utils import (
|
|
129
126
|
MultiprocessingSerializer,
|
|
130
127
|
cpu_has_amx_support,
|
|
@@ -146,8 +143,15 @@ from sglang.srt.utils import (
|
|
|
146
143
|
monkey_patch_vllm_gguf_config,
|
|
147
144
|
set_cuda_arch,
|
|
148
145
|
slow_rank_detector,
|
|
146
|
+
xpu_has_xmx_support,
|
|
147
|
+
)
|
|
148
|
+
from sglang.srt.utils.offloader import (
|
|
149
|
+
create_offloader_from_server_args,
|
|
150
|
+
get_offloader,
|
|
151
|
+
set_offloader,
|
|
149
152
|
)
|
|
150
153
|
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
|
154
|
+
from sglang.srt.utils.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
|
151
155
|
from sglang.srt.weight_sync.tensor_bucket import (
|
|
152
156
|
FlattenedTensorBucket,
|
|
153
157
|
FlattenedTensorMetadata,
|
|
@@ -166,6 +170,15 @@ MLA_ATTENTION_BACKENDS = [
|
|
|
166
170
|
"nsa",
|
|
167
171
|
]
|
|
168
172
|
|
|
173
|
+
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS = [
|
|
174
|
+
"flashinfer",
|
|
175
|
+
"fa3",
|
|
176
|
+
"fa4",
|
|
177
|
+
"flashmla",
|
|
178
|
+
"cutlass_mla",
|
|
179
|
+
"trtllm_mla",
|
|
180
|
+
]
|
|
181
|
+
|
|
169
182
|
|
|
170
183
|
def add_mla_attention_backend(backend_name):
|
|
171
184
|
if backend_name not in MLA_ATTENTION_BACKENDS:
|
|
@@ -173,9 +186,18 @@ def add_mla_attention_backend(backend_name):
|
|
|
173
186
|
logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
|
|
174
187
|
|
|
175
188
|
|
|
189
|
+
def add_chunked_prefix_cache_attention_backend(backend_name):
|
|
190
|
+
if backend_name not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS:
|
|
191
|
+
CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS.append(backend_name)
|
|
192
|
+
logger.info(
|
|
193
|
+
f"Added {backend_name} to CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS."
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
|
|
176
197
|
_is_hip = is_hip()
|
|
177
198
|
_is_npu = is_npu()
|
|
178
199
|
_is_cpu_amx_available = cpu_has_amx_support()
|
|
200
|
+
_is_xpu_xmx_available = xpu_has_xmx_support()
|
|
179
201
|
|
|
180
202
|
# Use a small KV cache pool size for tests in CI
|
|
181
203
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
|
@@ -183,8 +205,10 @@ SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
|
|
183
205
|
# Detect stragger ranks in model loading
|
|
184
206
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
185
207
|
|
|
186
|
-
|
|
208
|
+
# the ratio of mamba cache pool size to max_running_requests, it will be safe when it is larger than 2 (yizhang2077)
|
|
209
|
+
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO = 3
|
|
187
210
|
|
|
211
|
+
logger = logging.getLogger(__name__)
|
|
188
212
|
|
|
189
213
|
if _is_npu:
|
|
190
214
|
import torch_npu
|
|
@@ -257,25 +281,21 @@ class ModelRunner:
|
|
|
257
281
|
self.use_mla_backend = self.model_config.attention_arch == AttentionArch.MLA
|
|
258
282
|
self.attention_chunk_size = model_config.attention_chunk_size
|
|
259
283
|
self.forward_pass_id = 0
|
|
284
|
+
self.init_new_workspace = False
|
|
260
285
|
|
|
261
286
|
# Apply the rank zero filter to logger
|
|
262
|
-
if not any(isinstance(f, RankZeroFilter) for f in logger.filters):
|
|
263
|
-
logger.addFilter(RankZeroFilter(tp_rank == 0))
|
|
264
287
|
if server_args.show_time_cost:
|
|
265
288
|
enable_show_time_cost()
|
|
266
289
|
|
|
267
290
|
# Model-specific adjustment
|
|
268
291
|
self.model_specific_adjustment()
|
|
269
292
|
|
|
270
|
-
#
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
"speculative_algorithm": self.spec_algorithm,
|
|
277
|
-
}
|
|
278
|
-
)
|
|
293
|
+
# Set the global server_args in the scheduler process
|
|
294
|
+
set_global_server_args_for_scheduler(server_args)
|
|
295
|
+
global_server_args = get_global_server_args()
|
|
296
|
+
|
|
297
|
+
# FIXME: hacky set `use_mla_backend`
|
|
298
|
+
global_server_args.use_mla_backend = self.use_mla_backend
|
|
279
299
|
|
|
280
300
|
# Init OpenMP threads binding for CPU
|
|
281
301
|
if self.device == "cpu":
|
|
@@ -306,6 +326,26 @@ class ModelRunner:
|
|
|
306
326
|
self._model_update_group = {}
|
|
307
327
|
self._weights_send_group = {}
|
|
308
328
|
|
|
329
|
+
if (
|
|
330
|
+
self.server_args.enable_piecewise_cuda_graph
|
|
331
|
+
and self.can_run_piecewise_cuda_graph()
|
|
332
|
+
):
|
|
333
|
+
self.attention_layers = []
|
|
334
|
+
for layer in self.model.model.layers:
|
|
335
|
+
if hasattr(layer, "self_attn") and hasattr(layer.self_attn, "attn"):
|
|
336
|
+
self.attention_layers.append(layer.self_attn.attn)
|
|
337
|
+
if len(self.attention_layers) < self.model_config.num_hidden_layers:
|
|
338
|
+
# TODO(yuwei): support Non-Standard GQA
|
|
339
|
+
log_info_on_rank0(
|
|
340
|
+
logger,
|
|
341
|
+
"Disable piecewise CUDA graph because some layers do not apply Standard GQA",
|
|
342
|
+
)
|
|
343
|
+
self.piecewise_cuda_graph_runner = None
|
|
344
|
+
else:
|
|
345
|
+
self.piecewise_cuda_graph_runner = PiecewiseCudaGraphRunner(self)
|
|
346
|
+
else:
|
|
347
|
+
self.piecewise_cuda_graph_runner = None
|
|
348
|
+
|
|
309
349
|
def initialize(self, min_per_gpu_memory: float):
|
|
310
350
|
server_args = self.server_args
|
|
311
351
|
|
|
@@ -340,6 +380,11 @@ class ModelRunner:
|
|
|
340
380
|
)
|
|
341
381
|
self.expert_location_updater = ExpertLocationUpdater()
|
|
342
382
|
|
|
383
|
+
(
|
|
384
|
+
ElasticEPStateManager.init(self.server_args)
|
|
385
|
+
if self.server_args.elastic_ep_backend
|
|
386
|
+
else None
|
|
387
|
+
)
|
|
343
388
|
# Load the model
|
|
344
389
|
self.sampler = Sampler()
|
|
345
390
|
self.load_model()
|
|
@@ -354,24 +399,10 @@ class ModelRunner:
|
|
|
354
399
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
|
355
400
|
self.is_hybrid = self.model_config.is_hybrid = True
|
|
356
401
|
|
|
357
|
-
if self.
|
|
358
|
-
|
|
402
|
+
if config := self.mamba2_config:
|
|
403
|
+
class_name = config.__class__.__name__
|
|
404
|
+
logger.warning(f"{class_name} model detected, disable radix cache")
|
|
359
405
|
self.server_args.disable_radix_cache = True
|
|
360
|
-
if self.server_args.max_mamba_cache_size is None:
|
|
361
|
-
if self.server_args.max_running_requests is not None:
|
|
362
|
-
self.server_args.max_mamba_cache_size = (
|
|
363
|
-
self.server_args.max_running_requests
|
|
364
|
-
)
|
|
365
|
-
else:
|
|
366
|
-
self.server_args.max_mamba_cache_size = 512
|
|
367
|
-
self.server_args.max_mamba_cache_size = (
|
|
368
|
-
self.server_args.max_mamba_cache_size
|
|
369
|
-
// (
|
|
370
|
-
self.server_args.dp_size
|
|
371
|
-
if self.server_args.enable_dp_attention
|
|
372
|
-
else 1
|
|
373
|
-
)
|
|
374
|
-
)
|
|
375
406
|
|
|
376
407
|
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
|
377
408
|
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
|
@@ -402,7 +433,7 @@ class ModelRunner:
|
|
|
402
433
|
# In layered loading, torchao may have been applied
|
|
403
434
|
if not torchao_applied:
|
|
404
435
|
apply_torchao_config_to_model(
|
|
405
|
-
self.model,
|
|
436
|
+
self.model, get_global_server_args().torchao_config
|
|
406
437
|
)
|
|
407
438
|
|
|
408
439
|
# Apply torch TP if the model supports it
|
|
@@ -482,6 +513,16 @@ class ModelRunner:
|
|
|
482
513
|
)
|
|
483
514
|
server_args.attention_backend = "torch_native"
|
|
484
515
|
|
|
516
|
+
if (
|
|
517
|
+
server_args.attention_backend == "intel_xpu"
|
|
518
|
+
and server_args.device == "xpu"
|
|
519
|
+
and not _is_xpu_xmx_available
|
|
520
|
+
):
|
|
521
|
+
logger.info(
|
|
522
|
+
"The current platform does not support Intel XMX, will fallback to triton backend."
|
|
523
|
+
)
|
|
524
|
+
server_args.attention_backend = "triton"
|
|
525
|
+
|
|
485
526
|
if server_args.prefill_attention_backend is not None and (
|
|
486
527
|
server_args.prefill_attention_backend
|
|
487
528
|
== server_args.decode_attention_backend
|
|
@@ -547,8 +588,9 @@ class ModelRunner:
|
|
|
547
588
|
server_args.attention_backend = "ascend"
|
|
548
589
|
else:
|
|
549
590
|
server_args.attention_backend = "triton"
|
|
550
|
-
|
|
551
|
-
|
|
591
|
+
log_info_on_rank0(
|
|
592
|
+
logger,
|
|
593
|
+
f"Attention backend not explicitly specified. Use {server_args.attention_backend} backend by default.",
|
|
552
594
|
)
|
|
553
595
|
elif self.use_mla_backend:
|
|
554
596
|
if server_args.device != "cpu":
|
|
@@ -591,11 +633,15 @@ class ModelRunner:
|
|
|
591
633
|
f"{self.model_config.hf_config.model_type}"
|
|
592
634
|
)
|
|
593
635
|
|
|
594
|
-
if
|
|
636
|
+
if (
|
|
637
|
+
not self.use_mla_backend
|
|
638
|
+
or server_args.attention_backend
|
|
639
|
+
not in CHUNKED_PREFIX_CACHE_SUPPORTED_ATTENTION_BACKENDS
|
|
640
|
+
):
|
|
595
641
|
server_args.disable_chunked_prefix_cache = True
|
|
596
642
|
|
|
597
643
|
if not server_args.disable_chunked_prefix_cache:
|
|
598
|
-
logger
|
|
644
|
+
log_info_on_rank0(logger, "Chunked prefix cache is turned on.")
|
|
599
645
|
|
|
600
646
|
if server_args.attention_backend == "aiter":
|
|
601
647
|
if self.model_config.context_len > 8192:
|
|
@@ -622,6 +668,35 @@ class ModelRunner:
|
|
|
622
668
|
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
|
623
669
|
)
|
|
624
670
|
|
|
671
|
+
if self.model_config.hf_config.model_type == "qwen3_vl_moe":
|
|
672
|
+
if (
|
|
673
|
+
quantization_config := getattr(
|
|
674
|
+
self.model_config.hf_config, "quantization_config", None
|
|
675
|
+
)
|
|
676
|
+
) is not None:
|
|
677
|
+
weight_block_size_n = quantization_config["weight_block_size"][0]
|
|
678
|
+
|
|
679
|
+
if self.tp_size % self.moe_ep_size != 0:
|
|
680
|
+
raise ValueError(
|
|
681
|
+
f"tp_size {self.tp_size} must be divisible by moe_ep_size {self.moe_ep_size}"
|
|
682
|
+
)
|
|
683
|
+
moe_tp_size = self.tp_size // self.moe_ep_size
|
|
684
|
+
|
|
685
|
+
moe_intermediate_size = (
|
|
686
|
+
self.model_config.hf_text_config.moe_intermediate_size
|
|
687
|
+
)
|
|
688
|
+
if moe_intermediate_size % moe_tp_size != 0:
|
|
689
|
+
raise ValueError(
|
|
690
|
+
f"moe_intermediate_size {moe_intermediate_size} must be divisible by moe_tp_size ({moe_tp_size}) which is tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size})."
|
|
691
|
+
)
|
|
692
|
+
|
|
693
|
+
if (moe_intermediate_size // moe_tp_size) % weight_block_size_n != 0:
|
|
694
|
+
raise ValueError(
|
|
695
|
+
f"For qwen3-vl-fp8 models, please make sure ({moe_intermediate_size=} / {moe_tp_size=}) % {weight_block_size_n=} == 0 "
|
|
696
|
+
f"where moe_tp_size is equal to tp_size ({self.tp_size}) divided by moe_ep_size ({self.moe_ep_size}). "
|
|
697
|
+
f"You can fix this by setting arguments `--tp-size` and `--ep-size` correctly."
|
|
698
|
+
)
|
|
699
|
+
|
|
625
700
|
def init_torch_distributed(self):
|
|
626
701
|
logger.info("Init torch distributed begin.")
|
|
627
702
|
|
|
@@ -634,7 +709,18 @@ class ModelRunner:
|
|
|
634
709
|
raise
|
|
635
710
|
|
|
636
711
|
if self.device == "cuda":
|
|
637
|
-
|
|
712
|
+
if self.server_args.elastic_ep_backend == "mooncake":
|
|
713
|
+
backend = "mooncake"
|
|
714
|
+
if self.server_args.mooncake_ib_device:
|
|
715
|
+
mooncake_ib_device = self.server_args.mooncake_ib_device.split(",")
|
|
716
|
+
try:
|
|
717
|
+
from mooncake import ep as mooncake_ep
|
|
718
|
+
|
|
719
|
+
mooncake_ep.set_device_filter(mooncake_ib_device)
|
|
720
|
+
except:
|
|
721
|
+
pass # A warning will be raised in `init_distributed_environment`
|
|
722
|
+
else:
|
|
723
|
+
backend = "nccl"
|
|
638
724
|
elif self.device == "xpu":
|
|
639
725
|
backend = "xccl"
|
|
640
726
|
elif self.device == "hpu":
|
|
@@ -689,6 +775,7 @@ class ModelRunner:
|
|
|
689
775
|
pipeline_model_parallel_size=self.pp_size,
|
|
690
776
|
expert_model_parallel_size=self.moe_ep_size,
|
|
691
777
|
duplicate_tp_group=self.server_args.enable_pdmux,
|
|
778
|
+
torch_compile=self.server_args.enable_piecewise_cuda_graph,
|
|
692
779
|
)
|
|
693
780
|
initialize_dp_attention(
|
|
694
781
|
server_args=self.server_args,
|
|
@@ -747,6 +834,16 @@ class ModelRunner:
|
|
|
747
834
|
set_cuda_arch()
|
|
748
835
|
|
|
749
836
|
# Prepare the model config
|
|
837
|
+
from sglang.srt.configs.modelopt_config import ModelOptConfig
|
|
838
|
+
|
|
839
|
+
modelopt_config = ModelOptConfig(
|
|
840
|
+
quant=self.server_args.modelopt_quant,
|
|
841
|
+
checkpoint_restore_path=self.server_args.modelopt_checkpoint_restore_path,
|
|
842
|
+
checkpoint_save_path=self.server_args.modelopt_checkpoint_save_path,
|
|
843
|
+
export_path=self.server_args.modelopt_export_path,
|
|
844
|
+
quantize_and_serve=self.server_args.quantize_and_serve,
|
|
845
|
+
)
|
|
846
|
+
|
|
750
847
|
self.load_config = LoadConfig(
|
|
751
848
|
load_format=self.server_args.load_format,
|
|
752
849
|
download_dir=self.server_args.download_dir,
|
|
@@ -755,6 +852,7 @@ class ModelRunner:
|
|
|
755
852
|
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
|
756
853
|
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
|
757
854
|
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
|
855
|
+
modelopt_config=modelopt_config,
|
|
758
856
|
)
|
|
759
857
|
if self.device == "cpu":
|
|
760
858
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
|
@@ -841,33 +939,56 @@ class ModelRunner:
|
|
|
841
939
|
f"mem usage={self.weight_load_mem_usage:.2f} GB."
|
|
842
940
|
)
|
|
843
941
|
|
|
844
|
-
|
|
845
|
-
|
|
846
|
-
dist.
|
|
847
|
-
|
|
848
|
-
|
|
849
|
-
|
|
850
|
-
|
|
851
|
-
|
|
852
|
-
|
|
853
|
-
|
|
854
|
-
|
|
942
|
+
if self.server_args.elastic_ep_backend == "mooncake":
|
|
943
|
+
# Mooncake does not support `monitored_barrier`
|
|
944
|
+
dist.barrier(group=get_tp_group().cpu_group)
|
|
945
|
+
else:
|
|
946
|
+
# Handle the case where some ranks do not finish loading.
|
|
947
|
+
try:
|
|
948
|
+
dist.monitored_barrier(
|
|
949
|
+
group=get_tp_group().cpu_group,
|
|
950
|
+
timeout=datetime.timedelta(
|
|
951
|
+
seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S
|
|
952
|
+
),
|
|
953
|
+
wait_all_ranks=True,
|
|
954
|
+
)
|
|
955
|
+
except RuntimeError:
|
|
956
|
+
raise ValueError(
|
|
957
|
+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
|
958
|
+
) from None
|
|
855
959
|
|
|
856
960
|
def update_expert_location(
|
|
857
961
|
self,
|
|
858
962
|
new_expert_location_metadata: ExpertLocationMetadata,
|
|
859
963
|
update_layer_ids: List[int],
|
|
860
964
|
):
|
|
861
|
-
|
|
862
|
-
|
|
863
|
-
|
|
864
|
-
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
|
|
965
|
+
if ElasticEPStateManager.instance() is not None:
|
|
966
|
+
# TODO: refactor the weights update when elastic ep
|
|
967
|
+
old_expert_location_metadata = get_global_expert_location_metadata()
|
|
968
|
+
assert old_expert_location_metadata is not None
|
|
969
|
+
old_expert_location_metadata.update(
|
|
970
|
+
new_expert_location_metadata,
|
|
971
|
+
update_layer_ids=update_layer_ids,
|
|
972
|
+
)
|
|
973
|
+
self.update_weights_from_disk(
|
|
974
|
+
self.server_args.model_path,
|
|
975
|
+
self.server_args.load_format,
|
|
976
|
+
lambda name: "mlp.experts" in name and "mlp.shared_experts" not in name,
|
|
977
|
+
)
|
|
978
|
+
else:
|
|
979
|
+
self.expert_location_updater.update(
|
|
980
|
+
self.model.routed_experts_weights_of_layer,
|
|
981
|
+
new_expert_location_metadata,
|
|
982
|
+
update_layer_ids=update_layer_ids,
|
|
983
|
+
nnodes=self.server_args.nnodes,
|
|
984
|
+
rank=self.tp_rank,
|
|
985
|
+
)
|
|
868
986
|
|
|
869
987
|
def update_weights_from_disk(
|
|
870
|
-
self,
|
|
988
|
+
self,
|
|
989
|
+
model_path: str,
|
|
990
|
+
load_format: str,
|
|
991
|
+
weight_name_filter: Optional[Callable[[str], bool]] = None,
|
|
871
992
|
) -> tuple[bool, str]:
|
|
872
993
|
"""Update engine weights in-place from the disk."""
|
|
873
994
|
logger.info(
|
|
@@ -880,7 +1001,7 @@ class ModelRunner:
|
|
|
880
1001
|
load_config = LoadConfig(load_format=load_format)
|
|
881
1002
|
|
|
882
1003
|
# Only support DefaultModelLoader for now
|
|
883
|
-
loader = get_model_loader(load_config)
|
|
1004
|
+
loader = get_model_loader(load_config, self.model_config)
|
|
884
1005
|
if not isinstance(loader, DefaultModelLoader):
|
|
885
1006
|
message = f"Failed to get model loader: {loader}."
|
|
886
1007
|
return False, message
|
|
@@ -889,6 +1010,11 @@ class ModelRunner:
|
|
|
889
1010
|
iter = loader._get_weights_iterator(
|
|
890
1011
|
DefaultModelLoader.Source.init_new(config, self.model)
|
|
891
1012
|
)
|
|
1013
|
+
if weight_name_filter is not None:
|
|
1014
|
+
iter = (
|
|
1015
|
+
(name, weight) for name, weight in iter if weight_name_filter(name)
|
|
1016
|
+
)
|
|
1017
|
+
|
|
892
1018
|
return iter
|
|
893
1019
|
|
|
894
1020
|
def model_load_weights(model, iter):
|
|
@@ -1267,8 +1393,8 @@ class ModelRunner:
|
|
|
1267
1393
|
"num_nextn_predict_layers",
|
|
1268
1394
|
self.num_effective_layers,
|
|
1269
1395
|
)
|
|
1270
|
-
elif self.
|
|
1271
|
-
num_layers = len(
|
|
1396
|
+
elif config := self.mambaish_config:
|
|
1397
|
+
num_layers = len(config.full_attention_layer_ids)
|
|
1272
1398
|
else:
|
|
1273
1399
|
num_layers = self.num_effective_layers
|
|
1274
1400
|
if self.use_mla_backend:
|
|
@@ -1277,6 +1403,17 @@ class ModelRunner:
|
|
|
1277
1403
|
* num_layers
|
|
1278
1404
|
* torch._utils._element_size(self.kv_cache_dtype)
|
|
1279
1405
|
)
|
|
1406
|
+
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
|
|
1407
|
+
if is_deepseek_nsa(self.model_config.hf_config):
|
|
1408
|
+
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
|
|
1409
|
+
indexer_size_per_token = (
|
|
1410
|
+
index_head_dim
|
|
1411
|
+
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
|
1412
|
+
)
|
|
1413
|
+
element_size = torch._utils._element_size(
|
|
1414
|
+
NSATokenToKVPool.index_k_with_scale_buffer_dtype
|
|
1415
|
+
)
|
|
1416
|
+
cell_size += indexer_size_per_token * num_layers * element_size
|
|
1280
1417
|
else:
|
|
1281
1418
|
cell_size = (
|
|
1282
1419
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
|
@@ -1288,22 +1425,77 @@ class ModelRunner:
|
|
|
1288
1425
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
|
1289
1426
|
1 - self.mem_fraction_static
|
|
1290
1427
|
)
|
|
1291
|
-
if self.
|
|
1292
|
-
rest_memory
|
|
1293
|
-
self.server_args.max_mamba_cache_size
|
|
1294
|
-
* self.model_config.hf_config.mamba_cache_per_req
|
|
1295
|
-
/ (1 << 30)
|
|
1296
|
-
)
|
|
1428
|
+
if self.mambaish_config is not None:
|
|
1429
|
+
rest_memory = self.handle_max_mamba_cache(rest_memory)
|
|
1297
1430
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
|
1298
1431
|
return max_num_token
|
|
1299
1432
|
|
|
1433
|
+
def handle_max_mamba_cache(self, total_rest_memory):
|
|
1434
|
+
config = self.mambaish_config
|
|
1435
|
+
server_args = self.server_args
|
|
1436
|
+
assert config is not None
|
|
1437
|
+
|
|
1438
|
+
speculativa_ratio = (
|
|
1439
|
+
0
|
|
1440
|
+
if server_args.speculative_num_draft_tokens is None
|
|
1441
|
+
else server_args.speculative_num_draft_tokens
|
|
1442
|
+
)
|
|
1443
|
+
if (
|
|
1444
|
+
server_args.disable_radix_cache
|
|
1445
|
+
or config.mamba2_cache_params.mamba_cache_per_req == 0
|
|
1446
|
+
):
|
|
1447
|
+
# with disable radix cache, sets the max_mamba_cache_size based on the max_running_requests
|
|
1448
|
+
if server_args.max_mamba_cache_size is None:
|
|
1449
|
+
if server_args.max_running_requests is not None:
|
|
1450
|
+
server_args.max_mamba_cache_size = server_args.max_running_requests
|
|
1451
|
+
else:
|
|
1452
|
+
server_args.max_mamba_cache_size = 512
|
|
1453
|
+
else:
|
|
1454
|
+
# allocate the memory based on the ratio between mamba state memory vs. full kv cache memory
|
|
1455
|
+
# solve the equations:
|
|
1456
|
+
# 1. mamba_state_memory + full_kv_cache_memory == total_rest_memory
|
|
1457
|
+
# 2. mamba_state_memory / full_kv_cache_memory == server_args.mamba_full_memory_ratio
|
|
1458
|
+
mamba_state_memory_raw = (
|
|
1459
|
+
total_rest_memory
|
|
1460
|
+
* server_args.mamba_full_memory_ratio
|
|
1461
|
+
/ (1 + server_args.mamba_full_memory_ratio)
|
|
1462
|
+
)
|
|
1463
|
+
# calculate the max_mamba_cache_size based on the given total mamba memory
|
|
1464
|
+
server_args.max_mamba_cache_size = int(
|
|
1465
|
+
(mamba_state_memory_raw * (1 << 30))
|
|
1466
|
+
// config.mamba2_cache_params.mamba_cache_per_req
|
|
1467
|
+
// (1 + speculativa_ratio)
|
|
1468
|
+
)
|
|
1469
|
+
|
|
1470
|
+
if self.hybrid_gdn_config is not None:
|
|
1471
|
+
server_args.max_mamba_cache_size = server_args.max_mamba_cache_size // (
|
|
1472
|
+
server_args.dp_size if server_args.enable_dp_attention else 1
|
|
1473
|
+
)
|
|
1474
|
+
mamba_state_memory = (
|
|
1475
|
+
server_args.max_mamba_cache_size
|
|
1476
|
+
* config.mamba2_cache_params.mamba_cache_per_req
|
|
1477
|
+
* (1 + speculativa_ratio)
|
|
1478
|
+
/ (1 << 30)
|
|
1479
|
+
)
|
|
1480
|
+
return total_rest_memory - mamba_state_memory
|
|
1481
|
+
|
|
1300
1482
|
@property
|
|
1301
|
-
def
|
|
1302
|
-
|
|
1303
|
-
|
|
1304
|
-
|
|
1305
|
-
|
|
1306
|
-
|
|
1483
|
+
def hybrid_gdn_config(self):
|
|
1484
|
+
config = self.model_config.hf_config
|
|
1485
|
+
if isinstance(config, Qwen3NextConfig):
|
|
1486
|
+
return config
|
|
1487
|
+
return None
|
|
1488
|
+
|
|
1489
|
+
@property
|
|
1490
|
+
def mamba2_config(self):
|
|
1491
|
+
config = self.model_config.hf_config
|
|
1492
|
+
if isinstance(config, FalconH1Config | NemotronHConfig):
|
|
1493
|
+
return config
|
|
1494
|
+
return None
|
|
1495
|
+
|
|
1496
|
+
@property
|
|
1497
|
+
def mambaish_config(self):
|
|
1498
|
+
return self.mamba2_config or self.hybrid_gdn_config
|
|
1307
1499
|
|
|
1308
1500
|
def set_num_token_hybrid(self):
|
|
1309
1501
|
if (
|
|
@@ -1387,6 +1579,27 @@ class ModelRunner:
|
|
|
1387
1579
|
f"Use Sliding window memory pool. full_layer_tokens={self.full_max_total_num_tokens}, swa_layer_tokens={self.swa_max_total_num_tokens}"
|
|
1388
1580
|
)
|
|
1389
1581
|
|
|
1582
|
+
def can_run_piecewise_cuda_graph(self):
|
|
1583
|
+
if self.server_args.disable_cuda_graph:
|
|
1584
|
+
log_info_on_rank0(
|
|
1585
|
+
logger, "Disable piecewise CUDA graph because disable_cuda_graph is set"
|
|
1586
|
+
)
|
|
1587
|
+
return False
|
|
1588
|
+
if self.server_args.enable_torch_compile:
|
|
1589
|
+
log_info_on_rank0(
|
|
1590
|
+
logger,
|
|
1591
|
+
"Disable piecewise CUDA graph because piecewise_cuda_graph has conflict with torch compile",
|
|
1592
|
+
)
|
|
1593
|
+
return False
|
|
1594
|
+
if self.pp_size > 1:
|
|
1595
|
+
# TODO(yuwei): support PP
|
|
1596
|
+
log_info_on_rank0(
|
|
1597
|
+
logger,
|
|
1598
|
+
"Disable piecewise CUDA graph because piecewise_cuda_graph does not support PP",
|
|
1599
|
+
)
|
|
1600
|
+
return False
|
|
1601
|
+
return True
|
|
1602
|
+
|
|
1390
1603
|
def init_memory_pool(
|
|
1391
1604
|
self,
|
|
1392
1605
|
total_gpu_memory: int,
|
|
@@ -1417,6 +1630,8 @@ class ModelRunner:
|
|
|
1417
1630
|
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
|
1418
1631
|
else:
|
|
1419
1632
|
self.kv_cache_dtype = torch.float8_e4m3fn
|
|
1633
|
+
elif self.server_args.kv_cache_dtype in ("bf16", "bfloat16"):
|
|
1634
|
+
self.kv_cache_dtype = torch.bfloat16
|
|
1420
1635
|
else:
|
|
1421
1636
|
raise ValueError(
|
|
1422
1637
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
|
@@ -1438,8 +1653,16 @@ class ModelRunner:
|
|
|
1438
1653
|
),
|
|
1439
1654
|
4096,
|
|
1440
1655
|
)
|
|
1441
|
-
|
|
1442
|
-
|
|
1656
|
+
|
|
1657
|
+
if self.mambaish_config is not None:
|
|
1658
|
+
ratio = (
|
|
1659
|
+
MAMBA_CACHE_SIZE_MAX_RUNNING_REQUESTS_RATIO
|
|
1660
|
+
if not self.server_args.disable_radix_cache
|
|
1661
|
+
else 1
|
|
1662
|
+
)
|
|
1663
|
+
max_num_reqs = min(
|
|
1664
|
+
max_num_reqs, self.server_args.max_mamba_cache_size // ratio
|
|
1665
|
+
)
|
|
1443
1666
|
|
|
1444
1667
|
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
|
1445
1668
|
if self.is_draft_worker:
|
|
@@ -1506,39 +1729,43 @@ class ModelRunner:
|
|
|
1506
1729
|
extra_max_context_len += self.server_args.speculative_num_draft_tokens
|
|
1507
1730
|
|
|
1508
1731
|
if self.server_args.disaggregation_mode == "decode":
|
|
1509
|
-
from sglang.srt.disaggregation.decode import
|
|
1732
|
+
from sglang.srt.disaggregation.decode import (
|
|
1733
|
+
DecodeReqToTokenPool,
|
|
1734
|
+
HybridMambaDecodeReqToTokenPool,
|
|
1735
|
+
)
|
|
1510
1736
|
|
|
1511
1737
|
# subscribe memory for pre-allocated requests
|
|
1512
1738
|
# if max_num_reqs <= 32, we pre-allocate 2x requests
|
|
1513
1739
|
pre_alloc_size = max_num_reqs * 2 if max_num_reqs <= 32 else 0
|
|
1514
|
-
self.
|
|
1515
|
-
|
|
1516
|
-
|
|
1517
|
-
|
|
1518
|
-
|
|
1519
|
-
|
|
1520
|
-
|
|
1521
|
-
|
|
1522
|
-
|
|
1523
|
-
|
|
1524
|
-
|
|
1525
|
-
|
|
1526
|
-
|
|
1527
|
-
|
|
1528
|
-
|
|
1529
|
-
|
|
1530
|
-
|
|
1740
|
+
if config := self.mambaish_config:
|
|
1741
|
+
self.req_to_token_pool = HybridMambaDecodeReqToTokenPool(
|
|
1742
|
+
size=max_num_reqs,
|
|
1743
|
+
max_context_len=self.model_config.context_len
|
|
1744
|
+
+ extra_max_context_len,
|
|
1745
|
+
device=self.device,
|
|
1746
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
1747
|
+
cache_params=config.mamba2_cache_params,
|
|
1748
|
+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
|
1749
|
+
pre_alloc_size=pre_alloc_size,
|
|
1750
|
+
)
|
|
1751
|
+
else:
|
|
1752
|
+
self.req_to_token_pool = DecodeReqToTokenPool(
|
|
1753
|
+
size=max_num_reqs,
|
|
1754
|
+
max_context_len=self.model_config.context_len
|
|
1755
|
+
+ extra_max_context_len,
|
|
1756
|
+
device=self.device,
|
|
1757
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
1758
|
+
pre_alloc_size=pre_alloc_size,
|
|
1759
|
+
)
|
|
1760
|
+
elif config := self.mambaish_config:
|
|
1531
1761
|
self.req_to_token_pool = HybridReqToTokenPool(
|
|
1532
1762
|
size=max_num_reqs,
|
|
1763
|
+
mamba_size=self.server_args.max_mamba_cache_size,
|
|
1533
1764
|
max_context_len=self.model_config.context_len
|
|
1534
1765
|
+ extra_max_context_len,
|
|
1535
1766
|
device=self.device,
|
|
1536
1767
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
1537
|
-
|
|
1538
|
-
temporal_state_shape=temporal_state_shape,
|
|
1539
|
-
conv_dtype=conv_dtype,
|
|
1540
|
-
ssm_dtype=ssm_dtype,
|
|
1541
|
-
mamba_layers=mamba_layers,
|
|
1768
|
+
cache_params=config.mamba2_cache_params,
|
|
1542
1769
|
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
|
1543
1770
|
)
|
|
1544
1771
|
else:
|
|
@@ -1640,7 +1867,7 @@ class ModelRunner:
|
|
|
1640
1867
|
enable_kvcache_transpose=False,
|
|
1641
1868
|
device=self.device,
|
|
1642
1869
|
)
|
|
1643
|
-
elif self.
|
|
1870
|
+
elif config := self.mambaish_config:
|
|
1644
1871
|
self.token_to_kv_pool = HybridLinearKVPool(
|
|
1645
1872
|
page_size=self.page_size,
|
|
1646
1873
|
size=self.max_total_num_tokens,
|
|
@@ -1651,12 +1878,11 @@ class ModelRunner:
|
|
|
1651
1878
|
head_dim=self.model_config.head_dim,
|
|
1652
1879
|
# if draft worker, we only need 1 attention layer's kv pool
|
|
1653
1880
|
full_attention_layer_ids=(
|
|
1654
|
-
[0]
|
|
1655
|
-
if self.is_draft_worker
|
|
1656
|
-
else self.model_config.hf_config.full_attention_layer_ids
|
|
1881
|
+
[0] if self.is_draft_worker else config.full_attention_layer_ids
|
|
1657
1882
|
),
|
|
1658
1883
|
enable_kvcache_transpose=False,
|
|
1659
1884
|
device=self.device,
|
|
1885
|
+
mamba_pool=self.req_to_token_pool.mamba_pool,
|
|
1660
1886
|
)
|
|
1661
1887
|
else:
|
|
1662
1888
|
self.token_to_kv_pool = MHATokenToKVPool(
|
|
@@ -1672,13 +1898,17 @@ class ModelRunner:
|
|
|
1672
1898
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
|
1673
1899
|
start_layer=self.start_layer,
|
|
1674
1900
|
end_layer=self.end_layer,
|
|
1901
|
+
enable_kv_cache_copy=(
|
|
1902
|
+
self.server_args.speculative_algorithm is not None
|
|
1903
|
+
),
|
|
1675
1904
|
)
|
|
1676
1905
|
|
|
1677
1906
|
# Initialize token_to_kv_pool_allocator
|
|
1678
1907
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
|
1679
1908
|
if self.token_to_kv_pool_allocator is None:
|
|
1680
1909
|
if _is_npu and (
|
|
1681
|
-
self.server_args.attention_backend == "ascend"
|
|
1910
|
+
self.server_args.attention_backend == "ascend"
|
|
1911
|
+
or self.hybrid_gdn_config is not None
|
|
1682
1912
|
):
|
|
1683
1913
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
|
1684
1914
|
self.max_total_num_tokens,
|
|
@@ -1743,16 +1973,10 @@ class ModelRunner:
|
|
|
1743
1973
|
|
|
1744
1974
|
def _get_attention_backend(self):
|
|
1745
1975
|
"""Init attention kernel backend."""
|
|
1746
|
-
self.decode_attention_backend_str = (
|
|
1747
|
-
self.server_args.
|
|
1748
|
-
if self.server_args.decode_attention_backend
|
|
1749
|
-
else self.server_args.attention_backend
|
|
1750
|
-
)
|
|
1751
|
-
self.prefill_attention_backend_str = (
|
|
1752
|
-
self.server_args.prefill_attention_backend
|
|
1753
|
-
if self.server_args.prefill_attention_backend
|
|
1754
|
-
else self.server_args.attention_backend
|
|
1976
|
+
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
|
|
1977
|
+
self.server_args.get_attention_backends()
|
|
1755
1978
|
)
|
|
1979
|
+
|
|
1756
1980
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
|
1757
1981
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
|
1758
1982
|
HybridAttnBackend,
|
|
@@ -1781,12 +2005,10 @@ class ModelRunner:
|
|
|
1781
2005
|
self.server_args.attention_backend
|
|
1782
2006
|
)
|
|
1783
2007
|
|
|
1784
|
-
|
|
1785
|
-
|
|
1786
|
-
|
|
1787
|
-
|
|
1788
|
-
}
|
|
1789
|
-
)
|
|
2008
|
+
(
|
|
2009
|
+
get_global_server_args().prefill_attention_backend,
|
|
2010
|
+
get_global_server_args().decode_attention_backend,
|
|
2011
|
+
) = (self.prefill_attention_backend_str, self.decode_attention_backend_str)
|
|
1790
2012
|
return attn_backend
|
|
1791
2013
|
|
|
1792
2014
|
def _get_attention_backend_from_str(self, backend_str: str):
|
|
@@ -1924,6 +2146,11 @@ class ModelRunner:
|
|
|
1924
2146
|
kwargs["input_embeds"] = forward_batch.input_embeds.bfloat16()
|
|
1925
2147
|
if not self.is_generation:
|
|
1926
2148
|
kwargs["get_embedding"] = True
|
|
2149
|
+
|
|
2150
|
+
if self.piecewise_cuda_graph_runner is not None:
|
|
2151
|
+
if self.piecewise_cuda_graph_runner.can_run(forward_batch):
|
|
2152
|
+
return self.piecewise_cuda_graph_runner.replay(forward_batch, **kwargs)
|
|
2153
|
+
|
|
1927
2154
|
return self.model.forward(
|
|
1928
2155
|
forward_batch.input_ids,
|
|
1929
2156
|
forward_batch.positions,
|
|
@@ -2057,15 +2284,11 @@ class ModelRunner:
|
|
|
2057
2284
|
def _preprocess_logits(
|
|
2058
2285
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
|
2059
2286
|
):
|
|
2060
|
-
#
|
|
2061
|
-
|
|
2062
|
-
|
|
2063
|
-
|
|
2064
|
-
|
|
2065
|
-
sampling_info.sampling_info_done.wait()
|
|
2066
|
-
else:
|
|
2067
|
-
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
|
2068
|
-
sampling_info.update_regex_vocab_mask()
|
|
2287
|
+
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
|
|
2288
|
+
# was executed after we processed last batch's results.
|
|
2289
|
+
|
|
2290
|
+
# Calculate logits bias and apply it to next_token_logits.
|
|
2291
|
+
sampling_info.update_regex_vocab_mask()
|
|
2069
2292
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
|
2070
2293
|
|
|
2071
2294
|
def sample(
|
|
@@ -2164,6 +2387,23 @@ class ModelRunner:
|
|
|
2164
2387
|
)
|
|
2165
2388
|
ShardedStateLoader.save_model(self.model, path, pattern, max_size)
|
|
2166
2389
|
|
|
2390
|
+
def update_weights_from_ipc(self, recv_req):
|
|
2391
|
+
"""Update weights from IPC for checkpoint-engine integration."""
|
|
2392
|
+
try:
|
|
2393
|
+
from sglang.srt.checkpoint_engine.checkpoint_engine_worker import (
|
|
2394
|
+
SGLangCheckpointEngineWorkerExtensionImpl,
|
|
2395
|
+
)
|
|
2396
|
+
|
|
2397
|
+
# Create a worker extension that integrates with SGLang's model
|
|
2398
|
+
worker = SGLangCheckpointEngineWorkerExtensionImpl(self)
|
|
2399
|
+
worker.update_weights_from_ipc(recv_req.zmq_handles)
|
|
2400
|
+
return True, "IPC weight update completed successfully"
|
|
2401
|
+
except ImportError as e:
|
|
2402
|
+
return False, f"IPC weight update failed: ImportError {e}"
|
|
2403
|
+
except Exception as e:
|
|
2404
|
+
logger.error(f"IPC weight update failed: {e}")
|
|
2405
|
+
return False, str(e)
|
|
2406
|
+
|
|
2167
2407
|
|
|
2168
2408
|
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
|
2169
2409
|
params_dict = dict(model.named_parameters())
|