sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -4,27 +4,25 @@ from __future__ import annotations
|
|
4
4
|
end to end attention solution with aiter kernels
|
5
5
|
"""
|
6
6
|
|
7
|
-
import math
|
8
|
-
import os
|
9
7
|
from dataclasses import dataclass
|
10
8
|
from enum import Enum, auto
|
11
|
-
from
|
12
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
13
10
|
|
14
11
|
import torch
|
15
12
|
import triton
|
16
|
-
import triton.language as tl
|
17
13
|
|
18
|
-
from sglang.global_config import global_config
|
19
14
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
15
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
21
|
-
from sglang.srt.layers.dp_attention import
|
16
|
+
from sglang.srt.layers.dp_attention import (
|
17
|
+
get_attention_tp_size,
|
18
|
+
is_dp_attention_enabled,
|
19
|
+
)
|
22
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
23
21
|
|
24
22
|
if TYPE_CHECKING:
|
25
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
26
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
27
|
-
from sglang.srt.speculative.spec_info import
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
28
26
|
|
29
27
|
try:
|
30
28
|
from aiter import (
|
@@ -154,6 +152,8 @@ class AiterAttnBackend(AttentionBackend):
|
|
154
152
|
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
155
153
|
)
|
156
154
|
|
155
|
+
self.enable_dp_attention = is_dp_attention_enabled()
|
156
|
+
|
157
157
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
158
158
|
"""Init auxiliary variables for triton attention backend."""
|
159
159
|
|
@@ -302,19 +302,19 @@ class AiterAttnBackend(AttentionBackend):
|
|
302
302
|
if self.use_mla:
|
303
303
|
self.mla_indices_updater_prefill.update(
|
304
304
|
forward_batch.req_pool_indices,
|
305
|
-
forward_batch.
|
306
|
-
|
305
|
+
forward_batch.seq_lens,
|
306
|
+
forward_batch.seq_lens_sum,
|
307
307
|
forward_batch.extend_seq_lens,
|
308
|
-
max(
|
309
|
-
forward_batch.
|
308
|
+
forward_batch.extend_seq_lens.max().item(),
|
309
|
+
forward_batch.seq_lens.max().item(),
|
310
310
|
spec_info=None,
|
311
311
|
)
|
312
|
-
|
313
|
-
|
314
|
-
|
312
|
+
|
313
|
+
kv_indices = self.mla_indices_updater_prefill.kv_indices
|
314
|
+
|
315
315
|
self.forward_metadata = ForwardMetadata(
|
316
316
|
self.mla_indices_updater_prefill.kv_indptr,
|
317
|
-
|
317
|
+
kv_indices,
|
318
318
|
self.mla_indices_updater_prefill.qo_indptr,
|
319
319
|
self.kv_last_page_len[:bs],
|
320
320
|
self.mla_indices_updater_prefill.max_q_len,
|
@@ -369,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
369
369
|
seq_lens: torch.Tensor,
|
370
370
|
encoder_lens: Optional[torch.Tensor],
|
371
371
|
forward_mode: ForwardMode,
|
372
|
-
spec_info: Optional[
|
372
|
+
spec_info: Optional[SpecInput],
|
373
373
|
):
|
374
374
|
if forward_mode.is_decode_or_idle():
|
375
375
|
qo_indptr = None
|
@@ -504,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
504
504
|
seq_lens_sum: int,
|
505
505
|
encoder_lens: Optional[torch.Tensor],
|
506
506
|
forward_mode: ForwardMode,
|
507
|
-
spec_info: Optional[
|
507
|
+
spec_info: Optional[SpecInput],
|
508
508
|
seq_lens_cpu: Optional[torch.Tensor],
|
509
509
|
):
|
510
510
|
if forward_mode.is_decode_or_idle():
|
@@ -614,66 +614,90 @@ class AiterAttnBackend(AttentionBackend):
|
|
614
614
|
assert len(k.shape) == 3
|
615
615
|
assert len(v.shape) == 3
|
616
616
|
|
617
|
-
if
|
618
|
-
|
619
|
-
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
617
|
+
if (
|
618
|
+
forward_batch.forward_mode.is_extend()
|
619
|
+
and not forward_batch.forward_mode.is_target_verify()
|
620
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
621
|
+
):
|
622
|
+
if kv_indices.shape[0] == 0:
|
623
|
+
o = flash_attn_varlen_func(
|
624
|
+
q,
|
625
|
+
k,
|
626
|
+
v,
|
627
|
+
qo_indptr,
|
628
|
+
qo_indptr,
|
629
|
+
max_q_len,
|
630
|
+
max_q_len,
|
631
|
+
softmax_scale=layer.scaling,
|
632
|
+
causal=True,
|
633
|
+
)
|
634
|
+
return o
|
635
|
+
elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
|
636
|
+
K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
|
637
|
+
kvc, k_pe = torch.split(
|
638
|
+
K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
|
639
|
+
)
|
640
|
+
kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
|
636
641
|
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
658
|
-
|
659
|
-
|
660
|
-
|
661
|
-
|
662
|
-
|
663
|
-
|
664
|
-
|
665
|
-
|
666
|
-
|
667
|
-
|
668
|
-
|
669
|
-
|
670
|
-
|
671
|
-
|
672
|
-
|
673
|
-
|
674
|
-
|
675
|
-
|
676
|
-
|
642
|
+
kvprefix = kvprefix.view(
|
643
|
+
-1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
|
644
|
+
)
|
645
|
+
k_prefix, v_prefix = torch.split(
|
646
|
+
kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
|
647
|
+
)
|
648
|
+
k_prefix = torch.cat(
|
649
|
+
[
|
650
|
+
k_prefix,
|
651
|
+
torch.broadcast_to(
|
652
|
+
k_pe,
|
653
|
+
(k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
|
654
|
+
),
|
655
|
+
],
|
656
|
+
dim=-1,
|
657
|
+
)
|
658
|
+
assert (
|
659
|
+
forward_batch.extend_prefix_lens.shape
|
660
|
+
== forward_batch.extend_seq_lens.shape
|
661
|
+
)
|
662
|
+
|
663
|
+
k = k_prefix
|
664
|
+
v = v_prefix
|
665
|
+
|
666
|
+
o = flash_attn_varlen_func(
|
667
|
+
q,
|
668
|
+
k,
|
669
|
+
v,
|
670
|
+
qo_indptr,
|
671
|
+
kv_indptr,
|
672
|
+
max_q_len,
|
673
|
+
max_kv_len,
|
674
|
+
softmax_scale=layer.scaling,
|
675
|
+
causal=True,
|
676
|
+
)
|
677
|
+
return o
|
678
|
+
|
679
|
+
else:
|
680
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
681
|
+
o = q.new_empty(
|
682
|
+
(q.shape[0], layer.tp_q_head_num * layer.v_head_dim)
|
683
|
+
)
|
684
|
+
else:
|
685
|
+
o = torch.empty_like(q)
|
686
|
+
|
687
|
+
mla_prefill_fwd(
|
688
|
+
q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
|
689
|
+
K_Buffer.view(-1, 1, 1, layer.qk_head_dim),
|
690
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
691
|
+
qo_indptr,
|
692
|
+
kv_indptr,
|
693
|
+
kv_indices,
|
694
|
+
self.forward_metadata.kv_last_page_len,
|
695
|
+
self.forward_metadata.max_q_len,
|
696
|
+
layer.scaling,
|
697
|
+
layer.logit_cap,
|
698
|
+
)
|
699
|
+
K_Buffer = K_Buffer.view(-1, layer.tp_k_head_num, layer.qk_head_dim)
|
700
|
+
return o
|
677
701
|
elif forward_batch.forward_mode.is_target_verify():
|
678
702
|
o = q.new_empty((q.shape[0], layer.tp_q_head_num, layer.v_head_dim))
|
679
703
|
mla_decode_fwd(
|
@@ -859,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
|
|
859
883
|
seq_lens_sum: int,
|
860
884
|
prefix_lens: torch.Tensor,
|
861
885
|
encoder_lens: Optional[torch.Tensor],
|
862
|
-
spec_info: Optional[
|
886
|
+
spec_info: Optional[SpecInput],
|
863
887
|
):
|
864
888
|
# Keep the signature for type checking. It will be assigned during runtime.
|
865
889
|
raise NotImplementedError()
|
@@ -871,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
|
|
871
895
|
seq_lens_sum: int,
|
872
896
|
prefix_lens: torch.Tensor,
|
873
897
|
encoder_lens: Optional[torch.Tensor],
|
874
|
-
spec_info: Optional[
|
898
|
+
spec_info: Optional[SpecInput],
|
875
899
|
):
|
876
900
|
|
877
901
|
kv_start_idx = None
|
@@ -955,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
955
979
|
extend_lens: torch.Tensor,
|
956
980
|
max_q_len: int,
|
957
981
|
max_kv_len: int,
|
958
|
-
spec_info: Optional[
|
982
|
+
spec_info: Optional[SpecInput],
|
959
983
|
):
|
960
984
|
# Keep the signature for type checking. It will be assigned during runtime.
|
961
985
|
raise NotImplementedError()
|
@@ -968,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
968
992
|
extend_lens: torch.Tensor,
|
969
993
|
max_q_len: int,
|
970
994
|
max_kv_len: int,
|
971
|
-
spec_info: Optional[
|
995
|
+
spec_info: Optional[SpecInput],
|
972
996
|
):
|
973
997
|
bs = len(req_pool_indices)
|
974
998
|
|
@@ -1025,7 +1049,7 @@ class AiterMultiStepDraftBackend:
|
|
1025
1049
|
topk: int,
|
1026
1050
|
speculative_num_steps: int,
|
1027
1051
|
):
|
1028
|
-
from sglang.srt.speculative.
|
1052
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
1029
1053
|
|
1030
1054
|
self.topk = topk
|
1031
1055
|
self.speculative_num_steps = speculative_num_steps
|
@@ -5,13 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch_npu
|
8
|
-
from torch.nn.functional import scaled_dot_product_attention
|
9
8
|
|
10
9
|
from sglang.srt.configs.model_config import AttentionArch
|
11
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
|
12
12
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
14
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
16
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
15
17
|
from sglang.srt.utils import get_bool_env_var
|
16
18
|
|
17
19
|
if TYPE_CHECKING:
|
@@ -33,6 +35,9 @@ class ForwardMetadata:
|
|
33
35
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
34
36
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
35
37
|
seq_lens_cpu_list: Optional[List[int]] = None
|
38
|
+
seq_lens_list_cumsum: Optional[List[int]] = None
|
39
|
+
seq_lens: Optional[torch.Tensor] = None
|
40
|
+
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
36
41
|
|
37
42
|
|
38
43
|
class AscendAttnBackend(AttentionBackend):
|
@@ -64,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
64
69
|
if self.use_mla:
|
65
70
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
66
71
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
72
|
+
self.q_head_dim = (
|
73
|
+
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
74
|
+
)
|
67
75
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
68
76
|
self.graph_metadata = {}
|
69
77
|
self.max_context_len = model_runner.model_config.context_len
|
@@ -83,6 +91,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
83
91
|
|
84
92
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
85
93
|
"""Init the metadata for a forward pass."""
|
94
|
+
tp_size = get_attention_tp_size()
|
86
95
|
self.forward_metadata = ForwardMetadata()
|
87
96
|
|
88
97
|
self.forward_metadata.block_tables = (
|
@@ -96,9 +105,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
96
105
|
forward_batch.extend_seq_lens.cpu().int()
|
97
106
|
)
|
98
107
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
99
|
-
|
100
|
-
|
101
|
-
|
108
|
+
|
109
|
+
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
110
|
+
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
102
111
|
|
103
112
|
self.graph_mode = False
|
104
113
|
|
@@ -119,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
|
|
119
128
|
seq_lens: torch.Tensor,
|
120
129
|
encoder_lens: Optional[torch.Tensor],
|
121
130
|
forward_mode: ForwardMode,
|
122
|
-
spec_info: Optional[
|
131
|
+
spec_info: Optional[SpecInput],
|
123
132
|
):
|
124
133
|
metadata = ForwardMetadata()
|
125
134
|
|
126
135
|
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
127
136
|
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
137
|
+
metadata.seq_lens = seq_lens
|
138
|
+
metadata.actual_seq_lengths_q = torch.tensor(
|
139
|
+
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
140
|
+
)
|
128
141
|
|
129
142
|
self.graph_metadata[bs] = metadata
|
130
143
|
self.forward_metadata = metadata
|
@@ -139,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
139
152
|
seq_lens_sum: int,
|
140
153
|
encoder_lens: Optional[torch.Tensor],
|
141
154
|
forward_mode: ForwardMode,
|
142
|
-
spec_info: Optional[
|
155
|
+
spec_info: Optional[SpecInput],
|
143
156
|
seq_lens_cpu: Optional[torch.Tensor],
|
144
157
|
):
|
145
158
|
metadata = self.graph_metadata[bs]
|
@@ -153,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
|
|
153
166
|
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
154
167
|
metadata.block_tables[bs:, :].fill_(0)
|
155
168
|
|
169
|
+
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
170
|
+
|
156
171
|
self.forward_metadata = metadata
|
157
172
|
|
158
173
|
self.graph_mode = True
|
@@ -160,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
|
|
160
175
|
def get_cuda_graph_seq_len_fill_value(self):
|
161
176
|
return 0
|
162
177
|
|
178
|
+
def forward_sparse(
|
179
|
+
self,
|
180
|
+
q: torch.Tensor,
|
181
|
+
k: torch.Tensor,
|
182
|
+
v: torch.Tensor,
|
183
|
+
layer: RadixAttention,
|
184
|
+
forward_batch: ForwardBatch,
|
185
|
+
save_kv_cache: bool = True,
|
186
|
+
# For multi_head latent attention
|
187
|
+
q_rope: Optional[torch.Tensor] = None,
|
188
|
+
k_rope: Optional[torch.Tensor] = None,
|
189
|
+
topk_indices: torch.Tensor = None,
|
190
|
+
):
|
191
|
+
|
192
|
+
is_prefill = forward_batch.forward_mode.is_extend()
|
193
|
+
|
194
|
+
if save_kv_cache:
|
195
|
+
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
196
|
+
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
197
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
199
|
+
)
|
200
|
+
q_nope, q_pe = q, q_rope
|
201
|
+
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
202
|
+
block_table = self.forward_metadata.block_tables
|
203
|
+
if is_prefill:
|
204
|
+
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
205
|
+
else:
|
206
|
+
if self.forward_metadata.actual_seq_lengths_q is None:
|
207
|
+
actual_seq_qlen = (
|
208
|
+
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
212
|
+
if self.forward_metadata.seq_lens_cpu_int is None:
|
213
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
214
|
+
else:
|
215
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
216
|
+
|
217
|
+
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
218
|
+
query=q_nope,
|
219
|
+
key=k_nope,
|
220
|
+
value=k_nope,
|
221
|
+
query_rope=q_pe,
|
222
|
+
key_rope=k_pe,
|
223
|
+
sparse_indices=topk_indices,
|
224
|
+
scale_value=layer.scaling,
|
225
|
+
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
226
|
+
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
227
|
+
block_table=block_table,
|
228
|
+
sparse_block_size=1,
|
229
|
+
layout_query="TND",
|
230
|
+
layout_kv="PA_BSND",
|
231
|
+
sparse_mode=3,
|
232
|
+
)
|
233
|
+
|
234
|
+
return attn_out
|
235
|
+
|
163
236
|
def forward_extend(
|
164
237
|
self,
|
165
238
|
q,
|
@@ -168,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
|
|
168
241
|
layer: RadixAttention,
|
169
242
|
forward_batch: ForwardBatch,
|
170
243
|
save_kv_cache: bool = True,
|
244
|
+
# For multi_head latent attention
|
245
|
+
q_rope: Optional[torch.Tensor] = None,
|
246
|
+
k_rope: Optional[torch.Tensor] = None,
|
247
|
+
topk_indices: Optional[torch.Tensor] = None,
|
171
248
|
):
|
249
|
+
if topk_indices is not None:
|
250
|
+
return self.forward_sparse(
|
251
|
+
q,
|
252
|
+
k,
|
253
|
+
v,
|
254
|
+
layer,
|
255
|
+
forward_batch,
|
256
|
+
save_kv_cache,
|
257
|
+
q_rope,
|
258
|
+
k_rope,
|
259
|
+
topk_indices,
|
260
|
+
)
|
172
261
|
if not self.use_mla:
|
173
262
|
if save_kv_cache:
|
174
263
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -368,7 +457,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
368
457
|
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
369
458
|
)
|
370
459
|
|
371
|
-
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
460
|
+
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
|
372
461
|
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
373
462
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
374
463
|
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
@@ -394,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
394
483
|
antiquant_scale=None,
|
395
484
|
sparse_mode=0,
|
396
485
|
)
|
397
|
-
output = torch.
|
486
|
+
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
|
398
487
|
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
399
488
|
|
400
489
|
torch_npu.npu_fused_infer_attention_score.out(
|
@@ -429,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
|
|
429
518
|
# For multi-head latent attention
|
430
519
|
q_rope: Optional[torch.Tensor] = None,
|
431
520
|
k_rope: Optional[torch.Tensor] = None,
|
521
|
+
topk_indices: Optional[torch.Tensor] = None,
|
432
522
|
):
|
523
|
+
if is_mla_preprocess_enabled():
|
524
|
+
# MLAPO does saving kv_cache
|
525
|
+
save_kv_cache = False
|
526
|
+
if topk_indices is not None:
|
527
|
+
return self.forward_sparse(
|
528
|
+
q,
|
529
|
+
k,
|
530
|
+
v,
|
531
|
+
layer,
|
532
|
+
forward_batch,
|
533
|
+
save_kv_cache,
|
534
|
+
q_rope,
|
535
|
+
k_rope,
|
536
|
+
topk_indices,
|
537
|
+
)
|
538
|
+
|
433
539
|
if self.graph_mode:
|
434
540
|
return self.forward_decode_graph(
|
435
541
|
q,
|