sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import (
|
|
67
67
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
68
68
|
)
|
69
69
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
70
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
71
|
+
FusedMoE,
|
72
|
+
_is_fp4_quantization_enabled,
|
73
|
+
)
|
71
74
|
from sglang.srt.layers.moe.topk import TopK
|
72
75
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
73
76
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -87,8 +90,8 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
87
90
|
block_dequant as int8_block_dequant,
|
88
91
|
)
|
89
92
|
from sglang.srt.layers.radix_attention import RadixAttention
|
90
|
-
from sglang.srt.layers.rotary_embedding import
|
91
|
-
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
93
|
+
from sglang.srt.layers.rotary_embedding import get_rope_wrapper
|
94
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
92
95
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
93
96
|
ParallelLMHead,
|
94
97
|
VocabParallelEmbedding,
|
@@ -112,8 +115,11 @@ from sglang.srt.utils import (
|
|
112
115
|
is_cpu,
|
113
116
|
is_cuda,
|
114
117
|
is_flashinfer_available,
|
118
|
+
is_gfx95_supported,
|
115
119
|
is_hip,
|
116
120
|
is_non_idle_and_non_empty,
|
121
|
+
is_npu,
|
122
|
+
is_sm100_supported,
|
117
123
|
log_info_on_rank0,
|
118
124
|
make_layers,
|
119
125
|
use_intel_amx_backend,
|
@@ -121,11 +127,28 @@ from sglang.srt.utils import (
|
|
121
127
|
|
122
128
|
_is_hip = is_hip()
|
123
129
|
_is_cuda = is_cuda()
|
130
|
+
_is_npu = is_npu()
|
124
131
|
_is_fp8_fnuz = is_fp8_fnuz()
|
125
132
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
126
133
|
_is_cpu_amx_available = cpu_has_amx_support()
|
127
134
|
_is_cpu = is_cpu()
|
128
135
|
_device_sm = get_device_sm()
|
136
|
+
_is_gfx95_supported = is_gfx95_supported()
|
137
|
+
|
138
|
+
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
139
|
+
|
140
|
+
if _use_aiter_gfx95:
|
141
|
+
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
142
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
143
|
+
batched_gemm_afp4wfp4_pre_quant,
|
144
|
+
fused_flatten_mxfp4_quant,
|
145
|
+
fused_rms_mxfp4_quant,
|
146
|
+
)
|
147
|
+
from sglang.srt.layers.rocm_linear_utils import (
|
148
|
+
aiter_dsv3_router_gemm,
|
149
|
+
fused_qk_rope_cat,
|
150
|
+
get_dsv3_gemm_output_zero_allocator_size,
|
151
|
+
)
|
129
152
|
|
130
153
|
if _is_cuda:
|
131
154
|
from sgl_kernel import (
|
@@ -221,10 +244,21 @@ class DeepseekV2MLP(nn.Module):
|
|
221
244
|
forward_batch=None,
|
222
245
|
should_allreduce_fusion: bool = False,
|
223
246
|
use_reduce_scatter: bool = False,
|
247
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
224
248
|
):
|
225
249
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
226
250
|
return x
|
227
251
|
|
252
|
+
if (
|
253
|
+
gemm_output_zero_allocator is not None
|
254
|
+
and x.shape[0] <= 256
|
255
|
+
and self.gate_up_proj.weight.dtype == torch.uint8
|
256
|
+
):
|
257
|
+
y = gemm_output_zero_allocator.allocate(
|
258
|
+
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
259
|
+
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
260
|
+
x = (x, None, y)
|
261
|
+
|
228
262
|
gate_up, _ = self.gate_up_proj(x)
|
229
263
|
x = self.act_fn(gate_up)
|
230
264
|
x, _ = self.down_proj(
|
@@ -254,7 +288,7 @@ class MoEGate(nn.Module):
|
|
254
288
|
if _is_cpu and _is_cpu_amx_available:
|
255
289
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
256
290
|
|
257
|
-
def forward(self, hidden_states):
|
291
|
+
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
258
292
|
if use_intel_amx_backend(self):
|
259
293
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
260
294
|
hidden_states,
|
@@ -272,7 +306,13 @@ class MoEGate(nn.Module):
|
|
272
306
|
and _device_sm >= 90
|
273
307
|
):
|
274
308
|
# router gemm output float32
|
275
|
-
logits = dsv3_router_gemm(
|
309
|
+
logits = dsv3_router_gemm(
|
310
|
+
hidden_states, self.weight, out_dtype=torch.float32
|
311
|
+
)
|
312
|
+
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
313
|
+
logits = aiter_dsv3_router_gemm(
|
314
|
+
hidden_states, self.weight, gemm_output_zero_allocator
|
315
|
+
)
|
276
316
|
else:
|
277
317
|
logits = F.linear(hidden_states, self.weight, None)
|
278
318
|
|
@@ -333,6 +373,9 @@ class DeepseekV2MoE(nn.Module):
|
|
333
373
|
prefix=add_prefix("experts", prefix),
|
334
374
|
)
|
335
375
|
|
376
|
+
correction_bias = self.gate.e_score_correction_bias
|
377
|
+
if _is_fp4_quantization_enabled():
|
378
|
+
correction_bias = correction_bias.to(torch.bfloat16)
|
336
379
|
self.topk = TopK(
|
337
380
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
338
381
|
renormalize=config.norm_topk_prob,
|
@@ -340,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
|
|
340
383
|
num_expert_group=config.n_group,
|
341
384
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
342
385
|
topk_group=config.topk_group,
|
343
|
-
correction_bias=
|
386
|
+
correction_bias=correction_bias,
|
344
387
|
routed_scaling_factor=self.routed_scaling_factor,
|
345
388
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
346
389
|
force_topk=quant_config is None,
|
@@ -436,6 +479,7 @@ class DeepseekV2MoE(nn.Module):
|
|
436
479
|
forward_batch: Optional[ForwardBatch] = None,
|
437
480
|
should_allreduce_fusion: bool = False,
|
438
481
|
use_reduce_scatter: bool = False,
|
482
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
439
483
|
) -> torch.Tensor:
|
440
484
|
if not self._enable_deepep_moe:
|
441
485
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -449,12 +493,14 @@ class DeepseekV2MoE(nn.Module):
|
|
449
493
|
hidden_states,
|
450
494
|
should_allreduce_fusion,
|
451
495
|
use_reduce_scatter,
|
496
|
+
gemm_output_zero_allocator,
|
452
497
|
)
|
453
498
|
else:
|
454
499
|
return self.forward_normal(
|
455
500
|
hidden_states,
|
456
501
|
should_allreduce_fusion,
|
457
502
|
use_reduce_scatter,
|
503
|
+
gemm_output_zero_allocator,
|
458
504
|
)
|
459
505
|
else:
|
460
506
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -464,15 +510,18 @@ class DeepseekV2MoE(nn.Module):
|
|
464
510
|
hidden_states: torch.Tensor,
|
465
511
|
should_allreduce_fusion: bool = False,
|
466
512
|
use_reduce_scatter: bool = False,
|
513
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
467
514
|
) -> torch.Tensor:
|
468
515
|
|
469
516
|
current_stream = torch.cuda.current_stream()
|
470
517
|
self.alt_stream.wait_stream(current_stream)
|
471
|
-
shared_output = self._forward_shared_experts(
|
518
|
+
shared_output = self._forward_shared_experts(
|
519
|
+
hidden_states, gemm_output_zero_allocator
|
520
|
+
)
|
472
521
|
|
473
522
|
with torch.cuda.stream(self.alt_stream):
|
474
523
|
# router_logits: (num_tokens, n_experts)
|
475
|
-
router_logits = self.gate(hidden_states)
|
524
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
476
525
|
topk_output = self.topk(hidden_states, router_logits)
|
477
526
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
478
527
|
if not _is_cuda:
|
@@ -499,6 +548,7 @@ class DeepseekV2MoE(nn.Module):
|
|
499
548
|
hidden_states: torch.Tensor,
|
500
549
|
should_allreduce_fusion: bool = False,
|
501
550
|
use_reduce_scatter: bool = False,
|
551
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
502
552
|
) -> torch.Tensor:
|
503
553
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
504
554
|
self.shared_experts.gate_up_proj
|
@@ -506,9 +556,11 @@ class DeepseekV2MoE(nn.Module):
|
|
506
556
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
507
557
|
|
508
558
|
if hidden_states.shape[0] > 0:
|
509
|
-
shared_output = self._forward_shared_experts(
|
559
|
+
shared_output = self._forward_shared_experts(
|
560
|
+
hidden_states, gemm_output_zero_allocator
|
561
|
+
)
|
510
562
|
# router_logits: (num_tokens, n_experts)
|
511
|
-
router_logits = self.gate(hidden_states)
|
563
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
512
564
|
topk_output = self.topk(hidden_states, router_logits)
|
513
565
|
else:
|
514
566
|
shared_output = None
|
@@ -628,9 +680,13 @@ class DeepseekV2MoE(nn.Module):
|
|
628
680
|
|
629
681
|
return final_hidden_states
|
630
682
|
|
631
|
-
def _forward_shared_experts(
|
683
|
+
def _forward_shared_experts(
|
684
|
+
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
685
|
+
):
|
632
686
|
if self.num_fused_shared_experts == 0:
|
633
|
-
return self.shared_experts(
|
687
|
+
return self.shared_experts(
|
688
|
+
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
689
|
+
)
|
634
690
|
else:
|
635
691
|
return None
|
636
692
|
|
@@ -989,17 +1045,32 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
989
1045
|
# Determine attention backend used by current forward batch
|
990
1046
|
if forward_batch.forward_mode.is_decode_or_idle():
|
991
1047
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1048
|
+
elif (
|
1049
|
+
forward_batch.forward_mode.is_target_verify()
|
1050
|
+
or forward_batch.forward_mode.is_draft_extend()
|
1051
|
+
):
|
1052
|
+
# Use the specified backend for speculative operations (both verify and draft extend)
|
1053
|
+
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
1054
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1055
|
+
else: # default to prefill
|
1056
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
992
1057
|
else:
|
993
1058
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
994
1059
|
self.current_attention_backend = attention_backend
|
995
1060
|
|
996
1061
|
if attention_backend == "ascend":
|
997
|
-
|
1062
|
+
if (
|
1063
|
+
forward_batch.forward_mode.is_extend()
|
1064
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1065
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1066
|
+
):
|
1067
|
+
return AttnForwardMethod.MHA
|
1068
|
+
else:
|
1069
|
+
return AttnForwardMethod.MLA
|
998
1070
|
elif (
|
999
1071
|
attention_backend == "flashinfer"
|
1000
1072
|
or attention_backend == "fa3"
|
1001
1073
|
or attention_backend == "flashmla"
|
1002
|
-
or attention_backend == "trtllm_mla"
|
1003
1074
|
or attention_backend == "cutlass_mla"
|
1004
1075
|
):
|
1005
1076
|
# Use MHA with chunked KV cache when prefilling on long sequences.
|
@@ -1028,13 +1099,28 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1028
1099
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
1029
1100
|
else:
|
1030
1101
|
return _dispatch_mla_subtype()
|
1102
|
+
elif attention_backend == "trtllm_mla":
|
1103
|
+
if (
|
1104
|
+
forward_batch.forward_mode.is_extend()
|
1105
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1106
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1107
|
+
):
|
1108
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
1109
|
+
else:
|
1110
|
+
return _dispatch_mla_subtype()
|
1031
1111
|
elif attention_backend == "aiter":
|
1032
1112
|
if (
|
1033
1113
|
forward_batch.forward_mode.is_extend()
|
1034
1114
|
and not forward_batch.forward_mode.is_target_verify()
|
1035
1115
|
and not forward_batch.forward_mode.is_draft_extend()
|
1036
1116
|
):
|
1037
|
-
|
1117
|
+
if is_dp_attention_enabled():
|
1118
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
1119
|
+
return AttnForwardMethod.MHA
|
1120
|
+
else:
|
1121
|
+
return AttnForwardMethod.MLA
|
1122
|
+
else:
|
1123
|
+
return AttnForwardMethod.MHA
|
1038
1124
|
else:
|
1039
1125
|
return AttnForwardMethod.MLA
|
1040
1126
|
else:
|
@@ -1087,11 +1173,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1087
1173
|
if self.attn_mha.kv_b_proj is None:
|
1088
1174
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
1089
1175
|
|
1090
|
-
|
1091
|
-
|
1092
|
-
|
1093
|
-
|
1094
|
-
|
1176
|
+
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
1177
|
+
if isinstance(hidden_states, tuple):
|
1178
|
+
if hidden_states[0].shape[0] == 0:
|
1179
|
+
assert (
|
1180
|
+
not self.o_proj.reduce_results
|
1181
|
+
), "short-circuiting allreduce will lead to hangs"
|
1182
|
+
return hidden_states[0]
|
1183
|
+
else:
|
1184
|
+
if hidden_states.shape[0] == 0:
|
1185
|
+
assert (
|
1186
|
+
not self.o_proj.reduce_results
|
1187
|
+
), "short-circuiting allreduce will lead to hangs"
|
1188
|
+
return hidden_states, None, forward_batch, None
|
1095
1189
|
|
1096
1190
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1097
1191
|
|
@@ -1173,13 +1267,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1173
1267
|
k[..., : self.qk_nope_head_dim] = k_nope
|
1174
1268
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1175
1269
|
|
1176
|
-
|
1177
|
-
|
1270
|
+
if not _is_npu:
|
1271
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1272
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1178
1273
|
|
1179
|
-
|
1180
|
-
|
1181
|
-
|
1182
|
-
|
1274
|
+
# Save latent cache
|
1275
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1276
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1277
|
+
)
|
1278
|
+
else:
|
1279
|
+
# To reduce a time-costing split operation
|
1280
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1281
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
1282
|
+
)
|
1183
1283
|
|
1184
1284
|
return q, k, v, forward_batch
|
1185
1285
|
|
@@ -1209,7 +1309,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1209
1309
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1210
1310
|
|
1211
1311
|
if self.q_lora_rank is not None:
|
1212
|
-
if
|
1312
|
+
if (
|
1313
|
+
(not isinstance(hidden_states, tuple))
|
1314
|
+
and hidden_states.shape[0] <= 16
|
1315
|
+
and self.use_min_latency_fused_a_gemm
|
1316
|
+
):
|
1213
1317
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1214
1318
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1215
1319
|
)
|
@@ -1229,8 +1333,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1229
1333
|
k_nope = self.kv_a_layernorm(k_nope)
|
1230
1334
|
current_stream.wait_stream(self.alt_stream)
|
1231
1335
|
else:
|
1232
|
-
|
1233
|
-
|
1336
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1337
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1338
|
+
q,
|
1339
|
+
self.q_a_layernorm.weight,
|
1340
|
+
self.q_a_layernorm.variance_epsilon,
|
1341
|
+
k_nope,
|
1342
|
+
self.kv_a_layernorm.weight,
|
1343
|
+
self.kv_a_layernorm.variance_epsilon,
|
1344
|
+
)
|
1345
|
+
else:
|
1346
|
+
q = self.q_a_layernorm(q)
|
1347
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1234
1348
|
|
1235
1349
|
k_nope = k_nope.unsqueeze(1)
|
1236
1350
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
@@ -1262,10 +1376,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1262
1376
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
1263
1377
|
elif _is_hip:
|
1264
1378
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1265
|
-
|
1266
|
-
q_nope.
|
1267
|
-
|
1268
|
-
|
1379
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1380
|
+
x = q_nope.transpose(0, 1)
|
1381
|
+
q_nope_out = torch.empty(
|
1382
|
+
x.shape[0],
|
1383
|
+
x.shape[1],
|
1384
|
+
self.w_kc.shape[2],
|
1385
|
+
device=x.device,
|
1386
|
+
dtype=torch.bfloat16,
|
1387
|
+
)
|
1388
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1389
|
+
x,
|
1390
|
+
self.w_kc.transpose(-2, -1),
|
1391
|
+
self.w_scale_k.transpose(-2, -1),
|
1392
|
+
torch.bfloat16,
|
1393
|
+
q_nope_out,
|
1394
|
+
)
|
1395
|
+
else:
|
1396
|
+
q_nope_out = torch.bmm(
|
1397
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1398
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1399
|
+
)
|
1269
1400
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1270
1401
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1271
1402
|
q_nope.transpose(0, 1),
|
@@ -1279,19 +1410,22 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1279
1410
|
|
1280
1411
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1281
1412
|
|
1282
|
-
if not self._fuse_rope_for_trtllm_mla(forward_batch)
|
1413
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1414
|
+
not _use_aiter or not _is_gfx95_supported
|
1415
|
+
):
|
1283
1416
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1284
1417
|
|
1285
|
-
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1418
|
+
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1286
1419
|
|
1287
1420
|
def forward_absorb_core(
|
1288
|
-
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1421
|
+
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1289
1422
|
):
|
1290
1423
|
if (
|
1291
1424
|
self.current_attention_backend == "fa3"
|
1292
1425
|
or self.current_attention_backend == "flashinfer"
|
1293
1426
|
or self.current_attention_backend == "cutlass_mla"
|
1294
1427
|
or self.current_attention_backend == "trtllm_mla"
|
1428
|
+
or self.current_attention_backend == "ascend"
|
1295
1429
|
):
|
1296
1430
|
extra_args = {}
|
1297
1431
|
if self._fuse_rope_for_trtllm_mla(forward_batch):
|
@@ -1309,8 +1443,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1309
1443
|
**extra_args,
|
1310
1444
|
)
|
1311
1445
|
else:
|
1312
|
-
|
1313
|
-
|
1446
|
+
if _use_aiter_gfx95:
|
1447
|
+
cos = self.rotary_emb.cos_cache
|
1448
|
+
sin = self.rotary_emb.sin_cache
|
1449
|
+
q, k = fused_qk_rope_cat(
|
1450
|
+
q_nope_out,
|
1451
|
+
q_pe,
|
1452
|
+
k_nope,
|
1453
|
+
k_pe,
|
1454
|
+
positions,
|
1455
|
+
cos,
|
1456
|
+
sin,
|
1457
|
+
self.rotary_emb.is_neox_style,
|
1458
|
+
)
|
1459
|
+
else:
|
1460
|
+
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1461
|
+
k = torch.cat([k_nope, k_pe], dim=-1)
|
1462
|
+
|
1314
1463
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
1315
1464
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1316
1465
|
|
@@ -1335,11 +1484,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1335
1484
|
)
|
1336
1485
|
elif _is_hip:
|
1337
1486
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1338
|
-
|
1339
|
-
attn_output.
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1487
|
+
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
1488
|
+
x = attn_output.transpose(0, 1)
|
1489
|
+
attn_bmm_output = torch.empty(
|
1490
|
+
x.shape[0],
|
1491
|
+
x.shape[1],
|
1492
|
+
self.w_vc.shape[2],
|
1493
|
+
device=x.device,
|
1494
|
+
dtype=torch.bfloat16,
|
1495
|
+
)
|
1496
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1497
|
+
x,
|
1498
|
+
self.w_vc.transpose(-2, -1),
|
1499
|
+
self.w_scale_v.transpose(-2, -1),
|
1500
|
+
torch.bfloat16,
|
1501
|
+
attn_bmm_output,
|
1502
|
+
)
|
1503
|
+
else:
|
1504
|
+
attn_bmm_output = torch.bmm(
|
1505
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1506
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1507
|
+
)
|
1508
|
+
|
1509
|
+
if self.o_proj.weight.dtype == torch.uint8:
|
1510
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
1511
|
+
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
1512
|
+
else:
|
1513
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1514
|
+
|
1343
1515
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1344
1516
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1345
1517
|
attn_output.transpose(0, 1),
|
@@ -1661,9 +1833,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1661
1833
|
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
1662
1834
|
self.attn_mha.layer_id
|
1663
1835
|
)
|
1664
|
-
latent_cache =
|
1665
|
-
forward_batch.prefix_chunk_kv_indices[i]
|
1666
|
-
|
1836
|
+
latent_cache = (
|
1837
|
+
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
1838
|
+
.contiguous()
|
1839
|
+
.to(q.dtype)
|
1840
|
+
)
|
1667
1841
|
|
1668
1842
|
kv_a_normed, k_pe = latent_cache.split(
|
1669
1843
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
@@ -1847,10 +2021,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1847
2021
|
forward_batch: ForwardBatch,
|
1848
2022
|
residual: Optional[torch.Tensor],
|
1849
2023
|
zero_allocator: BumpAllocator,
|
2024
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
1850
2025
|
) -> torch.Tensor:
|
1851
2026
|
|
2027
|
+
quant_format = (
|
2028
|
+
"mxfp4"
|
2029
|
+
if _is_gfx95_supported
|
2030
|
+
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
|
2031
|
+
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
|
2032
|
+
is not None
|
2033
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
|
2034
|
+
else ""
|
2035
|
+
)
|
2036
|
+
|
1852
2037
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1853
|
-
hidden_states,
|
2038
|
+
hidden_states,
|
2039
|
+
residual,
|
2040
|
+
forward_batch,
|
2041
|
+
quant_format,
|
1854
2042
|
)
|
1855
2043
|
|
1856
2044
|
hidden_states = self.self_attn(
|
@@ -1874,8 +2062,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1874
2062
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1875
2063
|
forward_batch
|
1876
2064
|
)
|
2065
|
+
|
2066
|
+
if isinstance(self.mlp, DeepseekV2MLP):
|
2067
|
+
gemm_output_zero_allocator = None
|
2068
|
+
|
1877
2069
|
hidden_states = self.mlp(
|
1878
|
-
hidden_states,
|
2070
|
+
hidden_states,
|
2071
|
+
forward_batch,
|
2072
|
+
should_allreduce_fusion,
|
2073
|
+
use_reduce_scatter,
|
2074
|
+
gemm_output_zero_allocator,
|
1879
2075
|
)
|
1880
2076
|
|
1881
2077
|
if should_allreduce_fusion:
|
@@ -2019,6 +2215,37 @@ class DeepseekV2Model(nn.Module):
|
|
2019
2215
|
else:
|
2020
2216
|
self.norm = PPMissingLayer(return_tuple=True)
|
2021
2217
|
|
2218
|
+
self.gemm_output_zero_allocator_size = 0
|
2219
|
+
if (
|
2220
|
+
_use_aiter_gfx95
|
2221
|
+
and config.n_routed_experts == 256
|
2222
|
+
and self.embed_tokens.embedding_dim == 7168
|
2223
|
+
):
|
2224
|
+
num_moe_layers = sum(
|
2225
|
+
[
|
2226
|
+
1
|
2227
|
+
for i in range(len(self.layers))
|
2228
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
2229
|
+
]
|
2230
|
+
)
|
2231
|
+
|
2232
|
+
allocate_size = 0
|
2233
|
+
for i in range(len(self.layers)):
|
2234
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
2235
|
+
allocate_size = self.layers[
|
2236
|
+
i
|
2237
|
+
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
2238
|
+
break
|
2239
|
+
|
2240
|
+
self.gemm_output_zero_allocator_size = (
|
2241
|
+
get_dsv3_gemm_output_zero_allocator_size(
|
2242
|
+
config.n_routed_experts,
|
2243
|
+
num_moe_layers,
|
2244
|
+
allocate_size,
|
2245
|
+
self.embed_tokens.embedding_dim,
|
2246
|
+
)
|
2247
|
+
)
|
2248
|
+
|
2022
2249
|
def get_input_embeddings(self) -> torch.Tensor:
|
2023
2250
|
return self.embed_tokens
|
2024
2251
|
|
@@ -2038,6 +2265,21 @@ class DeepseekV2Model(nn.Module):
|
|
2038
2265
|
device=device,
|
2039
2266
|
)
|
2040
2267
|
|
2268
|
+
has_gemm_output_zero_allocator = hasattr(
|
2269
|
+
self, "gemm_output_zero_allocator_size"
|
2270
|
+
)
|
2271
|
+
|
2272
|
+
gemm_output_zero_allocator = (
|
2273
|
+
BumpAllocator(
|
2274
|
+
buffer_size=self.gemm_output_zero_allocator_size,
|
2275
|
+
dtype=torch.float32,
|
2276
|
+
device=device,
|
2277
|
+
)
|
2278
|
+
if has_gemm_output_zero_allocator
|
2279
|
+
and self.gemm_output_zero_allocator_size > 0
|
2280
|
+
else None
|
2281
|
+
)
|
2282
|
+
|
2041
2283
|
if self.pp_group.is_first_rank:
|
2042
2284
|
if input_embeds is None:
|
2043
2285
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -2064,7 +2306,12 @@ class DeepseekV2Model(nn.Module):
|
|
2064
2306
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2065
2307
|
layer = self.layers[i]
|
2066
2308
|
hidden_states, residual = layer(
|
2067
|
-
positions,
|
2309
|
+
positions,
|
2310
|
+
hidden_states,
|
2311
|
+
forward_batch,
|
2312
|
+
residual,
|
2313
|
+
zero_allocator,
|
2314
|
+
gemm_output_zero_allocator,
|
2068
2315
|
)
|
2069
2316
|
|
2070
2317
|
if normal_end_layer != self.end_layer:
|
@@ -2168,6 +2415,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2168
2415
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
2169
2416
|
elif get_moe_expert_parallel_world_size() > 1:
|
2170
2417
|
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
2418
|
+
elif self.quant_config.get_name() == "w4afp8":
|
2419
|
+
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
2171
2420
|
|
2172
2421
|
if disable_reason is not None:
|
2173
2422
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
@@ -2335,6 +2584,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2335
2584
|
w_kc, w_vc = w.unflatten(
|
2336
2585
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2337
2586
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
2587
|
+
|
2588
|
+
if (
|
2589
|
+
_use_aiter_gfx95
|
2590
|
+
and self.quant_config is not None
|
2591
|
+
and self.quant_config.get_name() == "quark"
|
2592
|
+
):
|
2593
|
+
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
2594
|
+
quark_post_load_weights(self_attn, w, "mxfp4")
|
2595
|
+
)
|
2596
|
+
|
2338
2597
|
if not use_deep_gemm_bmm:
|
2339
2598
|
self_attn.w_kc = bind_or_assign(
|
2340
2599
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
@@ -2397,18 +2656,26 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2397
2656
|
)
|
2398
2657
|
|
2399
2658
|
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2659
|
+
|
2400
2660
|
for layer_id in range(num_hidden_layers):
|
2401
2661
|
if is_nextn:
|
2402
2662
|
layer = self.model.decoder
|
2403
2663
|
else:
|
2404
2664
|
layer = self.model.layers[layer_id]
|
2405
2665
|
|
2406
|
-
|
2407
|
-
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
2408
|
-
layer.self_attn.q_b_proj,
|
2666
|
+
module_list = [
|
2409
2667
|
layer.self_attn.kv_b_proj,
|
2410
2668
|
layer.self_attn.o_proj,
|
2411
|
-
]
|
2669
|
+
]
|
2670
|
+
|
2671
|
+
if self.config.q_lora_rank is not None:
|
2672
|
+
module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
|
2673
|
+
module_list.append(layer.self_attn.q_b_proj)
|
2674
|
+
else:
|
2675
|
+
module_list.append(layer.self_attn.kv_a_proj_with_mqa)
|
2676
|
+
module_list.append(layer.self_attn.q_proj)
|
2677
|
+
|
2678
|
+
for module in module_list:
|
2412
2679
|
requant_weight_ue8m0_inplace(
|
2413
2680
|
module.weight, module.weight_scale_inv, weight_block_size
|
2414
2681
|
)
|
@@ -2471,6 +2738,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2471
2738
|
ckpt_up_proj_name="up_proj",
|
2472
2739
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2473
2740
|
)
|
2741
|
+
# Params for special naming rules in mixed-precision models, for example:
|
2742
|
+
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
|
2743
|
+
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
|
2474
2744
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2475
2745
|
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
2476
2746
|
num_experts=self.config.n_routed_experts
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
499
499
|
def should_apply_lora(self, module_name: str) -> bool:
|
500
500
|
return bool(self.lora_pattern.match(module_name))
|
501
501
|
|
502
|
-
def get_hidden_dim(self, module_name):
|
502
|
+
def get_hidden_dim(self, module_name, layer_idx):
|
503
503
|
# return input_dim, output_dim
|
504
504
|
if module_name == "qkv_proj":
|
505
505
|
return (
|