sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +7 -9
- sglang/bench_one_batch_server.py +321 -31
- sglang/bench_serving.py +10 -3
- sglang/global_config.py +2 -2
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/launch_server.py +14 -0
- sglang/profiler.py +2 -2
- sglang/srt/batch_invariant_ops/__init__.py +27 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/dots_ocr.py +64 -0
- sglang/srt/configs/falcon_h1.py +360 -0
- sglang/srt/configs/load_config.py +8 -0
- sglang/srt/configs/model_config.py +160 -105
- sglang/srt/configs/qwen3_vl.py +586 -0
- sglang/srt/constrained/base_grammar_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +1 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -4
- sglang/srt/debug_utils/dumper.py +10 -3
- sglang/srt/disaggregation/ascend/conn.py +2 -2
- sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
- sglang/srt/disaggregation/common/conn.py +266 -98
- sglang/srt/disaggregation/decode.py +50 -9
- sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
- sglang/srt/disaggregation/mooncake/conn.py +51 -541
- sglang/srt/disaggregation/nixl/conn.py +148 -39
- sglang/srt/disaggregation/prefill.py +31 -14
- sglang/srt/disaggregation/utils.py +36 -5
- sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
- sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
- sglang/srt/distributed/parallel_state.py +135 -80
- sglang/srt/entrypoints/engine.py +23 -3
- sglang/srt/entrypoints/grpc_request_manager.py +330 -55
- sglang/srt/entrypoints/grpc_server.py +232 -102
- sglang/srt/entrypoints/http_server.py +49 -9
- sglang/srt/entrypoints/openai/protocol.py +110 -5
- sglang/srt/entrypoints/openai/serving_base.py +25 -6
- sglang/srt/entrypoints/openai/serving_chat.py +178 -49
- sglang/srt/entrypoints/openai/serving_completions.py +5 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/entrypoints/openai/serving_responses.py +42 -0
- sglang/srt/environ.py +285 -0
- sglang/srt/eplb/expert_location.py +30 -5
- sglang/srt/function_call/function_call_parser.py +3 -2
- sglang/srt/function_call/glm4_moe_detector.py +3 -3
- sglang/srt/function_call/gpt_oss_detector.py +23 -0
- sglang/srt/function_call/json_array_parser.py +63 -0
- sglang/srt/function_call/kimik2_detector.py +17 -4
- sglang/srt/function_call/utils.py +96 -5
- sglang/srt/grpc/compile_proto.py +245 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
- sglang/srt/layers/activation.py +7 -6
- sglang/srt/layers/attention/aiter_backend.py +14 -15
- sglang/srt/layers/attention/ascend_backend.py +108 -9
- sglang/srt/layers/attention/attention_registry.py +206 -0
- sglang/srt/layers/attention/base_attn_backend.py +12 -3
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
- sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
- sglang/srt/layers/attention/flashattention_backend.py +41 -8
- sglang/srt/layers/attention/flashinfer_backend.py +112 -194
- sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
- sglang/srt/layers/attention/flashmla_backend.py +7 -5
- sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
- sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
- sglang/srt/layers/attention/mamba/mamba.py +566 -1
- sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
- sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
- sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
- sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
- sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
- sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
- sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
- sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
- sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
- sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
- sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
- sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
- sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
- sglang/srt/layers/attention/nsa/transform_index.py +144 -0
- sglang/srt/layers/attention/nsa/utils.py +24 -0
- sglang/srt/layers/attention/nsa_backend.py +887 -0
- sglang/srt/layers/attention/tbo_backend.py +6 -6
- sglang/srt/layers/attention/torch_flex_backend.py +325 -0
- sglang/srt/layers/attention/triton_backend.py +42 -9
- sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
- sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
- sglang/srt/layers/attention/vision.py +58 -0
- sglang/srt/layers/attention/wave_backend.py +4 -4
- sglang/srt/layers/communicator.py +8 -0
- sglang/srt/layers/dp_attention.py +11 -1
- sglang/srt/layers/elementwise.py +3 -1
- sglang/srt/layers/layernorm.py +2 -0
- sglang/srt/layers/linear.py +21 -4
- sglang/srt/layers/logits_processor.py +15 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +147 -74
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
- sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
- sglang/srt/layers/moe/utils.py +10 -0
- sglang/srt/layers/parameter.py +23 -6
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
- sglang/srt/layers/quantization/fp8.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +1 -1
- sglang/srt/layers/quantization/modelopt_quant.py +44 -9
- sglang/srt/layers/quantization/mxfp4.py +12 -4
- sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
- sglang/srt/layers/quantization/w4afp8.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +15 -3
- sglang/srt/layers/rotary_embedding.py +78 -31
- sglang/srt/layers/sampler.py +52 -4
- sglang/srt/layers/utils.py +23 -0
- sglang/srt/lora/backend/base_backend.py +3 -3
- sglang/srt/lora/backend/chunked_backend.py +348 -0
- sglang/srt/lora/backend/triton_backend.py +10 -4
- sglang/srt/lora/lora.py +7 -5
- sglang/srt/lora/lora_manager.py +17 -6
- sglang/srt/lora/mem_pool.py +1 -1
- sglang/srt/lora/triton_ops/__init__.py +4 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
- sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
- sglang/srt/lora/utils.py +7 -5
- sglang/srt/managers/cache_controller.py +42 -142
- sglang/srt/managers/data_parallel_controller.py +11 -46
- sglang/srt/managers/detokenizer_manager.py +11 -11
- sglang/srt/managers/io_struct.py +162 -118
- sglang/srt/managers/mm_utils.py +43 -6
- sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
- sglang/srt/managers/multimodal_processor.py +1 -2
- sglang/srt/managers/overlap_utils.py +53 -0
- sglang/srt/managers/schedule_batch.py +167 -86
- sglang/srt/managers/schedule_policy.py +143 -16
- sglang/srt/managers/scheduler.py +359 -214
- sglang/srt/managers/scheduler_input_blocker.py +1 -1
- sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
- sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
- sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
- sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
- sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
- sglang/srt/managers/tokenizer_manager.py +84 -136
- sglang/srt/managers/tp_worker.py +39 -29
- sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
- sglang/srt/managers/utils.py +1 -45
- sglang/srt/mem_cache/allocator.py +14 -20
- sglang/srt/mem_cache/allocator_ascend.py +41 -27
- sglang/srt/mem_cache/base_prefix_cache.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -1
- sglang/srt/mem_cache/evict_policy.py +23 -0
- sglang/srt/mem_cache/hicache_storage.py +40 -1
- sglang/srt/mem_cache/hiradix_cache.py +119 -32
- sglang/srt/mem_cache/memory_pool.py +188 -10
- sglang/srt/mem_cache/memory_pool_host.py +134 -182
- sglang/srt/mem_cache/radix_cache.py +222 -71
- sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
- sglang/srt/mem_cache/storage/__init__.py +10 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
- sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
- sglang/srt/mem_cache/storage/backend_factory.py +223 -0
- sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
- sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
- sglang/srt/mem_cache/swa_radix_cache.py +25 -34
- sglang/srt/metrics/collector.py +82 -120
- sglang/srt/metrics/func_timer.py +2 -7
- sglang/srt/metrics/utils.py +8 -1
- sglang/srt/model_executor/cpu_graph_runner.py +2 -2
- sglang/srt/model_executor/cuda_graph_runner.py +39 -32
- sglang/srt/model_executor/forward_batch_info.py +23 -38
- sglang/srt/model_executor/model_runner.py +131 -183
- sglang/srt/model_executor/npu_graph_runner.py +12 -5
- sglang/srt/model_loader/loader.py +14 -10
- sglang/srt/model_loader/weight_utils.py +156 -2
- sglang/srt/models/bailing_moe.py +27 -4
- sglang/srt/models/deepseek_nextn.py +6 -1
- sglang/srt/models/deepseek_v2.py +536 -153
- sglang/srt/models/dots_ocr.py +173 -0
- sglang/srt/models/falcon_h1.py +576 -0
- sglang/srt/models/gemma3_causal.py +0 -2
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +3 -3
- sglang/srt/models/glm4_moe_nextn.py +2 -2
- sglang/srt/models/glm4v.py +1 -1
- sglang/srt/models/glm4v_moe.py +1 -1
- sglang/srt/models/gpt_oss.py +7 -30
- sglang/srt/models/kimi_vl_moonvit.py +2 -2
- sglang/srt/models/llama.py +4 -0
- sglang/srt/models/longcat_flash.py +1 -1
- sglang/srt/models/longcat_flash_nextn.py +1 -1
- sglang/srt/models/mllama4.py +15 -4
- sglang/srt/models/qwen2.py +0 -7
- sglang/srt/models/qwen2_5_vl.py +2 -2
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +64 -1
- sglang/srt/models/qwen2_vl.py +1 -1
- sglang/srt/models/qwen3.py +18 -3
- sglang/srt/models/qwen3_moe.py +31 -3
- sglang/srt/models/qwen3_next.py +36 -9
- sglang/srt/models/qwen3_vl.py +787 -0
- sglang/srt/models/qwen3_vl_moe.py +471 -0
- sglang/srt/models/registry.py +15 -3
- sglang/srt/models/sarashina2_vision.py +269 -0
- sglang/srt/models/solar.py +505 -0
- sglang/srt/models/starcoder2.py +357 -0
- sglang/srt/models/torch_native_llama.py +9 -2
- sglang/srt/models/utils.py +51 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -7
- sglang/srt/multimodal/processors/dots_vlm.py +2 -3
- sglang/srt/multimodal/processors/internvl.py +20 -8
- sglang/srt/multimodal/processors/qwen_vl.py +8 -1
- sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
- sglang/srt/parser/jinja_template_utils.py +6 -0
- sglang/srt/sampling/sampling_batch_info.py +20 -2
- sglang/srt/sampling/sampling_params.py +7 -0
- sglang/srt/server_args.py +753 -295
- sglang/srt/server_args_config_parser.py +146 -0
- sglang/srt/single_batch_overlap.py +151 -0
- sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
- sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
- sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
- sglang/srt/speculative/cpp_ngram/param.h +125 -0
- sglang/srt/speculative/cpp_ngram/queue.h +71 -0
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
- sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
- sglang/srt/speculative/eagle_worker.py +57 -25
- sglang/srt/speculative/ngram_utils.py +428 -0
- sglang/srt/speculative/ngram_worker.py +245 -0
- sglang/srt/speculative/spec_info.py +47 -0
- sglang/srt/speculative/spec_utils.py +606 -0
- sglang/srt/torch_memory_saver_adapter.py +5 -7
- sglang/srt/tracing/trace.py +32 -6
- sglang/srt/two_batch_overlap.py +8 -5
- sglang/srt/utils/__init__.py +2 -0
- sglang/srt/{utils.py → utils/common.py} +399 -74
- sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
- sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
- sglang/srt/utils/rpd_utils.py +452 -0
- sglang/srt/utils/slow_rank_detector.py +71 -0
- sglang/srt/warmup.py +8 -4
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/get_logits_ut.py +57 -0
- sglang/test/run_eval.py +79 -11
- sglang/test/runners.py +1 -1
- sglang/test/simple_eval_common.py +5 -2
- sglang/test/simple_eval_mmmu_vlm.py +441 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deterministic.py +297 -0
- sglang/test/test_disaggregation_utils.py +12 -1
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +355 -4
- sglang/utils.py +10 -1
- sglang/version.py +1 -1
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
- sglang/srt/mem_cache/lora_radix_cache.py +0 -421
- /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
- /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15
15
|
# Adapted from:
|
16
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
17
17
|
"""Inference-only DeepseekV2 model."""
|
18
|
+
from __future__ import annotations
|
18
19
|
|
19
20
|
import concurrent.futures
|
20
21
|
import logging
|
@@ -25,9 +26,16 @@ from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
|
25
26
|
import torch
|
26
27
|
import torch.nn.functional as F
|
27
28
|
from torch import nn
|
28
|
-
from tqdm import tqdm
|
29
29
|
from transformers import PretrainedConfig
|
30
30
|
|
31
|
+
from sglang.srt import single_batch_overlap
|
32
|
+
from sglang.srt.configs.model_config import (
|
33
|
+
get_nsa_index_head_dim,
|
34
|
+
get_nsa_index_n_heads,
|
35
|
+
get_nsa_index_topk,
|
36
|
+
is_deepseek_nsa,
|
37
|
+
)
|
38
|
+
from sglang.srt.debug_utils.dumper import dumper
|
31
39
|
from sglang.srt.distributed import (
|
32
40
|
get_moe_expert_parallel_world_size,
|
33
41
|
get_pp_group,
|
@@ -43,6 +51,11 @@ from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
|
43
51
|
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
|
44
52
|
from sglang.srt.layers.activation import SiluAndMul
|
45
53
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
54
|
+
from sglang.srt.layers.attention.npu_ops.mla_preprocess import (
|
55
|
+
NPUFusedMLAPreprocess,
|
56
|
+
is_mla_preprocess_enabled,
|
57
|
+
)
|
58
|
+
from sglang.srt.layers.attention.nsa.nsa_indexer import Indexer
|
46
59
|
from sglang.srt.layers.communicator import (
|
47
60
|
LayerCommunicator,
|
48
61
|
LayerScatterModes,
|
@@ -97,6 +110,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
97
110
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
98
111
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
99
112
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
113
|
+
from sglang.srt.single_batch_overlap import SboFlags
|
100
114
|
from sglang.srt.two_batch_overlap import (
|
101
115
|
MaybeTboDeepEPDispatcher,
|
102
116
|
model_forward_maybe_tbo,
|
@@ -160,16 +174,18 @@ if _is_cuda:
|
|
160
174
|
elif _is_cpu and _is_cpu_amx_available:
|
161
175
|
pass
|
162
176
|
elif _is_hip:
|
177
|
+
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
178
|
+
decode_attention_fwd_grouped_rope,
|
179
|
+
)
|
163
180
|
from sglang.srt.layers.quantization.awq_triton import (
|
164
181
|
awq_dequantize_triton as awq_dequantize,
|
165
182
|
)
|
183
|
+
elif _is_npu:
|
184
|
+
import custom_ops
|
185
|
+
import sgl_kernel_npu
|
186
|
+
import torch_npu
|
166
187
|
else:
|
167
|
-
|
168
|
-
|
169
|
-
if _is_hip:
|
170
|
-
from sglang.srt.layers.attention.triton_ops.rocm_mla_decode_rope import (
|
171
|
-
decode_attention_fwd_grouped_rope,
|
172
|
-
)
|
188
|
+
pass
|
173
189
|
|
174
190
|
_is_flashinfer_available = is_flashinfer_available()
|
175
191
|
_is_sm100_supported = is_cuda() and is_sm100_supported()
|
@@ -177,6 +193,21 @@ _is_sm100_supported = is_cuda() and is_sm100_supported()
|
|
177
193
|
|
178
194
|
logger = logging.getLogger(__name__)
|
179
195
|
|
196
|
+
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS = [
|
197
|
+
"fa3",
|
198
|
+
"nsa",
|
199
|
+
"flashinfer",
|
200
|
+
"cutlass_mla",
|
201
|
+
"trtllm_mla",
|
202
|
+
"ascend",
|
203
|
+
]
|
204
|
+
|
205
|
+
|
206
|
+
def add_forward_absorb_core_attention_backend(backend_name):
|
207
|
+
if backend_name not in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
208
|
+
FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.append(backend_name)
|
209
|
+
logger.info(f"Added {backend_name} to FORWARD_ABSORB_CORE_ATTENTION_BACKENDS.")
|
210
|
+
|
180
211
|
|
181
212
|
class AttnForwardMethod(IntEnum):
|
182
213
|
# Use multi-head attention
|
@@ -185,6 +216,9 @@ class AttnForwardMethod(IntEnum):
|
|
185
216
|
# Use absorbed multi-latent attention
|
186
217
|
MLA = auto()
|
187
218
|
|
219
|
+
# Use Deepseek V3.2 sparse multi-latent attention
|
220
|
+
NPU_MLA_SPARSE = auto()
|
221
|
+
|
188
222
|
# Use multi-head attention, but with KV cache chunked.
|
189
223
|
# This method can avoid OOM when prefix lengths are long.
|
190
224
|
MHA_CHUNKED_KV = auto()
|
@@ -196,6 +230,146 @@ class AttnForwardMethod(IntEnum):
|
|
196
230
|
MLA_FUSED_ROPE_CPU = auto()
|
197
231
|
|
198
232
|
|
233
|
+
def _dispatch_mla_subtype(attn, forward_batch):
|
234
|
+
if _is_hip:
|
235
|
+
if attn.rocm_fused_decode_mla and forward_batch.forward_mode.is_decode():
|
236
|
+
return AttnForwardMethod.MLA_FUSED_ROPE
|
237
|
+
else:
|
238
|
+
return AttnForwardMethod.MLA
|
239
|
+
else:
|
240
|
+
if hasattr(attn, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(attn):
|
241
|
+
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
242
|
+
else:
|
243
|
+
return AttnForwardMethod.MLA
|
244
|
+
|
245
|
+
|
246
|
+
class AttentionBackendRegistry:
|
247
|
+
_handlers = {}
|
248
|
+
|
249
|
+
@classmethod
|
250
|
+
def register(cls, backend_name, handler_func):
|
251
|
+
cls._handlers[backend_name] = handler_func
|
252
|
+
|
253
|
+
@classmethod
|
254
|
+
def get_handler(cls, backend_name):
|
255
|
+
return cls._handlers.get(backend_name, cls._handlers.get("triton"))
|
256
|
+
|
257
|
+
|
258
|
+
def handle_attention_ascend(attn, forward_batch):
|
259
|
+
if (
|
260
|
+
forward_batch.forward_mode.is_extend()
|
261
|
+
and not forward_batch.forward_mode.is_target_verify()
|
262
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
263
|
+
):
|
264
|
+
if hasattr(attn, "indexer"):
|
265
|
+
return AttnForwardMethod.NPU_MLA_SPARSE
|
266
|
+
else:
|
267
|
+
return AttnForwardMethod.MHA
|
268
|
+
else:
|
269
|
+
if hasattr(attn, "indexer"):
|
270
|
+
return AttnForwardMethod.NPU_MLA_SPARSE
|
271
|
+
else:
|
272
|
+
return AttnForwardMethod.MLA
|
273
|
+
|
274
|
+
|
275
|
+
def _get_sum_extend_prefix_lens(forward_batch):
|
276
|
+
return (
|
277
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
278
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
279
|
+
else 0
|
280
|
+
)
|
281
|
+
|
282
|
+
|
283
|
+
def _is_extend_without_speculative(forward_batch):
|
284
|
+
return (
|
285
|
+
forward_batch.forward_mode.is_extend()
|
286
|
+
and not forward_batch.forward_mode.is_target_verify()
|
287
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
288
|
+
)
|
289
|
+
|
290
|
+
|
291
|
+
def _handle_attention_backend(
|
292
|
+
attn: DeepseekV2AttentionMLA, forward_batch, backend_name
|
293
|
+
):
|
294
|
+
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
295
|
+
disable_ragged = (
|
296
|
+
backend_name in ["flashinfer", "flashmla"]
|
297
|
+
) and attn.flashinfer_mla_disable_ragged
|
298
|
+
|
299
|
+
if (
|
300
|
+
not disable_ragged
|
301
|
+
and _is_extend_without_speculative(forward_batch)
|
302
|
+
and (
|
303
|
+
(
|
304
|
+
sum_extend_prefix_lens >= attn.chunked_prefix_cache_threshold
|
305
|
+
and not attn.disable_chunked_prefix_cache
|
306
|
+
)
|
307
|
+
or sum_extend_prefix_lens == 0
|
308
|
+
)
|
309
|
+
):
|
310
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
311
|
+
else:
|
312
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
313
|
+
|
314
|
+
|
315
|
+
def handle_attention_flashinfer(attn, forward_batch):
|
316
|
+
return _handle_attention_backend(attn, forward_batch, "flashinfer")
|
317
|
+
|
318
|
+
|
319
|
+
def handle_attention_fa3(attn, forward_batch):
|
320
|
+
return _handle_attention_backend(attn, forward_batch, "fa3")
|
321
|
+
|
322
|
+
|
323
|
+
def handle_attention_flashmla(attn, forward_batch):
|
324
|
+
return _handle_attention_backend(attn, forward_batch, "flashmla")
|
325
|
+
|
326
|
+
|
327
|
+
def handle_attention_cutlass_mla(attn, forward_batch):
|
328
|
+
return _handle_attention_backend(attn, forward_batch, "cutlass_mla")
|
329
|
+
|
330
|
+
|
331
|
+
def handle_attention_fa4(attn, forward_batch):
|
332
|
+
# TODO(cicirori): use FA4 MHA for DeepSeekV3 for now
|
333
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
334
|
+
|
335
|
+
|
336
|
+
def handle_attention_trtllm_mla(attn, forward_batch):
|
337
|
+
sum_extend_prefix_lens = _get_sum_extend_prefix_lens(forward_batch)
|
338
|
+
if _is_extend_without_speculative(forward_batch) and (
|
339
|
+
not attn.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
340
|
+
):
|
341
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
342
|
+
else:
|
343
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
344
|
+
|
345
|
+
|
346
|
+
def handle_attention_aiter(attn, forward_batch):
|
347
|
+
if _is_extend_without_speculative(forward_batch):
|
348
|
+
if is_dp_attention_enabled():
|
349
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
350
|
+
return AttnForwardMethod.MHA
|
351
|
+
else:
|
352
|
+
return AttnForwardMethod.MLA
|
353
|
+
else:
|
354
|
+
return AttnForwardMethod.MHA
|
355
|
+
else:
|
356
|
+
return AttnForwardMethod.MLA
|
357
|
+
|
358
|
+
|
359
|
+
def handle_attention_nsa(attn, forward_batch):
|
360
|
+
return AttnForwardMethod.MLA
|
361
|
+
|
362
|
+
|
363
|
+
def handle_attention_triton(attn, forward_batch):
|
364
|
+
if (
|
365
|
+
_is_extend_without_speculative(forward_batch)
|
366
|
+
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
367
|
+
):
|
368
|
+
return AttnForwardMethod.MHA
|
369
|
+
else:
|
370
|
+
return _dispatch_mla_subtype(attn, forward_batch)
|
371
|
+
|
372
|
+
|
199
373
|
class DeepseekV2MLP(nn.Module):
|
200
374
|
def __init__(
|
201
375
|
self,
|
@@ -309,7 +483,7 @@ class MoEGate(nn.Module):
|
|
309
483
|
_is_cuda
|
310
484
|
and hidden_states.shape[0] <= 16
|
311
485
|
and hidden_states.shape[1] == 7168
|
312
|
-
and self.weight.shape[0] == 256
|
486
|
+
and (self.weight.shape[0] == 256 or self.weight.shape[0] == 384)
|
313
487
|
and _device_sm >= 90
|
314
488
|
):
|
315
489
|
# router gemm output float32
|
@@ -393,7 +567,7 @@ class DeepseekV2MoE(nn.Module):
|
|
393
567
|
correction_bias=self.gate.e_score_correction_bias,
|
394
568
|
quant_config=quant_config,
|
395
569
|
routed_scaling_factor=self.routed_scaling_factor,
|
396
|
-
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk
|
570
|
+
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
|
397
571
|
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
398
572
|
# and requires the output format to be standard. We use quant_config to determine the output format.
|
399
573
|
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
@@ -660,7 +834,8 @@ class DeepseekV2MoE(nn.Module):
|
|
660
834
|
if hidden_states.shape[0] > 0:
|
661
835
|
# router_logits: (num_tokens, n_experts)
|
662
836
|
router_logits = self.gate(hidden_states)
|
663
|
-
|
837
|
+
if not SboFlags.fuse_shared_experts_inside_sbo():
|
838
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
664
839
|
topk_weights, topk_idx, _ = self.topk(
|
665
840
|
hidden_states,
|
666
841
|
router_logits,
|
@@ -674,22 +849,28 @@ class DeepseekV2MoE(nn.Module):
|
|
674
849
|
hidden_states.device
|
675
850
|
)
|
676
851
|
|
677
|
-
final_hidden_states =
|
852
|
+
final_hidden_states, sbo_shared_output = single_batch_overlap.execute_sbo(
|
678
853
|
hidden_states=hidden_states,
|
679
854
|
topk_idx=topk_idx,
|
680
855
|
topk_weights=topk_weights,
|
681
856
|
forward_batch=forward_batch,
|
857
|
+
# SBO args
|
858
|
+
forward_shared_experts=lambda: self._forward_shared_experts(hidden_states),
|
859
|
+
experts=self.experts,
|
860
|
+
alt_stream=self.alt_stream,
|
682
861
|
)
|
862
|
+
if sbo_shared_output is not None:
|
863
|
+
shared_output = sbo_shared_output
|
683
864
|
|
684
865
|
if shared_output is not None:
|
685
866
|
x = shared_output
|
686
|
-
if self.experts.should_fuse_routed_scaling_factor_in_topk
|
867
|
+
if self.experts.should_fuse_routed_scaling_factor_in_topk:
|
687
868
|
x.add_(final_hidden_states)
|
688
869
|
else:
|
689
870
|
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
690
871
|
final_hidden_states = x
|
691
872
|
else:
|
692
|
-
if not self.experts.should_fuse_routed_scaling_factor_in_topk
|
873
|
+
if not self.experts.should_fuse_routed_scaling_factor_in_topk:
|
693
874
|
final_hidden_states *= self.routed_scaling_factor
|
694
875
|
|
695
876
|
return final_hidden_states
|
@@ -697,7 +878,7 @@ class DeepseekV2MoE(nn.Module):
|
|
697
878
|
def _forward_shared_experts(
|
698
879
|
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
699
880
|
):
|
700
|
-
if self.num_fused_shared_experts == 0:
|
881
|
+
if (hidden_states.shape[0] > 0) and (self.num_fused_shared_experts == 0):
|
701
882
|
return self.shared_experts(
|
702
883
|
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
703
884
|
)
|
@@ -750,6 +931,7 @@ class DeepseekV2MoE(nn.Module):
|
|
750
931
|
if self.ep_size > 1:
|
751
932
|
self.experts.deepep_dispatcher.dispatch_a(
|
752
933
|
hidden_states=state.hidden_states_mlp_input,
|
934
|
+
input_global_scale=None,
|
753
935
|
topk_idx=state.pop("topk_idx_local"),
|
754
936
|
topk_weights=state.pop("topk_weights_local"),
|
755
937
|
forward_batch=state.forward_batch,
|
@@ -850,6 +1032,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
850
1032
|
self.rope_theta = rope_theta
|
851
1033
|
self.max_position_embeddings = max_position_embeddings
|
852
1034
|
|
1035
|
+
# NOTE modification to rope_scaling must be done early enough, b/c e.g. Indexer needs it
|
1036
|
+
if rope_scaling:
|
1037
|
+
rope_scaling["rope_type"] = "deepseek_yarn"
|
1038
|
+
|
853
1039
|
# For tensor parallel attention
|
854
1040
|
if self.q_lora_rank is not None:
|
855
1041
|
self.fused_qkv_a_proj_with_mqa = ReplicatedLinear(
|
@@ -887,6 +1073,26 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
887
1073
|
prefix=add_prefix("kv_a_proj_with_mqa", prefix),
|
888
1074
|
)
|
889
1075
|
|
1076
|
+
self.use_nsa = is_deepseek_nsa(config)
|
1077
|
+
if self.use_nsa:
|
1078
|
+
self.indexer = Indexer(
|
1079
|
+
hidden_size=hidden_size,
|
1080
|
+
index_n_heads=get_nsa_index_n_heads(config),
|
1081
|
+
index_head_dim=get_nsa_index_head_dim(config),
|
1082
|
+
rope_head_dim=qk_rope_head_dim,
|
1083
|
+
index_topk=get_nsa_index_topk(config),
|
1084
|
+
q_lora_rank=q_lora_rank,
|
1085
|
+
max_position_embeddings=max_position_embeddings,
|
1086
|
+
rope_theta=rope_theta,
|
1087
|
+
scale_fmt="ue8m0",
|
1088
|
+
block_size=128,
|
1089
|
+
rope_scaling=rope_scaling,
|
1090
|
+
prefix=add_prefix("indexer", prefix),
|
1091
|
+
quant_config=quant_config,
|
1092
|
+
layer_id=layer_id,
|
1093
|
+
alt_stream=alt_stream,
|
1094
|
+
)
|
1095
|
+
|
890
1096
|
self.kv_b_proj = ColumnParallelLinear(
|
891
1097
|
self.kv_lora_rank,
|
892
1098
|
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
@@ -909,9 +1115,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
909
1115
|
)
|
910
1116
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
911
1117
|
|
912
|
-
if rope_scaling:
|
913
|
-
rope_scaling["rope_type"] = "deepseek_yarn"
|
914
|
-
|
915
1118
|
self.rotary_emb = get_rope_wrapper(
|
916
1119
|
qk_rope_head_dim,
|
917
1120
|
rotary_dim=qk_rope_head_dim,
|
@@ -1035,27 +1238,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1035
1238
|
self.weight_block_size = (
|
1036
1239
|
self.fused_qkv_a_proj_with_mqa.quant_method.quant_config.weight_block_size
|
1037
1240
|
)
|
1241
|
+
self.is_mla_preprocess_enabled = is_mla_preprocess_enabled()
|
1242
|
+
if self.is_mla_preprocess_enabled:
|
1243
|
+
assert (
|
1244
|
+
quant_config is None or quant_config.get_name() == "w8a8_int8"
|
1245
|
+
), "MLA Preprocess only works with Unquant or W8A8Int8"
|
1246
|
+
self.mla_preprocess = None
|
1038
1247
|
|
1039
1248
|
def dispatch_attn_forward_method(
|
1040
1249
|
self, forward_batch: ForwardBatch
|
1041
1250
|
) -> AttnForwardMethod:
|
1042
|
-
def _dispatch_mla_subtype():
|
1043
|
-
if _is_hip:
|
1044
|
-
if (
|
1045
|
-
self.rocm_fused_decode_mla
|
1046
|
-
and forward_batch.forward_mode.is_decode()
|
1047
|
-
):
|
1048
|
-
return AttnForwardMethod.MLA_FUSED_ROPE
|
1049
|
-
else:
|
1050
|
-
return AttnForwardMethod.MLA
|
1051
|
-
else:
|
1052
|
-
if hasattr(self, "fused_qkv_a_proj_with_mqa") and use_intel_amx_backend(
|
1053
|
-
self
|
1054
|
-
):
|
1055
|
-
return AttnForwardMethod.MLA_FUSED_ROPE_CPU
|
1056
|
-
else:
|
1057
|
-
return AttnForwardMethod.MLA
|
1058
|
-
|
1059
1251
|
# Determine attention backend used by current forward batch
|
1060
1252
|
if forward_batch.forward_mode.is_decode_or_idle():
|
1061
1253
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
@@ -1072,109 +1264,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1072
1264
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
1073
1265
|
self.current_attention_backend = attention_backend
|
1074
1266
|
|
1075
|
-
|
1076
|
-
|
1077
|
-
forward_batch.forward_mode.is_extend()
|
1078
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1079
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1080
|
-
):
|
1081
|
-
return AttnForwardMethod.MHA
|
1082
|
-
else:
|
1083
|
-
return AttnForwardMethod.MLA
|
1084
|
-
elif (
|
1085
|
-
attention_backend == "flashinfer"
|
1086
|
-
or attention_backend == "fa3"
|
1087
|
-
or attention_backend == "flashmla"
|
1088
|
-
or attention_backend == "cutlass_mla"
|
1089
|
-
):
|
1090
|
-
# Use MHA with chunked KV cache when prefilling on long sequences.
|
1091
|
-
sum_extend_prefix_lens = (
|
1092
|
-
sum(forward_batch.extend_prefix_lens_cpu)
|
1093
|
-
if forward_batch.extend_prefix_lens_cpu is not None
|
1094
|
-
else 0
|
1095
|
-
)
|
1096
|
-
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
1097
|
-
disable_ragged = (
|
1098
|
-
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1099
|
-
) and self.flashinfer_mla_disable_ragged
|
1100
|
-
|
1101
|
-
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
1102
|
-
if (
|
1103
|
-
not disable_ragged
|
1104
|
-
and forward_batch.forward_mode.is_extend()
|
1105
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1106
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1107
|
-
and (
|
1108
|
-
(
|
1109
|
-
sum_extend_prefix_lens >= self.chunked_prefix_cache_threshold
|
1110
|
-
and not self.disable_chunked_prefix_cache
|
1111
|
-
)
|
1112
|
-
or sum_extend_prefix_lens == 0
|
1113
|
-
)
|
1114
|
-
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
|
1115
|
-
# dp case. Redirect to mla kernel as a workaround.
|
1116
|
-
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
|
1117
|
-
and not (
|
1118
|
-
original_mode is not None
|
1119
|
-
and original_mode.is_decode()
|
1120
|
-
and is_sm100_supported()
|
1121
|
-
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
|
1122
|
-
)
|
1123
|
-
):
|
1124
|
-
return AttnForwardMethod.MHA_CHUNKED_KV
|
1125
|
-
else:
|
1126
|
-
return _dispatch_mla_subtype()
|
1127
|
-
elif attention_backend == "trtllm_mla":
|
1128
|
-
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
1129
|
-
if (
|
1130
|
-
original_mode is not None
|
1131
|
-
and original_mode.is_decode()
|
1132
|
-
and is_sm100_supported()
|
1133
|
-
):
|
1134
|
-
return _dispatch_mla_subtype()
|
1135
|
-
|
1136
|
-
sum_extend_prefix_lens = (
|
1137
|
-
sum(forward_batch.extend_prefix_lens_cpu)
|
1138
|
-
if forward_batch.extend_prefix_lens_cpu is not None
|
1139
|
-
else 0
|
1140
|
-
)
|
1141
|
-
if (
|
1142
|
-
forward_batch.forward_mode.is_extend()
|
1143
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1144
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1145
|
-
and (
|
1146
|
-
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
1147
|
-
)
|
1148
|
-
):
|
1149
|
-
return AttnForwardMethod.MHA_CHUNKED_KV
|
1150
|
-
else:
|
1151
|
-
return _dispatch_mla_subtype()
|
1152
|
-
elif attention_backend == "aiter":
|
1153
|
-
if (
|
1154
|
-
forward_batch.forward_mode.is_extend()
|
1155
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1156
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1157
|
-
):
|
1158
|
-
if is_dp_attention_enabled():
|
1159
|
-
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
1160
|
-
return AttnForwardMethod.MHA
|
1161
|
-
else:
|
1162
|
-
return AttnForwardMethod.MLA
|
1163
|
-
else:
|
1164
|
-
return AttnForwardMethod.MHA
|
1165
|
-
else:
|
1166
|
-
return AttnForwardMethod.MLA
|
1167
|
-
else:
|
1168
|
-
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
1169
|
-
if (
|
1170
|
-
forward_batch.forward_mode.is_extend()
|
1171
|
-
and not forward_batch.forward_mode.is_target_verify()
|
1172
|
-
and not forward_batch.forward_mode.is_draft_extend()
|
1173
|
-
and sum(forward_batch.extend_prefix_lens_cpu) == 0
|
1174
|
-
):
|
1175
|
-
return AttnForwardMethod.MHA
|
1176
|
-
else:
|
1177
|
-
return _dispatch_mla_subtype()
|
1267
|
+
handler = AttentionBackendRegistry.get_handler(attention_backend)
|
1268
|
+
return handler(self, forward_batch)
|
1178
1269
|
|
1179
1270
|
def op_prepare(self, state):
|
1180
1271
|
state.attn_intermediate_state = self.forward_prepare(
|
@@ -1229,7 +1320,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1229
1320
|
return hidden_states, None, forward_batch, None
|
1230
1321
|
|
1231
1322
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1232
|
-
|
1233
1323
|
if attn_forward_method == AttnForwardMethod.MHA:
|
1234
1324
|
inner_state = self.forward_normal_prepare(
|
1235
1325
|
positions, hidden_states, forward_batch, zero_allocator
|
@@ -1239,7 +1329,30 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1239
1329
|
positions, hidden_states, forward_batch, zero_allocator
|
1240
1330
|
)
|
1241
1331
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
1242
|
-
|
1332
|
+
if not self.is_mla_preprocess_enabled:
|
1333
|
+
inner_state = self.forward_absorb_prepare(
|
1334
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1335
|
+
)
|
1336
|
+
else:
|
1337
|
+
# TODO(iforgetmyname): to be separated as a standalone func
|
1338
|
+
if self.mla_preprocess is None:
|
1339
|
+
self.mla_preprocess = NPUFusedMLAPreprocess(
|
1340
|
+
self.fused_qkv_a_proj_with_mqa,
|
1341
|
+
self.q_a_layernorm,
|
1342
|
+
self.kv_a_layernorm,
|
1343
|
+
self.q_b_proj,
|
1344
|
+
self.w_kc,
|
1345
|
+
self.rotary_emb,
|
1346
|
+
self.layer_id,
|
1347
|
+
self.num_local_heads,
|
1348
|
+
self.qk_nope_head_dim,
|
1349
|
+
self.qk_rope_head_dim,
|
1350
|
+
)
|
1351
|
+
inner_state = self.mla_preprocess.forward(
|
1352
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1353
|
+
)
|
1354
|
+
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
1355
|
+
inner_state = self.forward_npu_sparse_prepare(
|
1243
1356
|
positions, hidden_states, forward_batch, zero_allocator
|
1244
1357
|
)
|
1245
1358
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
@@ -1267,6 +1380,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1267
1380
|
return self.forward_normal_chunked_kv_core(*inner_state)
|
1268
1381
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
1269
1382
|
return self.forward_absorb_core(*inner_state)
|
1383
|
+
elif attn_forward_method == AttnForwardMethod.NPU_MLA_SPARSE:
|
1384
|
+
return self.forward_npu_sparse_core(*inner_state)
|
1270
1385
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
1271
1386
|
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
1272
1387
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE_CPU:
|
@@ -1346,7 +1461,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1346
1461
|
"""
|
1347
1462
|
return (
|
1348
1463
|
self.current_attention_backend == "trtllm_mla"
|
1349
|
-
and
|
1464
|
+
and (
|
1465
|
+
forward_batch.forward_mode.is_decode_or_idle()
|
1466
|
+
or forward_batch.forward_mode.is_target_verify()
|
1467
|
+
)
|
1350
1468
|
and forward_batch.attn_backend.data_type == torch.float8_e4m3fn
|
1351
1469
|
)
|
1352
1470
|
|
@@ -1359,6 +1477,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1359
1477
|
):
|
1360
1478
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1361
1479
|
|
1480
|
+
q_lora = None
|
1362
1481
|
if self.q_lora_rank is not None:
|
1363
1482
|
if (
|
1364
1483
|
(not isinstance(hidden_states, tuple))
|
@@ -1397,6 +1516,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1397
1516
|
q = self.q_a_layernorm(q)
|
1398
1517
|
k_nope = self.kv_a_layernorm(k_nope)
|
1399
1518
|
|
1519
|
+
# q_lora needed by indexer
|
1520
|
+
if self.use_nsa:
|
1521
|
+
q_lora = q
|
1522
|
+
|
1400
1523
|
k_nope = k_nope.unsqueeze(1)
|
1401
1524
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1402
1525
|
else:
|
@@ -1462,28 +1585,50 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1462
1585
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1463
1586
|
|
1464
1587
|
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1465
|
-
not _use_aiter or not _is_gfx95_supported
|
1588
|
+
not _use_aiter or not _is_gfx95_supported or self.use_nsa
|
1466
1589
|
):
|
1467
1590
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1468
1591
|
|
1469
|
-
|
1592
|
+
topk_indices = None
|
1593
|
+
if q_lora is not None:
|
1594
|
+
topk_indices = self.indexer(
|
1595
|
+
x=hidden_states,
|
1596
|
+
q_lora=q_lora,
|
1597
|
+
positions=positions,
|
1598
|
+
forward_batch=forward_batch,
|
1599
|
+
layer_id=self.layer_id,
|
1600
|
+
)
|
1601
|
+
|
1602
|
+
return (
|
1603
|
+
q_pe,
|
1604
|
+
k_pe,
|
1605
|
+
q_nope_out,
|
1606
|
+
k_nope,
|
1607
|
+
forward_batch,
|
1608
|
+
zero_allocator,
|
1609
|
+
positions,
|
1610
|
+
topk_indices,
|
1611
|
+
)
|
1470
1612
|
|
1471
1613
|
def forward_absorb_core(
|
1472
|
-
self,
|
1614
|
+
self,
|
1615
|
+
q_pe,
|
1616
|
+
k_pe,
|
1617
|
+
q_nope_out,
|
1618
|
+
k_nope,
|
1619
|
+
forward_batch,
|
1620
|
+
zero_allocator,
|
1621
|
+
positions,
|
1622
|
+
topk_indices,
|
1473
1623
|
):
|
1474
|
-
if
|
1475
|
-
self.current_attention_backend == "fa3"
|
1476
|
-
or self.current_attention_backend == "flashinfer"
|
1477
|
-
or self.current_attention_backend == "cutlass_mla"
|
1478
|
-
or self.current_attention_backend == "trtllm_mla"
|
1479
|
-
or self.current_attention_backend == "ascend"
|
1480
|
-
):
|
1624
|
+
if self.current_attention_backend in FORWARD_ABSORB_CORE_ATTENTION_BACKENDS:
|
1481
1625
|
extra_args = {}
|
1482
1626
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
1483
1627
|
extra_args = {
|
1484
1628
|
"cos_sin_cache": self.rotary_emb.cos_sin_cache,
|
1485
1629
|
"is_neox": self.rotary_emb.is_neox_style,
|
1486
1630
|
}
|
1631
|
+
|
1487
1632
|
attn_output = self.attn_mqa(
|
1488
1633
|
q_nope_out,
|
1489
1634
|
k_nope,
|
@@ -1492,6 +1637,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1492
1637
|
q_rope=q_pe,
|
1493
1638
|
k_rope=k_pe,
|
1494
1639
|
**extra_args,
|
1640
|
+
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
1495
1641
|
)
|
1496
1642
|
else:
|
1497
1643
|
if _use_aiter_gfx95:
|
@@ -1511,7 +1657,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1511
1657
|
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1512
1658
|
k = torch.cat([k_nope, k_pe], dim=-1)
|
1513
1659
|
|
1514
|
-
attn_output = self.attn_mqa(
|
1660
|
+
attn_output = self.attn_mqa(
|
1661
|
+
q,
|
1662
|
+
k,
|
1663
|
+
k_nope,
|
1664
|
+
forward_batch,
|
1665
|
+
**(dict(topk_indices=topk_indices) if topk_indices is not None else {}),
|
1666
|
+
)
|
1515
1667
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1516
1668
|
|
1517
1669
|
if self.use_deep_gemm_bmm:
|
@@ -1593,6 +1745,221 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1593
1745
|
|
1594
1746
|
return output
|
1595
1747
|
|
1748
|
+
def forward_npu_sparse_prepare(
|
1749
|
+
self,
|
1750
|
+
positions: torch.Tensor,
|
1751
|
+
hidden_states: torch.Tensor,
|
1752
|
+
forward_batch: ForwardBatch,
|
1753
|
+
zero_allocator: BumpAllocator,
|
1754
|
+
):
|
1755
|
+
"""
|
1756
|
+
Reuse `self.q_lora_rank is not None` branch from forward_absorb_prepare
|
1757
|
+
"""
|
1758
|
+
if self.is_mla_preprocess_enabled and forward_batch.forward_mode.is_decode():
|
1759
|
+
if self.mla_preprocess is None:
|
1760
|
+
self.mla_preprocess = NPUFusedMLAPreprocess(
|
1761
|
+
self.fused_qkv_a_proj_with_mqa,
|
1762
|
+
self.q_a_layernorm,
|
1763
|
+
self.kv_a_layernorm,
|
1764
|
+
self.q_b_proj,
|
1765
|
+
self.w_kc,
|
1766
|
+
self.rotary_emb,
|
1767
|
+
self.layer_id,
|
1768
|
+
self.num_local_heads,
|
1769
|
+
self.qk_nope_head_dim,
|
1770
|
+
self.qk_rope_head_dim,
|
1771
|
+
)
|
1772
|
+
(
|
1773
|
+
q_pe,
|
1774
|
+
k_pe,
|
1775
|
+
q_nope_out,
|
1776
|
+
k_nope,
|
1777
|
+
forward_batch,
|
1778
|
+
zero_allocator,
|
1779
|
+
positions,
|
1780
|
+
) = self.mla_preprocess.forward(
|
1781
|
+
positions, hidden_states, forward_batch, zero_allocator
|
1782
|
+
)
|
1783
|
+
|
1784
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1785
|
+
q, _ = fused_qkv_a_proj_out.split(
|
1786
|
+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1787
|
+
)
|
1788
|
+
q_lora = self.q_a_layernorm(q)
|
1789
|
+
else:
|
1790
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1791
|
+
|
1792
|
+
if (
|
1793
|
+
(not isinstance(hidden_states, tuple))
|
1794
|
+
and hidden_states.shape[0] <= 16
|
1795
|
+
and self.use_min_latency_fused_a_gemm
|
1796
|
+
):
|
1797
|
+
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1798
|
+
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1799
|
+
)
|
1800
|
+
else:
|
1801
|
+
fused_qkv_a_proj_out = self.fused_qkv_a_proj_with_mqa(hidden_states)[0]
|
1802
|
+
q, latent_cache = fused_qkv_a_proj_out.split(
|
1803
|
+
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
1804
|
+
)
|
1805
|
+
k_nope = latent_cache[..., : self.kv_lora_rank]
|
1806
|
+
|
1807
|
+
# overlap qk norm
|
1808
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
1809
|
+
current_stream = torch.cuda.current_stream()
|
1810
|
+
self.alt_stream.wait_stream(current_stream)
|
1811
|
+
q = self.q_a_layernorm(q)
|
1812
|
+
with torch.cuda.stream(self.alt_stream):
|
1813
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1814
|
+
current_stream.wait_stream(self.alt_stream)
|
1815
|
+
else:
|
1816
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1817
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1818
|
+
q,
|
1819
|
+
self.q_a_layernorm.weight,
|
1820
|
+
self.q_a_layernorm.variance_epsilon,
|
1821
|
+
k_nope,
|
1822
|
+
self.kv_a_layernorm.weight,
|
1823
|
+
self.kv_a_layernorm.variance_epsilon,
|
1824
|
+
)
|
1825
|
+
else:
|
1826
|
+
q = self.q_a_layernorm(q)
|
1827
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1828
|
+
|
1829
|
+
q_lora = q.clone() # required for topk_indices
|
1830
|
+
k_nope = k_nope.unsqueeze(1)
|
1831
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
1832
|
+
|
1833
|
+
q_nope, q_pe = q.split(
|
1834
|
+
[self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
|
1835
|
+
)
|
1836
|
+
k_pe = latent_cache[..., self.kv_lora_rank :].unsqueeze(1)
|
1837
|
+
|
1838
|
+
if self.use_deep_gemm_bmm:
|
1839
|
+
q_nope_val, q_nope_scale, masked_m, expected_m, aligned_m = (
|
1840
|
+
per_token_group_quant_mla_deep_gemm_masked_fp8(
|
1841
|
+
q_nope.transpose(0, 1)
|
1842
|
+
)
|
1843
|
+
)
|
1844
|
+
q_nope_out = q_nope.new_empty(
|
1845
|
+
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
1846
|
+
)
|
1847
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
1848
|
+
(q_nope_val, q_nope_scale),
|
1849
|
+
(self.w_kc, self.w_scale_k),
|
1850
|
+
q_nope_out,
|
1851
|
+
masked_m,
|
1852
|
+
expected_m,
|
1853
|
+
)
|
1854
|
+
q_nope_out = q_nope_out[:, :expected_m, :]
|
1855
|
+
elif _is_hip:
|
1856
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
1857
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1858
|
+
x = q_nope.transpose(0, 1)
|
1859
|
+
q_nope_out = torch.empty(
|
1860
|
+
x.shape[0],
|
1861
|
+
x.shape[1],
|
1862
|
+
self.w_kc.shape[2],
|
1863
|
+
device=x.device,
|
1864
|
+
dtype=torch.bfloat16,
|
1865
|
+
)
|
1866
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1867
|
+
x,
|
1868
|
+
self.w_kc.transpose(-2, -1),
|
1869
|
+
self.w_scale_k.transpose(-2, -1),
|
1870
|
+
torch.bfloat16,
|
1871
|
+
q_nope_out,
|
1872
|
+
)
|
1873
|
+
else:
|
1874
|
+
q_nope_out = torch.bmm(
|
1875
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1876
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1877
|
+
)
|
1878
|
+
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1879
|
+
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1880
|
+
q_nope.transpose(0, 1),
|
1881
|
+
zero_allocator.allocate(1),
|
1882
|
+
)
|
1883
|
+
q_nope_out = bmm_fp8(
|
1884
|
+
q_nope_val, self.w_kc, q_nope_scale, self.w_scale, torch.bfloat16
|
1885
|
+
)
|
1886
|
+
else:
|
1887
|
+
q_nope_out = torch.bmm(q_nope.transpose(0, 1), self.w_kc)
|
1888
|
+
|
1889
|
+
q_nope_out = q_nope_out.transpose(0, 1)
|
1890
|
+
|
1891
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1892
|
+
not _use_aiter or not _is_gfx95_supported
|
1893
|
+
):
|
1894
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1895
|
+
|
1896
|
+
# TODO: multi-stream indexer
|
1897
|
+
topk_indices = self.indexer(
|
1898
|
+
hidden_states, q_lora, positions, forward_batch, self.layer_id
|
1899
|
+
)
|
1900
|
+
|
1901
|
+
return (
|
1902
|
+
q_pe,
|
1903
|
+
k_pe,
|
1904
|
+
q_nope_out,
|
1905
|
+
k_nope,
|
1906
|
+
topk_indices,
|
1907
|
+
forward_batch,
|
1908
|
+
zero_allocator,
|
1909
|
+
positions,
|
1910
|
+
)
|
1911
|
+
|
1912
|
+
def forward_npu_sparse_core(
|
1913
|
+
self,
|
1914
|
+
q_pe,
|
1915
|
+
k_pe,
|
1916
|
+
q_nope_out,
|
1917
|
+
k_nope,
|
1918
|
+
topk_indices,
|
1919
|
+
forward_batch,
|
1920
|
+
zero_allocator,
|
1921
|
+
positions,
|
1922
|
+
):
|
1923
|
+
attn_output = self.attn_mqa(
|
1924
|
+
q_nope_out.contiguous(),
|
1925
|
+
k_nope.contiguous(),
|
1926
|
+
k_nope.contiguous(),
|
1927
|
+
forward_batch,
|
1928
|
+
save_kv_cache=True, # False if forward_batch.forward_mode.is_extend() else True,
|
1929
|
+
q_rope=q_pe.contiguous(),
|
1930
|
+
k_rope=k_pe.contiguous(),
|
1931
|
+
topk_indices=topk_indices,
|
1932
|
+
)
|
1933
|
+
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1934
|
+
|
1935
|
+
attn_bmm_output = torch.empty(
|
1936
|
+
(attn_output.shape[0], self.num_local_heads, self.v_head_dim),
|
1937
|
+
dtype=attn_output.dtype,
|
1938
|
+
device=attn_output.device,
|
1939
|
+
)
|
1940
|
+
|
1941
|
+
if not forward_batch.forward_mode.is_decode():
|
1942
|
+
attn_output = attn_output.transpose(0, 1)
|
1943
|
+
torch.bmm(
|
1944
|
+
attn_output,
|
1945
|
+
self.w_vc,
|
1946
|
+
out=attn_bmm_output.view(
|
1947
|
+
-1, self.num_local_heads, self.v_head_dim
|
1948
|
+
).transpose(0, 1),
|
1949
|
+
)
|
1950
|
+
else:
|
1951
|
+
attn_output = attn_output.contiguous()
|
1952
|
+
torch.ops.npu.batch_matmul_transpose(
|
1953
|
+
attn_output, self.w_vc, attn_bmm_output
|
1954
|
+
)
|
1955
|
+
|
1956
|
+
attn_bmm_output = attn_bmm_output.reshape(
|
1957
|
+
-1, self.num_local_heads * self.v_head_dim
|
1958
|
+
)
|
1959
|
+
|
1960
|
+
output, _ = self.o_proj(attn_bmm_output)
|
1961
|
+
return output
|
1962
|
+
|
1596
1963
|
def forward_absorb_fused_mla_rope_prepare(
|
1597
1964
|
self,
|
1598
1965
|
positions: torch.Tensor,
|
@@ -1918,6 +2285,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1918
2285
|
tmp_lse = torch.empty_like(accum_lse)
|
1919
2286
|
merge_state_v2(output, lse, accum_output, accum_lse, tmp_output, tmp_lse)
|
1920
2287
|
accum_output, accum_lse = tmp_output, tmp_lse
|
2288
|
+
del kv, k, v, output, lse, tmp_output, tmp_lse
|
1921
2289
|
|
1922
2290
|
return accum_output
|
1923
2291
|
|
@@ -2074,7 +2442,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
2074
2442
|
zero_allocator: BumpAllocator,
|
2075
2443
|
gemm_output_zero_allocator: BumpAllocator = None,
|
2076
2444
|
) -> torch.Tensor:
|
2077
|
-
|
2078
2445
|
quant_format = (
|
2079
2446
|
"mxfp4"
|
2080
2447
|
if _is_gfx95_supported
|
@@ -3031,8 +3398,24 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
3031
3398
|
)
|
3032
3399
|
|
3033
3400
|
|
3401
|
+
AttentionBackendRegistry.register("ascend", handle_attention_ascend)
|
3402
|
+
AttentionBackendRegistry.register("flashinfer", handle_attention_flashinfer)
|
3403
|
+
AttentionBackendRegistry.register("fa3", handle_attention_fa3)
|
3404
|
+
AttentionBackendRegistry.register("flashmla", handle_attention_flashmla)
|
3405
|
+
AttentionBackendRegistry.register("cutlass_mla", handle_attention_cutlass_mla)
|
3406
|
+
AttentionBackendRegistry.register("fa4", handle_attention_fa4)
|
3407
|
+
AttentionBackendRegistry.register("trtllm_mla", handle_attention_trtllm_mla)
|
3408
|
+
AttentionBackendRegistry.register("aiter", handle_attention_aiter)
|
3409
|
+
AttentionBackendRegistry.register("nsa", handle_attention_nsa)
|
3410
|
+
AttentionBackendRegistry.register("triton", handle_attention_triton)
|
3411
|
+
|
3412
|
+
|
3034
3413
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
3035
3414
|
pass
|
3036
3415
|
|
3037
3416
|
|
3038
|
-
|
3417
|
+
class DeepseekV32ForCausalLM(DeepseekV2ForCausalLM):
|
3418
|
+
pass
|
3419
|
+
|
3420
|
+
|
3421
|
+
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM, DeepseekV32ForCausalLM]
|