sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -67,7 +67,10 @@ from sglang.srt.layers.moe import (
|
|
67
67
|
should_use_flashinfer_cutlass_moe_fp4_allgather,
|
68
68
|
)
|
69
69
|
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
70
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
70
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
71
|
+
FusedMoE,
|
72
|
+
_is_fp4_quantization_enabled,
|
73
|
+
)
|
71
74
|
from sglang.srt.layers.moe.topk import TopK
|
72
75
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
73
76
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -112,8 +115,10 @@ from sglang.srt.utils import (
|
|
112
115
|
is_cpu,
|
113
116
|
is_cuda,
|
114
117
|
is_flashinfer_available,
|
118
|
+
is_gfx95_supported,
|
115
119
|
is_hip,
|
116
120
|
is_non_idle_and_non_empty,
|
121
|
+
is_npu,
|
117
122
|
is_sm100_supported,
|
118
123
|
log_info_on_rank0,
|
119
124
|
make_layers,
|
@@ -122,11 +127,28 @@ from sglang.srt.utils import (
|
|
122
127
|
|
123
128
|
_is_hip = is_hip()
|
124
129
|
_is_cuda = is_cuda()
|
130
|
+
_is_npu = is_npu()
|
125
131
|
_is_fp8_fnuz = is_fp8_fnuz()
|
126
132
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
127
133
|
_is_cpu_amx_available = cpu_has_amx_support()
|
128
134
|
_is_cpu = is_cpu()
|
129
135
|
_device_sm = get_device_sm()
|
136
|
+
_is_gfx95_supported = is_gfx95_supported()
|
137
|
+
|
138
|
+
_use_aiter_gfx95 = _use_aiter and _is_gfx95_supported
|
139
|
+
|
140
|
+
if _use_aiter_gfx95:
|
141
|
+
from sglang.srt.layers.quantization.quark.utils import quark_post_load_weights
|
142
|
+
from sglang.srt.layers.quantization.rocm_mxfp4_utils import (
|
143
|
+
batched_gemm_afp4wfp4_pre_quant,
|
144
|
+
fused_flatten_mxfp4_quant,
|
145
|
+
fused_rms_mxfp4_quant,
|
146
|
+
)
|
147
|
+
from sglang.srt.layers.rocm_linear_utils import (
|
148
|
+
aiter_dsv3_router_gemm,
|
149
|
+
fused_qk_rope_cat,
|
150
|
+
get_dsv3_gemm_output_zero_allocator_size,
|
151
|
+
)
|
130
152
|
|
131
153
|
if _is_cuda:
|
132
154
|
from sgl_kernel import (
|
@@ -222,10 +244,21 @@ class DeepseekV2MLP(nn.Module):
|
|
222
244
|
forward_batch=None,
|
223
245
|
should_allreduce_fusion: bool = False,
|
224
246
|
use_reduce_scatter: bool = False,
|
247
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
225
248
|
):
|
226
249
|
if (self.tp_size == 1) and x.shape[0] == 0:
|
227
250
|
return x
|
228
251
|
|
252
|
+
if (
|
253
|
+
gemm_output_zero_allocator is not None
|
254
|
+
and x.shape[0] <= 256
|
255
|
+
and self.gate_up_proj.weight.dtype == torch.uint8
|
256
|
+
):
|
257
|
+
y = gemm_output_zero_allocator.allocate(
|
258
|
+
x.shape[0] * self.gate_up_proj.output_size_per_partition
|
259
|
+
).view(x.shape[0], self.gate_up_proj.output_size_per_partition)
|
260
|
+
x = (x, None, y)
|
261
|
+
|
229
262
|
gate_up, _ = self.gate_up_proj(x)
|
230
263
|
x = self.act_fn(gate_up)
|
231
264
|
x, _ = self.down_proj(
|
@@ -255,7 +288,7 @@ class MoEGate(nn.Module):
|
|
255
288
|
if _is_cpu and _is_cpu_amx_available:
|
256
289
|
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
257
290
|
|
258
|
-
def forward(self, hidden_states):
|
291
|
+
def forward(self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None):
|
259
292
|
if use_intel_amx_backend(self):
|
260
293
|
return torch.ops.sgl_kernel.weight_packed_linear(
|
261
294
|
hidden_states,
|
@@ -273,7 +306,13 @@ class MoEGate(nn.Module):
|
|
273
306
|
and _device_sm >= 90
|
274
307
|
):
|
275
308
|
# router gemm output float32
|
276
|
-
logits = dsv3_router_gemm(
|
309
|
+
logits = dsv3_router_gemm(
|
310
|
+
hidden_states, self.weight, out_dtype=torch.float32
|
311
|
+
)
|
312
|
+
elif _use_aiter_gfx95 and hidden_states.shape[0] <= 256:
|
313
|
+
logits = aiter_dsv3_router_gemm(
|
314
|
+
hidden_states, self.weight, gemm_output_zero_allocator
|
315
|
+
)
|
277
316
|
else:
|
278
317
|
logits = F.linear(hidden_states, self.weight, None)
|
279
318
|
|
@@ -334,6 +373,9 @@ class DeepseekV2MoE(nn.Module):
|
|
334
373
|
prefix=add_prefix("experts", prefix),
|
335
374
|
)
|
336
375
|
|
376
|
+
correction_bias = self.gate.e_score_correction_bias
|
377
|
+
if _is_fp4_quantization_enabled():
|
378
|
+
correction_bias = correction_bias.to(torch.bfloat16)
|
337
379
|
self.topk = TopK(
|
338
380
|
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
339
381
|
renormalize=config.norm_topk_prob,
|
@@ -341,7 +383,7 @@ class DeepseekV2MoE(nn.Module):
|
|
341
383
|
num_expert_group=config.n_group,
|
342
384
|
num_fused_shared_experts=self.num_fused_shared_experts,
|
343
385
|
topk_group=config.topk_group,
|
344
|
-
correction_bias=
|
386
|
+
correction_bias=correction_bias,
|
345
387
|
routed_scaling_factor=self.routed_scaling_factor,
|
346
388
|
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk(),
|
347
389
|
force_topk=quant_config is None,
|
@@ -437,6 +479,7 @@ class DeepseekV2MoE(nn.Module):
|
|
437
479
|
forward_batch: Optional[ForwardBatch] = None,
|
438
480
|
should_allreduce_fusion: bool = False,
|
439
481
|
use_reduce_scatter: bool = False,
|
482
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
440
483
|
) -> torch.Tensor:
|
441
484
|
if not self._enable_deepep_moe:
|
442
485
|
DUAL_STREAM_TOKEN_THRESHOLD = 1024
|
@@ -450,12 +493,14 @@ class DeepseekV2MoE(nn.Module):
|
|
450
493
|
hidden_states,
|
451
494
|
should_allreduce_fusion,
|
452
495
|
use_reduce_scatter,
|
496
|
+
gemm_output_zero_allocator,
|
453
497
|
)
|
454
498
|
else:
|
455
499
|
return self.forward_normal(
|
456
500
|
hidden_states,
|
457
501
|
should_allreduce_fusion,
|
458
502
|
use_reduce_scatter,
|
503
|
+
gemm_output_zero_allocator,
|
459
504
|
)
|
460
505
|
else:
|
461
506
|
return self.forward_deepep(hidden_states, forward_batch)
|
@@ -465,15 +510,18 @@ class DeepseekV2MoE(nn.Module):
|
|
465
510
|
hidden_states: torch.Tensor,
|
466
511
|
should_allreduce_fusion: bool = False,
|
467
512
|
use_reduce_scatter: bool = False,
|
513
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
468
514
|
) -> torch.Tensor:
|
469
515
|
|
470
516
|
current_stream = torch.cuda.current_stream()
|
471
517
|
self.alt_stream.wait_stream(current_stream)
|
472
|
-
shared_output = self._forward_shared_experts(
|
518
|
+
shared_output = self._forward_shared_experts(
|
519
|
+
hidden_states, gemm_output_zero_allocator
|
520
|
+
)
|
473
521
|
|
474
522
|
with torch.cuda.stream(self.alt_stream):
|
475
523
|
# router_logits: (num_tokens, n_experts)
|
476
|
-
router_logits = self.gate(hidden_states)
|
524
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
477
525
|
topk_output = self.topk(hidden_states, router_logits)
|
478
526
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
479
527
|
if not _is_cuda:
|
@@ -500,6 +548,7 @@ class DeepseekV2MoE(nn.Module):
|
|
500
548
|
hidden_states: torch.Tensor,
|
501
549
|
should_allreduce_fusion: bool = False,
|
502
550
|
use_reduce_scatter: bool = False,
|
551
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
503
552
|
) -> torch.Tensor:
|
504
553
|
if hasattr(self, "shared_experts") and use_intel_amx_backend(
|
505
554
|
self.shared_experts.gate_up_proj
|
@@ -507,9 +556,11 @@ class DeepseekV2MoE(nn.Module):
|
|
507
556
|
return self.forward_cpu(hidden_states, should_allreduce_fusion)
|
508
557
|
|
509
558
|
if hidden_states.shape[0] > 0:
|
510
|
-
shared_output = self._forward_shared_experts(
|
559
|
+
shared_output = self._forward_shared_experts(
|
560
|
+
hidden_states, gemm_output_zero_allocator
|
561
|
+
)
|
511
562
|
# router_logits: (num_tokens, n_experts)
|
512
|
-
router_logits = self.gate(hidden_states)
|
563
|
+
router_logits = self.gate(hidden_states, gemm_output_zero_allocator)
|
513
564
|
topk_output = self.topk(hidden_states, router_logits)
|
514
565
|
else:
|
515
566
|
shared_output = None
|
@@ -629,9 +680,13 @@ class DeepseekV2MoE(nn.Module):
|
|
629
680
|
|
630
681
|
return final_hidden_states
|
631
682
|
|
632
|
-
def _forward_shared_experts(
|
683
|
+
def _forward_shared_experts(
|
684
|
+
self, hidden_states, gemm_output_zero_allocator: BumpAllocator = None
|
685
|
+
):
|
633
686
|
if self.num_fused_shared_experts == 0:
|
634
|
-
return self.shared_experts(
|
687
|
+
return self.shared_experts(
|
688
|
+
hidden_states, gemm_output_zero_allocator=gemm_output_zero_allocator
|
689
|
+
)
|
635
690
|
else:
|
636
691
|
return None
|
637
692
|
|
@@ -990,6 +1045,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
990
1045
|
# Determine attention backend used by current forward batch
|
991
1046
|
if forward_batch.forward_mode.is_decode_or_idle():
|
992
1047
|
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1048
|
+
elif (
|
1049
|
+
forward_batch.forward_mode.is_target_verify()
|
1050
|
+
or forward_batch.forward_mode.is_draft_extend()
|
1051
|
+
):
|
1052
|
+
# Use the specified backend for speculative operations (both verify and draft extend)
|
1053
|
+
if global_server_args_dict["speculative_attention_mode"] == "decode":
|
1054
|
+
attention_backend = global_server_args_dict["decode_attention_backend"]
|
1055
|
+
else: # default to prefill
|
1056
|
+
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
993
1057
|
else:
|
994
1058
|
attention_backend = global_server_args_dict["prefill_attention_backend"]
|
995
1059
|
self.current_attention_backend = attention_backend
|
@@ -1007,7 +1071,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1007
1071
|
attention_backend == "flashinfer"
|
1008
1072
|
or attention_backend == "fa3"
|
1009
1073
|
or attention_backend == "flashmla"
|
1010
|
-
or attention_backend == "trtllm_mla"
|
1011
1074
|
or attention_backend == "cutlass_mla"
|
1012
1075
|
):
|
1013
1076
|
# Use MHA with chunked KV cache when prefilling on long sequences.
|
@@ -1036,13 +1099,28 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1036
1099
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
1037
1100
|
else:
|
1038
1101
|
return _dispatch_mla_subtype()
|
1102
|
+
elif attention_backend == "trtllm_mla":
|
1103
|
+
if (
|
1104
|
+
forward_batch.forward_mode.is_extend()
|
1105
|
+
and not forward_batch.forward_mode.is_target_verify()
|
1106
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
1107
|
+
):
|
1108
|
+
return AttnForwardMethod.MHA_CHUNKED_KV
|
1109
|
+
else:
|
1110
|
+
return _dispatch_mla_subtype()
|
1039
1111
|
elif attention_backend == "aiter":
|
1040
1112
|
if (
|
1041
1113
|
forward_batch.forward_mode.is_extend()
|
1042
1114
|
and not forward_batch.forward_mode.is_target_verify()
|
1043
1115
|
and not forward_batch.forward_mode.is_draft_extend()
|
1044
1116
|
):
|
1045
|
-
|
1117
|
+
if is_dp_attention_enabled():
|
1118
|
+
if sum(forward_batch.extend_prefix_lens_cpu) == 0:
|
1119
|
+
return AttnForwardMethod.MHA
|
1120
|
+
else:
|
1121
|
+
return AttnForwardMethod.MLA
|
1122
|
+
else:
|
1123
|
+
return AttnForwardMethod.MHA
|
1046
1124
|
else:
|
1047
1125
|
return AttnForwardMethod.MLA
|
1048
1126
|
else:
|
@@ -1095,11 +1173,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1095
1173
|
if self.attn_mha.kv_b_proj is None:
|
1096
1174
|
self.attn_mha.kv_b_proj = self.kv_b_proj
|
1097
1175
|
|
1098
|
-
|
1099
|
-
|
1100
|
-
|
1101
|
-
|
1102
|
-
|
1176
|
+
# when hidden_states is a tuple of tensors, the tuple will include quantized weight and scale tensor
|
1177
|
+
if isinstance(hidden_states, tuple):
|
1178
|
+
if hidden_states[0].shape[0] == 0:
|
1179
|
+
assert (
|
1180
|
+
not self.o_proj.reduce_results
|
1181
|
+
), "short-circuiting allreduce will lead to hangs"
|
1182
|
+
return hidden_states[0]
|
1183
|
+
else:
|
1184
|
+
if hidden_states.shape[0] == 0:
|
1185
|
+
assert (
|
1186
|
+
not self.o_proj.reduce_results
|
1187
|
+
), "short-circuiting allreduce will lead to hangs"
|
1188
|
+
return hidden_states, None, forward_batch, None
|
1103
1189
|
|
1104
1190
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
1105
1191
|
|
@@ -1181,13 +1267,19 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1181
1267
|
k[..., : self.qk_nope_head_dim] = k_nope
|
1182
1268
|
k[..., self.qk_nope_head_dim :] = k_pe
|
1183
1269
|
|
1184
|
-
|
1185
|
-
|
1270
|
+
if not _is_npu:
|
1271
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
1272
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
1186
1273
|
|
1187
|
-
|
1188
|
-
|
1189
|
-
|
1190
|
-
|
1274
|
+
# Save latent cache
|
1275
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1276
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1277
|
+
)
|
1278
|
+
else:
|
1279
|
+
# To reduce a time-costing split operation
|
1280
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
1281
|
+
self.attn_mha, forward_batch.out_cache_loc, kv_a.unsqueeze(1), k_pe
|
1282
|
+
)
|
1191
1283
|
|
1192
1284
|
return q, k, v, forward_batch
|
1193
1285
|
|
@@ -1217,7 +1309,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1217
1309
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
1218
1310
|
|
1219
1311
|
if self.q_lora_rank is not None:
|
1220
|
-
if
|
1312
|
+
if (
|
1313
|
+
(not isinstance(hidden_states, tuple))
|
1314
|
+
and hidden_states.shape[0] <= 16
|
1315
|
+
and self.use_min_latency_fused_a_gemm
|
1316
|
+
):
|
1221
1317
|
fused_qkv_a_proj_out = dsv3_fused_a_gemm(
|
1222
1318
|
hidden_states, self.fused_qkv_a_proj_with_mqa.weight.T
|
1223
1319
|
)
|
@@ -1237,8 +1333,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1237
1333
|
k_nope = self.kv_a_layernorm(k_nope)
|
1238
1334
|
current_stream.wait_stream(self.alt_stream)
|
1239
1335
|
else:
|
1240
|
-
|
1241
|
-
|
1336
|
+
if _use_aiter_gfx95 and self.q_b_proj.weight.dtype == torch.uint8:
|
1337
|
+
q, k_nope = fused_rms_mxfp4_quant(
|
1338
|
+
q,
|
1339
|
+
self.q_a_layernorm.weight,
|
1340
|
+
self.q_a_layernorm.variance_epsilon,
|
1341
|
+
k_nope,
|
1342
|
+
self.kv_a_layernorm.weight,
|
1343
|
+
self.kv_a_layernorm.variance_epsilon,
|
1344
|
+
)
|
1345
|
+
else:
|
1346
|
+
q = self.q_a_layernorm(q)
|
1347
|
+
k_nope = self.kv_a_layernorm(k_nope)
|
1242
1348
|
|
1243
1349
|
k_nope = k_nope.unsqueeze(1)
|
1244
1350
|
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
@@ -1270,10 +1376,27 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1270
1376
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
1271
1377
|
elif _is_hip:
|
1272
1378
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1273
|
-
|
1274
|
-
q_nope.
|
1275
|
-
|
1276
|
-
|
1379
|
+
if _use_aiter_gfx95 and self.w_kc.dtype == torch.uint8:
|
1380
|
+
x = q_nope.transpose(0, 1)
|
1381
|
+
q_nope_out = torch.empty(
|
1382
|
+
x.shape[0],
|
1383
|
+
x.shape[1],
|
1384
|
+
self.w_kc.shape[2],
|
1385
|
+
device=x.device,
|
1386
|
+
dtype=torch.bfloat16,
|
1387
|
+
)
|
1388
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1389
|
+
x,
|
1390
|
+
self.w_kc.transpose(-2, -1),
|
1391
|
+
self.w_scale_k.transpose(-2, -1),
|
1392
|
+
torch.bfloat16,
|
1393
|
+
q_nope_out,
|
1394
|
+
)
|
1395
|
+
else:
|
1396
|
+
q_nope_out = torch.bmm(
|
1397
|
+
q_nope.to(torch.bfloat16).transpose(0, 1),
|
1398
|
+
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
1399
|
+
)
|
1277
1400
|
elif self.w_kc.dtype == torch.float8_e4m3fn:
|
1278
1401
|
q_nope_val, q_nope_scale = per_tensor_quant_mla_fp8(
|
1279
1402
|
q_nope.transpose(0, 1),
|
@@ -1287,13 +1410,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1287
1410
|
|
1288
1411
|
q_nope_out = q_nope_out.transpose(0, 1)
|
1289
1412
|
|
1290
|
-
if not self._fuse_rope_for_trtllm_mla(forward_batch)
|
1413
|
+
if not self._fuse_rope_for_trtllm_mla(forward_batch) and (
|
1414
|
+
not _use_aiter or not _is_gfx95_supported
|
1415
|
+
):
|
1291
1416
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
1292
1417
|
|
1293
|
-
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1418
|
+
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1294
1419
|
|
1295
1420
|
def forward_absorb_core(
|
1296
|
-
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1421
|
+
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator, positions
|
1297
1422
|
):
|
1298
1423
|
if (
|
1299
1424
|
self.current_attention_backend == "fa3"
|
@@ -1318,8 +1443,23 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1318
1443
|
**extra_args,
|
1319
1444
|
)
|
1320
1445
|
else:
|
1321
|
-
|
1322
|
-
|
1446
|
+
if _use_aiter_gfx95:
|
1447
|
+
cos = self.rotary_emb.cos_cache
|
1448
|
+
sin = self.rotary_emb.sin_cache
|
1449
|
+
q, k = fused_qk_rope_cat(
|
1450
|
+
q_nope_out,
|
1451
|
+
q_pe,
|
1452
|
+
k_nope,
|
1453
|
+
k_pe,
|
1454
|
+
positions,
|
1455
|
+
cos,
|
1456
|
+
sin,
|
1457
|
+
self.rotary_emb.is_neox_style,
|
1458
|
+
)
|
1459
|
+
else:
|
1460
|
+
q = torch.cat([q_nope_out, q_pe], dim=-1)
|
1461
|
+
k = torch.cat([k_nope, k_pe], dim=-1)
|
1462
|
+
|
1323
1463
|
attn_output = self.attn_mqa(q, k, k_nope, forward_batch)
|
1324
1464
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1325
1465
|
|
@@ -1344,11 +1484,34 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1344
1484
|
)
|
1345
1485
|
elif _is_hip:
|
1346
1486
|
# TODO(haishaw): add bmm_fp8 to ROCm
|
1347
|
-
|
1348
|
-
attn_output.
|
1349
|
-
|
1350
|
-
|
1351
|
-
|
1487
|
+
if _use_aiter_gfx95 and self.w_vc.dtype == torch.uint8:
|
1488
|
+
x = attn_output.transpose(0, 1)
|
1489
|
+
attn_bmm_output = torch.empty(
|
1490
|
+
x.shape[0],
|
1491
|
+
x.shape[1],
|
1492
|
+
self.w_vc.shape[2],
|
1493
|
+
device=x.device,
|
1494
|
+
dtype=torch.bfloat16,
|
1495
|
+
)
|
1496
|
+
batched_gemm_afp4wfp4_pre_quant(
|
1497
|
+
x,
|
1498
|
+
self.w_vc.transpose(-2, -1),
|
1499
|
+
self.w_scale_v.transpose(-2, -1),
|
1500
|
+
torch.bfloat16,
|
1501
|
+
attn_bmm_output,
|
1502
|
+
)
|
1503
|
+
else:
|
1504
|
+
attn_bmm_output = torch.bmm(
|
1505
|
+
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1506
|
+
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
1507
|
+
)
|
1508
|
+
|
1509
|
+
if self.o_proj.weight.dtype == torch.uint8:
|
1510
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1)
|
1511
|
+
attn_bmm_output = fused_flatten_mxfp4_quant(attn_bmm_output)
|
1512
|
+
else:
|
1513
|
+
attn_bmm_output = attn_bmm_output.transpose(0, 1).flatten(1, 2)
|
1514
|
+
|
1352
1515
|
elif self.w_vc.dtype == torch.float8_e4m3fn:
|
1353
1516
|
attn_output_val, attn_output_scale = per_tensor_quant_mla_fp8(
|
1354
1517
|
attn_output.transpose(0, 1),
|
@@ -1670,9 +1833,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1670
1833
|
latent_cache_buf = forward_batch.token_to_kv_pool.get_key_buffer(
|
1671
1834
|
self.attn_mha.layer_id
|
1672
1835
|
)
|
1673
|
-
latent_cache =
|
1674
|
-
forward_batch.prefix_chunk_kv_indices[i]
|
1675
|
-
|
1836
|
+
latent_cache = (
|
1837
|
+
latent_cache_buf[forward_batch.prefix_chunk_kv_indices[i]]
|
1838
|
+
.contiguous()
|
1839
|
+
.to(q.dtype)
|
1840
|
+
)
|
1676
1841
|
|
1677
1842
|
kv_a_normed, k_pe = latent_cache.split(
|
1678
1843
|
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1
|
@@ -1856,10 +2021,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1856
2021
|
forward_batch: ForwardBatch,
|
1857
2022
|
residual: Optional[torch.Tensor],
|
1858
2023
|
zero_allocator: BumpAllocator,
|
2024
|
+
gemm_output_zero_allocator: BumpAllocator = None,
|
1859
2025
|
) -> torch.Tensor:
|
1860
2026
|
|
2027
|
+
quant_format = (
|
2028
|
+
"mxfp4"
|
2029
|
+
if _is_gfx95_supported
|
2030
|
+
and getattr(self.self_attn, "fused_qkv_a_proj_with_mqa", None) is not None
|
2031
|
+
and getattr(self.self_attn.fused_qkv_a_proj_with_mqa, "weight", None)
|
2032
|
+
is not None
|
2033
|
+
and self.self_attn.fused_qkv_a_proj_with_mqa.weight.dtype == torch.uint8
|
2034
|
+
else ""
|
2035
|
+
)
|
2036
|
+
|
1861
2037
|
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1862
|
-
hidden_states,
|
2038
|
+
hidden_states,
|
2039
|
+
residual,
|
2040
|
+
forward_batch,
|
2041
|
+
quant_format,
|
1863
2042
|
)
|
1864
2043
|
|
1865
2044
|
hidden_states = self.self_attn(
|
@@ -1883,8 +2062,16 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1883
2062
|
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
1884
2063
|
forward_batch
|
1885
2064
|
)
|
2065
|
+
|
2066
|
+
if isinstance(self.mlp, DeepseekV2MLP):
|
2067
|
+
gemm_output_zero_allocator = None
|
2068
|
+
|
1886
2069
|
hidden_states = self.mlp(
|
1887
|
-
hidden_states,
|
2070
|
+
hidden_states,
|
2071
|
+
forward_batch,
|
2072
|
+
should_allreduce_fusion,
|
2073
|
+
use_reduce_scatter,
|
2074
|
+
gemm_output_zero_allocator,
|
1888
2075
|
)
|
1889
2076
|
|
1890
2077
|
if should_allreduce_fusion:
|
@@ -2028,6 +2215,37 @@ class DeepseekV2Model(nn.Module):
|
|
2028
2215
|
else:
|
2029
2216
|
self.norm = PPMissingLayer(return_tuple=True)
|
2030
2217
|
|
2218
|
+
self.gemm_output_zero_allocator_size = 0
|
2219
|
+
if (
|
2220
|
+
_use_aiter_gfx95
|
2221
|
+
and config.n_routed_experts == 256
|
2222
|
+
and self.embed_tokens.embedding_dim == 7168
|
2223
|
+
):
|
2224
|
+
num_moe_layers = sum(
|
2225
|
+
[
|
2226
|
+
1
|
2227
|
+
for i in range(len(self.layers))
|
2228
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE)
|
2229
|
+
]
|
2230
|
+
)
|
2231
|
+
|
2232
|
+
allocate_size = 0
|
2233
|
+
for i in range(len(self.layers)):
|
2234
|
+
if isinstance(self.layers[i].mlp, DeepseekV2MoE):
|
2235
|
+
allocate_size = self.layers[
|
2236
|
+
i
|
2237
|
+
].mlp.shared_experts.gate_up_proj.output_size_per_partition
|
2238
|
+
break
|
2239
|
+
|
2240
|
+
self.gemm_output_zero_allocator_size = (
|
2241
|
+
get_dsv3_gemm_output_zero_allocator_size(
|
2242
|
+
config.n_routed_experts,
|
2243
|
+
num_moe_layers,
|
2244
|
+
allocate_size,
|
2245
|
+
self.embed_tokens.embedding_dim,
|
2246
|
+
)
|
2247
|
+
)
|
2248
|
+
|
2031
2249
|
def get_input_embeddings(self) -> torch.Tensor:
|
2032
2250
|
return self.embed_tokens
|
2033
2251
|
|
@@ -2047,6 +2265,21 @@ class DeepseekV2Model(nn.Module):
|
|
2047
2265
|
device=device,
|
2048
2266
|
)
|
2049
2267
|
|
2268
|
+
has_gemm_output_zero_allocator = hasattr(
|
2269
|
+
self, "gemm_output_zero_allocator_size"
|
2270
|
+
)
|
2271
|
+
|
2272
|
+
gemm_output_zero_allocator = (
|
2273
|
+
BumpAllocator(
|
2274
|
+
buffer_size=self.gemm_output_zero_allocator_size,
|
2275
|
+
dtype=torch.float32,
|
2276
|
+
device=device,
|
2277
|
+
)
|
2278
|
+
if has_gemm_output_zero_allocator
|
2279
|
+
and self.gemm_output_zero_allocator_size > 0
|
2280
|
+
else None
|
2281
|
+
)
|
2282
|
+
|
2050
2283
|
if self.pp_group.is_first_rank:
|
2051
2284
|
if input_embeds is None:
|
2052
2285
|
hidden_states = self.embed_tokens(input_ids)
|
@@ -2073,7 +2306,12 @@ class DeepseekV2Model(nn.Module):
|
|
2073
2306
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
2074
2307
|
layer = self.layers[i]
|
2075
2308
|
hidden_states, residual = layer(
|
2076
|
-
positions,
|
2309
|
+
positions,
|
2310
|
+
hidden_states,
|
2311
|
+
forward_batch,
|
2312
|
+
residual,
|
2313
|
+
zero_allocator,
|
2314
|
+
gemm_output_zero_allocator,
|
2077
2315
|
)
|
2078
2316
|
|
2079
2317
|
if normal_end_layer != self.end_layer:
|
@@ -2177,6 +2415,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2177
2415
|
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 80 can use shared experts fusion optimization."
|
2178
2416
|
elif get_moe_expert_parallel_world_size() > 1:
|
2179
2417
|
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization under expert parallelism."
|
2418
|
+
elif self.quant_config.get_name() == "w4afp8":
|
2419
|
+
disable_reason = "Deepseek V3/R1 W4AFP8 model uses different quant method for routed experts and shared experts."
|
2180
2420
|
|
2181
2421
|
if disable_reason is not None:
|
2182
2422
|
global_server_args_dict["disable_shared_experts_fusion"] = True
|
@@ -2344,6 +2584,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2344
2584
|
w_kc, w_vc = w.unflatten(
|
2345
2585
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
2346
2586
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
2587
|
+
|
2588
|
+
if (
|
2589
|
+
_use_aiter_gfx95
|
2590
|
+
and self.quant_config is not None
|
2591
|
+
and self.quant_config.get_name() == "quark"
|
2592
|
+
):
|
2593
|
+
w_kc, self_attn.w_scale_k, w_vc, self_attn.w_scale_v = (
|
2594
|
+
quark_post_load_weights(self_attn, w, "mxfp4")
|
2595
|
+
)
|
2596
|
+
|
2347
2597
|
if not use_deep_gemm_bmm:
|
2348
2598
|
self_attn.w_kc = bind_or_assign(
|
2349
2599
|
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
@@ -2406,18 +2656,26 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2406
2656
|
)
|
2407
2657
|
|
2408
2658
|
num_hidden_layers = 1 if is_nextn else self.config.num_hidden_layers
|
2659
|
+
|
2409
2660
|
for layer_id in range(num_hidden_layers):
|
2410
2661
|
if is_nextn:
|
2411
2662
|
layer = self.model.decoder
|
2412
2663
|
else:
|
2413
2664
|
layer = self.model.layers[layer_id]
|
2414
2665
|
|
2415
|
-
|
2416
|
-
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
2417
|
-
layer.self_attn.q_b_proj,
|
2666
|
+
module_list = [
|
2418
2667
|
layer.self_attn.kv_b_proj,
|
2419
2668
|
layer.self_attn.o_proj,
|
2420
|
-
]
|
2669
|
+
]
|
2670
|
+
|
2671
|
+
if self.config.q_lora_rank is not None:
|
2672
|
+
module_list.append(layer.self_attn.fused_qkv_a_proj_with_mqa)
|
2673
|
+
module_list.append(layer.self_attn.q_b_proj)
|
2674
|
+
else:
|
2675
|
+
module_list.append(layer.self_attn.kv_a_proj_with_mqa)
|
2676
|
+
module_list.append(layer.self_attn.q_proj)
|
2677
|
+
|
2678
|
+
for module in module_list:
|
2421
2679
|
requant_weight_ue8m0_inplace(
|
2422
2680
|
module.weight, module.weight_scale_inv, weight_block_size
|
2423
2681
|
)
|
@@ -2480,6 +2738,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
2480
2738
|
ckpt_up_proj_name="up_proj",
|
2481
2739
|
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
2482
2740
|
)
|
2741
|
+
# Params for special naming rules in mixed-precision models, for example:
|
2742
|
+
# model.layers.xx.mlp.experts.xx.w1.input_scale. For details,
|
2743
|
+
# see https://huggingface.co/Barrrrry/DeepSeek-R1-W4AFP8/blob/main.
|
2483
2744
|
if self.quant_config and self.quant_config.get_name() == "w4afp8":
|
2484
2745
|
expert_params_mapping += FusedMoE.make_expert_input_scale_params_mapping(
|
2485
2746
|
num_experts=self.config.n_routed_experts
|
sglang/srt/models/gemma3n_mm.py
CHANGED
@@ -499,7 +499,7 @@ class Gemma3nForConditionalGeneration(PreTrainedModel):
|
|
499
499
|
def should_apply_lora(self, module_name: str) -> bool:
|
500
500
|
return bool(self.lora_pattern.match(module_name))
|
501
501
|
|
502
|
-
def get_hidden_dim(self, module_name):
|
502
|
+
def get_hidden_dim(self, module_name, layer_idx):
|
503
503
|
# return input_dim, output_dim
|
504
504
|
if module_name == "qkv_proj":
|
505
505
|
return (
|