sglang 0.4.6.post5__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_offline_throughput.py +10 -4
- sglang/bench_one_batch_server.py +67 -11
- sglang/bench_serving.py +86 -75
- sglang/lang/backend/runtime_endpoint.py +24 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/profiler.py +167 -0
- sglang/srt/_custom_ops.py +34 -0
- sglang/srt/configs/internvl.py +8 -12
- sglang/srt/configs/model_config.py +33 -1
- sglang/srt/constrained/base_grammar_backend.py +5 -2
- sglang/srt/constrained/llguidance_backend.py +9 -8
- sglang/srt/constrained/outlines_backend.py +5 -4
- sglang/srt/constrained/xgrammar_backend.py +18 -18
- sglang/srt/conversation.py +52 -8
- sglang/srt/custom_op.py +38 -3
- sglang/srt/debug_utils.py +74 -0
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -0
- sglang/srt/disaggregation/common/conn.py +407 -0
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +261 -52
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +16 -9
- sglang/srt/disaggregation/kv_events.py +60 -5
- sglang/srt/disaggregation/launch_lb.py +140 -0
- sglang/srt/disaggregation/mini_lb.py +29 -48
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +446 -149
- sglang/srt/disaggregation/mooncake/transfer_engine.py +32 -16
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +134 -437
- sglang/srt/disaggregation/prefill.py +130 -43
- sglang/srt/disaggregation/utils.py +127 -86
- sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
- sglang/srt/distributed/parallel_state.py +52 -5
- sglang/srt/entrypoints/EngineBase.py +6 -0
- sglang/srt/entrypoints/engine.py +116 -5
- sglang/srt/entrypoints/http_server.py +28 -4
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +138 -86
- sglang/srt/function_call/deepseekv3_detector.py +54 -6
- sglang/srt/function_call/ebnf_composer.py +33 -19
- sglang/srt/function_call/function_call_parser.py +27 -0
- sglang/srt/function_call/llama32_detector.py +33 -14
- sglang/srt/function_call/mistral_detector.py +73 -26
- sglang/srt/function_call/pythonic_detector.py +86 -20
- sglang/srt/function_call/qwen25_detector.py +64 -10
- sglang/srt/function_call/utils.py +17 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +19 -0
- sglang/srt/layers/attention/aiter_backend.py +503 -125
- sglang/srt/layers/attention/base_attn_backend.py +4 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +40 -34
- sglang/srt/layers/attention/flashattention_backend.py +137 -63
- sglang/srt/layers/attention/flashinfer_backend.py +46 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +59 -25
- sglang/srt/layers/attention/flashmla_backend.py +2 -10
- sglang/srt/layers/attention/intel_amx_backend.py +128 -0
- sglang/srt/layers/attention/tbo_backend.py +232 -0
- sglang/srt/layers/attention/torch_native_backend.py +3 -0
- sglang/srt/layers/attention/triton_backend.py +304 -65
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +281 -197
- sglang/srt/layers/dp_attention.py +6 -5
- sglang/srt/layers/layernorm.py +30 -19
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +0 -12
- sglang/srt/layers/moe/cutlass_moe.py +170 -7
- sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +136 -72
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +24 -45
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +221 -29
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -4
- sglang/srt/layers/moe/topk.py +60 -26
- sglang/srt/layers/multimodal.py +3 -3
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/__init__.py +3 -2
- sglang/srt/layers/quantization/blockwise_int8.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +69 -127
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +28 -23
- sglang/srt/layers/quantization/fp8_kernel.py +156 -75
- sglang/srt/layers/quantization/fp8_utils.py +250 -69
- sglang/srt/layers/quantization/modelopt_quant.py +334 -7
- sglang/srt/layers/quantization/moe_wna16.py +3 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +6 -12
- sglang/srt/layers/sampler.py +80 -79
- sglang/srt/layers/utils.py +6 -0
- sglang/srt/lora/layers.py +12 -15
- sglang/srt/lora/lora.py +49 -5
- sglang/srt/lora/lora_manager.py +98 -39
- sglang/srt/lora/mem_pool.py +28 -21
- sglang/srt/lora/utils.py +17 -13
- sglang/srt/managers/cache_controller.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +13 -5
- sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
- sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
- sglang/srt/managers/{deepseek_eplb.py → eplb_algorithms/deepseek_vec.py} +5 -7
- sglang/srt/managers/eplb_manager.py +55 -14
- sglang/srt/managers/expert_distribution.py +220 -46
- sglang/srt/managers/expert_location.py +110 -56
- sglang/srt/managers/expert_location_dispatch.py +23 -6
- sglang/srt/managers/io_struct.py +43 -8
- sglang/srt/managers/mm_utils.py +88 -38
- sglang/srt/managers/multimodal_processors/base_processor.py +190 -18
- sglang/srt/managers/multimodal_processors/gemma3.py +4 -31
- sglang/srt/managers/multimodal_processors/internvl.py +4 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +15 -34
- sglang/srt/managers/multimodal_processors/minicpm.py +2 -1
- sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -64
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +173 -38
- sglang/srt/managers/scheduler.py +376 -127
- sglang/srt/managers/tokenizer_manager.py +163 -19
- sglang/srt/managers/utils.py +0 -4
- sglang/srt/mem_cache/chunk_cache.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +4 -2
- sglang/srt/mem_cache/memory_pool.py +111 -407
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +36 -12
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +191 -113
- sglang/srt/model_executor/expert_location_updater.py +157 -22
- sglang/srt/model_executor/forward_batch_info.py +52 -22
- sglang/srt/model_executor/model_runner.py +102 -62
- sglang/srt/model_loader/loader.py +8 -1
- sglang/srt/model_loader/utils.py +67 -1
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +1 -1
- sglang/srt/models/deepseek_v2.py +623 -290
- sglang/srt/models/gemma3_causal.py +7 -0
- sglang/srt/models/gemma3_mm.py +19 -14
- sglang/srt/models/idefics2.py +342 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/kimi_vl.py +4 -4
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/minicpmo.py +2 -5
- sglang/srt/models/minicpmv.py +3 -295
- sglang/srt/models/phi4mm.py +512 -0
- sglang/srt/models/qwen2.py +38 -9
- sglang/srt/models/qwen2_5_vl.py +3 -9
- sglang/srt/models/qwen2_eagle.py +4 -1
- sglang/srt/models/qwen2_moe.py +58 -191
- sglang/srt/models/qwen2_vl.py +3 -9
- sglang/srt/models/qwen3.py +41 -10
- sglang/srt/models/qwen3_moe.py +230 -191
- sglang/srt/models/registry.py +9 -1
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/transformers.py +291 -0
- sglang/srt/models/vila.py +305 -0
- sglang/srt/openai_api/adapter.py +248 -28
- sglang/srt/openai_api/protocol.py +68 -3
- sglang/srt/openai_api/utils.py +172 -0
- sglang/srt/operations.py +37 -2
- sglang/srt/operations_strategy.py +200 -24
- sglang/srt/sampling/sampling_batch_info.py +37 -1
- sglang/srt/sampling/sampling_params.py +4 -1
- sglang/srt/server_args.py +381 -209
- sglang/srt/speculative/build_eagle_tree.py +9 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +12 -14
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +256 -0
- sglang/srt/speculative/eagle_utils.py +440 -200
- sglang/srt/speculative/eagle_worker.py +234 -63
- sglang/srt/two_batch_overlap.py +637 -0
- sglang/srt/utils.py +187 -7
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +54 -10
- sglang/test/send_one.py +4 -0
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +1 -0
- sglang/test/test_cutlass_moe.py +3 -3
- sglang/test/test_fp4_moe.py +248 -0
- sglang/test/test_utils.py +82 -7
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +17 -14
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +359 -321
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +1 -1
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -51,12 +51,13 @@ from sglang.srt.layers.linear import (
|
|
51
51
|
RowParallelLinear,
|
52
52
|
)
|
53
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
54
|
-
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
54
|
+
from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
|
55
55
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
56
56
|
from sglang.srt.layers.moe.topk import select_experts
|
57
|
+
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
57
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
58
|
-
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
59
59
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
60
|
+
is_fp8_fnuz,
|
60
61
|
per_tensor_quant_mla_fp8,
|
61
62
|
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
62
63
|
)
|
@@ -65,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
65
66
|
block_quant_to_tensor_quant,
|
66
67
|
channel_quant_to_tensor_quant,
|
67
68
|
normalize_e4m3fn_to_e4m3fnuz,
|
69
|
+
requant_weight_ue8m0_inplace,
|
68
70
|
)
|
69
71
|
from sglang.srt.layers.quantization.int8_utils import (
|
70
72
|
block_dequant as int8_block_dequant,
|
@@ -83,28 +85,31 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
|
|
83
85
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
84
86
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
85
87
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
86
|
-
from sglang.srt.
|
87
|
-
|
88
|
+
from sglang.srt.two_batch_overlap import (
|
89
|
+
MaybeTboDeepEPDispatcher,
|
90
|
+
model_forward_maybe_tbo,
|
91
|
+
)
|
88
92
|
from sglang.srt.utils import (
|
89
93
|
BumpAllocator,
|
90
94
|
DeepEPMode,
|
95
|
+
LazyValue,
|
91
96
|
add_prefix,
|
97
|
+
bind_or_assign,
|
92
98
|
get_bool_env_var,
|
93
99
|
get_int_env_var,
|
94
100
|
is_cuda,
|
95
101
|
is_hip,
|
102
|
+
is_non_idle_and_non_empty,
|
96
103
|
log_info_on_rank0,
|
97
104
|
)
|
98
105
|
|
99
106
|
_is_hip = is_hip()
|
100
107
|
_is_cuda = is_cuda()
|
108
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
109
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
101
110
|
|
102
111
|
if _is_cuda:
|
103
112
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
104
|
-
|
105
|
-
from sglang.srt.layers.quantization.deep_gemm import (
|
106
|
-
grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
|
107
|
-
)
|
108
113
|
else:
|
109
114
|
from vllm._custom_ops import awq_dequantize
|
110
115
|
|
@@ -113,6 +118,9 @@ if _is_hip:
|
|
113
118
|
decode_attention_fwd_grouped_rope,
|
114
119
|
)
|
115
120
|
|
121
|
+
if _use_aiter:
|
122
|
+
from aiter.rotary_embedding import get_rope
|
123
|
+
|
116
124
|
logger = logging.getLogger(__name__)
|
117
125
|
|
118
126
|
|
@@ -204,14 +212,6 @@ class MoEGate(nn.Module):
|
|
204
212
|
return logits
|
205
213
|
|
206
214
|
|
207
|
-
def is_non_idle_and_non_empty(forward_mode, hidden_states):
|
208
|
-
return (
|
209
|
-
(forward_mode is not None)
|
210
|
-
and not forward_mode.is_idle()
|
211
|
-
and hidden_states.shape[0] > 0
|
212
|
-
)
|
213
|
-
|
214
|
-
|
215
215
|
class DeepseekV2MoE(nn.Module):
|
216
216
|
|
217
217
|
def __init__(
|
@@ -225,7 +225,12 @@ class DeepseekV2MoE(nn.Module):
|
|
225
225
|
self.tp_size = get_tensor_model_parallel_world_size()
|
226
226
|
self.routed_scaling_factor = config.routed_scaling_factor
|
227
227
|
self.n_shared_experts = config.n_shared_experts
|
228
|
-
self.
|
228
|
+
self.num_fused_shared_experts = (
|
229
|
+
0
|
230
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
231
|
+
else config.n_shared_experts
|
232
|
+
)
|
233
|
+
self.config = config
|
229
234
|
self.layer_id = layer_id
|
230
235
|
|
231
236
|
if self.tp_size > config.n_routed_experts:
|
@@ -244,9 +249,9 @@ class DeepseekV2MoE(nn.Module):
|
|
244
249
|
|
245
250
|
self.experts = get_moe_impl_class()(
|
246
251
|
num_experts=config.n_routed_experts
|
247
|
-
+ self.
|
252
|
+
+ self.num_fused_shared_experts
|
248
253
|
+ global_server_args_dict["ep_num_redundant_experts"],
|
249
|
-
top_k=config.num_experts_per_tok +
|
254
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
250
255
|
hidden_size=config.hidden_size,
|
251
256
|
intermediate_size=config.moe_intermediate_size,
|
252
257
|
layer_id=self.layer_id,
|
@@ -254,6 +259,7 @@ class DeepseekV2MoE(nn.Module):
|
|
254
259
|
quant_config=quant_config,
|
255
260
|
use_grouped_topk=True,
|
256
261
|
num_expert_group=config.n_group,
|
262
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
257
263
|
topk_group=config.topk_group,
|
258
264
|
correction_bias=self.gate.e_score_correction_bias,
|
259
265
|
routed_scaling_factor=self.routed_scaling_factor,
|
@@ -265,7 +271,7 @@ class DeepseekV2MoE(nn.Module):
|
|
265
271
|
),
|
266
272
|
)
|
267
273
|
|
268
|
-
if config.n_shared_experts is not None and self.
|
274
|
+
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
269
275
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
270
276
|
# disable tp for shared experts when enable deepep moe
|
271
277
|
self.shared_experts = DeepseekV2MLP(
|
@@ -300,7 +306,7 @@ class DeepseekV2MoE(nn.Module):
|
|
300
306
|
else None
|
301
307
|
)
|
302
308
|
|
303
|
-
self.deepep_dispatcher =
|
309
|
+
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
304
310
|
group=parallel_state.get_tp_group().device_group,
|
305
311
|
router_topk=self.top_k,
|
306
312
|
permute_fusion=True,
|
@@ -309,13 +315,11 @@ class DeepseekV2MoE(nn.Module):
|
|
309
315
|
hidden_size=config.hidden_size,
|
310
316
|
params_dtype=config.torch_dtype,
|
311
317
|
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
312
|
-
async_finish=True,
|
318
|
+
async_finish=True,
|
313
319
|
return_recv_hook=True,
|
314
320
|
)
|
315
321
|
|
316
|
-
|
317
|
-
def _enable_deepep_moe(self):
|
318
|
-
return global_server_args_dict["enable_deepep_moe"]
|
322
|
+
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
319
323
|
|
320
324
|
def get_moe_weights(self):
|
321
325
|
return [
|
@@ -324,8 +328,114 @@ class DeepseekV2MoE(nn.Module):
|
|
324
328
|
if name not in ["correction_bias"]
|
325
329
|
]
|
326
330
|
|
331
|
+
def forward(
|
332
|
+
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
333
|
+
) -> torch.Tensor:
|
334
|
+
if not self._enable_deepep_moe:
|
335
|
+
return self.forward_normal(hidden_states)
|
336
|
+
else:
|
337
|
+
return self.forward_deepep(hidden_states, forward_batch)
|
338
|
+
|
339
|
+
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
340
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
341
|
+
# router_logits: (num_tokens, n_experts)
|
342
|
+
router_logits = self.gate(hidden_states)
|
343
|
+
final_hidden_states = self.experts(
|
344
|
+
hidden_states=hidden_states, router_logits=router_logits
|
345
|
+
)
|
346
|
+
if not _is_cuda:
|
347
|
+
final_hidden_states *= self.routed_scaling_factor
|
348
|
+
if shared_output is not None:
|
349
|
+
final_hidden_states = final_hidden_states + shared_output
|
350
|
+
if self.tp_size > 1:
|
351
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
352
|
+
return final_hidden_states
|
353
|
+
|
354
|
+
def forward_deepep(
|
355
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
356
|
+
) -> torch.Tensor:
|
357
|
+
forward_mode = forward_batch.forward_mode
|
358
|
+
shared_output = None
|
359
|
+
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
360
|
+
# router_logits: (num_tokens, n_experts)
|
361
|
+
router_logits = self.gate(hidden_states)
|
362
|
+
shared_output = self._forward_shared_experts(hidden_states)
|
363
|
+
topk_weights, topk_idx = select_experts(
|
364
|
+
hidden_states=hidden_states,
|
365
|
+
router_logits=router_logits,
|
366
|
+
top_k=self.top_k,
|
367
|
+
use_grouped_topk=True,
|
368
|
+
renormalize=self.renormalize,
|
369
|
+
topk_group=self.topk_group,
|
370
|
+
num_expert_group=self.num_expert_group,
|
371
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
372
|
+
correction_bias=self.correction_bias,
|
373
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
374
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
375
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
376
|
+
layer_id=self.layer_id,
|
377
|
+
),
|
378
|
+
)
|
379
|
+
else:
|
380
|
+
topk_idx = torch.full(
|
381
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
382
|
+
)
|
383
|
+
topk_weights = torch.empty(
|
384
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
385
|
+
)
|
386
|
+
if self.ep_size > 1:
|
387
|
+
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
388
|
+
(
|
389
|
+
hidden_states,
|
390
|
+
topk_idx,
|
391
|
+
topk_weights,
|
392
|
+
reorder_topk_ids,
|
393
|
+
num_recv_tokens_per_expert,
|
394
|
+
seg_indptr,
|
395
|
+
masked_m,
|
396
|
+
expected_m,
|
397
|
+
) = self.deepep_dispatcher.dispatch(
|
398
|
+
hidden_states=hidden_states,
|
399
|
+
topk_idx=topk_idx,
|
400
|
+
topk_weights=topk_weights,
|
401
|
+
forward_mode=forward_mode,
|
402
|
+
)
|
403
|
+
final_hidden_states = self.experts(
|
404
|
+
hidden_states=hidden_states,
|
405
|
+
topk_idx=topk_idx,
|
406
|
+
topk_weights=topk_weights,
|
407
|
+
reorder_topk_ids=reorder_topk_ids,
|
408
|
+
seg_indptr=seg_indptr,
|
409
|
+
masked_m=masked_m,
|
410
|
+
expected_m=expected_m,
|
411
|
+
num_recv_tokens_per_expert=num_recv_tokens_per_expert,
|
412
|
+
forward_mode=forward_mode,
|
413
|
+
)
|
414
|
+
if self.ep_size > 1:
|
415
|
+
final_hidden_states = self.deepep_dispatcher.combine(
|
416
|
+
hidden_states=final_hidden_states,
|
417
|
+
topk_idx=topk_idx,
|
418
|
+
topk_weights=topk_weights,
|
419
|
+
forward_mode=forward_mode,
|
420
|
+
)
|
421
|
+
|
422
|
+
if shared_output is not None:
|
423
|
+
x = shared_output
|
424
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
425
|
+
final_hidden_states = x
|
426
|
+
else:
|
427
|
+
final_hidden_states *= self.routed_scaling_factor
|
428
|
+
|
429
|
+
return final_hidden_states
|
430
|
+
|
431
|
+
def _forward_shared_experts(self, hidden_states):
|
432
|
+
if self.num_fused_shared_experts == 0:
|
433
|
+
return self.shared_experts(hidden_states)
|
434
|
+
else:
|
435
|
+
return None
|
436
|
+
|
327
437
|
def op_gate(self, state):
|
328
|
-
if
|
438
|
+
if is_non_idle_and_non_empty(
|
329
439
|
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
330
440
|
):
|
331
441
|
# router_logits: (num_tokens, n_experts)
|
@@ -334,22 +444,22 @@ class DeepseekV2MoE(nn.Module):
|
|
334
444
|
state.router_logits = None
|
335
445
|
|
336
446
|
def op_shared_experts(self, state):
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
341
|
-
)
|
447
|
+
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
|
448
|
+
if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
|
449
|
+
state.forward_batch.forward_mode, hidden_states_mlp_input
|
342
450
|
):
|
343
|
-
state.shared_output = self.shared_experts(
|
451
|
+
state.shared_output = self.shared_experts(hidden_states_mlp_input)
|
344
452
|
else:
|
345
453
|
state.shared_output = None
|
346
454
|
|
347
455
|
def op_select_experts(self, state):
|
348
|
-
router_logits = state.router_logits
|
456
|
+
router_logits = state.pop("router_logits")
|
349
457
|
hidden_states = state.hidden_states_mlp_input
|
350
458
|
|
351
|
-
if
|
352
|
-
|
459
|
+
if router_logits is not None:
|
460
|
+
with get_global_expert_distribution_recorder().with_current_layer(
|
461
|
+
self.layer_id
|
462
|
+
):
|
353
463
|
state.topk_weights_local, state.topk_idx_local = select_experts(
|
354
464
|
hidden_states=hidden_states,
|
355
465
|
router_logits=router_logits,
|
@@ -358,90 +468,89 @@ class DeepseekV2MoE(nn.Module):
|
|
358
468
|
renormalize=self.renormalize,
|
359
469
|
topk_group=self.topk_group,
|
360
470
|
num_expert_group=self.num_expert_group,
|
471
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
361
472
|
correction_bias=self.correction_bias,
|
362
473
|
routed_scaling_factor=self.routed_scaling_factor,
|
474
|
+
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
363
475
|
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
364
476
|
layer_id=self.layer_id,
|
365
477
|
),
|
366
478
|
)
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
371
|
-
|
372
|
-
|
373
|
-
|
479
|
+
else:
|
480
|
+
state.topk_idx_local = torch.full(
|
481
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
482
|
+
)
|
483
|
+
state.topk_weights_local = torch.empty(
|
484
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
485
|
+
)
|
374
486
|
|
375
487
|
def op_dispatch_a(self, state):
|
376
|
-
if self.
|
488
|
+
if self.ep_size > 1:
|
377
489
|
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
378
490
|
self.deepep_dispatcher.dispatch_a(
|
379
|
-
hidden_states=state.
|
491
|
+
hidden_states=state.hidden_states_mlp_input,
|
380
492
|
topk_idx=state.pop("topk_idx_local"),
|
381
493
|
topk_weights=state.pop("topk_weights_local"),
|
382
494
|
forward_mode=state.forward_batch.forward_mode,
|
495
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
383
496
|
)
|
384
497
|
|
385
498
|
def op_dispatch_b(self, state):
|
386
|
-
if self.
|
387
|
-
(
|
388
|
-
|
389
|
-
|
390
|
-
|
391
|
-
|
392
|
-
|
393
|
-
|
394
|
-
|
395
|
-
|
396
|
-
|
499
|
+
if self.ep_size > 1:
|
500
|
+
with get_global_expert_distribution_recorder().with_current_layer(
|
501
|
+
self.layer_id
|
502
|
+
):
|
503
|
+
(
|
504
|
+
state.hidden_states_experts_input,
|
505
|
+
state.topk_idx_dispatched,
|
506
|
+
state.topk_weights_dispatched,
|
507
|
+
state.reorder_topk_ids,
|
508
|
+
state.num_recv_tokens_per_expert,
|
509
|
+
state.seg_indptr,
|
510
|
+
state.masked_m,
|
511
|
+
state.expected_m,
|
512
|
+
) = self.deepep_dispatcher.dispatch_b(
|
513
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
514
|
+
)
|
397
515
|
|
398
516
|
def op_experts(self, state):
|
399
|
-
|
400
|
-
state.pop("
|
401
|
-
state.
|
402
|
-
|
403
|
-
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
forward_mode=state.forward_batch.forward_mode,
|
411
|
-
)
|
412
|
-
else:
|
413
|
-
state.hidden_states_experts_output = self.experts(
|
414
|
-
hidden_states=state.pop("hidden_states_mlp_input"),
|
415
|
-
router_logits=state.pop("router_logits"),
|
416
|
-
)
|
517
|
+
state.hidden_states_experts_output = self.experts(
|
518
|
+
hidden_states=state.pop("hidden_states_experts_input"),
|
519
|
+
topk_idx=state.topk_idx_dispatched,
|
520
|
+
topk_weights=state.topk_weights_dispatched,
|
521
|
+
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
522
|
+
seg_indptr=state.pop("seg_indptr"),
|
523
|
+
masked_m=state.pop("masked_m"),
|
524
|
+
expected_m=state.pop("expected_m"),
|
525
|
+
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
526
|
+
forward_mode=state.forward_batch.forward_mode,
|
527
|
+
)
|
417
528
|
|
418
529
|
def op_combine_a(self, state):
|
419
|
-
if self.
|
530
|
+
if self.ep_size > 1:
|
420
531
|
self.deepep_dispatcher.combine_a(
|
421
|
-
state.pop("hidden_states_experts_output"),
|
532
|
+
hidden_states=state.pop("hidden_states_experts_output"),
|
422
533
|
topk_idx=state.pop("topk_idx_dispatched"),
|
423
534
|
topk_weights=state.pop("topk_weights_dispatched"),
|
424
535
|
forward_mode=state.forward_batch.forward_mode,
|
536
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
425
537
|
)
|
426
538
|
|
427
539
|
def op_combine_b(self, state):
|
428
|
-
if self.
|
429
|
-
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
540
|
+
if self.ep_size > 1:
|
541
|
+
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
542
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
543
|
+
)
|
430
544
|
|
431
545
|
def op_output(self, state):
|
432
|
-
final_hidden_states = (
|
433
|
-
state.pop("hidden_states_after_combine")
|
434
|
-
if self._enable_deepep_moe
|
435
|
-
else state.pop("hidden_states_experts_output")
|
436
|
-
)
|
437
|
-
|
438
|
-
final_hidden_states *= self.routed_scaling_factor
|
439
|
-
|
440
|
-
if (s := state.pop("shared_output")) is not None:
|
441
|
-
final_hidden_states = final_hidden_states + s
|
546
|
+
final_hidden_states = state.pop("hidden_states_after_combine")
|
442
547
|
|
443
|
-
if (
|
444
|
-
|
548
|
+
if (shared_output := state.pop("shared_output")) is not None:
|
549
|
+
x = shared_output
|
550
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
551
|
+
final_hidden_states = x
|
552
|
+
else:
|
553
|
+
final_hidden_states *= self.routed_scaling_factor
|
445
554
|
|
446
555
|
state.hidden_states_mlp_output = final_hidden_states
|
447
556
|
|
@@ -596,10 +705,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
596
705
|
)
|
597
706
|
|
598
707
|
self.alt_stream = alt_stream
|
708
|
+
self.attn_mha.kv_b_proj = None
|
599
709
|
|
600
710
|
self.w_kc = None
|
601
711
|
self.w_vc = None
|
602
|
-
self.w_scale =
|
712
|
+
self.w_scale = 1.0
|
603
713
|
|
604
714
|
self.w_scale_k = None
|
605
715
|
self.w_scale_v = None
|
@@ -665,6 +775,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
665
775
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
666
776
|
else:
|
667
777
|
return _dispatch_mla_subtype()
|
778
|
+
elif self.attention_backend == "aiter":
|
779
|
+
if (
|
780
|
+
forward_batch.forward_mode.is_extend()
|
781
|
+
and not forward_batch.forward_mode.is_target_verify()
|
782
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
783
|
+
):
|
784
|
+
return AttnForwardMethod.MHA
|
785
|
+
else:
|
786
|
+
return AttnForwardMethod.MLA
|
668
787
|
else:
|
669
788
|
# Triton: Use normal computation for prefill and use weight absorption for extend/decode
|
670
789
|
if (
|
@@ -677,44 +796,97 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
677
796
|
else:
|
678
797
|
return _dispatch_mla_subtype()
|
679
798
|
|
799
|
+
def op_prepare(self, state):
|
800
|
+
state.attn_intermediate_state = self.forward_prepare(
|
801
|
+
positions=state.positions,
|
802
|
+
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
803
|
+
forward_batch=state.forward_batch,
|
804
|
+
zero_allocator=state.zero_allocator,
|
805
|
+
)
|
806
|
+
|
807
|
+
def op_core(self, state):
|
808
|
+
state.hidden_states_after_attn = self.forward_core(
|
809
|
+
state.pop("attn_intermediate_state")
|
810
|
+
)
|
811
|
+
|
680
812
|
def forward(
|
681
813
|
self,
|
682
814
|
positions: torch.Tensor,
|
683
815
|
hidden_states: torch.Tensor,
|
684
816
|
forward_batch: ForwardBatch,
|
685
817
|
zero_allocator: BumpAllocator,
|
686
|
-
)
|
818
|
+
):
|
819
|
+
s = self.forward_prepare(
|
820
|
+
positions=positions,
|
821
|
+
hidden_states=hidden_states,
|
822
|
+
forward_batch=forward_batch,
|
823
|
+
zero_allocator=zero_allocator,
|
824
|
+
)
|
825
|
+
return self.forward_core(s)
|
826
|
+
|
827
|
+
def forward_prepare(
|
828
|
+
self,
|
829
|
+
positions: torch.Tensor,
|
830
|
+
hidden_states: torch.Tensor,
|
831
|
+
forward_batch: ForwardBatch,
|
832
|
+
zero_allocator: BumpAllocator,
|
833
|
+
):
|
834
|
+
if self.attn_mha.kv_b_proj is None:
|
835
|
+
self.attn_mha.kv_b_proj = self.kv_b_proj
|
836
|
+
|
687
837
|
if hidden_states.shape[0] == 0:
|
688
838
|
assert (
|
689
839
|
not self.o_proj.reduce_results
|
690
840
|
), "short-circuiting allreduce will lead to hangs"
|
691
|
-
return hidden_states
|
841
|
+
return hidden_states, None, forward_batch, None
|
692
842
|
|
693
843
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
694
844
|
|
695
845
|
if attn_forward_method == AttnForwardMethod.MHA:
|
696
|
-
|
846
|
+
inner_state = self.forward_normal_prepare(
|
847
|
+
positions, hidden_states, forward_batch, zero_allocator
|
848
|
+
)
|
697
849
|
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
698
|
-
|
699
|
-
positions, hidden_states, forward_batch
|
850
|
+
inner_state = self.forward_normal_chunked_kv_prepare(
|
851
|
+
positions, hidden_states, forward_batch, zero_allocator
|
700
852
|
)
|
701
853
|
elif attn_forward_method == AttnForwardMethod.MLA:
|
702
|
-
|
854
|
+
inner_state = self.forward_absorb_prepare(
|
703
855
|
positions, hidden_states, forward_batch, zero_allocator
|
704
856
|
)
|
705
857
|
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
706
|
-
|
707
|
-
positions, hidden_states, forward_batch
|
858
|
+
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
859
|
+
positions, hidden_states, forward_batch, zero_allocator
|
708
860
|
)
|
709
861
|
else:
|
710
862
|
raise NotImplementedError
|
863
|
+
return None, attn_forward_method, forward_batch, inner_state
|
864
|
+
|
865
|
+
def forward_core(self, intermediate_state):
|
866
|
+
hidden_states, attn_forward_method, forward_batch, inner_state = (
|
867
|
+
intermediate_state
|
868
|
+
)
|
869
|
+
if inner_state is None:
|
870
|
+
return hidden_states
|
871
|
+
|
872
|
+
if attn_forward_method == AttnForwardMethod.MHA:
|
873
|
+
return self.forward_normal_core(*inner_state)
|
874
|
+
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
875
|
+
return self.forward_normal_chunked_kv_core(*inner_state)
|
876
|
+
elif attn_forward_method == AttnForwardMethod.MLA:
|
877
|
+
return self.forward_absorb_core(*inner_state)
|
878
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
879
|
+
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
880
|
+
else:
|
881
|
+
raise NotImplementedError
|
711
882
|
|
712
|
-
def
|
883
|
+
def forward_normal_prepare(
|
713
884
|
self,
|
714
885
|
positions: torch.Tensor,
|
715
886
|
hidden_states: torch.Tensor,
|
716
887
|
forward_batch: ForwardBatch,
|
717
|
-
|
888
|
+
zero_allocator: BumpAllocator,
|
889
|
+
):
|
718
890
|
if self.q_lora_rank is not None:
|
719
891
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
720
892
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
@@ -749,18 +921,22 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
749
921
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
750
922
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
751
923
|
)
|
924
|
+
|
925
|
+
return q, k, v, forward_batch
|
926
|
+
|
927
|
+
def forward_normal_core(self, q, k, v, forward_batch):
|
752
928
|
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
753
929
|
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
754
930
|
output, _ = self.o_proj(attn_output)
|
755
931
|
return output
|
756
932
|
|
757
|
-
def
|
933
|
+
def forward_absorb_prepare(
|
758
934
|
self,
|
759
935
|
positions: torch.Tensor,
|
760
936
|
hidden_states: torch.Tensor,
|
761
937
|
forward_batch: ForwardBatch,
|
762
938
|
zero_allocator: BumpAllocator,
|
763
|
-
)
|
939
|
+
):
|
764
940
|
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
765
941
|
|
766
942
|
if self.q_lora_rank is not None:
|
@@ -801,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
801
977
|
q_nope_out = q_nope.new_empty(
|
802
978
|
(self.num_local_heads, aligned_m, self.kv_lora_rank)
|
803
979
|
)
|
804
|
-
|
980
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
805
981
|
(q_nope_val, q_nope_scale),
|
806
982
|
(self.w_kc, self.w_scale_k),
|
807
983
|
q_nope_out,
|
@@ -809,8 +985,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
809
985
|
expected_m,
|
810
986
|
)
|
811
987
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
812
|
-
elif
|
813
|
-
# TODO(
|
988
|
+
elif _is_hip:
|
989
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
814
990
|
q_nope_out = torch.bmm(
|
815
991
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
816
992
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
@@ -829,7 +1005,16 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
829
1005
|
q_nope_out = q_nope_out.transpose(0, 1)
|
830
1006
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
831
1007
|
|
832
|
-
|
1008
|
+
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1009
|
+
|
1010
|
+
def forward_absorb_core(
|
1011
|
+
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1012
|
+
):
|
1013
|
+
if (
|
1014
|
+
self.attention_backend == "fa3"
|
1015
|
+
or self.attention_backend == "flashinfer"
|
1016
|
+
or self.attention_backend == "cutlass_mla"
|
1017
|
+
):
|
833
1018
|
attn_output = self.attn_mqa(
|
834
1019
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
835
1020
|
)
|
@@ -848,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
848
1033
|
attn_bmm_output = attn_output.new_empty(
|
849
1034
|
(self.num_local_heads, aligned_m, self.v_head_dim)
|
850
1035
|
)
|
851
|
-
|
1036
|
+
deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
|
852
1037
|
(attn_output_val, attn_output_scale),
|
853
1038
|
(self.w_vc, self.w_scale_v),
|
854
1039
|
attn_bmm_output,
|
@@ -856,8 +1041,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
856
1041
|
expected_m,
|
857
1042
|
)
|
858
1043
|
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
859
|
-
elif
|
860
|
-
# TODO(
|
1044
|
+
elif _is_hip:
|
1045
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
861
1046
|
attn_bmm_output = torch.bmm(
|
862
1047
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
863
1048
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
@@ -881,13 +1066,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
881
1066
|
|
882
1067
|
return output
|
883
1068
|
|
884
|
-
def
|
1069
|
+
def forward_absorb_fused_mla_rope_prepare(
|
885
1070
|
self,
|
886
1071
|
positions: torch.Tensor,
|
887
1072
|
hidden_states: torch.Tensor,
|
888
1073
|
forward_batch: ForwardBatch,
|
889
1074
|
zero_allocator: BumpAllocator,
|
890
|
-
)
|
1075
|
+
):
|
891
1076
|
enable_rope_fusion = (
|
892
1077
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
893
1078
|
)
|
@@ -908,8 +1093,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
908
1093
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
909
1094
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
910
1095
|
|
911
|
-
if
|
912
|
-
# TODO(
|
1096
|
+
if _is_hip:
|
1097
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
913
1098
|
q_nope_out = torch.bmm(
|
914
1099
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
915
1100
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
@@ -976,6 +1161,44 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
976
1161
|
)
|
977
1162
|
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
978
1163
|
|
1164
|
+
return (
|
1165
|
+
q_input,
|
1166
|
+
key_cache_buf,
|
1167
|
+
val_cache_buf,
|
1168
|
+
attn_output,
|
1169
|
+
kv_indptr,
|
1170
|
+
kv_indices,
|
1171
|
+
k_pe_output,
|
1172
|
+
cos_sin_cache,
|
1173
|
+
positions,
|
1174
|
+
attn_logits,
|
1175
|
+
num_kv_split,
|
1176
|
+
sm_scale,
|
1177
|
+
enable_rope_fusion,
|
1178
|
+
k_input,
|
1179
|
+
forward_batch,
|
1180
|
+
zero_allocator,
|
1181
|
+
)
|
1182
|
+
|
1183
|
+
def forward_absorb_fused_mla_rope_core(
|
1184
|
+
self,
|
1185
|
+
q_input,
|
1186
|
+
key_cache_buf,
|
1187
|
+
val_cache_buf,
|
1188
|
+
attn_output,
|
1189
|
+
kv_indptr,
|
1190
|
+
kv_indices,
|
1191
|
+
k_pe_output,
|
1192
|
+
cos_sin_cache,
|
1193
|
+
positions,
|
1194
|
+
attn_logits,
|
1195
|
+
num_kv_split,
|
1196
|
+
sm_scale,
|
1197
|
+
enable_rope_fusion,
|
1198
|
+
k_input,
|
1199
|
+
forward_batch,
|
1200
|
+
zero_allocator,
|
1201
|
+
):
|
979
1202
|
decode_attention_fwd_grouped_rope(
|
980
1203
|
q_input,
|
981
1204
|
key_cache_buf,
|
@@ -1004,8 +1227,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1004
1227
|
|
1005
1228
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
1006
1229
|
|
1007
|
-
if
|
1008
|
-
# TODO(
|
1230
|
+
if _is_hip:
|
1231
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
1009
1232
|
attn_bmm_output = torch.bmm(
|
1010
1233
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
1011
1234
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
@@ -1082,12 +1305,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1082
1305
|
|
1083
1306
|
return accum_output
|
1084
1307
|
|
1085
|
-
def
|
1308
|
+
def forward_normal_chunked_kv_prepare(
|
1086
1309
|
self,
|
1087
1310
|
positions: torch.Tensor,
|
1088
1311
|
hidden_states: torch.Tensor,
|
1089
1312
|
forward_batch: ForwardBatch,
|
1090
|
-
|
1313
|
+
zero_allocator: BumpAllocator,
|
1314
|
+
):
|
1091
1315
|
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
1092
1316
|
# To avoid this, we split the kv cache into chunks and process them one after another.
|
1093
1317
|
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
@@ -1130,6 +1354,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1130
1354
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1131
1355
|
)
|
1132
1356
|
|
1357
|
+
return q, k, v, forward_batch
|
1358
|
+
|
1359
|
+
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
1133
1360
|
# Do mha for extended part without prefix
|
1134
1361
|
forward_batch.set_attn_attend_prefix_cache(False)
|
1135
1362
|
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
@@ -1252,17 +1479,29 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1252
1479
|
residual: Optional[torch.Tensor],
|
1253
1480
|
zero_allocator: BumpAllocator,
|
1254
1481
|
) -> torch.Tensor:
|
1255
|
-
|
1256
|
-
|
1257
|
-
positions=positions,
|
1258
|
-
hidden_states=hidden_states,
|
1259
|
-
forward_batch=forward_batch,
|
1260
|
-
residual=residual,
|
1261
|
-
zero_allocator=zero_allocator,
|
1262
|
-
),
|
1263
|
-
operations=compute_layer_operations(self),
|
1482
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1483
|
+
hidden_states, residual, forward_batch
|
1264
1484
|
)
|
1265
1485
|
|
1486
|
+
hidden_states = self.self_attn(
|
1487
|
+
positions=positions,
|
1488
|
+
hidden_states=hidden_states,
|
1489
|
+
forward_batch=forward_batch,
|
1490
|
+
zero_allocator=zero_allocator,
|
1491
|
+
)
|
1492
|
+
|
1493
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
1494
|
+
hidden_states, residual, forward_batch
|
1495
|
+
)
|
1496
|
+
|
1497
|
+
hidden_states = self.mlp(hidden_states, forward_batch)
|
1498
|
+
|
1499
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1500
|
+
hidden_states, residual, forward_batch
|
1501
|
+
)
|
1502
|
+
|
1503
|
+
return hidden_states, residual
|
1504
|
+
|
1266
1505
|
def op_comm_prepare_attn(
|
1267
1506
|
self,
|
1268
1507
|
state,
|
@@ -1271,6 +1510,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1271
1510
|
forward_batch: ForwardBatch,
|
1272
1511
|
residual: Optional[torch.Tensor],
|
1273
1512
|
zero_allocator: BumpAllocator,
|
1513
|
+
tbo_subbatch_index: Optional[int] = None,
|
1274
1514
|
):
|
1275
1515
|
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
1276
1516
|
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
@@ -1280,17 +1520,10 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1280
1520
|
forward_batch=forward_batch,
|
1281
1521
|
positions=positions,
|
1282
1522
|
zero_allocator=zero_allocator,
|
1523
|
+
tbo_subbatch_index=tbo_subbatch_index,
|
1283
1524
|
)
|
1284
1525
|
)
|
1285
1526
|
|
1286
|
-
def op_attn(self, state):
|
1287
|
-
state.hidden_states_after_attn = self.self_attn(
|
1288
|
-
positions=state.positions,
|
1289
|
-
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
1290
|
-
forward_batch=state.forward_batch,
|
1291
|
-
zero_allocator=state.zero_allocator,
|
1292
|
-
)
|
1293
|
-
|
1294
1527
|
def op_comm_prepare_mlp(self, state):
|
1295
1528
|
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
1296
1529
|
self.layer_communicator.prepare_mlp(
|
@@ -1320,8 +1553,24 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1320
1553
|
state.forward_batch,
|
1321
1554
|
)
|
1322
1555
|
|
1323
|
-
|
1324
|
-
|
1556
|
+
output = dict(
|
1557
|
+
positions=state.positions,
|
1558
|
+
hidden_states=hidden_states,
|
1559
|
+
residual=residual,
|
1560
|
+
forward_batch=state.forward_batch,
|
1561
|
+
zero_allocator=state.zero_allocator,
|
1562
|
+
tbo_subbatch_index=state.tbo_subbatch_index,
|
1563
|
+
)
|
1564
|
+
|
1565
|
+
state.clear(
|
1566
|
+
expect_keys={
|
1567
|
+
"positions",
|
1568
|
+
"forward_batch",
|
1569
|
+
"zero_allocator",
|
1570
|
+
"tbo_subbatch_index",
|
1571
|
+
}
|
1572
|
+
)
|
1573
|
+
return output
|
1325
1574
|
|
1326
1575
|
|
1327
1576
|
class DeepseekV2Model(nn.Module):
|
@@ -1336,6 +1585,7 @@ class DeepseekV2Model(nn.Module):
|
|
1336
1585
|
super().__init__()
|
1337
1586
|
self.padding_id = config.pad_token_id
|
1338
1587
|
self.vocab_size = config.vocab_size
|
1588
|
+
self.first_k_dense_replace = config.first_k_dense_replace
|
1339
1589
|
|
1340
1590
|
self.embed_tokens = VocabParallelEmbedding(
|
1341
1591
|
config.vocab_size,
|
@@ -1369,13 +1619,12 @@ class DeepseekV2Model(nn.Module):
|
|
1369
1619
|
forward_batch: ForwardBatch,
|
1370
1620
|
input_embeds: torch.Tensor = None,
|
1371
1621
|
) -> torch.Tensor:
|
1622
|
+
total_num_layers = len(self.layers)
|
1623
|
+
device = input_embeds.device if input_embeds is not None else input_ids.device
|
1372
1624
|
zero_allocator = BumpAllocator(
|
1373
|
-
|
1374
|
-
buffer_size=len(self.layers) * 2,
|
1625
|
+
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
1375
1626
|
dtype=torch.float32,
|
1376
|
-
device=
|
1377
|
-
input_embeds.device if input_embeds is not None else input_ids.device
|
1378
|
-
),
|
1627
|
+
device=device,
|
1379
1628
|
)
|
1380
1629
|
|
1381
1630
|
if input_embeds is None:
|
@@ -1384,12 +1633,33 @@ class DeepseekV2Model(nn.Module):
|
|
1384
1633
|
hidden_states = input_embeds
|
1385
1634
|
|
1386
1635
|
residual = None
|
1387
|
-
|
1636
|
+
|
1637
|
+
normal_num_layers = (
|
1638
|
+
self.first_k_dense_replace
|
1639
|
+
if forward_batch.can_run_tbo
|
1640
|
+
else total_num_layers
|
1641
|
+
)
|
1642
|
+
for i in range(normal_num_layers):
|
1388
1643
|
with get_global_expert_distribution_recorder().with_current_layer(i):
|
1389
1644
|
layer = self.layers[i]
|
1390
1645
|
hidden_states, residual = layer(
|
1391
1646
|
positions, hidden_states, forward_batch, residual, zero_allocator
|
1392
1647
|
)
|
1648
|
+
|
1649
|
+
if normal_num_layers != total_num_layers:
|
1650
|
+
hidden_states, residual = model_forward_maybe_tbo(
|
1651
|
+
layers=self.layers[normal_num_layers:],
|
1652
|
+
enable_tbo=True,
|
1653
|
+
positions=positions,
|
1654
|
+
forward_batch=forward_batch,
|
1655
|
+
hidden_states=hidden_states,
|
1656
|
+
residual=residual,
|
1657
|
+
input_data_scatter_mode=self.layers[
|
1658
|
+
normal_num_layers - 1
|
1659
|
+
].layer_scatter_modes.layer_output_mode,
|
1660
|
+
zero_allocator=zero_allocator,
|
1661
|
+
)
|
1662
|
+
|
1393
1663
|
if not forward_batch.forward_mode.is_idle():
|
1394
1664
|
if residual is None:
|
1395
1665
|
hidden_states = self.norm(hidden_states)
|
@@ -1410,7 +1680,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1410
1680
|
self.config = config
|
1411
1681
|
self.tp_size = get_tensor_model_parallel_world_size()
|
1412
1682
|
self.quant_config = quant_config
|
1413
|
-
self.
|
1683
|
+
self.determine_num_fused_shared_experts()
|
1414
1684
|
self.model = DeepseekV2Model(
|
1415
1685
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1416
1686
|
)
|
@@ -1424,41 +1694,50 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1424
1694
|
self.logits_processor = LogitsProcessor(config)
|
1425
1695
|
self.dp_size = get_local_attention_dp_size()
|
1426
1696
|
|
1427
|
-
|
1697
|
+
self._routed_experts_weights_of_layer = LazyValue(
|
1698
|
+
lambda: {
|
1699
|
+
layer_id: layer.mlp.get_moe_weights()
|
1700
|
+
for layer_id, layer in enumerate(self.model.layers)
|
1701
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
1702
|
+
}
|
1703
|
+
)
|
1704
|
+
|
1705
|
+
@property
|
1706
|
+
def routed_experts_weights_of_layer(self):
|
1707
|
+
return self._routed_experts_weights_of_layer.value
|
1708
|
+
|
1709
|
+
def determine_num_fused_shared_experts(
|
1428
1710
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
1429
1711
|
):
|
1430
|
-
self.
|
1431
|
-
if
|
1432
|
-
|
1433
|
-
|
1434
|
-
|
1435
|
-
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
|
1441
|
-
|
1442
|
-
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1448
|
-
|
1449
|
-
|
1450
|
-
|
1451
|
-
|
1452
|
-
|
1453
|
-
|
1454
|
-
|
1455
|
-
)
|
1456
|
-
|
1457
|
-
|
1458
|
-
|
1459
|
-
logger,
|
1460
|
-
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1461
|
-
)
|
1712
|
+
self.num_fused_shared_experts = 0
|
1713
|
+
if global_server_args_dict["disable_shared_experts_fusion"]:
|
1714
|
+
return
|
1715
|
+
|
1716
|
+
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1717
|
+
disable_reason = None
|
1718
|
+
if (
|
1719
|
+
not _is_cuda
|
1720
|
+
or torch.cuda.get_device_capability("cuda") < (9, 0)
|
1721
|
+
or self.config.architectures[0] != architecture
|
1722
|
+
or self.config.n_routed_experts != 256
|
1723
|
+
or self.config.n_shared_experts != 1
|
1724
|
+
):
|
1725
|
+
disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
|
1726
|
+
elif (
|
1727
|
+
global_server_args_dict["enable_deepep_moe"]
|
1728
|
+
or global_server_args_dict["enable_ep_moe"]
|
1729
|
+
):
|
1730
|
+
disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
|
1731
|
+
|
1732
|
+
if disable_reason is not None:
|
1733
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
1734
|
+
log_info_on_rank0(
|
1735
|
+
logger,
|
1736
|
+
f"{disable_reason} Shared experts fusion optimization is disabled.",
|
1737
|
+
)
|
1738
|
+
return
|
1739
|
+
|
1740
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
1462
1741
|
|
1463
1742
|
def get_input_embeddings(self) -> nn.Embedding:
|
1464
1743
|
return self.model.embed_tokens
|
@@ -1471,21 +1750,28 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1471
1750
|
forward_batch: ForwardBatch,
|
1472
1751
|
input_embeds: torch.Tensor = None,
|
1473
1752
|
) -> torch.Tensor:
|
1474
|
-
|
1475
1753
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
1476
1754
|
|
1477
1755
|
return self.logits_processor(
|
1478
1756
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1479
1757
|
)
|
1480
1758
|
|
1481
|
-
def post_load_weights(self, is_nextn=False):
|
1759
|
+
def post_load_weights(self, is_nextn=False, weight_names=None):
|
1482
1760
|
|
1483
1761
|
# Perform post-processing after loading weights
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1762
|
+
if is_nextn:
|
1763
|
+
layer_ids = [self.config.num_hidden_layers]
|
1764
|
+
else:
|
1765
|
+
if weight_names is None:
|
1766
|
+
layer_ids = range(self.config.num_hidden_layers)
|
1767
|
+
else:
|
1768
|
+
layer_ids = set()
|
1769
|
+
for name in weight_names:
|
1770
|
+
if "kv_b_proj" in name:
|
1771
|
+
layer_id = int(name.split(".")[2])
|
1772
|
+
if layer_id < self.config.num_hidden_layers:
|
1773
|
+
layer_ids.add(layer_id)
|
1774
|
+
|
1489
1775
|
for layer_id in layer_ids:
|
1490
1776
|
self_attn = (
|
1491
1777
|
self.model.layers[layer_id].self_attn
|
@@ -1521,46 +1807,58 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1521
1807
|
torch.float8_e4m3fn,
|
1522
1808
|
torch.float8_e4m3fnuz,
|
1523
1809
|
):
|
1524
|
-
if
|
1810
|
+
if (
|
1811
|
+
hasattr(self.quant_config, "weight_block_size")
|
1812
|
+
and self.quant_config.weight_block_size is not None
|
1813
|
+
):
|
1525
1814
|
weight_block_size = self.quant_config.weight_block_size
|
1526
|
-
|
1527
|
-
|
1528
|
-
|
1529
|
-
weight,
|
1530
|
-
|
1531
|
-
|
1532
|
-
|
1533
|
-
|
1534
|
-
|
1535
|
-
|
1536
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1815
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1816
|
+
if _is_fp8_fnuz:
|
1817
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1818
|
+
weight=w,
|
1819
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1820
|
+
input_scale=None,
|
1821
|
+
)
|
1822
|
+
else:
|
1823
|
+
weight = w
|
1824
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1537
1825
|
|
1826
|
+
if (
|
1827
|
+
_is_cuda
|
1828
|
+
and weight_block_size[0] == 128
|
1829
|
+
and weight_block_size[1] == 128
|
1830
|
+
and model_dtype == torch.bfloat16
|
1831
|
+
):
|
1538
1832
|
if (
|
1539
|
-
|
1540
|
-
and
|
1541
|
-
and
|
1542
|
-
and model_dtype == torch.bfloat16
|
1833
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1834
|
+
and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
|
1835
|
+
and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
|
1543
1836
|
):
|
1544
|
-
|
1545
|
-
|
1546
|
-
):
|
1547
|
-
block_scale = weight_scale
|
1548
|
-
use_deep_gemm_bmm = True
|
1549
|
-
else:
|
1550
|
-
w = block_quant_dequant(
|
1551
|
-
weight,
|
1552
|
-
weight_scale,
|
1553
|
-
weight_block_size,
|
1554
|
-
model_dtype,
|
1555
|
-
)
|
1837
|
+
block_scale = weight_scale
|
1838
|
+
use_deep_gemm_bmm = True
|
1556
1839
|
else:
|
1557
|
-
w
|
1558
|
-
weight,
|
1840
|
+
w = block_quant_dequant(
|
1841
|
+
weight,
|
1842
|
+
weight_scale,
|
1843
|
+
weight_block_size,
|
1844
|
+
model_dtype,
|
1559
1845
|
)
|
1560
|
-
|
1846
|
+
else:
|
1847
|
+
w, scale = block_quant_to_tensor_quant(
|
1848
|
+
weight, weight_scale, weight_block_size
|
1849
|
+
)
|
1850
|
+
self_attn.w_scale = scale
|
1561
1851
|
else:
|
1562
|
-
|
1563
|
-
|
1852
|
+
if _is_fp8_fnuz:
|
1853
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1854
|
+
weight=w,
|
1855
|
+
weight_scale=self_attn.kv_b_proj.weight_scale,
|
1856
|
+
input_scale=None,
|
1857
|
+
)
|
1858
|
+
else:
|
1859
|
+
weight = w
|
1860
|
+
weight_scale = self_attn.kv_b_proj.weight_scale
|
1861
|
+
|
1564
1862
|
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
1565
1863
|
self_attn.w_scale = scale
|
1566
1864
|
|
@@ -1585,13 +1883,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1585
1883
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1586
1884
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1587
1885
|
if not use_deep_gemm_bmm:
|
1588
|
-
self_attn.w_kc =
|
1589
|
-
|
1886
|
+
self_attn.w_kc = bind_or_assign(
|
1887
|
+
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1888
|
+
)
|
1889
|
+
self_attn.w_vc = bind_or_assign(
|
1890
|
+
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
|
1891
|
+
)
|
1590
1892
|
if (
|
1591
1893
|
hasattr(self_attn.kv_b_proj, "weight_scale")
|
1592
1894
|
and self_attn.w_scale is None
|
1593
1895
|
):
|
1594
|
-
self_attn.w_scale =
|
1896
|
+
self_attn.w_scale = bind_or_assign(
|
1897
|
+
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
|
1898
|
+
)
|
1595
1899
|
if _is_hip:
|
1596
1900
|
self_attn.w_scale *= 2.0
|
1597
1901
|
else:
|
@@ -1600,21 +1904,79 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1600
1904
|
ws_kc, ws_vc = block_scale.unflatten(
|
1601
1905
|
0, (-1, (num_tiles_k + num_tiles_n))
|
1602
1906
|
).split([num_tiles_k, num_tiles_n], dim=1)
|
1603
|
-
self_attn.w_scale_k =
|
1604
|
-
|
1605
|
-
|
1606
|
-
self_attn.
|
1907
|
+
self_attn.w_scale_k = bind_or_assign(
|
1908
|
+
self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
|
1909
|
+
)
|
1910
|
+
self_attn.w_scale_v = bind_or_assign(
|
1911
|
+
self_attn.w_scale_v, ws_vc.contiguous()
|
1912
|
+
)
|
1913
|
+
self_attn.w_kc = bind_or_assign(
|
1914
|
+
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
|
1915
|
+
)
|
1916
|
+
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
1607
1917
|
self_attn.use_deep_gemm_bmm = True
|
1608
1918
|
|
1609
|
-
|
1610
|
-
|
1611
|
-
|
1612
|
-
|
1613
|
-
|
1614
|
-
|
1615
|
-
|
1919
|
+
if (
|
1920
|
+
deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
|
1921
|
+
and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
|
1922
|
+
):
|
1923
|
+
self._weight_requant_ue8m0()
|
1924
|
+
|
1925
|
+
def _weight_requant_ue8m0(self):
|
1926
|
+
weight_block_size = self.quant_config.weight_block_size
|
1927
|
+
|
1928
|
+
moe_layers = list(
|
1929
|
+
range(
|
1930
|
+
self.config.first_k_dense_replace,
|
1931
|
+
self.config.num_hidden_layers,
|
1932
|
+
self.config.moe_layer_freq,
|
1933
|
+
)
|
1934
|
+
)
|
1935
|
+
|
1936
|
+
for layer_id in range(self.config.num_hidden_layers):
|
1937
|
+
layer = self.model.layers[layer_id]
|
1938
|
+
|
1939
|
+
for module in [
|
1940
|
+
layer.self_attn.fused_qkv_a_proj_with_mqa,
|
1941
|
+
layer.self_attn.q_b_proj,
|
1942
|
+
layer.self_attn.kv_b_proj,
|
1943
|
+
layer.self_attn.o_proj,
|
1944
|
+
]:
|
1945
|
+
requant_weight_ue8m0_inplace(
|
1946
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
1947
|
+
)
|
1948
|
+
|
1949
|
+
if layer_id in moe_layers:
|
1950
|
+
shared_experts = getattr(layer.mlp, "shared_experts", None)
|
1951
|
+
if shared_experts is not None:
|
1952
|
+
for module in [
|
1953
|
+
shared_experts.gate_up_proj,
|
1954
|
+
shared_experts.down_proj,
|
1955
|
+
]:
|
1956
|
+
requant_weight_ue8m0_inplace(
|
1957
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
1958
|
+
)
|
1959
|
+
|
1960
|
+
experts = layer.mlp.experts
|
1961
|
+
if isinstance(experts, DeepEPMoE):
|
1962
|
+
for w in [
|
1963
|
+
experts.w13_weight_fp8,
|
1964
|
+
experts.w2_weight_fp8,
|
1965
|
+
]:
|
1966
|
+
requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
|
1967
|
+
else:
|
1968
|
+
mlp = layer.mlp
|
1969
|
+
assert isinstance(mlp, DeepseekV2MLP)
|
1970
|
+
for module in [
|
1971
|
+
mlp.gate_up_proj,
|
1972
|
+
mlp.down_proj,
|
1973
|
+
]:
|
1974
|
+
requant_weight_ue8m0_inplace(
|
1975
|
+
module.weight, module.weight_scale_inv, weight_block_size
|
1976
|
+
)
|
1616
1977
|
|
1617
1978
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1979
|
+
|
1618
1980
|
if is_nextn:
|
1619
1981
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1620
1982
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
@@ -1633,60 +1995,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1633
1995
|
("gate_up_proj", "gate_proj", 0),
|
1634
1996
|
("gate_up_proj", "up_proj", 1),
|
1635
1997
|
]
|
1636
|
-
if self.n_share_experts_fusion > 0:
|
1637
|
-
weights_list = list(weights)
|
1638
|
-
weights_dict = dict(weights_list)
|
1639
|
-
if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
|
1640
|
-
suffix_list = [
|
1641
|
-
"down_proj.weight",
|
1642
|
-
"down_proj.weight_scale",
|
1643
|
-
"gate_proj.weight",
|
1644
|
-
"gate_proj.weight_scale",
|
1645
|
-
"up_proj.weight",
|
1646
|
-
"up_proj.weight_scale",
|
1647
|
-
]
|
1648
|
-
else:
|
1649
|
-
suffix_list = [
|
1650
|
-
"down_proj.weight",
|
1651
|
-
"down_proj.weight_scale_inv",
|
1652
|
-
"gate_proj.weight",
|
1653
|
-
"gate_proj.weight_scale_inv",
|
1654
|
-
"up_proj.weight",
|
1655
|
-
"up_proj.weight_scale_inv",
|
1656
|
-
]
|
1657
|
-
names_to_remove = []
|
1658
|
-
|
1659
|
-
moe_layers = (
|
1660
|
-
range(
|
1661
|
-
self.config.first_k_dense_replace,
|
1662
|
-
self.config.num_hidden_layers,
|
1663
|
-
self.config.moe_layer_freq,
|
1664
|
-
)
|
1665
|
-
if not is_nextn
|
1666
|
-
else [nextn_layer_id]
|
1667
|
-
)
|
1668
|
-
|
1669
|
-
for moe_layer in tqdm(
|
1670
|
-
moe_layers,
|
1671
|
-
desc=f"Cloning {self.n_share_experts_fusion} "
|
1672
|
-
"replicas of the shared expert into MoE",
|
1673
|
-
):
|
1674
|
-
for suffix in suffix_list:
|
1675
|
-
shared_expert_weight_name = (
|
1676
|
-
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
1677
|
-
)
|
1678
|
-
for num_repeat in range(self.n_share_experts_fusion):
|
1679
|
-
weights_list.append(
|
1680
|
-
(
|
1681
|
-
f"model.layers.{moe_layer}."
|
1682
|
-
f"mlp.experts."
|
1683
|
-
f"{self.config.n_routed_experts + num_repeat}"
|
1684
|
-
f".{suffix}",
|
1685
|
-
weights_dict[shared_expert_weight_name],
|
1686
|
-
)
|
1687
|
-
)
|
1688
|
-
names_to_remove += [shared_expert_weight_name]
|
1689
|
-
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
1690
1998
|
|
1691
1999
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1692
2000
|
# (param_name, weight_name, expert_id, shard_id)
|
@@ -1694,7 +2002,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1694
2002
|
ckpt_gate_proj_name="gate_proj",
|
1695
2003
|
ckpt_down_proj_name="down_proj",
|
1696
2004
|
ckpt_up_proj_name="up_proj",
|
1697
|
-
num_experts=self.config.n_routed_experts + self.
|
2005
|
+
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
1698
2006
|
)
|
1699
2007
|
|
1700
2008
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
@@ -1712,8 +2020,21 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1712
2020
|
"hnorm",
|
1713
2021
|
]
|
1714
2022
|
|
2023
|
+
if self.num_fused_shared_experts > 0:
|
2024
|
+
assert self.num_fused_shared_experts == 1
|
2025
|
+
logger.info("Shared experts fusion optimization enabled.")
|
2026
|
+
|
1715
2027
|
params_dict = dict(self.named_parameters())
|
2028
|
+
weight_names = []
|
1716
2029
|
for name, loaded_weight in weights:
|
2030
|
+
if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
|
2031
|
+
name = name.replace(
|
2032
|
+
"mlp.shared_experts",
|
2033
|
+
f"mlp.experts.{self.config.n_routed_experts}",
|
2034
|
+
)
|
2035
|
+
|
2036
|
+
weight_names.append(name)
|
2037
|
+
|
1717
2038
|
if not is_nextn:
|
1718
2039
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1719
2040
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
@@ -1785,7 +2106,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1785
2106
|
# Skip loading extra bias for GPTQ models.
|
1786
2107
|
if name.endswith(".bias") and name not in params_dict:
|
1787
2108
|
continue
|
1788
|
-
|
1789
2109
|
if fuse_qkv_a_proj and (
|
1790
2110
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
1791
2111
|
):
|
@@ -1811,9 +2131,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1811
2131
|
fused_weight = torch.cat(
|
1812
2132
|
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
1813
2133
|
)
|
1814
|
-
|
1815
|
-
|
1816
|
-
"q_a_proj"
|
2134
|
+
param_name = (
|
2135
|
+
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
2136
|
+
if "q_a_proj" in name
|
2137
|
+
else name.replace(
|
2138
|
+
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
2139
|
+
)
|
1817
2140
|
)
|
1818
2141
|
param = params_dict[param_name]
|
1819
2142
|
|
@@ -1824,13 +2147,23 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1824
2147
|
cached_a_proj.pop(q_a_proj_name)
|
1825
2148
|
cached_a_proj.pop(kv_a_proj_name)
|
1826
2149
|
else:
|
2150
|
+
if (
|
2151
|
+
"k_scale" in name or "v_scale" in name
|
2152
|
+
) and name not in params_dict:
|
2153
|
+
# modelopt attn kv scale is named differently
|
2154
|
+
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
2155
|
+
name = name.replace("_proj", "attn_mqa")
|
2156
|
+
else:
|
2157
|
+
logger.warning(
|
2158
|
+
f"Unknown scale found in checkpoint: {name}"
|
2159
|
+
)
|
1827
2160
|
param = params_dict[name]
|
1828
2161
|
weight_loader = getattr(
|
1829
2162
|
param, "weight_loader", default_weight_loader
|
1830
2163
|
)
|
1831
2164
|
weight_loader(param, loaded_weight)
|
1832
2165
|
|
1833
|
-
self.post_load_weights(is_nextn=is_nextn)
|
2166
|
+
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
1834
2167
|
|
1835
2168
|
def get_embed_and_head(self):
|
1836
2169
|
return self.model.embed_tokens.weight, self.lm_head.weight
|