sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -19,25 +19,36 @@ import inspect
|
|
19
19
|
import json
|
20
20
|
import logging
|
21
21
|
import os
|
22
|
+
import socket
|
23
|
+
import threading
|
22
24
|
import time
|
25
|
+
from collections import defaultdict
|
23
26
|
from dataclasses import dataclass
|
24
27
|
from typing import List, Optional, Tuple, Union
|
25
28
|
|
26
29
|
import torch
|
27
30
|
import torch.distributed as dist
|
28
31
|
|
32
|
+
from sglang.srt.configs import FalconH1Config, NemotronHConfig, Qwen3NextConfig
|
29
33
|
from sglang.srt.configs.device_config import DeviceConfig
|
30
|
-
from sglang.srt.configs.load_config import LoadConfig
|
31
|
-
from sglang.srt.configs.model_config import
|
34
|
+
from sglang.srt.configs.load_config import LoadConfig, LoadFormat
|
35
|
+
from sglang.srt.configs.model_config import (
|
36
|
+
AttentionArch,
|
37
|
+
ModelConfig,
|
38
|
+
get_nsa_index_head_dim,
|
39
|
+
is_deepseek_nsa,
|
40
|
+
)
|
32
41
|
from sglang.srt.configs.update_config import adjust_config_with_unaligned_cpu_tp
|
33
42
|
from sglang.srt.constants import GPU_MEMORY_TYPE_WEIGHTS
|
34
43
|
from sglang.srt.distributed import (
|
44
|
+
get_pp_group,
|
35
45
|
get_tp_group,
|
36
46
|
get_world_group,
|
37
47
|
init_distributed_environment,
|
38
48
|
initialize_model_parallel,
|
39
49
|
set_custom_all_reduce,
|
40
50
|
set_mscclpp_all_reduce,
|
51
|
+
set_symm_mem_all_reduce,
|
41
52
|
)
|
42
53
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
43
54
|
from sglang.srt.eplb.eplb_manager import EPLBManager
|
@@ -53,6 +64,10 @@ from sglang.srt.eplb.expert_location import (
|
|
53
64
|
set_global_expert_location_metadata,
|
54
65
|
)
|
55
66
|
from sglang.srt.eplb.expert_location_updater import ExpertLocationUpdater
|
67
|
+
from sglang.srt.layers.attention.attention_registry import (
|
68
|
+
ATTENTION_BACKENDS,
|
69
|
+
attn_backend_wrapper,
|
70
|
+
)
|
56
71
|
from sglang.srt.layers.attention.tbo_backend import TboAttnBackend
|
57
72
|
from sglang.srt.layers.dp_attention import (
|
58
73
|
get_attention_tp_group,
|
@@ -83,16 +98,23 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
83
98
|
AscendMLAPagedTokenToKVPool,
|
84
99
|
AscendTokenToKVPool,
|
85
100
|
DoubleSparseTokenToKVPool,
|
101
|
+
HybridLinearKVPool,
|
102
|
+
HybridReqToTokenPool,
|
86
103
|
MHATokenToKVPool,
|
87
104
|
MLATokenToKVPool,
|
105
|
+
NSATokenToKVPool,
|
88
106
|
ReqToTokenPool,
|
89
107
|
SWAKVPool,
|
90
108
|
)
|
109
|
+
from sglang.srt.model_executor.cpu_graph_runner import CPUGraphRunner
|
91
110
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
92
111
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
93
112
|
from sglang.srt.model_executor.npu_graph_runner import NPUGraphRunner
|
94
113
|
from sglang.srt.model_loader import get_model
|
95
114
|
from sglang.srt.model_loader.loader import DefaultModelLoader, get_model_loader
|
115
|
+
from sglang.srt.model_loader.remote_instance_weight_loader_utils import (
|
116
|
+
trigger_init_weights_send_group_for_remote_instance_request,
|
117
|
+
)
|
96
118
|
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
97
119
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
98
120
|
from sglang.srt.offloader import (
|
@@ -100,7 +122,6 @@ from sglang.srt.offloader import (
|
|
100
122
|
get_offloader,
|
101
123
|
set_offloader,
|
102
124
|
)
|
103
|
-
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
104
125
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
105
126
|
from sglang.srt.server_args import ServerArgs
|
106
127
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
@@ -121,15 +142,38 @@ from sglang.srt.utils import (
|
|
121
142
|
is_no_spec_infer_or_topk_one,
|
122
143
|
is_npu,
|
123
144
|
is_sm100_supported,
|
145
|
+
log_info_on_rank0,
|
124
146
|
monkey_patch_p2p_access_check,
|
125
147
|
monkey_patch_vllm_gguf_config,
|
126
148
|
set_cuda_arch,
|
149
|
+
slow_rank_detector,
|
127
150
|
)
|
151
|
+
from sglang.srt.utils.patch_torch import monkey_patch_torch_reductions
|
128
152
|
from sglang.srt.weight_sync.tensor_bucket import (
|
129
153
|
FlattenedTensorBucket,
|
130
154
|
FlattenedTensorMetadata,
|
131
155
|
)
|
132
156
|
|
157
|
+
MLA_ATTENTION_BACKENDS = [
|
158
|
+
"aiter",
|
159
|
+
"flashinfer",
|
160
|
+
"fa3",
|
161
|
+
"fa4",
|
162
|
+
"triton",
|
163
|
+
"flashmla",
|
164
|
+
"cutlass_mla",
|
165
|
+
"trtllm_mla",
|
166
|
+
"ascend",
|
167
|
+
"nsa",
|
168
|
+
]
|
169
|
+
|
170
|
+
|
171
|
+
def add_mla_attention_backend(backend_name):
|
172
|
+
if backend_name not in MLA_ATTENTION_BACKENDS:
|
173
|
+
MLA_ATTENTION_BACKENDS.append(backend_name)
|
174
|
+
logger.info(f"Added {backend_name} to MLA_ATTENTION_BACKENDS.")
|
175
|
+
|
176
|
+
|
133
177
|
_is_hip = is_hip()
|
134
178
|
_is_npu = is_npu()
|
135
179
|
_is_cpu_amx_available = cpu_has_amx_support()
|
@@ -143,6 +187,13 @@ UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
|
143
187
|
logger = logging.getLogger(__name__)
|
144
188
|
|
145
189
|
|
190
|
+
if _is_npu:
|
191
|
+
import torch_npu
|
192
|
+
|
193
|
+
torch.npu.config.allow_internal_format = True
|
194
|
+
torch_npu.npu.set_compile_mode(jit_compile=False)
|
195
|
+
|
196
|
+
|
146
197
|
class RankZeroFilter(logging.Filter):
|
147
198
|
"""Filter that only allows INFO level logs from rank 0, but allows all other levels from any rank."""
|
148
199
|
|
@@ -237,6 +288,9 @@ class ModelRunner:
|
|
237
288
|
# CPU offload
|
238
289
|
set_offloader(create_offloader_from_server_args(server_args, dp_rank=dp_rank))
|
239
290
|
|
291
|
+
if get_bool_env_var("SGLANG_DETECT_SLOW_RANK"):
|
292
|
+
slow_rank_detector.execute()
|
293
|
+
|
240
294
|
# Update deep gemm configure
|
241
295
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
|
242
296
|
deep_gemm_wrapper.update_deep_gemm_config(gpu_id, server_args)
|
@@ -251,6 +305,7 @@ class ModelRunner:
|
|
251
305
|
|
252
306
|
# For weight updates
|
253
307
|
self._model_update_group = {}
|
308
|
+
self._weights_send_group = {}
|
254
309
|
|
255
310
|
def initialize(self, min_per_gpu_memory: float):
|
256
311
|
server_args = self.server_args
|
@@ -300,6 +355,27 @@ class ModelRunner:
|
|
300
355
|
if architectures and not any("Llama4" in arch for arch in architectures):
|
301
356
|
self.is_hybrid = self.model_config.is_hybrid = True
|
302
357
|
|
358
|
+
if config := self.mambaish_config:
|
359
|
+
class_name = config.__class__.__name__
|
360
|
+
logger.warning(f"{class_name} model detected, disable radix cache")
|
361
|
+
self.server_args.disable_radix_cache = True
|
362
|
+
if self.server_args.max_mamba_cache_size is None:
|
363
|
+
if self.server_args.max_running_requests is not None:
|
364
|
+
self.server_args.max_mamba_cache_size = (
|
365
|
+
self.server_args.max_running_requests
|
366
|
+
)
|
367
|
+
else:
|
368
|
+
self.server_args.max_mamba_cache_size = 512
|
369
|
+
if self.hybrid_gdn_config is not None:
|
370
|
+
self.server_args.max_mamba_cache_size = (
|
371
|
+
self.server_args.max_mamba_cache_size
|
372
|
+
// (
|
373
|
+
self.server_args.dp_size
|
374
|
+
if self.server_args.enable_dp_attention
|
375
|
+
else 1
|
376
|
+
)
|
377
|
+
)
|
378
|
+
|
303
379
|
# For MTP models like DeepSeek-V3 or GLM-4.5, the MTP layer(s) are used separately as draft
|
304
380
|
# models for speculative decoding. In those cases, `num_nextn_predict_layers` is used to
|
305
381
|
# determine the number of layers.
|
@@ -341,6 +417,20 @@ class ModelRunner:
|
|
341
417
|
if server_args.enable_lora:
|
342
418
|
self.init_lora_manager()
|
343
419
|
|
420
|
+
# Init Double Sparsity
|
421
|
+
if server_args.enable_double_sparsity:
|
422
|
+
if server_args.ds_heavy_channel_type is None:
|
423
|
+
raise ValueError(
|
424
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
425
|
+
)
|
426
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
427
|
+
|
428
|
+
# Enable batch invariant mode
|
429
|
+
if server_args.enable_deterministic_inference:
|
430
|
+
from sglang.srt.batch_invariant_ops import enable_batch_invariant_mode
|
431
|
+
|
432
|
+
enable_batch_invariant_mode()
|
433
|
+
|
344
434
|
# Init memory pool and attention backends
|
345
435
|
self.init_memory_pool(
|
346
436
|
min_per_gpu_memory,
|
@@ -351,12 +441,12 @@ class ModelRunner:
|
|
351
441
|
self.init_cublas()
|
352
442
|
self.init_attention_backend()
|
353
443
|
self.init_device_graphs()
|
354
|
-
elif self.device
|
444
|
+
elif self.device in ["npu", "cpu"]:
|
355
445
|
self.init_attention_backend()
|
356
446
|
self.init_device_graphs()
|
357
447
|
else:
|
358
448
|
self.graph_runner = None
|
359
|
-
self.
|
449
|
+
self.graph_mem_usage = 0
|
360
450
|
self.init_attention_backend()
|
361
451
|
|
362
452
|
# auxiliary hidden capture mode. TODO: expose this to server args?
|
@@ -452,9 +542,7 @@ class ModelRunner:
|
|
452
542
|
elif _is_hip:
|
453
543
|
head_num = self.model_config.get_num_kv_heads(self.tp_size)
|
454
544
|
# TODO current aiter only support head number 16 or 128 head number
|
455
|
-
if
|
456
|
-
head_num == 128 or head_num == 16
|
457
|
-
) and self.spec_algorithm.is_none():
|
545
|
+
if head_num == 128 or head_num == 16:
|
458
546
|
server_args.attention_backend = "aiter"
|
459
547
|
else:
|
460
548
|
server_args.attention_backend = "triton"
|
@@ -467,16 +555,7 @@ class ModelRunner:
|
|
467
555
|
)
|
468
556
|
elif self.use_mla_backend:
|
469
557
|
if server_args.device != "cpu":
|
470
|
-
if server_args.attention_backend in
|
471
|
-
"aiter",
|
472
|
-
"flashinfer",
|
473
|
-
"fa3",
|
474
|
-
"triton",
|
475
|
-
"flashmla",
|
476
|
-
"cutlass_mla",
|
477
|
-
"trtllm_mla",
|
478
|
-
"ascend",
|
479
|
-
]:
|
558
|
+
if server_args.attention_backend in MLA_ATTENTION_BACKENDS:
|
480
559
|
logger.info(
|
481
560
|
f"MLA optimization is turned on. Use {server_args.attention_backend} backend."
|
482
561
|
)
|
@@ -506,11 +585,6 @@ class ModelRunner:
|
|
506
585
|
)
|
507
586
|
server_args.attention_backend = "triton"
|
508
587
|
server_args.disable_cuda_graph = True
|
509
|
-
if server_args.ds_heavy_channel_type is None:
|
510
|
-
raise ValueError(
|
511
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
512
|
-
)
|
513
|
-
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
514
588
|
|
515
589
|
if self.is_multimodal:
|
516
590
|
if not self.is_multimodal_chunked_prefill_supported:
|
@@ -548,7 +622,7 @@ class ModelRunner:
|
|
548
622
|
server_args.hicache_io_backend = "direct"
|
549
623
|
logger.warning(
|
550
624
|
"FlashAttention3 decode backend is not compatible with hierarchical cache. "
|
551
|
-
|
625
|
+
"Setting hicache_io_backend to vanilla I/O, which may lead to suboptimal performance with small page sizes."
|
552
626
|
)
|
553
627
|
|
554
628
|
def init_torch_distributed(self):
|
@@ -583,6 +657,7 @@ class ModelRunner:
|
|
583
657
|
dist_init_method = f"tcp://127.0.0.1:{self.dist_port}"
|
584
658
|
set_custom_all_reduce(not self.server_args.disable_custom_all_reduce)
|
585
659
|
set_mscclpp_all_reduce(self.server_args.enable_mscclpp)
|
660
|
+
set_symm_mem_all_reduce(self.server_args.enable_torch_symm_mem)
|
586
661
|
|
587
662
|
if not self.is_draft_worker:
|
588
663
|
if self.device == "cpu":
|
@@ -593,6 +668,11 @@ class ModelRunner:
|
|
593
668
|
# Set local size to hint SGLang to use shared memory based AllReduce
|
594
669
|
os.environ["LOCAL_SIZE"] = str(self.tp_size)
|
595
670
|
torch.ops.sgl_kernel.initialize(self.tp_size, self.tp_rank)
|
671
|
+
|
672
|
+
@torch.library.register_fake("sgl_kernel::shm_allgather")
|
673
|
+
def _(data, dim):
|
674
|
+
return torch.cat([data] * self.tp_size, dim=dim)
|
675
|
+
|
596
676
|
else:
|
597
677
|
logger.warning(
|
598
678
|
"init_cpu_threads_env and shared memory based AllReduce is disabled since intel amx backend is not available"
|
@@ -625,6 +705,7 @@ class ModelRunner:
|
|
625
705
|
cpu_group=get_world_group().cpu_group,
|
626
706
|
)
|
627
707
|
self.tp_group = get_tp_group()
|
708
|
+
self.pp_group = get_pp_group()
|
628
709
|
self.attention_tp_group = get_attention_tp_group()
|
629
710
|
|
630
711
|
# Check memory for tensor parallelism
|
@@ -673,6 +754,10 @@ class ModelRunner:
|
|
673
754
|
load_format=self.server_args.load_format,
|
674
755
|
download_dir=self.server_args.download_dir,
|
675
756
|
model_loader_extra_config=self.server_args.model_loader_extra_config,
|
757
|
+
tp_rank=self.tp_rank,
|
758
|
+
remote_instance_weight_loader_seed_instance_ip=self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
759
|
+
remote_instance_weight_loader_seed_instance_service_port=self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
760
|
+
remote_instance_weight_loader_send_weights_group_ports=self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
676
761
|
)
|
677
762
|
if self.device == "cpu":
|
678
763
|
self.model_config = adjust_config_with_unaligned_cpu_tp(
|
@@ -681,16 +766,33 @@ class ModelRunner:
|
|
681
766
|
if self.server_args.load_format == "gguf":
|
682
767
|
monkey_patch_vllm_gguf_config()
|
683
768
|
|
769
|
+
if self.server_args.load_format == LoadFormat.REMOTE_INSTANCE:
|
770
|
+
if self.tp_rank == 0:
|
771
|
+
instance_ip = socket.gethostbyname(socket.gethostname())
|
772
|
+
t = threading.Thread(
|
773
|
+
target=trigger_init_weights_send_group_for_remote_instance_request,
|
774
|
+
args=(
|
775
|
+
self.server_args.remote_instance_weight_loader_seed_instance_ip,
|
776
|
+
self.server_args.remote_instance_weight_loader_seed_instance_service_port,
|
777
|
+
self.server_args.remote_instance_weight_loader_send_weights_group_ports,
|
778
|
+
instance_ip,
|
779
|
+
),
|
780
|
+
)
|
781
|
+
t.start()
|
782
|
+
|
684
783
|
# Load the model
|
685
784
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
686
785
|
monkey_patch_vllm_parallel_state()
|
687
786
|
monkey_patch_isinstance_for_vllm_base_layer()
|
688
787
|
|
689
|
-
with self.memory_saver_adapter.region(
|
788
|
+
with self.memory_saver_adapter.region(
|
789
|
+
GPU_MEMORY_TYPE_WEIGHTS,
|
790
|
+
enable_cpu_backup=self.server_args.enable_weights_cpu_backup,
|
791
|
+
):
|
690
792
|
self.model = get_model(
|
691
793
|
model_config=self.model_config,
|
692
794
|
load_config=self.load_config,
|
693
|
-
device_config=DeviceConfig(self.device),
|
795
|
+
device_config=DeviceConfig(self.device, self.gpu_id),
|
694
796
|
)
|
695
797
|
monkey_patch_vllm_parallel_state(reverse=True)
|
696
798
|
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
@@ -781,7 +883,7 @@ class ModelRunner:
|
|
781
883
|
load_config = LoadConfig(load_format=load_format)
|
782
884
|
|
783
885
|
# Only support DefaultModelLoader for now
|
784
|
-
loader = get_model_loader(load_config)
|
886
|
+
loader = get_model_loader(load_config, self.model_config)
|
785
887
|
if not isinstance(loader, DefaultModelLoader):
|
786
888
|
message = f"Failed to get model loader: {loader}."
|
787
889
|
return False, message
|
@@ -822,6 +924,103 @@ class ModelRunner:
|
|
822
924
|
logger.info("Update weights end.")
|
823
925
|
return True, "Succeeded to update model weights."
|
824
926
|
|
927
|
+
def init_weights_send_group_for_remote_instance(
|
928
|
+
self,
|
929
|
+
master_address,
|
930
|
+
ports,
|
931
|
+
group_rank,
|
932
|
+
world_size,
|
933
|
+
group_name,
|
934
|
+
backend="nccl",
|
935
|
+
):
|
936
|
+
assert (
|
937
|
+
torch.distributed.is_initialized()
|
938
|
+
), "Default torch process group must be initialized"
|
939
|
+
assert group_name != "", "Group name cannot be empty"
|
940
|
+
|
941
|
+
ports_list = ports.split(",")
|
942
|
+
assert (
|
943
|
+
len(ports_list) == self.tp_size
|
944
|
+
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
945
|
+
group_port = ports_list[self.tp_rank]
|
946
|
+
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
947
|
+
|
948
|
+
logger.info(
|
949
|
+
f"init custom process group: tp_rank={self.tp_rank}, gpu_id={self.gpu_id}, master_address={master_address}, master_port={group_port}, "
|
950
|
+
f"group_rank={group_rank}, world_size={world_size}, group_name={group_name}, backend={backend}"
|
951
|
+
)
|
952
|
+
|
953
|
+
torch.cuda.empty_cache()
|
954
|
+
success = False
|
955
|
+
message = ""
|
956
|
+
try:
|
957
|
+
self._weights_send_group[group_name] = init_custom_process_group(
|
958
|
+
backend=backend,
|
959
|
+
init_method=f"tcp://{master_address}:{group_port}",
|
960
|
+
world_size=world_size,
|
961
|
+
rank=group_rank,
|
962
|
+
group_name=group_name,
|
963
|
+
device_id=torch.device("cuda", self.gpu_id),
|
964
|
+
)
|
965
|
+
dist.barrier(group=self._weights_send_group[group_name])
|
966
|
+
success = True
|
967
|
+
message = (
|
968
|
+
f"Succeeded to init group through {master_address}:{group_port} group."
|
969
|
+
)
|
970
|
+
except Exception as e:
|
971
|
+
message = f"Failed to init group: {e}."
|
972
|
+
logger.error(message)
|
973
|
+
|
974
|
+
torch.cuda.empty_cache()
|
975
|
+
return success, message
|
976
|
+
|
977
|
+
def send_weights_to_remote_instance(
|
978
|
+
self,
|
979
|
+
master_address,
|
980
|
+
ports,
|
981
|
+
group_name,
|
982
|
+
):
|
983
|
+
assert (
|
984
|
+
torch.distributed.is_initialized()
|
985
|
+
), "Default torch process group must be initialized"
|
986
|
+
assert group_name != "", "Group name cannot be empty"
|
987
|
+
|
988
|
+
ports_list = ports.split(",")
|
989
|
+
assert (
|
990
|
+
len(ports_list) == self.tp_size
|
991
|
+
), f"Expected {self.tp_size} ports, but got {len(ports_list)} ports."
|
992
|
+
group_port = ports_list[self.tp_rank]
|
993
|
+
group_name = f"{group_name}_{group_port}_{self.tp_rank}"
|
994
|
+
|
995
|
+
if self._weights_send_group[group_name] is not None:
|
996
|
+
send_group = self._weights_send_group[group_name]
|
997
|
+
else:
|
998
|
+
message = f"Group {group_name} not in _weights_send_group list. Please call `init_weights_send_group_for_remote_instance` first."
|
999
|
+
logger.error(message)
|
1000
|
+
return False, message
|
1001
|
+
|
1002
|
+
torch.cuda.empty_cache()
|
1003
|
+
success = False
|
1004
|
+
message = ""
|
1005
|
+
try:
|
1006
|
+
for _, weights in self.model.named_parameters():
|
1007
|
+
torch.distributed.broadcast(
|
1008
|
+
weights,
|
1009
|
+
src=0,
|
1010
|
+
group=send_group,
|
1011
|
+
)
|
1012
|
+
success = True
|
1013
|
+
message = f"Succeeded to send weights through {master_address}:{group_port} {group_name}."
|
1014
|
+
except Exception as e:
|
1015
|
+
message = f"Failed to send weights: {e}."
|
1016
|
+
logger.error(message)
|
1017
|
+
|
1018
|
+
# destroy the process group after sending weights
|
1019
|
+
del self._weights_send_group[group_name]
|
1020
|
+
torch.distributed.distributed_c10d.destroy_process_group(send_group)
|
1021
|
+
torch.cuda.empty_cache()
|
1022
|
+
return success, message
|
1023
|
+
|
825
1024
|
def init_weights_update_group(
|
826
1025
|
self,
|
827
1026
|
master_address,
|
@@ -867,6 +1066,19 @@ class ModelRunner:
|
|
867
1066
|
logger.error(message)
|
868
1067
|
return False, message
|
869
1068
|
|
1069
|
+
def destroy_weights_update_group(self, group_name):
|
1070
|
+
try:
|
1071
|
+
if group_name in self._model_update_group:
|
1072
|
+
pg = self._model_update_group.pop(group_name)
|
1073
|
+
torch.distributed.destroy_process_group(pg)
|
1074
|
+
return True, "Succeeded to destroy custom process group."
|
1075
|
+
else:
|
1076
|
+
return False, "The group to be destroyed does not exist."
|
1077
|
+
except Exception as e:
|
1078
|
+
message = f"Failed to destroy custom process group: {e}."
|
1079
|
+
logger.error(message)
|
1080
|
+
return False, message
|
1081
|
+
|
870
1082
|
def update_weights_from_distributed(self, names, dtypes, shapes, group_name):
|
871
1083
|
"""
|
872
1084
|
Update specific parameter in the model weights online
|
@@ -904,7 +1116,7 @@ class ModelRunner:
|
|
904
1116
|
handle.wait()
|
905
1117
|
|
906
1118
|
self.model.load_weights(weights)
|
907
|
-
return True,
|
1119
|
+
return True, "Succeeded to update parameter online."
|
908
1120
|
|
909
1121
|
except Exception as e:
|
910
1122
|
error_msg = (
|
@@ -1008,6 +1220,7 @@ class ModelRunner:
|
|
1008
1220
|
max_lora_rank=self.server_args.max_lora_rank,
|
1009
1221
|
target_modules=self.server_args.lora_target_modules,
|
1010
1222
|
lora_paths=self.server_args.lora_paths,
|
1223
|
+
server_args=self.server_args,
|
1011
1224
|
)
|
1012
1225
|
|
1013
1226
|
def load_lora_adapter(self, lora_ref: LoRARef):
|
@@ -1057,6 +1270,8 @@ class ModelRunner:
|
|
1057
1270
|
"num_nextn_predict_layers",
|
1058
1271
|
self.num_effective_layers,
|
1059
1272
|
)
|
1273
|
+
elif config := self.mambaish_config:
|
1274
|
+
num_layers = len(config.full_attention_layer_ids)
|
1060
1275
|
else:
|
1061
1276
|
num_layers = self.num_effective_layers
|
1062
1277
|
if self.use_mla_backend:
|
@@ -1065,6 +1280,17 @@ class ModelRunner:
|
|
1065
1280
|
* num_layers
|
1066
1281
|
* torch._utils._element_size(self.kv_cache_dtype)
|
1067
1282
|
)
|
1283
|
+
# Add indexer KV cache overhead for NSA models (DeepSeek V3.2)
|
1284
|
+
if is_deepseek_nsa(self.model_config.hf_config):
|
1285
|
+
index_head_dim = get_nsa_index_head_dim(self.model_config.hf_config)
|
1286
|
+
indexer_size_per_token = (
|
1287
|
+
index_head_dim
|
1288
|
+
+ index_head_dim // NSATokenToKVPool.quant_block_size * 4
|
1289
|
+
)
|
1290
|
+
element_size = torch._utils._element_size(
|
1291
|
+
NSATokenToKVPool.index_k_with_scale_buffer_dtype
|
1292
|
+
)
|
1293
|
+
cell_size += indexer_size_per_token * num_layers * element_size
|
1068
1294
|
else:
|
1069
1295
|
cell_size = (
|
1070
1296
|
self.model_config.get_num_kv_heads(get_attention_tp_size())
|
@@ -1076,9 +1302,33 @@ class ModelRunner:
|
|
1076
1302
|
rest_memory = available_gpu_memory - total_gpu_memory * (
|
1077
1303
|
1 - self.mem_fraction_static
|
1078
1304
|
)
|
1305
|
+
if config := self.mambaish_config:
|
1306
|
+
rest_memory -= (
|
1307
|
+
self.server_args.max_mamba_cache_size
|
1308
|
+
* config.mamba2_cache_params.mamba_cache_per_req
|
1309
|
+
/ (1 << 30)
|
1310
|
+
)
|
1079
1311
|
max_num_token = int(rest_memory * (1 << 30) // cell_size)
|
1080
1312
|
return max_num_token
|
1081
1313
|
|
1314
|
+
@property
|
1315
|
+
def hybrid_gdn_config(self):
|
1316
|
+
config = self.model_config.hf_config
|
1317
|
+
if isinstance(config, Qwen3NextConfig):
|
1318
|
+
return config
|
1319
|
+
return None
|
1320
|
+
|
1321
|
+
@property
|
1322
|
+
def mamba2_config(self):
|
1323
|
+
config = self.model_config.hf_config
|
1324
|
+
if isinstance(config, FalconH1Config | NemotronHConfig):
|
1325
|
+
return config
|
1326
|
+
return None
|
1327
|
+
|
1328
|
+
@property
|
1329
|
+
def mambaish_config(self):
|
1330
|
+
return self.mamba2_config or self.hybrid_gdn_config
|
1331
|
+
|
1082
1332
|
def set_num_token_hybrid(self):
|
1083
1333
|
if (
|
1084
1334
|
"Llama4ForConditionalGeneration"
|
@@ -1169,7 +1419,18 @@ class ModelRunner:
|
|
1169
1419
|
):
|
1170
1420
|
# Determine the kv cache dtype
|
1171
1421
|
if self.server_args.kv_cache_dtype == "auto":
|
1172
|
-
|
1422
|
+
quant_config = getattr(self.model, "quant_config", None)
|
1423
|
+
kv_cache_quant_algo = getattr(quant_config, "kv_cache_quant_algo", None)
|
1424
|
+
if (
|
1425
|
+
isinstance(kv_cache_quant_algo, str)
|
1426
|
+
and kv_cache_quant_algo.upper() == "FP8"
|
1427
|
+
):
|
1428
|
+
if _is_hip:
|
1429
|
+
self.kv_cache_dtype = torch.float8_e4m3fnuz
|
1430
|
+
else:
|
1431
|
+
self.kv_cache_dtype = torch.float8_e4m3fn
|
1432
|
+
else:
|
1433
|
+
self.kv_cache_dtype = self.dtype
|
1173
1434
|
elif self.server_args.kv_cache_dtype == "fp8_e5m2":
|
1174
1435
|
if _is_hip: # Using natively supported format
|
1175
1436
|
self.kv_cache_dtype = torch.float8_e5m2fnuz
|
@@ -1185,6 +1446,8 @@ class ModelRunner:
|
|
1185
1446
|
f"Unsupported kv_cache_dtype: {self.server_args.kv_cache_dtype}."
|
1186
1447
|
)
|
1187
1448
|
|
1449
|
+
log_info_on_rank0(logger, f"Using KV cache dtype: {self.kv_cache_dtype}")
|
1450
|
+
|
1188
1451
|
self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
|
1189
1452
|
if SGLANG_CI_SMALL_KV_SIZE:
|
1190
1453
|
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
@@ -1199,8 +1462,10 @@ class ModelRunner:
|
|
1199
1462
|
),
|
1200
1463
|
4096,
|
1201
1464
|
)
|
1465
|
+
if self.mambaish_config is not None:
|
1466
|
+
max_num_reqs = min(max_num_reqs, self.server_args.max_mamba_cache_size)
|
1202
1467
|
|
1203
|
-
if
|
1468
|
+
if self.spec_algorithm.is_eagle() or self.spec_algorithm.is_standalone():
|
1204
1469
|
if self.is_draft_worker:
|
1205
1470
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
1206
1471
|
max_num_reqs = self.server_args.max_num_reqs
|
@@ -1237,13 +1502,24 @@ class ModelRunner:
|
|
1237
1502
|
// self.server_args.page_size
|
1238
1503
|
* self.server_args.page_size
|
1239
1504
|
)
|
1505
|
+
# different pp rank may have different num of layers, so we need to reduce the max_total_num_tokens
|
1506
|
+
if self.pp_size > 1:
|
1507
|
+
tensor = torch.tensor(self.max_total_num_tokens, dtype=torch.int64)
|
1508
|
+
torch.distributed.all_reduce(
|
1509
|
+
tensor,
|
1510
|
+
op=torch.distributed.ReduceOp.MIN,
|
1511
|
+
group=get_world_group().cpu_group,
|
1512
|
+
)
|
1513
|
+
self.max_total_num_tokens = tensor.item()
|
1514
|
+
|
1240
1515
|
# create token size for hybrid cache
|
1241
1516
|
if self.is_hybrid:
|
1242
1517
|
self.set_num_token_hybrid()
|
1243
1518
|
|
1244
1519
|
if self.max_total_num_tokens <= 0:
|
1245
1520
|
raise RuntimeError(
|
1246
|
-
"Not enough memory. Please try to increase --mem-fraction-static."
|
1521
|
+
f"Not enough memory. Please try to increase --mem-fraction-static. "
|
1522
|
+
f"Current value: {self.server_args.mem_fraction_static=}"
|
1247
1523
|
)
|
1248
1524
|
|
1249
1525
|
# Initialize req_to_token_pool
|
@@ -1267,6 +1543,16 @@ class ModelRunner:
|
|
1267
1543
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1268
1544
|
pre_alloc_size=pre_alloc_size,
|
1269
1545
|
)
|
1546
|
+
elif config := self.mambaish_config:
|
1547
|
+
self.req_to_token_pool = HybridReqToTokenPool(
|
1548
|
+
size=max_num_reqs,
|
1549
|
+
max_context_len=self.model_config.context_len
|
1550
|
+
+ extra_max_context_len,
|
1551
|
+
device=self.device,
|
1552
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1553
|
+
cache_params=config.mamba2_cache_params,
|
1554
|
+
speculative_num_draft_tokens=self.server_args.speculative_num_draft_tokens,
|
1555
|
+
)
|
1270
1556
|
else:
|
1271
1557
|
self.req_to_token_pool = ReqToTokenPool(
|
1272
1558
|
size=max_num_reqs,
|
@@ -1280,6 +1566,7 @@ class ModelRunner:
|
|
1280
1566
|
assert self.is_draft_worker
|
1281
1567
|
|
1282
1568
|
# Initialize token_to_kv_pool
|
1569
|
+
is_nsa_model = is_deepseek_nsa(self.model_config.hf_config)
|
1283
1570
|
if self.server_args.attention_backend == "ascend":
|
1284
1571
|
if self.use_mla_backend:
|
1285
1572
|
self.token_to_kv_pool = AscendMLAPagedTokenToKVPool(
|
@@ -1288,6 +1575,7 @@ class ModelRunner:
|
|
1288
1575
|
dtype=self.kv_cache_dtype,
|
1289
1576
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
1290
1577
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1578
|
+
index_head_dim=self.model_config.index_head_dim,
|
1291
1579
|
layer_num=self.num_effective_layers,
|
1292
1580
|
device=self.device,
|
1293
1581
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
@@ -1307,7 +1595,22 @@ class ModelRunner:
|
|
1307
1595
|
device=self.device,
|
1308
1596
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1309
1597
|
)
|
1598
|
+
elif self.use_mla_backend and is_nsa_model:
|
1599
|
+
self.token_to_kv_pool = NSATokenToKVPool(
|
1600
|
+
self.max_total_num_tokens,
|
1601
|
+
page_size=self.page_size,
|
1602
|
+
dtype=self.kv_cache_dtype,
|
1603
|
+
kv_lora_rank=self.model_config.kv_lora_rank,
|
1604
|
+
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
1605
|
+
layer_num=self.num_effective_layers,
|
1606
|
+
device=self.device,
|
1607
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
1608
|
+
start_layer=self.start_layer,
|
1609
|
+
end_layer=self.end_layer,
|
1610
|
+
index_head_dim=get_nsa_index_head_dim(self.model_config.hf_config),
|
1611
|
+
)
|
1310
1612
|
elif self.use_mla_backend:
|
1613
|
+
assert not is_nsa_model
|
1311
1614
|
self.token_to_kv_pool = MLATokenToKVPool(
|
1312
1615
|
self.max_total_num_tokens,
|
1313
1616
|
page_size=self.page_size,
|
@@ -1349,6 +1652,22 @@ class ModelRunner:
|
|
1349
1652
|
enable_kvcache_transpose=False,
|
1350
1653
|
device=self.device,
|
1351
1654
|
)
|
1655
|
+
elif config := self.mambaish_config:
|
1656
|
+
self.token_to_kv_pool = HybridLinearKVPool(
|
1657
|
+
page_size=self.page_size,
|
1658
|
+
size=self.max_total_num_tokens,
|
1659
|
+
dtype=self.kv_cache_dtype,
|
1660
|
+
head_num=self.model_config.get_num_kv_heads(
|
1661
|
+
get_attention_tp_size()
|
1662
|
+
),
|
1663
|
+
head_dim=self.model_config.head_dim,
|
1664
|
+
# if draft worker, we only need 1 attention layer's kv pool
|
1665
|
+
full_attention_layer_ids=(
|
1666
|
+
[0] if self.is_draft_worker else config.full_attention_layer_ids
|
1667
|
+
),
|
1668
|
+
enable_kvcache_transpose=False,
|
1669
|
+
device=self.device,
|
1670
|
+
)
|
1352
1671
|
else:
|
1353
1672
|
self.token_to_kv_pool = MHATokenToKVPool(
|
1354
1673
|
self.max_total_num_tokens,
|
@@ -1363,12 +1682,18 @@ class ModelRunner:
|
|
1363
1682
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
1364
1683
|
start_layer=self.start_layer,
|
1365
1684
|
end_layer=self.end_layer,
|
1685
|
+
enable_kv_cache_copy=(
|
1686
|
+
self.server_args.speculative_algorithm is not None
|
1687
|
+
),
|
1366
1688
|
)
|
1367
1689
|
|
1368
1690
|
# Initialize token_to_kv_pool_allocator
|
1369
1691
|
need_sort = self.server_args.disaggregation_mode in ("decode", "prefill")
|
1370
1692
|
if self.token_to_kv_pool_allocator is None:
|
1371
|
-
if
|
1693
|
+
if _is_npu and (
|
1694
|
+
self.server_args.attention_backend == "ascend"
|
1695
|
+
or self.hybrid_gdn_config is not None
|
1696
|
+
):
|
1372
1697
|
self.token_to_kv_pool_allocator = AscendPagedTokenToKVPoolAllocator(
|
1373
1698
|
self.max_total_num_tokens,
|
1374
1699
|
page_size=self.page_size,
|
@@ -1432,16 +1757,10 @@ class ModelRunner:
|
|
1432
1757
|
|
1433
1758
|
def _get_attention_backend(self):
|
1434
1759
|
"""Init attention kernel backend."""
|
1435
|
-
self.decode_attention_backend_str = (
|
1436
|
-
self.server_args.
|
1437
|
-
if self.server_args.decode_attention_backend
|
1438
|
-
else self.server_args.attention_backend
|
1439
|
-
)
|
1440
|
-
self.prefill_attention_backend_str = (
|
1441
|
-
self.server_args.prefill_attention_backend
|
1442
|
-
if self.server_args.prefill_attention_backend
|
1443
|
-
else self.server_args.attention_backend
|
1760
|
+
self.prefill_attention_backend_str, self.decode_attention_backend_str = (
|
1761
|
+
self.server_args.get_attention_backends()
|
1444
1762
|
)
|
1763
|
+
|
1445
1764
|
if self.decode_attention_backend_str != self.prefill_attention_backend_str:
|
1446
1765
|
from sglang.srt.layers.attention.hybrid_attn_backend import (
|
1447
1766
|
HybridAttnBackend,
|
@@ -1462,8 +1781,8 @@ class ModelRunner:
|
|
1462
1781
|
f"prefill_backend={self.prefill_attention_backend_str}."
|
1463
1782
|
)
|
1464
1783
|
logger.warning(
|
1465
|
-
|
1466
|
-
|
1784
|
+
"Warning: Attention backend specified by --attention-backend or default backend might be overridden."
|
1785
|
+
"The feature of hybrid attention backend is experimental and unstable. Please raise an issue if you encounter any problem."
|
1467
1786
|
)
|
1468
1787
|
else:
|
1469
1788
|
attn_backend = self._get_attention_backend_from_str(
|
@@ -1479,111 +1798,10 @@ class ModelRunner:
|
|
1479
1798
|
return attn_backend
|
1480
1799
|
|
1481
1800
|
def _get_attention_backend_from_str(self, backend_str: str):
|
1482
|
-
if backend_str
|
1483
|
-
if not self.use_mla_backend:
|
1484
|
-
from sglang.srt.layers.attention.flashinfer_backend import (
|
1485
|
-
FlashInferAttnBackend,
|
1486
|
-
)
|
1487
|
-
|
1488
|
-
# Init streams
|
1489
|
-
if self.server_args.speculative_algorithm == "EAGLE":
|
1490
|
-
if (
|
1491
|
-
not hasattr(self, "plan_stream_for_flashinfer")
|
1492
|
-
or not self.plan_stream_for_flashinfer
|
1493
|
-
):
|
1494
|
-
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
1495
|
-
return FlashInferAttnBackend(self)
|
1496
|
-
else:
|
1497
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
1498
|
-
FlashInferMLAAttnBackend,
|
1499
|
-
)
|
1500
|
-
|
1501
|
-
return FlashInferMLAAttnBackend(self)
|
1502
|
-
elif backend_str == "aiter":
|
1503
|
-
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
1504
|
-
|
1505
|
-
return AiterAttnBackend(self)
|
1506
|
-
elif self.server_args.attention_backend == "wave":
|
1507
|
-
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
1508
|
-
|
1509
|
-
return WaveAttnBackend(self)
|
1510
|
-
elif backend_str == "ascend":
|
1511
|
-
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
1512
|
-
|
1513
|
-
return AscendAttnBackend(self)
|
1514
|
-
elif backend_str == "triton":
|
1515
|
-
assert not self.model_config.is_encoder_decoder, (
|
1516
|
-
"Cross attention is not supported in the triton attention backend. "
|
1517
|
-
"Please use `--attention-backend flashinfer`."
|
1518
|
-
)
|
1519
|
-
if self.server_args.enable_double_sparsity:
|
1520
|
-
from sglang.srt.layers.attention.double_sparsity_backend import (
|
1521
|
-
DoubleSparseAttnBackend,
|
1522
|
-
)
|
1523
|
-
|
1524
|
-
return DoubleSparseAttnBackend(self)
|
1525
|
-
else:
|
1526
|
-
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
1527
|
-
|
1528
|
-
return TritonAttnBackend(self)
|
1529
|
-
elif backend_str == "torch_native":
|
1530
|
-
from sglang.srt.layers.attention.torch_native_backend import (
|
1531
|
-
TorchNativeAttnBackend,
|
1532
|
-
)
|
1533
|
-
|
1534
|
-
return TorchNativeAttnBackend(self)
|
1535
|
-
elif backend_str == "flashmla":
|
1536
|
-
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
1537
|
-
|
1538
|
-
return FlashMLABackend(self)
|
1539
|
-
elif backend_str == "fa3":
|
1540
|
-
assert (
|
1541
|
-
torch.cuda.get_device_capability()[0] == 8 and not self.use_mla_backend
|
1542
|
-
) or torch.cuda.get_device_capability()[0] == 9, (
|
1543
|
-
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
1544
|
-
"Please use `--attention-backend flashinfer`."
|
1545
|
-
)
|
1546
|
-
from sglang.srt.layers.attention.flashattention_backend import (
|
1547
|
-
FlashAttentionBackend,
|
1548
|
-
)
|
1549
|
-
|
1550
|
-
return FlashAttentionBackend(self)
|
1551
|
-
elif backend_str == "cutlass_mla":
|
1552
|
-
from sglang.srt.layers.attention.cutlass_mla_backend import (
|
1553
|
-
CutlassMLABackend,
|
1554
|
-
)
|
1555
|
-
|
1556
|
-
return CutlassMLABackend(self)
|
1557
|
-
elif backend_str == "trtllm_mla":
|
1558
|
-
if not self.use_mla_backend:
|
1559
|
-
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
1560
|
-
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
1561
|
-
|
1562
|
-
return TRTLLMMLABackend(self)
|
1563
|
-
elif backend_str == "trtllm_mha":
|
1564
|
-
if self.use_mla_backend:
|
1565
|
-
raise ValueError(
|
1566
|
-
"trtllm_mha backend can only be used with non-MLA models."
|
1567
|
-
)
|
1568
|
-
from sglang.srt.layers.attention.trtllm_mha_backend import (
|
1569
|
-
TRTLLMHAAttnBackend,
|
1570
|
-
)
|
1571
|
-
|
1572
|
-
return TRTLLMHAAttnBackend(self)
|
1573
|
-
elif backend_str == "intel_amx":
|
1574
|
-
from sglang.srt.layers.attention.intel_amx_backend import (
|
1575
|
-
IntelAMXAttnBackend,
|
1576
|
-
)
|
1577
|
-
|
1578
|
-
return IntelAMXAttnBackend(self)
|
1579
|
-
elif backend_str == "dual_chunk_flash_attn":
|
1580
|
-
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
1581
|
-
DualChunkFlashAttentionBackend,
|
1582
|
-
)
|
1583
|
-
|
1584
|
-
return DualChunkFlashAttentionBackend(self)
|
1585
|
-
else:
|
1801
|
+
if backend_str not in ATTENTION_BACKENDS:
|
1586
1802
|
raise ValueError(f"Invalid attention backend: {backend_str}")
|
1803
|
+
full_attention_backend = ATTENTION_BACKENDS[backend_str](self)
|
1804
|
+
return attn_backend_wrapper(self, full_attention_backend)
|
1587
1805
|
|
1588
1806
|
def init_double_sparsity_channel_config(self, selected_channel):
|
1589
1807
|
selected_channel = "." + selected_channel + "_proj"
|
@@ -1603,38 +1821,46 @@ class ModelRunner:
|
|
1603
1821
|
)
|
1604
1822
|
|
1605
1823
|
def init_device_graphs(self):
|
1606
|
-
"""Capture
|
1824
|
+
"""Capture device graphs."""
|
1607
1825
|
self.graph_runner = None
|
1608
|
-
self.
|
1826
|
+
self.graph_mem_usage = 0
|
1609
1827
|
|
1610
1828
|
if not self.is_generation:
|
1611
1829
|
# TODO: Currently, cuda graph only captures decode steps, which only exists for generation models
|
1612
1830
|
return
|
1613
1831
|
|
1614
|
-
if self.server_args.disable_cuda_graph:
|
1832
|
+
if self.device != "cpu" and self.server_args.disable_cuda_graph:
|
1833
|
+
return
|
1834
|
+
|
1835
|
+
if self.device == "cpu" and not self.server_args.enable_torch_compile:
|
1615
1836
|
return
|
1616
1837
|
|
1617
1838
|
tic = time.perf_counter()
|
1618
1839
|
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1619
1840
|
logger.info(
|
1620
|
-
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1841
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
1621
1842
|
)
|
1622
|
-
|
1623
|
-
CudaGraphRunner
|
1843
|
+
graph_runners = defaultdict(
|
1844
|
+
lambda: CudaGraphRunner,
|
1845
|
+
{
|
1846
|
+
"cpu": CPUGraphRunner,
|
1847
|
+
"npu": NPUGraphRunner,
|
1848
|
+
},
|
1624
1849
|
)
|
1850
|
+
self.graph_runner = graph_runners[self.device](self)
|
1851
|
+
|
1625
1852
|
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
1626
|
-
self.
|
1853
|
+
self.graph_mem_usage = before_mem - after_mem
|
1627
1854
|
logger.info(
|
1628
|
-
f"Capture cuda graph end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1629
|
-
f"mem usage={self.
|
1855
|
+
f"Capture {'cpu graph' if self.device == 'cpu' else 'cuda graph'} end. Time elapsed: {time.perf_counter() - tic:.2f} s. "
|
1856
|
+
f"mem usage={self.graph_mem_usage:.2f} GB. avail mem={after_mem:.2f} GB."
|
1630
1857
|
)
|
1631
1858
|
|
1632
1859
|
def init_threads_binding(self):
|
1633
1860
|
omp_cpuids = os.environ.get("SGLANG_CPU_OMP_THREADS_BIND", "all")
|
1861
|
+
cpu_ids_by_node = get_cpu_ids_by_node()
|
1862
|
+
n_numa_node = len(cpu_ids_by_node)
|
1634
1863
|
if omp_cpuids == "all":
|
1635
|
-
cpu_ids_by_node = get_cpu_ids_by_node()
|
1636
|
-
n_numa_node = len(cpu_ids_by_node)
|
1637
|
-
|
1638
1864
|
assert self.tp_size <= n_numa_node, (
|
1639
1865
|
f"SGLANG_CPU_OMP_THREADS_BIND is not set, in this case, "
|
1640
1866
|
f"tp_size {self.tp_size} should be smaller than or equal to number of numa node on the machine {n_numa_node}. "
|
@@ -1651,7 +1877,18 @@ class ModelRunner:
|
|
1651
1877
|
)
|
1652
1878
|
self.local_omp_cpuid = cpu_ids_by_node[self.tp_rank]
|
1653
1879
|
else:
|
1654
|
-
|
1880
|
+
threads_bind_list = omp_cpuids.split("|")
|
1881
|
+
assert self.tp_size == len(threads_bind_list), (
|
1882
|
+
f"SGLANG_CPU_OMP_THREADS_BIND setting must be aligned with TP size parameter ({self.tp_size}). "
|
1883
|
+
f"Please double check your settings."
|
1884
|
+
)
|
1885
|
+
self.local_omp_cpuid = threads_bind_list[self.tp_rank]
|
1886
|
+
if self.tp_size > n_numa_node:
|
1887
|
+
logger.warning(
|
1888
|
+
f"TP size ({self.tp_size})is larger than numa node number ({n_numa_node}), "
|
1889
|
+
f"in this case the available memory amount of each rank cannot be determined in prior. "
|
1890
|
+
f"Please set proper `--max-total-tokens` to avoid the out-of-memory error."
|
1891
|
+
)
|
1655
1892
|
|
1656
1893
|
def apply_torch_tp(self):
|
1657
1894
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -1771,18 +2008,24 @@ class ModelRunner:
|
|
1771
2008
|
reinit_attn_backend: bool = False,
|
1772
2009
|
split_forward_count: int = 1,
|
1773
2010
|
) -> Tuple[Union[LogitsProcessorOutput, PPProxyTensors], bool]:
|
1774
|
-
|
1775
|
-
forward_batch.forward_mode.
|
2011
|
+
mode_check = (
|
2012
|
+
forward_batch.forward_mode.is_cpu_graph
|
2013
|
+
if self.device == "cpu"
|
2014
|
+
else forward_batch.forward_mode.is_cuda_graph
|
2015
|
+
)
|
2016
|
+
can_run_graph = bool(
|
2017
|
+
mode_check()
|
1776
2018
|
and self.graph_runner
|
1777
2019
|
and self.graph_runner.can_run(forward_batch)
|
1778
2020
|
)
|
1779
|
-
|
2021
|
+
|
2022
|
+
if can_run_graph:
|
1780
2023
|
ret = self.graph_runner.replay(
|
1781
2024
|
forward_batch,
|
1782
2025
|
skip_attn_backend_init=skip_attn_backend_init,
|
1783
2026
|
pp_proxy_tensors=pp_proxy_tensors,
|
1784
2027
|
)
|
1785
|
-
return ret,
|
2028
|
+
return ret, can_run_graph
|
1786
2029
|
|
1787
2030
|
# For MLP sync
|
1788
2031
|
if forward_batch.global_num_tokens_cpu is not None:
|
@@ -1811,23 +2054,22 @@ class ModelRunner:
|
|
1811
2054
|
else:
|
1812
2055
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
1813
2056
|
|
1814
|
-
if
|
2057
|
+
if (
|
2058
|
+
forward_batch.global_num_tokens_cpu is not None
|
2059
|
+
and self.pp_group.is_last_rank
|
2060
|
+
):
|
1815
2061
|
forward_batch.post_forward_mlp_sync_batch(ret)
|
1816
2062
|
|
1817
|
-
return ret,
|
2063
|
+
return ret, can_run_graph
|
1818
2064
|
|
1819
2065
|
def _preprocess_logits(
|
1820
2066
|
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
1821
2067
|
):
|
1822
|
-
#
|
1823
|
-
|
1824
|
-
|
1825
|
-
|
1826
|
-
|
1827
|
-
sampling_info.sampling_info_done.wait()
|
1828
|
-
else:
|
1829
|
-
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
1830
|
-
sampling_info.update_regex_vocab_mask()
|
2068
|
+
# NOTE: In overlap mode, the function update_regex_vocab_mask (in sample)
|
2069
|
+
# was executed after we processed last batch's results.
|
2070
|
+
|
2071
|
+
# Calculate logits bias and apply it to next_token_logits.
|
2072
|
+
sampling_info.update_regex_vocab_mask()
|
1831
2073
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
1832
2074
|
|
1833
2075
|
def sample(
|
@@ -1852,7 +2094,6 @@ class ModelRunner:
|
|
1852
2094
|
)
|
1853
2095
|
|
1854
2096
|
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
1855
|
-
|
1856
2097
|
# Sample the next tokens
|
1857
2098
|
next_token_ids = self.sampler(
|
1858
2099
|
logits_output,
|
@@ -1860,9 +2101,47 @@ class ModelRunner:
|
|
1860
2101
|
forward_batch.return_logprob,
|
1861
2102
|
forward_batch.top_logprobs_nums,
|
1862
2103
|
forward_batch.token_ids_logprobs,
|
2104
|
+
# For prefill, we only use the position of the last token.
|
2105
|
+
(
|
2106
|
+
forward_batch.positions
|
2107
|
+
if forward_batch.forward_mode.is_decode()
|
2108
|
+
else forward_batch.seq_lens - 1
|
2109
|
+
),
|
1863
2110
|
)
|
1864
2111
|
return next_token_ids
|
1865
2112
|
|
2113
|
+
def compute_logprobs_only(
|
2114
|
+
self,
|
2115
|
+
logits_output: LogitsProcessorOutput,
|
2116
|
+
forward_batch: ForwardBatch,
|
2117
|
+
) -> None:
|
2118
|
+
"""
|
2119
|
+
Compute token_ids_logprobs without performing sampling.
|
2120
|
+
|
2121
|
+
Optimized path for prefill-only requests that need token_ids_logprobs but don't
|
2122
|
+
require next token generation. Skips expensive sampling operations
|
2123
|
+
while still providing requested probability information.
|
2124
|
+
|
2125
|
+
Args:
|
2126
|
+
logits_output: The logits output from the model forward
|
2127
|
+
forward_batch: The forward batch that generates logits_output
|
2128
|
+
"""
|
2129
|
+
if not forward_batch.token_ids_logprobs:
|
2130
|
+
return
|
2131
|
+
|
2132
|
+
# Preprocess logits (same as in sample method)
|
2133
|
+
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
2134
|
+
|
2135
|
+
# Delegate to sampler for logprob-only computation
|
2136
|
+
# This populates logits_output with requested token probabilities
|
2137
|
+
self.sampler.compute_logprobs_only(
|
2138
|
+
logits_output,
|
2139
|
+
forward_batch.sampling_info,
|
2140
|
+
forward_batch.return_logprob,
|
2141
|
+
forward_batch.top_logprobs_nums,
|
2142
|
+
forward_batch.token_ids_logprobs,
|
2143
|
+
)
|
2144
|
+
|
1866
2145
|
@property
|
1867
2146
|
def model_is_mrope(self) -> bool:
|
1868
2147
|
"""Detect if the model has "mrope" rope_scaling type.
|