sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +67 -43
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +88 -53
- sglang/srt/entrypoints/openai/protocol.py +7 -4
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +39 -19
- sglang/srt/entrypoints/openai/serving_completions.py +15 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -7
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +182 -49
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +68 -41
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +0 -18
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +200 -199
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +191 -139
- sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +260 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +18 -33
- sglang/srt/mem_cache/hiradix_cache.py +108 -48
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +121 -57
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +502 -77
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +75 -19
- sglang/srt/model_executor/model_runner.py +357 -30
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +346 -48
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +11 -2
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +60 -13
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +40 -9
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +355 -37
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +197 -112
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +46 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +12 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
- sglang/srt/disaggregation/launch_lb.py +0 -118
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -65,10 +65,11 @@ from sglang.srt.layers.moe import (
|
|
65
65
|
get_deepep_mode,
|
66
66
|
get_moe_a2a_backend,
|
67
67
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
68
|
+
should_use_flashinfer_trtllm_moe,
|
68
69
|
)
|
69
70
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
71
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
71
|
-
from sglang.srt.layers.moe.topk import TopK
|
72
|
+
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
72
73
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
73
74
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
74
75
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
@@ -112,6 +113,7 @@ from sglang.srt.utils import (
|
|
112
113
|
is_cpu,
|
113
114
|
is_cuda,
|
114
115
|
is_flashinfer_available,
|
116
|
+
is_gfx95_supported,
|
115
117
|
is_hip,
|
116
118
|
is_non_idle_and_non_empty,
|
117
119
|
is_npu,
|
@@ -129,11 +131,28 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
|
129
131
|
_is_cpu_amx_available = cpu_has_amx_support()
|
130
132
|
_is_cpu = is_cpu()
|
131
133
|
_device_sm = get_device_sm()
|
134
|
+
_is_gfx95_supported = is_gfx95_supported()
|
135
|
+
|
136
|
+
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
137
|
+
|
138
|
+
if _use_aiter_gfx95:
|
139
|
+
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
140
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
141
|
+
batched_gemm_afp4wfp4_pre_quant,
|
142
|
+
fused_flatten_mxfp4_quant,
|
143
|
+
fused_rms_mxfp4_quant,
|
144
|
+
)
|
145
|
+
from sglang.srt.layers.rocm_linear_utils import (
|
146
|
+
aiter_dsv3_router_gemm,
|
147
|
+
fused_qk_rope_cat,
|
148
|
+
get_dsv3_gemm_output_zero_allocator_size,
|
149
|
+
)
|
132
150
|
|
133
151
|
if _is_cuda:
|
134
152
|
from sgl_kernel import (
|
135
153
|
awq_dequantize,
|
136
154
|
bmm_fp8,
|
155
|
+
concat_mla_k,
|
137
156
|
dsv3_fused_a_gemm,
|
138
157
|
dsv3_router_gemm,
|
139
158
|
merge_state_v2,
|
@@ -224,10 +243,21 @@ class DeepseekV2MLP(nn.Module):
|
|
224
243
|
forward_batch=None,
|
225
244
|
should_allreduce_fusion: bool = False,
|
226
245
|
use_reduce_scatter: bool = False,
|
246
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
227
247
|
):
|
228
248
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
229
249
|
return x
|
230
250
|
|
251
|
+
if (
|
252
|
+
gemm_output_zero_allocator is not None
|
253
|
+
and x.shape[0] <= 256
|
254
|
+
and self.gate_up_proj.weight.dtype == torch.uint8
|
255
|
+
):
|
256
|
+
y = gemm_output_zero_allocator.allocate(
|
257
|
+
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
258
|
+
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
259
|
+
x = (x, None, y)
|
260
|
+
|
231
261
|
gate_up, _ = self.gate_up_proj(x)
|
232
262
|
x = self.act_fn(gate_up)
|
233
263
|
x, _ = self.down_proj(
|
@@ -240,6 +270,7 @@ class MoEGate(nn.Module):
|
|
240
270
|
def __init__(
|
241
271
|
self,
|
242
272
|
config,
|
273
|
+
quant_config,
|
243
274
|
prefix: str = "",
|
244
275
|
is_nextn: bool = False,
|
245
276
|
):
|
@@ -249,15 +280,22 @@ class MoEGate(nn.Module):
|
|
249
280
|
torch.empty((config.n_routed_experts, config.hidden_size))
|
250
281
|
)
|
251
282
|
if config.topk_method == "noaux_tc":
|
283
|
+
correction_bias_dtype = (
|
284
|
+
torch.bfloat16
|
285
|
+
if quant_config is not None
|
286
|
+
and quant_config.get_name() == "modelopt_fp4"
|
287
|
+
and should_use_flashinfer_trtllm_moe()
|
288
|
+
else torch.float32
|
289
|
+
)
|
252
290
|
self.e_score_correction_bias = nn.Parameter(
|
253
|
-
torch.empty((config.n_routed_experts), dtype=
|
291
|
+
torch.empty((config.n_routed_experts), dtype=correction_bias_dtype)
|
254
292
|
)
|
255
293
|
else:
|
256
294
|
self.e_score_correction_bias = None
|
257
295
|
if _is_cpu and _is_cpu_amx_available:
|
258
296
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
259
297
|
|
260
|
-
def forward(self, hidden_states):
|
298
|
+
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
261
299
|
if use_intel_amx_backend(self):
|
262
300
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
263
301
|
hidden_states,
|
@@ -275,7 +313,13 @@ class MoEGate(nn.Module):
|
|
275
313
|
and _device_sm >= 90
|
276
314
|
):
|
277
315
|
# router gemm output float32
|
278
|
-
logits = dsv3_router_gemm(
|
316
|
+
logits = dsv3_router_gemm(
|
317
|
+
hidden_states, self.weight, out_dtype=torch.float32
|
318
|
+
)
|
319
|
+
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
320
|
+
logits = aiter_dsv3_router_gemm(
|
321
|
+
hidden_states, self.weight, gemm_output_zero_allocator
|
322
|
+
)
|
279
323
|
else:
|
280
324
|
logits = F.linear(hidden_states, self.weight, None)
|
281
325
|
|
@@ -319,7 +363,10 @@ class DeepseekV2MoE(nn.Module):
|
|
319
363
|
)
|
320
364
|
|
321
365
|
self.gate = MoEGate(
|
322
|
-
config=config,
|
366
|
+
config=config,
|
367
|
+
quant_config=quant_config,
|
368
|
+
prefix=add_prefix("gate", prefix),
|
369
|
+
is_nextn=is_nextn,
|
323
370
|
)
|
324
371
|
|
325
372
|
self.experts = get_moe_impl_class(quant_config)(
|
@@ -344,9 +391,12 @@ class DeepseekV2MoE(nn.Module):
|
|
344
391
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
345
392
|
topk_group=config.topk_group,
|
346
393
|
correction_bias=self.gate.e_score_correction_bias,
|
394
|
+
quant_config=quant_config,
|
347
395
|
routed_scaling_factor=self.routed_scaling_factor,
|
348
396
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
349
|
-
|
397
|
+
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
398
|
+
# and requires the output format to be standard. We use quant_config to determine the output format.
|
399
|
+
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
350
400
|
)
|
351
401
|
|
352
402
|
self.shared_experts_is_int8 = False
|
@@ -439,6 +489,7 @@ class DeepseekV2MoE(nn.Module):
|
|
439
489
|
forward_batch: Optional[ForwardBatch] = None,
|
440
490
|
should_allreduce_fusion: bool = False,
|
441
491
|
use_reduce_scatter: bool = False,
|
492
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
442
493
|
) -> torch.Tensor:
|
443
494
|
if not self._enable_deepep_moe:
|
444
495
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -452,12 +503,14 @@ class DeepseekV2MoE(nn.Module):
|
|
452
503
|
hidden_states,
|
453
504
|
should_allreduce_fusion,
|
454
505
|
use_reduce_scatter,
|
506
|
+
gemm_output_zero_allocator,
|
455
507
|
)
|
456
508
|
else:
|
457
509
|
return self.forward_normal(
|
458
510
|
hidden_states,
|
459
511
|
should_allreduce_fusion,
|
460
512
|
use_reduce_scatter,
|
513
|
+
gemm_output_zero_allocator,
|
461
514
|
)
|
462
515
|
else:
|
463
516
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -467,15 +520,18 @@ class DeepseekV2MoE(nn.Module):
|
|
467
520
|
hidden_states: torch.Tensor,
|
468
521
|
should_allreduce_fusion: bool = False,
|
469
522
|
use_reduce_scatter: bool = False,
|
523
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
470
524
|
) -> torch.Tensor:
|
471
525
|
|
472
526
|
current_stream = torch.cuda.current_stream()
|
473
527
|
self.alt_stream.wait_stream(current_stream)
|
474
|
-
shared_output = self._forward_shared_experts(
|
528
|
+
shared_output = self._forward_shared_experts(
|
529
|
+
hidden_states, gemm_output_zero_allocator
|
530
|
+
)
|
475
531
|
|
476
532
|
with torch.cuda.stream(self.alt_stream):
|
477
533
|
# router_logits: (num_tokens, n_experts)
|
478
|
-
router_logits = self.gate(hidden_states)
|
534
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
479
535
|
topk_output = self.topk(hidden_states, router_logits)
|
480
536
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
481
537
|
if not _is_cuda:
|
@@ -502,6 +558,7 @@ class DeepseekV2MoE(nn.Module):
|
|
502
558
|
hidden_states: torch.Tensor,
|
503
559
|
should_allreduce_fusion: bool = False,
|
504
560
|
use_reduce_scatter: bool = False,
|
561
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
505
562
|
) -> torch.Tensor:
|
506
563
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
507
564
|
self.shared_experts.gate_up_proj
|
@@ -509,9 +566,11 @@ class DeepseekV2MoE(nn.Module):
|
|
509
566
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
510
567
|
|
511
568
|
if hidden_states.shape[0] > 0:
|
512
|
-
shared_output = self._forward_shared_experts(
|
569
|
+
shared_output = self._forward_shared_experts(
|
570
|
+
hidden_states, gemm_output_zero_allocator
|
571
|
+
)
|
513
572
|
# router_logits: (num_tokens, n_experts)
|
514
|
-
router_logits = self.gate(hidden_states)
|
573
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
515
574
|
topk_output = self.topk(hidden_states, router_logits)
|
516
575
|
else:
|
517
576
|
shared_output = None
|
@@ -624,16 +683,24 @@ class DeepseekV2MoE(nn.Module):
|
|
624
683
|
|
625
684
|
if shared_output is not None:
|
626
685
|
x = shared_output
|
627
|
-
|
686
|
+
if self.experts.should_fuse_routed_scaling_factor_in_topk():
|
687
|
+
x.add_(final_hidden_states)
|
688
|
+
else:
|
689
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
628
690
|
final_hidden_states = x
|
629
691
|
else:
|
630
|
-
|
692
|
+
if not self.experts.should_fuse_routed_scaling_factor_in_topk():
|
693
|
+
final_hidden_states *= self.routed_scaling_factor
|
631
694
|
|
632
695
|
return final_hidden_states
|
633
696
|
|
634
|
-
def _forward_shared_experts(
|
697
|
+
def _forward_shared_experts(
|
698
|
+
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
699
|
+
):
|
635
700
|
if self.num_fused_shared_experts == 0:
|
636
|
-
return self.shared_experts(
|
701
|
+
return self.shared_experts(
|
702
|
+
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
703
|
+
)
|
637
704
|
else:
|
638
705
|
return None
|
639
706
|
|
@@ -992,6 +1059,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
992
1059
|
# Determine attention backend used by current forward batch
|
993
1060
|
if forward_batch.forward_mode.is_decode_or_idle():
|
994
1061
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1062
|
+
elif (
|
1063
|
+
forward_batch.forward_mode.is_target_verify()
|
1064
|
+
or forward_batch.forward_mode.is_draft_extend()
|
1065
|
+
):
|
1066
|
+
# Use the specified backend for speculative operations (both verify and draft extend)
|
1067
|
+
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
1068
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1069
|
+
else: # default to prefill
|
1070
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
995
1071
|
else:
|
996
1072
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
997
1073
|
self.current_attention_backend = attention_backend
|
@@ -1009,7 +1085,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1009
1085
|
attention_backend == "flashinfer"
|
1010
1086
|
or attention_backend == "fa3"
|
1011
1087
|
or attention_backend == "flashmla"
|
1012
|
-
or attention_backend == "trtllm_mla"
|
1013
1088
|
or attention_backend == "cutlass_mla"
|
1014
1089
|
):
|
1015
1090
|
# Use MHA with chunked KV cache when prefilling on long sequences.
|
@@ -1022,6 +1097,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1022
1097
|
disable_ragged = (
|
1023
1098
|
attention_backend == "flashinfer" or attention_backend == "flashmla"
|
1024
1099
|
) and self.flashinfer_mla_disable_ragged
|
1100
|
+
|
1101
|
+
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
1025
1102
|
if (
|
1026
1103
|
not disable_ragged
|
1027
1104
|
and forward_batch.forward_mode.is_extend()
|
@@ -1034,6 +1111,40 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1034
1111
|
)
|
1035
1112
|
or sum_extend_prefix_lens == 0
|
1036
1113
|
)
|
1114
|
+
# TODO(shuw@nvidia.com) Flashinfer cutlass and trtllm_mla backend have accuracy issue on blackwell for
|
1115
|
+
# dp case. Redirect to mla kernel as a workaround.
|
1116
|
+
# Tracked by https://github.com/sgl-project/sglang/issues/9806.
|
1117
|
+
and not (
|
1118
|
+
original_mode is not None
|
1119
|
+
and original_mode.is_decode()
|
1120
|
+
and is_sm100_supported()
|
1121
|
+
and self.current_attention_backend in ("cutlass_mla", "flashinfer")
|
1122
|
+
)
|
1123
|
+
):
|
1124
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
1125
|
+
else:
|
1126
|
+
return _dispatch_mla_subtype()
|
1127
|
+
elif attention_backend == "trtllm_mla":
|
1128
|
+
original_mode = getattr(forward_batch, "_original_forward_mode", None)
|
1129
|
+
if (
|
1130
|
+
original_mode is not None
|
1131
|
+
and original_mode.is_decode()
|
1132
|
+
and is_sm100_supported()
|
1133
|
+
):
|
1134
|
+
return _dispatch_mla_subtype()
|
1135
|
+
|
1136
|
+
sum_extend_prefix_lens = (
|
1137
|
+
sum(forward_batch.extend_prefix_lens_cpu)
|
1138
|
+
if forward_batch.extend_prefix_lens_cpu is not None
|
1139
|
+
else 0
|
1140
|
+
)
|
1141
|
+
if (
|
1142
|
+
forward_batch.forward_mode.is_extend()
|
1143
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1144
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1145
|
+
and (
|
1146
|
+
not self.disable_chunked_prefix_cache or sum_extend_prefix_lens == 0
|
1147
|
+
)
|
1037
1148
|
):
|
1038
1149
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
1039
1150
|
else:
|
@@ -1044,7 +1155,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1044
1155
|
and not forward_batch.forward_mode.is_target_verify()
|
1045
1156
|
and not forward_batch.forward_mode.is_draft_extend()
|
1046
1157
|
):
|
1047
|
-
|
1158
|
+
if is_dp_attention_enabled():
|
1159
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
1160
|
+
return AttnForwardMethod.MHA
|
1161
|
+
else:
|
1162
|
+
return AttnForwardMethod.MLA
|
1163
|
+
else:
|
1164
|
+
return AttnForwardMethod.MHA
|
1048
1165
|
else:
|
1049
1166
|
return AttnForwardMethod.MLA
|
1050
1167
|
else:
|
@@ -1097,11 +1214,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1097
1214
|
if self.attn_mha.kv_b_proj is None:
|
1098
1215
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
1099
1216
|
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1103
|
-
|
1104
|
-
|
1217
|
+
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
1218
|
+
if isinstance(hidden_states, tuple):
|
1219
|
+
if hidden_states[0].shape[0] == 0:
|
1220
|
+
assert (
|
1221
|
+
not self.o_proj.reduce_results
|
1222
|
+
), "short-circuiting allreduce will lead to hangs"
|
1223
|
+
return hidden_states[0]
|
1224
|
+
else:
|
1225
|
+
if hidden_states.shape[0] == 0:
|
1226
|
+
assert (
|
1227
|
+
not self.o_proj.reduce_results
|
1228
|
+
), "short-circuiting allreduce will lead to hangs"
|
1229
|
+
return hidden_states, None, forward_batch, None
|
1105
1230
|
|
1106
1231
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1107
1232
|
|
@@ -1180,8 +1305,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1180
1305
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1181
1306
|
q[..., self.qk_nope_head_dim :] = q_pe
|
1182
1307
|
k = torch.empty_like(q)
|
1183
|
-
|
1184
|
-
|
1308
|
+
|
1309
|
+
# Temporary for DeepSeek V3/R1 only, but can generalize if needed
|
1310
|
+
if (
|
1311
|
+
_is_cuda
|
1312
|
+
and (self.num_local_heads == 128)
|
1313
|
+
and (self.qk_nope_head_dim == 128)
|
1314
|
+
and (self.qk_rope_head_dim == 64)
|
1315
|
+
):
|
1316
|
+
concat_mla_k(k=k, k_nope=k_nope, k_rope=k_pe)
|
1317
|
+
else:
|
1318
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
1319
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
1185
1320
|
|
1186
1321
|
if not _is_npu:
|
1187
1322
|
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
@@ -1225,7 +1360,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1225
1360
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1226
1361
|
|
1227
1362
|
if self.q_lora_rank is not None:
|
1228
|
-
if
|
1363
|
+
if (
|
1364
|
+
(not isinstance(hidden_states, tuple))
|
1365
|
+
and hidden_states.shape[0] <= 16
|
1366
|
+
and self.use_min_latency_fused_a_gemm
|
1367
|
+
):
|
1229
1368
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1230
1369
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1231
1370
|
)
|
@@ -1245,8 +1384,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1245
1384
|
k_nope = self.kv_a_layernorm(k_nope)
|
1246
1385
|
current_stream.wait_stream(self.alt_stream)
|
1247
1386
|
else:
|
1248
|
-
|
1249
|
-
|
1387
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1388
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1389
|
+
q,
|
1390
|
+
self.q_a_layernorm.weight,
|
1391
|
+
self.q_a_layernorm.variance_epsilon,
|
1392
|
+
k_nope,
|
1393
|
+
self.kv_a_layernorm.weight,
|
1394
|
+
self.kv_a_layernorm.variance_epsilon,
|
1395
|
+
)
|
1396
|
+
else:
|
1397
|
+
q = self.q_a_layernorm(q)
|
1398
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1250
1399
|
|
1251
1400
|
k_nope = k_nope.unsqueeze(1)
|
1252
1401
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
@@ -1278,10 +1427,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1278
1427
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
1279
1428
|
elif _is_hip:
|
1280
1429
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1281
|
-
|
1282
|
-
q_nope.
|
1283
|
-
|
1284
|
-
|
1430
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1431
|
+
x = q_nope.transpose(0, 1)
|
1432
|
+
q_nope_out = torch.empty(
|
1433
|
+
x.shape[0],
|
1434
|
+
x.shape[1],
|
1435
|
+
self.w_kc.shape[2],
|
1436
|
+
device=x.device,
|
1437
|
+
dtype=torch.bfloat16,
|
1438
|
+
)
|
1439
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1440
|
+
x,
|
1441
|
+
self.w_kc.transpose(-2, -1),
|
1442
|
+
self.w_scale_k.transpose(-2, -1),
|
1443
|
+
torch.bfloat16,
|
1444
|
+
q_nope_out,
|
1445
|
+
)
|
1446
|
+
else:
|
1447
|
+
q_nope_out = torch.bmm(
|
1448
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1449
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1450
|
+
)
|
1285
1451
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1286
1452
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1287
1453
|
q_nope.transpose(0, 1),
|
@@ -1295,13 +1461,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1295
1461
|
|
1296
1462
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1297
1463
|
|
1298
|
-
if not self._fuse_rope_for_trtllm_mla(forward_batch)
|
1464
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1465
|
+
not _use_aiter or not _is_gfx95_supported
|
1466
|
+
):
|
1299
1467
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1300
1468
|
|
1301
|
-
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1469
|
+
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1302
1470
|
|
1303
1471
|
def forward_absorb_core(
|
1304
|
-
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1472
|
+
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1305
1473
|
):
|
1306
1474
|
if (
|
1307
1475
|
self.current_attention_backend == "fa3"
|
@@ -1326,8 +1494,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1326
1494
|
**extra_args,
|
1327
1495
|
)
|
1328
1496
|
else:
|
1329
|
-
|
1330
|
-
|
1497
|
+
if _use_aiter_gfx95:
|
1498
|
+
cos = self.rotary_emb.cos_cache
|
1499
|
+
sin = self.rotary_emb.sin_cache
|
1500
|
+
q, k = fused_qk_rope_cat(
|
1501
|
+
q_nope_out,
|
1502
|
+
q_pe,
|
1503
|
+
k_nope,
|
1504
|
+
k_pe,
|
1505
|
+
positions,
|
1506
|
+
cos,
|
1507
|
+
sin,
|
1508
|
+
self.rotary_emb.is_neox_style,
|
1509
|
+
)
|
1510
|
+
else:
|
1511
|
+
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1512
|
+
k = torch.cat([k_nope, k_pe], dim=-1)
|
1513
|
+
|
1331
1514
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
1332
1515
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1333
1516
|
|
@@ -1352,11 +1535,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1352
1535
|
)
|
1353
1536
|
elif _is_hip:
|
1354
1537
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1355
|
-
|
1356
|
-
attn_output.
|
1357
|
-
|
1358
|
-
|
1359
|
-
|
1538
|
+
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
1539
|
+
x = attn_output.transpose(0, 1)
|
1540
|
+
attn_bmm_output = torch.empty(
|
1541
|
+
x.shape[0],
|
1542
|
+
x.shape[1],
|
1543
|
+
self.w_vc.shape[2],
|
1544
|
+
device=x.device,
|
1545
|
+
dtype=torch.bfloat16,
|
1546
|
+
)
|
1547
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1548
|
+
x,
|
1549
|
+
self.w_vc.transpose(-2, -1),
|
1550
|
+
self.w_scale_v.transpose(-2, -1),
|
1551
|
+
torch.bfloat16,
|
1552
|
+
attn_bmm_output,
|
1553
|
+
)
|
1554
|
+
else:
|
1555
|
+
attn_bmm_output = torch.bmm(
|
1556
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1557
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1558
|
+
)
|
1559
|
+
|
1560
|
+
if self.o_proj.weight.dtype == torch.uint8:
|
1561
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
1562
|
+
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
1563
|
+
else:
|
1564
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1565
|
+
|
1360
1566
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1361
1567
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1362
1568
|
attn_output.transpose(0, 1),
|
@@ -1678,9 +1884,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1678
1884
|
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
1679
1885
|
self.attn_mha.layer_id
|
1680
1886
|
)
|
1681
|
-
latent_cache =
|
1682
|
-
forward_batch.prefix_chunk_kv_indices[i]
|
1683
|
-
|
1887
|
+
latent_cache = (
|
1888
|
+
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
1889
|
+
.contiguous()
|
1890
|
+
.to(q.dtype)
|
1891
|
+
)
|
1684
1892
|
|
1685
1893
|
kv_a_normed, k_pe = latent_cache.split(
|
1686
1894
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
@@ -1864,10 +2072,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1864
2072
|
forward_batch: ForwardBatch,
|
1865
2073
|
residual: Optional[torch.Tensor],
|
1866
2074
|
zero_allocator: BumpAllocator,
|
2075
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
1867
2076
|
) -> torch.Tensor:
|
1868
2077
|
|
2078
|
+
quant_format = (
|
2079
|
+
"mxfp4"
|
2080
|
+
if _is_gfx95_supported
|
2081
|
+
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
|
2082
|
+
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
|
2083
|
+
is not None
|
2084
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
|
2085
|
+
else ""
|
2086
|
+
)
|
2087
|
+
|
1869
2088
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1870
|
-
hidden_states,
|
2089
|
+
hidden_states,
|
2090
|
+
residual,
|
2091
|
+
forward_batch,
|
2092
|
+
quant_format,
|
1871
2093
|
)
|
1872
2094
|
|
1873
2095
|
hidden_states = self.self_attn(
|
@@ -1891,8 +2113,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1891
2113
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1892
2114
|
forward_batch
|
1893
2115
|
)
|
2116
|
+
|
2117
|
+
if isinstance(self.mlp, DeepseekV2MLP):
|
2118
|
+
gemm_output_zero_allocator = None
|
2119
|
+
|
1894
2120
|
hidden_states = self.mlp(
|
1895
|
-
hidden_states,
|
2121
|
+
hidden_states,
|
2122
|
+
forward_batch,
|
2123
|
+
should_allreduce_fusion,
|
2124
|
+
use_reduce_scatter,
|
2125
|
+
gemm_output_zero_allocator,
|
1896
2126
|
)
|
1897
2127
|
|
1898
2128
|
if should_allreduce_fusion:
|
@@ -2023,8 +2253,15 @@ class DeepseekV2Model(nn.Module):
|
|
2023
2253
|
[
|
2024
2254
|
"w13_weight",
|
2025
2255
|
"w2_weight",
|
2026
|
-
|
2027
|
-
|
2256
|
+
# only for nvfp4
|
2257
|
+
*(
|
2258
|
+
[
|
2259
|
+
"w13_blockscale_swizzled",
|
2260
|
+
"w2_blockscale_swizzled",
|
2261
|
+
]
|
2262
|
+
if hasattr(module, "w13_blockscale_swizzled")
|
2263
|
+
else []
|
2264
|
+
),
|
2028
2265
|
]
|
2029
2266
|
if isinstance(module, FusedMoE)
|
2030
2267
|
else []
|
@@ -2036,6 +2273,37 @@ class DeepseekV2Model(nn.Module):
|
|
2036
2273
|
else:
|
2037
2274
|
self.norm = PPMissingLayer(return_tuple=True)
|
2038
2275
|
|
2276
|
+
self.gemm_output_zero_allocator_size = 0
|
2277
|
+
if (
|
2278
|
+
_use_aiter_gfx95
|
2279
|
+
and config.n_routed_experts == 256
|
2280
|
+
and self.embed_tokens.embedding_dim == 7168
|
2281
|
+
):
|
2282
|
+
num_moe_layers = sum(
|
2283
|
+
[
|
2284
|
+
1
|
2285
|
+
for i in range(len(self.layers))
|
2286
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
2287
|
+
]
|
2288
|
+
)
|
2289
|
+
|
2290
|
+
allocate_size = 0
|
2291
|
+
for i in range(len(self.layers)):
|
2292
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
2293
|
+
allocate_size = self.layers[
|
2294
|
+
i
|
2295
|
+
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
2296
|
+
break
|
2297
|
+
|
2298
|
+
self.gemm_output_zero_allocator_size = (
|
2299
|
+
get_dsv3_gemm_output_zero_allocator_size(
|
2300
|
+
config.n_routed_experts,
|
2301
|
+
num_moe_layers,
|
2302
|
+
allocate_size,
|
2303
|
+
self.embed_tokens.embedding_dim,
|
2304
|
+
)
|
2305
|
+
)
|
2306
|
+
|
2039
2307
|
def get_input_embeddings(self) -> torch.Tensor:
|
2040
2308
|
return self.embed_tokens
|
2041
2309
|
|
@@ -2055,6 +2323,21 @@ class DeepseekV2Model(nn.Module):
|
|
2055
2323
|
device=device,
|
2056
2324
|
)
|
2057
2325
|
|
2326
|
+
has_gemm_output_zero_allocator = hasattr(
|
2327
|
+
self, "gemm_output_zero_allocator_size"
|
2328
|
+
)
|
2329
|
+
|
2330
|
+
gemm_output_zero_allocator = (
|
2331
|
+
BumpAllocator(
|
2332
|
+
buffer_size=self.gemm_output_zero_allocator_size,
|
2333
|
+
dtype=torch.float32,
|
2334
|
+
device=device,
|
2335
|
+
)
|
2336
|
+
if has_gemm_output_zero_allocator
|
2337
|
+
and self.gemm_output_zero_allocator_size > 0
|
2338
|
+
else None
|
2339
|
+
)
|
2340
|
+
|
2058
2341
|
if self.pp_group.is_first_rank:
|
2059
2342
|
if input_embeds is None:
|
2060
2343
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -2081,7 +2364,12 @@ class DeepseekV2Model(nn.Module):
|
|
2081
2364
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2082
2365
|
layer = self.layers[i]
|
2083
2366
|
hidden_states, residual = layer(
|
2084
|
-
positions,
|
2367
|
+
positions,
|
2368
|
+
hidden_states,
|
2369
|
+
forward_batch,
|
2370
|
+
residual,
|
2371
|
+
zero_allocator,
|
2372
|
+
gemm_output_zero_allocator,
|
2085
2373
|
)
|
2086
2374
|
|
2087
2375
|
if normal_end_layer != self.end_layer:
|
@@ -2354,6 +2642,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2354
2642
|
w_kc, w_vc = w.unflatten(
|
2355
2643
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2356
2644
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
2645
|
+
|
2646
|
+
if (
|
2647
|
+
_use_aiter_gfx95
|
2648
|
+
and self.quant_config is not None
|
2649
|
+
and self.quant_config.get_name() == "quark"
|
2650
|
+
):
|
2651
|
+
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
2652
|
+
quark_post_load_weights(self_attn, w, "mxfp4")
|
2653
|
+
)
|
2654
|
+
|
2357
2655
|
if not use_deep_gemm_bmm:
|
2358
2656
|
self_attn.w_kc = bind_or_assign(
|
2359
2657
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|