sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +156 -80
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +130 -59
- sglang/srt/entrypoints/openai/protocol.py +112 -4
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +204 -55
- sglang/srt/entrypoints/openai/serving_completions.py +14 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +48 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -6
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +190 -55
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +74 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +199 -301
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +357 -407
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +144 -17
- sglang/srt/managers/scheduler.py +502 -209
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +99 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +320 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +43 -24
- sglang/srt/mem_cache/hiradix_cache.py +222 -75
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -228
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +259 -62
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +166 -17
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +511 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +82 -40
- sglang/srt/model_executor/model_runner.py +432 -157
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +607 -130
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +4 -4
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +120 -13
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +32 -4
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +966 -267
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +99 -28
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +433 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +21 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +375 -245
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
# Adapted from:
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
|
+
from __future__ import annotations
|
18
19
|
|
19
20
|
import concurrent.futures
|
20
21
|
import logging
|
@@ -25,9 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
25
26
|
import torch
|
26
27
|
import torch.nn.functional as F
|
27
28
|
from torch import nn
|
28
|
-
from tqdm import tqdm
|
29
29
|
from transformers import PretrainedConfig
|
30
30
|
|
31
|
+
from sglang.srt import single_batch_overlap
|
32
|
+
from sglang.srt.configs.model_config import (
|
33
|
+
get_nsa_index_head_dim,
|
34
|
+
get_nsa_index_n_heads,
|
35
|
+
get_nsa_index_topk,
|
36
|
+
is_deepseek_nsa,
|
37
|
+
)
|
38
|
+
from sglang.srt.debug_utils.dumper import dumper
|
31
39
|
from sglang.srt.distributed import (
|
32
40
|
get_moe_expert_parallel_world_size,
|
33
41
|
get_pp_group,
|
@@ -43,6 +51,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
43
51
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
44
52
|
from sglang.srt.layers.activation import SiluAndMul
|
45
53
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
54
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
55
|
+
NPUFusedMLAPreprocess,
|
56
|
+
is_mla_preprocess_enabled,
|
57
|
+
)
|
58
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
|
46
59
|
from sglang.srt.layers.communicator import (
|
47
60
|
LayerCommunicator,
|
48
61
|
LayerScatterModes,
|
@@ -65,10 +78,11 @@ from sglang.srt.layers.moe import (
|
|
65
78
|
get_deepep_mode,
|
66
79
|
get_moe_a2a_backend,
|
67
80
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
81
|
+
should_use_flashinfer_trtllm_moe,
|
68
82
|
)
|
69
83
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
84
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
71
|
-
from sglang.srt.layers.moe.topk import TopK
|
85
|
+
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
72
86
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
73
87
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
74
88
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -96,6 +110,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
96
110
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
97
111
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
98
112
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
113
|
+
from sglang.srt.single_batch_overlap import SboFlags
|
99
114
|
from sglang.srt.two_batch_overlap import (
|
100
115
|
MaybeTboDeepEPDispatcher,
|
101
116
|
model_forward_maybe_tbo,
|
@@ -151,6 +166,7 @@ if _is_cuda:
|
|
151
166
|
from sgl_kernel import (
|
152
167
|
awq_dequantize,
|
153
168
|
bmm_fp8,
|
169
|
+
concat_mla_k,
|
154
170
|
dsv3_fused_a_gemm,
|
155
171
|
dsv3_router_gemm,
|
156
172
|
merge_state_v2,
|
@@ -158,16 +174,18 @@ if _is_cuda:
|
|
158
174
|
elif _is_cpu and _is_cpu_amx_available:
|
159
175
|
pass
|
160
176
|
elif _is_hip:
|
177
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
178
|
+
decode_attention_fwd_grouped_rope,
|
179
|
+
)
|
161
180
|
from sglang.srt.layers.quantization.awq_triton import (
|
162
181
|
awq_dequantize_triton as awq_dequantize,
|
163
182
|
)
|
183
|
+
elif _is_npu:
|
184
|
+
import custom_ops
|
185
|
+
import sgl_kernel_npu
|
186
|
+
import torch_npu
|
164
187
|
else:
|
165
|
-
|
166
|
-
|
167
|
-
if _is_hip:
|
168
|
-
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
169
|
-
decode_attention_fwd_grouped_rope,
|
170
|
-
)
|
188
|
+
pass
|
171
189
|
|
172
190
|
_is_flashinfer_available = is_flashinfer_available()
|
173
191
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
@@ -175,6 +193,21 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
|
|
175
193
|
|
176
194
|
logger = logging.getLogger(__name__)
|
177
195
|
|
196
|
+
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
197
|
+
"fa3",
|
198
|
+
"nsa",
|
199
|
+
"flashinfer",
|
200
|
+
"cutlass_mla",
|
201
|
+
"trtllm_mla",
|
202
|
+
"ascend",
|
203
|
+
]
|
204
|
+
|
205
|
+
|
206
|
+
def add_forward_absorb_core_attention_backend(backend_name):
|
207
|
+
if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
208
|
+
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
|
209
|
+
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
|
210
|
+
|
178
211
|
|
179
212
|
class AttnForwardMethod(IntEnum):
|
180
213
|
# Use multi-head attention
|
@@ -183,6 +216,9 @@ class AttnForwardMethod(IntEnum):
|
|
183
216
|
# Use absorbed multi-latent attention
|
184
217
|
MLA = auto()
|
185
218
|
|
219
|
+
# Use Deepseek V3.2 sparse multi-latent attention
|
220
|
+
NPU_MLA_SPARSE = auto()
|
221
|
+
|
186
222
|
# Use multi-head attention, but with KV cache chunked.
|
187
223
|
# This method can avoid OOM when prefix lengths are long.
|
188
224
|
MHA_CHUNKED_KV = auto()
|
@@ -194,6 +230,146 @@ class AttnForwardMethod(IntEnum):
|
|
194
230
|
MLA_FUSED_ROPE_CPU = auto()
|
195
231
|
|
196
232
|
|
233
|
+
def _dispatch_mla_subtype(attn, forward_batch):
|
234
|
+
if _is_hip:
|
235
|
+
if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
|
236
|
+
return AttnForwardMethod.MLA_FUSED_ROPE
|
237
|
+
else:
|
238
|
+
return AttnForwardMethod.MLA
|
239
|
+
else:
|
240
|
+
if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
|
241
|
+
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
242
|
+
else:
|
243
|
+
return AttnForwardMethod.MLA
|
244
|
+
|
245
|
+
|
246
|
+
class AttentionBackendRegistry:
|
247
|
+
_handlers = {}
|
248
|
+
|
249
|
+
@classmethod
|
250
|
+
def register(cls, backend_name, handler_func):
|
251
|
+
cls._handlers[backend_name] = handler_func
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def get_handler(cls, backend_name):
|
255
|
+
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
256
|
+
|
257
|
+
|
258
|
+
def handle_attention_ascend(attn, forward_batch):
|
259
|
+
if (
|
260
|
+
forward_batch.forward_mode.is_extend()
|
261
|
+
and not forward_batch.forward_mode.is_target_verify()
|
262
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
263
|
+
):
|
264
|
+
if hasattr(attn, "indexer"):
|
265
|
+
return AttnForwardMethod.NPU_MLA_SPARSE
|
266
|
+
else:
|
267
|
+
return AttnForwardMethod.MHA
|
268
|
+
else:
|
269
|
+
if hasattr(attn, "indexer"):
|
270
|
+
return AttnForwardMethod.NPU_MLA_SPARSE
|
271
|
+
else:
|
272
|
+
return AttnForwardMethod.MLA
|
273
|
+
|
274
|
+
|
275
|
+
def _get_sum_extend_prefix_lens(forward_batch):
|
276
|
+
return (
|
277
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
278
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
279
|
+
else 0
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
def _is_extend_without_speculative(forward_batch):
|
284
|
+
return (
|
285
|
+
forward_batch.forward_mode.is_extend()
|
286
|
+
and not forward_batch.forward_mode.is_target_verify()
|
287
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
288
|
+
)
|
289
|
+
|
290
|
+
|
291
|
+
def _handle_attention_backend(
|
292
|
+
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
|
293
|
+
):
|
294
|
+
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
295
|
+
disable_ragged = (
|
296
|
+
backend_name in ["flashinfer", "flashmla"]
|
297
|
+
) and attn.flashinfer_mla_disable_ragged
|
298
|
+
|
299
|
+
if (
|
300
|
+
not disable_ragged
|
301
|
+
and _is_extend_without_speculative(forward_batch)
|
302
|
+
and (
|
303
|
+
(
|
304
|
+
sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
|
305
|
+
and not attn.disable_chunked_prefix_cache
|
306
|
+
)
|
307
|
+
or sum_extend_prefix_lens == 0
|
308
|
+
)
|
309
|
+
):
|
310
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
311
|
+
else:
|
312
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
313
|
+
|
314
|
+
|
315
|
+
def handle_attention_flashinfer(attn, forward_batch):
|
316
|
+
return _handle_attention_backend(attn, forward_batch, "flashinfer")
|
317
|
+
|
318
|
+
|
319
|
+
def handle_attention_fa3(attn, forward_batch):
|
320
|
+
return _handle_attention_backend(attn, forward_batch, "fa3")
|
321
|
+
|
322
|
+
|
323
|
+
def handle_attention_flashmla(attn, forward_batch):
|
324
|
+
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
325
|
+
|
326
|
+
|
327
|
+
def handle_attention_cutlass_mla(attn, forward_batch):
|
328
|
+
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
329
|
+
|
330
|
+
|
331
|
+
def handle_attention_fa4(attn, forward_batch):
|
332
|
+
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
333
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
334
|
+
|
335
|
+
|
336
|
+
def handle_attention_trtllm_mla(attn, forward_batch):
|
337
|
+
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
338
|
+
if _is_extend_without_speculative(forward_batch) and (
|
339
|
+
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
340
|
+
):
|
341
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
342
|
+
else:
|
343
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
344
|
+
|
345
|
+
|
346
|
+
def handle_attention_aiter(attn, forward_batch):
|
347
|
+
if _is_extend_without_speculative(forward_batch):
|
348
|
+
if is_dp_attention_enabled():
|
349
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
350
|
+
return AttnForwardMethod.MHA
|
351
|
+
else:
|
352
|
+
return AttnForwardMethod.MLA
|
353
|
+
else:
|
354
|
+
return AttnForwardMethod.MHA
|
355
|
+
else:
|
356
|
+
return AttnForwardMethod.MLA
|
357
|
+
|
358
|
+
|
359
|
+
def handle_attention_nsa(attn, forward_batch):
|
360
|
+
return AttnForwardMethod.MLA
|
361
|
+
|
362
|
+
|
363
|
+
def handle_attention_triton(attn, forward_batch):
|
364
|
+
if (
|
365
|
+
_is_extend_without_speculative(forward_batch)
|
366
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
367
|
+
):
|
368
|
+
return AttnForwardMethod.MHA
|
369
|
+
else:
|
370
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
371
|
+
|
372
|
+
|
197
373
|
class DeepseekV2MLP(nn.Module):
|
198
374
|
def __init__(
|
199
375
|
self,
|
@@ -246,7 +422,11 @@ class DeepseekV2MLP(nn.Module):
|
|
246
422
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
247
423
|
return x
|
248
424
|
|
249
|
-
if
|
425
|
+
if (
|
426
|
+
gemm_output_zero_allocator is not None
|
427
|
+
and x.shape[0] <= 256
|
428
|
+
and self.gate_up_proj.weight.dtype == torch.uint8
|
429
|
+
):
|
250
430
|
y = gemm_output_zero_allocator.allocate(
|
251
431
|
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
252
432
|
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
@@ -264,6 +444,7 @@ class MoEGate(nn.Module):
|
|
264
444
|
def __init__(
|
265
445
|
self,
|
266
446
|
config,
|
447
|
+
quant_config,
|
267
448
|
prefix: str = "",
|
268
449
|
is_nextn: bool = False,
|
269
450
|
):
|
@@ -273,8 +454,15 @@ class MoEGate(nn.Module):
|
|
273
454
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
274
455
|
)
|
275
456
|
if config.topk_method == "noaux_tc":
|
457
|
+
correction_bias_dtype = (
|
458
|
+
torch.bfloat16
|
459
|
+
if quant_config is not None
|
460
|
+
and quant_config.get_name() == "modelopt_fp4"
|
461
|
+
and should_use_flashinfer_trtllm_moe()
|
462
|
+
else torch.float32
|
463
|
+
)
|
276
464
|
self.e_score_correction_bias = nn.Parameter(
|
277
|
-
torch.empty((config.n_routed_experts), dtype=
|
465
|
+
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
|
278
466
|
)
|
279
467
|
else:
|
280
468
|
self.e_score_correction_bias = None
|
@@ -295,11 +483,13 @@ class MoEGate(nn.Module):
|
|
295
483
|
_is_cuda
|
296
484
|
and hidden_states.shape[0] <= 16
|
297
485
|
and hidden_states.shape[1] == 7168
|
298
|
-
and self.weight.shape[0] == 256
|
486
|
+
and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
|
299
487
|
and _device_sm >= 90
|
300
488
|
):
|
301
489
|
# router gemm output float32
|
302
|
-
logits = dsv3_router_gemm(
|
490
|
+
logits = dsv3_router_gemm(
|
491
|
+
hidden_states, self.weight, out_dtype=torch.float32
|
492
|
+
)
|
303
493
|
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
304
494
|
logits = aiter_dsv3_router_gemm(
|
305
495
|
hidden_states, self.weight, gemm_output_zero_allocator
|
@@ -347,7 +537,10 @@ class DeepseekV2MoE(nn.Module):
|
|
347
537
|
)
|
348
538
|
|
349
539
|
self.gate = MoEGate(
|
350
|
-
config=config,
|
540
|
+
config=config,
|
541
|
+
quant_config=quant_config,
|
542
|
+
prefix=add_prefix("gate", prefix),
|
543
|
+
is_nextn=is_nextn,
|
351
544
|
)
|
352
545
|
|
353
546
|
self.experts = get_moe_impl_class(quant_config)(
|
@@ -372,9 +565,12 @@ class DeepseekV2MoE(nn.Module):
|
|
372
565
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
373
566
|
topk_group=config.topk_group,
|
374
567
|
correction_bias=self.gate.e_score_correction_bias,
|
568
|
+
quant_config=quant_config,
|
375
569
|
routed_scaling_factor=self.routed_scaling_factor,
|
376
|
-
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk
|
377
|
-
|
570
|
+
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
|
571
|
+
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
572
|
+
# and requires the output format to be standard. We use quant_config to determine the output format.
|
573
|
+
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
378
574
|
)
|
379
575
|
|
380
576
|
self.shared_experts_is_int8 = False
|
@@ -638,7 +834,8 @@ class DeepseekV2MoE(nn.Module):
|
|
638
834
|
if hidden_states.shape[0] > 0:
|
639
835
|
# router_logits: (num_tokens, n_experts)
|
640
836
|
router_logits = self.gate(hidden_states)
|
641
|
-
|
837
|
+
if not SboFlags.fuse_shared_experts_inside_sbo():
|
838
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
642
839
|
topk_weights, topk_idx, _ = self.topk(
|
643
840
|
hidden_states,
|
644
841
|
router_logits,
|
@@ -652,26 +849,36 @@ class DeepseekV2MoE(nn.Module):
|
|
652
849
|
hidden_states.device
|
653
850
|
)
|
654
851
|
|
655
|
-
final_hidden_states =
|
852
|
+
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
656
853
|
hidden_states=hidden_states,
|
657
854
|
topk_idx=topk_idx,
|
658
855
|
topk_weights=topk_weights,
|
659
856
|
forward_batch=forward_batch,
|
857
|
+
# SBO args
|
858
|
+
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
859
|
+
experts=self.experts,
|
860
|
+
alt_stream=self.alt_stream,
|
660
861
|
)
|
862
|
+
if sbo_shared_output is not None:
|
863
|
+
shared_output = sbo_shared_output
|
661
864
|
|
662
865
|
if shared_output is not None:
|
663
866
|
x = shared_output
|
664
|
-
|
867
|
+
if self.experts.should_fuse_routed_scaling_factor_in_topk:
|
868
|
+
x.add_(final_hidden_states)
|
869
|
+
else:
|
870
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
665
871
|
final_hidden_states = x
|
666
872
|
else:
|
667
|
-
|
873
|
+
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
|
874
|
+
final_hidden_states *= self.routed_scaling_factor
|
668
875
|
|
669
876
|
return final_hidden_states
|
670
877
|
|
671
878
|
def _forward_shared_experts(
|
672
879
|
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
673
880
|
):
|
674
|
-
if self.num_fused_shared_experts == 0:
|
881
|
+
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
|
675
882
|
return self.shared_experts(
|
676
883
|
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
677
884
|
)
|
@@ -724,6 +931,7 @@ class DeepseekV2MoE(nn.Module):
|
|
724
931
|
if self.ep_size > 1:
|
725
932
|
self.experts.deepep_dispatcher.dispatch_a(
|
726
933
|
hidden_states=state.hidden_states_mlp_input,
|
934
|
+
input_global_scale=None,
|
727
935
|
topk_idx=state.pop("topk_idx_local"),
|
728
936
|
topk_weights=state.pop("topk_weights_local"),
|
729
937
|
forward_batch=state.forward_batch,
|
@@ -824,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
824
1032
|
self.rope_theta = rope_theta
|
825
1033
|
self.max_position_embeddings = max_position_embeddings
|
826
1034
|
|
1035
|
+
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
|
1036
|
+
if rope_scaling:
|
1037
|
+
rope_scaling["rope_type"] = "deepseek_yarn"
|
1038
|
+
|
827
1039
|
# For tensor parallel attention
|
828
1040
|
if self.q_lora_rank is not None:
|
829
1041
|
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -861,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
861
1073
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
862
1074
|
)
|
863
1075
|
|
1076
|
+
self.use_nsa = is_deepseek_nsa(config)
|
1077
|
+
if self.use_nsa:
|
1078
|
+
self.indexer = Indexer(
|
1079
|
+
hidden_size=hidden_size,
|
1080
|
+
index_n_heads=get_nsa_index_n_heads(config),
|
1081
|
+
index_head_dim=get_nsa_index_head_dim(config),
|
1082
|
+
rope_head_dim=qk_rope_head_dim,
|
1083
|
+
index_topk=get_nsa_index_topk(config),
|
1084
|
+
q_lora_rank=q_lora_rank,
|
1085
|
+
max_position_embeddings=max_position_embeddings,
|
1086
|
+
rope_theta=rope_theta,
|
1087
|
+
scale_fmt="ue8m0",
|
1088
|
+
block_size=128,
|
1089
|
+
rope_scaling=rope_scaling,
|
1090
|
+
prefix=add_prefix("indexer", prefix),
|
1091
|
+
quant_config=quant_config,
|
1092
|
+
layer_id=layer_id,
|
1093
|
+
alt_stream=alt_stream,
|
1094
|
+
)
|
1095
|
+
|
864
1096
|
self.kv_b_proj = ColumnParallelLinear(
|
865
1097
|
self.kv_lora_rank,
|
866
1098
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
@@ -883,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
883
1115
|
)
|
884
1116
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
885
1117
|
|
886
|
-
if rope_scaling:
|
887
|
-
rope_scaling["rope_type"] = "deepseek_yarn"
|
888
|
-
|
889
1118
|
self.rotary_emb = get_rope_wrapper(
|
890
1119
|
qk_rope_head_dim,
|
891
1120
|
rotary_dim=qk_rope_head_dim,
|
@@ -1009,102 +1238,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1009
1238
|
self.weight_block_size = (
|
1010
1239
|
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
1011
1240
|
)
|
1241
|
+
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
1242
|
+
if self.is_mla_preprocess_enabled:
|
1243
|
+
assert (
|
1244
|
+
quant_config is None or quant_config.get_name() == "w8a8_int8"
|
1245
|
+
), "MLA Preprocess only works with Unquant or W8A8Int8"
|
1246
|
+
self.mla_preprocess = None
|
1012
1247
|
|
1013
1248
|
def dispatch_attn_forward_method(
|
1014
1249
|
self, forward_batch: ForwardBatch
|
1015
1250
|
) -> AttnForwardMethod:
|
1016
|
-
def _dispatch_mla_subtype():
|
1017
|
-
if _is_hip:
|
1018
|
-
if (
|
1019
|
-
self.rocm_fused_decode_mla
|
1020
|
-
and forward_batch.forward_mode.is_decode()
|
1021
|
-
):
|
1022
|
-
return AttnForwardMethod.MLA_FUSED_ROPE
|
1023
|
-
else:
|
1024
|
-
return AttnForwardMethod.MLA
|
1025
|
-
else:
|
1026
|
-
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
1027
|
-
self
|
1028
|
-
):
|
1029
|
-
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
1030
|
-
else:
|
1031
|
-
return AttnForwardMethod.MLA
|
1032
|
-
|
1033
1251
|
# Determine attention backend used by current forward batch
|
1034
1252
|
if forward_batch.forward_mode.is_decode_or_idle():
|
1035
1253
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1254
|
+
elif (
|
1255
|
+
forward_batch.forward_mode.is_target_verify()
|
1256
|
+
or forward_batch.forward_mode.is_draft_extend()
|
1257
|
+
):
|
1258
|
+
# Use the specified backend for speculative operations (both verify and draft extend)
|
1259
|
+
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
1260
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1261
|
+
else: # default to prefill
|
1262
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
1036
1263
|
else:
|
1037
1264
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
1038
1265
|
self.current_attention_backend = attention_backend
|
1039
1266
|
|
1040
|
-
|
1041
|
-
|
1042
|
-
forward_batch.forward_mode.is_extend()
|
1043
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1044
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1045
|
-
):
|
1046
|
-
return AttnForwardMethod.MHA
|
1047
|
-
else:
|
1048
|
-
return AttnForwardMethod.MLA
|
1049
|
-
elif (
|
1050
|
-
attention_backend == "flashinfer"
|
1051
|
-
or attention_backend == "fa3"
|
1052
|
-
or attention_backend == "flashmla"
|
1053
|
-
or attention_backend == "trtllm_mla"
|
1054
|
-
or attention_backend == "cutlass_mla"
|
1055
|
-
):
|
1056
|
-
# Use MHA with chunked KV cache when prefilling on long sequences.
|
1057
|
-
sum_extend_prefix_lens = (
|
1058
|
-
sum(forward_batch.extend_prefix_lens_cpu)
|
1059
|
-
if forward_batch.extend_prefix_lens_cpu is not None
|
1060
|
-
else 0
|
1061
|
-
)
|
1062
|
-
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
1063
|
-
disable_ragged = (
|
1064
|
-
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1065
|
-
) and self.flashinfer_mla_disable_ragged
|
1066
|
-
if (
|
1067
|
-
not disable_ragged
|
1068
|
-
and forward_batch.forward_mode.is_extend()
|
1069
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1070
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1071
|
-
and (
|
1072
|
-
(
|
1073
|
-
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
1074
|
-
and not self.disable_chunked_prefix_cache
|
1075
|
-
)
|
1076
|
-
or sum_extend_prefix_lens == 0
|
1077
|
-
)
|
1078
|
-
):
|
1079
|
-
return AttnForwardMethod.MHA_CHUNKED_KV
|
1080
|
-
else:
|
1081
|
-
return _dispatch_mla_subtype()
|
1082
|
-
elif attention_backend == "aiter":
|
1083
|
-
if (
|
1084
|
-
forward_batch.forward_mode.is_extend()
|
1085
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1086
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1087
|
-
):
|
1088
|
-
if is_dp_attention_enabled():
|
1089
|
-
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
1090
|
-
return AttnForwardMethod.MHA
|
1091
|
-
else:
|
1092
|
-
return AttnForwardMethod.MLA
|
1093
|
-
else:
|
1094
|
-
return AttnForwardMethod.MHA
|
1095
|
-
else:
|
1096
|
-
return AttnForwardMethod.MLA
|
1097
|
-
else:
|
1098
|
-
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
1099
|
-
if (
|
1100
|
-
forward_batch.forward_mode.is_extend()
|
1101
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1102
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1103
|
-
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
1104
|
-
):
|
1105
|
-
return AttnForwardMethod.MHA
|
1106
|
-
else:
|
1107
|
-
return _dispatch_mla_subtype()
|
1267
|
+
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
1268
|
+
return handler(self, forward_batch)
|
1108
1269
|
|
1109
1270
|
def op_prepare(self, state):
|
1110
1271
|
state.attn_intermediate_state = self.forward_prepare(
|
@@ -1159,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1159
1320
|
return hidden_states, None, forward_batch, None
|
1160
1321
|
|
1161
1322
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1162
|
-
|
1163
1323
|
if attn_forward_method == AttnForwardMethod.MHA:
|
1164
1324
|
inner_state = self.forward_normal_prepare(
|
1165
1325
|
positions, hidden_states, forward_batch, zero_allocator
|
@@ -1169,7 +1329,30 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1169
1329
|
positions, hidden_states, forward_batch, zero_allocator
|
1170
1330
|
)
|
1171
1331
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
1172
|
-
|
1332
|
+
if not self.is_mla_preprocess_enabled:
|
1333
|
+
inner_state = self.forward_absorb_prepare(
|
1334
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1335
|
+
)
|
1336
|
+
else:
|
1337
|
+
# TODO(iforgetmyname): to be separated as a standalone func
|
1338
|
+
if self.mla_preprocess is None:
|
1339
|
+
self.mla_preprocess = NPUFusedMLAPreprocess(
|
1340
|
+
self.fused_qkv_a_proj_with_mqa,
|
1341
|
+
self.q_a_layernorm,
|
1342
|
+
self.kv_a_layernorm,
|
1343
|
+
self.q_b_proj,
|
1344
|
+
self.w_kc,
|
1345
|
+
self.rotary_emb,
|
1346
|
+
self.layer_id,
|
1347
|
+
self.num_local_heads,
|
1348
|
+
self.qk_nope_head_dim,
|
1349
|
+
self.qk_rope_head_dim,
|
1350
|
+
)
|
1351
|
+
inner_state = self.mla_preprocess.forward(
|
1352
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1353
|
+
)
|
1354
|
+
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
1355
|
+
inner_state = self.forward_npu_sparse_prepare(
|
1173
1356
|
positions, hidden_states, forward_batch, zero_allocator
|
1174
1357
|
)
|
1175
1358
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
@@ -1197,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1197
1380
|
return self.forward_normal_chunked_kv_core(*inner_state)
|
1198
1381
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
1199
1382
|
return self.forward_absorb_core(*inner_state)
|
1383
|
+
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
1384
|
+
return self.forward_npu_sparse_core(*inner_state)
|
1200
1385
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
1201
1386
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
1202
1387
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
@@ -1235,8 +1420,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1235
1420
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1236
1421
|
q[..., self.qk_nope_head_dim :] = q_pe
|
1237
1422
|
k = torch.empty_like(q)
|
1238
|
-
|
1239
|
-
|
1423
|
+
|
1424
|
+
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
1425
|
+
if (
|
1426
|
+
_is_cuda
|
1427
|
+
and (self.num_local_heads == 128)
|
1428
|
+
and (self.qk_nope_head_dim == 128)
|
1429
|
+
and (self.qk_rope_head_dim == 64)
|
1430
|
+
):
|
1431
|
+
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
1432
|
+
else:
|
1433
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1434
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1240
1435
|
|
1241
1436
|
if not _is_npu:
|
1242
1437
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
@@ -1266,7 +1461,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1266
1461
|
"""
|
1267
1462
|
return (
|
1268
1463
|
self.current_attention_backend == "trtllm_mla"
|
1269
|
-
and
|
1464
|
+
and (
|
1465
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
1466
|
+
or forward_batch.forward_mode.is_target_verify()
|
1467
|
+
)
|
1270
1468
|
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
1271
1469
|
)
|
1272
1470
|
|
@@ -1279,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1279
1477
|
):
|
1280
1478
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1281
1479
|
|
1480
|
+
q_lora = None
|
1282
1481
|
if self.q_lora_rank is not None:
|
1283
1482
|
if (
|
1284
1483
|
(not isinstance(hidden_states, tuple))
|
@@ -1317,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1317
1516
|
q = self.q_a_layernorm(q)
|
1318
1517
|
k_nope = self.kv_a_layernorm(k_nope)
|
1319
1518
|
|
1519
|
+
# q_lora needed by indexer
|
1520
|
+
if self.use_nsa:
|
1521
|
+
q_lora = q
|
1522
|
+
|
1320
1523
|
k_nope = k_nope.unsqueeze(1)
|
1321
1524
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1322
1525
|
else:
|
@@ -1382,28 +1585,50 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1382
1585
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1383
1586
|
|
1384
1587
|
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1385
|
-
not _use_aiter or not _is_gfx95_supported
|
1588
|
+
not _use_aiter or not _is_gfx95_supported or self.use_nsa
|
1386
1589
|
):
|
1387
1590
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1388
1591
|
|
1389
|
-
|
1592
|
+
topk_indices = None
|
1593
|
+
if q_lora is not None:
|
1594
|
+
topk_indices = self.indexer(
|
1595
|
+
x=hidden_states,
|
1596
|
+
q_lora=q_lora,
|
1597
|
+
positions=positions,
|
1598
|
+
forward_batch=forward_batch,
|
1599
|
+
layer_id=self.layer_id,
|
1600
|
+
)
|
1601
|
+
|
1602
|
+
return (
|
1603
|
+
q_pe,
|
1604
|
+
k_pe,
|
1605
|
+
q_nope_out,
|
1606
|
+
k_nope,
|
1607
|
+
forward_batch,
|
1608
|
+
zero_allocator,
|
1609
|
+
positions,
|
1610
|
+
topk_indices,
|
1611
|
+
)
|
1390
1612
|
|
1391
1613
|
def forward_absorb_core(
|
1392
|
-
self,
|
1614
|
+
self,
|
1615
|
+
q_pe,
|
1616
|
+
k_pe,
|
1617
|
+
q_nope_out,
|
1618
|
+
k_nope,
|
1619
|
+
forward_batch,
|
1620
|
+
zero_allocator,
|
1621
|
+
positions,
|
1622
|
+
topk_indices,
|
1393
1623
|
):
|
1394
|
-
if
|
1395
|
-
self.current_attention_backend == "fa3"
|
1396
|
-
or self.current_attention_backend == "flashinfer"
|
1397
|
-
or self.current_attention_backend == "cutlass_mla"
|
1398
|
-
or self.current_attention_backend == "trtllm_mla"
|
1399
|
-
or self.current_attention_backend == "ascend"
|
1400
|
-
):
|
1624
|
+
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
1401
1625
|
extra_args = {}
|
1402
1626
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
1403
1627
|
extra_args = {
|
1404
1628
|
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
1405
1629
|
"is_neox": self.rotary_emb.is_neox_style,
|
1406
1630
|
}
|
1631
|
+
|
1407
1632
|
attn_output = self.attn_mqa(
|
1408
1633
|
q_nope_out,
|
1409
1634
|
k_nope,
|
@@ -1412,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1412
1637
|
q_rope=q_pe,
|
1413
1638
|
k_rope=k_pe,
|
1414
1639
|
**extra_args,
|
1640
|
+
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
1415
1641
|
)
|
1416
1642
|
else:
|
1417
1643
|
if _use_aiter_gfx95:
|
@@ -1431,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1431
1657
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1432
1658
|
k = torch.cat([k_nope, k_pe], dim=-1)
|
1433
1659
|
|
1434
|
-
attn_output = self.attn_mqa(
|
1660
|
+
attn_output = self.attn_mqa(
|
1661
|
+
q,
|
1662
|
+
k,
|
1663
|
+
k_nope,
|
1664
|
+
forward_batch,
|
1665
|
+
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
1666
|
+
)
|
1435
1667
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1436
1668
|
|
1437
1669
|
if self.use_deep_gemm_bmm:
|
@@ -1513,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1513
1745
|
|
1514
1746
|
return output
|
1515
1747
|
|
1748
|
+
def forward_npu_sparse_prepare(
|
1749
|
+
self,
|
1750
|
+
positions: torch.Tensor,
|
1751
|
+
hidden_states: torch.Tensor,
|
1752
|
+
forward_batch: ForwardBatch,
|
1753
|
+
zero_allocator: BumpAllocator,
|
1754
|
+
):
|
1755
|
+
"""
|
1756
|
+
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
|
1757
|
+
"""
|
1758
|
+
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
|
1759
|
+
if self.mla_preprocess is None:
|
1760
|
+
self.mla_preprocess = NPUFusedMLAPreprocess(
|
1761
|
+
self.fused_qkv_a_proj_with_mqa,
|
1762
|
+
self.q_a_layernorm,
|
1763
|
+
self.kv_a_layernorm,
|
1764
|
+
self.q_b_proj,
|
1765
|
+
self.w_kc,
|
1766
|
+
self.rotary_emb,
|
1767
|
+
self.layer_id,
|
1768
|
+
self.num_local_heads,
|
1769
|
+
self.qk_nope_head_dim,
|
1770
|
+
self.qk_rope_head_dim,
|
1771
|
+
)
|
1772
|
+
(
|
1773
|
+
q_pe,
|
1774
|
+
k_pe,
|
1775
|
+
q_nope_out,
|
1776
|
+
k_nope,
|
1777
|
+
forward_batch,
|
1778
|
+
zero_allocator,
|
1779
|
+
positions,
|
1780
|
+
) = self.mla_preprocess.forward(
|
1781
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1782
|
+
)
|
1783
|
+
|
1784
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1785
|
+
q, _ = fused_qkv_a_proj_out.split(
|
1786
|
+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1787
|
+
)
|
1788
|
+
q_lora = self.q_a_layernorm(q)
|
1789
|
+
else:
|
1790
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1791
|
+
|
1792
|
+
if (
|
1793
|
+
(not isinstance(hidden_states, tuple))
|
1794
|
+
and hidden_states.shape[0] <= 16
|
1795
|
+
and self.use_min_latency_fused_a_gemm
|
1796
|
+
):
|
1797
|
+
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1798
|
+
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1799
|
+
)
|
1800
|
+
else:
|
1801
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1802
|
+
q, latent_cache = fused_qkv_a_proj_out.split(
|
1803
|
+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1804
|
+
)
|
1805
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
1806
|
+
|
1807
|
+
# overlap qk norm
|
1808
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
1809
|
+
current_stream = torch.cuda.current_stream()
|
1810
|
+
self.alt_stream.wait_stream(current_stream)
|
1811
|
+
q = self.q_a_layernorm(q)
|
1812
|
+
with torch.cuda.stream(self.alt_stream):
|
1813
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1814
|
+
current_stream.wait_stream(self.alt_stream)
|
1815
|
+
else:
|
1816
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1817
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1818
|
+
q,
|
1819
|
+
self.q_a_layernorm.weight,
|
1820
|
+
self.q_a_layernorm.variance_epsilon,
|
1821
|
+
k_nope,
|
1822
|
+
self.kv_a_layernorm.weight,
|
1823
|
+
self.kv_a_layernorm.variance_epsilon,
|
1824
|
+
)
|
1825
|
+
else:
|
1826
|
+
q = self.q_a_layernorm(q)
|
1827
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1828
|
+
|
1829
|
+
q_lora = q.clone() # required for topk_indices
|
1830
|
+
k_nope = k_nope.unsqueeze(1)
|
1831
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1832
|
+
|
1833
|
+
q_nope, q_pe = q.split(
|
1834
|
+
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
1835
|
+
)
|
1836
|
+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
1837
|
+
|
1838
|
+
if self.use_deep_gemm_bmm:
|
1839
|
+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
1840
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
1841
|
+
q_nope.transpose(0, 1)
|
1842
|
+
)
|
1843
|
+
)
|
1844
|
+
q_nope_out = q_nope.new_empty(
|
1845
|
+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
1846
|
+
)
|
1847
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1848
|
+
(q_nope_val, q_nope_scale),
|
1849
|
+
(self.w_kc, self.w_scale_k),
|
1850
|
+
q_nope_out,
|
1851
|
+
masked_m,
|
1852
|
+
expected_m,
|
1853
|
+
)
|
1854
|
+
q_nope_out = q_nope_out[:, :expected_m, :]
|
1855
|
+
elif _is_hip:
|
1856
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
1857
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1858
|
+
x = q_nope.transpose(0, 1)
|
1859
|
+
q_nope_out = torch.empty(
|
1860
|
+
x.shape[0],
|
1861
|
+
x.shape[1],
|
1862
|
+
self.w_kc.shape[2],
|
1863
|
+
device=x.device,
|
1864
|
+
dtype=torch.bfloat16,
|
1865
|
+
)
|
1866
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1867
|
+
x,
|
1868
|
+
self.w_kc.transpose(-2, -1),
|
1869
|
+
self.w_scale_k.transpose(-2, -1),
|
1870
|
+
torch.bfloat16,
|
1871
|
+
q_nope_out,
|
1872
|
+
)
|
1873
|
+
else:
|
1874
|
+
q_nope_out = torch.bmm(
|
1875
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1876
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1877
|
+
)
|
1878
|
+
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1879
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1880
|
+
q_nope.transpose(0, 1),
|
1881
|
+
zero_allocator.allocate(1),
|
1882
|
+
)
|
1883
|
+
q_nope_out = bmm_fp8(
|
1884
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
1885
|
+
)
|
1886
|
+
else:
|
1887
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
1888
|
+
|
1889
|
+
q_nope_out = q_nope_out.transpose(0, 1)
|
1890
|
+
|
1891
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1892
|
+
not _use_aiter or not _is_gfx95_supported
|
1893
|
+
):
|
1894
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1895
|
+
|
1896
|
+
# TODO: multi-stream indexer
|
1897
|
+
topk_indices = self.indexer(
|
1898
|
+
hidden_states, q_lora, positions, forward_batch, self.layer_id
|
1899
|
+
)
|
1900
|
+
|
1901
|
+
return (
|
1902
|
+
q_pe,
|
1903
|
+
k_pe,
|
1904
|
+
q_nope_out,
|
1905
|
+
k_nope,
|
1906
|
+
topk_indices,
|
1907
|
+
forward_batch,
|
1908
|
+
zero_allocator,
|
1909
|
+
positions,
|
1910
|
+
)
|
1911
|
+
|
1912
|
+
def forward_npu_sparse_core(
|
1913
|
+
self,
|
1914
|
+
q_pe,
|
1915
|
+
k_pe,
|
1916
|
+
q_nope_out,
|
1917
|
+
k_nope,
|
1918
|
+
topk_indices,
|
1919
|
+
forward_batch,
|
1920
|
+
zero_allocator,
|
1921
|
+
positions,
|
1922
|
+
):
|
1923
|
+
attn_output = self.attn_mqa(
|
1924
|
+
q_nope_out.contiguous(),
|
1925
|
+
k_nope.contiguous(),
|
1926
|
+
k_nope.contiguous(),
|
1927
|
+
forward_batch,
|
1928
|
+
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
|
1929
|
+
q_rope=q_pe.contiguous(),
|
1930
|
+
k_rope=k_pe.contiguous(),
|
1931
|
+
topk_indices=topk_indices,
|
1932
|
+
)
|
1933
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1934
|
+
|
1935
|
+
attn_bmm_output = torch.empty(
|
1936
|
+
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
|
1937
|
+
dtype=attn_output.dtype,
|
1938
|
+
device=attn_output.device,
|
1939
|
+
)
|
1940
|
+
|
1941
|
+
if not forward_batch.forward_mode.is_decode():
|
1942
|
+
attn_output = attn_output.transpose(0, 1)
|
1943
|
+
torch.bmm(
|
1944
|
+
attn_output,
|
1945
|
+
self.w_vc,
|
1946
|
+
out=attn_bmm_output.view(
|
1947
|
+
-1, self.num_local_heads, self.v_head_dim
|
1948
|
+
).transpose(0, 1),
|
1949
|
+
)
|
1950
|
+
else:
|
1951
|
+
attn_output = attn_output.contiguous()
|
1952
|
+
torch.ops.npu.batch_matmul_transpose(
|
1953
|
+
attn_output, self.w_vc, attn_bmm_output
|
1954
|
+
)
|
1955
|
+
|
1956
|
+
attn_bmm_output = attn_bmm_output.reshape(
|
1957
|
+
-1, self.num_local_heads * self.v_head_dim
|
1958
|
+
)
|
1959
|
+
|
1960
|
+
output, _ = self.o_proj(attn_bmm_output)
|
1961
|
+
return output
|
1962
|
+
|
1516
1963
|
def forward_absorb_fused_mla_rope_prepare(
|
1517
1964
|
self,
|
1518
1965
|
positions: torch.Tensor,
|
@@ -1838,6 +2285,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1838
2285
|
tmp_lse = torch.empty_like(accum_lse)
|
1839
2286
|
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
1840
2287
|
accum_output, accum_lse = tmp_output, tmp_lse
|
2288
|
+
del kv, k, v, output, lse, tmp_output, tmp_lse
|
1841
2289
|
|
1842
2290
|
return accum_output
|
1843
2291
|
|
@@ -1994,11 +2442,13 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1994
2442
|
zero_allocator: BumpAllocator,
|
1995
2443
|
gemm_output_zero_allocator: BumpAllocator = None,
|
1996
2444
|
) -> torch.Tensor:
|
1997
|
-
|
1998
2445
|
quant_format = (
|
1999
2446
|
"mxfp4"
|
2000
2447
|
if _is_gfx95_supported
|
2001
|
-
and self.self_attn
|
2448
|
+
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
|
2449
|
+
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
|
2450
|
+
is not None
|
2451
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
|
2002
2452
|
else ""
|
2003
2453
|
)
|
2004
2454
|
|
@@ -2170,8 +2620,15 @@ class DeepseekV2Model(nn.Module):
|
|
2170
2620
|
[
|
2171
2621
|
"w13_weight",
|
2172
2622
|
"w2_weight",
|
2173
|
-
|
2174
|
-
|
2623
|
+
# only for nvfp4
|
2624
|
+
*(
|
2625
|
+
[
|
2626
|
+
"w13_blockscale_swizzled",
|
2627
|
+
"w2_blockscale_swizzled",
|
2628
|
+
]
|
2629
|
+
if hasattr(module, "w13_blockscale_swizzled")
|
2630
|
+
else []
|
2631
|
+
),
|
2175
2632
|
]
|
2176
2633
|
if isinstance(module, FusedMoE)
|
2177
2634
|
else []
|
@@ -2553,7 +3010,11 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2553
3010
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2554
3011
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
2555
3012
|
|
2556
|
-
if
|
3013
|
+
if (
|
3014
|
+
_use_aiter_gfx95
|
3015
|
+
and self.quant_config is not None
|
3016
|
+
and self.quant_config.get_name() == "quark"
|
3017
|
+
):
|
2557
3018
|
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
2558
3019
|
quark_post_load_weights(self_attn, w, "mxfp4")
|
2559
3020
|
)
|
@@ -2937,8 +3398,24 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2937
3398
|
)
|
2938
3399
|
|
2939
3400
|
|
3401
|
+
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
3402
|
+
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
3403
|
+
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
3404
|
+
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
3405
|
+
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
3406
|
+
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
3407
|
+
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
3408
|
+
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
3409
|
+
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
|
3410
|
+
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
3411
|
+
|
3412
|
+
|
2940
3413
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
2941
3414
|
pass
|
2942
3415
|
|
2943
3416
|
|
2944
|
-
|
3417
|
+
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
|
3418
|
+
pass
|
3419
|
+
|
3420
|
+
|
3421
|
+
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
|