sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -19,7 +19,10 @@ import inspect
|
|
19
19
|
import json
|
20
20
|
import logging
|
21
21
|
import os
|
22
|
+
import socket
|
23
|
+
import threading
|
22
24
|
import time
|
25
|
+
from collections import defaultdict
|
23
26
|
from dataclasses import dataclass
|
24
27
|
from typing import List, Optional, Tuple, Union
|
25
28
|
|
@@ -27,17 +30,24 @@ import torch
|
|
27
30
|
import torch.distributed as dist
|
28
31
|
|
29
32
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
|
-
from sglang.srt.configs.load_config import LoadConfig
|
31
|
-
from sglang.srt.configs.model_config import
|
33
|
+
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
34
|
+
from sglang.srt.configs.model_config import (
|
35
|
+
AttentionArch,
|
36
|
+
ModelConfig,
|
37
|
+
get_nsa_index_head_dim,
|
38
|
+
is_deepseek_nsa,
|
39
|
+
)
|
32
40
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
33
41
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
34
42
|
from sglang.srt.distributed import (
|
43
|
+
get_pp_group,
|
35
44
|
get_tp_group,
|
36
45
|
get_world_group,
|
37
46
|
init_distributed_environment,
|
38
47
|
initialize_model_parallel,
|
39
48
|
set_custom_all_reduce,
|
40
49
|
set_mscclpp_all_reduce,
|
50
|
+
set_symm_mem_all_reduce,
|
41
51
|
)
|
42
52
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
43
53
|
from sglang.srt.eplb.eplb_manager import EPLBManager
|
@@ -53,6 +63,10 @@ from sglang.srt.eplb.expert_location import (
|
|
53
63
|
set_global_expert_location_metadata,
|
54
64
|
)
|
55
65
|
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
66
|
+
from sglang.srt.layers.attention.attention_registry import (
|
67
|
+
ATTENTION_BACKENDS,
|
68
|
+
attn_backend_wrapper,
|
69
|
+
)
|
56
70
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
57
71
|
from sglang.srt.layers.dp_attention import (
|
58
72
|
get_attention_tp_group,
|
@@ -83,16 +97,23 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
83
97
|
AscendMLAPagedTokenToKVPool,
|
84
98
|
AscendTokenToKVPool,
|
85
99
|
DoubleSparseTokenToKVPool,
|
100
|
+
HybridLinearKVPool,
|
101
|
+
HybridReqToTokenPool,
|
86
102
|
MHATokenToKVPool,
|
87
103
|
MLATokenToKVPool,
|
104
|
+
NSATokenToKVPool,
|
88
105
|
ReqToTokenPool,
|
89
106
|
SWAKVPool,
|
90
107
|
)
|
108
|
+
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
91
109
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
92
110
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
93
111
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
94
112
|
from sglang.srt.model_loader import get_model
|
95
113
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
114
|
+
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
115
|
+
trigger_init_weights_send_group_for_remote_instance_request,
|
116
|
+
)
|
96
117
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
97
118
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
98
119
|
from sglang.srt.offloader import (
|
@@ -100,7 +121,6 @@ from sglang.srt.offloader import (
|
|
100
121
|
get_offloader,
|
101
122
|
set_offloader,
|
102
123
|
)
|
103
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
104
124
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
105
125
|
from sglang.srt.server_args import ServerArgs
|
106
126
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -121,15 +141,38 @@ from sglang.srt.utils import (
|
|
121
141
|
is_no_spec_infer_or_topk_one,
|
122
142
|
is_npu,
|
123
143
|
is_sm100_supported,
|
144
|
+
log_info_on_rank0,
|
124
145
|
monkey_patch_p2p_access_check,
|
125
146
|
monkey_patch_vllm_gguf_config,
|
126
147
|
set_cuda_arch,
|
148
|
+
slow_rank_detector,
|
127
149
|
)
|
150
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
128
151
|
from sglang.srt.weight_sync.tensor_bucket import (
|
129
152
|
FlattenedTensorBucket,
|
130
153
|
FlattenedTensorMetadata,
|
131
154
|
)
|
132
155
|
|
156
|
+
MLA_ATTENTION_BACKENDS = [
|
157
|
+
"aiter",
|
158
|
+
"flashinfer",
|
159
|
+
"fa3",
|
160
|
+
"fa4",
|
161
|
+
"triton",
|
162
|
+
"flashmla",
|
163
|
+
"cutlass_mla",
|
164
|
+
"trtllm_mla",
|
165
|
+
"ascend",
|
166
|
+
"nsa",
|
167
|
+
]
|
168
|
+
|
169
|
+
|
170
|
+
def add_mla_attention_backend(backend_name):
|
171
|
+
if backend_name not in MLA_ATTENTION_BACKENDS:
|
172
|
+
MLA_ATTENTION_BACKENDS.append(backend_name)
|
173
|
+
logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
|
174
|
+
|
175
|
+
|
133
176
|
_is_hip = is_hip()
|
134
177
|
_is_npu = is_npu()
|
135
178
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -143,6 +186,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
143
186
|
logger = logging.getLogger(__name__)
|
144
187
|
|
145
188
|
|
189
|
+
if _is_npu:
|
190
|
+
import torch_npu
|
191
|
+
|
192
|
+
torch.npu.config.allow_internal_format = True
|
193
|
+
torch_npu.npu.set_compile_mode(jit_compile=False)
|
194
|
+
|
195
|
+
|
146
196
|
class RankZeroFilter(logging.Filter):
|
147
197
|
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
148
198
|
|
@@ -237,6 +287,9 @@ class ModelRunner:
|
|
237
287
|
# CPU offload
|
238
288
|
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
239
289
|
|
290
|
+
if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
|
291
|
+
slow_rank_detector.execute()
|
292
|
+
|
240
293
|
# Update deep gemm configure
|
241
294
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
242
295
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
@@ -251,6 +304,7 @@ class ModelRunner:
|
|
251
304
|
|
252
305
|
# For weight updates
|
253
306
|
self._model_update_group = {}
|
307
|
+
self._weights_send_group = {}
|
254
308
|
|
255
309
|
def initialize(self, min_per_gpu_memory: float):
|
256
310
|
server_args = self.server_args
|
@@ -300,6 +354,25 @@ class ModelRunner:
|
|
300
354
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
301
355
|
self.is_hybrid = self.model_config.is_hybrid = True
|
302
356
|
|
357
|
+
if self.is_hybrid_gdn:
|
358
|
+
logger.warning("Hybrid GDN model detected, disable radix cache")
|
359
|
+
self.server_args.disable_radix_cache = True
|
360
|
+
if self.server_args.max_mamba_cache_size is None:
|
361
|
+
if self.server_args.max_running_requests is not None:
|
362
|
+
self.server_args.max_mamba_cache_size = (
|
363
|
+
self.server_args.max_running_requests
|
364
|
+
)
|
365
|
+
else:
|
366
|
+
self.server_args.max_mamba_cache_size = 512
|
367
|
+
self.server_args.max_mamba_cache_size = (
|
368
|
+
self.server_args.max_mamba_cache_size
|
369
|
+
// (
|
370
|
+
self.server_args.dp_size
|
371
|
+
if self.server_args.enable_dp_attention
|
372
|
+
else 1
|
373
|
+
)
|
374
|
+
)
|
375
|
+
|
303
376
|
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
304
377
|
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
305
378
|
# determine the number of layers.
|
@@ -341,6 +414,20 @@ class ModelRunner:
|
|
341
414
|
if server_args.enable_lora:
|
342
415
|
self.init_lora_manager()
|
343
416
|
|
417
|
+
# Init Double Sparsity
|
418
|
+
if server_args.enable_double_sparsity:
|
419
|
+
if server_args.ds_heavy_channel_type is None:
|
420
|
+
raise ValueError(
|
421
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
422
|
+
)
|
423
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
424
|
+
|
425
|
+
# Enable batch invariant mode
|
426
|
+
if server_args.enable_deterministic_inference:
|
427
|
+
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
|
428
|
+
|
429
|
+
enable_batch_invariant_mode()
|
430
|
+
|
344
431
|
# Init memory pool and attention backends
|
345
432
|
self.init_memory_pool(
|
346
433
|
min_per_gpu_memory,
|
@@ -351,12 +438,12 @@ class ModelRunner:
|
|
351
438
|
self.init_cublas()
|
352
439
|
self.init_attention_backend()
|
353
440
|
self.init_device_graphs()
|
354
|
-
elif self.device
|
441
|
+
elif self.device in ["npu", "cpu"]:
|
355
442
|
self.init_attention_backend()
|
356
443
|
self.init_device_graphs()
|
357
444
|
else:
|
358
445
|
self.graph_runner = None
|
359
|
-
self.
|
446
|
+
self.graph_mem_usage = 0
|
360
447
|
self.init_attention_backend()
|
361
448
|
|
362
449
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
@@ -452,9 +539,7 @@ class ModelRunner:
|
|
452
539
|
elif _is_hip:
|
453
540
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
454
541
|
# TODO current aiter only support head number 16 or 128 head number
|
455
|
-
if
|
456
|
-
head_num == 128 or head_num == 16
|
457
|
-
) and self.spec_algorithm.is_none():
|
542
|
+
if head_num == 128 or head_num == 16:
|
458
543
|
server_args.attention_backend = "aiter"
|
459
544
|
else:
|
460
545
|
server_args.attention_backend = "triton"
|
@@ -467,16 +552,7 @@ class ModelRunner:
|
|
467
552
|
)
|
468
553
|
elif self.use_mla_backend:
|
469
554
|
if server_args.device != "cpu":
|
470
|
-
if server_args.attention_backend in
|
471
|
-
"aiter",
|
472
|
-
"flashinfer",
|
473
|
-
"fa3",
|
474
|
-
"triton",
|
475
|
-
"flashmla",
|
476
|
-
"cutlass_mla",
|
477
|
-
"trtllm_mla",
|
478
|
-
"ascend",
|
479
|
-
]:
|
555
|
+
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
|
480
556
|
logger.info(
|
481
557
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
482
558
|
)
|
@@ -506,11 +582,6 @@ class ModelRunner:
|
|
506
582
|
)
|
507
583
|
server_args.attention_backend = "triton"
|
508
584
|
server_args.disable_cuda_graph = True
|
509
|
-
if server_args.ds_heavy_channel_type is None:
|
510
|
-
raise ValueError(
|
511
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
512
|
-
)
|
513
|
-
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
514
585
|
|
515
586
|
if self.is_multimodal:
|
516
587
|
if not self.is_multimodal_chunked_prefill_supported:
|
@@ -548,7 +619,7 @@ class ModelRunner:
|
|
548
619
|
server_args.hicache_io_backend = "direct"
|
549
620
|
logger.warning(
|
550
621
|
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
551
|
-
|
622
|
+
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
552
623
|
)
|
553
624
|
|
554
625
|
def init_torch_distributed(self):
|
@@ -583,6 +654,7 @@ class ModelRunner:
|
|
583
654
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
584
655
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
585
656
|
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
657
|
+
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
|
586
658
|
|
587
659
|
if not self.is_draft_worker:
|
588
660
|
if self.device == "cpu":
|
@@ -593,6 +665,11 @@ class ModelRunner:
|
|
593
665
|
# Set local size to hint SGLang to use shared memory based AllReduce
|
594
666
|
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
595
667
|
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
668
|
+
|
669
|
+
@torch.library.register_fake("sgl_kernel::shm_allgather")
|
670
|
+
def _(data, dim):
|
671
|
+
return torch.cat([data] * self.tp_size, dim=dim)
|
672
|
+
|
596
673
|
else:
|
597
674
|
logger.warning(
|
598
675
|
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
@@ -625,6 +702,7 @@ class ModelRunner:
|
|
625
702
|
cpu_group=get_world_group().cpu_group,
|
626
703
|
)
|
627
704
|
self.tp_group = get_tp_group()
|
705
|
+
self.pp_group = get_pp_group()
|
628
706
|
self.attention_tp_group = get_attention_tp_group()
|
629
707
|
|
630
708
|
# Check memory for tensor parallelism
|
@@ -673,6 +751,10 @@ class ModelRunner:
|
|
673
751
|
load_format=self.server_args.load_format,
|
674
752
|
download_dir=self.server_args.download_dir,
|
675
753
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
754
|
+
tp_rank=self.tp_rank,
|
755
|
+
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
756
|
+
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
757
|
+
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
676
758
|
)
|
677
759
|
if self.device == "cpu":
|
678
760
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
@@ -681,16 +763,33 @@ class ModelRunner:
|
|
681
763
|
if self.server_args.load_format == "gguf":
|
682
764
|
monkey_patch_vllm_gguf_config()
|
683
765
|
|
766
|
+
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
767
|
+
if self.tp_rank == 0:
|
768
|
+
instance_ip = socket.gethostbyname(socket.gethostname())
|
769
|
+
t = threading.Thread(
|
770
|
+
target=trigger_init_weights_send_group_for_remote_instance_request,
|
771
|
+
args=(
|
772
|
+
self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
773
|
+
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
774
|
+
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
775
|
+
instance_ip,
|
776
|
+
),
|
777
|
+
)
|
778
|
+
t.start()
|
779
|
+
|
684
780
|
# Load the model
|
685
781
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
686
782
|
monkey_patch_vllm_parallel_state()
|
687
783
|
monkey_patch_isinstance_for_vllm_base_layer()
|
688
784
|
|
689
|
-
with self.memory_saver_adapter.region(
|
785
|
+
with self.memory_saver_adapter.region(
|
786
|
+
GPU_MEMORY_TYPE_WEIGHTS,
|
787
|
+
enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
|
788
|
+
):
|
690
789
|
self.model = get_model(
|
691
790
|
model_config=self.model_config,
|
692
791
|
load_config=self.load_config,
|
693
|
-
device_config=DeviceConfig(self.device),
|
792
|
+
device_config=DeviceConfig(self.device, self.gpu_id),
|
694
793
|
)
|
695
794
|
monkey_patch_vllm_parallel_state(reverse=True)
|
696
795
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
@@ -822,6 +921,103 @@ class ModelRunner:
|
|
822
921
|
logger.info("Update weights end.")
|
823
922
|
return True, "Succeeded to update model weights."
|
824
923
|
|
924
|
+
def init_weights_send_group_for_remote_instance(
|
925
|
+
self,
|
926
|
+
master_address,
|
927
|
+
ports,
|
928
|
+
group_rank,
|
929
|
+
world_size,
|
930
|
+
group_name,
|
931
|
+
backend="nccl",
|
932
|
+
):
|
933
|
+
assert (
|
934
|
+
torch.distributed.is_initialized()
|
935
|
+
), "Default torch process group must be initialized"
|
936
|
+
assert group_name != "", "Group name cannot be empty"
|
937
|
+
|
938
|
+
ports_list = ports.split(",")
|
939
|
+
assert (
|
940
|
+
len(ports_list) == self.tp_size
|
941
|
+
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
942
|
+
group_port = ports_list[self.tp_rank]
|
943
|
+
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
944
|
+
|
945
|
+
logger.info(
|
946
|
+
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
|
947
|
+
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
948
|
+
)
|
949
|
+
|
950
|
+
torch.cuda.empty_cache()
|
951
|
+
success = False
|
952
|
+
message = ""
|
953
|
+
try:
|
954
|
+
self._weights_send_group[group_name] = init_custom_process_group(
|
955
|
+
backend=backend,
|
956
|
+
init_method=f"tcp://{master_address}:{group_port}",
|
957
|
+
world_size=world_size,
|
958
|
+
rank=group_rank,
|
959
|
+
group_name=group_name,
|
960
|
+
device_id=torch.device("cuda", self.gpu_id),
|
961
|
+
)
|
962
|
+
dist.barrier(group=self._weights_send_group[group_name])
|
963
|
+
success = True
|
964
|
+
message = (
|
965
|
+
f"Succeeded to init group through {master_address}:{group_port} group."
|
966
|
+
)
|
967
|
+
except Exception as e:
|
968
|
+
message = f"Failed to init group: {e}."
|
969
|
+
logger.error(message)
|
970
|
+
|
971
|
+
torch.cuda.empty_cache()
|
972
|
+
return success, message
|
973
|
+
|
974
|
+
def send_weights_to_remote_instance(
|
975
|
+
self,
|
976
|
+
master_address,
|
977
|
+
ports,
|
978
|
+
group_name,
|
979
|
+
):
|
980
|
+
assert (
|
981
|
+
torch.distributed.is_initialized()
|
982
|
+
), "Default torch process group must be initialized"
|
983
|
+
assert group_name != "", "Group name cannot be empty"
|
984
|
+
|
985
|
+
ports_list = ports.split(",")
|
986
|
+
assert (
|
987
|
+
len(ports_list) == self.tp_size
|
988
|
+
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
989
|
+
group_port = ports_list[self.tp_rank]
|
990
|
+
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
991
|
+
|
992
|
+
if self._weights_send_group[group_name] is not None:
|
993
|
+
send_group = self._weights_send_group[group_name]
|
994
|
+
else:
|
995
|
+
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
|
996
|
+
logger.error(message)
|
997
|
+
return False, message
|
998
|
+
|
999
|
+
torch.cuda.empty_cache()
|
1000
|
+
success = False
|
1001
|
+
message = ""
|
1002
|
+
try:
|
1003
|
+
for _, weights in self.model.named_parameters():
|
1004
|
+
torch.distributed.broadcast(
|
1005
|
+
weights,
|
1006
|
+
src=0,
|
1007
|
+
group=send_group,
|
1008
|
+
)
|
1009
|
+
success = True
|
1010
|
+
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
|
1011
|
+
except Exception as e:
|
1012
|
+
message = f"Failed to send weights: {e}."
|
1013
|
+
logger.error(message)
|
1014
|
+
|
1015
|
+
# destroy the process group after sending weights
|
1016
|
+
del self._weights_send_group[group_name]
|
1017
|
+
torch.distributed.distributed_c10d.destroy_process_group(send_group)
|
1018
|
+
torch.cuda.empty_cache()
|
1019
|
+
return success, message
|
1020
|
+
|
825
1021
|
def init_weights_update_group(
|
826
1022
|
self,
|
827
1023
|
master_address,
|
@@ -867,6 +1063,19 @@ class ModelRunner:
|
|
867
1063
|
logger.error(message)
|
868
1064
|
return False, message
|
869
1065
|
|
1066
|
+
def destroy_weights_update_group(self, group_name):
|
1067
|
+
try:
|
1068
|
+
if group_name in self._model_update_group:
|
1069
|
+
pg = self._model_update_group.pop(group_name)
|
1070
|
+
torch.distributed.destroy_process_group(pg)
|
1071
|
+
return True, "Succeeded to destroy custom process group."
|
1072
|
+
else:
|
1073
|
+
return False, "The group to be destroyed does not exist."
|
1074
|
+
except Exception as e:
|
1075
|
+
message = f"Failed to destroy custom process group: {e}."
|
1076
|
+
logger.error(message)
|
1077
|
+
return False, message
|
1078
|
+
|
870
1079
|
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
871
1080
|
"""
|
872
1081
|
Update specific parameter in the model weights online
|
@@ -904,7 +1113,7 @@ class ModelRunner:
|
|
904
1113
|
handle.wait()
|
905
1114
|
|
906
1115
|
self.model.load_weights(weights)
|
907
|
-
return True,
|
1116
|
+
return True, "Succeeded to update parameter online."
|
908
1117
|
|
909
1118
|
except Exception as e:
|
910
1119
|
error_msg = (
|
@@ -1008,6 +1217,7 @@ class ModelRunner:
|
|
1008
1217
|
max_lora_rank=self.server_args.max_lora_rank,
|
1009
1218
|
target_modules=self.server_args.lora_target_modules,
|
1010
1219
|
lora_paths=self.server_args.lora_paths,
|
1220
|
+
server_args=self.server_args,
|
1011
1221
|
)
|
1012
1222
|
|
1013
1223
|
def load_lora_adapter(self, lora_ref: LoRARef):
|
@@ -1057,6 +1267,8 @@ class ModelRunner:
|
|
1057
1267
|
"num_nextn_predict_layers",
|
1058
1268
|
self.num_effective_layers,
|
1059
1269
|
)
|
1270
|
+
elif self.is_hybrid_gdn:
|
1271
|
+
num_layers = len(self.model_config.hf_config.full_attention_layer_ids)
|
1060
1272
|
else:
|
1061
1273
|
num_layers = self.num_effective_layers
|
1062
1274
|
if self.use_mla_backend:
|
@@ -1076,9 +1288,23 @@ class ModelRunner:
|
|
1076
1288
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
1077
1289
|
1 - self.mem_fraction_static
|
1078
1290
|
)
|
1291
|
+
if self.is_hybrid_gdn:
|
1292
|
+
rest_memory -= (
|
1293
|
+
self.server_args.max_mamba_cache_size
|
1294
|
+
* self.model_config.hf_config.mamba_cache_per_req
|
1295
|
+
/ (1 << 30)
|
1296
|
+
)
|
1079
1297
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
1080
1298
|
return max_num_token
|
1081
1299
|
|
1300
|
+
@property
|
1301
|
+
def is_hybrid_gdn(self):
|
1302
|
+
return self.model_config.hf_config.architectures[0] in [
|
1303
|
+
"Qwen3NextForCausalLM",
|
1304
|
+
"Qwen3NextForCausalLMMTP",
|
1305
|
+
"FalconH1ForCausalLM",
|
1306
|
+
]
|
1307
|
+
|
1082
1308
|
def set_num_token_hybrid(self):
|
1083
1309
|
if (
|
1084
1310
|
"Llama4ForConditionalGeneration"
|
@@ -1169,7 +1395,18 @@ class ModelRunner:
|
|
1169
1395
|
):
|
1170
1396
|
# Determine the kv cache dtype
|
1171
1397
|
if self.server_args.kv_cache_dtype == "auto":
|
1172
|
-
|
1398
|
+
quant_config = getattr(self.model, "quant_config", None)
|
1399
|
+
kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
|
1400
|
+
if (
|
1401
|
+
isinstance(kv_cache_quant_algo, str)
|
1402
|
+
and kv_cache_quant_algo.upper() == "FP8"
|
1403
|
+
):
|
1404
|
+
if _is_hip:
|
1405
|
+
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
1406
|
+
else:
|
1407
|
+
self.kv_cache_dtype = torch.float8_e4m3fn
|
1408
|
+
else:
|
1409
|
+
self.kv_cache_dtype = self.dtype
|
1173
1410
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
1174
1411
|
if _is_hip: # Using natively supported format
|
1175
1412
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
@@ -1185,6 +1422,8 @@ class ModelRunner:
|
|
1185
1422
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
1186
1423
|
)
|
1187
1424
|
|
1425
|
+
log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
|
1426
|
+
|
1188
1427
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
1189
1428
|
if SGLANG_CI_SMALL_KV_SIZE:
|
1190
1429
|
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
@@ -1199,8 +1438,10 @@ class ModelRunner:
|
|
1199
1438
|
),
|
1200
1439
|
4096,
|
1201
1440
|
)
|
1441
|
+
if self.is_hybrid_gdn:
|
1442
|
+
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
1202
1443
|
|
1203
|
-
if
|
1444
|
+
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
1204
1445
|
if self.is_draft_worker:
|
1205
1446
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
1206
1447
|
max_num_reqs = self.server_args.max_num_reqs
|
@@ -1237,13 +1478,24 @@ class ModelRunner:
|
|
1237
1478
|
// self.server_args.page_size
|
1238
1479
|
* self.server_args.page_size
|
1239
1480
|
)
|
1481
|
+
# different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
|
1482
|
+
if self.pp_size > 1:
|
1483
|
+
tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
|
1484
|
+
torch.distributed.all_reduce(
|
1485
|
+
tensor,
|
1486
|
+
op=torch.distributed.ReduceOp.MIN,
|
1487
|
+
group=get_world_group().cpu_group,
|
1488
|
+
)
|
1489
|
+
self.max_total_num_tokens = tensor.item()
|
1490
|
+
|
1240
1491
|
# create token size for hybrid cache
|
1241
1492
|
if self.is_hybrid:
|
1242
1493
|
self.set_num_token_hybrid()
|
1243
1494
|
|
1244
1495
|
if self.max_total_num_tokens <= 0:
|
1245
1496
|
raise RuntimeError(
|
1246
|
-
"Not enough memory. Please try to increase --mem-fraction-static."
|
1497
|
+
f"Not enough memory. Please try to increase --mem-fraction-static. "
|
1498
|
+
f"Current value: {self.server_args.mem_fraction_static=}"
|
1247
1499
|
)
|
1248
1500
|
|
1249
1501
|
# Initialize req_to_token_pool
|
@@ -1267,6 +1519,28 @@ class ModelRunner:
|
|
1267
1519
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1268
1520
|
pre_alloc_size=pre_alloc_size,
|
1269
1521
|
)
|
1522
|
+
elif self.is_hybrid_gdn:
|
1523
|
+
config = self.model_config.hf_config
|
1524
|
+
(
|
1525
|
+
conv_state_shape,
|
1526
|
+
temporal_state_shape,
|
1527
|
+
conv_dtype,
|
1528
|
+
ssm_dtype,
|
1529
|
+
mamba_layers,
|
1530
|
+
) = config.hybrid_gdn_params
|
1531
|
+
self.req_to_token_pool = HybridReqToTokenPool(
|
1532
|
+
size=max_num_reqs,
|
1533
|
+
max_context_len=self.model_config.context_len
|
1534
|
+
+ extra_max_context_len,
|
1535
|
+
device=self.device,
|
1536
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1537
|
+
conv_state_shape=conv_state_shape,
|
1538
|
+
temporal_state_shape=temporal_state_shape,
|
1539
|
+
conv_dtype=conv_dtype,
|
1540
|
+
ssm_dtype=ssm_dtype,
|
1541
|
+
mamba_layers=mamba_layers,
|
1542
|
+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1543
|
+
)
|
1270
1544
|
else:
|
1271
1545
|
self.req_to_token_pool = ReqToTokenPool(
|
1272
1546
|
size=max_num_reqs,
|
@@ -1280,6 +1554,7 @@ class ModelRunner:
|
|
1280
1554
|
assert self.is_draft_worker
|
1281
1555
|
|
1282
1556
|
# Initialize token_to_kv_pool
|
1557
|
+
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
|
1283
1558
|
if self.server_args.attention_backend == "ascend":
|
1284
1559
|
if self.use_mla_backend:
|
1285
1560
|
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
@@ -1288,6 +1563,7 @@ class ModelRunner:
|
|
1288
1563
|
dtype=self.kv_cache_dtype,
|
1289
1564
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1290
1565
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1566
|
+
index_head_dim=self.model_config.index_head_dim,
|
1291
1567
|
layer_num=self.num_effective_layers,
|
1292
1568
|
device=self.device,
|
1293
1569
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
@@ -1307,7 +1583,22 @@ class ModelRunner:
|
|
1307
1583
|
device=self.device,
|
1308
1584
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1309
1585
|
)
|
1586
|
+
elif self.use_mla_backend and is_nsa_model:
|
1587
|
+
self.token_to_kv_pool = NSATokenToKVPool(
|
1588
|
+
self.max_total_num_tokens,
|
1589
|
+
page_size=self.page_size,
|
1590
|
+
dtype=self.kv_cache_dtype,
|
1591
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
1592
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1593
|
+
layer_num=self.num_effective_layers,
|
1594
|
+
device=self.device,
|
1595
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1596
|
+
start_layer=self.start_layer,
|
1597
|
+
end_layer=self.end_layer,
|
1598
|
+
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
|
1599
|
+
)
|
1310
1600
|
elif self.use_mla_backend:
|
1601
|
+
assert not is_nsa_model
|
1311
1602
|
self.token_to_kv_pool = MLATokenToKVPool(
|
1312
1603
|
self.max_total_num_tokens,
|
1313
1604
|
page_size=self.page_size,
|
@@ -1349,6 +1640,24 @@ class ModelRunner:
|
|
1349
1640
|
enable_kvcache_transpose=False,
|
1350
1641
|
device=self.device,
|
1351
1642
|
)
|
1643
|
+
elif self.is_hybrid_gdn:
|
1644
|
+
self.token_to_kv_pool = HybridLinearKVPool(
|
1645
|
+
page_size=self.page_size,
|
1646
|
+
size=self.max_total_num_tokens,
|
1647
|
+
dtype=self.kv_cache_dtype,
|
1648
|
+
head_num=self.model_config.get_num_kv_heads(
|
1649
|
+
get_attention_tp_size()
|
1650
|
+
),
|
1651
|
+
head_dim=self.model_config.head_dim,
|
1652
|
+
# if draft worker, we only need 1 attention layer's kv pool
|
1653
|
+
full_attention_layer_ids=(
|
1654
|
+
[0]
|
1655
|
+
if self.is_draft_worker
|
1656
|
+
else self.model_config.hf_config.full_attention_layer_ids
|
1657
|
+
),
|
1658
|
+
enable_kvcache_transpose=False,
|
1659
|
+
device=self.device,
|
1660
|
+
)
|
1352
1661
|
else:
|
1353
1662
|
self.token_to_kv_pool = MHATokenToKVPool(
|
1354
1663
|
self.max_total_num_tokens,
|
@@ -1368,7 +1677,9 @@ class ModelRunner:
|
|
1368
1677
|
# Initialize token_to_kv_pool_allocator
|
1369
1678
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1370
1679
|
if self.token_to_kv_pool_allocator is None:
|
1371
|
-
if
|
1680
|
+
if _is_npu and (
|
1681
|
+
self.server_args.attention_backend == "ascend" or self.is_hybrid_gdn
|
1682
|
+
):
|
1372
1683
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1373
1684
|
self.max_total_num_tokens,
|
1374
1685
|
page_size=self.page_size,
|
@@ -1462,8 +1773,8 @@ class ModelRunner:
|
|
1462
1773
|
f"prefill_backend={self.prefill_attention_backend_str}."
|
1463
1774
|
)
|
1464
1775
|
logger.warning(
|
1465
|
-
|
1466
|
-
|
1776
|
+
"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
1777
|
+
"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
1467
1778
|
)
|
1468
1779
|
else:
|
1469
1780
|
attn_backend = self._get_attention_backend_from_str(
|
@@ -1479,111 +1790,10 @@ class ModelRunner:
|
|
1479
1790
|
return attn_backend
|
1480
1791
|
|
1481
1792
|
def _get_attention_backend_from_str(self, backend_str: str):
|
1482
|
-
if backend_str
|
1483
|
-
if not self.use_mla_backend:
|
1484
|
-
from sglang.srt.layers.attention.flashinfer_backend import (
|
1485
|
-
FlashInferAttnBackend,
|
1486
|
-
)
|
1487
|
-
|
1488
|
-
# Init streams
|
1489
|
-
if self.server_args.speculative_algorithm == "EAGLE":
|
1490
|
-
if (
|
1491
|
-
not hasattr(self, "plan_stream_for_flashinfer")
|
1492
|
-
or not self.plan_stream_for_flashinfer
|
1493
|
-
):
|
1494
|
-
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
1495
|
-
return FlashInferAttnBackend(self)
|
1496
|
-
else:
|
1497
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
1498
|
-
FlashInferMLAAttnBackend,
|
1499
|
-
)
|
1500
|
-
|
1501
|
-
return FlashInferMLAAttnBackend(self)
|
1502
|
-
elif backend_str == "aiter":
|
1503
|
-
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1504
|
-
|
1505
|
-
return AiterAttnBackend(self)
|
1506
|
-
elif self.server_args.attention_backend == "wave":
|
1507
|
-
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
1508
|
-
|
1509
|
-
return WaveAttnBackend(self)
|
1510
|
-
elif backend_str == "ascend":
|
1511
|
-
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1512
|
-
|
1513
|
-
return AscendAttnBackend(self)
|
1514
|
-
elif backend_str == "triton":
|
1515
|
-
assert not self.model_config.is_encoder_decoder, (
|
1516
|
-
"Cross attention is not supported in the triton attention backend. "
|
1517
|
-
"Please use `--attention-backend flashinfer`."
|
1518
|
-
)
|
1519
|
-
if self.server_args.enable_double_sparsity:
|
1520
|
-
from sglang.srt.layers.attention.double_sparsity_backend import (
|
1521
|
-
DoubleSparseAttnBackend,
|
1522
|
-
)
|
1523
|
-
|
1524
|
-
return DoubleSparseAttnBackend(self)
|
1525
|
-
else:
|
1526
|
-
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1527
|
-
|
1528
|
-
return TritonAttnBackend(self)
|
1529
|
-
elif backend_str == "torch_native":
|
1530
|
-
from sglang.srt.layers.attention.torch_native_backend import (
|
1531
|
-
TorchNativeAttnBackend,
|
1532
|
-
)
|
1533
|
-
|
1534
|
-
return TorchNativeAttnBackend(self)
|
1535
|
-
elif backend_str == "flashmla":
|
1536
|
-
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
1537
|
-
|
1538
|
-
return FlashMLABackend(self)
|
1539
|
-
elif backend_str == "fa3":
|
1540
|
-
assert (
|
1541
|
-
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
1542
|
-
) or torch.cuda.get_device_capability()[0] == 9, (
|
1543
|
-
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
1544
|
-
"Please use `--attention-backend flashinfer`."
|
1545
|
-
)
|
1546
|
-
from sglang.srt.layers.attention.flashattention_backend import (
|
1547
|
-
FlashAttentionBackend,
|
1548
|
-
)
|
1549
|
-
|
1550
|
-
return FlashAttentionBackend(self)
|
1551
|
-
elif backend_str == "cutlass_mla":
|
1552
|
-
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
1553
|
-
CutlassMLABackend,
|
1554
|
-
)
|
1555
|
-
|
1556
|
-
return CutlassMLABackend(self)
|
1557
|
-
elif backend_str == "trtllm_mla":
|
1558
|
-
if not self.use_mla_backend:
|
1559
|
-
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1560
|
-
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1561
|
-
|
1562
|
-
return TRTLLMMLABackend(self)
|
1563
|
-
elif backend_str == "trtllm_mha":
|
1564
|
-
if self.use_mla_backend:
|
1565
|
-
raise ValueError(
|
1566
|
-
"trtllm_mha backend can only be used with non-MLA models."
|
1567
|
-
)
|
1568
|
-
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
1569
|
-
TRTLLMHAAttnBackend,
|
1570
|
-
)
|
1571
|
-
|
1572
|
-
return TRTLLMHAAttnBackend(self)
|
1573
|
-
elif backend_str == "intel_amx":
|
1574
|
-
from sglang.srt.layers.attention.intel_amx_backend import (
|
1575
|
-
IntelAMXAttnBackend,
|
1576
|
-
)
|
1577
|
-
|
1578
|
-
return IntelAMXAttnBackend(self)
|
1579
|
-
elif backend_str == "dual_chunk_flash_attn":
|
1580
|
-
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1581
|
-
DualChunkFlashAttentionBackend,
|
1582
|
-
)
|
1583
|
-
|
1584
|
-
return DualChunkFlashAttentionBackend(self)
|
1585
|
-
else:
|
1793
|
+
if backend_str not in ATTENTION_BACKENDS:
|
1586
1794
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1795
|
+
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
|
1796
|
+
return attn_backend_wrapper(self, full_attention_backend)
|
1587
1797
|
|
1588
1798
|
def init_double_sparsity_channel_config(self, selected_channel):
|
1589
1799
|
selected_channel = "." + selected_channel + "_proj"
|
@@ -1603,38 +1813,46 @@ class ModelRunner:
|
|
1603
1813
|
)
|
1604
1814
|
|
1605
1815
|
def init_device_graphs(self):
|
1606
|
-
"""Capture
|
1816
|
+
"""Capture device graphs."""
|
1607
1817
|
self.graph_runner = None
|
1608
|
-
self.
|
1818
|
+
self.graph_mem_usage = 0
|
1609
1819
|
|
1610
1820
|
if not self.is_generation:
|
1611
1821
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
1612
1822
|
return
|
1613
1823
|
|
1614
|
-
if self.server_args.disable_cuda_graph:
|
1824
|
+
if self.device != "cpu" and self.server_args.disable_cuda_graph:
|
1825
|
+
return
|
1826
|
+
|
1827
|
+
if self.device == "cpu" and not self.server_args.enable_torch_compile:
|
1615
1828
|
return
|
1616
1829
|
|
1617
1830
|
tic = time.perf_counter()
|
1618
1831
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1619
1832
|
logger.info(
|
1620
|
-
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1833
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1621
1834
|
)
|
1622
|
-
|
1623
|
-
CudaGraphRunner
|
1835
|
+
graph_runners = defaultdict(
|
1836
|
+
lambda: CudaGraphRunner,
|
1837
|
+
{
|
1838
|
+
"cpu": CPUGraphRunner,
|
1839
|
+
"npu": NPUGraphRunner,
|
1840
|
+
},
|
1624
1841
|
)
|
1842
|
+
self.graph_runner = graph_runners[self.device](self)
|
1843
|
+
|
1625
1844
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1626
|
-
self.
|
1845
|
+
self.graph_mem_usage = before_mem - after_mem
|
1627
1846
|
logger.info(
|
1628
|
-
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1629
|
-
f"mem usage={self.
|
1847
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1848
|
+
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
1630
1849
|
)
|
1631
1850
|
|
1632
1851
|
def init_threads_binding(self):
|
1633
1852
|
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
|
1853
|
+
cpu_ids_by_node = get_cpu_ids_by_node()
|
1854
|
+
n_numa_node = len(cpu_ids_by_node)
|
1634
1855
|
if omp_cpuids == "all":
|
1635
|
-
cpu_ids_by_node = get_cpu_ids_by_node()
|
1636
|
-
n_numa_node = len(cpu_ids_by_node)
|
1637
|
-
|
1638
1856
|
assert self.tp_size <= n_numa_node, (
|
1639
1857
|
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
|
1640
1858
|
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
|
@@ -1651,11 +1869,22 @@ class ModelRunner:
|
|
1651
1869
|
)
|
1652
1870
|
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
|
1653
1871
|
else:
|
1654
|
-
|
1872
|
+
threads_bind_list = omp_cpuids.split("|")
|
1873
|
+
assert self.tp_size == len(threads_bind_list), (
|
1874
|
+
f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
|
1875
|
+
f"Please double check your settings."
|
1876
|
+
)
|
1877
|
+
self.local_omp_cpuid = threads_bind_list[self.tp_rank]
|
1878
|
+
if self.tp_size > n_numa_node:
|
1879
|
+
logger.warning(
|
1880
|
+
f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
|
1881
|
+
f"in this case the available memory amount of each rank cannot be determined in prior. "
|
1882
|
+
f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
|
1883
|
+
)
|
1655
1884
|
|
1656
1885
|
def apply_torch_tp(self):
|
1657
1886
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
1658
|
-
from sglang.srt.model_parallel import tensor_parallel
|
1887
|
+
from sglang.srt.layers.model_parallel import tensor_parallel
|
1659
1888
|
|
1660
1889
|
device_mesh = torch.distributed.init_device_mesh(self.device, (self.tp_size,))
|
1661
1890
|
tensor_parallel(self.model, device_mesh)
|
@@ -1771,18 +2000,24 @@ class ModelRunner:
|
|
1771
2000
|
reinit_attn_backend: bool = False,
|
1772
2001
|
split_forward_count: int = 1,
|
1773
2002
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1774
|
-
|
1775
|
-
forward_batch.forward_mode.
|
2003
|
+
mode_check = (
|
2004
|
+
forward_batch.forward_mode.is_cpu_graph
|
2005
|
+
if self.device == "cpu"
|
2006
|
+
else forward_batch.forward_mode.is_cuda_graph
|
2007
|
+
)
|
2008
|
+
can_run_graph = bool(
|
2009
|
+
mode_check()
|
1776
2010
|
and self.graph_runner
|
1777
2011
|
and self.graph_runner.can_run(forward_batch)
|
1778
2012
|
)
|
1779
|
-
|
2013
|
+
|
2014
|
+
if can_run_graph:
|
1780
2015
|
ret = self.graph_runner.replay(
|
1781
2016
|
forward_batch,
|
1782
2017
|
skip_attn_backend_init=skip_attn_backend_init,
|
1783
2018
|
pp_proxy_tensors=pp_proxy_tensors,
|
1784
2019
|
)
|
1785
|
-
return ret,
|
2020
|
+
return ret, can_run_graph
|
1786
2021
|
|
1787
2022
|
# For MLP sync
|
1788
2023
|
if forward_batch.global_num_tokens_cpu is not None:
|
@@ -1811,10 +2046,13 @@ class ModelRunner:
|
|
1811
2046
|
else:
|
1812
2047
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1813
2048
|
|
1814
|
-
if
|
2049
|
+
if (
|
2050
|
+
forward_batch.global_num_tokens_cpu is not None
|
2051
|
+
and self.pp_group.is_last_rank
|
2052
|
+
):
|
1815
2053
|
forward_batch.post_forward_mlp_sync_batch(ret)
|
1816
2054
|
|
1817
|
-
return ret,
|
2055
|
+
return ret, can_run_graph
|
1818
2056
|
|
1819
2057
|
def _preprocess_logits(
|
1820
2058
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
@@ -1852,7 +2090,6 @@ class ModelRunner:
|
|
1852
2090
|
)
|
1853
2091
|
|
1854
2092
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
1855
|
-
|
1856
2093
|
# Sample the next tokens
|
1857
2094
|
next_token_ids = self.sampler(
|
1858
2095
|
logits_output,
|
@@ -1860,9 +2097,47 @@ class ModelRunner:
|
|
1860
2097
|
forward_batch.return_logprob,
|
1861
2098
|
forward_batch.top_logprobs_nums,
|
1862
2099
|
forward_batch.token_ids_logprobs,
|
2100
|
+
# For prefill, we only use the position of the last token.
|
2101
|
+
(
|
2102
|
+
forward_batch.positions
|
2103
|
+
if forward_batch.forward_mode.is_decode()
|
2104
|
+
else forward_batch.seq_lens - 1
|
2105
|
+
),
|
1863
2106
|
)
|
1864
2107
|
return next_token_ids
|
1865
2108
|
|
2109
|
+
def compute_logprobs_only(
|
2110
|
+
self,
|
2111
|
+
logits_output: LogitsProcessorOutput,
|
2112
|
+
forward_batch: ForwardBatch,
|
2113
|
+
) -> None:
|
2114
|
+
"""
|
2115
|
+
Compute token_ids_logprobs without performing sampling.
|
2116
|
+
|
2117
|
+
Optimized path for prefill-only requests that need token_ids_logprobs but don't
|
2118
|
+
require next token generation. Skips expensive sampling operations
|
2119
|
+
while still providing requested probability information.
|
2120
|
+
|
2121
|
+
Args:
|
2122
|
+
logits_output: The logits output from the model forward
|
2123
|
+
forward_batch: The forward batch that generates logits_output
|
2124
|
+
"""
|
2125
|
+
if not forward_batch.token_ids_logprobs:
|
2126
|
+
return
|
2127
|
+
|
2128
|
+
# Preprocess logits (same as in sample method)
|
2129
|
+
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
2130
|
+
|
2131
|
+
# Delegate to sampler for logprob-only computation
|
2132
|
+
# This populates logits_output with requested token probabilities
|
2133
|
+
self.sampler.compute_logprobs_only(
|
2134
|
+
logits_output,
|
2135
|
+
forward_batch.sampling_info,
|
2136
|
+
forward_batch.return_logprob,
|
2137
|
+
forward_batch.top_logprobs_nums,
|
2138
|
+
forward_batch.token_ids_logprobs,
|
2139
|
+
)
|
2140
|
+
|
1866
2141
|
@property
|
1867
2142
|
def model_is_mrope(self) -> bool:
|
1868
2143
|
"""Detect if the model has "mrope" rope_scaling type.
|