sglang 0.5.2rc1__py3-none-any.whl → 0.5.3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +330 -31
- sglang/bench_serving.py +267 -32
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/interpreter.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +9 -0
- sglang/srt/configs/model_config.py +181 -82
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +49 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +30 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +21 -6
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +279 -108
- sglang/srt/disaggregation/decode.py +71 -19
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +29 -17
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +55 -537
- sglang/srt/disaggregation/nixl/conn.py +326 -53
- sglang/srt/disaggregation/prefill.py +36 -17
- sglang/srt/disaggregation/utils.py +40 -54
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +192 -113
- sglang/srt/entrypoints/engine.py +59 -18
- sglang/srt/entrypoints/grpc_request_manager.py +855 -0
- sglang/srt/entrypoints/grpc_server.py +810 -0
- sglang/srt/entrypoints/http_server.py +132 -57
- sglang/srt/entrypoints/openai/protocol.py +115 -7
- sglang/srt/entrypoints/openai/serving_base.py +65 -3
- sglang/srt/entrypoints/openai/serving_chat.py +207 -58
- sglang/srt/entrypoints/openai/serving_completions.py +17 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +10 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +49 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +38 -8
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +9 -2
- sglang/srt/function_call/glm4_moe_detector.py +4 -4
- sglang/srt/function_call/gpt_oss_detector.py +24 -1
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +111 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +434 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +239 -0
- sglang/srt/layers/activation.py +143 -9
- sglang/srt/layers/attention/aiter_backend.py +106 -82
- sglang/srt/layers/attention/ascend_backend.py +115 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +118 -198
- sglang/srt/layers/attention/flashinfer_mla_backend.py +27 -27
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +68 -53
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +129 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +969 -0
- sglang/srt/layers/attention/mamba/mamba.py +629 -0
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +57 -7
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +276 -39
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +53 -7
- sglang/srt/layers/dp_attention.py +41 -2
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +34 -15
- sglang/srt/layers/linear.py +55 -7
- sglang/srt/layers/logits_processor.py +44 -12
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -2
- sglang/srt/layers/moe/ep_moe/layer.py +256 -63
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +183 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +7 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +71 -70
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +118 -56
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +22 -7
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +78 -49
- sglang/srt/layers/quantization/fp8_utils.py +51 -32
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +225 -57
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -42
- sglang/srt/layers/quantization/quark/quark_moe.py +48 -30
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +26 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +91 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +78 -49
- sglang/srt/layers/sampler.py +213 -21
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +99 -5
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +8 -3
- sglang/srt/lora/lora_manager.py +52 -118
- sglang/srt/lora/mem_pool.py +25 -11
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +22 -11
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +215 -314
- sglang/srt/managers/data_parallel_controller.py +115 -80
- sglang/srt/managers/detokenizer_manager.py +19 -15
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +340 -109
- sglang/srt/managers/mm_utils.py +44 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +358 -404
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +240 -138
- sglang/srt/managers/schedule_policy.py +147 -19
- sglang/srt/managers/scheduler.py +501 -304
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +119 -40
- sglang/srt/managers/scheduler_output_processor_mixin.py +75 -22
- sglang/srt/managers/scheduler_profiler_mixin.py +6 -6
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +675 -0
- sglang/srt/managers/tokenizer_manager.py +321 -632
- sglang/srt/managers/tp_worker.py +81 -22
- sglang/srt/managers/tp_worker_overlap_thread.py +71 -56
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +15 -21
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +58 -34
- sglang/srt/mem_cache/hiradix_cache.py +227 -80
- sglang/srt/mem_cache/memory_pool.py +535 -58
- sglang/srt/mem_cache/memory_pool_host.py +239 -223
- sglang/srt/mem_cache/radix_cache.py +222 -73
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +268 -63
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +284 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +198 -30
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +25 -36
- sglang/srt/metrics/collector.py +519 -132
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +55 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +52 -37
- sglang/srt/model_executor/forward_batch_info.py +98 -57
- sglang/srt/model_executor/model_runner.py +433 -158
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +133 -5
- sglang/srt/model_loader/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/model_loader/weight_utils.py +158 -3
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +820 -217
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +833 -152
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -2
- sglang/srt/models/glm4_moe.py +14 -5
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +5 -3
- sglang/srt/models/glm4v_moe.py +4 -1
- sglang/srt/models/gpt_oss.py +8 -31
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +3 -3
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +40 -4
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2_5_vl.py +29 -5
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +124 -14
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +26 -5
- sglang/srt/models/qwen3_moe.py +71 -12
- sglang/srt/models/qwen3_next.py +1069 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +10 -3
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +98 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +153 -129
- sglang/srt/multimodal/processors/qwen_vl.py +23 -6
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/offloader.py +27 -3
- sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +6 -0
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/sampling_batch_info.py +38 -17
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +1030 -254
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +13 -2
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -757
- sglang/srt/speculative/eagle_worker.py +253 -136
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +52 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +578 -0
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +445 -77
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +53 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +2 -2
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +5 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +77 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +383 -5
- sglang/utils.py +22 -1
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/METADATA +69 -124
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/RECORD +392 -258
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
# Adapted from:
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
|
+
from __future__ import annotations
|
18
19
|
|
19
20
|
import concurrent.futures
|
20
21
|
import logging
|
@@ -25,9 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
25
26
|
import torch
|
26
27
|
import torch.nn.functional as F
|
27
28
|
from torch import nn
|
28
|
-
from tqdm import tqdm
|
29
29
|
from transformers import PretrainedConfig
|
30
30
|
|
31
|
+
from sglang.srt import single_batch_overlap
|
32
|
+
from sglang.srt.configs.model_config import (
|
33
|
+
get_nsa_index_head_dim,
|
34
|
+
get_nsa_index_n_heads,
|
35
|
+
get_nsa_index_topk,
|
36
|
+
is_deepseek_nsa,
|
37
|
+
)
|
38
|
+
from sglang.srt.debug_utils.dumper import dumper
|
31
39
|
from sglang.srt.distributed import (
|
32
40
|
get_moe_expert_parallel_world_size,
|
33
41
|
get_pp_group,
|
@@ -43,6 +51,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
43
51
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
44
52
|
from sglang.srt.layers.activation import SiluAndMul
|
45
53
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
54
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
55
|
+
NPUFusedMLAPreprocess,
|
56
|
+
is_mla_preprocess_enabled,
|
57
|
+
)
|
58
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
|
46
59
|
from sglang.srt.layers.communicator import (
|
47
60
|
LayerCommunicator,
|
48
61
|
LayerScatterModes,
|
@@ -65,10 +78,11 @@ from sglang.srt.layers.moe import (
|
|
65
78
|
get_deepep_mode,
|
66
79
|
get_moe_a2a_backend,
|
67
80
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
81
|
+
should_use_flashinfer_trtllm_moe,
|
68
82
|
)
|
69
83
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
84
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
71
|
-
from sglang.srt.layers.moe.topk import TopK
|
85
|
+
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
72
86
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
73
87
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
74
88
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -96,6 +110,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
96
110
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
97
111
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
98
112
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
113
|
+
from sglang.srt.single_batch_overlap import SboFlags
|
99
114
|
from sglang.srt.two_batch_overlap import (
|
100
115
|
MaybeTboDeepEPDispatcher,
|
101
116
|
model_forward_maybe_tbo,
|
@@ -112,6 +127,7 @@ from sglang.srt.utils import (
|
|
112
127
|
is_cpu,
|
113
128
|
is_cuda,
|
114
129
|
is_flashinfer_available,
|
130
|
+
is_gfx95_supported,
|
115
131
|
is_hip,
|
116
132
|
is_non_idle_and_non_empty,
|
117
133
|
is_npu,
|
@@ -129,11 +145,28 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
129
145
|
_is_cpu_amx_available = cpu_has_amx_support()
|
130
146
|
_is_cpu = is_cpu()
|
131
147
|
_device_sm = get_device_sm()
|
148
|
+
_is_gfx95_supported = is_gfx95_supported()
|
149
|
+
|
150
|
+
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
151
|
+
|
152
|
+
if _use_aiter_gfx95:
|
153
|
+
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
154
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
155
|
+
batched_gemm_afp4wfp4_pre_quant,
|
156
|
+
fused_flatten_mxfp4_quant,
|
157
|
+
fused_rms_mxfp4_quant,
|
158
|
+
)
|
159
|
+
from sglang.srt.layers.rocm_linear_utils import (
|
160
|
+
aiter_dsv3_router_gemm,
|
161
|
+
fused_qk_rope_cat,
|
162
|
+
get_dsv3_gemm_output_zero_allocator_size,
|
163
|
+
)
|
132
164
|
|
133
165
|
if _is_cuda:
|
134
166
|
from sgl_kernel import (
|
135
167
|
awq_dequantize,
|
136
168
|
bmm_fp8,
|
169
|
+
concat_mla_k,
|
137
170
|
dsv3_fused_a_gemm,
|
138
171
|
dsv3_router_gemm,
|
139
172
|
merge_state_v2,
|
@@ -141,16 +174,18 @@ if _is_cuda:
|
|
141
174
|
elif _is_cpu and _is_cpu_amx_available:
|
142
175
|
pass
|
143
176
|
elif _is_hip:
|
177
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
178
|
+
decode_attention_fwd_grouped_rope,
|
179
|
+
)
|
144
180
|
from sglang.srt.layers.quantization.awq_triton import (
|
145
181
|
awq_dequantize_triton as awq_dequantize,
|
146
182
|
)
|
183
|
+
elif _is_npu:
|
184
|
+
import custom_ops
|
185
|
+
import sgl_kernel_npu
|
186
|
+
import torch_npu
|
147
187
|
else:
|
148
|
-
|
149
|
-
|
150
|
-
if _is_hip:
|
151
|
-
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
152
|
-
decode_attention_fwd_grouped_rope,
|
153
|
-
)
|
188
|
+
pass
|
154
189
|
|
155
190
|
_is_flashinfer_available = is_flashinfer_available()
|
156
191
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
@@ -158,6 +193,21 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
|
|
158
193
|
|
159
194
|
logger = logging.getLogger(__name__)
|
160
195
|
|
196
|
+
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
197
|
+
"fa3",
|
198
|
+
"nsa",
|
199
|
+
"flashinfer",
|
200
|
+
"cutlass_mla",
|
201
|
+
"trtllm_mla",
|
202
|
+
"ascend",
|
203
|
+
]
|
204
|
+
|
205
|
+
|
206
|
+
def add_forward_absorb_core_attention_backend(backend_name):
|
207
|
+
if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
208
|
+
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
|
209
|
+
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
|
210
|
+
|
161
211
|
|
162
212
|
class AttnForwardMethod(IntEnum):
|
163
213
|
# Use multi-head attention
|
@@ -166,6 +216,9 @@ class AttnForwardMethod(IntEnum):
|
|
166
216
|
# Use absorbed multi-latent attention
|
167
217
|
MLA = auto()
|
168
218
|
|
219
|
+
# Use Deepseek V3.2 sparse multi-latent attention
|
220
|
+
NPU_MLA_SPARSE = auto()
|
221
|
+
|
169
222
|
# Use multi-head attention, but with KV cache chunked.
|
170
223
|
# This method can avoid OOM when prefix lengths are long.
|
171
224
|
MHA_CHUNKED_KV = auto()
|
@@ -177,6 +230,146 @@ class AttnForwardMethod(IntEnum):
|
|
177
230
|
MLA_FUSED_ROPE_CPU = auto()
|
178
231
|
|
179
232
|
|
233
|
+
def _dispatch_mla_subtype(attn, forward_batch):
|
234
|
+
if _is_hip:
|
235
|
+
if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
|
236
|
+
return AttnForwardMethod.MLA_FUSED_ROPE
|
237
|
+
else:
|
238
|
+
return AttnForwardMethod.MLA
|
239
|
+
else:
|
240
|
+
if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
|
241
|
+
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
242
|
+
else:
|
243
|
+
return AttnForwardMethod.MLA
|
244
|
+
|
245
|
+
|
246
|
+
class AttentionBackendRegistry:
|
247
|
+
_handlers = {}
|
248
|
+
|
249
|
+
@classmethod
|
250
|
+
def register(cls, backend_name, handler_func):
|
251
|
+
cls._handlers[backend_name] = handler_func
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def get_handler(cls, backend_name):
|
255
|
+
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
256
|
+
|
257
|
+
|
258
|
+
def handle_attention_ascend(attn, forward_batch):
|
259
|
+
if (
|
260
|
+
forward_batch.forward_mode.is_extend()
|
261
|
+
and not forward_batch.forward_mode.is_target_verify()
|
262
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
263
|
+
):
|
264
|
+
if hasattr(attn, "indexer"):
|
265
|
+
return AttnForwardMethod.NPU_MLA_SPARSE
|
266
|
+
else:
|
267
|
+
return AttnForwardMethod.MHA
|
268
|
+
else:
|
269
|
+
if hasattr(attn, "indexer"):
|
270
|
+
return AttnForwardMethod.NPU_MLA_SPARSE
|
271
|
+
else:
|
272
|
+
return AttnForwardMethod.MLA
|
273
|
+
|
274
|
+
|
275
|
+
def _get_sum_extend_prefix_lens(forward_batch):
|
276
|
+
return (
|
277
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
278
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
279
|
+
else 0
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
def _is_extend_without_speculative(forward_batch):
|
284
|
+
return (
|
285
|
+
forward_batch.forward_mode.is_extend()
|
286
|
+
and not forward_batch.forward_mode.is_target_verify()
|
287
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
288
|
+
)
|
289
|
+
|
290
|
+
|
291
|
+
def _handle_attention_backend(
|
292
|
+
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
|
293
|
+
):
|
294
|
+
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
295
|
+
disable_ragged = (
|
296
|
+
backend_name in ["flashinfer", "flashmla"]
|
297
|
+
) and attn.flashinfer_mla_disable_ragged
|
298
|
+
|
299
|
+
if (
|
300
|
+
not disable_ragged
|
301
|
+
and _is_extend_without_speculative(forward_batch)
|
302
|
+
and (
|
303
|
+
(
|
304
|
+
sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
|
305
|
+
and not attn.disable_chunked_prefix_cache
|
306
|
+
)
|
307
|
+
or sum_extend_prefix_lens == 0
|
308
|
+
)
|
309
|
+
):
|
310
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
311
|
+
else:
|
312
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
313
|
+
|
314
|
+
|
315
|
+
def handle_attention_flashinfer(attn, forward_batch):
|
316
|
+
return _handle_attention_backend(attn, forward_batch, "flashinfer")
|
317
|
+
|
318
|
+
|
319
|
+
def handle_attention_fa3(attn, forward_batch):
|
320
|
+
return _handle_attention_backend(attn, forward_batch, "fa3")
|
321
|
+
|
322
|
+
|
323
|
+
def handle_attention_flashmla(attn, forward_batch):
|
324
|
+
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
325
|
+
|
326
|
+
|
327
|
+
def handle_attention_cutlass_mla(attn, forward_batch):
|
328
|
+
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
329
|
+
|
330
|
+
|
331
|
+
def handle_attention_fa4(attn, forward_batch):
|
332
|
+
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
333
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
334
|
+
|
335
|
+
|
336
|
+
def handle_attention_trtllm_mla(attn, forward_batch):
|
337
|
+
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
338
|
+
if _is_extend_without_speculative(forward_batch) and (
|
339
|
+
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
340
|
+
):
|
341
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
342
|
+
else:
|
343
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
344
|
+
|
345
|
+
|
346
|
+
def handle_attention_aiter(attn, forward_batch):
|
347
|
+
if _is_extend_without_speculative(forward_batch):
|
348
|
+
if is_dp_attention_enabled():
|
349
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
350
|
+
return AttnForwardMethod.MHA
|
351
|
+
else:
|
352
|
+
return AttnForwardMethod.MLA
|
353
|
+
else:
|
354
|
+
return AttnForwardMethod.MHA
|
355
|
+
else:
|
356
|
+
return AttnForwardMethod.MLA
|
357
|
+
|
358
|
+
|
359
|
+
def handle_attention_nsa(attn, forward_batch):
|
360
|
+
return AttnForwardMethod.MLA
|
361
|
+
|
362
|
+
|
363
|
+
def handle_attention_triton(attn, forward_batch):
|
364
|
+
if (
|
365
|
+
_is_extend_without_speculative(forward_batch)
|
366
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
367
|
+
):
|
368
|
+
return AttnForwardMethod.MHA
|
369
|
+
else:
|
370
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
371
|
+
|
372
|
+
|
180
373
|
class DeepseekV2MLP(nn.Module):
|
181
374
|
def __init__(
|
182
375
|
self,
|
@@ -224,10 +417,21 @@ class DeepseekV2MLP(nn.Module):
|
|
224
417
|
forward_batch=None,
|
225
418
|
should_allreduce_fusion: bool = False,
|
226
419
|
use_reduce_scatter: bool = False,
|
420
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
227
421
|
):
|
228
422
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
229
423
|
return x
|
230
424
|
|
425
|
+
if (
|
426
|
+
gemm_output_zero_allocator is not None
|
427
|
+
and x.shape[0] <= 256
|
428
|
+
and self.gate_up_proj.weight.dtype == torch.uint8
|
429
|
+
):
|
430
|
+
y = gemm_output_zero_allocator.allocate(
|
431
|
+
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
432
|
+
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
433
|
+
x = (x, None, y)
|
434
|
+
|
231
435
|
gate_up, _ = self.gate_up_proj(x)
|
232
436
|
x = self.act_fn(gate_up)
|
233
437
|
x, _ = self.down_proj(
|
@@ -240,6 +444,7 @@ class MoEGate(nn.Module):
|
|
240
444
|
def __init__(
|
241
445
|
self,
|
242
446
|
config,
|
447
|
+
quant_config,
|
243
448
|
prefix: str = "",
|
244
449
|
is_nextn: bool = False,
|
245
450
|
):
|
@@ -249,15 +454,22 @@ class MoEGate(nn.Module):
|
|
249
454
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
250
455
|
)
|
251
456
|
if config.topk_method == "noaux_tc":
|
457
|
+
correction_bias_dtype = (
|
458
|
+
torch.bfloat16
|
459
|
+
if quant_config is not None
|
460
|
+
and quant_config.get_name() == "modelopt_fp4"
|
461
|
+
and should_use_flashinfer_trtllm_moe()
|
462
|
+
else torch.float32
|
463
|
+
)
|
252
464
|
self.e_score_correction_bias = nn.Parameter(
|
253
|
-
torch.empty((config.n_routed_experts), dtype=
|
465
|
+
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
|
254
466
|
)
|
255
467
|
else:
|
256
468
|
self.e_score_correction_bias = None
|
257
469
|
if _is_cpu and _is_cpu_amx_available:
|
258
470
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
259
471
|
|
260
|
-
def forward(self, hidden_states):
|
472
|
+
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
261
473
|
if use_intel_amx_backend(self):
|
262
474
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
263
475
|
hidden_states,
|
@@ -271,11 +483,17 @@ class MoEGate(nn.Module):
|
|
271
483
|
_is_cuda
|
272
484
|
and hidden_states.shape[0] <= 16
|
273
485
|
and hidden_states.shape[1] == 7168
|
274
|
-
and self.weight.shape[0] == 256
|
486
|
+
and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
|
275
487
|
and _device_sm >= 90
|
276
488
|
):
|
277
489
|
# router gemm output float32
|
278
|
-
logits = dsv3_router_gemm(
|
490
|
+
logits = dsv3_router_gemm(
|
491
|
+
hidden_states, self.weight, out_dtype=torch.float32
|
492
|
+
)
|
493
|
+
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
494
|
+
logits = aiter_dsv3_router_gemm(
|
495
|
+
hidden_states, self.weight, gemm_output_zero_allocator
|
496
|
+
)
|
279
497
|
else:
|
280
498
|
logits = F.linear(hidden_states, self.weight, None)
|
281
499
|
|
@@ -319,7 +537,10 @@ class DeepseekV2MoE(nn.Module):
|
|
319
537
|
)
|
320
538
|
|
321
539
|
self.gate = MoEGate(
|
322
|
-
config=config,
|
540
|
+
config=config,
|
541
|
+
quant_config=quant_config,
|
542
|
+
prefix=add_prefix("gate", prefix),
|
543
|
+
is_nextn=is_nextn,
|
323
544
|
)
|
324
545
|
|
325
546
|
self.experts = get_moe_impl_class(quant_config)(
|
@@ -344,9 +565,12 @@ class DeepseekV2MoE(nn.Module):
|
|
344
565
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
345
566
|
topk_group=config.topk_group,
|
346
567
|
correction_bias=self.gate.e_score_correction_bias,
|
568
|
+
quant_config=quant_config,
|
347
569
|
routed_scaling_factor=self.routed_scaling_factor,
|
348
|
-
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk
|
349
|
-
|
570
|
+
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
|
571
|
+
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
572
|
+
# and requires the output format to be standard. We use quant_config to determine the output format.
|
573
|
+
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
350
574
|
)
|
351
575
|
|
352
576
|
self.shared_experts_is_int8 = False
|
@@ -439,6 +663,7 @@ class DeepseekV2MoE(nn.Module):
|
|
439
663
|
forward_batch: Optional[ForwardBatch] = None,
|
440
664
|
should_allreduce_fusion: bool = False,
|
441
665
|
use_reduce_scatter: bool = False,
|
666
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
442
667
|
) -> torch.Tensor:
|
443
668
|
if not self._enable_deepep_moe:
|
444
669
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -452,12 +677,14 @@ class DeepseekV2MoE(nn.Module):
|
|
452
677
|
hidden_states,
|
453
678
|
should_allreduce_fusion,
|
454
679
|
use_reduce_scatter,
|
680
|
+
gemm_output_zero_allocator,
|
455
681
|
)
|
456
682
|
else:
|
457
683
|
return self.forward_normal(
|
458
684
|
hidden_states,
|
459
685
|
should_allreduce_fusion,
|
460
686
|
use_reduce_scatter,
|
687
|
+
gemm_output_zero_allocator,
|
461
688
|
)
|
462
689
|
else:
|
463
690
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -467,15 +694,18 @@ class DeepseekV2MoE(nn.Module):
|
|
467
694
|
hidden_states: torch.Tensor,
|
468
695
|
should_allreduce_fusion: bool = False,
|
469
696
|
use_reduce_scatter: bool = False,
|
697
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
470
698
|
) -> torch.Tensor:
|
471
699
|
|
472
700
|
current_stream = torch.cuda.current_stream()
|
473
701
|
self.alt_stream.wait_stream(current_stream)
|
474
|
-
shared_output = self._forward_shared_experts(
|
702
|
+
shared_output = self._forward_shared_experts(
|
703
|
+
hidden_states, gemm_output_zero_allocator
|
704
|
+
)
|
475
705
|
|
476
706
|
with torch.cuda.stream(self.alt_stream):
|
477
707
|
# router_logits: (num_tokens, n_experts)
|
478
|
-
router_logits = self.gate(hidden_states)
|
708
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
479
709
|
topk_output = self.topk(hidden_states, router_logits)
|
480
710
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
481
711
|
if not _is_cuda:
|
@@ -502,6 +732,7 @@ class DeepseekV2MoE(nn.Module):
|
|
502
732
|
hidden_states: torch.Tensor,
|
503
733
|
should_allreduce_fusion: bool = False,
|
504
734
|
use_reduce_scatter: bool = False,
|
735
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
505
736
|
) -> torch.Tensor:
|
506
737
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
507
738
|
self.shared_experts.gate_up_proj
|
@@ -509,9 +740,11 @@ class DeepseekV2MoE(nn.Module):
|
|
509
740
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
510
741
|
|
511
742
|
if hidden_states.shape[0] > 0:
|
512
|
-
shared_output = self._forward_shared_experts(
|
743
|
+
shared_output = self._forward_shared_experts(
|
744
|
+
hidden_states, gemm_output_zero_allocator
|
745
|
+
)
|
513
746
|
# router_logits: (num_tokens, n_experts)
|
514
|
-
router_logits = self.gate(hidden_states)
|
747
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
515
748
|
topk_output = self.topk(hidden_states, router_logits)
|
516
749
|
else:
|
517
750
|
shared_output = None
|
@@ -601,7 +834,8 @@ class DeepseekV2MoE(nn.Module):
|
|
601
834
|
if hidden_states.shape[0] > 0:
|
602
835
|
# router_logits: (num_tokens, n_experts)
|
603
836
|
router_logits = self.gate(hidden_states)
|
604
|
-
|
837
|
+
if not SboFlags.fuse_shared_experts_inside_sbo():
|
838
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
605
839
|
topk_weights, topk_idx, _ = self.topk(
|
606
840
|
hidden_states,
|
607
841
|
router_logits,
|
@@ -615,25 +849,39 @@ class DeepseekV2MoE(nn.Module):
|
|
615
849
|
hidden_states.device
|
616
850
|
)
|
617
851
|
|
618
|
-
final_hidden_states =
|
852
|
+
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
619
853
|
hidden_states=hidden_states,
|
620
854
|
topk_idx=topk_idx,
|
621
855
|
topk_weights=topk_weights,
|
622
856
|
forward_batch=forward_batch,
|
857
|
+
# SBO args
|
858
|
+
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
859
|
+
experts=self.experts,
|
860
|
+
alt_stream=self.alt_stream,
|
623
861
|
)
|
862
|
+
if sbo_shared_output is not None:
|
863
|
+
shared_output = sbo_shared_output
|
624
864
|
|
625
865
|
if shared_output is not None:
|
626
866
|
x = shared_output
|
627
|
-
|
867
|
+
if self.experts.should_fuse_routed_scaling_factor_in_topk:
|
868
|
+
x.add_(final_hidden_states)
|
869
|
+
else:
|
870
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
628
871
|
final_hidden_states = x
|
629
872
|
else:
|
630
|
-
|
873
|
+
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
|
874
|
+
final_hidden_states *= self.routed_scaling_factor
|
631
875
|
|
632
876
|
return final_hidden_states
|
633
877
|
|
634
|
-
def _forward_shared_experts(
|
635
|
-
|
636
|
-
|
878
|
+
def _forward_shared_experts(
|
879
|
+
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
880
|
+
):
|
881
|
+
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
|
882
|
+
return self.shared_experts(
|
883
|
+
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
884
|
+
)
|
637
885
|
else:
|
638
886
|
return None
|
639
887
|
|
@@ -683,6 +931,7 @@ class DeepseekV2MoE(nn.Module):
|
|
683
931
|
if self.ep_size > 1:
|
684
932
|
self.experts.deepep_dispatcher.dispatch_a(
|
685
933
|
hidden_states=state.hidden_states_mlp_input,
|
934
|
+
input_global_scale=None,
|
686
935
|
topk_idx=state.pop("topk_idx_local"),
|
687
936
|
topk_weights=state.pop("topk_weights_local"),
|
688
937
|
forward_batch=state.forward_batch,
|
@@ -783,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
783
1032
|
self.rope_theta = rope_theta
|
784
1033
|
self.max_position_embeddings = max_position_embeddings
|
785
1034
|
|
1035
|
+
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
|
1036
|
+
if rope_scaling:
|
1037
|
+
rope_scaling["rope_type"] = "deepseek_yarn"
|
1038
|
+
|
786
1039
|
# For tensor parallel attention
|
787
1040
|
if self.q_lora_rank is not None:
|
788
1041
|
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -820,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
820
1073
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
821
1074
|
)
|
822
1075
|
|
1076
|
+
self.use_nsa = is_deepseek_nsa(config)
|
1077
|
+
if self.use_nsa:
|
1078
|
+
self.indexer = Indexer(
|
1079
|
+
hidden_size=hidden_size,
|
1080
|
+
index_n_heads=get_nsa_index_n_heads(config),
|
1081
|
+
index_head_dim=get_nsa_index_head_dim(config),
|
1082
|
+
rope_head_dim=qk_rope_head_dim,
|
1083
|
+
index_topk=get_nsa_index_topk(config),
|
1084
|
+
q_lora_rank=q_lora_rank,
|
1085
|
+
max_position_embeddings=max_position_embeddings,
|
1086
|
+
rope_theta=rope_theta,
|
1087
|
+
scale_fmt="ue8m0",
|
1088
|
+
block_size=128,
|
1089
|
+
rope_scaling=rope_scaling,
|
1090
|
+
prefix=add_prefix("indexer", prefix),
|
1091
|
+
quant_config=quant_config,
|
1092
|
+
layer_id=layer_id,
|
1093
|
+
alt_stream=alt_stream,
|
1094
|
+
)
|
1095
|
+
|
823
1096
|
self.kv_b_proj = ColumnParallelLinear(
|
824
1097
|
self.kv_lora_rank,
|
825
1098
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
@@ -842,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
842
1115
|
)
|
843
1116
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
844
1117
|
|
845
|
-
if rope_scaling:
|
846
|
-
rope_scaling["rope_type"] = "deepseek_yarn"
|
847
|
-
|
848
1118
|
self.rotary_emb = get_rope_wrapper(
|
849
1119
|
qk_rope_head_dim,
|
850
1120
|
rotary_dim=qk_rope_head_dim,
|
@@ -968,96 +1238,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
968
1238
|
self.weight_block_size = (
|
969
1239
|
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
970
1240
|
)
|
1241
|
+
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
1242
|
+
if self.is_mla_preprocess_enabled:
|
1243
|
+
assert (
|
1244
|
+
quant_config is None or quant_config.get_name() == "w8a8_int8"
|
1245
|
+
), "MLA Preprocess only works with Unquant or W8A8Int8"
|
1246
|
+
self.mla_preprocess = None
|
971
1247
|
|
972
1248
|
def dispatch_attn_forward_method(
|
973
1249
|
self, forward_batch: ForwardBatch
|
974
1250
|
) -> AttnForwardMethod:
|
975
|
-
def _dispatch_mla_subtype():
|
976
|
-
if _is_hip:
|
977
|
-
if (
|
978
|
-
self.rocm_fused_decode_mla
|
979
|
-
and forward_batch.forward_mode.is_decode()
|
980
|
-
):
|
981
|
-
return AttnForwardMethod.MLA_FUSED_ROPE
|
982
|
-
else:
|
983
|
-
return AttnForwardMethod.MLA
|
984
|
-
else:
|
985
|
-
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
986
|
-
self
|
987
|
-
):
|
988
|
-
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
989
|
-
else:
|
990
|
-
return AttnForwardMethod.MLA
|
991
|
-
|
992
1251
|
# Determine attention backend used by current forward batch
|
993
1252
|
if forward_batch.forward_mode.is_decode_or_idle():
|
994
1253
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1254
|
+
elif (
|
1255
|
+
forward_batch.forward_mode.is_target_verify()
|
1256
|
+
or forward_batch.forward_mode.is_draft_extend()
|
1257
|
+
):
|
1258
|
+
# Use the specified backend for speculative operations (both verify and draft extend)
|
1259
|
+
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
1260
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1261
|
+
else: # default to prefill
|
1262
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
995
1263
|
else:
|
996
1264
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
997
1265
|
self.current_attention_backend = attention_backend
|
998
1266
|
|
999
|
-
|
1000
|
-
|
1001
|
-
forward_batch.forward_mode.is_extend()
|
1002
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1003
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1004
|
-
):
|
1005
|
-
return AttnForwardMethod.MHA
|
1006
|
-
else:
|
1007
|
-
return AttnForwardMethod.MLA
|
1008
|
-
elif (
|
1009
|
-
attention_backend == "flashinfer"
|
1010
|
-
or attention_backend == "fa3"
|
1011
|
-
or attention_backend == "flashmla"
|
1012
|
-
or attention_backend == "trtllm_mla"
|
1013
|
-
or attention_backend == "cutlass_mla"
|
1014
|
-
):
|
1015
|
-
# Use MHA with chunked KV cache when prefilling on long sequences.
|
1016
|
-
sum_extend_prefix_lens = (
|
1017
|
-
sum(forward_batch.extend_prefix_lens_cpu)
|
1018
|
-
if forward_batch.extend_prefix_lens_cpu is not None
|
1019
|
-
else 0
|
1020
|
-
)
|
1021
|
-
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
1022
|
-
disable_ragged = (
|
1023
|
-
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1024
|
-
) and self.flashinfer_mla_disable_ragged
|
1025
|
-
if (
|
1026
|
-
not disable_ragged
|
1027
|
-
and forward_batch.forward_mode.is_extend()
|
1028
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1029
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1030
|
-
and (
|
1031
|
-
(
|
1032
|
-
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
1033
|
-
and not self.disable_chunked_prefix_cache
|
1034
|
-
)
|
1035
|
-
or sum_extend_prefix_lens == 0
|
1036
|
-
)
|
1037
|
-
):
|
1038
|
-
return AttnForwardMethod.MHA_CHUNKED_KV
|
1039
|
-
else:
|
1040
|
-
return _dispatch_mla_subtype()
|
1041
|
-
elif attention_backend == "aiter":
|
1042
|
-
if (
|
1043
|
-
forward_batch.forward_mode.is_extend()
|
1044
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1045
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1046
|
-
):
|
1047
|
-
return AttnForwardMethod.MHA
|
1048
|
-
else:
|
1049
|
-
return AttnForwardMethod.MLA
|
1050
|
-
else:
|
1051
|
-
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
1052
|
-
if (
|
1053
|
-
forward_batch.forward_mode.is_extend()
|
1054
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1055
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1056
|
-
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
1057
|
-
):
|
1058
|
-
return AttnForwardMethod.MHA
|
1059
|
-
else:
|
1060
|
-
return _dispatch_mla_subtype()
|
1267
|
+
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
1268
|
+
return handler(self, forward_batch)
|
1061
1269
|
|
1062
1270
|
def op_prepare(self, state):
|
1063
1271
|
state.attn_intermediate_state = self.forward_prepare(
|
@@ -1097,14 +1305,21 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1097
1305
|
if self.attn_mha.kv_b_proj is None:
|
1098
1306
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
1099
1307
|
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1308
|
+
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
1309
|
+
if isinstance(hidden_states, tuple):
|
1310
|
+
if hidden_states[0].shape[0] == 0:
|
1311
|
+
assert (
|
1312
|
+
not self.o_proj.reduce_results
|
1313
|
+
), "short-circuiting allreduce will lead to hangs"
|
1314
|
+
return hidden_states[0]
|
1315
|
+
else:
|
1316
|
+
if hidden_states.shape[0] == 0:
|
1317
|
+
assert (
|
1318
|
+
not self.o_proj.reduce_results
|
1319
|
+
), "short-circuiting allreduce will lead to hangs"
|
1320
|
+
return hidden_states, None, forward_batch, None
|
1105
1321
|
|
1106
1322
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1107
|
-
|
1108
1323
|
if attn_forward_method == AttnForwardMethod.MHA:
|
1109
1324
|
inner_state = self.forward_normal_prepare(
|
1110
1325
|
positions, hidden_states, forward_batch, zero_allocator
|
@@ -1114,7 +1329,30 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1114
1329
|
positions, hidden_states, forward_batch, zero_allocator
|
1115
1330
|
)
|
1116
1331
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
1117
|
-
|
1332
|
+
if not self.is_mla_preprocess_enabled:
|
1333
|
+
inner_state = self.forward_absorb_prepare(
|
1334
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1335
|
+
)
|
1336
|
+
else:
|
1337
|
+
# TODO(iforgetmyname): to be separated as a standalone func
|
1338
|
+
if self.mla_preprocess is None:
|
1339
|
+
self.mla_preprocess = NPUFusedMLAPreprocess(
|
1340
|
+
self.fused_qkv_a_proj_with_mqa,
|
1341
|
+
self.q_a_layernorm,
|
1342
|
+
self.kv_a_layernorm,
|
1343
|
+
self.q_b_proj,
|
1344
|
+
self.w_kc,
|
1345
|
+
self.rotary_emb,
|
1346
|
+
self.layer_id,
|
1347
|
+
self.num_local_heads,
|
1348
|
+
self.qk_nope_head_dim,
|
1349
|
+
self.qk_rope_head_dim,
|
1350
|
+
)
|
1351
|
+
inner_state = self.mla_preprocess.forward(
|
1352
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1353
|
+
)
|
1354
|
+
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
1355
|
+
inner_state = self.forward_npu_sparse_prepare(
|
1118
1356
|
positions, hidden_states, forward_batch, zero_allocator
|
1119
1357
|
)
|
1120
1358
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
@@ -1142,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1142
1380
|
return self.forward_normal_chunked_kv_core(*inner_state)
|
1143
1381
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
1144
1382
|
return self.forward_absorb_core(*inner_state)
|
1383
|
+
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
1384
|
+
return self.forward_npu_sparse_core(*inner_state)
|
1145
1385
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
1146
1386
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
1147
1387
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
@@ -1180,8 +1420,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1180
1420
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1181
1421
|
q[..., self.qk_nope_head_dim :] = q_pe
|
1182
1422
|
k = torch.empty_like(q)
|
1183
|
-
|
1184
|
-
|
1423
|
+
|
1424
|
+
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
1425
|
+
if (
|
1426
|
+
_is_cuda
|
1427
|
+
and (self.num_local_heads == 128)
|
1428
|
+
and (self.qk_nope_head_dim == 128)
|
1429
|
+
and (self.qk_rope_head_dim == 64)
|
1430
|
+
):
|
1431
|
+
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
1432
|
+
else:
|
1433
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1434
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1185
1435
|
|
1186
1436
|
if not _is_npu:
|
1187
1437
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
@@ -1211,7 +1461,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1211
1461
|
"""
|
1212
1462
|
return (
|
1213
1463
|
self.current_attention_backend == "trtllm_mla"
|
1214
|
-
and
|
1464
|
+
and (
|
1465
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
1466
|
+
or forward_batch.forward_mode.is_target_verify()
|
1467
|
+
)
|
1215
1468
|
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
1216
1469
|
)
|
1217
1470
|
|
@@ -1224,8 +1477,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1224
1477
|
):
|
1225
1478
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1226
1479
|
|
1480
|
+
q_lora = None
|
1227
1481
|
if self.q_lora_rank is not None:
|
1228
|
-
if
|
1482
|
+
if (
|
1483
|
+
(not isinstance(hidden_states, tuple))
|
1484
|
+
and hidden_states.shape[0] <= 16
|
1485
|
+
and self.use_min_latency_fused_a_gemm
|
1486
|
+
):
|
1229
1487
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1230
1488
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1231
1489
|
)
|
@@ -1245,8 +1503,22 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1245
1503
|
k_nope = self.kv_a_layernorm(k_nope)
|
1246
1504
|
current_stream.wait_stream(self.alt_stream)
|
1247
1505
|
else:
|
1248
|
-
|
1249
|
-
|
1506
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1507
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1508
|
+
q,
|
1509
|
+
self.q_a_layernorm.weight,
|
1510
|
+
self.q_a_layernorm.variance_epsilon,
|
1511
|
+
k_nope,
|
1512
|
+
self.kv_a_layernorm.weight,
|
1513
|
+
self.kv_a_layernorm.variance_epsilon,
|
1514
|
+
)
|
1515
|
+
else:
|
1516
|
+
q = self.q_a_layernorm(q)
|
1517
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1518
|
+
|
1519
|
+
# q_lora needed by indexer
|
1520
|
+
if self.use_nsa:
|
1521
|
+
q_lora = q
|
1250
1522
|
|
1251
1523
|
k_nope = k_nope.unsqueeze(1)
|
1252
1524
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
@@ -1278,10 +1550,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1278
1550
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
1279
1551
|
elif _is_hip:
|
1280
1552
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1281
|
-
|
1282
|
-
q_nope.
|
1283
|
-
|
1284
|
-
|
1553
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1554
|
+
x = q_nope.transpose(0, 1)
|
1555
|
+
q_nope_out = torch.empty(
|
1556
|
+
x.shape[0],
|
1557
|
+
x.shape[1],
|
1558
|
+
self.w_kc.shape[2],
|
1559
|
+
device=x.device,
|
1560
|
+
dtype=torch.bfloat16,
|
1561
|
+
)
|
1562
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1563
|
+
x,
|
1564
|
+
self.w_kc.transpose(-2, -1),
|
1565
|
+
self.w_scale_k.transpose(-2, -1),
|
1566
|
+
torch.bfloat16,
|
1567
|
+
q_nope_out,
|
1568
|
+
)
|
1569
|
+
else:
|
1570
|
+
q_nope_out = torch.bmm(
|
1571
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1572
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1573
|
+
)
|
1285
1574
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1286
1575
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1287
1576
|
q_nope.transpose(0, 1),
|
@@ -1295,27 +1584,51 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1295
1584
|
|
1296
1585
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1297
1586
|
|
1298
|
-
if not self._fuse_rope_for_trtllm_mla(forward_batch)
|
1587
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1588
|
+
not _use_aiter or not _is_gfx95_supported or self.use_nsa
|
1589
|
+
):
|
1299
1590
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1300
1591
|
|
1301
|
-
|
1592
|
+
topk_indices = None
|
1593
|
+
if q_lora is not None:
|
1594
|
+
topk_indices = self.indexer(
|
1595
|
+
x=hidden_states,
|
1596
|
+
q_lora=q_lora,
|
1597
|
+
positions=positions,
|
1598
|
+
forward_batch=forward_batch,
|
1599
|
+
layer_id=self.layer_id,
|
1600
|
+
)
|
1601
|
+
|
1602
|
+
return (
|
1603
|
+
q_pe,
|
1604
|
+
k_pe,
|
1605
|
+
q_nope_out,
|
1606
|
+
k_nope,
|
1607
|
+
forward_batch,
|
1608
|
+
zero_allocator,
|
1609
|
+
positions,
|
1610
|
+
topk_indices,
|
1611
|
+
)
|
1302
1612
|
|
1303
1613
|
def forward_absorb_core(
|
1304
|
-
self,
|
1614
|
+
self,
|
1615
|
+
q_pe,
|
1616
|
+
k_pe,
|
1617
|
+
q_nope_out,
|
1618
|
+
k_nope,
|
1619
|
+
forward_batch,
|
1620
|
+
zero_allocator,
|
1621
|
+
positions,
|
1622
|
+
topk_indices,
|
1305
1623
|
):
|
1306
|
-
if
|
1307
|
-
self.current_attention_backend == "fa3"
|
1308
|
-
or self.current_attention_backend == "flashinfer"
|
1309
|
-
or self.current_attention_backend == "cutlass_mla"
|
1310
|
-
or self.current_attention_backend == "trtllm_mla"
|
1311
|
-
or self.current_attention_backend == "ascend"
|
1312
|
-
):
|
1624
|
+
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
1313
1625
|
extra_args = {}
|
1314
1626
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
1315
1627
|
extra_args = {
|
1316
1628
|
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
1317
1629
|
"is_neox": self.rotary_emb.is_neox_style,
|
1318
1630
|
}
|
1631
|
+
|
1319
1632
|
attn_output = self.attn_mqa(
|
1320
1633
|
q_nope_out,
|
1321
1634
|
k_nope,
|
@@ -1324,11 +1637,33 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1324
1637
|
q_rope=q_pe,
|
1325
1638
|
k_rope=k_pe,
|
1326
1639
|
**extra_args,
|
1640
|
+
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
1327
1641
|
)
|
1328
1642
|
else:
|
1329
|
-
|
1330
|
-
|
1331
|
-
|
1643
|
+
if _use_aiter_gfx95:
|
1644
|
+
cos = self.rotary_emb.cos_cache
|
1645
|
+
sin = self.rotary_emb.sin_cache
|
1646
|
+
q, k = fused_qk_rope_cat(
|
1647
|
+
q_nope_out,
|
1648
|
+
q_pe,
|
1649
|
+
k_nope,
|
1650
|
+
k_pe,
|
1651
|
+
positions,
|
1652
|
+
cos,
|
1653
|
+
sin,
|
1654
|
+
self.rotary_emb.is_neox_style,
|
1655
|
+
)
|
1656
|
+
else:
|
1657
|
+
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1658
|
+
k = torch.cat([k_nope, k_pe], dim=-1)
|
1659
|
+
|
1660
|
+
attn_output = self.attn_mqa(
|
1661
|
+
q,
|
1662
|
+
k,
|
1663
|
+
k_nope,
|
1664
|
+
forward_batch,
|
1665
|
+
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
1666
|
+
)
|
1332
1667
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1333
1668
|
|
1334
1669
|
if self.use_deep_gemm_bmm:
|
@@ -1352,11 +1687,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1352
1687
|
)
|
1353
1688
|
elif _is_hip:
|
1354
1689
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1355
|
-
|
1356
|
-
attn_output.
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1690
|
+
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
1691
|
+
x = attn_output.transpose(0, 1)
|
1692
|
+
attn_bmm_output = torch.empty(
|
1693
|
+
x.shape[0],
|
1694
|
+
x.shape[1],
|
1695
|
+
self.w_vc.shape[2],
|
1696
|
+
device=x.device,
|
1697
|
+
dtype=torch.bfloat16,
|
1698
|
+
)
|
1699
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1700
|
+
x,
|
1701
|
+
self.w_vc.transpose(-2, -1),
|
1702
|
+
self.w_scale_v.transpose(-2, -1),
|
1703
|
+
torch.bfloat16,
|
1704
|
+
attn_bmm_output,
|
1705
|
+
)
|
1706
|
+
else:
|
1707
|
+
attn_bmm_output = torch.bmm(
|
1708
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1709
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1710
|
+
)
|
1711
|
+
|
1712
|
+
if self.o_proj.weight.dtype == torch.uint8:
|
1713
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
1714
|
+
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
1715
|
+
else:
|
1716
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1717
|
+
|
1360
1718
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1361
1719
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1362
1720
|
attn_output.transpose(0, 1),
|
@@ -1387,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1387
1745
|
|
1388
1746
|
return output
|
1389
1747
|
|
1748
|
+
def forward_npu_sparse_prepare(
|
1749
|
+
self,
|
1750
|
+
positions: torch.Tensor,
|
1751
|
+
hidden_states: torch.Tensor,
|
1752
|
+
forward_batch: ForwardBatch,
|
1753
|
+
zero_allocator: BumpAllocator,
|
1754
|
+
):
|
1755
|
+
"""
|
1756
|
+
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
|
1757
|
+
"""
|
1758
|
+
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
|
1759
|
+
if self.mla_preprocess is None:
|
1760
|
+
self.mla_preprocess = NPUFusedMLAPreprocess(
|
1761
|
+
self.fused_qkv_a_proj_with_mqa,
|
1762
|
+
self.q_a_layernorm,
|
1763
|
+
self.kv_a_layernorm,
|
1764
|
+
self.q_b_proj,
|
1765
|
+
self.w_kc,
|
1766
|
+
self.rotary_emb,
|
1767
|
+
self.layer_id,
|
1768
|
+
self.num_local_heads,
|
1769
|
+
self.qk_nope_head_dim,
|
1770
|
+
self.qk_rope_head_dim,
|
1771
|
+
)
|
1772
|
+
(
|
1773
|
+
q_pe,
|
1774
|
+
k_pe,
|
1775
|
+
q_nope_out,
|
1776
|
+
k_nope,
|
1777
|
+
forward_batch,
|
1778
|
+
zero_allocator,
|
1779
|
+
positions,
|
1780
|
+
) = self.mla_preprocess.forward(
|
1781
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1782
|
+
)
|
1783
|
+
|
1784
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1785
|
+
q, _ = fused_qkv_a_proj_out.split(
|
1786
|
+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1787
|
+
)
|
1788
|
+
q_lora = self.q_a_layernorm(q)
|
1789
|
+
else:
|
1790
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1791
|
+
|
1792
|
+
if (
|
1793
|
+
(not isinstance(hidden_states, tuple))
|
1794
|
+
and hidden_states.shape[0] <= 16
|
1795
|
+
and self.use_min_latency_fused_a_gemm
|
1796
|
+
):
|
1797
|
+
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1798
|
+
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1799
|
+
)
|
1800
|
+
else:
|
1801
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1802
|
+
q, latent_cache = fused_qkv_a_proj_out.split(
|
1803
|
+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1804
|
+
)
|
1805
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
1806
|
+
|
1807
|
+
# overlap qk norm
|
1808
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
1809
|
+
current_stream = torch.cuda.current_stream()
|
1810
|
+
self.alt_stream.wait_stream(current_stream)
|
1811
|
+
q = self.q_a_layernorm(q)
|
1812
|
+
with torch.cuda.stream(self.alt_stream):
|
1813
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1814
|
+
current_stream.wait_stream(self.alt_stream)
|
1815
|
+
else:
|
1816
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1817
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1818
|
+
q,
|
1819
|
+
self.q_a_layernorm.weight,
|
1820
|
+
self.q_a_layernorm.variance_epsilon,
|
1821
|
+
k_nope,
|
1822
|
+
self.kv_a_layernorm.weight,
|
1823
|
+
self.kv_a_layernorm.variance_epsilon,
|
1824
|
+
)
|
1825
|
+
else:
|
1826
|
+
q = self.q_a_layernorm(q)
|
1827
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1828
|
+
|
1829
|
+
q_lora = q.clone() # required for topk_indices
|
1830
|
+
k_nope = k_nope.unsqueeze(1)
|
1831
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1832
|
+
|
1833
|
+
q_nope, q_pe = q.split(
|
1834
|
+
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
1835
|
+
)
|
1836
|
+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
1837
|
+
|
1838
|
+
if self.use_deep_gemm_bmm:
|
1839
|
+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
1840
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
1841
|
+
q_nope.transpose(0, 1)
|
1842
|
+
)
|
1843
|
+
)
|
1844
|
+
q_nope_out = q_nope.new_empty(
|
1845
|
+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
1846
|
+
)
|
1847
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1848
|
+
(q_nope_val, q_nope_scale),
|
1849
|
+
(self.w_kc, self.w_scale_k),
|
1850
|
+
q_nope_out,
|
1851
|
+
masked_m,
|
1852
|
+
expected_m,
|
1853
|
+
)
|
1854
|
+
q_nope_out = q_nope_out[:, :expected_m, :]
|
1855
|
+
elif _is_hip:
|
1856
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
1857
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1858
|
+
x = q_nope.transpose(0, 1)
|
1859
|
+
q_nope_out = torch.empty(
|
1860
|
+
x.shape[0],
|
1861
|
+
x.shape[1],
|
1862
|
+
self.w_kc.shape[2],
|
1863
|
+
device=x.device,
|
1864
|
+
dtype=torch.bfloat16,
|
1865
|
+
)
|
1866
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1867
|
+
x,
|
1868
|
+
self.w_kc.transpose(-2, -1),
|
1869
|
+
self.w_scale_k.transpose(-2, -1),
|
1870
|
+
torch.bfloat16,
|
1871
|
+
q_nope_out,
|
1872
|
+
)
|
1873
|
+
else:
|
1874
|
+
q_nope_out = torch.bmm(
|
1875
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1876
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1877
|
+
)
|
1878
|
+
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1879
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1880
|
+
q_nope.transpose(0, 1),
|
1881
|
+
zero_allocator.allocate(1),
|
1882
|
+
)
|
1883
|
+
q_nope_out = bmm_fp8(
|
1884
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
1885
|
+
)
|
1886
|
+
else:
|
1887
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
1888
|
+
|
1889
|
+
q_nope_out = q_nope_out.transpose(0, 1)
|
1890
|
+
|
1891
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1892
|
+
not _use_aiter or not _is_gfx95_supported
|
1893
|
+
):
|
1894
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1895
|
+
|
1896
|
+
# TODO: multi-stream indexer
|
1897
|
+
topk_indices = self.indexer(
|
1898
|
+
hidden_states, q_lora, positions, forward_batch, self.layer_id
|
1899
|
+
)
|
1900
|
+
|
1901
|
+
return (
|
1902
|
+
q_pe,
|
1903
|
+
k_pe,
|
1904
|
+
q_nope_out,
|
1905
|
+
k_nope,
|
1906
|
+
topk_indices,
|
1907
|
+
forward_batch,
|
1908
|
+
zero_allocator,
|
1909
|
+
positions,
|
1910
|
+
)
|
1911
|
+
|
1912
|
+
def forward_npu_sparse_core(
|
1913
|
+
self,
|
1914
|
+
q_pe,
|
1915
|
+
k_pe,
|
1916
|
+
q_nope_out,
|
1917
|
+
k_nope,
|
1918
|
+
topk_indices,
|
1919
|
+
forward_batch,
|
1920
|
+
zero_allocator,
|
1921
|
+
positions,
|
1922
|
+
):
|
1923
|
+
attn_output = self.attn_mqa(
|
1924
|
+
q_nope_out.contiguous(),
|
1925
|
+
k_nope.contiguous(),
|
1926
|
+
k_nope.contiguous(),
|
1927
|
+
forward_batch,
|
1928
|
+
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
|
1929
|
+
q_rope=q_pe.contiguous(),
|
1930
|
+
k_rope=k_pe.contiguous(),
|
1931
|
+
topk_indices=topk_indices,
|
1932
|
+
)
|
1933
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1934
|
+
|
1935
|
+
attn_bmm_output = torch.empty(
|
1936
|
+
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
|
1937
|
+
dtype=attn_output.dtype,
|
1938
|
+
device=attn_output.device,
|
1939
|
+
)
|
1940
|
+
|
1941
|
+
if not forward_batch.forward_mode.is_decode():
|
1942
|
+
attn_output = attn_output.transpose(0, 1)
|
1943
|
+
torch.bmm(
|
1944
|
+
attn_output,
|
1945
|
+
self.w_vc,
|
1946
|
+
out=attn_bmm_output.view(
|
1947
|
+
-1, self.num_local_heads, self.v_head_dim
|
1948
|
+
).transpose(0, 1),
|
1949
|
+
)
|
1950
|
+
else:
|
1951
|
+
attn_output = attn_output.contiguous()
|
1952
|
+
torch.ops.npu.batch_matmul_transpose(
|
1953
|
+
attn_output, self.w_vc, attn_bmm_output
|
1954
|
+
)
|
1955
|
+
|
1956
|
+
attn_bmm_output = attn_bmm_output.reshape(
|
1957
|
+
-1, self.num_local_heads * self.v_head_dim
|
1958
|
+
)
|
1959
|
+
|
1960
|
+
output, _ = self.o_proj(attn_bmm_output)
|
1961
|
+
return output
|
1962
|
+
|
1390
1963
|
def forward_absorb_fused_mla_rope_prepare(
|
1391
1964
|
self,
|
1392
1965
|
positions: torch.Tensor,
|
@@ -1678,9 +2251,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1678
2251
|
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
1679
2252
|
self.attn_mha.layer_id
|
1680
2253
|
)
|
1681
|
-
latent_cache =
|
1682
|
-
forward_batch.prefix_chunk_kv_indices[i]
|
1683
|
-
|
2254
|
+
latent_cache = (
|
2255
|
+
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
2256
|
+
.contiguous()
|
2257
|
+
.to(q.dtype)
|
2258
|
+
)
|
1684
2259
|
|
1685
2260
|
kv_a_normed, k_pe = latent_cache.split(
|
1686
2261
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
@@ -1710,6 +2285,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1710
2285
|
tmp_lse = torch.empty_like(accum_lse)
|
1711
2286
|
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
1712
2287
|
accum_output, accum_lse = tmp_output, tmp_lse
|
2288
|
+
del kv, k, v, output, lse, tmp_output, tmp_lse
|
1713
2289
|
|
1714
2290
|
return accum_output
|
1715
2291
|
|
@@ -1864,10 +2440,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1864
2440
|
forward_batch: ForwardBatch,
|
1865
2441
|
residual: Optional[torch.Tensor],
|
1866
2442
|
zero_allocator: BumpAllocator,
|
2443
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
1867
2444
|
) -> torch.Tensor:
|
2445
|
+
quant_format = (
|
2446
|
+
"mxfp4"
|
2447
|
+
if _is_gfx95_supported
|
2448
|
+
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
|
2449
|
+
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
|
2450
|
+
is not None
|
2451
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
|
2452
|
+
else ""
|
2453
|
+
)
|
1868
2454
|
|
1869
2455
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1870
|
-
hidden_states,
|
2456
|
+
hidden_states,
|
2457
|
+
residual,
|
2458
|
+
forward_batch,
|
2459
|
+
quant_format,
|
1871
2460
|
)
|
1872
2461
|
|
1873
2462
|
hidden_states = self.self_attn(
|
@@ -1891,8 +2480,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1891
2480
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1892
2481
|
forward_batch
|
1893
2482
|
)
|
2483
|
+
|
2484
|
+
if isinstance(self.mlp, DeepseekV2MLP):
|
2485
|
+
gemm_output_zero_allocator = None
|
2486
|
+
|
1894
2487
|
hidden_states = self.mlp(
|
1895
|
-
hidden_states,
|
2488
|
+
hidden_states,
|
2489
|
+
forward_batch,
|
2490
|
+
should_allreduce_fusion,
|
2491
|
+
use_reduce_scatter,
|
2492
|
+
gemm_output_zero_allocator,
|
1896
2493
|
)
|
1897
2494
|
|
1898
2495
|
if should_allreduce_fusion:
|
@@ -2023,8 +2620,15 @@ class DeepseekV2Model(nn.Module):
|
|
2023
2620
|
[
|
2024
2621
|
"w13_weight",
|
2025
2622
|
"w2_weight",
|
2026
|
-
|
2027
|
-
|
2623
|
+
# only for nvfp4
|
2624
|
+
*(
|
2625
|
+
[
|
2626
|
+
"w13_blockscale_swizzled",
|
2627
|
+
"w2_blockscale_swizzled",
|
2628
|
+
]
|
2629
|
+
if hasattr(module, "w13_blockscale_swizzled")
|
2630
|
+
else []
|
2631
|
+
),
|
2028
2632
|
]
|
2029
2633
|
if isinstance(module, FusedMoE)
|
2030
2634
|
else []
|
@@ -2036,6 +2640,37 @@ class DeepseekV2Model(nn.Module):
|
|
2036
2640
|
else:
|
2037
2641
|
self.norm = PPMissingLayer(return_tuple=True)
|
2038
2642
|
|
2643
|
+
self.gemm_output_zero_allocator_size = 0
|
2644
|
+
if (
|
2645
|
+
_use_aiter_gfx95
|
2646
|
+
and config.n_routed_experts == 256
|
2647
|
+
and self.embed_tokens.embedding_dim == 7168
|
2648
|
+
):
|
2649
|
+
num_moe_layers = sum(
|
2650
|
+
[
|
2651
|
+
1
|
2652
|
+
for i in range(len(self.layers))
|
2653
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
2654
|
+
]
|
2655
|
+
)
|
2656
|
+
|
2657
|
+
allocate_size = 0
|
2658
|
+
for i in range(len(self.layers)):
|
2659
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
2660
|
+
allocate_size = self.layers[
|
2661
|
+
i
|
2662
|
+
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
2663
|
+
break
|
2664
|
+
|
2665
|
+
self.gemm_output_zero_allocator_size = (
|
2666
|
+
get_dsv3_gemm_output_zero_allocator_size(
|
2667
|
+
config.n_routed_experts,
|
2668
|
+
num_moe_layers,
|
2669
|
+
allocate_size,
|
2670
|
+
self.embed_tokens.embedding_dim,
|
2671
|
+
)
|
2672
|
+
)
|
2673
|
+
|
2039
2674
|
def get_input_embeddings(self) -> torch.Tensor:
|
2040
2675
|
return self.embed_tokens
|
2041
2676
|
|
@@ -2055,6 +2690,21 @@ class DeepseekV2Model(nn.Module):
|
|
2055
2690
|
device=device,
|
2056
2691
|
)
|
2057
2692
|
|
2693
|
+
has_gemm_output_zero_allocator = hasattr(
|
2694
|
+
self, "gemm_output_zero_allocator_size"
|
2695
|
+
)
|
2696
|
+
|
2697
|
+
gemm_output_zero_allocator = (
|
2698
|
+
BumpAllocator(
|
2699
|
+
buffer_size=self.gemm_output_zero_allocator_size,
|
2700
|
+
dtype=torch.float32,
|
2701
|
+
device=device,
|
2702
|
+
)
|
2703
|
+
if has_gemm_output_zero_allocator
|
2704
|
+
and self.gemm_output_zero_allocator_size > 0
|
2705
|
+
else None
|
2706
|
+
)
|
2707
|
+
|
2058
2708
|
if self.pp_group.is_first_rank:
|
2059
2709
|
if input_embeds is None:
|
2060
2710
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -2081,7 +2731,12 @@ class DeepseekV2Model(nn.Module):
|
|
2081
2731
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2082
2732
|
layer = self.layers[i]
|
2083
2733
|
hidden_states, residual = layer(
|
2084
|
-
positions,
|
2734
|
+
positions,
|
2735
|
+
hidden_states,
|
2736
|
+
forward_batch,
|
2737
|
+
residual,
|
2738
|
+
zero_allocator,
|
2739
|
+
gemm_output_zero_allocator,
|
2085
2740
|
)
|
2086
2741
|
|
2087
2742
|
if normal_end_layer != self.end_layer:
|
@@ -2354,6 +3009,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2354
3009
|
w_kc, w_vc = w.unflatten(
|
2355
3010
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2356
3011
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
3012
|
+
|
3013
|
+
if (
|
3014
|
+
_use_aiter_gfx95
|
3015
|
+
and self.quant_config is not None
|
3016
|
+
and self.quant_config.get_name() == "quark"
|
3017
|
+
):
|
3018
|
+
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
3019
|
+
quark_post_load_weights(self_attn, w, "mxfp4")
|
3020
|
+
)
|
3021
|
+
|
2357
3022
|
if not use_deep_gemm_bmm:
|
2358
3023
|
self_attn.w_kc = bind_or_assign(
|
2359
3024
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
@@ -2733,8 +3398,24 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2733
3398
|
)
|
2734
3399
|
|
2735
3400
|
|
3401
|
+
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
3402
|
+
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
3403
|
+
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
3404
|
+
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
3405
|
+
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
3406
|
+
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
3407
|
+
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
3408
|
+
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
3409
|
+
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
|
3410
|
+
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
3411
|
+
|
3412
|
+
|
2736
3413
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
2737
3414
|
pass
|
2738
3415
|
|
2739
3416
|
|
2740
|
-
|
3417
|
+
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
|
3418
|
+
pass
|
3419
|
+
|
3420
|
+
|
3421
|
+
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
|