sglang 0.5.2rc2__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,887 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import sys
|
4
|
+
from dataclasses import dataclass
|
5
|
+
from typing import TYPE_CHECKING, Dict, List, Literal, Optional, TypeAlias
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.configs.model_config import get_nsa_index_topk, is_deepseek_nsa
|
10
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
11
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import BaseIndexerMetadata
|
12
|
+
from sglang.srt.layers.attention.nsa.quant_k_cache import quantize_k_cache
|
13
|
+
from sglang.srt.layers.attention.nsa.transform_index import (
|
14
|
+
transform_index_page_table_decode,
|
15
|
+
transform_index_page_table_prefill,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.attention.nsa.utils import (
|
18
|
+
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
19
|
+
NSA_FUSE_TOPK,
|
20
|
+
compute_nsa_seqlens,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.dp_attention import get_attention_tp_size
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
24
|
+
from sglang.srt.utils import is_hip
|
25
|
+
|
26
|
+
# from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
27
|
+
|
28
|
+
if TYPE_CHECKING:
|
29
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
31
|
+
from sglang.srt.speculative.spec_info import SpecInput
|
32
|
+
|
33
|
+
_is_hip = is_hip()
|
34
|
+
|
35
|
+
if _is_hip:
|
36
|
+
try:
|
37
|
+
from aiter import (
|
38
|
+
flash_attn_varlen_func,
|
39
|
+
mha_batch_prefill_func,
|
40
|
+
paged_attention_ragged,
|
41
|
+
)
|
42
|
+
from aiter.mla import mla_decode_fwd, mla_prefill_fwd
|
43
|
+
except ImportError:
|
44
|
+
print(
|
45
|
+
"aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
|
46
|
+
)
|
47
|
+
else:
|
48
|
+
from sgl_kernel.flash_attn import flash_attn_varlen_func, flash_attn_with_kvcache
|
49
|
+
|
50
|
+
|
51
|
+
@dataclass(frozen=True)
|
52
|
+
class NSAFlashMLAMetadata:
|
53
|
+
"""Metadata only needed by FlashMLA"""
|
54
|
+
|
55
|
+
flashmla_metadata: torch.Tensor
|
56
|
+
num_splits: torch.Tensor
|
57
|
+
|
58
|
+
def slice(self, sli):
|
59
|
+
return NSAFlashMLAMetadata(
|
60
|
+
flashmla_metadata=self.flashmla_metadata,
|
61
|
+
num_splits=self.num_splits[sli],
|
62
|
+
)
|
63
|
+
|
64
|
+
def copy_(self, other: "NSAFlashMLAMetadata"):
|
65
|
+
self.flashmla_metadata.copy_(other.flashmla_metadata)
|
66
|
+
self.num_splits.copy_(other.num_splits)
|
67
|
+
|
68
|
+
|
69
|
+
@dataclass(frozen=True)
|
70
|
+
class NSAMetadata:
|
71
|
+
page_size: int
|
72
|
+
|
73
|
+
# Sequence lengths for the forward batch
|
74
|
+
cache_seqlens_int32: torch.Tensor
|
75
|
+
# Maximum sequence length for query
|
76
|
+
max_seq_len_q: int
|
77
|
+
# Maximum sequence length for key
|
78
|
+
max_seq_len_k: int
|
79
|
+
# Cumulative sequence lengths for query
|
80
|
+
cu_seqlens_q: torch.Tensor
|
81
|
+
# Cumulative sequence lengths for key
|
82
|
+
cu_seqlens_k: torch.Tensor
|
83
|
+
# Page table, the index of KV Cache Tables/Blocks
|
84
|
+
# this table is always with page_size = 1
|
85
|
+
page_table_1: torch.Tensor
|
86
|
+
|
87
|
+
# NOTE(dark): This will property be used in:
|
88
|
+
# 1. dense decode/prefill, we use paged flash attention, need real_page_table
|
89
|
+
# 2. sparse decode/prefill, indexer need real_page_table to compute the score
|
90
|
+
real_page_table: torch.Tensor
|
91
|
+
|
92
|
+
# NSA metadata (nsa prefill are expanded)
|
93
|
+
nsa_cache_seqlens_int32: torch.Tensor # this seqlens is clipped to `topk`
|
94
|
+
nsa_cu_seqlens_q: torch.Tensor # must be arange(0, len(nsa_cu_seqlens_k))
|
95
|
+
nsa_cu_seqlens_k: torch.Tensor # cumsum of `nsa_cache_seqlens_int32`
|
96
|
+
nsa_extend_seq_lens_list: List[int]
|
97
|
+
nsa_seqlens_expanded: torch.Tensor # expanded, unclipped `seqlens`
|
98
|
+
nsa_max_seqlen_q: Literal[1] = 1 # always 1 for decode, variable for extend
|
99
|
+
|
100
|
+
flashmla_metadata: Optional[NSAFlashMLAMetadata] = None
|
101
|
+
|
102
|
+
|
103
|
+
@dataclass(frozen=True)
|
104
|
+
class NSAIndexerMetadata(BaseIndexerMetadata):
|
105
|
+
attn_metadata: NSAMetadata
|
106
|
+
|
107
|
+
def get_seqlens_int32(self) -> torch.Tensor:
|
108
|
+
return self.attn_metadata.cache_seqlens_int32
|
109
|
+
|
110
|
+
def get_page_table_64(self) -> torch.Tensor:
|
111
|
+
return self.attn_metadata.real_page_table
|
112
|
+
|
113
|
+
def get_seqlens_expanded(self) -> torch.Tensor:
|
114
|
+
return self.attn_metadata.nsa_seqlens_expanded
|
115
|
+
|
116
|
+
def topk_transform(
|
117
|
+
self,
|
118
|
+
logits: torch.Tensor,
|
119
|
+
topk: int,
|
120
|
+
) -> torch.Tensor:
|
121
|
+
from sgl_kernel import fast_topk_transform_fused, fast_topk_v2
|
122
|
+
|
123
|
+
if not NSA_FUSE_TOPK:
|
124
|
+
return fast_topk_v2(logits, self.get_seqlens_expanded(), topk)
|
125
|
+
|
126
|
+
# NOTE(dark): if fused, we return a transformed page table directly
|
127
|
+
return fast_topk_transform_fused(
|
128
|
+
score=logits,
|
129
|
+
lengths=self.get_seqlens_expanded(),
|
130
|
+
page_table_size_1=self.attn_metadata.page_table_1,
|
131
|
+
cu_seqlens_q=self.attn_metadata.cu_seqlens_q,
|
132
|
+
topk=topk,
|
133
|
+
)
|
134
|
+
|
135
|
+
|
136
|
+
def compute_cu_seqlens(seqlens: torch.Tensor) -> torch.Tensor:
|
137
|
+
assert seqlens.dtype == torch.int32 and seqlens.is_cuda
|
138
|
+
return torch.nn.functional.pad(
|
139
|
+
torch.cumsum(seqlens, dim=0, dtype=torch.int32), (1, 0)
|
140
|
+
)
|
141
|
+
|
142
|
+
|
143
|
+
_NSA_IMPL_T: TypeAlias = Literal[
|
144
|
+
"flashmla_prefill", "flashmla_decode", "fa3", "tilelang"
|
145
|
+
]
|
146
|
+
|
147
|
+
NSA_PREFILL_IMPL: _NSA_IMPL_T
|
148
|
+
NSA_DECODE_IMPL: _NSA_IMPL_T
|
149
|
+
|
150
|
+
|
151
|
+
class NativeSparseAttnBackend(AttentionBackend):
|
152
|
+
def __init__(self, model_runner: ModelRunner):
|
153
|
+
super().__init__()
|
154
|
+
self.forward_metadata: NSAMetadata
|
155
|
+
self.device = model_runner.device
|
156
|
+
assert isinstance(model_runner.page_size, int)
|
157
|
+
self.real_page_size = model_runner.page_size
|
158
|
+
self.num_splits = (
|
159
|
+
1 if model_runner.server_args.enable_deterministic_inference else 0
|
160
|
+
)
|
161
|
+
self.use_nsa = is_deepseek_nsa(model_runner.model_config.hf_config)
|
162
|
+
assert self.use_nsa, "NSA backend only supports DeepSeek NSA"
|
163
|
+
self.nsa_kv_cache_store_fp8 = (
|
164
|
+
model_runner.token_to_kv_pool.nsa_kv_cache_store_fp8
|
165
|
+
)
|
166
|
+
self.nsa_index_topk = get_nsa_index_topk(model_runner.model_config.hf_config)
|
167
|
+
self.max_context_len = model_runner.model_config.context_len
|
168
|
+
self.num_q_heads = (
|
169
|
+
model_runner.model_config.num_attention_heads // get_attention_tp_size()
|
170
|
+
)
|
171
|
+
self.kv_cache_dim = model_runner.token_to_kv_pool.kv_cache_dim
|
172
|
+
|
173
|
+
assert model_runner.req_to_token_pool is not None
|
174
|
+
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
175
|
+
|
176
|
+
global NSA_PREFILL_IMPL, NSA_DECODE_IMPL
|
177
|
+
NSA_PREFILL_IMPL = model_runner.server_args.nsa_prefill
|
178
|
+
NSA_DECODE_IMPL = model_runner.server_args.nsa_decode
|
179
|
+
|
180
|
+
self._arange_buf = torch.arange(16384, device=self.device, dtype=torch.int32)
|
181
|
+
|
182
|
+
if _is_hip:
|
183
|
+
max_bs = model_runner.req_to_token_pool.size
|
184
|
+
|
185
|
+
self.kv_indptr = torch.zeros(
|
186
|
+
(max_bs + 1,), dtype=torch.int32, device=model_runner.device
|
187
|
+
)
|
188
|
+
|
189
|
+
def get_device_int32_arange(self, l: int) -> torch.Tensor:
|
190
|
+
if l > len(self._arange_buf):
|
191
|
+
next_pow_of_2 = 1 << (l - 1).bit_length()
|
192
|
+
self._arange_buf = torch.arange(
|
193
|
+
next_pow_of_2, device=self.device, dtype=torch.int32
|
194
|
+
)
|
195
|
+
return self._arange_buf[:l]
|
196
|
+
|
197
|
+
def _transform_table_1_to_real(self, page_table: torch.Tensor) -> torch.Tensor:
|
198
|
+
page_size = self.real_page_size
|
199
|
+
if page_size == 1:
|
200
|
+
return page_table
|
201
|
+
max_seqlen_k = page_table.shape[1]
|
202
|
+
strided_indices = torch.arange(
|
203
|
+
0, max_seqlen_k, page_size, device=page_table.device, dtype=torch.int32
|
204
|
+
)
|
205
|
+
return page_table[:, strided_indices] // page_size
|
206
|
+
|
207
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
208
|
+
"""Init the metadata for a forward pass."""
|
209
|
+
batch_size = forward_batch.batch_size
|
210
|
+
device = forward_batch.seq_lens.device
|
211
|
+
|
212
|
+
assert (
|
213
|
+
forward_batch.spec_info is None
|
214
|
+
), "Spec decoding is not supported for NSA backend now"
|
215
|
+
cache_seqlens_int32 = forward_batch.seq_lens.to(torch.int32)
|
216
|
+
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
217
|
+
assert forward_batch.seq_lens_cpu is not None
|
218
|
+
max_seqlen_k = int(forward_batch.seq_lens_cpu.max().item())
|
219
|
+
page_table = forward_batch.req_to_token_pool.req_to_token[
|
220
|
+
forward_batch.req_pool_indices, :max_seqlen_k
|
221
|
+
]
|
222
|
+
|
223
|
+
if forward_batch.forward_mode.is_decode_or_idle():
|
224
|
+
extend_seq_lens_cpu = [1] * batch_size
|
225
|
+
max_seqlen_q = 1
|
226
|
+
cu_seqlens_q = self.get_device_int32_arange(batch_size + 1)
|
227
|
+
seqlens_expanded = cache_seqlens_int32
|
228
|
+
elif forward_batch.forward_mode.is_extend():
|
229
|
+
assert (
|
230
|
+
forward_batch.extend_seq_lens_cpu is not None
|
231
|
+
and forward_batch.extend_seq_lens is not None
|
232
|
+
and forward_batch.extend_prefix_lens_cpu is not None
|
233
|
+
), "All of them must not be None"
|
234
|
+
extend_seq_lens_cpu = forward_batch.extend_seq_lens_cpu
|
235
|
+
assert forward_batch.extend_seq_lens is not None
|
236
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
237
|
+
max_seqlen_q = max(extend_seq_lens_cpu)
|
238
|
+
cu_seqlens_q = compute_cu_seqlens(
|
239
|
+
forward_batch.extend_seq_lens.to(torch.int32)
|
240
|
+
)
|
241
|
+
else:
|
242
|
+
max_seqlen_q = max_seqlen_k
|
243
|
+
cu_seqlens_q = cu_seqlens_k
|
244
|
+
seqlens_expanded = torch.cat(
|
245
|
+
[
|
246
|
+
torch.arange(
|
247
|
+
kv_len - qo_len + 1,
|
248
|
+
kv_len + 1,
|
249
|
+
dtype=torch.int32,
|
250
|
+
device=device,
|
251
|
+
)
|
252
|
+
for qo_len, kv_len in zip(
|
253
|
+
forward_batch.extend_seq_lens_cpu,
|
254
|
+
forward_batch.seq_lens_cpu.tolist(),
|
255
|
+
strict=True,
|
256
|
+
)
|
257
|
+
]
|
258
|
+
)
|
259
|
+
else:
|
260
|
+
assert False, f"Unsupported {forward_batch.forward_mode = }"
|
261
|
+
|
262
|
+
# 1D, expanded seqlens (1D means cheap to compute, so always compute it)
|
263
|
+
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
264
|
+
original_seq_lens=seqlens_expanded,
|
265
|
+
nsa_index_topk=self.nsa_index_topk,
|
266
|
+
)
|
267
|
+
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
268
|
+
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
269
|
+
|
270
|
+
metadata = NSAMetadata(
|
271
|
+
page_size=self.real_page_size,
|
272
|
+
cache_seqlens_int32=cache_seqlens_int32,
|
273
|
+
max_seq_len_q=max_seqlen_q,
|
274
|
+
max_seq_len_k=max_seqlen_k,
|
275
|
+
cu_seqlens_q=cu_seqlens_q,
|
276
|
+
cu_seqlens_k=cu_seqlens_k,
|
277
|
+
page_table_1=page_table,
|
278
|
+
flashmla_metadata=(
|
279
|
+
self._compute_flashmla_metadata(
|
280
|
+
cache_seqlens=nsa_cache_seqlens_int32,
|
281
|
+
seq_len_q=1, # TODO handle MTP which is not 1
|
282
|
+
)
|
283
|
+
if NSA_DECODE_IMPL == "flashmla_decode"
|
284
|
+
else None
|
285
|
+
),
|
286
|
+
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
287
|
+
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
288
|
+
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
289
|
+
nsa_seqlens_expanded=seqlens_expanded,
|
290
|
+
nsa_extend_seq_lens_list=extend_seq_lens_cpu,
|
291
|
+
real_page_table=self._transform_table_1_to_real(page_table),
|
292
|
+
)
|
293
|
+
|
294
|
+
self.forward_metadata = metadata
|
295
|
+
|
296
|
+
def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
|
297
|
+
"""Initialize CUDA graph state for the attention backend.
|
298
|
+
|
299
|
+
Args:
|
300
|
+
max_bs (int): Maximum batch size to support in CUDA graphs
|
301
|
+
|
302
|
+
This creates fixed-size tensors that will be reused during CUDA graph replay
|
303
|
+
to avoid memory allocations.
|
304
|
+
"""
|
305
|
+
self.decode_cuda_graph_metadata: Dict = {
|
306
|
+
"cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
|
307
|
+
"cu_seqlens_q": torch.arange(
|
308
|
+
0, max_bs + 1, dtype=torch.int32, device=self.device
|
309
|
+
),
|
310
|
+
"cu_seqlens_k": torch.zeros(
|
311
|
+
max_bs + 1, dtype=torch.int32, device=self.device
|
312
|
+
),
|
313
|
+
# fake page_table for sparse_prefill
|
314
|
+
"page_table": torch.zeros(
|
315
|
+
max_bs,
|
316
|
+
self.max_context_len,
|
317
|
+
dtype=torch.int32,
|
318
|
+
device=self.device,
|
319
|
+
),
|
320
|
+
"flashmla_metadata": (
|
321
|
+
self._compute_flashmla_metadata(
|
322
|
+
cache_seqlens=torch.ones(
|
323
|
+
max_bs, dtype=torch.int32, device=self.device
|
324
|
+
),
|
325
|
+
seq_len_q=1, # TODO handle MTP which is not 1
|
326
|
+
)
|
327
|
+
if NSA_DECODE_IMPL == "flashmla_decode"
|
328
|
+
else None
|
329
|
+
),
|
330
|
+
}
|
331
|
+
|
332
|
+
def init_forward_metadata_capture_cuda_graph(
|
333
|
+
self,
|
334
|
+
bs: int,
|
335
|
+
num_tokens: int,
|
336
|
+
req_pool_indices: torch.Tensor,
|
337
|
+
seq_lens: torch.Tensor,
|
338
|
+
encoder_lens: Optional[torch.Tensor],
|
339
|
+
forward_mode: ForwardMode,
|
340
|
+
spec_info: Optional[SpecInput],
|
341
|
+
):
|
342
|
+
"""Initialize forward metadata for capturing CUDA graph."""
|
343
|
+
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
344
|
+
assert (
|
345
|
+
spec_info is None
|
346
|
+
), "Speculative decoding is not supported for NSA backend now"
|
347
|
+
|
348
|
+
# Normal Decode
|
349
|
+
# Get sequence information
|
350
|
+
cache_seqlens_int32 = seq_lens.to(torch.int32)
|
351
|
+
cu_seqlens_k = compute_cu_seqlens(cache_seqlens_int32)
|
352
|
+
|
353
|
+
# Use max context length for seq_len_k
|
354
|
+
page_table_1 = self.decode_cuda_graph_metadata["page_table"][:bs, :]
|
355
|
+
max_seq_len_k = page_table_1.shape[1]
|
356
|
+
|
357
|
+
# Precompute page table
|
358
|
+
# Precompute cumulative sequence lengths
|
359
|
+
|
360
|
+
# NOTE(dark): this is always arange, since we are decoding
|
361
|
+
cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][: bs + 1]
|
362
|
+
nsa_cache_seqlens_int32 = compute_nsa_seqlens(
|
363
|
+
cache_seqlens_int32, nsa_index_topk=self.nsa_index_topk
|
364
|
+
)
|
365
|
+
nsa_cu_seqlens_k = compute_cu_seqlens(nsa_cache_seqlens_int32)
|
366
|
+
nsa_cu_seqlens_q = self.get_device_int32_arange(len(nsa_cu_seqlens_k))
|
367
|
+
real_page_table = self._transform_table_1_to_real(page_table_1)
|
368
|
+
|
369
|
+
if NSA_DECODE_IMPL == "flashmla_decode":
|
370
|
+
flashmla_metadata = self.decode_cuda_graph_metadata[
|
371
|
+
"flashmla_metadata"
|
372
|
+
].slice(slice(0, bs + 1))
|
373
|
+
flashmla_metadata.copy_(
|
374
|
+
self._compute_flashmla_metadata(
|
375
|
+
cache_seqlens=nsa_cache_seqlens_int32,
|
376
|
+
seq_len_q=1, # TODO handle MTP which is not 1
|
377
|
+
)
|
378
|
+
)
|
379
|
+
else:
|
380
|
+
flashmla_metadata = None
|
381
|
+
|
382
|
+
metadata = NSAMetadata(
|
383
|
+
page_size=self.real_page_size,
|
384
|
+
cache_seqlens_int32=cache_seqlens_int32,
|
385
|
+
max_seq_len_q=1,
|
386
|
+
max_seq_len_k=max_seq_len_k,
|
387
|
+
cu_seqlens_q=cu_seqlens_q,
|
388
|
+
cu_seqlens_k=cu_seqlens_k,
|
389
|
+
page_table_1=page_table_1,
|
390
|
+
flashmla_metadata=flashmla_metadata,
|
391
|
+
nsa_cache_seqlens_int32=nsa_cache_seqlens_int32,
|
392
|
+
nsa_cu_seqlens_q=nsa_cu_seqlens_q,
|
393
|
+
nsa_cu_seqlens_k=nsa_cu_seqlens_k,
|
394
|
+
nsa_seqlens_expanded=cache_seqlens_int32,
|
395
|
+
real_page_table=real_page_table,
|
396
|
+
nsa_extend_seq_lens_list=[1] * bs,
|
397
|
+
)
|
398
|
+
self.decode_cuda_graph_metadata[bs] = metadata
|
399
|
+
self.forward_metadata = metadata
|
400
|
+
|
401
|
+
def init_forward_metadata_replay_cuda_graph(
|
402
|
+
self,
|
403
|
+
bs: int,
|
404
|
+
req_pool_indices: torch.Tensor,
|
405
|
+
seq_lens: torch.Tensor,
|
406
|
+
seq_lens_sum: int,
|
407
|
+
encoder_lens: Optional[torch.Tensor],
|
408
|
+
forward_mode: ForwardMode,
|
409
|
+
spec_info: Optional[SpecInput],
|
410
|
+
seq_lens_cpu: Optional[torch.Tensor],
|
411
|
+
out_cache_loc: Optional[torch.Tensor] = None,
|
412
|
+
):
|
413
|
+
"""Initialize forward metadata for replaying CUDA graph."""
|
414
|
+
assert seq_lens_cpu is not None
|
415
|
+
assert forward_mode.is_decode_or_idle(), "Only support decode for now"
|
416
|
+
assert (
|
417
|
+
spec_info is None
|
418
|
+
), "Speculative decoding is not supported for NSA backend now"
|
419
|
+
seq_lens = seq_lens[:bs]
|
420
|
+
seq_lens_cpu = seq_lens_cpu[:bs]
|
421
|
+
req_pool_indices = req_pool_indices[:bs]
|
422
|
+
|
423
|
+
# Normal Decode
|
424
|
+
metadata: NSAMetadata = self.decode_cuda_graph_metadata[bs]
|
425
|
+
max_len = int(seq_lens_cpu.max().item())
|
426
|
+
|
427
|
+
cache_seqlens = seq_lens.to(torch.int32)
|
428
|
+
metadata.cache_seqlens_int32.copy_(cache_seqlens)
|
429
|
+
metadata.cu_seqlens_k[1:].copy_(
|
430
|
+
torch.cumsum(cache_seqlens, dim=0, dtype=torch.int32)
|
431
|
+
)
|
432
|
+
page_indices = self.req_to_token[req_pool_indices, :max_len]
|
433
|
+
metadata.page_table_1[:, :max_len].copy_(page_indices)
|
434
|
+
assert (
|
435
|
+
metadata.nsa_cache_seqlens_int32 is not None
|
436
|
+
and metadata.nsa_cu_seqlens_k is not None
|
437
|
+
and self.nsa_index_topk is not None
|
438
|
+
)
|
439
|
+
nsa_cache_seqlens = compute_nsa_seqlens(cache_seqlens, self.nsa_index_topk)
|
440
|
+
metadata.nsa_cache_seqlens_int32.copy_(nsa_cache_seqlens)
|
441
|
+
metadata.nsa_cu_seqlens_k[1:].copy_(
|
442
|
+
torch.cumsum(nsa_cache_seqlens, dim=0, dtype=torch.int32)
|
443
|
+
)
|
444
|
+
# NOTE(dark): (nsa-) cu_seqlens_q is always arange, no need to copy
|
445
|
+
|
446
|
+
assert self.real_page_size == metadata.page_size
|
447
|
+
if self.real_page_size > 1:
|
448
|
+
real_table = self._transform_table_1_to_real(page_indices)
|
449
|
+
new_len = real_table.shape[1]
|
450
|
+
metadata.real_page_table[:, :new_len].copy_(real_table)
|
451
|
+
else:
|
452
|
+
assert metadata.real_page_table is metadata.page_table_1
|
453
|
+
|
454
|
+
if NSA_DECODE_IMPL == "flashmla_decode":
|
455
|
+
metadata.flashmla_metadata.copy_(
|
456
|
+
self._compute_flashmla_metadata(
|
457
|
+
cache_seqlens=nsa_cache_seqlens,
|
458
|
+
seq_len_q=1, # TODO handle MTP which is not 1
|
459
|
+
)
|
460
|
+
)
|
461
|
+
|
462
|
+
self.forward_metadata = metadata
|
463
|
+
|
464
|
+
def forward_extend(
|
465
|
+
self,
|
466
|
+
q: torch.Tensor,
|
467
|
+
k: torch.Tensor,
|
468
|
+
v: torch.Tensor,
|
469
|
+
layer: RadixAttention,
|
470
|
+
forward_batch: ForwardBatch,
|
471
|
+
save_kv_cache=True,
|
472
|
+
# For multi-head latent attention
|
473
|
+
q_rope: Optional[torch.Tensor] = None,
|
474
|
+
k_rope: Optional[torch.Tensor] = None,
|
475
|
+
topk_indices: Optional[torch.Tensor] = None,
|
476
|
+
) -> torch.Tensor:
|
477
|
+
assert (
|
478
|
+
not forward_batch.forward_mode.is_target_verify()
|
479
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
480
|
+
), "NSA backend doesn't support speculative decoding"
|
481
|
+
if k is not None:
|
482
|
+
assert v is not None
|
483
|
+
if save_kv_cache:
|
484
|
+
cache_loc = (
|
485
|
+
forward_batch.out_cache_loc
|
486
|
+
if not layer.is_cross_attention
|
487
|
+
else forward_batch.encoder_out_cache_loc
|
488
|
+
)
|
489
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
490
|
+
layer,
|
491
|
+
cache_loc,
|
492
|
+
k,
|
493
|
+
k_rope,
|
494
|
+
)
|
495
|
+
|
496
|
+
metadata = self.forward_metadata
|
497
|
+
causal = not layer.is_cross_attention
|
498
|
+
assert causal, "NSA is causal only"
|
499
|
+
|
500
|
+
# For fa3 interface version compatibility, we put new fields into conditional keyword args
|
501
|
+
kwargs = {}
|
502
|
+
|
503
|
+
# Do absorbed multi-latent attention
|
504
|
+
assert q_rope is not None
|
505
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
506
|
+
|
507
|
+
# when store in fp8 and compute in fp8, no need to convert dtype
|
508
|
+
if not (
|
509
|
+
NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and self.nsa_kv_cache_store_fp8
|
510
|
+
):
|
511
|
+
kv_cache = kv_cache.to(q.dtype)
|
512
|
+
|
513
|
+
if q_rope is not None:
|
514
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
515
|
+
q_rope = q_rope.view(
|
516
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
517
|
+
)
|
518
|
+
else:
|
519
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
520
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
521
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
522
|
+
|
523
|
+
# NOTE(dark): here, we use page size = 1
|
524
|
+
|
525
|
+
if NSA_FUSE_TOPK:
|
526
|
+
page_table_1 = topk_indices
|
527
|
+
else:
|
528
|
+
assert metadata.nsa_extend_seq_lens_list is not None
|
529
|
+
page_table_1 = transform_index_page_table_prefill(
|
530
|
+
page_table=metadata.page_table_1,
|
531
|
+
topk_indices=topk_indices,
|
532
|
+
extend_lens_cpu=metadata.nsa_extend_seq_lens_list,
|
533
|
+
page_size=1,
|
534
|
+
)
|
535
|
+
if NSA_PREFILL_IMPL == "tilelang":
|
536
|
+
if q_rope is not None:
|
537
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
538
|
+
return self._forward_tilelang(
|
539
|
+
q_all=q_all,
|
540
|
+
kv_cache=kv_cache,
|
541
|
+
page_table_1=page_table_1,
|
542
|
+
sm_scale=layer.scaling,
|
543
|
+
v_head_dim=layer.v_head_dim,
|
544
|
+
)
|
545
|
+
elif NSA_PREFILL_IMPL == "flashmla_prefill":
|
546
|
+
if q_rope is not None:
|
547
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
548
|
+
return self._forward_flashmla_prefill(
|
549
|
+
q_all=q_all,
|
550
|
+
kv_cache=kv_cache,
|
551
|
+
page_table_1=page_table_1,
|
552
|
+
sm_scale=layer.scaling,
|
553
|
+
v_head_dim=layer.v_head_dim,
|
554
|
+
)
|
555
|
+
elif NSA_PREFILL_IMPL == "flashmla_decode":
|
556
|
+
if q_rope is not None:
|
557
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
558
|
+
return self._forward_flashmla_decode(
|
559
|
+
q_all=q_all,
|
560
|
+
kv_cache=kv_cache,
|
561
|
+
sm_scale=layer.scaling,
|
562
|
+
v_head_dim=layer.v_head_dim,
|
563
|
+
# TODO optimize args
|
564
|
+
layer=layer,
|
565
|
+
metadata=metadata,
|
566
|
+
page_table_1=page_table_1,
|
567
|
+
)
|
568
|
+
elif NSA_PREFILL_IMPL == "fa3":
|
569
|
+
return self._forward_fa3(
|
570
|
+
q_rope=q_rope,
|
571
|
+
kv_cache=kv_cache,
|
572
|
+
v_head_dim=layer.v_head_dim,
|
573
|
+
q_nope=q_nope,
|
574
|
+
page_table=page_table_1,
|
575
|
+
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
576
|
+
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
577
|
+
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
578
|
+
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
579
|
+
sm_scale=layer.scaling,
|
580
|
+
logit_cap=layer.logit_cap,
|
581
|
+
page_size=1,
|
582
|
+
)
|
583
|
+
else:
|
584
|
+
raise ValueError(f"Unsupported {NSA_PREFILL_IMPL = }")
|
585
|
+
|
586
|
+
def forward_decode(
|
587
|
+
self,
|
588
|
+
q: torch.Tensor,
|
589
|
+
k: torch.Tensor,
|
590
|
+
v: torch.Tensor,
|
591
|
+
layer: RadixAttention,
|
592
|
+
forward_batch: ForwardBatch,
|
593
|
+
save_kv_cache=True,
|
594
|
+
# For multi-head latent attention
|
595
|
+
q_rope: Optional[torch.Tensor] = None,
|
596
|
+
k_rope: Optional[torch.Tensor] = None,
|
597
|
+
topk_indices: Optional[torch.Tensor] = None,
|
598
|
+
) -> torch.Tensor:
|
599
|
+
if k is not None:
|
600
|
+
assert v is not None
|
601
|
+
if save_kv_cache:
|
602
|
+
cache_loc = (
|
603
|
+
forward_batch.out_cache_loc
|
604
|
+
if not layer.is_cross_attention
|
605
|
+
else forward_batch.encoder_out_cache_loc
|
606
|
+
)
|
607
|
+
forward_batch.token_to_kv_pool.set_mla_kv_buffer( # type: ignore
|
608
|
+
layer,
|
609
|
+
cache_loc,
|
610
|
+
k,
|
611
|
+
k_rope,
|
612
|
+
)
|
613
|
+
|
614
|
+
metadata = self.forward_metadata
|
615
|
+
causal = not layer.is_cross_attention
|
616
|
+
assert causal, "NSA is causal only"
|
617
|
+
|
618
|
+
# Do absorbed multi-latent attention
|
619
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
620
|
+
if q_rope is not None:
|
621
|
+
q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
622
|
+
q_rope = q_rope.view(
|
623
|
+
-1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
|
624
|
+
)
|
625
|
+
else:
|
626
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
627
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
628
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
629
|
+
|
630
|
+
if NSA_FUSE_TOPK:
|
631
|
+
page_table_1 = topk_indices
|
632
|
+
else:
|
633
|
+
page_table_1 = transform_index_page_table_decode(
|
634
|
+
page_table=metadata.page_table_1,
|
635
|
+
topk_indices=topk_indices,
|
636
|
+
page_size=1,
|
637
|
+
)
|
638
|
+
|
639
|
+
if NSA_DECODE_IMPL == "flashmla_prefill":
|
640
|
+
if q_rope is not None:
|
641
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
642
|
+
return self._forward_flashmla_prefill(
|
643
|
+
q_all=q_all,
|
644
|
+
kv_cache=kv_cache,
|
645
|
+
page_table_1=page_table_1,
|
646
|
+
sm_scale=layer.scaling,
|
647
|
+
v_head_dim=layer.v_head_dim,
|
648
|
+
)
|
649
|
+
elif NSA_DECODE_IMPL == "flashmla_decode":
|
650
|
+
if q_rope is not None:
|
651
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
652
|
+
return self._forward_flashmla_decode(
|
653
|
+
q_all=q_all,
|
654
|
+
kv_cache=kv_cache,
|
655
|
+
sm_scale=layer.scaling,
|
656
|
+
v_head_dim=layer.v_head_dim,
|
657
|
+
# TODO optimize args
|
658
|
+
layer=layer,
|
659
|
+
metadata=metadata,
|
660
|
+
page_table_1=page_table_1,
|
661
|
+
)
|
662
|
+
elif NSA_DECODE_IMPL == "tilelang":
|
663
|
+
if q_rope is not None:
|
664
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
665
|
+
return self._forward_tilelang(
|
666
|
+
q_all=q_all,
|
667
|
+
kv_cache=kv_cache,
|
668
|
+
page_table_1=page_table_1,
|
669
|
+
sm_scale=layer.scaling,
|
670
|
+
v_head_dim=layer.v_head_dim,
|
671
|
+
)
|
672
|
+
elif NSA_DECODE_IMPL == "fa3":
|
673
|
+
return self._forward_fa3(
|
674
|
+
q_rope=q_rope,
|
675
|
+
kv_cache=kv_cache,
|
676
|
+
v_head_dim=layer.v_head_dim,
|
677
|
+
q_nope=q_nope,
|
678
|
+
page_table=page_table_1,
|
679
|
+
cache_seqlens=metadata.nsa_cache_seqlens_int32,
|
680
|
+
cu_seqlens_q=metadata.nsa_cu_seqlens_q,
|
681
|
+
cu_seqlens_k=metadata.nsa_cu_seqlens_k,
|
682
|
+
max_seqlen_q=metadata.nsa_max_seqlen_q,
|
683
|
+
sm_scale=layer.scaling,
|
684
|
+
logit_cap=layer.logit_cap,
|
685
|
+
page_size=1,
|
686
|
+
)
|
687
|
+
elif NSA_DECODE_IMPL == "aiter":
|
688
|
+
if q_rope is not None:
|
689
|
+
q_all = torch.cat([q_nope, q_rope], dim=-1)
|
690
|
+
return self._forward_aiter(
|
691
|
+
q_all=q_all,
|
692
|
+
kv_cache=kv_cache,
|
693
|
+
page_table_1=page_table_1,
|
694
|
+
layer=layer,
|
695
|
+
metadata=metadata,
|
696
|
+
bs=forward_batch.batch_size,
|
697
|
+
)
|
698
|
+
|
699
|
+
else:
|
700
|
+
assert False, f"Unsupported {NSA_DECODE_IMPL = }"
|
701
|
+
|
702
|
+
def _forward_fa3(
|
703
|
+
self,
|
704
|
+
q_rope: torch.Tensor,
|
705
|
+
kv_cache: torch.Tensor,
|
706
|
+
v_head_dim: int,
|
707
|
+
q_nope: torch.Tensor,
|
708
|
+
page_table: torch.Tensor,
|
709
|
+
cache_seqlens: torch.Tensor,
|
710
|
+
cu_seqlens_q: torch.Tensor,
|
711
|
+
cu_seqlens_k: torch.Tensor,
|
712
|
+
max_seqlen_q: int,
|
713
|
+
sm_scale: float,
|
714
|
+
logit_cap: float,
|
715
|
+
page_size: int,
|
716
|
+
) -> torch.Tensor:
|
717
|
+
k_rope_cache = kv_cache[:, :, v_head_dim:]
|
718
|
+
c_kv_cache = kv_cache[:, :, :v_head_dim]
|
719
|
+
qk_rope_dim = k_rope_cache.shape[-1]
|
720
|
+
k_rope_cache = k_rope_cache.view(-1, page_size, 1, qk_rope_dim)
|
721
|
+
c_kv_cache = c_kv_cache.view(-1, page_size, 1, v_head_dim)
|
722
|
+
o = flash_attn_with_kvcache(
|
723
|
+
q=q_rope,
|
724
|
+
k_cache=k_rope_cache,
|
725
|
+
v_cache=c_kv_cache,
|
726
|
+
qv=q_nope,
|
727
|
+
page_table=page_table,
|
728
|
+
cache_seqlens=cache_seqlens,
|
729
|
+
cu_seqlens_q=cu_seqlens_q,
|
730
|
+
cu_seqlens_k_new=cu_seqlens_k,
|
731
|
+
max_seqlen_q=max_seqlen_q,
|
732
|
+
softmax_scale=sm_scale,
|
733
|
+
causal=True,
|
734
|
+
softcap=logit_cap,
|
735
|
+
return_softmax_lse=False,
|
736
|
+
num_splits=self.num_splits,
|
737
|
+
)
|
738
|
+
return o # type: ignore
|
739
|
+
|
740
|
+
def _forward_flashmla_prefill(
|
741
|
+
self,
|
742
|
+
q_all: torch.Tensor,
|
743
|
+
kv_cache: torch.Tensor,
|
744
|
+
v_head_dim: int,
|
745
|
+
page_table_1: torch.Tensor,
|
746
|
+
sm_scale: float,
|
747
|
+
) -> torch.Tensor:
|
748
|
+
from flash_mla import flash_mla_sparse_fwd
|
749
|
+
|
750
|
+
o, _, _ = flash_mla_sparse_fwd(
|
751
|
+
q=q_all,
|
752
|
+
kv=kv_cache,
|
753
|
+
indices=page_table_1.unsqueeze(1),
|
754
|
+
sm_scale=sm_scale,
|
755
|
+
d_v=v_head_dim,
|
756
|
+
)
|
757
|
+
return o
|
758
|
+
|
759
|
+
def _forward_flashmla_decode(
|
760
|
+
self,
|
761
|
+
q_all: torch.Tensor,
|
762
|
+
kv_cache: torch.Tensor,
|
763
|
+
v_head_dim: int,
|
764
|
+
sm_scale: float,
|
765
|
+
layer,
|
766
|
+
metadata: NSAMetadata,
|
767
|
+
page_table_1,
|
768
|
+
) -> torch.Tensor:
|
769
|
+
from flash_mla import flash_mla_with_kvcache
|
770
|
+
|
771
|
+
cache_seqlens = metadata.nsa_cache_seqlens_int32
|
772
|
+
|
773
|
+
# TODO the 2nd dim is seq_len_q, need to be >1 when MTP
|
774
|
+
q_all = q_all.view(-1, 1, layer.tp_q_head_num, layer.head_dim)
|
775
|
+
kv_cache = kv_cache.view(-1, self.real_page_size, 1, self.kv_cache_dim)
|
776
|
+
assert self.real_page_size == 64, "only page size 64 is supported"
|
777
|
+
|
778
|
+
if NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8 and not self.nsa_kv_cache_store_fp8:
|
779
|
+
# inefficiently quantize the whole cache
|
780
|
+
kv_cache = quantize_k_cache(kv_cache)
|
781
|
+
|
782
|
+
indices = page_table_1.unsqueeze(1)
|
783
|
+
assert (
|
784
|
+
indices.shape[-1] == self.nsa_index_topk
|
785
|
+
) # requirement of FlashMLA decode kernel
|
786
|
+
|
787
|
+
o, _ = flash_mla_with_kvcache(
|
788
|
+
q=q_all,
|
789
|
+
k_cache=kv_cache,
|
790
|
+
cache_seqlens=cache_seqlens,
|
791
|
+
head_dim_v=v_head_dim,
|
792
|
+
tile_scheduler_metadata=metadata.flashmla_metadata.flashmla_metadata,
|
793
|
+
num_splits=metadata.flashmla_metadata.num_splits,
|
794
|
+
softmax_scale=sm_scale,
|
795
|
+
indices=indices,
|
796
|
+
# doc says it is not used, but if pass in None then error
|
797
|
+
block_table=torch.empty(
|
798
|
+
(q_all.shape[0], 0), dtype=torch.int32, device=q_all.device
|
799
|
+
),
|
800
|
+
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
801
|
+
)
|
802
|
+
return o
|
803
|
+
|
804
|
+
def _forward_tilelang(
|
805
|
+
self,
|
806
|
+
q_all: torch.Tensor,
|
807
|
+
kv_cache: torch.Tensor,
|
808
|
+
v_head_dim: int,
|
809
|
+
page_table_1: torch.Tensor,
|
810
|
+
sm_scale: float,
|
811
|
+
) -> torch.Tensor:
|
812
|
+
from sglang.srt.layers.attention.nsa.tilelang_kernel import tilelang_sparse_fwd
|
813
|
+
|
814
|
+
return tilelang_sparse_fwd(
|
815
|
+
q=q_all,
|
816
|
+
kv=kv_cache,
|
817
|
+
indices=page_table_1.unsqueeze(1),
|
818
|
+
sm_scale=sm_scale,
|
819
|
+
d_v=v_head_dim,
|
820
|
+
)
|
821
|
+
|
822
|
+
def _forward_aiter(
|
823
|
+
self,
|
824
|
+
q_all: torch.Tensor,
|
825
|
+
kv_cache: torch.Tensor,
|
826
|
+
page_table_1: torch.Tensor,
|
827
|
+
layer: RadixAttention,
|
828
|
+
metadata: NSAMetadata,
|
829
|
+
bs: int,
|
830
|
+
) -> torch.Tensor:
|
831
|
+
q = q_all.reshape(-1, layer.tp_q_head_num * layer.head_dim)
|
832
|
+
|
833
|
+
if layer.head_dim != layer.v_head_dim:
|
834
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
835
|
+
else:
|
836
|
+
o = torch.empty_like(q)
|
837
|
+
|
838
|
+
kv_indptr = self.kv_indptr
|
839
|
+
|
840
|
+
non_minus1_mask = page_table_1 != -1
|
841
|
+
non_minus1_counts = non_minus1_mask.sum(dim=1)
|
842
|
+
kv_indptr[1 : bs + 1] = torch.cumsum(non_minus1_counts, dim=0)
|
843
|
+
|
844
|
+
kv_indices = page_table_1[page_table_1 != -1]
|
845
|
+
|
846
|
+
mla_decode_fwd(
|
847
|
+
q.view(-1, layer.tp_q_head_num, layer.head_dim),
|
848
|
+
kv_cache.view(-1, 1, 1, layer.head_dim),
|
849
|
+
o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
|
850
|
+
metadata.cu_seqlens_q,
|
851
|
+
kv_indptr,
|
852
|
+
kv_indices,
|
853
|
+
metadata.cu_seqlens_q,
|
854
|
+
metadata.max_seq_len_q,
|
855
|
+
layer.scaling,
|
856
|
+
layer.logit_cap,
|
857
|
+
)
|
858
|
+
# kv_cache = kv_cache.view(-1, 1, layer.head_dim)
|
859
|
+
return o
|
860
|
+
|
861
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
862
|
+
"""Get the fill value for sequence length in CUDA graph."""
|
863
|
+
return 1
|
864
|
+
|
865
|
+
def get_indexer_metadata(
|
866
|
+
self, layer_id: int, forward_batch: ForwardBatch
|
867
|
+
) -> NSAIndexerMetadata:
|
868
|
+
return NSAIndexerMetadata(attn_metadata=self.forward_metadata)
|
869
|
+
|
870
|
+
def _compute_flashmla_metadata(self, cache_seqlens: torch.Tensor, seq_len_q: int):
|
871
|
+
from flash_mla import get_mla_metadata
|
872
|
+
|
873
|
+
flashmla_metadata, num_splits = get_mla_metadata(
|
874
|
+
cache_seqlens=cache_seqlens,
|
875
|
+
# TODO doc says `num_q_tokens_per_q_seq * num_heads_q // num_heads_k`
|
876
|
+
# but the name looks like need seq_len_q?
|
877
|
+
num_q_tokens_per_head_k=seq_len_q * self.num_q_heads // 1,
|
878
|
+
num_heads_k=1,
|
879
|
+
num_heads_q=self.num_q_heads,
|
880
|
+
is_fp8_kvcache=NSA_FLASHMLA_BACKEND_DECODE_COMPUTE_FP8,
|
881
|
+
topk=self.nsa_index_topk,
|
882
|
+
)
|
883
|
+
|
884
|
+
return NSAFlashMLAMetadata(
|
885
|
+
flashmla_metadata=flashmla_metadata,
|
886
|
+
num_splits=num_splits,
|
887
|
+
)
|