sglang 0.4.6.post4__py3-none-any.whl → 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +16 -10
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +86 -22
- sglang/bench_serving.py +197 -110
- sglang/compile_deep_gemm.py +4 -4
- sglang/lang/backend/runtime_endpoint.py +24 -1
- sglang/profiler.py +167 -0
- sglang/srt/_custom_ops.py +34 -0
- sglang/srt/configs/internvl.py +8 -12
- sglang/srt/configs/model_config.py +66 -29
- sglang/srt/constrained/base_grammar_backend.py +5 -2
- sglang/srt/constrained/llguidance_backend.py +9 -8
- sglang/srt/constrained/outlines_backend.py +5 -4
- sglang/srt/constrained/xgrammar_backend.py +18 -18
- sglang/srt/conversation.py +47 -9
- sglang/srt/custom_op.py +38 -3
- sglang/srt/debug_utils.py +74 -0
- sglang/srt/disaggregation/common/__init__.py +1 -0
- sglang/srt/disaggregation/common/conn.py +407 -0
- sglang/srt/disaggregation/decode.py +187 -134
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +4 -13
- sglang/srt/disaggregation/kv_events.py +412 -0
- sglang/srt/disaggregation/launch_lb.py +140 -0
- sglang/srt/disaggregation/mini_lb.py +84 -70
- sglang/srt/disaggregation/mooncake/conn.py +441 -140
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -14
- sglang/srt/disaggregation/nixl/conn.py +124 -442
- sglang/srt/disaggregation/prefill.py +128 -44
- sglang/srt/disaggregation/utils.py +154 -6
- sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
- sglang/srt/distributed/parallel_state.py +52 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +11 -0
- sglang/srt/entrypoints/engine.py +129 -12
- sglang/srt/entrypoints/http_server.py +21 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +302 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +205 -0
- sglang/srt/function_call/ebnf_composer.py +248 -0
- sglang/srt/function_call/function_call_parser.py +202 -0
- sglang/srt/function_call/llama32_detector.py +93 -0
- sglang/srt/function_call/mistral_detector.py +131 -0
- sglang/srt/function_call/pythonic_detector.py +229 -0
- sglang/srt/function_call/qwen25_detector.py +121 -0
- sglang/srt/function_call/utils.py +52 -0
- sglang/srt/hf_transformers_utils.py +50 -7
- sglang/srt/layers/attention/aiter_backend.py +878 -0
- sglang/srt/layers/attention/base_attn_backend.py +4 -0
- sglang/srt/layers/attention/cutlass_mla_backend.py +2 -19
- sglang/srt/layers/attention/flashattention_backend.py +166 -35
- sglang/srt/layers/attention/flashinfer_backend.py +45 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +45 -5
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/intel_amx_backend.py +128 -0
- sglang/srt/layers/attention/tbo_backend.py +232 -0
- sglang/srt/layers/attention/torch_native_backend.py +3 -0
- sglang/srt/layers/attention/triton_backend.py +247 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +517 -0
- sglang/srt/layers/dp_attention.py +6 -15
- sglang/srt/layers/layernorm.py +30 -19
- sglang/srt/layers/moe/cutlass_moe.py +370 -0
- sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +60 -17
- sglang/srt/layers/moe/ep_moe/layer.py +195 -87
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +88 -8
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +220 -25
- sglang/srt/layers/moe/fused_moe_triton/layer.py +48 -4
- sglang/srt/layers/moe/topk.py +107 -24
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +10 -4
- sglang/srt/layers/quantization/blockwise_int8.py +3 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
- sglang/srt/layers/quantization/deep_gemm.py +60 -59
- sglang/srt/layers/quantization/fp8.py +113 -18
- sglang/srt/layers/quantization/fp8_kernel.py +118 -66
- sglang/srt/layers/quantization/fp8_utils.py +165 -43
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/modelopt_quant.py +334 -7
- sglang/srt/layers/quantization/moe_wna16.py +3 -0
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
- sglang/srt/layers/quantization/w8a8_int8.py +3 -0
- sglang/srt/layers/rotary_embedding.py +6 -12
- sglang/srt/layers/sampler.py +80 -79
- sglang/srt/layers/utils.py +6 -0
- sglang/srt/lora/layers.py +12 -15
- sglang/srt/lora/lora.py +49 -5
- sglang/srt/lora/lora_manager.py +20 -8
- sglang/srt/lora/mem_pool.py +24 -16
- sglang/srt/lora/utils.py +17 -13
- sglang/srt/managers/data_parallel_controller.py +13 -5
- sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
- sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
- sglang/srt/managers/eplb_algorithms/deepseek_vec.py +276 -0
- sglang/srt/managers/eplb_manager.py +96 -0
- sglang/srt/managers/expert_distribution.py +878 -56
- sglang/srt/managers/expert_location.py +448 -0
- sglang/srt/managers/expert_location_dispatch.py +108 -0
- sglang/srt/managers/io_struct.py +29 -5
- sglang/srt/managers/mm_utils.py +355 -151
- sglang/srt/managers/multimodal_processors/base_processor.py +299 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +15 -17
- sglang/srt/managers/multimodal_processors/internvl.py +18 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +14 -32
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +27 -32
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +35 -35
- sglang/srt/managers/schedule_batch.py +185 -55
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +389 -154
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +231 -39
- sglang/srt/managers/utils.py +0 -4
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +11 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +87 -65
- sglang/srt/model_executor/expert_location_updater.py +557 -0
- sglang/srt/model_executor/forward_batch_info.py +39 -14
- sglang/srt/model_executor/model_runner.py +231 -101
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/model_loader/utils.py +67 -1
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_nextn.py +1 -1
- sglang/srt/models/deepseek_v2.py +732 -403
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_causal.py +7 -0
- sglang/srt/models/gemma3_mm.py +75 -33
- sglang/srt/models/idefics2.py +342 -0
- sglang/srt/models/kimi_vl.py +4 -4
- sglang/srt/models/llama.py +1 -1
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +7 -17
- sglang/srt/models/minicpmv.py +3 -295
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/phi4mm.py +512 -0
- sglang/srt/models/qwen2.py +133 -35
- sglang/srt/models/qwen2_5_vl.py +5 -3
- sglang/srt/models/qwen2_eagle.py +4 -1
- sglang/srt/models/qwen2_moe.py +206 -69
- sglang/srt/models/qwen2_vl.py +3 -3
- sglang/srt/models/qwen3.py +92 -19
- sglang/srt/models/qwen3_moe.py +457 -55
- sglang/srt/models/registry.py +9 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/transformers.py +291 -0
- sglang/srt/openai_api/adapter.py +114 -40
- sglang/srt/openai_api/protocol.py +37 -2
- sglang/srt/openai_api/utils.py +172 -0
- sglang/srt/operations.py +189 -0
- sglang/srt/operations_strategy.py +207 -0
- sglang/srt/sampling/sampling_batch_info.py +13 -1
- sglang/srt/sampling/sampling_params.py +2 -1
- sglang/srt/server_args.py +235 -38
- sglang/srt/speculative/build_eagle_tree.py +8 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -11
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +253 -0
- sglang/srt/speculative/eagle_utils.py +181 -90
- sglang/srt/speculative/eagle_worker.py +146 -21
- sglang/srt/two_batch_overlap.py +635 -0
- sglang/srt/utils.py +197 -19
- sglang/test/runners.py +16 -7
- sglang/test/send_one.py +4 -0
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_fp4_moe.py +248 -0
- sglang/test/test_utils.py +81 -42
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/METADATA +31 -19
- sglang-0.4.7.dist-info/RECORD +699 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- sglang-0.4.6.post4.dist-info/RECORD +0 -646
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -18,8 +18,7 @@
|
|
18
18
|
|
19
19
|
import logging
|
20
20
|
import os
|
21
|
-
from
|
22
|
-
from enum import Enum, IntEnum, auto
|
21
|
+
from enum import IntEnum, auto
|
23
22
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
24
23
|
|
25
24
|
import torch
|
@@ -29,17 +28,17 @@ from tqdm import tqdm
|
|
29
28
|
from transformers import PretrainedConfig
|
30
29
|
|
31
30
|
from sglang.srt.distributed import (
|
32
|
-
get_tensor_model_parallel_rank,
|
33
31
|
get_tensor_model_parallel_world_size,
|
34
32
|
parallel_state,
|
35
33
|
tensor_model_parallel_all_reduce,
|
36
34
|
)
|
37
35
|
from sglang.srt.layers.activation import SiluAndMul
|
36
|
+
from sglang.srt.layers.communicator import (
|
37
|
+
LayerCommunicator,
|
38
|
+
LayerScatterModes,
|
39
|
+
enable_moe_dense_fully_dp,
|
40
|
+
)
|
38
41
|
from sglang.srt.layers.dp_attention import (
|
39
|
-
attn_tp_all_gather,
|
40
|
-
attn_tp_reduce_scatter,
|
41
|
-
dp_gather_partial,
|
42
|
-
dp_scatter,
|
43
42
|
get_attention_tp_rank,
|
44
43
|
get_attention_tp_size,
|
45
44
|
get_local_attention_dp_size,
|
@@ -52,13 +51,13 @@ from sglang.srt.layers.linear import (
|
|
52
51
|
RowParallelLinear,
|
53
52
|
)
|
54
53
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
55
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
54
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
56
55
|
from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
|
57
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
58
56
|
from sglang.srt.layers.moe.topk import select_experts
|
59
57
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
58
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
61
59
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
60
|
+
is_fp8_fnuz,
|
62
61
|
per_tensor_quant_mla_fp8,
|
63
62
|
per_token_group_quant_mla_deep_gemm_masked_fp8,
|
64
63
|
)
|
@@ -72,28 +71,41 @@ from sglang.srt.layers.quantization.int8_utils import (
|
|
72
71
|
block_dequant as int8_block_dequant,
|
73
72
|
)
|
74
73
|
from sglang.srt.layers.radix_attention import RadixAttention
|
75
|
-
from sglang.srt.layers.rotary_embedding import get_rope
|
74
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
76
75
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
77
76
|
ParallelLMHead,
|
78
77
|
VocabParallelEmbedding,
|
79
78
|
)
|
80
|
-
from sglang.srt.managers.expert_distribution import
|
79
|
+
from sglang.srt.managers.expert_distribution import (
|
80
|
+
get_global_expert_distribution_recorder,
|
81
|
+
)
|
82
|
+
from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
|
83
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
81
84
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
82
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
85
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
83
86
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
87
|
+
from sglang.srt.two_batch_overlap import (
|
88
|
+
MaybeTboDeepEPDispatcher,
|
89
|
+
model_forward_maybe_tbo,
|
90
|
+
)
|
84
91
|
from sglang.srt.utils import (
|
85
92
|
BumpAllocator,
|
86
93
|
DeepEPMode,
|
94
|
+
LazyValue,
|
87
95
|
add_prefix,
|
96
|
+
bind_or_assign,
|
88
97
|
get_bool_env_var,
|
89
98
|
get_int_env_var,
|
90
99
|
is_cuda,
|
91
100
|
is_hip,
|
101
|
+
is_non_idle_and_non_empty,
|
92
102
|
log_info_on_rank0,
|
93
103
|
)
|
94
104
|
|
95
105
|
_is_hip = is_hip()
|
96
106
|
_is_cuda = is_cuda()
|
107
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
108
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
97
109
|
|
98
110
|
if _is_cuda:
|
99
111
|
from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
|
@@ -109,7 +121,8 @@ if _is_hip:
|
|
109
121
|
decode_attention_fwd_grouped_rope,
|
110
122
|
)
|
111
123
|
|
112
|
-
|
124
|
+
if _use_aiter:
|
125
|
+
from aiter.rotary_embedding import get_rope
|
113
126
|
|
114
127
|
logger = logging.getLogger(__name__)
|
115
128
|
|
@@ -125,6 +138,9 @@ class AttnForwardMethod(IntEnum):
|
|
125
138
|
# This method can avoid OOM when prefix lengths are long.
|
126
139
|
MHA_CHUNKED_KV = auto()
|
127
140
|
|
141
|
+
# Use MLA but with fused RoPE
|
142
|
+
MLA_FUSED_ROPE = auto()
|
143
|
+
|
128
144
|
|
129
145
|
class DeepseekV2MLP(nn.Module):
|
130
146
|
def __init__(
|
@@ -139,6 +155,8 @@ class DeepseekV2MLP(nn.Module):
|
|
139
155
|
tp_size: Optional[int] = None,
|
140
156
|
) -> None:
|
141
157
|
super().__init__()
|
158
|
+
self.tp_size = tp_size
|
159
|
+
|
142
160
|
self.gate_up_proj = MergedColumnParallelLinear(
|
143
161
|
hidden_size,
|
144
162
|
[intermediate_size] * 2,
|
@@ -165,7 +183,10 @@ class DeepseekV2MLP(nn.Module):
|
|
165
183
|
)
|
166
184
|
self.act_fn = SiluAndMul()
|
167
185
|
|
168
|
-
def forward(self, x,
|
186
|
+
def forward(self, x, forward_batch=None):
|
187
|
+
if (self.tp_size == 1) and x.shape[0] == 0:
|
188
|
+
return x
|
189
|
+
|
169
190
|
gate_up, _ = self.gate_up_proj(x)
|
170
191
|
x = self.act_fn(gate_up)
|
171
192
|
x, _ = self.down_proj(x)
|
@@ -199,6 +220,7 @@ class DeepseekV2MoE(nn.Module):
|
|
199
220
|
def __init__(
|
200
221
|
self,
|
201
222
|
config: PretrainedConfig,
|
223
|
+
layer_id: int,
|
202
224
|
quant_config: Optional[QuantizationConfig] = None,
|
203
225
|
prefix: str = "",
|
204
226
|
):
|
@@ -206,7 +228,13 @@ class DeepseekV2MoE(nn.Module):
|
|
206
228
|
self.tp_size = get_tensor_model_parallel_world_size()
|
207
229
|
self.routed_scaling_factor = config.routed_scaling_factor
|
208
230
|
self.n_shared_experts = config.n_shared_experts
|
209
|
-
self.
|
231
|
+
self.num_fused_shared_experts = (
|
232
|
+
0
|
233
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
234
|
+
else config.n_shared_experts
|
235
|
+
)
|
236
|
+
self.config = config
|
237
|
+
self.layer_id = layer_id
|
210
238
|
|
211
239
|
if self.tp_size > config.n_routed_experts:
|
212
240
|
raise ValueError(
|
@@ -222,21 +250,19 @@ class DeepseekV2MoE(nn.Module):
|
|
222
250
|
|
223
251
|
self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
|
224
252
|
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
|
230
|
-
|
231
|
-
self.experts = MoEImpl(
|
232
|
-
num_experts=config.n_routed_experts + self.n_share_experts_fusion,
|
233
|
-
top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
|
253
|
+
self.experts = get_moe_impl_class()(
|
254
|
+
num_experts=config.n_routed_experts
|
255
|
+
+ self.num_fused_shared_experts
|
256
|
+
+ global_server_args_dict["ep_num_redundant_experts"],
|
257
|
+
top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
|
234
258
|
hidden_size=config.hidden_size,
|
235
259
|
intermediate_size=config.moe_intermediate_size,
|
260
|
+
layer_id=self.layer_id,
|
236
261
|
renormalize=config.norm_topk_prob,
|
237
262
|
quant_config=quant_config,
|
238
263
|
use_grouped_topk=True,
|
239
264
|
num_expert_group=config.n_group,
|
265
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
240
266
|
topk_group=config.topk_group,
|
241
267
|
correction_bias=self.gate.e_score_correction_bias,
|
242
268
|
routed_scaling_factor=self.routed_scaling_factor,
|
@@ -248,35 +274,32 @@ class DeepseekV2MoE(nn.Module):
|
|
248
274
|
),
|
249
275
|
)
|
250
276
|
|
251
|
-
if config.n_shared_experts is not None and self.
|
277
|
+
if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
|
252
278
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
253
279
|
# disable tp for shared experts when enable deepep moe
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
|
264
|
-
|
265
|
-
|
266
|
-
|
267
|
-
|
268
|
-
|
269
|
-
reduce_results=False,
|
270
|
-
prefix=add_prefix("shared_experts", prefix),
|
271
|
-
tp_rank=0,
|
272
|
-
tp_size=1,
|
273
|
-
)
|
280
|
+
self.shared_experts = DeepseekV2MLP(
|
281
|
+
hidden_size=config.hidden_size,
|
282
|
+
intermediate_size=intermediate_size,
|
283
|
+
hidden_act=config.hidden_act,
|
284
|
+
quant_config=quant_config,
|
285
|
+
reduce_results=False,
|
286
|
+
prefix=add_prefix("shared_experts", prefix),
|
287
|
+
**(
|
288
|
+
dict(tp_rank=0, tp_size=1)
|
289
|
+
if global_server_args_dict["enable_deepep_moe"]
|
290
|
+
else {}
|
291
|
+
),
|
292
|
+
)
|
293
|
+
|
294
|
+
self.top_k = config.num_experts_per_tok
|
274
295
|
|
275
296
|
if global_server_args_dict["enable_deepep_moe"]:
|
276
297
|
# TODO: we will support tp < ep in the future
|
277
298
|
self.ep_size = get_tensor_model_parallel_world_size()
|
278
|
-
self.num_experts =
|
279
|
-
|
299
|
+
self.num_experts = (
|
300
|
+
config.n_routed_experts
|
301
|
+
+ global_server_args_dict["ep_num_redundant_experts"]
|
302
|
+
)
|
280
303
|
self.renormalize = config.norm_topk_prob
|
281
304
|
self.topk_group = config.topk_group
|
282
305
|
self.num_expert_group = config.n_group
|
@@ -286,35 +309,45 @@ class DeepseekV2MoE(nn.Module):
|
|
286
309
|
else None
|
287
310
|
)
|
288
311
|
|
289
|
-
self.deepep_dispatcher =
|
312
|
+
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
290
313
|
group=parallel_state.get_tp_group().device_group,
|
291
314
|
router_topk=self.top_k,
|
292
315
|
permute_fusion=True,
|
293
|
-
num_experts=
|
316
|
+
num_experts=self.num_experts,
|
294
317
|
num_local_experts=config.n_routed_experts // self.tp_size,
|
295
318
|
hidden_size=config.hidden_size,
|
296
319
|
params_dtype=config.torch_dtype,
|
297
320
|
deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
|
298
|
-
async_finish=True,
|
321
|
+
async_finish=True,
|
299
322
|
return_recv_hook=True,
|
300
323
|
)
|
301
324
|
|
325
|
+
self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
|
326
|
+
|
327
|
+
def get_moe_weights(self):
|
328
|
+
return [
|
329
|
+
x.data
|
330
|
+
for name, x in self.experts.named_parameters()
|
331
|
+
if name not in ["correction_bias"]
|
332
|
+
]
|
333
|
+
|
302
334
|
def forward(
|
303
|
-
self, hidden_states: torch.Tensor,
|
335
|
+
self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
|
304
336
|
) -> torch.Tensor:
|
305
|
-
if not
|
337
|
+
if not self._enable_deepep_moe:
|
306
338
|
return self.forward_normal(hidden_states)
|
307
339
|
else:
|
308
|
-
return self.forward_deepep(hidden_states,
|
340
|
+
return self.forward_deepep(hidden_states, forward_batch)
|
309
341
|
|
310
342
|
def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
311
343
|
shared_output = self._forward_shared_experts(hidden_states)
|
312
344
|
# router_logits: (num_tokens, n_experts)
|
313
345
|
router_logits = self.gate(hidden_states)
|
314
|
-
final_hidden_states = (
|
315
|
-
|
316
|
-
* self.routed_scaling_factor
|
346
|
+
final_hidden_states = self.experts(
|
347
|
+
hidden_states=hidden_states, router_logits=router_logits
|
317
348
|
)
|
349
|
+
if not _is_cuda:
|
350
|
+
final_hidden_states *= self.routed_scaling_factor
|
318
351
|
if shared_output is not None:
|
319
352
|
final_hidden_states = final_hidden_states + shared_output
|
320
353
|
if self.tp_size > 1:
|
@@ -322,14 +355,11 @@ class DeepseekV2MoE(nn.Module):
|
|
322
355
|
return final_hidden_states
|
323
356
|
|
324
357
|
def forward_deepep(
|
325
|
-
self, hidden_states: torch.Tensor,
|
358
|
+
self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
|
326
359
|
) -> torch.Tensor:
|
360
|
+
forward_mode = forward_batch.forward_mode
|
327
361
|
shared_output = None
|
328
|
-
if (
|
329
|
-
forward_mode is not None
|
330
|
-
and not forward_mode.is_idle()
|
331
|
-
and hidden_states.shape[0] > 0
|
332
|
-
):
|
362
|
+
if is_non_idle_and_non_empty(forward_mode, hidden_states):
|
333
363
|
# router_logits: (num_tokens, n_experts)
|
334
364
|
router_logits = self.gate(hidden_states)
|
335
365
|
shared_output = self._forward_shared_experts(hidden_states)
|
@@ -341,8 +371,13 @@ class DeepseekV2MoE(nn.Module):
|
|
341
371
|
renormalize=self.renormalize,
|
342
372
|
topk_group=self.topk_group,
|
343
373
|
num_expert_group=self.num_expert_group,
|
374
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
344
375
|
correction_bias=self.correction_bias,
|
345
376
|
routed_scaling_factor=self.routed_scaling_factor,
|
377
|
+
num_token_non_padded=forward_batch.num_token_non_padded,
|
378
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
379
|
+
layer_id=self.layer_id,
|
380
|
+
),
|
346
381
|
)
|
347
382
|
else:
|
348
383
|
topk_idx = torch.full(
|
@@ -363,9 +398,9 @@ class DeepseekV2MoE(nn.Module):
|
|
363
398
|
masked_m,
|
364
399
|
expected_m,
|
365
400
|
) = self.deepep_dispatcher.dispatch(
|
366
|
-
hidden_states,
|
367
|
-
topk_idx,
|
368
|
-
topk_weights,
|
401
|
+
hidden_states=hidden_states,
|
402
|
+
topk_idx=topk_idx,
|
403
|
+
topk_weights=topk_weights,
|
369
404
|
forward_mode=forward_mode,
|
370
405
|
)
|
371
406
|
final_hidden_states = self.experts(
|
@@ -381,24 +416,147 @@ class DeepseekV2MoE(nn.Module):
|
|
381
416
|
)
|
382
417
|
if self.ep_size > 1:
|
383
418
|
final_hidden_states = self.deepep_dispatcher.combine(
|
384
|
-
final_hidden_states,
|
385
|
-
topk_idx,
|
386
|
-
topk_weights,
|
387
|
-
forward_mode,
|
419
|
+
hidden_states=final_hidden_states,
|
420
|
+
topk_idx=topk_idx,
|
421
|
+
topk_weights=topk_weights,
|
422
|
+
forward_mode=forward_mode,
|
388
423
|
)
|
389
|
-
final_hidden_states *= self.routed_scaling_factor
|
390
424
|
|
391
425
|
if shared_output is not None:
|
392
|
-
|
426
|
+
x = shared_output
|
427
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
428
|
+
final_hidden_states = x
|
429
|
+
else:
|
430
|
+
final_hidden_states *= self.routed_scaling_factor
|
393
431
|
|
394
432
|
return final_hidden_states
|
395
433
|
|
396
434
|
def _forward_shared_experts(self, hidden_states):
|
397
|
-
if self.
|
435
|
+
if self.num_fused_shared_experts == 0:
|
398
436
|
return self.shared_experts(hidden_states)
|
399
437
|
else:
|
400
438
|
return None
|
401
439
|
|
440
|
+
def op_gate(self, state):
|
441
|
+
if is_non_idle_and_non_empty(
|
442
|
+
state.forward_batch.forward_mode, state.hidden_states_mlp_input
|
443
|
+
):
|
444
|
+
# router_logits: (num_tokens, n_experts)
|
445
|
+
state.router_logits = self.gate(state.hidden_states_mlp_input)
|
446
|
+
else:
|
447
|
+
state.router_logits = None
|
448
|
+
|
449
|
+
def op_shared_experts(self, state):
|
450
|
+
hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
|
451
|
+
if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
|
452
|
+
state.forward_batch.forward_mode, hidden_states_mlp_input
|
453
|
+
):
|
454
|
+
state.shared_output = self.shared_experts(hidden_states_mlp_input)
|
455
|
+
else:
|
456
|
+
state.shared_output = None
|
457
|
+
|
458
|
+
def op_select_experts(self, state):
|
459
|
+
router_logits = state.pop("router_logits")
|
460
|
+
hidden_states = state.hidden_states_mlp_input
|
461
|
+
|
462
|
+
if router_logits is not None:
|
463
|
+
with get_global_expert_distribution_recorder().with_current_layer(
|
464
|
+
self.layer_id
|
465
|
+
):
|
466
|
+
state.topk_weights_local, state.topk_idx_local = select_experts(
|
467
|
+
hidden_states=hidden_states,
|
468
|
+
router_logits=router_logits,
|
469
|
+
top_k=self.top_k,
|
470
|
+
use_grouped_topk=True,
|
471
|
+
renormalize=self.renormalize,
|
472
|
+
topk_group=self.topk_group,
|
473
|
+
num_expert_group=self.num_expert_group,
|
474
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
475
|
+
correction_bias=self.correction_bias,
|
476
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
477
|
+
num_token_non_padded=state.forward_batch.num_token_non_padded,
|
478
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
479
|
+
layer_id=self.layer_id,
|
480
|
+
),
|
481
|
+
)
|
482
|
+
else:
|
483
|
+
state.topk_idx_local = torch.full(
|
484
|
+
(0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
|
485
|
+
)
|
486
|
+
state.topk_weights_local = torch.empty(
|
487
|
+
(0, self.top_k), dtype=torch.float32, device=hidden_states.device
|
488
|
+
)
|
489
|
+
|
490
|
+
def op_dispatch_a(self, state):
|
491
|
+
if self.ep_size > 1:
|
492
|
+
# TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
|
493
|
+
self.deepep_dispatcher.dispatch_a(
|
494
|
+
hidden_states=state.hidden_states_mlp_input,
|
495
|
+
topk_idx=state.pop("topk_idx_local"),
|
496
|
+
topk_weights=state.pop("topk_weights_local"),
|
497
|
+
forward_mode=state.forward_batch.forward_mode,
|
498
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
499
|
+
)
|
500
|
+
|
501
|
+
def op_dispatch_b(self, state):
|
502
|
+
if self.ep_size > 1:
|
503
|
+
with get_global_expert_distribution_recorder().with_current_layer(
|
504
|
+
self.layer_id
|
505
|
+
):
|
506
|
+
(
|
507
|
+
state.hidden_states_experts_input,
|
508
|
+
state.topk_idx_dispatched,
|
509
|
+
state.topk_weights_dispatched,
|
510
|
+
state.reorder_topk_ids,
|
511
|
+
state.num_recv_tokens_per_expert,
|
512
|
+
state.seg_indptr,
|
513
|
+
state.masked_m,
|
514
|
+
state.expected_m,
|
515
|
+
) = self.deepep_dispatcher.dispatch_b(
|
516
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
517
|
+
)
|
518
|
+
|
519
|
+
def op_experts(self, state):
|
520
|
+
state.hidden_states_experts_output = self.experts(
|
521
|
+
hidden_states=state.pop("hidden_states_experts_input"),
|
522
|
+
topk_idx=state.topk_idx_dispatched,
|
523
|
+
topk_weights=state.topk_weights_dispatched,
|
524
|
+
reorder_topk_ids=state.pop("reorder_topk_ids"),
|
525
|
+
seg_indptr=state.pop("seg_indptr"),
|
526
|
+
masked_m=state.pop("masked_m"),
|
527
|
+
expected_m=state.pop("expected_m"),
|
528
|
+
num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
|
529
|
+
forward_mode=state.forward_batch.forward_mode,
|
530
|
+
)
|
531
|
+
|
532
|
+
def op_combine_a(self, state):
|
533
|
+
if self.ep_size > 1:
|
534
|
+
self.deepep_dispatcher.combine_a(
|
535
|
+
hidden_states=state.pop("hidden_states_experts_output"),
|
536
|
+
topk_idx=state.pop("topk_idx_dispatched"),
|
537
|
+
topk_weights=state.pop("topk_weights_dispatched"),
|
538
|
+
forward_mode=state.forward_batch.forward_mode,
|
539
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
540
|
+
)
|
541
|
+
|
542
|
+
def op_combine_b(self, state):
|
543
|
+
if self.ep_size > 1:
|
544
|
+
state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
|
545
|
+
tbo_subbatch_index=state.get("tbo_subbatch_index"),
|
546
|
+
)
|
547
|
+
|
548
|
+
def op_output(self, state):
|
549
|
+
final_hidden_states = state.pop("hidden_states_after_combine")
|
550
|
+
|
551
|
+
if (shared_output := state.pop("shared_output")) is not None:
|
552
|
+
x = shared_output
|
553
|
+
x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
|
554
|
+
final_hidden_states = x
|
555
|
+
else:
|
556
|
+
final_hidden_states *= self.routed_scaling_factor
|
557
|
+
|
558
|
+
state.hidden_states_mlp_output = final_hidden_states
|
559
|
+
|
402
560
|
|
403
561
|
def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
404
562
|
import math
|
@@ -550,10 +708,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
550
708
|
)
|
551
709
|
|
552
710
|
self.alt_stream = alt_stream
|
711
|
+
self.attn_mha.kv_b_proj = None
|
553
712
|
|
554
713
|
self.w_kc = None
|
555
714
|
self.w_vc = None
|
556
|
-
self.w_scale =
|
715
|
+
self.w_scale = 1.0
|
557
716
|
|
558
717
|
self.w_scale_k = None
|
559
718
|
self.w_scale_v = None
|
@@ -578,6 +737,18 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
578
737
|
def dispatch_attn_forward_method(
|
579
738
|
self, forward_batch: ForwardBatch
|
580
739
|
) -> AttnForwardMethod:
|
740
|
+
def _dispatch_mla_subtype():
|
741
|
+
if _is_hip:
|
742
|
+
if (
|
743
|
+
self.rocm_fused_decode_mla
|
744
|
+
and forward_batch.forward_mode.is_decode()
|
745
|
+
):
|
746
|
+
return AttnForwardMethod.MLA_FUSED_ROPE
|
747
|
+
else:
|
748
|
+
return AttnForwardMethod.MLA
|
749
|
+
else:
|
750
|
+
return AttnForwardMethod.MLA
|
751
|
+
|
581
752
|
if self.attention_backend == "flashinfer":
|
582
753
|
# Flashinfer MLA: Do not absorb when enabling ragged prefill
|
583
754
|
if (
|
@@ -589,7 +760,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
589
760
|
):
|
590
761
|
return AttnForwardMethod.MHA
|
591
762
|
else:
|
592
|
-
return
|
763
|
+
return _dispatch_mla_subtype()
|
593
764
|
elif self.attention_backend == "fa3":
|
594
765
|
# Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
|
595
766
|
if forward_batch.extend_prefix_lens_cpu is not None:
|
@@ -605,6 +776,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
605
776
|
)
|
606
777
|
):
|
607
778
|
return AttnForwardMethod.MHA_CHUNKED_KV
|
779
|
+
else:
|
780
|
+
return _dispatch_mla_subtype()
|
781
|
+
elif self.attention_backend == "aiter":
|
782
|
+
if (
|
783
|
+
forward_batch.forward_mode.is_extend()
|
784
|
+
and not forward_batch.forward_mode.is_target_verify()
|
785
|
+
and not forward_batch.forward_mode.is_draft_extend()
|
786
|
+
):
|
787
|
+
return AttnForwardMethod.MHA
|
608
788
|
else:
|
609
789
|
return AttnForwardMethod.MLA
|
610
790
|
else:
|
@@ -617,7 +797,20 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
617
797
|
):
|
618
798
|
return AttnForwardMethod.MHA
|
619
799
|
else:
|
620
|
-
return
|
800
|
+
return _dispatch_mla_subtype()
|
801
|
+
|
802
|
+
def op_prepare(self, state):
|
803
|
+
state.attn_intermediate_state = self.forward_prepare(
|
804
|
+
positions=state.positions,
|
805
|
+
hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
|
806
|
+
forward_batch=state.forward_batch,
|
807
|
+
zero_allocator=state.zero_allocator,
|
808
|
+
)
|
809
|
+
|
810
|
+
def op_core(self, state):
|
811
|
+
state.hidden_states_after_attn = self.forward_core(
|
812
|
+
state.pop("attn_intermediate_state")
|
813
|
+
)
|
621
814
|
|
622
815
|
def forward(
|
623
816
|
self,
|
@@ -625,45 +818,78 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
625
818
|
hidden_states: torch.Tensor,
|
626
819
|
forward_batch: ForwardBatch,
|
627
820
|
zero_allocator: BumpAllocator,
|
628
|
-
)
|
821
|
+
):
|
822
|
+
s = self.forward_prepare(
|
823
|
+
positions=positions,
|
824
|
+
hidden_states=hidden_states,
|
825
|
+
forward_batch=forward_batch,
|
826
|
+
zero_allocator=zero_allocator,
|
827
|
+
)
|
828
|
+
return self.forward_core(s)
|
829
|
+
|
830
|
+
def forward_prepare(
|
831
|
+
self,
|
832
|
+
positions: torch.Tensor,
|
833
|
+
hidden_states: torch.Tensor,
|
834
|
+
forward_batch: ForwardBatch,
|
835
|
+
zero_allocator: BumpAllocator,
|
836
|
+
):
|
837
|
+
if self.attn_mha.kv_b_proj is None:
|
838
|
+
self.attn_mha.kv_b_proj = self.kv_b_proj
|
839
|
+
|
629
840
|
if hidden_states.shape[0] == 0:
|
630
841
|
assert (
|
631
842
|
not self.o_proj.reduce_results
|
632
843
|
), "short-circuiting allreduce will lead to hangs"
|
633
|
-
return hidden_states
|
844
|
+
return hidden_states, None, forward_batch, None
|
634
845
|
|
635
846
|
attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
|
636
847
|
|
637
848
|
if attn_forward_method == AttnForwardMethod.MHA:
|
638
|
-
|
849
|
+
inner_state = self.forward_normal_prepare(
|
850
|
+
positions, hidden_states, forward_batch, zero_allocator
|
851
|
+
)
|
639
852
|
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
640
|
-
|
641
|
-
positions, hidden_states, forward_batch
|
853
|
+
inner_state = self.forward_normal_chunked_kv_prepare(
|
854
|
+
positions, hidden_states, forward_batch, zero_allocator
|
855
|
+
)
|
856
|
+
elif attn_forward_method == AttnForwardMethod.MLA:
|
857
|
+
inner_state = self.forward_absorb_prepare(
|
858
|
+
positions, hidden_states, forward_batch, zero_allocator
|
859
|
+
)
|
860
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
861
|
+
inner_state = self.forward_absorb_fused_mla_rope_prepare(
|
862
|
+
positions, hidden_states, forward_batch, zero_allocator
|
642
863
|
)
|
643
864
|
else:
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
return self.forward_absorb(
|
654
|
-
positions, hidden_states, forward_batch, zero_allocator
|
655
|
-
)
|
656
|
-
else:
|
657
|
-
return self.forward_absorb(
|
658
|
-
positions, hidden_states, forward_batch, zero_allocator
|
659
|
-
)
|
865
|
+
raise NotImplementedError
|
866
|
+
return None, attn_forward_method, forward_batch, inner_state
|
867
|
+
|
868
|
+
def forward_core(self, intermediate_state):
|
869
|
+
hidden_states, attn_forward_method, forward_batch, inner_state = (
|
870
|
+
intermediate_state
|
871
|
+
)
|
872
|
+
if inner_state is None:
|
873
|
+
return hidden_states
|
660
874
|
|
661
|
-
|
875
|
+
if attn_forward_method == AttnForwardMethod.MHA:
|
876
|
+
return self.forward_normal_core(*inner_state)
|
877
|
+
elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
|
878
|
+
return self.forward_normal_chunked_kv_core(*inner_state)
|
879
|
+
elif attn_forward_method == AttnForwardMethod.MLA:
|
880
|
+
return self.forward_absorb_core(*inner_state)
|
881
|
+
elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
|
882
|
+
return self.forward_absorb_fused_mla_rope_core(*inner_state)
|
883
|
+
else:
|
884
|
+
raise NotImplementedError
|
885
|
+
|
886
|
+
def forward_normal_prepare(
|
662
887
|
self,
|
663
888
|
positions: torch.Tensor,
|
664
889
|
hidden_states: torch.Tensor,
|
665
890
|
forward_batch: ForwardBatch,
|
666
|
-
|
891
|
+
zero_allocator: BumpAllocator,
|
892
|
+
):
|
667
893
|
if self.q_lora_rank is not None:
|
668
894
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
669
895
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
@@ -698,18 +924,24 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
698
924
|
forward_batch.token_to_kv_pool.set_kv_buffer(
|
699
925
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
700
926
|
)
|
927
|
+
|
928
|
+
return q, k, v, forward_batch
|
929
|
+
|
930
|
+
def forward_normal_core(self, q, k, v, forward_batch):
|
701
931
|
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
702
932
|
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
703
933
|
output, _ = self.o_proj(attn_output)
|
704
934
|
return output
|
705
935
|
|
706
|
-
def
|
936
|
+
def forward_absorb_prepare(
|
707
937
|
self,
|
708
938
|
positions: torch.Tensor,
|
709
939
|
hidden_states: torch.Tensor,
|
710
940
|
forward_batch: ForwardBatch,
|
711
941
|
zero_allocator: BumpAllocator,
|
712
|
-
)
|
942
|
+
):
|
943
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
944
|
+
|
713
945
|
if self.q_lora_rank is not None:
|
714
946
|
q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
|
715
947
|
[self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
|
@@ -717,7 +949,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
717
949
|
k_nope = latent_cache[..., : self.kv_lora_rank]
|
718
950
|
|
719
951
|
# overlap qk norm
|
720
|
-
if self.alt_stream is not None and
|
952
|
+
if self.alt_stream is not None and get_is_capture_mode():
|
721
953
|
current_stream = torch.cuda.current_stream()
|
722
954
|
self.alt_stream.wait_stream(current_stream)
|
723
955
|
q = self.q_a_layernorm(q)
|
@@ -756,8 +988,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
756
988
|
expected_m,
|
757
989
|
)
|
758
990
|
q_nope_out = q_nope_out[:, :expected_m, :]
|
759
|
-
elif
|
760
|
-
# TODO(
|
991
|
+
elif _is_hip:
|
992
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
761
993
|
q_nope_out = torch.bmm(
|
762
994
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
763
995
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
@@ -776,6 +1008,11 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
776
1008
|
q_nope_out = q_nope_out.transpose(0, 1)
|
777
1009
|
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
778
1010
|
|
1011
|
+
return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1012
|
+
|
1013
|
+
def forward_absorb_core(
|
1014
|
+
self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
|
1015
|
+
):
|
779
1016
|
if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
|
780
1017
|
attn_output = self.attn_mqa(
|
781
1018
|
q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
|
@@ -803,8 +1040,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
803
1040
|
expected_m,
|
804
1041
|
)
|
805
1042
|
attn_bmm_output = attn_bmm_output[:, :expected_m, :]
|
806
|
-
elif
|
807
|
-
# TODO(
|
1043
|
+
elif _is_hip:
|
1044
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
808
1045
|
attn_bmm_output = torch.bmm(
|
809
1046
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
810
1047
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
@@ -828,13 +1065,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
828
1065
|
|
829
1066
|
return output
|
830
1067
|
|
831
|
-
def
|
1068
|
+
def forward_absorb_fused_mla_rope_prepare(
|
832
1069
|
self,
|
833
1070
|
positions: torch.Tensor,
|
834
1071
|
hidden_states: torch.Tensor,
|
835
1072
|
forward_batch: ForwardBatch,
|
836
1073
|
zero_allocator: BumpAllocator,
|
837
|
-
)
|
1074
|
+
):
|
838
1075
|
enable_rope_fusion = (
|
839
1076
|
os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
|
840
1077
|
)
|
@@ -855,8 +1092,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
855
1092
|
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
856
1093
|
q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
857
1094
|
|
858
|
-
if
|
859
|
-
# TODO(
|
1095
|
+
if _is_hip:
|
1096
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
860
1097
|
q_nope_out = torch.bmm(
|
861
1098
|
q_nope.to(torch.bfloat16).transpose(0, 1),
|
862
1099
|
self.w_kc.to(torch.bfloat16) * self.w_scale,
|
@@ -923,6 +1160,44 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
923
1160
|
)
|
924
1161
|
val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
|
925
1162
|
|
1163
|
+
return (
|
1164
|
+
q_input,
|
1165
|
+
key_cache_buf,
|
1166
|
+
val_cache_buf,
|
1167
|
+
attn_output,
|
1168
|
+
kv_indptr,
|
1169
|
+
kv_indices,
|
1170
|
+
k_pe_output,
|
1171
|
+
cos_sin_cache,
|
1172
|
+
positions,
|
1173
|
+
attn_logits,
|
1174
|
+
num_kv_split,
|
1175
|
+
sm_scale,
|
1176
|
+
enable_rope_fusion,
|
1177
|
+
k_input,
|
1178
|
+
forward_batch,
|
1179
|
+
zero_allocator,
|
1180
|
+
)
|
1181
|
+
|
1182
|
+
def forward_absorb_fused_mla_rope_core(
|
1183
|
+
self,
|
1184
|
+
q_input,
|
1185
|
+
key_cache_buf,
|
1186
|
+
val_cache_buf,
|
1187
|
+
attn_output,
|
1188
|
+
kv_indptr,
|
1189
|
+
kv_indices,
|
1190
|
+
k_pe_output,
|
1191
|
+
cos_sin_cache,
|
1192
|
+
positions,
|
1193
|
+
attn_logits,
|
1194
|
+
num_kv_split,
|
1195
|
+
sm_scale,
|
1196
|
+
enable_rope_fusion,
|
1197
|
+
k_input,
|
1198
|
+
forward_batch,
|
1199
|
+
zero_allocator,
|
1200
|
+
):
|
926
1201
|
decode_attention_fwd_grouped_rope(
|
927
1202
|
q_input,
|
928
1203
|
key_cache_buf,
|
@@ -951,8 +1226,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
951
1226
|
|
952
1227
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
953
1228
|
|
954
|
-
if
|
955
|
-
# TODO(
|
1229
|
+
if _is_hip:
|
1230
|
+
# TODO(haishaw): add bmm_fp8 to ROCm
|
956
1231
|
attn_bmm_output = torch.bmm(
|
957
1232
|
attn_output.to(torch.bfloat16).transpose(0, 1),
|
958
1233
|
self.w_vc.to(torch.bfloat16) * self.w_scale,
|
@@ -1029,12 +1304,13 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1029
1304
|
|
1030
1305
|
return accum_output
|
1031
1306
|
|
1032
|
-
def
|
1307
|
+
def forward_normal_chunked_kv_prepare(
|
1033
1308
|
self,
|
1034
1309
|
positions: torch.Tensor,
|
1035
1310
|
hidden_states: torch.Tensor,
|
1036
1311
|
forward_batch: ForwardBatch,
|
1037
|
-
|
1312
|
+
zero_allocator: BumpAllocator,
|
1313
|
+
):
|
1038
1314
|
# In normal mha, the k and v tensors will become overly large when the prefix length is long.
|
1039
1315
|
# To avoid this, we split the kv cache into chunks and process them one after another.
|
1040
1316
|
# Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
|
@@ -1077,6 +1353,9 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1077
1353
|
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
1078
1354
|
)
|
1079
1355
|
|
1356
|
+
return q, k, v, forward_batch
|
1357
|
+
|
1358
|
+
def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
|
1080
1359
|
# Do mha for extended part without prefix
|
1081
1360
|
forward_batch.set_attn_attend_prefix_cache(False)
|
1082
1361
|
attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
@@ -1101,19 +1380,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
1101
1380
|
return output
|
1102
1381
|
|
1103
1382
|
|
1104
|
-
class _FFNInputMode(Enum):
|
1105
|
-
# The MLP sublayer requires 1/tp_size tokens as input
|
1106
|
-
SCATTERED = auto()
|
1107
|
-
# The MLP sublayer requires all tokens as input
|
1108
|
-
FULL = auto()
|
1109
|
-
|
1110
|
-
|
1111
|
-
@dataclass
|
1112
|
-
class _DecoderLayerInfo:
|
1113
|
-
is_sparse: bool
|
1114
|
-
ffn_input_mode: _FFNInputMode
|
1115
|
-
|
1116
|
-
|
1117
1383
|
class DeepseekV2DecoderLayer(nn.Module):
|
1118
1384
|
|
1119
1385
|
def __init__(
|
@@ -1127,14 +1393,12 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1127
1393
|
) -> None:
|
1128
1394
|
super().__init__()
|
1129
1395
|
self.hidden_size = config.hidden_size
|
1396
|
+
self.config = config
|
1130
1397
|
rope_theta = getattr(config, "rope_theta", 10000)
|
1131
1398
|
rope_scaling = getattr(config, "rope_scaling", None)
|
1132
1399
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
1133
1400
|
self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
|
1134
1401
|
self.layer_id = layer_id
|
1135
|
-
self.local_dp_size = get_local_attention_dp_size()
|
1136
|
-
self.attn_tp_size = get_attention_tp_size()
|
1137
|
-
self.attn_tp_rank = get_attention_tp_rank()
|
1138
1402
|
self.self_attn = DeepseekV2AttentionMLA(
|
1139
1403
|
config=config,
|
1140
1404
|
hidden_size=self.hidden_size,
|
@@ -1156,19 +1420,25 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1156
1420
|
alt_stream=alt_stream,
|
1157
1421
|
)
|
1158
1422
|
|
1159
|
-
self.
|
1160
|
-
|
1161
|
-
|
1423
|
+
self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
|
1424
|
+
is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
|
1425
|
+
|
1426
|
+
self.layer_scatter_modes = LayerScatterModes.init_new(
|
1427
|
+
layer_id=layer_id,
|
1428
|
+
num_layers=config.num_hidden_layers,
|
1429
|
+
is_layer_sparse=self.is_layer_sparse,
|
1430
|
+
is_previous_layer_sparse=is_previous_layer_sparse,
|
1162
1431
|
)
|
1163
1432
|
|
1164
|
-
if self.
|
1433
|
+
if self.is_layer_sparse:
|
1165
1434
|
self.mlp = DeepseekV2MoE(
|
1166
1435
|
config=config,
|
1167
1436
|
quant_config=quant_config,
|
1168
1437
|
prefix=add_prefix("mlp", prefix),
|
1438
|
+
layer_id=self.layer_id,
|
1169
1439
|
)
|
1170
1440
|
else:
|
1171
|
-
if
|
1441
|
+
if enable_moe_dense_fully_dp():
|
1172
1442
|
mlp_tp_rank, mlp_tp_size = 0, 1
|
1173
1443
|
else:
|
1174
1444
|
mlp_tp_rank, mlp_tp_size = None, None
|
@@ -1182,35 +1452,23 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1182
1452
|
tp_size=mlp_tp_size,
|
1183
1453
|
)
|
1184
1454
|
|
1185
|
-
self.input_is_scattered = (
|
1186
|
-
layer_id > 0
|
1187
|
-
and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
|
1188
|
-
)
|
1189
|
-
self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
|
1190
|
-
|
1191
1455
|
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
1192
1456
|
self.post_attention_layernorm = RMSNorm(
|
1193
1457
|
config.hidden_size, eps=config.rms_norm_eps
|
1194
1458
|
)
|
1195
1459
|
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
@staticmethod
|
1201
|
-
def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
|
1202
|
-
is_sparse = is_nextn or (
|
1203
|
-
config.n_routed_experts is not None
|
1204
|
-
and layer_id >= config.first_k_dense_replace
|
1205
|
-
and layer_id % config.moe_layer_freq == 0
|
1460
|
+
self.layer_communicator = LayerCommunicator(
|
1461
|
+
layer_scatter_modes=self.layer_scatter_modes,
|
1462
|
+
input_layernorm=self.input_layernorm,
|
1463
|
+
post_attention_layernorm=self.post_attention_layernorm,
|
1206
1464
|
)
|
1207
|
-
|
1208
|
-
|
1209
|
-
|
1210
|
-
|
1211
|
-
|
1465
|
+
|
1466
|
+
def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
|
1467
|
+
return is_nextn or (
|
1468
|
+
self.config.n_routed_experts is not None
|
1469
|
+
and layer_id >= self.config.first_k_dense_replace
|
1470
|
+
and layer_id % self.config.moe_layer_freq == 0
|
1212
1471
|
)
|
1213
|
-
return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
|
1214
1472
|
|
1215
1473
|
def forward(
|
1216
1474
|
self,
|
@@ -1220,164 +1478,98 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
1220
1478
|
residual: Optional[torch.Tensor],
|
1221
1479
|
zero_allocator: BumpAllocator,
|
1222
1480
|
) -> torch.Tensor:
|
1223
|
-
|
1224
|
-
|
1225
|
-
|
1226
|
-
)
|
1227
|
-
elif self.info.ffn_input_mode == _FFNInputMode.FULL:
|
1228
|
-
return self.forward_ffn_with_full_input(
|
1229
|
-
positions, hidden_states, forward_batch, residual, zero_allocator
|
1230
|
-
)
|
1231
|
-
else:
|
1232
|
-
raise NotImplementedError
|
1233
|
-
|
1234
|
-
def forward_ffn_with_full_input(
|
1235
|
-
self,
|
1236
|
-
positions: torch.Tensor,
|
1237
|
-
hidden_states: torch.Tensor,
|
1238
|
-
forward_batch: ForwardBatch,
|
1239
|
-
residual: Optional[torch.Tensor],
|
1240
|
-
zero_allocator: BumpAllocator,
|
1241
|
-
) -> torch.Tensor:
|
1242
|
-
|
1243
|
-
if hidden_states.shape[0] == 0:
|
1244
|
-
residual = hidden_states
|
1245
|
-
else:
|
1246
|
-
if residual is None:
|
1247
|
-
residual = hidden_states
|
1248
|
-
hidden_states = self.input_layernorm(hidden_states)
|
1249
|
-
else:
|
1250
|
-
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
1251
|
-
|
1252
|
-
assert not (
|
1253
|
-
self.attn_tp_size != 1 and self.input_is_scattered
|
1254
|
-
), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
|
1481
|
+
hidden_states, residual = self.layer_communicator.prepare_attn(
|
1482
|
+
hidden_states, residual, forward_batch
|
1483
|
+
)
|
1255
1484
|
|
1256
|
-
|
1257
|
-
|
1258
|
-
|
1259
|
-
|
1260
|
-
|
1261
|
-
|
1262
|
-
)
|
1485
|
+
hidden_states = self.self_attn(
|
1486
|
+
positions=positions,
|
1487
|
+
hidden_states=hidden_states,
|
1488
|
+
forward_batch=forward_batch,
|
1489
|
+
zero_allocator=zero_allocator,
|
1490
|
+
)
|
1263
1491
|
|
1264
|
-
|
1265
|
-
|
1266
|
-
|
1267
|
-
if self.local_dp_size != 1:
|
1268
|
-
if self.attn_tp_rank == 0:
|
1269
|
-
hidden_states += residual
|
1270
|
-
hidden_states, local_hidden_states = (
|
1271
|
-
forward_batch.gathered_buffer,
|
1272
|
-
hidden_states,
|
1273
|
-
)
|
1274
|
-
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
1275
|
-
dp_scatter(residual, hidden_states, forward_batch)
|
1276
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1277
|
-
else:
|
1278
|
-
hidden_states = tensor_model_parallel_all_reduce(hidden_states)
|
1279
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1280
|
-
hidden_states, residual
|
1281
|
-
)
|
1282
|
-
else:
|
1283
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1284
|
-
hidden_states, residual
|
1285
|
-
)
|
1492
|
+
hidden_states, residual = self.layer_communicator.prepare_mlp(
|
1493
|
+
hidden_states, residual, forward_batch
|
1494
|
+
)
|
1286
1495
|
|
1287
|
-
|
1288
|
-
hidden_states = self.mlp(hidden_states)
|
1496
|
+
hidden_states = self.mlp(hidden_states, forward_batch)
|
1289
1497
|
|
1290
|
-
|
1291
|
-
|
1292
|
-
|
1293
|
-
# important: forward batch.gathered_buffer is used both after scatter and after gather.
|
1294
|
-
# be careful about this!
|
1295
|
-
hidden_states, global_hidden_states = (
|
1296
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1297
|
-
hidden_states,
|
1298
|
-
)
|
1299
|
-
dp_scatter(hidden_states, global_hidden_states, forward_batch)
|
1498
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1499
|
+
hidden_states, residual, forward_batch
|
1500
|
+
)
|
1300
1501
|
|
1301
1502
|
return hidden_states, residual
|
1302
1503
|
|
1303
|
-
def
|
1504
|
+
def op_comm_prepare_attn(
|
1304
1505
|
self,
|
1506
|
+
state,
|
1305
1507
|
positions: torch.Tensor,
|
1306
1508
|
hidden_states: torch.Tensor,
|
1307
1509
|
forward_batch: ForwardBatch,
|
1308
1510
|
residual: Optional[torch.Tensor],
|
1309
1511
|
zero_allocator: BumpAllocator,
|
1310
|
-
|
1311
|
-
|
1312
|
-
|
1313
|
-
residual
|
1314
|
-
|
1315
|
-
|
1316
|
-
|
1317
|
-
|
1318
|
-
|
1319
|
-
|
1320
|
-
|
1321
|
-
if self.attn_tp_size != 1 and self.input_is_scattered:
|
1322
|
-
hidden_states, local_hidden_states = (
|
1323
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1324
|
-
hidden_states,
|
1325
|
-
)
|
1326
|
-
attn_tp_all_gather(
|
1327
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1512
|
+
tbo_subbatch_index: Optional[int] = None,
|
1513
|
+
):
|
1514
|
+
state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
|
1515
|
+
self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
|
1516
|
+
)
|
1517
|
+
state.update(
|
1518
|
+
dict(
|
1519
|
+
forward_batch=forward_batch,
|
1520
|
+
positions=positions,
|
1521
|
+
zero_allocator=zero_allocator,
|
1522
|
+
tbo_subbatch_index=tbo_subbatch_index,
|
1328
1523
|
)
|
1329
|
-
|
1330
|
-
# Self Attention
|
1331
|
-
hidden_states = self.self_attn(
|
1332
|
-
positions=positions,
|
1333
|
-
hidden_states=hidden_states,
|
1334
|
-
forward_batch=forward_batch,
|
1335
|
-
zero_allocator=zero_allocator,
|
1336
1524
|
)
|
1337
1525
|
|
1338
|
-
|
1339
|
-
|
1340
|
-
|
1341
|
-
|
1342
|
-
|
1343
|
-
|
1344
|
-
|
1345
|
-
|
1346
|
-
)
|
1347
|
-
else:
|
1348
|
-
if self.attn_tp_rank == 0:
|
1349
|
-
hidden_states += residual
|
1350
|
-
tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
|
1351
|
-
hidden_states = tensor_list[self.attn_tp_rank]
|
1352
|
-
attn_tp_reduce_scatter(hidden_states, tensor_list)
|
1353
|
-
residual = hidden_states
|
1354
|
-
if hidden_states.shape[0] != 0:
|
1355
|
-
hidden_states = self.post_attention_layernorm(hidden_states)
|
1356
|
-
else:
|
1357
|
-
if hidden_states.shape[0] != 0:
|
1358
|
-
hidden_states, residual = self.post_attention_layernorm(
|
1359
|
-
hidden_states, residual
|
1360
|
-
)
|
1526
|
+
def op_comm_prepare_mlp(self, state):
|
1527
|
+
state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
|
1528
|
+
self.layer_communicator.prepare_mlp(
|
1529
|
+
state.pop("hidden_states_after_attn"),
|
1530
|
+
state.pop("residual_after_input_ln"),
|
1531
|
+
state.forward_batch,
|
1532
|
+
)
|
1533
|
+
)
|
1361
1534
|
|
1535
|
+
def op_mlp(self, state):
|
1536
|
+
hidden_states = state.pop("hidden_states_mlp_input")
|
1362
1537
|
if not (
|
1363
|
-
|
1364
|
-
and (not self.
|
1538
|
+
enable_moe_dense_fully_dp()
|
1539
|
+
and (not self.is_layer_sparse)
|
1365
1540
|
and hidden_states.shape[0] == 0
|
1366
1541
|
):
|
1367
|
-
|
1368
|
-
|
1369
|
-
if self.is_last_layer and self.attn_tp_size != 1:
|
1370
|
-
hidden_states += residual
|
1371
|
-
residual = None
|
1372
|
-
hidden_states, local_hidden_states = (
|
1373
|
-
forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
|
1374
|
-
hidden_states,
|
1375
|
-
)
|
1376
|
-
attn_tp_all_gather(
|
1377
|
-
list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
|
1542
|
+
state.hidden_states_mlp_output = self.mlp(
|
1543
|
+
hidden_states, state.forward_batch.forward_mode
|
1378
1544
|
)
|
1545
|
+
else:
|
1546
|
+
state.hidden_states_mlp_output = hidden_states
|
1379
1547
|
|
1380
|
-
|
1548
|
+
def op_comm_postprocess_layer(self, state):
|
1549
|
+
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
1550
|
+
state.pop("hidden_states_mlp_output"),
|
1551
|
+
state.pop("residual_after_comm_pre_mlp"),
|
1552
|
+
state.forward_batch,
|
1553
|
+
)
|
1554
|
+
|
1555
|
+
output = dict(
|
1556
|
+
positions=state.positions,
|
1557
|
+
hidden_states=hidden_states,
|
1558
|
+
residual=residual,
|
1559
|
+
forward_batch=state.forward_batch,
|
1560
|
+
zero_allocator=state.zero_allocator,
|
1561
|
+
tbo_subbatch_index=state.tbo_subbatch_index,
|
1562
|
+
)
|
1563
|
+
|
1564
|
+
state.clear(
|
1565
|
+
expect_keys={
|
1566
|
+
"positions",
|
1567
|
+
"forward_batch",
|
1568
|
+
"zero_allocator",
|
1569
|
+
"tbo_subbatch_index",
|
1570
|
+
}
|
1571
|
+
)
|
1572
|
+
return output
|
1381
1573
|
|
1382
1574
|
|
1383
1575
|
class DeepseekV2Model(nn.Module):
|
@@ -1392,13 +1584,14 @@ class DeepseekV2Model(nn.Module):
|
|
1392
1584
|
super().__init__()
|
1393
1585
|
self.padding_id = config.pad_token_id
|
1394
1586
|
self.vocab_size = config.vocab_size
|
1587
|
+
self.first_k_dense_replace = config.first_k_dense_replace
|
1395
1588
|
|
1396
1589
|
self.embed_tokens = VocabParallelEmbedding(
|
1397
1590
|
config.vocab_size,
|
1398
1591
|
config.hidden_size,
|
1399
1592
|
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
1400
1593
|
)
|
1401
|
-
self.alt_stream = torch.cuda.Stream()
|
1594
|
+
self.alt_stream = torch.cuda.Stream() if _is_cuda else None
|
1402
1595
|
self.layers = nn.ModuleList(
|
1403
1596
|
[
|
1404
1597
|
DeepseekV2DecoderLayer(
|
@@ -1425,13 +1618,12 @@ class DeepseekV2Model(nn.Module):
|
|
1425
1618
|
forward_batch: ForwardBatch,
|
1426
1619
|
input_embeds: torch.Tensor = None,
|
1427
1620
|
) -> torch.Tensor:
|
1621
|
+
total_num_layers = len(self.layers)
|
1622
|
+
device = input_embeds.device if input_embeds is not None else input_ids.device
|
1428
1623
|
zero_allocator = BumpAllocator(
|
1429
|
-
|
1430
|
-
buffer_size=len(self.layers) * 2,
|
1624
|
+
buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
|
1431
1625
|
dtype=torch.float32,
|
1432
|
-
device=
|
1433
|
-
input_embeds.device if input_embeds is not None else input_ids.device
|
1434
|
-
),
|
1626
|
+
device=device,
|
1435
1627
|
)
|
1436
1628
|
|
1437
1629
|
if input_embeds is None:
|
@@ -1440,12 +1632,33 @@ class DeepseekV2Model(nn.Module):
|
|
1440
1632
|
hidden_states = input_embeds
|
1441
1633
|
|
1442
1634
|
residual = None
|
1443
|
-
|
1444
|
-
|
1445
|
-
|
1446
|
-
|
1447
|
-
|
1635
|
+
|
1636
|
+
normal_num_layers = (
|
1637
|
+
self.first_k_dense_replace
|
1638
|
+
if forward_batch.can_run_tbo
|
1639
|
+
else total_num_layers
|
1640
|
+
)
|
1641
|
+
for i in range(normal_num_layers):
|
1642
|
+
with get_global_expert_distribution_recorder().with_current_layer(i):
|
1643
|
+
layer = self.layers[i]
|
1644
|
+
hidden_states, residual = layer(
|
1645
|
+
positions, hidden_states, forward_batch, residual, zero_allocator
|
1646
|
+
)
|
1647
|
+
|
1648
|
+
if normal_num_layers != total_num_layers:
|
1649
|
+
hidden_states, residual = model_forward_maybe_tbo(
|
1650
|
+
layers=self.layers[normal_num_layers:],
|
1651
|
+
enable_tbo=True,
|
1652
|
+
positions=positions,
|
1653
|
+
forward_batch=forward_batch,
|
1654
|
+
hidden_states=hidden_states,
|
1655
|
+
residual=residual,
|
1656
|
+
input_data_scatter_mode=self.layers[
|
1657
|
+
normal_num_layers - 1
|
1658
|
+
].layer_scatter_modes.layer_output_mode,
|
1659
|
+
zero_allocator=zero_allocator,
|
1448
1660
|
)
|
1661
|
+
|
1449
1662
|
if not forward_batch.forward_mode.is_idle():
|
1450
1663
|
if residual is None:
|
1451
1664
|
hidden_states = self.norm(hidden_states)
|
@@ -1466,7 +1679,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1466
1679
|
self.config = config
|
1467
1680
|
self.tp_size = get_tensor_model_parallel_world_size()
|
1468
1681
|
self.quant_config = quant_config
|
1469
|
-
self.
|
1682
|
+
self.determine_num_fused_shared_experts()
|
1470
1683
|
self.model = DeepseekV2Model(
|
1471
1684
|
config, quant_config, prefix=add_prefix("model", prefix)
|
1472
1685
|
)
|
@@ -1480,40 +1693,67 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1480
1693
|
self.logits_processor = LogitsProcessor(config)
|
1481
1694
|
self.dp_size = get_local_attention_dp_size()
|
1482
1695
|
|
1483
|
-
|
1696
|
+
self._routed_experts_weights_of_layer = LazyValue(
|
1697
|
+
lambda: {
|
1698
|
+
layer_id: layer.mlp.get_moe_weights()
|
1699
|
+
for layer_id, layer in enumerate(self.model.layers)
|
1700
|
+
if isinstance(layer.mlp, DeepseekV2MoE)
|
1701
|
+
}
|
1702
|
+
)
|
1703
|
+
|
1704
|
+
@property
|
1705
|
+
def routed_experts_weights_of_layer(self):
|
1706
|
+
return self._routed_experts_weights_of_layer.value
|
1707
|
+
|
1708
|
+
def determine_num_fused_shared_experts(
|
1484
1709
|
self, architecture: str = "DeepseekV3ForCausalLM"
|
1485
1710
|
):
|
1486
|
-
self.
|
1487
|
-
|
1711
|
+
self.num_fused_shared_experts = (
|
1712
|
+
0
|
1713
|
+
if global_server_args_dict["disable_shared_experts_fusion"]
|
1714
|
+
else self.config.n_shared_experts
|
1715
|
+
)
|
1716
|
+
if self.num_fused_shared_experts > 0:
|
1488
1717
|
# Only Deepseek V3/R1 can use shared experts fusion optimization now.
|
1489
1718
|
if (
|
1490
1719
|
not _is_cuda
|
1491
1720
|
or self.config.architectures[0] != architecture
|
1492
1721
|
or self.config.n_routed_experts != 256
|
1493
1722
|
):
|
1494
|
-
self.
|
1495
|
-
global_server_args_dict["
|
1723
|
+
self.num_fused_shared_experts = 0
|
1724
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
1496
1725
|
log_info_on_rank0(
|
1497
1726
|
logger,
|
1498
1727
|
"Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
|
1499
1728
|
)
|
1500
|
-
|
1501
|
-
|
1502
|
-
|
1503
|
-
|
1504
|
-
|
1729
|
+
elif (
|
1730
|
+
global_server_args_dict["enable_deepep_moe"]
|
1731
|
+
or global_server_args_dict["enable_ep_moe"]
|
1732
|
+
):
|
1733
|
+
self.num_fused_shared_experts = 0
|
1734
|
+
global_server_args_dict["disable_shared_experts_fusion"] = True
|
1735
|
+
log_info_on_rank0(
|
1736
|
+
logger,
|
1737
|
+
"Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.",
|
1738
|
+
)
|
1739
|
+
elif self.num_fused_shared_experts == 0:
|
1505
1740
|
if (
|
1506
1741
|
_is_cuda
|
1507
1742
|
and torch.cuda.get_device_capability("cuda") >= (9, 0)
|
1508
1743
|
and self.config.architectures[0] == architecture
|
1509
1744
|
and self.config.n_routed_experts == 256
|
1510
|
-
and (
|
1745
|
+
and (
|
1746
|
+
not (
|
1747
|
+
global_server_args_dict["enable_deepep_moe"]
|
1748
|
+
or global_server_args_dict["enable_ep_moe"]
|
1749
|
+
)
|
1750
|
+
)
|
1511
1751
|
):
|
1512
|
-
self.
|
1513
|
-
global_server_args_dict["
|
1752
|
+
self.num_fused_shared_experts = self.config.n_shared_experts
|
1753
|
+
global_server_args_dict["disable_shared_experts_fusion"] = False
|
1514
1754
|
log_info_on_rank0(
|
1515
1755
|
logger,
|
1516
|
-
"Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1756
|
+
"Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
|
1517
1757
|
)
|
1518
1758
|
|
1519
1759
|
def get_input_embeddings(self) -> nn.Embedding:
|
@@ -1527,21 +1767,29 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1527
1767
|
forward_batch: ForwardBatch,
|
1528
1768
|
input_embeds: torch.Tensor = None,
|
1529
1769
|
) -> torch.Tensor:
|
1530
|
-
|
1531
1770
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
1532
1771
|
|
1533
1772
|
return self.logits_processor(
|
1534
1773
|
input_ids, hidden_states, self.lm_head, forward_batch
|
1535
1774
|
)
|
1536
1775
|
|
1537
|
-
def post_load_weights(self, is_nextn=False):
|
1776
|
+
def post_load_weights(self, is_nextn=False, weight_names=None):
|
1538
1777
|
|
1539
1778
|
# Perform post-processing after loading weights
|
1540
|
-
|
1541
|
-
|
1542
|
-
|
1543
|
-
|
1544
|
-
|
1779
|
+
if is_nextn:
|
1780
|
+
layer_ids = [self.config.num_hidden_layers]
|
1781
|
+
else:
|
1782
|
+
if weight_names is None:
|
1783
|
+
layer_ids = range(self.config.num_hidden_layers)
|
1784
|
+
else:
|
1785
|
+
layer_ids = set()
|
1786
|
+
for name in weight_names:
|
1787
|
+
if "kv_b_proj" in name:
|
1788
|
+
layer_id = int(name.split(".")[2])
|
1789
|
+
# filter the nextn layer.
|
1790
|
+
if layer_id != self.config.num_hidden_layers:
|
1791
|
+
layer_ids.add(layer_id)
|
1792
|
+
|
1545
1793
|
for layer_id in layer_ids:
|
1546
1794
|
self_attn = (
|
1547
1795
|
self.model.layers[layer_id].self_attn
|
@@ -1577,46 +1825,56 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1577
1825
|
torch.float8_e4m3fn,
|
1578
1826
|
torch.float8_e4m3fnuz,
|
1579
1827
|
):
|
1580
|
-
if
|
1828
|
+
if (
|
1829
|
+
hasattr(self.quant_config, "weight_block_size")
|
1830
|
+
and self.quant_config.weight_block_size is not None
|
1831
|
+
):
|
1581
1832
|
weight_block_size = self.quant_config.weight_block_size
|
1582
|
-
|
1583
|
-
|
1584
|
-
|
1585
|
-
weight,
|
1586
|
-
|
1587
|
-
|
1588
|
-
|
1589
|
-
|
1590
|
-
|
1591
|
-
|
1592
|
-
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1833
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
1834
|
+
if _is_fp8_fnuz:
|
1835
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1836
|
+
weight=w,
|
1837
|
+
weight_scale=self_attn.kv_b_proj.weight_scale_inv,
|
1838
|
+
input_scale=None,
|
1839
|
+
)
|
1840
|
+
else:
|
1841
|
+
weight = w
|
1842
|
+
weight_scale = self_attn.kv_b_proj.weight_scale_inv
|
1593
1843
|
|
1594
|
-
|
1595
|
-
|
1596
|
-
|
1597
|
-
|
1598
|
-
|
1844
|
+
if (
|
1845
|
+
_is_cuda
|
1846
|
+
and weight_block_size[0] == 128
|
1847
|
+
and weight_block_size[1] == 128
|
1848
|
+
and model_dtype == torch.bfloat16
|
1849
|
+
):
|
1850
|
+
if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
|
1851
|
+
"SGL_USE_DEEPGEMM_BMM", "false"
|
1599
1852
|
):
|
1600
|
-
|
1601
|
-
|
1602
|
-
):
|
1603
|
-
block_scale = weight_scale
|
1604
|
-
use_deep_gemm_bmm = True
|
1605
|
-
else:
|
1606
|
-
w = block_quant_dequant(
|
1607
|
-
weight,
|
1608
|
-
weight_scale,
|
1609
|
-
weight_block_size,
|
1610
|
-
model_dtype,
|
1611
|
-
)
|
1853
|
+
block_scale = weight_scale
|
1854
|
+
use_deep_gemm_bmm = True
|
1612
1855
|
else:
|
1613
|
-
w
|
1614
|
-
weight,
|
1856
|
+
w = block_quant_dequant(
|
1857
|
+
weight,
|
1858
|
+
weight_scale,
|
1859
|
+
weight_block_size,
|
1860
|
+
model_dtype,
|
1615
1861
|
)
|
1616
|
-
|
1862
|
+
else:
|
1863
|
+
w, scale = block_quant_to_tensor_quant(
|
1864
|
+
weight, weight_scale, weight_block_size
|
1865
|
+
)
|
1866
|
+
self_attn.w_scale = scale
|
1617
1867
|
else:
|
1618
|
-
|
1619
|
-
|
1868
|
+
if _is_fp8_fnuz:
|
1869
|
+
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
1870
|
+
weight=w,
|
1871
|
+
weight_scale=self_attn.kv_b_proj.weight_scale,
|
1872
|
+
input_scale=None,
|
1873
|
+
)
|
1874
|
+
else:
|
1875
|
+
weight = w
|
1876
|
+
weight_scale = self_attn.kv_b_proj.weight_scale
|
1877
|
+
|
1620
1878
|
w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
|
1621
1879
|
self_attn.w_scale = scale
|
1622
1880
|
|
@@ -1641,13 +1899,19 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1641
1899
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
1642
1900
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
1643
1901
|
if not use_deep_gemm_bmm:
|
1644
|
-
self_attn.w_kc =
|
1645
|
-
|
1902
|
+
self_attn.w_kc = bind_or_assign(
|
1903
|
+
self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
1904
|
+
)
|
1905
|
+
self_attn.w_vc = bind_or_assign(
|
1906
|
+
self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
|
1907
|
+
)
|
1646
1908
|
if (
|
1647
1909
|
hasattr(self_attn.kv_b_proj, "weight_scale")
|
1648
1910
|
and self_attn.w_scale is None
|
1649
1911
|
):
|
1650
|
-
self_attn.w_scale =
|
1912
|
+
self_attn.w_scale = bind_or_assign(
|
1913
|
+
self_attn.w_scale, self_attn.kv_b_proj.weight_scale
|
1914
|
+
)
|
1651
1915
|
if _is_hip:
|
1652
1916
|
self_attn.w_scale *= 2.0
|
1653
1917
|
else:
|
@@ -1656,13 +1920,20 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1656
1920
|
ws_kc, ws_vc = block_scale.unflatten(
|
1657
1921
|
0, (-1, (num_tiles_k + num_tiles_n))
|
1658
1922
|
).split([num_tiles_k, num_tiles_n], dim=1)
|
1659
|
-
self_attn.w_scale_k =
|
1660
|
-
|
1661
|
-
|
1662
|
-
self_attn.
|
1923
|
+
self_attn.w_scale_k = bind_or_assign(
|
1924
|
+
self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
|
1925
|
+
)
|
1926
|
+
self_attn.w_scale_v = bind_or_assign(
|
1927
|
+
self_attn.w_scale_v, ws_vc.contiguous()
|
1928
|
+
)
|
1929
|
+
self_attn.w_kc = bind_or_assign(
|
1930
|
+
self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
|
1931
|
+
)
|
1932
|
+
self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
|
1663
1933
|
self_attn.use_deep_gemm_bmm = True
|
1664
1934
|
|
1665
1935
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
|
1936
|
+
|
1666
1937
|
if is_nextn:
|
1667
1938
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1668
1939
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
@@ -1681,26 +1952,68 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1681
1952
|
("gate_up_proj", "gate_proj", 0),
|
1682
1953
|
("gate_up_proj", "up_proj", 1),
|
1683
1954
|
]
|
1684
|
-
if self.
|
1955
|
+
if self.num_fused_shared_experts > 0:
|
1956
|
+
assert self.num_fused_shared_experts == 1
|
1685
1957
|
weights_list = list(weights)
|
1686
1958
|
weights_dict = dict(weights_list)
|
1687
|
-
if self.quant_config is None
|
1688
|
-
|
1689
|
-
|
1690
|
-
|
1691
|
-
|
1692
|
-
|
1693
|
-
|
1694
|
-
|
1695
|
-
|
1959
|
+
if self.quant_config is not None:
|
1960
|
+
if self.quant_config.get_name() == "w8a8_int8":
|
1961
|
+
suffix_list = [
|
1962
|
+
"down_proj.weight",
|
1963
|
+
"down_proj.weight_scale",
|
1964
|
+
"gate_proj.weight",
|
1965
|
+
"gate_proj.weight_scale",
|
1966
|
+
"up_proj.weight",
|
1967
|
+
"up_proj.weight_scale",
|
1968
|
+
]
|
1969
|
+
elif (
|
1970
|
+
self.quant_config.get_name() == "fp8"
|
1971
|
+
or self.quant_config.get_name() == "blockwise_int8"
|
1972
|
+
):
|
1973
|
+
suffix_list = [
|
1974
|
+
"down_proj.weight",
|
1975
|
+
"down_proj.weight_scale_inv",
|
1976
|
+
"gate_proj.weight",
|
1977
|
+
"gate_proj.weight_scale_inv",
|
1978
|
+
"up_proj.weight",
|
1979
|
+
"up_proj.weight_scale_inv",
|
1980
|
+
]
|
1981
|
+
elif self.quant_config.get_name() == "awq":
|
1982
|
+
suffix_list = [
|
1983
|
+
"down_proj.qweight",
|
1984
|
+
"down_proj.qzeros",
|
1985
|
+
"down_proj.scales",
|
1986
|
+
"gate_proj.qweight",
|
1987
|
+
"gate_proj.qzeros",
|
1988
|
+
"gate_proj.scales",
|
1989
|
+
"up_proj.qweight",
|
1990
|
+
"up_proj.qzeros",
|
1991
|
+
"up_proj.scales",
|
1992
|
+
]
|
1993
|
+
elif self.quant_config.get_name() == "modelopt_fp4":
|
1994
|
+
suffix_list = [
|
1995
|
+
"down_proj.weight",
|
1996
|
+
"down_proj.weight_scale",
|
1997
|
+
"down_proj.weight_scale_2",
|
1998
|
+
"down_proj.input_scale",
|
1999
|
+
"gate_proj.weight",
|
2000
|
+
"gate_proj.weight_scale",
|
2001
|
+
"gate_proj.weight_scale_2",
|
2002
|
+
"gate_proj.input_scale",
|
2003
|
+
"up_proj.weight",
|
2004
|
+
"up_proj.weight_scale",
|
2005
|
+
"up_proj.weight_scale_2",
|
2006
|
+
"up_proj.input_scale",
|
2007
|
+
]
|
2008
|
+
else:
|
2009
|
+
raise ValueError(
|
2010
|
+
f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
|
2011
|
+
)
|
1696
2012
|
else:
|
1697
2013
|
suffix_list = [
|
1698
2014
|
"down_proj.weight",
|
1699
|
-
"down_proj.weight_scale_inv",
|
1700
2015
|
"gate_proj.weight",
|
1701
|
-
"gate_proj.weight_scale_inv",
|
1702
2016
|
"up_proj.weight",
|
1703
|
-
"up_proj.weight_scale_inv",
|
1704
2017
|
]
|
1705
2018
|
names_to_remove = []
|
1706
2019
|
|
@@ -1716,38 +2029,32 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1716
2029
|
|
1717
2030
|
for moe_layer in tqdm(
|
1718
2031
|
moe_layers,
|
1719
|
-
desc=f"Cloning {self.
|
1720
|
-
"
|
2032
|
+
desc=f"Cloning {self.num_fused_shared_experts} "
|
2033
|
+
"shared expert into MoE",
|
1721
2034
|
):
|
1722
2035
|
for suffix in suffix_list:
|
1723
2036
|
shared_expert_weight_name = (
|
1724
2037
|
f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
|
1725
2038
|
)
|
1726
|
-
|
1727
|
-
|
1728
|
-
|
1729
|
-
|
1730
|
-
|
1731
|
-
|
1732
|
-
|
1733
|
-
weights_dict[shared_expert_weight_name],
|
1734
|
-
)
|
2039
|
+
weights_list.append(
|
2040
|
+
(
|
2041
|
+
f"model.layers.{moe_layer}."
|
2042
|
+
f"mlp.experts."
|
2043
|
+
f"{self.config.n_routed_experts + 0}"
|
2044
|
+
f".{suffix}",
|
2045
|
+
weights_dict[shared_expert_weight_name],
|
1735
2046
|
)
|
2047
|
+
)
|
1736
2048
|
names_to_remove += [shared_expert_weight_name]
|
1737
2049
|
weights = [w for w in weights_list if w[0] not in names_to_remove]
|
1738
2050
|
|
1739
2051
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
1740
2052
|
# (param_name, weight_name, expert_id, shard_id)
|
1741
|
-
|
1742
|
-
DeepEPMoE
|
1743
|
-
if global_server_args_dict["enable_deepep_moe"]
|
1744
|
-
else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
|
1745
|
-
)
|
1746
|
-
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
2053
|
+
expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
|
1747
2054
|
ckpt_gate_proj_name="gate_proj",
|
1748
2055
|
ckpt_down_proj_name="down_proj",
|
1749
2056
|
ckpt_up_proj_name="up_proj",
|
1750
|
-
num_experts=self.config.n_routed_experts + self.
|
2057
|
+
num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
|
1751
2058
|
)
|
1752
2059
|
|
1753
2060
|
# Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
|
@@ -1766,7 +2073,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1766
2073
|
]
|
1767
2074
|
|
1768
2075
|
params_dict = dict(self.named_parameters())
|
2076
|
+
weight_names = []
|
1769
2077
|
for name, loaded_weight in weights:
|
2078
|
+
weight_names.append(name)
|
2079
|
+
|
1770
2080
|
if not is_nextn:
|
1771
2081
|
if hasattr(self.config, "num_nextn_predict_layers"):
|
1772
2082
|
num_nextn_layers = self.config.num_nextn_predict_layers
|
@@ -1838,7 +2148,6 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1838
2148
|
# Skip loading extra bias for GPTQ models.
|
1839
2149
|
if name.endswith(".bias") and name not in params_dict:
|
1840
2150
|
continue
|
1841
|
-
|
1842
2151
|
if fuse_qkv_a_proj and (
|
1843
2152
|
"q_a_proj" in name or "kv_a_proj_with_mqa" in name
|
1844
2153
|
):
|
@@ -1859,15 +2168,17 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1859
2168
|
q_a_proj_name in cached_a_proj
|
1860
2169
|
and kv_a_proj_name in cached_a_proj
|
1861
2170
|
):
|
1862
|
-
|
1863
2171
|
q_a_proj_weight = cached_a_proj[q_a_proj_name]
|
1864
2172
|
kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
|
1865
2173
|
fused_weight = torch.cat(
|
1866
2174
|
[q_a_proj_weight, kv_a_proj_weight], dim=0
|
1867
2175
|
)
|
1868
|
-
|
1869
|
-
|
1870
|
-
"q_a_proj"
|
2176
|
+
param_name = (
|
2177
|
+
name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
|
2178
|
+
if "q_a_proj" in name
|
2179
|
+
else name.replace(
|
2180
|
+
"kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
|
2181
|
+
)
|
1871
2182
|
)
|
1872
2183
|
param = params_dict[param_name]
|
1873
2184
|
|
@@ -1878,13 +2189,23 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1878
2189
|
cached_a_proj.pop(q_a_proj_name)
|
1879
2190
|
cached_a_proj.pop(kv_a_proj_name)
|
1880
2191
|
else:
|
2192
|
+
if (
|
2193
|
+
"k_scale" in name or "v_scale" in name
|
2194
|
+
) and name not in params_dict:
|
2195
|
+
# modelopt attn kv scale is named differently
|
2196
|
+
if any(scale in name for scale in ["k_scale", "v_scale"]):
|
2197
|
+
name = name.replace("_proj", "attn_mqa")
|
2198
|
+
else:
|
2199
|
+
logger.warning(
|
2200
|
+
f"Unknown scale found in checkpoint: {name}"
|
2201
|
+
)
|
1881
2202
|
param = params_dict[name]
|
1882
2203
|
weight_loader = getattr(
|
1883
2204
|
param, "weight_loader", default_weight_loader
|
1884
2205
|
)
|
1885
2206
|
weight_loader(param, loaded_weight)
|
1886
2207
|
|
1887
|
-
self.post_load_weights(is_nextn=is_nextn)
|
2208
|
+
self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
|
1888
2209
|
|
1889
2210
|
def get_embed_and_head(self):
|
1890
2211
|
return self.model.embed_tokens.weight, self.lm_head.weight
|
@@ -1897,6 +2218,14 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
1897
2218
|
torch.cuda.empty_cache()
|
1898
2219
|
torch.cuda.synchronize()
|
1899
2220
|
|
2221
|
+
@classmethod
|
2222
|
+
def get_model_config_for_expert_location(cls, config):
|
2223
|
+
return ModelConfigForExpertLocation(
|
2224
|
+
num_layers=config.num_hidden_layers,
|
2225
|
+
num_logical_experts=config.n_routed_experts,
|
2226
|
+
num_groups=config.n_group,
|
2227
|
+
)
|
2228
|
+
|
1900
2229
|
|
1901
2230
|
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
1902
2231
|
pass
|