sglang 0.5.2rc2__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -11
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +474 -142
- sglang/compile_deep_gemm.py +3 -0
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +10 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +314 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/mamba_utils.py +117 -0
- sglang/srt/configs/model_config.py +228 -92
- sglang/srt/configs/nemotron_h.py +286 -0
- sglang/srt/configs/qwen3_next.py +294 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +78 -37
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +373 -68
- sglang/srt/disaggregation/prefill.py +53 -49
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +842 -0
- sglang/srt/entrypoints/grpc_server.py +950 -0
- sglang/srt/entrypoints/http_server.py +179 -60
- sglang/srt/entrypoints/openai/protocol.py +265 -29
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +213 -122
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
- sglang/srt/environ.py +289 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +17 -8
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +119 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +492 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +327 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +215 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +343 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +40 -8
- sglang/srt/layers/attention/flashinfer_backend.py +341 -204
- sglang/srt/layers/attention/flashinfer_mla_backend.py +28 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +708 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +974 -0
- sglang/srt/layers/attention/mamba/mamba.py +577 -0
- sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
- sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +214 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +562 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +646 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +264 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +180 -18
- sglang/srt/layers/modelopt_utils.py +11 -0
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +21 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -454
- sglang/srt/layers/moe/ep_moe/layer.py +248 -333
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +68 -72
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
- sglang/srt/layers/moe/moe_runner/runner.py +83 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +29 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/__init__.py +1 -1
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +155 -60
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +191 -56
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +28 -33
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +44 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +55 -0
- sglang/srt/managers/schedule_batch.py +343 -212
- sglang/srt/managers/schedule_policy.py +145 -18
- sglang/srt/managers/scheduler.py +653 -273
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +255 -108
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +579 -674
- sglang/srt/managers/tp_worker.py +96 -26
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +21 -22
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +9 -2
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +651 -80
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +227 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +93 -48
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +74 -46
- sglang/srt/model_executor/model_runner.py +455 -176
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +10 -4
- sglang/srt/model_loader/loader.py +319 -10
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +161 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +578 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +17 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/grok.py +5 -13
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mixtral.py +1 -3
- sglang/srt/models/mllama4.py +50 -4
- sglang/srt/models/nemotron_h.py +514 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +55 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +49 -26
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1051 -285
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +98 -29
- sglang/srt/speculative/ngram_info.py +428 -0
- sglang/srt/speculative/ngram_worker.py +246 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +605 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +451 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +55 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/longbench_v2/__init__.py +1 -0
- sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
- sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
- sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
- sglang/test/run_eval.py +119 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_longbench_v2.py +332 -0
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +9 -19
- sglang/test/test_deterministic.py +313 -0
- sglang/test/test_deterministic_utils.py +81 -0
- sglang/test/test_disaggregation_utils.py +140 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +407 -8
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +392 -251
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/managers/tp_worker_overlap_thread.py +0 -296
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/test/test_block_fp8_ep.py +0 -358
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -5,13 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch_npu
|
8
|
-
from torch.nn.functional import scaled_dot_product_attention
|
9
8
|
|
10
9
|
from sglang.srt.configs.model_config import AttentionArch
|
11
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
|
12
12
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
14
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
16
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
15
17
|
from sglang.srt.utils import get_bool_env_var
|
16
18
|
|
17
19
|
if TYPE_CHECKING:
|
@@ -33,6 +35,9 @@ class ForwardMetadata:
|
|
33
35
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
34
36
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
35
37
|
seq_lens_cpu_list: Optional[List[int]] = None
|
38
|
+
seq_lens_list_cumsum: Optional[List[int]] = None
|
39
|
+
seq_lens: Optional[torch.Tensor] = None
|
40
|
+
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
36
41
|
|
37
42
|
|
38
43
|
class AscendAttnBackend(AttentionBackend):
|
@@ -64,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
64
69
|
if self.use_mla:
|
65
70
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
66
71
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
72
|
+
self.q_head_dim = (
|
73
|
+
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
74
|
+
)
|
67
75
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
68
76
|
self.graph_metadata = {}
|
69
77
|
self.max_context_len = model_runner.model_config.context_len
|
@@ -83,6 +91,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
83
91
|
|
84
92
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
85
93
|
"""Init the metadata for a forward pass."""
|
94
|
+
tp_size = get_attention_tp_size()
|
86
95
|
self.forward_metadata = ForwardMetadata()
|
87
96
|
|
88
97
|
self.forward_metadata.block_tables = (
|
@@ -96,9 +105,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
96
105
|
forward_batch.extend_seq_lens.cpu().int()
|
97
106
|
)
|
98
107
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
99
|
-
|
100
|
-
|
101
|
-
|
108
|
+
|
109
|
+
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
110
|
+
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
102
111
|
|
103
112
|
self.graph_mode = False
|
104
113
|
|
@@ -119,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
|
|
119
128
|
seq_lens: torch.Tensor,
|
120
129
|
encoder_lens: Optional[torch.Tensor],
|
121
130
|
forward_mode: ForwardMode,
|
122
|
-
spec_info: Optional[
|
131
|
+
spec_info: Optional[SpecInput],
|
123
132
|
):
|
124
133
|
metadata = ForwardMetadata()
|
125
134
|
|
126
135
|
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
127
136
|
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
137
|
+
metadata.seq_lens = seq_lens
|
138
|
+
metadata.actual_seq_lengths_q = torch.tensor(
|
139
|
+
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
140
|
+
)
|
128
141
|
|
129
142
|
self.graph_metadata[bs] = metadata
|
130
143
|
self.forward_metadata = metadata
|
@@ -139,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
139
152
|
seq_lens_sum: int,
|
140
153
|
encoder_lens: Optional[torch.Tensor],
|
141
154
|
forward_mode: ForwardMode,
|
142
|
-
spec_info: Optional[
|
155
|
+
spec_info: Optional[SpecInput],
|
143
156
|
seq_lens_cpu: Optional[torch.Tensor],
|
144
157
|
):
|
145
158
|
metadata = self.graph_metadata[bs]
|
@@ -153,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
|
|
153
166
|
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
154
167
|
metadata.block_tables[bs:, :].fill_(0)
|
155
168
|
|
169
|
+
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
170
|
+
|
156
171
|
self.forward_metadata = metadata
|
157
172
|
|
158
173
|
self.graph_mode = True
|
@@ -160,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
|
|
160
175
|
def get_cuda_graph_seq_len_fill_value(self):
|
161
176
|
return 0
|
162
177
|
|
178
|
+
def forward_sparse(
|
179
|
+
self,
|
180
|
+
q: torch.Tensor,
|
181
|
+
k: torch.Tensor,
|
182
|
+
v: torch.Tensor,
|
183
|
+
layer: RadixAttention,
|
184
|
+
forward_batch: ForwardBatch,
|
185
|
+
save_kv_cache: bool = True,
|
186
|
+
# For multi_head latent attention
|
187
|
+
q_rope: Optional[torch.Tensor] = None,
|
188
|
+
k_rope: Optional[torch.Tensor] = None,
|
189
|
+
topk_indices: torch.Tensor = None,
|
190
|
+
):
|
191
|
+
|
192
|
+
is_prefill = forward_batch.forward_mode.is_extend()
|
193
|
+
|
194
|
+
if save_kv_cache:
|
195
|
+
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
196
|
+
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
197
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
199
|
+
)
|
200
|
+
q_nope, q_pe = q, q_rope
|
201
|
+
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
202
|
+
block_table = self.forward_metadata.block_tables
|
203
|
+
if is_prefill:
|
204
|
+
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
205
|
+
else:
|
206
|
+
if self.forward_metadata.actual_seq_lengths_q is None:
|
207
|
+
actual_seq_qlen = (
|
208
|
+
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
212
|
+
if self.forward_metadata.seq_lens_cpu_int is None:
|
213
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
214
|
+
else:
|
215
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
216
|
+
|
217
|
+
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
218
|
+
query=q_nope,
|
219
|
+
key=k_nope,
|
220
|
+
value=k_nope,
|
221
|
+
query_rope=q_pe,
|
222
|
+
key_rope=k_pe,
|
223
|
+
sparse_indices=topk_indices,
|
224
|
+
scale_value=layer.scaling,
|
225
|
+
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
226
|
+
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
227
|
+
block_table=block_table,
|
228
|
+
sparse_block_size=1,
|
229
|
+
layout_query="TND",
|
230
|
+
layout_kv="PA_BSND",
|
231
|
+
sparse_mode=3,
|
232
|
+
)
|
233
|
+
|
234
|
+
return attn_out
|
235
|
+
|
163
236
|
def forward_extend(
|
164
237
|
self,
|
165
238
|
q,
|
@@ -168,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
|
|
168
241
|
layer: RadixAttention,
|
169
242
|
forward_batch: ForwardBatch,
|
170
243
|
save_kv_cache: bool = True,
|
244
|
+
# For multi_head latent attention
|
245
|
+
q_rope: Optional[torch.Tensor] = None,
|
246
|
+
k_rope: Optional[torch.Tensor] = None,
|
247
|
+
topk_indices: Optional[torch.Tensor] = None,
|
171
248
|
):
|
249
|
+
if topk_indices is not None:
|
250
|
+
return self.forward_sparse(
|
251
|
+
q,
|
252
|
+
k,
|
253
|
+
v,
|
254
|
+
layer,
|
255
|
+
forward_batch,
|
256
|
+
save_kv_cache,
|
257
|
+
q_rope,
|
258
|
+
k_rope,
|
259
|
+
topk_indices,
|
260
|
+
)
|
172
261
|
if not self.use_mla:
|
173
262
|
if save_kv_cache:
|
174
263
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -368,7 +457,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
368
457
|
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
369
458
|
)
|
370
459
|
|
371
|
-
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
460
|
+
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
|
372
461
|
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
373
462
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
374
463
|
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
@@ -394,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
394
483
|
antiquant_scale=None,
|
395
484
|
sparse_mode=0,
|
396
485
|
)
|
397
|
-
output = torch.
|
486
|
+
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
|
398
487
|
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
399
488
|
|
400
489
|
torch_npu.npu_fused_infer_attention_score.out(
|
@@ -429,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
|
|
429
518
|
# For multi-head latent attention
|
430
519
|
q_rope: Optional[torch.Tensor] = None,
|
431
520
|
k_rope: Optional[torch.Tensor] = None,
|
521
|
+
topk_indices: Optional[torch.Tensor] = None,
|
432
522
|
):
|
523
|
+
if is_mla_preprocess_enabled():
|
524
|
+
# MLAPO does saving kv_cache
|
525
|
+
save_kv_cache = False
|
526
|
+
if topk_indices is not None:
|
527
|
+
return self.forward_sparse(
|
528
|
+
q,
|
529
|
+
k,
|
530
|
+
v,
|
531
|
+
layer,
|
532
|
+
forward_batch,
|
533
|
+
save_kv_cache,
|
534
|
+
q_rope,
|
535
|
+
k_rope,
|
536
|
+
topk_indices,
|
537
|
+
)
|
538
|
+
|
433
539
|
if self.graph_mode:
|
434
540
|
return self.forward_decode_graph(
|
435
541
|
q,
|
@@ -0,0 +1,215 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import TYPE_CHECKING
|
3
|
+
|
4
|
+
logger = logging.getLogger(__name__)
|
5
|
+
|
6
|
+
|
7
|
+
if TYPE_CHECKING:
|
8
|
+
# evade circular imports
|
9
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
10
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
11
|
+
|
12
|
+
ATTENTION_BACKENDS = {}
|
13
|
+
|
14
|
+
|
15
|
+
def register_attention_backend(name):
|
16
|
+
def decorator(fn):
|
17
|
+
ATTENTION_BACKENDS[name] = fn
|
18
|
+
return fn
|
19
|
+
|
20
|
+
return decorator
|
21
|
+
|
22
|
+
|
23
|
+
@register_attention_backend("flashinfer")
|
24
|
+
def create_flashinfer_backend(runner):
|
25
|
+
import torch
|
26
|
+
|
27
|
+
if not runner.use_mla_backend:
|
28
|
+
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
29
|
+
|
30
|
+
# Init streams
|
31
|
+
if runner.server_args.speculative_algorithm == "EAGLE":
|
32
|
+
if (
|
33
|
+
not hasattr(runner, "plan_stream_for_flashinfer")
|
34
|
+
or not runner.plan_stream_for_flashinfer
|
35
|
+
):
|
36
|
+
runner.plan_stream_for_flashinfer = torch.cuda.Stream()
|
37
|
+
return FlashInferAttnBackend(runner)
|
38
|
+
else:
|
39
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
40
|
+
FlashInferMLAAttnBackend,
|
41
|
+
)
|
42
|
+
|
43
|
+
return FlashInferMLAAttnBackend(runner)
|
44
|
+
|
45
|
+
|
46
|
+
@register_attention_backend("trtllm_mla")
|
47
|
+
def create_trtllm_mla_backend(runner):
|
48
|
+
if not runner.use_mla_backend:
|
49
|
+
raise ValueError("trtllm_mla backend can only be used with MLA models.")
|
50
|
+
from sglang.srt.layers.attention.trtllm_mla_backend import TRTLLMMLABackend
|
51
|
+
|
52
|
+
return TRTLLMMLABackend(runner)
|
53
|
+
|
54
|
+
|
55
|
+
@register_attention_backend("aiter")
|
56
|
+
def create_aiter_backend(runner):
|
57
|
+
from sglang.srt.layers.attention.aiter_backend import AiterAttnBackend
|
58
|
+
|
59
|
+
return AiterAttnBackend(runner)
|
60
|
+
|
61
|
+
|
62
|
+
@register_attention_backend("wave")
|
63
|
+
def create_wave_backend(runner):
|
64
|
+
from sglang.srt.layers.attention.wave_backend import WaveAttnBackend
|
65
|
+
|
66
|
+
return WaveAttnBackend(runner)
|
67
|
+
|
68
|
+
|
69
|
+
@register_attention_backend("ascend")
|
70
|
+
def create_ascend_backend(runner):
|
71
|
+
from sglang.srt.layers.attention.ascend_backend import AscendAttnBackend
|
72
|
+
|
73
|
+
return AscendAttnBackend(runner)
|
74
|
+
|
75
|
+
|
76
|
+
@register_attention_backend("nsa")
|
77
|
+
def create_nsa_backend(runner):
|
78
|
+
from sglang.srt.layers.attention.nsa_backend import NativeSparseAttnBackend
|
79
|
+
|
80
|
+
return NativeSparseAttnBackend(runner)
|
81
|
+
|
82
|
+
|
83
|
+
@register_attention_backend("triton")
|
84
|
+
def create_triton_backend(runner):
|
85
|
+
assert not runner.model_config.is_encoder_decoder, (
|
86
|
+
"Cross attention is not supported in the triton attention backend. "
|
87
|
+
"Please use `--attention-backend flashinfer`."
|
88
|
+
)
|
89
|
+
if runner.server_args.enable_double_sparsity:
|
90
|
+
from sglang.srt.layers.attention.double_sparsity_backend import (
|
91
|
+
DoubleSparseAttnBackend,
|
92
|
+
)
|
93
|
+
|
94
|
+
return DoubleSparseAttnBackend(runner)
|
95
|
+
else:
|
96
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
97
|
+
|
98
|
+
return TritonAttnBackend(runner)
|
99
|
+
|
100
|
+
|
101
|
+
@register_attention_backend("torch_native")
|
102
|
+
def create_torch_native_backend(runner):
|
103
|
+
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
104
|
+
|
105
|
+
return TorchNativeAttnBackend(runner)
|
106
|
+
|
107
|
+
|
108
|
+
@register_attention_backend("flex_attention")
|
109
|
+
def create_flex_attention_backend(runner):
|
110
|
+
from sglang.srt.layers.attention.torch_flex_backend import TorchFlexAttnBackend
|
111
|
+
|
112
|
+
return TorchFlexAttnBackend(runner)
|
113
|
+
|
114
|
+
|
115
|
+
@register_attention_backend("flashmla")
|
116
|
+
def create_flashmla_backend(runner):
|
117
|
+
from sglang.srt.layers.attention.flashmla_backend import FlashMLABackend
|
118
|
+
|
119
|
+
return FlashMLABackend(runner)
|
120
|
+
|
121
|
+
|
122
|
+
@register_attention_backend("fa3")
|
123
|
+
def create_flashattention_v3_backend(runner):
|
124
|
+
import torch
|
125
|
+
|
126
|
+
assert (
|
127
|
+
torch.cuda.get_device_capability()[0] == 8 and not runner.use_mla_backend
|
128
|
+
) or torch.cuda.get_device_capability()[0] == 9, (
|
129
|
+
"FlashAttention v3 Backend requires SM>=80 and SM<=90. "
|
130
|
+
"Please use `--attention-backend flashinfer`."
|
131
|
+
)
|
132
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
133
|
+
|
134
|
+
return FlashAttentionBackend(runner)
|
135
|
+
|
136
|
+
|
137
|
+
@register_attention_backend("fa4")
|
138
|
+
def create_flashattention_v4_backend(runner):
|
139
|
+
from sglang.srt.layers.attention.flashattention_backend import FlashAttentionBackend
|
140
|
+
|
141
|
+
return FlashAttentionBackend(runner, fa_impl_ver=4)
|
142
|
+
|
143
|
+
|
144
|
+
@register_attention_backend("cutlass_mla")
|
145
|
+
def create_cutlass_mla_backend(runner):
|
146
|
+
from sglang.srt.layers.attention.cutlass_mla_backend import CutlassMLABackend
|
147
|
+
|
148
|
+
return CutlassMLABackend(runner)
|
149
|
+
|
150
|
+
|
151
|
+
@register_attention_backend("trtllm_mha")
|
152
|
+
def create_trtllm_mha_backend(runner):
|
153
|
+
if runner.use_mla_backend:
|
154
|
+
raise ValueError("trtllm_mha backend can only be used with non-MLA models.")
|
155
|
+
from sglang.srt.layers.attention.trtllm_mha_backend import TRTLLMHAAttnBackend
|
156
|
+
|
157
|
+
return TRTLLMHAAttnBackend(runner)
|
158
|
+
|
159
|
+
|
160
|
+
@register_attention_backend("intel_amx")
|
161
|
+
def create_intel_amx_backend(runner):
|
162
|
+
from sglang.srt.layers.attention.intel_amx_backend import IntelAMXAttnBackend
|
163
|
+
|
164
|
+
return IntelAMXAttnBackend(runner)
|
165
|
+
|
166
|
+
|
167
|
+
@register_attention_backend("dual_chunk_flash_attn")
|
168
|
+
def create_dual_chunk_flash_attn_backend(runner):
|
169
|
+
from sglang.srt.layers.attention.dual_chunk_flashattention_backend import (
|
170
|
+
DualChunkFlashAttentionBackend,
|
171
|
+
)
|
172
|
+
|
173
|
+
return DualChunkFlashAttentionBackend(runner)
|
174
|
+
|
175
|
+
|
176
|
+
def attn_backend_wrapper(runner: "ModelRunner", full_attn_backend: "AttentionBackend"):
|
177
|
+
"""
|
178
|
+
Wrapper for special models like hybrid GDN, so we don't
|
179
|
+
need to change the code of the original attention backend.
|
180
|
+
"""
|
181
|
+
assert not (
|
182
|
+
runner.hybrid_gdn_config is not None and runner.use_mla_backend
|
183
|
+
), "hybrid_gdn can only be used with non-MLA models."
|
184
|
+
|
185
|
+
if cfg := runner.mambaish_config:
|
186
|
+
from sglang.srt.layers.attention.hybrid_linear_attn_backend import (
|
187
|
+
GDNAttnBackend,
|
188
|
+
HybridLinearAttnBackend,
|
189
|
+
Mamba2AttnBackend,
|
190
|
+
)
|
191
|
+
from sglang.srt.utils import is_blackwell, is_npu
|
192
|
+
|
193
|
+
if runner.hybrid_gdn_config is not None:
|
194
|
+
if is_blackwell():
|
195
|
+
assert (
|
196
|
+
runner.server_args.attention_backend == "triton"
|
197
|
+
), "triton backend is the only supported backend on Blackwell GPUs for hybrid GDN models, use --attention-backend triton to specify the backend."
|
198
|
+
if is_npu():
|
199
|
+
assert (
|
200
|
+
runner.server_args.attention_backend == "ascend"
|
201
|
+
), "ascend backend is the only supported backend on NPU for hybrid GDN models, use --attention-backend ascend to specify the backend."
|
202
|
+
logger.info(f"Using hybrid linear attention backend for hybrid GDN models.")
|
203
|
+
linear_attn_backend = GDNAttnBackend(runner)
|
204
|
+
elif runner.mamba2_config is not None:
|
205
|
+
linear_attn_backend = Mamba2AttnBackend(runner)
|
206
|
+
else:
|
207
|
+
raise ValueError(
|
208
|
+
"Expected hybrid GDN or NemotronH models, but got unknown model."
|
209
|
+
)
|
210
|
+
full_attn_layers = cfg.full_attention_layer_ids
|
211
|
+
return HybridLinearAttnBackend(
|
212
|
+
full_attn_backend, linear_attn_backend, full_attn_layers
|
213
|
+
)
|
214
|
+
|
215
|
+
return full_attn_backend
|
@@ -6,9 +6,10 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
6
6
|
import torch
|
7
7
|
|
8
8
|
if TYPE_CHECKING:
|
9
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
9
10
|
from sglang.srt.layers.radix_attention import RadixAttention
|
10
11
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
11
|
-
from sglang.srt.speculative.
|
12
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
12
13
|
|
13
14
|
|
14
15
|
class AttentionBackend(ABC):
|
@@ -31,7 +32,7 @@ class AttentionBackend(ABC):
|
|
31
32
|
seq_lens: torch.Tensor,
|
32
33
|
encoder_lens: Optional[torch.Tensor],
|
33
34
|
forward_mode: ForwardMode,
|
34
|
-
spec_info: Optional[
|
35
|
+
spec_info: Optional[SpecInput],
|
35
36
|
):
|
36
37
|
"""Init the metadata for a forward pass for capturing a cuda graph."""
|
37
38
|
raise NotImplementedError()
|
@@ -44,7 +45,7 @@ class AttentionBackend(ABC):
|
|
44
45
|
seq_lens_sum: int,
|
45
46
|
encoder_lens: Optional[torch.Tensor],
|
46
47
|
forward_mode: ForwardMode,
|
47
|
-
spec_info: Optional[
|
48
|
+
spec_info: Optional[SpecInput],
|
48
49
|
seq_lens_cpu: Optional[torch.Tensor],
|
49
50
|
):
|
50
51
|
"""Init the metadata for a forward pass for replaying a cuda graph."""
|
@@ -115,3 +116,11 @@ class AttentionBackend(ABC):
|
|
115
116
|
def support_triton(self):
|
116
117
|
"""Check if the current backend supports triton."""
|
117
118
|
return True
|
119
|
+
|
120
|
+
def get_indexer_metadata(
|
121
|
+
self,
|
122
|
+
layer_id: int,
|
123
|
+
forward_batch: ForwardBatch,
|
124
|
+
) -> Optional[BaseIndexerMetadata]:
|
125
|
+
"""Get the indexer metadata. None means don't support indexer."""
|
126
|
+
return None
|
@@ -20,7 +20,7 @@ from sglang.srt.utils import is_cuda
|
|
20
20
|
if TYPE_CHECKING:
|
21
21
|
from sglang.srt.layers.radix_attention import RadixAttention
|
22
22
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
23
|
-
from sglang.srt.speculative.spec_info import
|
23
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
24
24
|
|
25
25
|
_is_cuda = is_cuda()
|
26
26
|
if _is_cuda:
|
@@ -151,7 +151,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
151
151
|
seq_lens: torch.Tensor,
|
152
152
|
encoder_lens: Optional[torch.Tensor],
|
153
153
|
forward_mode: ForwardMode,
|
154
|
-
spec_info: Optional[
|
154
|
+
spec_info: Optional[SpecInput],
|
155
155
|
):
|
156
156
|
if forward_mode.is_decode_or_idle():
|
157
157
|
if spec_info is None:
|
@@ -190,7 +190,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
190
190
|
seq_lens_sum: int,
|
191
191
|
encoder_lens: Optional[torch.Tensor],
|
192
192
|
forward_mode: ForwardMode,
|
193
|
-
spec_info: Optional[
|
193
|
+
spec_info: Optional[SpecInput],
|
194
194
|
seq_lens_cpu: Optional[torch.Tensor],
|
195
195
|
):
|
196
196
|
|
@@ -1537,7 +1537,7 @@ class DualChunkFlashAttentionBackend(AttentionBackend):
|
|
1537
1537
|
query_inter,
|
1538
1538
|
key_cache,
|
1539
1539
|
value_cache,
|
1540
|
-
block_table
|
1540
|
+
block_table,
|
1541
1541
|
decode_meta.seq_lens_inter,
|
1542
1542
|
softmax_scale,
|
1543
1543
|
causal=False,
|