sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/layers/activation.py
CHANGED
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
|
|
35
35
|
is_cuda,
|
36
36
|
is_hip,
|
37
37
|
is_npu,
|
38
|
+
is_xpu,
|
38
39
|
set_weight_attrs,
|
39
40
|
)
|
40
41
|
from sglang.utils import resolve_obj_by_qualname
|
@@ -44,8 +45,9 @@ _is_npu = is_npu()
|
|
44
45
|
_is_cpu_amx_available = cpu_has_amx_support()
|
45
46
|
_is_cpu = is_cpu()
|
46
47
|
_is_hip = is_hip()
|
48
|
+
_is_xpu = is_xpu()
|
47
49
|
|
48
|
-
if _is_cuda:
|
50
|
+
if _is_cuda or _is_xpu:
|
49
51
|
from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
|
50
52
|
elif _is_hip:
|
51
53
|
from sgl_kernel import gelu_and_mul, gelu_quick, gelu_tanh_and_mul, silu_and_mul
|
@@ -70,8 +72,6 @@ class SiluAndMul(CustomOp):
|
|
70
72
|
|
71
73
|
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
72
74
|
if _is_cpu_amx_available:
|
73
|
-
d = x.shape[-1] // 2
|
74
|
-
output_shape = x.shape[:-1] + (d,)
|
75
75
|
out = torch.ops.sgl_kernel.silu_and_mul_cpu(x)
|
76
76
|
return out
|
77
77
|
else:
|
@@ -81,17 +81,20 @@ class SiluAndMul(CustomOp):
|
|
81
81
|
out = torch_npu.npu_swiglu(x)
|
82
82
|
return out
|
83
83
|
|
84
|
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
85
|
+
d = x.shape[-1] // 2
|
86
|
+
output_shape = x.shape[:-1] + (d,)
|
87
|
+
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
88
|
+
silu_and_mul(x, out)
|
89
|
+
return out
|
90
|
+
|
84
91
|
|
85
92
|
class GeluAndMul(CustomOp):
|
86
93
|
def __init__(self, approximate="tanh"):
|
87
94
|
super().__init__()
|
88
95
|
self.approximate = approximate
|
89
96
|
|
90
|
-
def
|
91
|
-
d = x.shape[-1] // 2
|
92
|
-
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
93
|
-
|
94
|
-
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
97
|
+
def _forward_impl(self, x: torch.Tensor) -> torch.Tensor:
|
95
98
|
d = x.shape[-1] // 2
|
96
99
|
output_shape = x.shape[:-1] + (d,)
|
97
100
|
out = torch.empty(output_shape, dtype=x.dtype, device=x.device)
|
@@ -103,6 +106,24 @@ class GeluAndMul(CustomOp):
|
|
103
106
|
raise RuntimeError("GeluAndMul only support tanh or none")
|
104
107
|
return out
|
105
108
|
|
109
|
+
def forward_native(self, x: torch.Tensor) -> torch.Tensor:
|
110
|
+
d = x.shape[-1] // 2
|
111
|
+
return F.gelu(x[..., :d], approximate=self.approximate) * x[..., d:]
|
112
|
+
|
113
|
+
def forward_cpu(self, x: torch.Tensor) -> torch.Tensor:
|
114
|
+
if _is_cpu_amx_available and self.approximate == "tanh":
|
115
|
+
return torch.ops.sgl_kernel.gelu_tanh_and_mul_cpu(x)
|
116
|
+
elif _is_cpu_amx_available and self.approximate == "none":
|
117
|
+
return torch.ops.sgl_kernel.gelu_and_mul_cpu(x)
|
118
|
+
else:
|
119
|
+
return self.forward_native(x)
|
120
|
+
|
121
|
+
def forward_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
122
|
+
return self._forward_impl(x)
|
123
|
+
|
124
|
+
def forward_xpu(self, x: torch.Tensor) -> torch.Tensor:
|
125
|
+
return self._forward_impl(x)
|
126
|
+
|
106
127
|
def forward_npu(self, x: torch.Tensor) -> torch.Tensor:
|
107
128
|
y_npu, gelu_npu = torch_npu.npu_geglu(
|
108
129
|
x,
|
@@ -150,6 +171,116 @@ class QuickGELU(CustomOp):
|
|
150
171
|
return torch_npu.npu_fast_gelu(x)
|
151
172
|
|
152
173
|
|
174
|
+
class XIELU(CustomOp):
|
175
|
+
"""
|
176
|
+
Applies the xIELU activation function introduced in https://arxiv.org/abs/2411.13010
|
177
|
+
If the user has installed the nickjbrowning/XIELU, we import xIELU CUDA
|
178
|
+
Otherwise, we emit a single warning and use xIELU Python
|
179
|
+
"""
|
180
|
+
|
181
|
+
def __init__(
|
182
|
+
self,
|
183
|
+
alpha_p_init: float = 0.8,
|
184
|
+
alpha_n_init: float = 0.8,
|
185
|
+
beta: float = 0.5,
|
186
|
+
eps: float = -1e-6,
|
187
|
+
dtype: torch.dtype = torch.bfloat16,
|
188
|
+
with_vector_loads: bool = False,
|
189
|
+
):
|
190
|
+
super().__init__()
|
191
|
+
self.alpha_p = nn.Parameter(
|
192
|
+
torch.log(torch.exp(torch.tensor(alpha_p_init, dtype=dtype)) - 1).unsqueeze(
|
193
|
+
0
|
194
|
+
)
|
195
|
+
)
|
196
|
+
self.alpha_n = nn.Parameter(
|
197
|
+
torch.log(
|
198
|
+
torch.exp(torch.tensor(alpha_n_init - beta, dtype=dtype)) - 1
|
199
|
+
).unsqueeze(0)
|
200
|
+
)
|
201
|
+
self.register_buffer("beta", torch.tensor(beta, dtype=dtype))
|
202
|
+
self.register_buffer("eps", torch.tensor(eps, dtype=dtype))
|
203
|
+
self.with_vector_loads = with_vector_loads
|
204
|
+
# Temporary until xIELU CUDA fully implemented
|
205
|
+
self._beta_scalar = float(self.beta.detach().cpu().float().item())
|
206
|
+
self._eps_scalar = float(self.eps.detach().cpu().float().item())
|
207
|
+
|
208
|
+
self._xielu_cuda_obj = None
|
209
|
+
try:
|
210
|
+
import xielu.ops # noqa: F401
|
211
|
+
|
212
|
+
self._xielu_cuda_obj = torch.classes.xielu.XIELU()
|
213
|
+
msg = "Using experimental xIELU CUDA."
|
214
|
+
try:
|
215
|
+
from torch._dynamo import allow_in_graph
|
216
|
+
|
217
|
+
self._xielu_cuda_fn = allow_in_graph(self._xielu_cuda)
|
218
|
+
msg += " Enabled torch._dynamo for xIELU CUDA."
|
219
|
+
except Exception as err:
|
220
|
+
msg += (
|
221
|
+
f" Could not enable torch._dynamo for xIELU ({err}) - "
|
222
|
+
"this may result in slower performance."
|
223
|
+
)
|
224
|
+
self._xielu_cuda_fn = self._xielu_cuda
|
225
|
+
logger.warning_once(msg)
|
226
|
+
except Exception as err:
|
227
|
+
pass
|
228
|
+
# logger.warning_once(
|
229
|
+
# "CUDA-fused xIELU not available (%s) –"
|
230
|
+
# " falling back to a Python version.\n"
|
231
|
+
# "For CUDA xIELU (experimental), `pip install git+https://github.com/nickjbrowning/XIELU`",
|
232
|
+
# str(err),
|
233
|
+
# )
|
234
|
+
|
235
|
+
def _xielu_python(self, x: torch.Tensor) -> torch.Tensor:
|
236
|
+
alpha_p = nn.functional.softplus(self.alpha_p)
|
237
|
+
alpha_n = self.beta + nn.functional.softplus(self.alpha_n)
|
238
|
+
return torch.where(
|
239
|
+
x > 0,
|
240
|
+
alpha_p * x * x + self.beta * x,
|
241
|
+
(torch.expm1(torch.min(x, self.eps)) - x) * alpha_n + self.beta * x,
|
242
|
+
)
|
243
|
+
|
244
|
+
def _xielu_cuda(self, x: torch.Tensor) -> torch.Tensor:
|
245
|
+
"""Firewall function to prevent torch.compile from seeing .item()"""
|
246
|
+
assert self._xielu_cuda_obj is not None, "XIELU CUDA object must not be None"
|
247
|
+
original_shape = x.shape
|
248
|
+
# CUDA kernel expects 3D tensors, reshape if needed
|
249
|
+
while x.dim() < 3:
|
250
|
+
x = x.unsqueeze(0)
|
251
|
+
if x.dim() > 3:
|
252
|
+
x = x.view(-1, 1, x.size(-1))
|
253
|
+
if original_shape != x.shape:
|
254
|
+
logger.warning_once(
|
255
|
+
"Warning: xIELU input tensor expects 3 dimensions"
|
256
|
+
" but got (shape: %s). Reshaping to (shape: %s).\n"
|
257
|
+
"Note: For SGLang this may be expected if sending"
|
258
|
+
"[B*S,D] instead of [B,S,D].",
|
259
|
+
original_shape,
|
260
|
+
x.shape,
|
261
|
+
)
|
262
|
+
result = self._xielu_cuda_obj.forward(
|
263
|
+
x,
|
264
|
+
self.alpha_p,
|
265
|
+
self.alpha_n,
|
266
|
+
# Temporary until xIELU CUDA fully implemented -> self.{beta,eps}.item()
|
267
|
+
self._beta_scalar,
|
268
|
+
self._eps_scalar,
|
269
|
+
self.with_vector_loads,
|
270
|
+
)
|
271
|
+
return result.view(original_shape)
|
272
|
+
|
273
|
+
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
274
|
+
if self._xielu_cuda_obj is not None and input.is_cuda:
|
275
|
+
if not torch._dynamo.is_compiling():
|
276
|
+
return self._xielu_cuda_fn(input)
|
277
|
+
else:
|
278
|
+
logger.warning_once(
|
279
|
+
"torch._dynamo is compiling, using Python version of xIELU."
|
280
|
+
)
|
281
|
+
return self._xielu_python(input)
|
282
|
+
|
283
|
+
|
153
284
|
class ScaledActivation(nn.Module):
|
154
285
|
"""An activation function with post-scale parameters.
|
155
286
|
|
@@ -197,6 +328,7 @@ _ACTIVATION_REGISTRY = {
|
|
197
328
|
"gelu_pytorch_tanh": nn.GELU(approximate="tanh"),
|
198
329
|
"gelu_new": NewGELU(),
|
199
330
|
"relu2": ReLU2(),
|
331
|
+
"xielu": XIELU(),
|
200
332
|
}
|
201
333
|
|
202
334
|
|
@@ -242,7 +374,9 @@ def get_cross_encoder_activation_function(config: PretrainedConfig):
|
|
242
374
|
return nn.Identity()
|
243
375
|
|
244
376
|
|
245
|
-
if not (
|
377
|
+
if not (
|
378
|
+
_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip or _is_xpu
|
379
|
+
):
|
246
380
|
logger.info(
|
247
381
|
"sgl-kernel is not available on Non-NV, Non-AMD platforms or Non-AMX CPUs. Fallback to other kernel libraries."
|
248
382
|
)
|
@@ -4,18 +4,13 @@ from __future__ import annotations
|
|
4
4
|
end to end attention solution with aiter kernels
|
5
5
|
"""
|
6
6
|
|
7
|
-
import math
|
8
|
-
import os
|
9
7
|
from dataclasses import dataclass
|
10
8
|
from enum import Enum, auto
|
11
|
-
from
|
12
|
-
from typing import TYPE_CHECKING, List, Optional, Union
|
9
|
+
from typing import TYPE_CHECKING, Optional
|
13
10
|
|
14
11
|
import torch
|
15
12
|
import triton
|
16
|
-
import triton.language as tl
|
17
13
|
|
18
|
-
from sglang.global_config import global_config
|
19
14
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
20
15
|
from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
|
21
16
|
from sglang.srt.layers.dp_attention import (
|
@@ -27,7 +22,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMo
|
|
27
22
|
if TYPE_CHECKING:
|
28
23
|
from sglang.srt.layers.radix_attention import RadixAttention
|
29
24
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
30
|
-
from sglang.srt.speculative.spec_info import
|
25
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
31
26
|
|
32
27
|
try:
|
33
28
|
from aiter import (
|
@@ -374,7 +369,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
374
369
|
seq_lens: torch.Tensor,
|
375
370
|
encoder_lens: Optional[torch.Tensor],
|
376
371
|
forward_mode: ForwardMode,
|
377
|
-
spec_info: Optional[
|
372
|
+
spec_info: Optional[SpecInput],
|
378
373
|
):
|
379
374
|
if forward_mode.is_decode_or_idle():
|
380
375
|
qo_indptr = None
|
@@ -509,7 +504,7 @@ class AiterAttnBackend(AttentionBackend):
|
|
509
504
|
seq_lens_sum: int,
|
510
505
|
encoder_lens: Optional[torch.Tensor],
|
511
506
|
forward_mode: ForwardMode,
|
512
|
-
spec_info: Optional[
|
507
|
+
spec_info: Optional[SpecInput],
|
513
508
|
seq_lens_cpu: Optional[torch.Tensor],
|
514
509
|
):
|
515
510
|
if forward_mode.is_decode_or_idle():
|
@@ -619,7 +614,11 @@ class AiterAttnBackend(AttentionBackend):
|
|
619
614
|
assert len(k.shape) == 3
|
620
615
|
assert len(v.shape) == 3
|
621
616
|
|
622
|
-
if
|
617
|
+
if (
|
618
|
+
forward_batch.forward_mode.is_extend()
|
619
|
+
and not forward_batch.forward_mode.is_target_verify()
|
620
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
621
|
+
):
|
623
622
|
if kv_indices.shape[0] == 0:
|
624
623
|
o = flash_attn_varlen_func(
|
625
624
|
q,
|
@@ -884,7 +883,7 @@ class AiterIndicesUpdaterPrefill:
|
|
884
883
|
seq_lens_sum: int,
|
885
884
|
prefix_lens: torch.Tensor,
|
886
885
|
encoder_lens: Optional[torch.Tensor],
|
887
|
-
spec_info: Optional[
|
886
|
+
spec_info: Optional[SpecInput],
|
888
887
|
):
|
889
888
|
# Keep the signature for type checking. It will be assigned during runtime.
|
890
889
|
raise NotImplementedError()
|
@@ -896,7 +895,7 @@ class AiterIndicesUpdaterPrefill:
|
|
896
895
|
seq_lens_sum: int,
|
897
896
|
prefix_lens: torch.Tensor,
|
898
897
|
encoder_lens: Optional[torch.Tensor],
|
899
|
-
spec_info: Optional[
|
898
|
+
spec_info: Optional[SpecInput],
|
900
899
|
):
|
901
900
|
|
902
901
|
kv_start_idx = None
|
@@ -980,7 +979,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
980
979
|
extend_lens: torch.Tensor,
|
981
980
|
max_q_len: int,
|
982
981
|
max_kv_len: int,
|
983
|
-
spec_info: Optional[
|
982
|
+
spec_info: Optional[SpecInput],
|
984
983
|
):
|
985
984
|
# Keep the signature for type checking. It will be assigned during runtime.
|
986
985
|
raise NotImplementedError()
|
@@ -993,7 +992,7 @@ class AiterMlaIndicesUpdaterPrefill:
|
|
993
992
|
extend_lens: torch.Tensor,
|
994
993
|
max_q_len: int,
|
995
994
|
max_kv_len: int,
|
996
|
-
spec_info: Optional[
|
995
|
+
spec_info: Optional[SpecInput],
|
997
996
|
):
|
998
997
|
bs = len(req_pool_indices)
|
999
998
|
|
@@ -1050,7 +1049,7 @@ class AiterMultiStepDraftBackend:
|
|
1050
1049
|
topk: int,
|
1051
1050
|
speculative_num_steps: int,
|
1052
1051
|
):
|
1053
|
-
from sglang.srt.speculative.
|
1052
|
+
from sglang.srt.speculative.spec_utils import generate_draft_decode_kv_indices
|
1054
1053
|
|
1055
1054
|
self.topk = topk
|
1056
1055
|
self.speculative_num_steps = speculative_num_steps
|
@@ -5,13 +5,15 @@ from typing import TYPE_CHECKING, List, Optional
|
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch_npu
|
8
|
-
from torch.nn.functional import scaled_dot_product_attention
|
9
8
|
|
10
9
|
from sglang.srt.configs.model_config import AttentionArch
|
11
10
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import is_mla_preprocess_enabled
|
12
12
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
13
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
13
14
|
from sglang.srt.layers.radix_attention import AttentionType
|
14
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
15
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
16
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
15
17
|
from sglang.srt.utils import get_bool_env_var
|
16
18
|
|
17
19
|
if TYPE_CHECKING:
|
@@ -33,6 +35,9 @@ class ForwardMetadata:
|
|
33
35
|
extend_seq_lens_cpu_int: Optional[torch.Tensor] = None
|
34
36
|
seq_lens_cpu_int: Optional[torch.Tensor] = None
|
35
37
|
seq_lens_cpu_list: Optional[List[int]] = None
|
38
|
+
seq_lens_list_cumsum: Optional[List[int]] = None
|
39
|
+
seq_lens: Optional[torch.Tensor] = None
|
40
|
+
actual_seq_lengths_q: Optional[torch.Tensor] = None
|
36
41
|
|
37
42
|
|
38
43
|
class AscendAttnBackend(AttentionBackend):
|
@@ -64,6 +69,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
64
69
|
if self.use_mla:
|
65
70
|
self.kv_lora_rank = model_runner.model_config.kv_lora_rank
|
66
71
|
self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
|
72
|
+
self.q_head_dim = (
|
73
|
+
self.qk_rope_head_dim + model_runner.model_config.qk_nope_head_dim
|
74
|
+
)
|
67
75
|
self.native_attn = TorchNativeAttnBackend(model_runner)
|
68
76
|
self.graph_metadata = {}
|
69
77
|
self.max_context_len = model_runner.model_config.context_len
|
@@ -83,6 +91,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
83
91
|
|
84
92
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
85
93
|
"""Init the metadata for a forward pass."""
|
94
|
+
tp_size = get_attention_tp_size()
|
86
95
|
self.forward_metadata = ForwardMetadata()
|
87
96
|
|
88
97
|
self.forward_metadata.block_tables = (
|
@@ -96,9 +105,9 @@ class AscendAttnBackend(AttentionBackend):
|
|
96
105
|
forward_batch.extend_seq_lens.cpu().int()
|
97
106
|
)
|
98
107
|
self.forward_metadata.seq_lens_cpu_int = forward_batch.seq_lens_cpu.int()
|
99
|
-
|
100
|
-
|
101
|
-
|
108
|
+
|
109
|
+
seq_lens_list_cumsum = np.cumsum(forward_batch.extend_seq_lens_cpu)
|
110
|
+
self.forward_metadata.seq_lens_list_cumsum = seq_lens_list_cumsum
|
102
111
|
|
103
112
|
self.graph_mode = False
|
104
113
|
|
@@ -119,12 +128,16 @@ class AscendAttnBackend(AttentionBackend):
|
|
119
128
|
seq_lens: torch.Tensor,
|
120
129
|
encoder_lens: Optional[torch.Tensor],
|
121
130
|
forward_mode: ForwardMode,
|
122
|
-
spec_info: Optional[
|
131
|
+
spec_info: Optional[SpecInput],
|
123
132
|
):
|
124
133
|
metadata = ForwardMetadata()
|
125
134
|
|
126
135
|
metadata.block_tables = self.graph_metadata["block_tables"][:bs, :]
|
127
136
|
metadata.seq_lens_cpu_list = seq_lens.cpu().int().tolist()
|
137
|
+
metadata.seq_lens = seq_lens
|
138
|
+
metadata.actual_seq_lengths_q = torch.tensor(
|
139
|
+
[1 + i * 1 for i in range(bs)], dtype=torch.int32, device=seq_lens.device
|
140
|
+
)
|
128
141
|
|
129
142
|
self.graph_metadata[bs] = metadata
|
130
143
|
self.forward_metadata = metadata
|
@@ -139,7 +152,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
139
152
|
seq_lens_sum: int,
|
140
153
|
encoder_lens: Optional[torch.Tensor],
|
141
154
|
forward_mode: ForwardMode,
|
142
|
-
spec_info: Optional[
|
155
|
+
spec_info: Optional[SpecInput],
|
143
156
|
seq_lens_cpu: Optional[torch.Tensor],
|
144
157
|
):
|
145
158
|
metadata = self.graph_metadata[bs]
|
@@ -153,6 +166,8 @@ class AscendAttnBackend(AttentionBackend):
|
|
153
166
|
metadata.block_tables[:bs, max_seq_pages:].fill_(0)
|
154
167
|
metadata.block_tables[bs:, :].fill_(0)
|
155
168
|
|
169
|
+
metadata.seq_lens[:bs].copy_(seq_lens[:bs])
|
170
|
+
|
156
171
|
self.forward_metadata = metadata
|
157
172
|
|
158
173
|
self.graph_mode = True
|
@@ -160,6 +175,64 @@ class AscendAttnBackend(AttentionBackend):
|
|
160
175
|
def get_cuda_graph_seq_len_fill_value(self):
|
161
176
|
return 0
|
162
177
|
|
178
|
+
def forward_sparse(
|
179
|
+
self,
|
180
|
+
q: torch.Tensor,
|
181
|
+
k: torch.Tensor,
|
182
|
+
v: torch.Tensor,
|
183
|
+
layer: RadixAttention,
|
184
|
+
forward_batch: ForwardBatch,
|
185
|
+
save_kv_cache: bool = True,
|
186
|
+
# For multi_head latent attention
|
187
|
+
q_rope: Optional[torch.Tensor] = None,
|
188
|
+
k_rope: Optional[torch.Tensor] = None,
|
189
|
+
topk_indices: torch.Tensor = None,
|
190
|
+
):
|
191
|
+
|
192
|
+
is_prefill = forward_batch.forward_mode.is_extend()
|
193
|
+
|
194
|
+
if save_kv_cache:
|
195
|
+
k = k.view(-1, layer.tp_k_head_num, self.kv_lora_rank)
|
196
|
+
k_rope = k_rope.view(-1, layer.tp_k_head_num, self.qk_rope_head_dim)
|
197
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
198
|
+
layer, forward_batch.out_cache_loc, k, k_rope
|
199
|
+
)
|
200
|
+
q_nope, q_pe = q, q_rope
|
201
|
+
k_nope, k_pe = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
202
|
+
block_table = self.forward_metadata.block_tables
|
203
|
+
if is_prefill:
|
204
|
+
actual_seq_qlen = torch.cumsum(forward_batch.seq_lens, dim=0)
|
205
|
+
else:
|
206
|
+
if self.forward_metadata.actual_seq_lengths_q is None:
|
207
|
+
actual_seq_qlen = (
|
208
|
+
torch.arange(1, q.shape[0] + 1).to(q.device).to(torch.int32)
|
209
|
+
)
|
210
|
+
else:
|
211
|
+
actual_seq_qlen = self.forward_metadata.actual_seq_lengths_q
|
212
|
+
if self.forward_metadata.seq_lens_cpu_int is None:
|
213
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens
|
214
|
+
else:
|
215
|
+
actual_seq_lengths_kv = self.forward_metadata.seq_lens_cpu_int
|
216
|
+
|
217
|
+
attn_out = torch.ops.custom.npu_sparse_flash_attention(
|
218
|
+
query=q_nope,
|
219
|
+
key=k_nope,
|
220
|
+
value=k_nope,
|
221
|
+
query_rope=q_pe,
|
222
|
+
key_rope=k_pe,
|
223
|
+
sparse_indices=topk_indices,
|
224
|
+
scale_value=layer.scaling,
|
225
|
+
actual_seq_lengths_query=actual_seq_qlen.to(torch.int32),
|
226
|
+
actual_seq_lengths_kv=actual_seq_lengths_kv.to(q.device),
|
227
|
+
block_table=block_table,
|
228
|
+
sparse_block_size=1,
|
229
|
+
layout_query="TND",
|
230
|
+
layout_kv="PA_BSND",
|
231
|
+
sparse_mode=3,
|
232
|
+
)
|
233
|
+
|
234
|
+
return attn_out
|
235
|
+
|
163
236
|
def forward_extend(
|
164
237
|
self,
|
165
238
|
q,
|
@@ -168,7 +241,23 @@ class AscendAttnBackend(AttentionBackend):
|
|
168
241
|
layer: RadixAttention,
|
169
242
|
forward_batch: ForwardBatch,
|
170
243
|
save_kv_cache: bool = True,
|
244
|
+
# For multi_head latent attention
|
245
|
+
q_rope: Optional[torch.Tensor] = None,
|
246
|
+
k_rope: Optional[torch.Tensor] = None,
|
247
|
+
topk_indices: Optional[torch.Tensor] = None,
|
171
248
|
):
|
249
|
+
if topk_indices is not None:
|
250
|
+
return self.forward_sparse(
|
251
|
+
q,
|
252
|
+
k,
|
253
|
+
v,
|
254
|
+
layer,
|
255
|
+
forward_batch,
|
256
|
+
save_kv_cache,
|
257
|
+
q_rope,
|
258
|
+
k_rope,
|
259
|
+
topk_indices,
|
260
|
+
)
|
172
261
|
if not self.use_mla:
|
173
262
|
if save_kv_cache:
|
174
263
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
@@ -368,7 +457,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
368
457
|
-1, layer.tp_v_head_num, self.page_size, self.kv_lora_rank
|
369
458
|
)
|
370
459
|
|
371
|
-
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank)
|
460
|
+
q_nope = q.view(-1, layer.tp_q_head_num, 1, self.kv_lora_rank).contiguous()
|
372
461
|
q_rope = q_rope.view(-1, layer.tp_q_head_num, 1, self.qk_rope_head_dim)
|
373
462
|
if self.forward_metadata.seq_lens_cpu_int is None:
|
374
463
|
actual_seq_len_kv = self.forward_metadata.seq_lens_cpu_list
|
@@ -394,7 +483,7 @@ class AscendAttnBackend(AttentionBackend):
|
|
394
483
|
antiquant_scale=None,
|
395
484
|
sparse_mode=0,
|
396
485
|
)
|
397
|
-
output = torch.
|
486
|
+
output = torch.empty_like(q_nope, dtype=q.dtype, device=q.device)
|
398
487
|
softmax_lse = torch.empty(1, dtype=q.dtype, device=q.device)
|
399
488
|
|
400
489
|
torch_npu.npu_fused_infer_attention_score.out(
|
@@ -429,7 +518,24 @@ class AscendAttnBackend(AttentionBackend):
|
|
429
518
|
# For multi-head latent attention
|
430
519
|
q_rope: Optional[torch.Tensor] = None,
|
431
520
|
k_rope: Optional[torch.Tensor] = None,
|
521
|
+
topk_indices: Optional[torch.Tensor] = None,
|
432
522
|
):
|
523
|
+
if is_mla_preprocess_enabled():
|
524
|
+
# MLAPO does saving kv_cache
|
525
|
+
save_kv_cache = False
|
526
|
+
if topk_indices is not None:
|
527
|
+
return self.forward_sparse(
|
528
|
+
q,
|
529
|
+
k,
|
530
|
+
v,
|
531
|
+
layer,
|
532
|
+
forward_batch,
|
533
|
+
save_kv_cache,
|
534
|
+
q_rope,
|
535
|
+
k_rope,
|
536
|
+
topk_indices,
|
537
|
+
)
|
538
|
+
|
433
539
|
if self.graph_mode:
|
434
540
|
return self.forward_decode_graph(
|
435
541
|
q,
|