sglang 0.4.6.post4__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sglang/bench_offline_throughput.py +16 -10
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +86 -22
  4. sglang/bench_serving.py +197 -110
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/lang/backend/runtime_endpoint.py +24 -1
  7. sglang/profiler.py +167 -0
  8. sglang/srt/_custom_ops.py +34 -0
  9. sglang/srt/configs/internvl.py +8 -12
  10. sglang/srt/configs/model_config.py +66 -29
  11. sglang/srt/constrained/base_grammar_backend.py +5 -2
  12. sglang/srt/constrained/llguidance_backend.py +9 -8
  13. sglang/srt/constrained/outlines_backend.py +5 -4
  14. sglang/srt/constrained/xgrammar_backend.py +18 -18
  15. sglang/srt/conversation.py +47 -9
  16. sglang/srt/custom_op.py +38 -3
  17. sglang/srt/debug_utils.py +74 -0
  18. sglang/srt/disaggregation/common/__init__.py +1 -0
  19. sglang/srt/disaggregation/common/conn.py +407 -0
  20. sglang/srt/disaggregation/decode.py +187 -134
  21. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  22. sglang/srt/disaggregation/fake/conn.py +4 -13
  23. sglang/srt/disaggregation/kv_events.py +412 -0
  24. sglang/srt/disaggregation/launch_lb.py +140 -0
  25. sglang/srt/disaggregation/mini_lb.py +84 -70
  26. sglang/srt/disaggregation/mooncake/conn.py +441 -140
  27. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -14
  28. sglang/srt/disaggregation/nixl/conn.py +124 -442
  29. sglang/srt/disaggregation/prefill.py +128 -44
  30. sglang/srt/disaggregation/utils.py +154 -6
  31. sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
  32. sglang/srt/distributed/parallel_state.py +52 -5
  33. sglang/srt/distributed/utils.py +3 -3
  34. sglang/srt/entrypoints/EngineBase.py +11 -0
  35. sglang/srt/entrypoints/engine.py +129 -12
  36. sglang/srt/entrypoints/http_server.py +21 -6
  37. sglang/srt/entrypoints/http_server_engine.py +5 -2
  38. sglang/srt/function_call/base_format_detector.py +302 -0
  39. sglang/srt/function_call/core_types.py +34 -0
  40. sglang/srt/function_call/deepseekv3_detector.py +205 -0
  41. sglang/srt/function_call/ebnf_composer.py +248 -0
  42. sglang/srt/function_call/function_call_parser.py +202 -0
  43. sglang/srt/function_call/llama32_detector.py +93 -0
  44. sglang/srt/function_call/mistral_detector.py +131 -0
  45. sglang/srt/function_call/pythonic_detector.py +229 -0
  46. sglang/srt/function_call/qwen25_detector.py +121 -0
  47. sglang/srt/function_call/utils.py +52 -0
  48. sglang/srt/hf_transformers_utils.py +50 -7
  49. sglang/srt/layers/attention/aiter_backend.py +878 -0
  50. sglang/srt/layers/attention/base_attn_backend.py +4 -0
  51. sglang/srt/layers/attention/cutlass_mla_backend.py +2 -19
  52. sglang/srt/layers/attention/flashattention_backend.py +166 -35
  53. sglang/srt/layers/attention/flashinfer_backend.py +45 -1
  54. sglang/srt/layers/attention/flashinfer_mla_backend.py +45 -5
  55. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  56. sglang/srt/layers/attention/intel_amx_backend.py +128 -0
  57. sglang/srt/layers/attention/tbo_backend.py +232 -0
  58. sglang/srt/layers/attention/torch_native_backend.py +3 -0
  59. sglang/srt/layers/attention/triton_backend.py +247 -5
  60. sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
  61. sglang/srt/layers/attention/utils.py +2 -2
  62. sglang/srt/layers/attention/vision.py +1 -1
  63. sglang/srt/layers/communicator.py +517 -0
  64. sglang/srt/layers/dp_attention.py +6 -15
  65. sglang/srt/layers/layernorm.py +30 -19
  66. sglang/srt/layers/moe/cutlass_moe.py +370 -0
  67. sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
  68. sglang/srt/layers/moe/ep_moe/kernels.py +60 -17
  69. sglang/srt/layers/moe/ep_moe/layer.py +195 -87
  70. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +88 -8
  71. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  77. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  78. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  80. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +220 -25
  81. sglang/srt/layers/moe/fused_moe_triton/layer.py +48 -4
  82. sglang/srt/layers/moe/topk.py +107 -24
  83. sglang/srt/layers/multimodal.py +70 -0
  84. sglang/srt/layers/quantization/__init__.py +10 -4
  85. sglang/srt/layers/quantization/blockwise_int8.py +3 -0
  86. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
  87. sglang/srt/layers/quantization/deep_gemm.py +60 -59
  88. sglang/srt/layers/quantization/fp8.py +113 -18
  89. sglang/srt/layers/quantization/fp8_kernel.py +118 -66
  90. sglang/srt/layers/quantization/fp8_utils.py +165 -43
  91. sglang/srt/layers/quantization/gptq.py +298 -6
  92. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  93. sglang/srt/layers/quantization/modelopt_quant.py +334 -7
  94. sglang/srt/layers/quantization/moe_wna16.py +3 -0
  95. sglang/srt/layers/quantization/qoq.py +244 -0
  96. sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
  97. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  98. sglang/srt/layers/rotary_embedding.py +6 -12
  99. sglang/srt/layers/sampler.py +80 -79
  100. sglang/srt/layers/utils.py +6 -0
  101. sglang/srt/lora/layers.py +12 -15
  102. sglang/srt/lora/lora.py +49 -5
  103. sglang/srt/lora/lora_manager.py +20 -8
  104. sglang/srt/lora/mem_pool.py +24 -16
  105. sglang/srt/lora/utils.py +17 -13
  106. sglang/srt/managers/data_parallel_controller.py +13 -5
  107. sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
  108. sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
  109. sglang/srt/managers/eplb_algorithms/deepseek_vec.py +276 -0
  110. sglang/srt/managers/eplb_manager.py +96 -0
  111. sglang/srt/managers/expert_distribution.py +878 -56
  112. sglang/srt/managers/expert_location.py +448 -0
  113. sglang/srt/managers/expert_location_dispatch.py +108 -0
  114. sglang/srt/managers/io_struct.py +29 -5
  115. sglang/srt/managers/mm_utils.py +355 -151
  116. sglang/srt/managers/multimodal_processors/base_processor.py +299 -42
  117. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  118. sglang/srt/managers/multimodal_processors/gemma3.py +15 -17
  119. sglang/srt/managers/multimodal_processors/internvl.py +18 -5
  120. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  121. sglang/srt/managers/multimodal_processors/kimi_vl.py +14 -32
  122. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  123. sglang/srt/managers/multimodal_processors/minicpm.py +27 -32
  124. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  125. sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
  126. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  127. sglang/srt/managers/multimodal_processors/qwen_vl.py +35 -35
  128. sglang/srt/managers/schedule_batch.py +185 -55
  129. sglang/srt/managers/schedule_policy.py +4 -5
  130. sglang/srt/managers/scheduler.py +389 -154
  131. sglang/srt/managers/session_controller.py +1 -1
  132. sglang/srt/managers/tokenizer_manager.py +231 -39
  133. sglang/srt/managers/utils.py +0 -4
  134. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  135. sglang/srt/mem_cache/chunk_cache.py +3 -1
  136. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  137. sglang/srt/mem_cache/memory_pool.py +74 -52
  138. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  139. sglang/srt/mem_cache/radix_cache.py +58 -5
  140. sglang/srt/metrics/collector.py +11 -2
  141. sglang/srt/mm_utils.py +10 -0
  142. sglang/srt/model_executor/cuda_graph_runner.py +87 -65
  143. sglang/srt/model_executor/expert_location_updater.py +557 -0
  144. sglang/srt/model_executor/forward_batch_info.py +39 -14
  145. sglang/srt/model_executor/model_runner.py +231 -101
  146. sglang/srt/model_loader/loader.py +10 -6
  147. sglang/srt/model_loader/utils.py +67 -1
  148. sglang/srt/models/clip.py +5 -1
  149. sglang/srt/models/deepseek_nextn.py +1 -1
  150. sglang/srt/models/deepseek_v2.py +732 -403
  151. sglang/srt/models/exaone.py +8 -3
  152. sglang/srt/models/gemma3_causal.py +7 -0
  153. sglang/srt/models/gemma3_mm.py +75 -33
  154. sglang/srt/models/idefics2.py +342 -0
  155. sglang/srt/models/kimi_vl.py +4 -4
  156. sglang/srt/models/llama.py +1 -1
  157. sglang/srt/models/llama4.py +10 -2
  158. sglang/srt/models/llava.py +26 -18
  159. sglang/srt/models/mimo_mtp.py +220 -0
  160. sglang/srt/models/minicpmo.py +7 -17
  161. sglang/srt/models/minicpmv.py +3 -295
  162. sglang/srt/models/mistral.py +71 -1
  163. sglang/srt/models/mllama.py +3 -3
  164. sglang/srt/models/phi4mm.py +512 -0
  165. sglang/srt/models/qwen2.py +133 -35
  166. sglang/srt/models/qwen2_5_vl.py +5 -3
  167. sglang/srt/models/qwen2_eagle.py +4 -1
  168. sglang/srt/models/qwen2_moe.py +206 -69
  169. sglang/srt/models/qwen2_vl.py +3 -3
  170. sglang/srt/models/qwen3.py +92 -19
  171. sglang/srt/models/qwen3_moe.py +457 -55
  172. sglang/srt/models/registry.py +9 -1
  173. sglang/srt/models/siglip.py +294 -0
  174. sglang/srt/models/transformers.py +291 -0
  175. sglang/srt/openai_api/adapter.py +114 -40
  176. sglang/srt/openai_api/protocol.py +37 -2
  177. sglang/srt/openai_api/utils.py +172 -0
  178. sglang/srt/operations.py +189 -0
  179. sglang/srt/operations_strategy.py +207 -0
  180. sglang/srt/sampling/sampling_batch_info.py +13 -1
  181. sglang/srt/sampling/sampling_params.py +2 -1
  182. sglang/srt/server_args.py +235 -38
  183. sglang/srt/speculative/build_eagle_tree.py +8 -8
  184. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -11
  185. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +253 -0
  186. sglang/srt/speculative/eagle_utils.py +181 -90
  187. sglang/srt/speculative/eagle_worker.py +146 -21
  188. sglang/srt/two_batch_overlap.py +635 -0
  189. sglang/srt/utils.py +197 -19
  190. sglang/test/runners.py +16 -7
  191. sglang/test/send_one.py +4 -0
  192. sglang/test/test_cutlass_moe.py +278 -0
  193. sglang/test/test_fp4_moe.py +248 -0
  194. sglang/test/test_utils.py +81 -42
  195. sglang/utils.py +2 -2
  196. sglang/version.py +1 -1
  197. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/METADATA +31 -19
  198. sglang-0.4.7.dist-info/RECORD +699 -0
  199. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/WHEEL +1 -1
  200. sglang/srt/function_call_parser.py +0 -858
  201. sglang/srt/platforms/interface.py +0 -371
  202. sglang-0.4.6.post4.dist-info/RECORD +0 -646
  203. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  204. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  205. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  206. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  207. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  208. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  209. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  210. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  211. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  212. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  213. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  214. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  215. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  216. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
  217. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  218. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  219. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  220. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  221. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  222. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  223. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  224. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  225. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  226. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  227. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  228. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  229. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  230. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  231. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  232. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  233. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  234. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  235. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  236. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  237. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  238. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  239. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  240. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  241. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  242. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  243. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  244. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  245. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
  246. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  247. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  248. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  249. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  250. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  251. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  252. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  253. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  254. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  255. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
  256. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
  257. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  258. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  259. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  260. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  261. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  262. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  263. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
  264. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  265. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  266. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
  267. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  268. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  269. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  270. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
  271. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  272. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  273. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  274. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  275. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  276. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  277. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  278. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
  279. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
  280. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
  281. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
  282. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  283. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  284. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
  285. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
  286. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
  287. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
  288. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  289. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  290. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  291. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  292. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
  293. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  294. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  295. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  296. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  297. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
  298. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
  299. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
  300. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
  301. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  302. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  303. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  304. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  305. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  306. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  307. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
  308. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
  309. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  310. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  311. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  312. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  313. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  314. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  315. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  316. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
  317. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
  318. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
  319. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
  320. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  321. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  322. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  323. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  324. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
  325. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  326. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  327. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  328. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  329. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  330. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  331. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  332. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
  333. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
  334. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  335. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  336. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
  337. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  338. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
  339. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  340. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  341. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  342. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  343. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
  344. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  345. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
  346. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
  347. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  348. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  349. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  350. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  351. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  352. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  353. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  354. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  355. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  356. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  357. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/licenses/LICENSE +0 -0
  358. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/top_level.txt +0 -0
@@ -18,8 +18,7 @@
18
18
 
19
19
  import logging
20
20
  import os
21
- from dataclasses import dataclass
22
- from enum import Enum, IntEnum, auto
21
+ from enum import IntEnum, auto
23
22
  from typing import Any, Dict, Iterable, Optional, Tuple
24
23
 
25
24
  import torch
@@ -29,17 +28,17 @@ from tqdm import tqdm
29
28
  from transformers import PretrainedConfig
30
29
 
31
30
  from sglang.srt.distributed import (
32
- get_tensor_model_parallel_rank,
33
31
  get_tensor_model_parallel_world_size,
34
32
  parallel_state,
35
33
  tensor_model_parallel_all_reduce,
36
34
  )
37
35
  from sglang.srt.layers.activation import SiluAndMul
36
+ from sglang.srt.layers.communicator import (
37
+ LayerCommunicator,
38
+ LayerScatterModes,
39
+ enable_moe_dense_fully_dp,
40
+ )
38
41
  from sglang.srt.layers.dp_attention import (
39
- attn_tp_all_gather,
40
- attn_tp_reduce_scatter,
41
- dp_gather_partial,
42
- dp_scatter,
43
42
  get_attention_tp_rank,
44
43
  get_attention_tp_size,
45
44
  get_local_attention_dp_size,
@@ -52,13 +51,13 @@ from sglang.srt.layers.linear import (
52
51
  RowParallelLinear,
53
52
  )
54
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
55
- from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, EPMoE
54
+ from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
56
55
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
57
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
58
56
  from sglang.srt.layers.moe.topk import select_experts
59
57
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
60
58
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
61
59
  from sglang.srt.layers.quantization.fp8_kernel import (
60
+ is_fp8_fnuz,
62
61
  per_tensor_quant_mla_fp8,
63
62
  per_token_group_quant_mla_deep_gemm_masked_fp8,
64
63
  )
@@ -72,28 +71,41 @@ from sglang.srt.layers.quantization.int8_utils import (
72
71
  block_dequant as int8_block_dequant,
73
72
  )
74
73
  from sglang.srt.layers.radix_attention import RadixAttention
75
- from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
74
+ from sglang.srt.layers.rotary_embedding import get_rope
76
75
  from sglang.srt.layers.vocab_parallel_embedding import (
77
76
  ParallelLMHead,
78
77
  VocabParallelEmbedding,
79
78
  )
80
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
79
+ from sglang.srt.managers.expert_distribution import (
80
+ get_global_expert_distribution_recorder,
81
+ )
82
+ from sglang.srt.managers.expert_location import ModelConfigForExpertLocation
83
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
81
84
  from sglang.srt.managers.schedule_batch import global_server_args_dict
82
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
85
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
83
86
  from sglang.srt.model_loader.weight_utils import default_weight_loader
87
+ from sglang.srt.two_batch_overlap import (
88
+ MaybeTboDeepEPDispatcher,
89
+ model_forward_maybe_tbo,
90
+ )
84
91
  from sglang.srt.utils import (
85
92
  BumpAllocator,
86
93
  DeepEPMode,
94
+ LazyValue,
87
95
  add_prefix,
96
+ bind_or_assign,
88
97
  get_bool_env_var,
89
98
  get_int_env_var,
90
99
  is_cuda,
91
100
  is_hip,
101
+ is_non_idle_and_non_empty,
92
102
  log_info_on_rank0,
93
103
  )
94
104
 
95
105
  _is_hip = is_hip()
96
106
  _is_cuda = is_cuda()
107
+ _is_fp8_fnuz = is_fp8_fnuz()
108
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
97
109
 
98
110
  if _is_cuda:
99
111
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
@@ -109,7 +121,8 @@ if _is_hip:
109
121
  decode_attention_fwd_grouped_rope,
110
122
  )
111
123
 
112
- expert_distribution_recorder = ExpertDistributionRecorder()
124
+ if _use_aiter:
125
+ from aiter.rotary_embedding import get_rope
113
126
 
114
127
  logger = logging.getLogger(__name__)
115
128
 
@@ -125,6 +138,9 @@ class AttnForwardMethod(IntEnum):
125
138
  # This method can avoid OOM when prefix lengths are long.
126
139
  MHA_CHUNKED_KV = auto()
127
140
 
141
+ # Use MLA but with fused RoPE
142
+ MLA_FUSED_ROPE = auto()
143
+
128
144
 
129
145
  class DeepseekV2MLP(nn.Module):
130
146
  def __init__(
@@ -139,6 +155,8 @@ class DeepseekV2MLP(nn.Module):
139
155
  tp_size: Optional[int] = None,
140
156
  ) -> None:
141
157
  super().__init__()
158
+ self.tp_size = tp_size
159
+
142
160
  self.gate_up_proj = MergedColumnParallelLinear(
143
161
  hidden_size,
144
162
  [intermediate_size] * 2,
@@ -165,7 +183,10 @@ class DeepseekV2MLP(nn.Module):
165
183
  )
166
184
  self.act_fn = SiluAndMul()
167
185
 
168
- def forward(self, x, forward_mode: Optional[ForwardMode] = None):
186
+ def forward(self, x, forward_batch=None):
187
+ if (self.tp_size == 1) and x.shape[0] == 0:
188
+ return x
189
+
169
190
  gate_up, _ = self.gate_up_proj(x)
170
191
  x = self.act_fn(gate_up)
171
192
  x, _ = self.down_proj(x)
@@ -199,6 +220,7 @@ class DeepseekV2MoE(nn.Module):
199
220
  def __init__(
200
221
  self,
201
222
  config: PretrainedConfig,
223
+ layer_id: int,
202
224
  quant_config: Optional[QuantizationConfig] = None,
203
225
  prefix: str = "",
204
226
  ):
@@ -206,7 +228,13 @@ class DeepseekV2MoE(nn.Module):
206
228
  self.tp_size = get_tensor_model_parallel_world_size()
207
229
  self.routed_scaling_factor = config.routed_scaling_factor
208
230
  self.n_shared_experts = config.n_shared_experts
209
- self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
231
+ self.num_fused_shared_experts = (
232
+ 0
233
+ if global_server_args_dict["disable_shared_experts_fusion"]
234
+ else config.n_shared_experts
235
+ )
236
+ self.config = config
237
+ self.layer_id = layer_id
210
238
 
211
239
  if self.tp_size > config.n_routed_experts:
212
240
  raise ValueError(
@@ -222,21 +250,19 @@ class DeepseekV2MoE(nn.Module):
222
250
 
223
251
  self.gate = MoEGate(config=config, prefix=add_prefix("gate", prefix))
224
252
 
225
- MoEImpl = (
226
- DeepEPMoE
227
- if global_server_args_dict["enable_deepep_moe"]
228
- else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
229
- )
230
-
231
- self.experts = MoEImpl(
232
- num_experts=config.n_routed_experts + self.n_share_experts_fusion,
233
- top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
253
+ self.experts = get_moe_impl_class()(
254
+ num_experts=config.n_routed_experts
255
+ + self.num_fused_shared_experts
256
+ + global_server_args_dict["ep_num_redundant_experts"],
257
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
234
258
  hidden_size=config.hidden_size,
235
259
  intermediate_size=config.moe_intermediate_size,
260
+ layer_id=self.layer_id,
236
261
  renormalize=config.norm_topk_prob,
237
262
  quant_config=quant_config,
238
263
  use_grouped_topk=True,
239
264
  num_expert_group=config.n_group,
265
+ num_fused_shared_experts=self.num_fused_shared_experts,
240
266
  topk_group=config.topk_group,
241
267
  correction_bias=self.gate.e_score_correction_bias,
242
268
  routed_scaling_factor=self.routed_scaling_factor,
@@ -248,35 +274,32 @@ class DeepseekV2MoE(nn.Module):
248
274
  ),
249
275
  )
250
276
 
251
- if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
277
+ if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
252
278
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
253
279
  # disable tp for shared experts when enable deepep moe
254
- if not global_server_args_dict["enable_deepep_moe"]:
255
- self.shared_experts = DeepseekV2MLP(
256
- hidden_size=config.hidden_size,
257
- intermediate_size=intermediate_size,
258
- hidden_act=config.hidden_act,
259
- quant_config=quant_config,
260
- reduce_results=False,
261
- prefix=add_prefix("shared_experts", prefix),
262
- )
263
- else:
264
- self.shared_experts = DeepseekV2MLP(
265
- hidden_size=config.hidden_size,
266
- intermediate_size=intermediate_size,
267
- hidden_act=config.hidden_act,
268
- quant_config=quant_config,
269
- reduce_results=False,
270
- prefix=add_prefix("shared_experts", prefix),
271
- tp_rank=0,
272
- tp_size=1,
273
- )
280
+ self.shared_experts = DeepseekV2MLP(
281
+ hidden_size=config.hidden_size,
282
+ intermediate_size=intermediate_size,
283
+ hidden_act=config.hidden_act,
284
+ quant_config=quant_config,
285
+ reduce_results=False,
286
+ prefix=add_prefix("shared_experts", prefix),
287
+ **(
288
+ dict(tp_rank=0, tp_size=1)
289
+ if global_server_args_dict["enable_deepep_moe"]
290
+ else {}
291
+ ),
292
+ )
293
+
294
+ self.top_k = config.num_experts_per_tok
274
295
 
275
296
  if global_server_args_dict["enable_deepep_moe"]:
276
297
  # TODO: we will support tp < ep in the future
277
298
  self.ep_size = get_tensor_model_parallel_world_size()
278
- self.num_experts = config.n_routed_experts
279
- self.top_k = config.num_experts_per_tok
299
+ self.num_experts = (
300
+ config.n_routed_experts
301
+ + global_server_args_dict["ep_num_redundant_experts"]
302
+ )
280
303
  self.renormalize = config.norm_topk_prob
281
304
  self.topk_group = config.topk_group
282
305
  self.num_expert_group = config.n_group
@@ -286,35 +309,45 @@ class DeepseekV2MoE(nn.Module):
286
309
  else None
287
310
  )
288
311
 
289
- self.deepep_dispatcher = DeepEPDispatcher(
312
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
290
313
  group=parallel_state.get_tp_group().device_group,
291
314
  router_topk=self.top_k,
292
315
  permute_fusion=True,
293
- num_experts=config.n_routed_experts,
316
+ num_experts=self.num_experts,
294
317
  num_local_experts=config.n_routed_experts // self.tp_size,
295
318
  hidden_size=config.hidden_size,
296
319
  params_dtype=config.torch_dtype,
297
320
  deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
298
- async_finish=True, # TODO
321
+ async_finish=True,
299
322
  return_recv_hook=True,
300
323
  )
301
324
 
325
+ self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
326
+
327
+ def get_moe_weights(self):
328
+ return [
329
+ x.data
330
+ for name, x in self.experts.named_parameters()
331
+ if name not in ["correction_bias"]
332
+ ]
333
+
302
334
  def forward(
303
- self, hidden_states: torch.Tensor, forward_mode: Optional[ForwardMode] = None
335
+ self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
304
336
  ) -> torch.Tensor:
305
- if not global_server_args_dict["enable_deepep_moe"]:
337
+ if not self._enable_deepep_moe:
306
338
  return self.forward_normal(hidden_states)
307
339
  else:
308
- return self.forward_deepep(hidden_states, forward_mode)
340
+ return self.forward_deepep(hidden_states, forward_batch)
309
341
 
310
342
  def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
311
343
  shared_output = self._forward_shared_experts(hidden_states)
312
344
  # router_logits: (num_tokens, n_experts)
313
345
  router_logits = self.gate(hidden_states)
314
- final_hidden_states = (
315
- self.experts(hidden_states=hidden_states, router_logits=router_logits)
316
- * self.routed_scaling_factor
346
+ final_hidden_states = self.experts(
347
+ hidden_states=hidden_states, router_logits=router_logits
317
348
  )
349
+ if not _is_cuda:
350
+ final_hidden_states *= self.routed_scaling_factor
318
351
  if shared_output is not None:
319
352
  final_hidden_states = final_hidden_states + shared_output
320
353
  if self.tp_size > 1:
@@ -322,14 +355,11 @@ class DeepseekV2MoE(nn.Module):
322
355
  return final_hidden_states
323
356
 
324
357
  def forward_deepep(
325
- self, hidden_states: torch.Tensor, forward_mode: ForwardMode
358
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
326
359
  ) -> torch.Tensor:
360
+ forward_mode = forward_batch.forward_mode
327
361
  shared_output = None
328
- if (
329
- forward_mode is not None
330
- and not forward_mode.is_idle()
331
- and hidden_states.shape[0] > 0
332
- ):
362
+ if is_non_idle_and_non_empty(forward_mode, hidden_states):
333
363
  # router_logits: (num_tokens, n_experts)
334
364
  router_logits = self.gate(hidden_states)
335
365
  shared_output = self._forward_shared_experts(hidden_states)
@@ -341,8 +371,13 @@ class DeepseekV2MoE(nn.Module):
341
371
  renormalize=self.renormalize,
342
372
  topk_group=self.topk_group,
343
373
  num_expert_group=self.num_expert_group,
374
+ num_fused_shared_experts=self.num_fused_shared_experts,
344
375
  correction_bias=self.correction_bias,
345
376
  routed_scaling_factor=self.routed_scaling_factor,
377
+ num_token_non_padded=forward_batch.num_token_non_padded,
378
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
379
+ layer_id=self.layer_id,
380
+ ),
346
381
  )
347
382
  else:
348
383
  topk_idx = torch.full(
@@ -363,9 +398,9 @@ class DeepseekV2MoE(nn.Module):
363
398
  masked_m,
364
399
  expected_m,
365
400
  ) = self.deepep_dispatcher.dispatch(
366
- hidden_states,
367
- topk_idx,
368
- topk_weights,
401
+ hidden_states=hidden_states,
402
+ topk_idx=topk_idx,
403
+ topk_weights=topk_weights,
369
404
  forward_mode=forward_mode,
370
405
  )
371
406
  final_hidden_states = self.experts(
@@ -381,24 +416,147 @@ class DeepseekV2MoE(nn.Module):
381
416
  )
382
417
  if self.ep_size > 1:
383
418
  final_hidden_states = self.deepep_dispatcher.combine(
384
- final_hidden_states,
385
- topk_idx,
386
- topk_weights,
387
- forward_mode,
419
+ hidden_states=final_hidden_states,
420
+ topk_idx=topk_idx,
421
+ topk_weights=topk_weights,
422
+ forward_mode=forward_mode,
388
423
  )
389
- final_hidden_states *= self.routed_scaling_factor
390
424
 
391
425
  if shared_output is not None:
392
- final_hidden_states = final_hidden_states + shared_output
426
+ x = shared_output
427
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
428
+ final_hidden_states = x
429
+ else:
430
+ final_hidden_states *= self.routed_scaling_factor
393
431
 
394
432
  return final_hidden_states
395
433
 
396
434
  def _forward_shared_experts(self, hidden_states):
397
- if self.n_share_experts_fusion == 0:
435
+ if self.num_fused_shared_experts == 0:
398
436
  return self.shared_experts(hidden_states)
399
437
  else:
400
438
  return None
401
439
 
440
+ def op_gate(self, state):
441
+ if is_non_idle_and_non_empty(
442
+ state.forward_batch.forward_mode, state.hidden_states_mlp_input
443
+ ):
444
+ # router_logits: (num_tokens, n_experts)
445
+ state.router_logits = self.gate(state.hidden_states_mlp_input)
446
+ else:
447
+ state.router_logits = None
448
+
449
+ def op_shared_experts(self, state):
450
+ hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
451
+ if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
452
+ state.forward_batch.forward_mode, hidden_states_mlp_input
453
+ ):
454
+ state.shared_output = self.shared_experts(hidden_states_mlp_input)
455
+ else:
456
+ state.shared_output = None
457
+
458
+ def op_select_experts(self, state):
459
+ router_logits = state.pop("router_logits")
460
+ hidden_states = state.hidden_states_mlp_input
461
+
462
+ if router_logits is not None:
463
+ with get_global_expert_distribution_recorder().with_current_layer(
464
+ self.layer_id
465
+ ):
466
+ state.topk_weights_local, state.topk_idx_local = select_experts(
467
+ hidden_states=hidden_states,
468
+ router_logits=router_logits,
469
+ top_k=self.top_k,
470
+ use_grouped_topk=True,
471
+ renormalize=self.renormalize,
472
+ topk_group=self.topk_group,
473
+ num_expert_group=self.num_expert_group,
474
+ num_fused_shared_experts=self.num_fused_shared_experts,
475
+ correction_bias=self.correction_bias,
476
+ routed_scaling_factor=self.routed_scaling_factor,
477
+ num_token_non_padded=state.forward_batch.num_token_non_padded,
478
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
479
+ layer_id=self.layer_id,
480
+ ),
481
+ )
482
+ else:
483
+ state.topk_idx_local = torch.full(
484
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
485
+ )
486
+ state.topk_weights_local = torch.empty(
487
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
488
+ )
489
+
490
+ def op_dispatch_a(self, state):
491
+ if self.ep_size > 1:
492
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
493
+ self.deepep_dispatcher.dispatch_a(
494
+ hidden_states=state.hidden_states_mlp_input,
495
+ topk_idx=state.pop("topk_idx_local"),
496
+ topk_weights=state.pop("topk_weights_local"),
497
+ forward_mode=state.forward_batch.forward_mode,
498
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
499
+ )
500
+
501
+ def op_dispatch_b(self, state):
502
+ if self.ep_size > 1:
503
+ with get_global_expert_distribution_recorder().with_current_layer(
504
+ self.layer_id
505
+ ):
506
+ (
507
+ state.hidden_states_experts_input,
508
+ state.topk_idx_dispatched,
509
+ state.topk_weights_dispatched,
510
+ state.reorder_topk_ids,
511
+ state.num_recv_tokens_per_expert,
512
+ state.seg_indptr,
513
+ state.masked_m,
514
+ state.expected_m,
515
+ ) = self.deepep_dispatcher.dispatch_b(
516
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
517
+ )
518
+
519
+ def op_experts(self, state):
520
+ state.hidden_states_experts_output = self.experts(
521
+ hidden_states=state.pop("hidden_states_experts_input"),
522
+ topk_idx=state.topk_idx_dispatched,
523
+ topk_weights=state.topk_weights_dispatched,
524
+ reorder_topk_ids=state.pop("reorder_topk_ids"),
525
+ seg_indptr=state.pop("seg_indptr"),
526
+ masked_m=state.pop("masked_m"),
527
+ expected_m=state.pop("expected_m"),
528
+ num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
529
+ forward_mode=state.forward_batch.forward_mode,
530
+ )
531
+
532
+ def op_combine_a(self, state):
533
+ if self.ep_size > 1:
534
+ self.deepep_dispatcher.combine_a(
535
+ hidden_states=state.pop("hidden_states_experts_output"),
536
+ topk_idx=state.pop("topk_idx_dispatched"),
537
+ topk_weights=state.pop("topk_weights_dispatched"),
538
+ forward_mode=state.forward_batch.forward_mode,
539
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
540
+ )
541
+
542
+ def op_combine_b(self, state):
543
+ if self.ep_size > 1:
544
+ state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
545
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
546
+ )
547
+
548
+ def op_output(self, state):
549
+ final_hidden_states = state.pop("hidden_states_after_combine")
550
+
551
+ if (shared_output := state.pop("shared_output")) is not None:
552
+ x = shared_output
553
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
554
+ final_hidden_states = x
555
+ else:
556
+ final_hidden_states *= self.routed_scaling_factor
557
+
558
+ state.hidden_states_mlp_output = final_hidden_states
559
+
402
560
 
403
561
  def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
404
562
  import math
@@ -550,10 +708,11 @@ class DeepseekV2AttentionMLA(nn.Module):
550
708
  )
551
709
 
552
710
  self.alt_stream = alt_stream
711
+ self.attn_mha.kv_b_proj = None
553
712
 
554
713
  self.w_kc = None
555
714
  self.w_vc = None
556
- self.w_scale = None
715
+ self.w_scale = 1.0
557
716
 
558
717
  self.w_scale_k = None
559
718
  self.w_scale_v = None
@@ -578,6 +737,18 @@ class DeepseekV2AttentionMLA(nn.Module):
578
737
  def dispatch_attn_forward_method(
579
738
  self, forward_batch: ForwardBatch
580
739
  ) -> AttnForwardMethod:
740
+ def _dispatch_mla_subtype():
741
+ if _is_hip:
742
+ if (
743
+ self.rocm_fused_decode_mla
744
+ and forward_batch.forward_mode.is_decode()
745
+ ):
746
+ return AttnForwardMethod.MLA_FUSED_ROPE
747
+ else:
748
+ return AttnForwardMethod.MLA
749
+ else:
750
+ return AttnForwardMethod.MLA
751
+
581
752
  if self.attention_backend == "flashinfer":
582
753
  # Flashinfer MLA: Do not absorb when enabling ragged prefill
583
754
  if (
@@ -589,7 +760,7 @@ class DeepseekV2AttentionMLA(nn.Module):
589
760
  ):
590
761
  return AttnForwardMethod.MHA
591
762
  else:
592
- return AttnForwardMethod.MLA
763
+ return _dispatch_mla_subtype()
593
764
  elif self.attention_backend == "fa3":
594
765
  # Flash Attention: Use MHA with chunked KV cache when prefilling on long sequences.
595
766
  if forward_batch.extend_prefix_lens_cpu is not None:
@@ -605,6 +776,15 @@ class DeepseekV2AttentionMLA(nn.Module):
605
776
  )
606
777
  ):
607
778
  return AttnForwardMethod.MHA_CHUNKED_KV
779
+ else:
780
+ return _dispatch_mla_subtype()
781
+ elif self.attention_backend == "aiter":
782
+ if (
783
+ forward_batch.forward_mode.is_extend()
784
+ and not forward_batch.forward_mode.is_target_verify()
785
+ and not forward_batch.forward_mode.is_draft_extend()
786
+ ):
787
+ return AttnForwardMethod.MHA
608
788
  else:
609
789
  return AttnForwardMethod.MLA
610
790
  else:
@@ -617,7 +797,20 @@ class DeepseekV2AttentionMLA(nn.Module):
617
797
  ):
618
798
  return AttnForwardMethod.MHA
619
799
  else:
620
- return AttnForwardMethod.MLA
800
+ return _dispatch_mla_subtype()
801
+
802
+ def op_prepare(self, state):
803
+ state.attn_intermediate_state = self.forward_prepare(
804
+ positions=state.positions,
805
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
806
+ forward_batch=state.forward_batch,
807
+ zero_allocator=state.zero_allocator,
808
+ )
809
+
810
+ def op_core(self, state):
811
+ state.hidden_states_after_attn = self.forward_core(
812
+ state.pop("attn_intermediate_state")
813
+ )
621
814
 
622
815
  def forward(
623
816
  self,
@@ -625,45 +818,78 @@ class DeepseekV2AttentionMLA(nn.Module):
625
818
  hidden_states: torch.Tensor,
626
819
  forward_batch: ForwardBatch,
627
820
  zero_allocator: BumpAllocator,
628
- ) -> torch.Tensor:
821
+ ):
822
+ s = self.forward_prepare(
823
+ positions=positions,
824
+ hidden_states=hidden_states,
825
+ forward_batch=forward_batch,
826
+ zero_allocator=zero_allocator,
827
+ )
828
+ return self.forward_core(s)
829
+
830
+ def forward_prepare(
831
+ self,
832
+ positions: torch.Tensor,
833
+ hidden_states: torch.Tensor,
834
+ forward_batch: ForwardBatch,
835
+ zero_allocator: BumpAllocator,
836
+ ):
837
+ if self.attn_mha.kv_b_proj is None:
838
+ self.attn_mha.kv_b_proj = self.kv_b_proj
839
+
629
840
  if hidden_states.shape[0] == 0:
630
841
  assert (
631
842
  not self.o_proj.reduce_results
632
843
  ), "short-circuiting allreduce will lead to hangs"
633
- return hidden_states
844
+ return hidden_states, None, forward_batch, None
634
845
 
635
846
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
636
847
 
637
848
  if attn_forward_method == AttnForwardMethod.MHA:
638
- return self.forward_normal(positions, hidden_states, forward_batch)
849
+ inner_state = self.forward_normal_prepare(
850
+ positions, hidden_states, forward_batch, zero_allocator
851
+ )
639
852
  elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
640
- return self.forward_normal_chunked_kv(
641
- positions, hidden_states, forward_batch
853
+ inner_state = self.forward_normal_chunked_kv_prepare(
854
+ positions, hidden_states, forward_batch, zero_allocator
855
+ )
856
+ elif attn_forward_method == AttnForwardMethod.MLA:
857
+ inner_state = self.forward_absorb_prepare(
858
+ positions, hidden_states, forward_batch, zero_allocator
859
+ )
860
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
861
+ inner_state = self.forward_absorb_fused_mla_rope_prepare(
862
+ positions, hidden_states, forward_batch, zero_allocator
642
863
  )
643
864
  else:
644
- if _is_hip:
645
- if (
646
- self.rocm_fused_decode_mla
647
- and forward_batch.forward_mode.is_decode()
648
- ):
649
- return self.forward_absorb_fused_mla_rope(
650
- positions, hidden_states, forward_batch
651
- )
652
- else:
653
- return self.forward_absorb(
654
- positions, hidden_states, forward_batch, zero_allocator
655
- )
656
- else:
657
- return self.forward_absorb(
658
- positions, hidden_states, forward_batch, zero_allocator
659
- )
865
+ raise NotImplementedError
866
+ return None, attn_forward_method, forward_batch, inner_state
867
+
868
+ def forward_core(self, intermediate_state):
869
+ hidden_states, attn_forward_method, forward_batch, inner_state = (
870
+ intermediate_state
871
+ )
872
+ if inner_state is None:
873
+ return hidden_states
660
874
 
661
- def forward_normal(
875
+ if attn_forward_method == AttnForwardMethod.MHA:
876
+ return self.forward_normal_core(*inner_state)
877
+ elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
878
+ return self.forward_normal_chunked_kv_core(*inner_state)
879
+ elif attn_forward_method == AttnForwardMethod.MLA:
880
+ return self.forward_absorb_core(*inner_state)
881
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
882
+ return self.forward_absorb_fused_mla_rope_core(*inner_state)
883
+ else:
884
+ raise NotImplementedError
885
+
886
+ def forward_normal_prepare(
662
887
  self,
663
888
  positions: torch.Tensor,
664
889
  hidden_states: torch.Tensor,
665
890
  forward_batch: ForwardBatch,
666
- ) -> torch.Tensor:
891
+ zero_allocator: BumpAllocator,
892
+ ):
667
893
  if self.q_lora_rank is not None:
668
894
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
669
895
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -698,18 +924,24 @@ class DeepseekV2AttentionMLA(nn.Module):
698
924
  forward_batch.token_to_kv_pool.set_kv_buffer(
699
925
  self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
700
926
  )
927
+
928
+ return q, k, v, forward_batch
929
+
930
+ def forward_normal_core(self, q, k, v, forward_batch):
701
931
  attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
702
932
  attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
703
933
  output, _ = self.o_proj(attn_output)
704
934
  return output
705
935
 
706
- def forward_absorb(
936
+ def forward_absorb_prepare(
707
937
  self,
708
938
  positions: torch.Tensor,
709
939
  hidden_states: torch.Tensor,
710
940
  forward_batch: ForwardBatch,
711
941
  zero_allocator: BumpAllocator,
712
- ) -> torch.Tensor:
942
+ ):
943
+ from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
944
+
713
945
  if self.q_lora_rank is not None:
714
946
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
715
947
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -717,7 +949,7 @@ class DeepseekV2AttentionMLA(nn.Module):
717
949
  k_nope = latent_cache[..., : self.kv_lora_rank]
718
950
 
719
951
  # overlap qk norm
720
- if self.alt_stream is not None and torch.cuda.is_current_stream_capturing():
952
+ if self.alt_stream is not None and get_is_capture_mode():
721
953
  current_stream = torch.cuda.current_stream()
722
954
  self.alt_stream.wait_stream(current_stream)
723
955
  q = self.q_a_layernorm(q)
@@ -756,8 +988,8 @@ class DeepseekV2AttentionMLA(nn.Module):
756
988
  expected_m,
757
989
  )
758
990
  q_nope_out = q_nope_out[:, :expected_m, :]
759
- elif self.w_kc.dtype == torch.float8_e4m3fnuz:
760
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
991
+ elif _is_hip:
992
+ # TODO(haishaw): add bmm_fp8 to ROCm
761
993
  q_nope_out = torch.bmm(
762
994
  q_nope.to(torch.bfloat16).transpose(0, 1),
763
995
  self.w_kc.to(torch.bfloat16) * self.w_scale,
@@ -776,6 +1008,11 @@ class DeepseekV2AttentionMLA(nn.Module):
776
1008
  q_nope_out = q_nope_out.transpose(0, 1)
777
1009
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
778
1010
 
1011
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1012
+
1013
+ def forward_absorb_core(
1014
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1015
+ ):
779
1016
  if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
780
1017
  attn_output = self.attn_mqa(
781
1018
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
@@ -803,8 +1040,8 @@ class DeepseekV2AttentionMLA(nn.Module):
803
1040
  expected_m,
804
1041
  )
805
1042
  attn_bmm_output = attn_bmm_output[:, :expected_m, :]
806
- elif self.w_vc.dtype == torch.float8_e4m3fnuz:
807
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1043
+ elif _is_hip:
1044
+ # TODO(haishaw): add bmm_fp8 to ROCm
808
1045
  attn_bmm_output = torch.bmm(
809
1046
  attn_output.to(torch.bfloat16).transpose(0, 1),
810
1047
  self.w_vc.to(torch.bfloat16) * self.w_scale,
@@ -828,13 +1065,13 @@ class DeepseekV2AttentionMLA(nn.Module):
828
1065
 
829
1066
  return output
830
1067
 
831
- def forward_absorb_fused_mla_rope(
1068
+ def forward_absorb_fused_mla_rope_prepare(
832
1069
  self,
833
1070
  positions: torch.Tensor,
834
1071
  hidden_states: torch.Tensor,
835
1072
  forward_batch: ForwardBatch,
836
1073
  zero_allocator: BumpAllocator,
837
- ) -> torch.Tensor:
1074
+ ):
838
1075
  enable_rope_fusion = (
839
1076
  os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
840
1077
  )
@@ -855,8 +1092,8 @@ class DeepseekV2AttentionMLA(nn.Module):
855
1092
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
856
1093
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
857
1094
 
858
- if self.w_kc.dtype == torch.float8_e4m3fnuz:
859
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1095
+ if _is_hip:
1096
+ # TODO(haishaw): add bmm_fp8 to ROCm
860
1097
  q_nope_out = torch.bmm(
861
1098
  q_nope.to(torch.bfloat16).transpose(0, 1),
862
1099
  self.w_kc.to(torch.bfloat16) * self.w_scale,
@@ -923,6 +1160,44 @@ class DeepseekV2AttentionMLA(nn.Module):
923
1160
  )
924
1161
  val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
925
1162
 
1163
+ return (
1164
+ q_input,
1165
+ key_cache_buf,
1166
+ val_cache_buf,
1167
+ attn_output,
1168
+ kv_indptr,
1169
+ kv_indices,
1170
+ k_pe_output,
1171
+ cos_sin_cache,
1172
+ positions,
1173
+ attn_logits,
1174
+ num_kv_split,
1175
+ sm_scale,
1176
+ enable_rope_fusion,
1177
+ k_input,
1178
+ forward_batch,
1179
+ zero_allocator,
1180
+ )
1181
+
1182
+ def forward_absorb_fused_mla_rope_core(
1183
+ self,
1184
+ q_input,
1185
+ key_cache_buf,
1186
+ val_cache_buf,
1187
+ attn_output,
1188
+ kv_indptr,
1189
+ kv_indices,
1190
+ k_pe_output,
1191
+ cos_sin_cache,
1192
+ positions,
1193
+ attn_logits,
1194
+ num_kv_split,
1195
+ sm_scale,
1196
+ enable_rope_fusion,
1197
+ k_input,
1198
+ forward_batch,
1199
+ zero_allocator,
1200
+ ):
926
1201
  decode_attention_fwd_grouped_rope(
927
1202
  q_input,
928
1203
  key_cache_buf,
@@ -951,8 +1226,8 @@ class DeepseekV2AttentionMLA(nn.Module):
951
1226
 
952
1227
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
953
1228
 
954
- if self.w_vc.dtype == torch.float8_e4m3fnuz:
955
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1229
+ if _is_hip:
1230
+ # TODO(haishaw): add bmm_fp8 to ROCm
956
1231
  attn_bmm_output = torch.bmm(
957
1232
  attn_output.to(torch.bfloat16).transpose(0, 1),
958
1233
  self.w_vc.to(torch.bfloat16) * self.w_scale,
@@ -1029,12 +1304,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1029
1304
 
1030
1305
  return accum_output
1031
1306
 
1032
- def forward_normal_chunked_kv(
1307
+ def forward_normal_chunked_kv_prepare(
1033
1308
  self,
1034
1309
  positions: torch.Tensor,
1035
1310
  hidden_states: torch.Tensor,
1036
1311
  forward_batch: ForwardBatch,
1037
- ) -> torch.Tensor:
1312
+ zero_allocator: BumpAllocator,
1313
+ ):
1038
1314
  # In normal mha, the k and v tensors will become overly large when the prefix length is long.
1039
1315
  # To avoid this, we split the kv cache into chunks and process them one after another.
1040
1316
  # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
@@ -1077,6 +1353,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1077
1353
  self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1078
1354
  )
1079
1355
 
1356
+ return q, k, v, forward_batch
1357
+
1358
+ def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1080
1359
  # Do mha for extended part without prefix
1081
1360
  forward_batch.set_attn_attend_prefix_cache(False)
1082
1361
  attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
@@ -1101,19 +1380,6 @@ class DeepseekV2AttentionMLA(nn.Module):
1101
1380
  return output
1102
1381
 
1103
1382
 
1104
- class _FFNInputMode(Enum):
1105
- # The MLP sublayer requires 1/tp_size tokens as input
1106
- SCATTERED = auto()
1107
- # The MLP sublayer requires all tokens as input
1108
- FULL = auto()
1109
-
1110
-
1111
- @dataclass
1112
- class _DecoderLayerInfo:
1113
- is_sparse: bool
1114
- ffn_input_mode: _FFNInputMode
1115
-
1116
-
1117
1383
  class DeepseekV2DecoderLayer(nn.Module):
1118
1384
 
1119
1385
  def __init__(
@@ -1127,14 +1393,12 @@ class DeepseekV2DecoderLayer(nn.Module):
1127
1393
  ) -> None:
1128
1394
  super().__init__()
1129
1395
  self.hidden_size = config.hidden_size
1396
+ self.config = config
1130
1397
  rope_theta = getattr(config, "rope_theta", 10000)
1131
1398
  rope_scaling = getattr(config, "rope_scaling", None)
1132
1399
  max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
1133
1400
  self.enable_dp_attention = global_server_args_dict["enable_dp_attention"]
1134
1401
  self.layer_id = layer_id
1135
- self.local_dp_size = get_local_attention_dp_size()
1136
- self.attn_tp_size = get_attention_tp_size()
1137
- self.attn_tp_rank = get_attention_tp_rank()
1138
1402
  self.self_attn = DeepseekV2AttentionMLA(
1139
1403
  config=config,
1140
1404
  hidden_size=self.hidden_size,
@@ -1156,19 +1420,25 @@ class DeepseekV2DecoderLayer(nn.Module):
1156
1420
  alt_stream=alt_stream,
1157
1421
  )
1158
1422
 
1159
- self.info = self._compute_info(config, layer_id=layer_id, is_nextn=is_nextn)
1160
- previous_layer_info = self._compute_info(
1161
- config, layer_id=layer_id - 1, is_nextn=False
1423
+ self.is_layer_sparse = self._is_layer_sparse(layer_id, is_nextn=is_nextn)
1424
+ is_previous_layer_sparse = self._is_layer_sparse(layer_id - 1, is_nextn=False)
1425
+
1426
+ self.layer_scatter_modes = LayerScatterModes.init_new(
1427
+ layer_id=layer_id,
1428
+ num_layers=config.num_hidden_layers,
1429
+ is_layer_sparse=self.is_layer_sparse,
1430
+ is_previous_layer_sparse=is_previous_layer_sparse,
1162
1431
  )
1163
1432
 
1164
- if self.info.is_sparse:
1433
+ if self.is_layer_sparse:
1165
1434
  self.mlp = DeepseekV2MoE(
1166
1435
  config=config,
1167
1436
  quant_config=quant_config,
1168
1437
  prefix=add_prefix("mlp", prefix),
1438
+ layer_id=self.layer_id,
1169
1439
  )
1170
1440
  else:
1171
- if self._enable_moe_dense_fully_dp():
1441
+ if enable_moe_dense_fully_dp():
1172
1442
  mlp_tp_rank, mlp_tp_size = 0, 1
1173
1443
  else:
1174
1444
  mlp_tp_rank, mlp_tp_size = None, None
@@ -1182,35 +1452,23 @@ class DeepseekV2DecoderLayer(nn.Module):
1182
1452
  tp_size=mlp_tp_size,
1183
1453
  )
1184
1454
 
1185
- self.input_is_scattered = (
1186
- layer_id > 0
1187
- and previous_layer_info.ffn_input_mode == _FFNInputMode.SCATTERED
1188
- )
1189
- self.is_last_layer = self.layer_id == config.num_hidden_layers - 1
1190
-
1191
1455
  self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
1192
1456
  self.post_attention_layernorm = RMSNorm(
1193
1457
  config.hidden_size, eps=config.rms_norm_eps
1194
1458
  )
1195
1459
 
1196
- @staticmethod
1197
- def _enable_moe_dense_fully_dp():
1198
- return global_server_args_dict["moe_dense_tp_size"] == 1
1199
-
1200
- @staticmethod
1201
- def _compute_info(config: PretrainedConfig, layer_id: int, is_nextn: bool):
1202
- is_sparse = is_nextn or (
1203
- config.n_routed_experts is not None
1204
- and layer_id >= config.first_k_dense_replace
1205
- and layer_id % config.moe_layer_freq == 0
1460
+ self.layer_communicator = LayerCommunicator(
1461
+ layer_scatter_modes=self.layer_scatter_modes,
1462
+ input_layernorm=self.input_layernorm,
1463
+ post_attention_layernorm=self.post_attention_layernorm,
1206
1464
  )
1207
- ffn_input_mode = (
1208
- _FFNInputMode.SCATTERED
1209
- if (global_server_args_dict["enable_deepep_moe"] and is_sparse)
1210
- or (DeepseekV2DecoderLayer._enable_moe_dense_fully_dp() and not is_sparse)
1211
- else _FFNInputMode.FULL
1465
+
1466
+ def _is_layer_sparse(self, layer_id: int, is_nextn: bool) -> bool:
1467
+ return is_nextn or (
1468
+ self.config.n_routed_experts is not None
1469
+ and layer_id >= self.config.first_k_dense_replace
1470
+ and layer_id % self.config.moe_layer_freq == 0
1212
1471
  )
1213
- return _DecoderLayerInfo(is_sparse=is_sparse, ffn_input_mode=ffn_input_mode)
1214
1472
 
1215
1473
  def forward(
1216
1474
  self,
@@ -1220,164 +1478,98 @@ class DeepseekV2DecoderLayer(nn.Module):
1220
1478
  residual: Optional[torch.Tensor],
1221
1479
  zero_allocator: BumpAllocator,
1222
1480
  ) -> torch.Tensor:
1223
- if self.info.ffn_input_mode == _FFNInputMode.SCATTERED:
1224
- return self.forward_ffn_with_scattered_input(
1225
- positions, hidden_states, forward_batch, residual, zero_allocator
1226
- )
1227
- elif self.info.ffn_input_mode == _FFNInputMode.FULL:
1228
- return self.forward_ffn_with_full_input(
1229
- positions, hidden_states, forward_batch, residual, zero_allocator
1230
- )
1231
- else:
1232
- raise NotImplementedError
1233
-
1234
- def forward_ffn_with_full_input(
1235
- self,
1236
- positions: torch.Tensor,
1237
- hidden_states: torch.Tensor,
1238
- forward_batch: ForwardBatch,
1239
- residual: Optional[torch.Tensor],
1240
- zero_allocator: BumpAllocator,
1241
- ) -> torch.Tensor:
1242
-
1243
- if hidden_states.shape[0] == 0:
1244
- residual = hidden_states
1245
- else:
1246
- if residual is None:
1247
- residual = hidden_states
1248
- hidden_states = self.input_layernorm(hidden_states)
1249
- else:
1250
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
1251
-
1252
- assert not (
1253
- self.attn_tp_size != 1 and self.input_is_scattered
1254
- ), "moe_layer_freq > 1 is not supported when attn_tp_size > 1"
1481
+ hidden_states, residual = self.layer_communicator.prepare_attn(
1482
+ hidden_states, residual, forward_batch
1483
+ )
1255
1484
 
1256
- # Self Attention
1257
- hidden_states = self.self_attn(
1258
- positions=positions,
1259
- hidden_states=hidden_states,
1260
- forward_batch=forward_batch,
1261
- zero_allocator=zero_allocator,
1262
- )
1485
+ hidden_states = self.self_attn(
1486
+ positions=positions,
1487
+ hidden_states=hidden_states,
1488
+ forward_batch=forward_batch,
1489
+ zero_allocator=zero_allocator,
1490
+ )
1263
1491
 
1264
- # Gather
1265
- if get_tensor_model_parallel_world_size() > 1:
1266
- # all gather and all reduce
1267
- if self.local_dp_size != 1:
1268
- if self.attn_tp_rank == 0:
1269
- hidden_states += residual
1270
- hidden_states, local_hidden_states = (
1271
- forward_batch.gathered_buffer,
1272
- hidden_states,
1273
- )
1274
- dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
1275
- dp_scatter(residual, hidden_states, forward_batch)
1276
- hidden_states = self.post_attention_layernorm(hidden_states)
1277
- else:
1278
- hidden_states = tensor_model_parallel_all_reduce(hidden_states)
1279
- hidden_states, residual = self.post_attention_layernorm(
1280
- hidden_states, residual
1281
- )
1282
- else:
1283
- hidden_states, residual = self.post_attention_layernorm(
1284
- hidden_states, residual
1285
- )
1492
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
1493
+ hidden_states, residual, forward_batch
1494
+ )
1286
1495
 
1287
- # Fully Connected
1288
- hidden_states = self.mlp(hidden_states)
1496
+ hidden_states = self.mlp(hidden_states, forward_batch)
1289
1497
 
1290
- # TODO(ch-wan): use reduce-scatter in MLP to avoid this scatter
1291
- # Scatter
1292
- if self.local_dp_size != 1:
1293
- # important: forward batch.gathered_buffer is used both after scatter and after gather.
1294
- # be careful about this!
1295
- hidden_states, global_hidden_states = (
1296
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1297
- hidden_states,
1298
- )
1299
- dp_scatter(hidden_states, global_hidden_states, forward_batch)
1498
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1499
+ hidden_states, residual, forward_batch
1500
+ )
1300
1501
 
1301
1502
  return hidden_states, residual
1302
1503
 
1303
- def forward_ffn_with_scattered_input(
1504
+ def op_comm_prepare_attn(
1304
1505
  self,
1506
+ state,
1305
1507
  positions: torch.Tensor,
1306
1508
  hidden_states: torch.Tensor,
1307
1509
  forward_batch: ForwardBatch,
1308
1510
  residual: Optional[torch.Tensor],
1309
1511
  zero_allocator: BumpAllocator,
1310
- ) -> torch.Tensor:
1311
-
1312
- if hidden_states.shape[0] == 0:
1313
- residual = hidden_states
1314
- else:
1315
- if residual is None:
1316
- residual = hidden_states
1317
- hidden_states = self.input_layernorm(hidden_states)
1318
- else:
1319
- hidden_states, residual = self.input_layernorm(hidden_states, residual)
1320
-
1321
- if self.attn_tp_size != 1 and self.input_is_scattered:
1322
- hidden_states, local_hidden_states = (
1323
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1324
- hidden_states,
1325
- )
1326
- attn_tp_all_gather(
1327
- list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1512
+ tbo_subbatch_index: Optional[int] = None,
1513
+ ):
1514
+ state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
1515
+ self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
1516
+ )
1517
+ state.update(
1518
+ dict(
1519
+ forward_batch=forward_batch,
1520
+ positions=positions,
1521
+ zero_allocator=zero_allocator,
1522
+ tbo_subbatch_index=tbo_subbatch_index,
1328
1523
  )
1329
-
1330
- # Self Attention
1331
- hidden_states = self.self_attn(
1332
- positions=positions,
1333
- hidden_states=hidden_states,
1334
- forward_batch=forward_batch,
1335
- zero_allocator=zero_allocator,
1336
1524
  )
1337
1525
 
1338
- if self.attn_tp_size != 1:
1339
- if self.input_is_scattered:
1340
- tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1341
- hidden_states = tensor_list[self.attn_tp_rank]
1342
- attn_tp_reduce_scatter(hidden_states, tensor_list)
1343
- if hidden_states.shape[0] != 0:
1344
- hidden_states, residual = self.post_attention_layernorm(
1345
- hidden_states, residual
1346
- )
1347
- else:
1348
- if self.attn_tp_rank == 0:
1349
- hidden_states += residual
1350
- tensor_list = list(hidden_states.tensor_split(self.attn_tp_size))
1351
- hidden_states = tensor_list[self.attn_tp_rank]
1352
- attn_tp_reduce_scatter(hidden_states, tensor_list)
1353
- residual = hidden_states
1354
- if hidden_states.shape[0] != 0:
1355
- hidden_states = self.post_attention_layernorm(hidden_states)
1356
- else:
1357
- if hidden_states.shape[0] != 0:
1358
- hidden_states, residual = self.post_attention_layernorm(
1359
- hidden_states, residual
1360
- )
1526
+ def op_comm_prepare_mlp(self, state):
1527
+ state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
1528
+ self.layer_communicator.prepare_mlp(
1529
+ state.pop("hidden_states_after_attn"),
1530
+ state.pop("residual_after_input_ln"),
1531
+ state.forward_batch,
1532
+ )
1533
+ )
1361
1534
 
1535
+ def op_mlp(self, state):
1536
+ hidden_states = state.pop("hidden_states_mlp_input")
1362
1537
  if not (
1363
- self._enable_moe_dense_fully_dp()
1364
- and (not self.info.is_sparse)
1538
+ enable_moe_dense_fully_dp()
1539
+ and (not self.is_layer_sparse)
1365
1540
  and hidden_states.shape[0] == 0
1366
1541
  ):
1367
- hidden_states = self.mlp(hidden_states, forward_batch.forward_mode)
1368
-
1369
- if self.is_last_layer and self.attn_tp_size != 1:
1370
- hidden_states += residual
1371
- residual = None
1372
- hidden_states, local_hidden_states = (
1373
- forward_batch.gathered_buffer[: forward_batch.input_ids.shape[0]],
1374
- hidden_states,
1375
- )
1376
- attn_tp_all_gather(
1377
- list(hidden_states.tensor_split(self.attn_tp_size)), local_hidden_states
1542
+ state.hidden_states_mlp_output = self.mlp(
1543
+ hidden_states, state.forward_batch.forward_mode
1378
1544
  )
1545
+ else:
1546
+ state.hidden_states_mlp_output = hidden_states
1379
1547
 
1380
- return hidden_states, residual
1548
+ def op_comm_postprocess_layer(self, state):
1549
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1550
+ state.pop("hidden_states_mlp_output"),
1551
+ state.pop("residual_after_comm_pre_mlp"),
1552
+ state.forward_batch,
1553
+ )
1554
+
1555
+ output = dict(
1556
+ positions=state.positions,
1557
+ hidden_states=hidden_states,
1558
+ residual=residual,
1559
+ forward_batch=state.forward_batch,
1560
+ zero_allocator=state.zero_allocator,
1561
+ tbo_subbatch_index=state.tbo_subbatch_index,
1562
+ )
1563
+
1564
+ state.clear(
1565
+ expect_keys={
1566
+ "positions",
1567
+ "forward_batch",
1568
+ "zero_allocator",
1569
+ "tbo_subbatch_index",
1570
+ }
1571
+ )
1572
+ return output
1381
1573
 
1382
1574
 
1383
1575
  class DeepseekV2Model(nn.Module):
@@ -1392,13 +1584,14 @@ class DeepseekV2Model(nn.Module):
1392
1584
  super().__init__()
1393
1585
  self.padding_id = config.pad_token_id
1394
1586
  self.vocab_size = config.vocab_size
1587
+ self.first_k_dense_replace = config.first_k_dense_replace
1395
1588
 
1396
1589
  self.embed_tokens = VocabParallelEmbedding(
1397
1590
  config.vocab_size,
1398
1591
  config.hidden_size,
1399
1592
  enable_tp=not global_server_args_dict["enable_dp_attention"],
1400
1593
  )
1401
- self.alt_stream = torch.cuda.Stream()
1594
+ self.alt_stream = torch.cuda.Stream() if _is_cuda else None
1402
1595
  self.layers = nn.ModuleList(
1403
1596
  [
1404
1597
  DeepseekV2DecoderLayer(
@@ -1425,13 +1618,12 @@ class DeepseekV2Model(nn.Module):
1425
1618
  forward_batch: ForwardBatch,
1426
1619
  input_embeds: torch.Tensor = None,
1427
1620
  ) -> torch.Tensor:
1621
+ total_num_layers = len(self.layers)
1622
+ device = input_embeds.device if input_embeds is not None else input_ids.device
1428
1623
  zero_allocator = BumpAllocator(
1429
- # TODO for two-batch-overlap, we need a larger buffer size
1430
- buffer_size=len(self.layers) * 2,
1624
+ buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
1431
1625
  dtype=torch.float32,
1432
- device=(
1433
- input_embeds.device if input_embeds is not None else input_ids.device
1434
- ),
1626
+ device=device,
1435
1627
  )
1436
1628
 
1437
1629
  if input_embeds is None:
@@ -1440,12 +1632,33 @@ class DeepseekV2Model(nn.Module):
1440
1632
  hidden_states = input_embeds
1441
1633
 
1442
1634
  residual = None
1443
- for i in range(len(self.layers)):
1444
- expert_distribution_recorder.set_current_layer(i)
1445
- layer = self.layers[i]
1446
- hidden_states, residual = layer(
1447
- positions, hidden_states, forward_batch, residual, zero_allocator
1635
+
1636
+ normal_num_layers = (
1637
+ self.first_k_dense_replace
1638
+ if forward_batch.can_run_tbo
1639
+ else total_num_layers
1640
+ )
1641
+ for i in range(normal_num_layers):
1642
+ with get_global_expert_distribution_recorder().with_current_layer(i):
1643
+ layer = self.layers[i]
1644
+ hidden_states, residual = layer(
1645
+ positions, hidden_states, forward_batch, residual, zero_allocator
1646
+ )
1647
+
1648
+ if normal_num_layers != total_num_layers:
1649
+ hidden_states, residual = model_forward_maybe_tbo(
1650
+ layers=self.layers[normal_num_layers:],
1651
+ enable_tbo=True,
1652
+ positions=positions,
1653
+ forward_batch=forward_batch,
1654
+ hidden_states=hidden_states,
1655
+ residual=residual,
1656
+ input_data_scatter_mode=self.layers[
1657
+ normal_num_layers - 1
1658
+ ].layer_scatter_modes.layer_output_mode,
1659
+ zero_allocator=zero_allocator,
1448
1660
  )
1661
+
1449
1662
  if not forward_batch.forward_mode.is_idle():
1450
1663
  if residual is None:
1451
1664
  hidden_states = self.norm(hidden_states)
@@ -1466,7 +1679,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1466
1679
  self.config = config
1467
1680
  self.tp_size = get_tensor_model_parallel_world_size()
1468
1681
  self.quant_config = quant_config
1469
- self.determine_n_share_experts_fusion()
1682
+ self.determine_num_fused_shared_experts()
1470
1683
  self.model = DeepseekV2Model(
1471
1684
  config, quant_config, prefix=add_prefix("model", prefix)
1472
1685
  )
@@ -1480,40 +1693,67 @@ class DeepseekV2ForCausalLM(nn.Module):
1480
1693
  self.logits_processor = LogitsProcessor(config)
1481
1694
  self.dp_size = get_local_attention_dp_size()
1482
1695
 
1483
- def determine_n_share_experts_fusion(
1696
+ self._routed_experts_weights_of_layer = LazyValue(
1697
+ lambda: {
1698
+ layer_id: layer.mlp.get_moe_weights()
1699
+ for layer_id, layer in enumerate(self.model.layers)
1700
+ if isinstance(layer.mlp, DeepseekV2MoE)
1701
+ }
1702
+ )
1703
+
1704
+ @property
1705
+ def routed_experts_weights_of_layer(self):
1706
+ return self._routed_experts_weights_of_layer.value
1707
+
1708
+ def determine_num_fused_shared_experts(
1484
1709
  self, architecture: str = "DeepseekV3ForCausalLM"
1485
1710
  ):
1486
- self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1487
- if self.n_share_experts_fusion > 0:
1711
+ self.num_fused_shared_experts = (
1712
+ 0
1713
+ if global_server_args_dict["disable_shared_experts_fusion"]
1714
+ else self.config.n_shared_experts
1715
+ )
1716
+ if self.num_fused_shared_experts > 0:
1488
1717
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1489
1718
  if (
1490
1719
  not _is_cuda
1491
1720
  or self.config.architectures[0] != architecture
1492
1721
  or self.config.n_routed_experts != 256
1493
1722
  ):
1494
- self.n_share_experts_fusion = 0
1495
- global_server_args_dict["n_share_experts_fusion"] = 0
1723
+ self.num_fused_shared_experts = 0
1724
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1496
1725
  log_info_on_rank0(
1497
1726
  logger,
1498
1727
  "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1499
1728
  )
1500
- else:
1501
- assert (
1502
- self.n_share_experts_fusion == self.tp_size
1503
- ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1504
- elif self.n_share_experts_fusion == 0:
1729
+ elif (
1730
+ global_server_args_dict["enable_deepep_moe"]
1731
+ or global_server_args_dict["enable_ep_moe"]
1732
+ ):
1733
+ self.num_fused_shared_experts = 0
1734
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1735
+ log_info_on_rank0(
1736
+ logger,
1737
+ "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.",
1738
+ )
1739
+ elif self.num_fused_shared_experts == 0:
1505
1740
  if (
1506
1741
  _is_cuda
1507
1742
  and torch.cuda.get_device_capability("cuda") >= (9, 0)
1508
1743
  and self.config.architectures[0] == architecture
1509
1744
  and self.config.n_routed_experts == 256
1510
- and (not global_server_args_dict["enable_deepep_moe"])
1745
+ and (
1746
+ not (
1747
+ global_server_args_dict["enable_deepep_moe"]
1748
+ or global_server_args_dict["enable_ep_moe"]
1749
+ )
1750
+ )
1511
1751
  ):
1512
- self.n_share_experts_fusion = self.tp_size
1513
- global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1752
+ self.num_fused_shared_experts = self.config.n_shared_experts
1753
+ global_server_args_dict["disable_shared_experts_fusion"] = False
1514
1754
  log_info_on_rank0(
1515
1755
  logger,
1516
- "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1756
+ "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1517
1757
  )
1518
1758
 
1519
1759
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1527,21 +1767,29 @@ class DeepseekV2ForCausalLM(nn.Module):
1527
1767
  forward_batch: ForwardBatch,
1528
1768
  input_embeds: torch.Tensor = None,
1529
1769
  ) -> torch.Tensor:
1530
-
1531
1770
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
1532
1771
 
1533
1772
  return self.logits_processor(
1534
1773
  input_ids, hidden_states, self.lm_head, forward_batch
1535
1774
  )
1536
1775
 
1537
- def post_load_weights(self, is_nextn=False):
1776
+ def post_load_weights(self, is_nextn=False, weight_names=None):
1538
1777
 
1539
1778
  # Perform post-processing after loading weights
1540
- layer_ids = (
1541
- range(self.config.num_hidden_layers)
1542
- if not is_nextn
1543
- else [self.config.num_hidden_layers]
1544
- )
1779
+ if is_nextn:
1780
+ layer_ids = [self.config.num_hidden_layers]
1781
+ else:
1782
+ if weight_names is None:
1783
+ layer_ids = range(self.config.num_hidden_layers)
1784
+ else:
1785
+ layer_ids = set()
1786
+ for name in weight_names:
1787
+ if "kv_b_proj" in name:
1788
+ layer_id = int(name.split(".")[2])
1789
+ # filter the nextn layer.
1790
+ if layer_id != self.config.num_hidden_layers:
1791
+ layer_ids.add(layer_id)
1792
+
1545
1793
  for layer_id in layer_ids:
1546
1794
  self_attn = (
1547
1795
  self.model.layers[layer_id].self_attn
@@ -1577,46 +1825,56 @@ class DeepseekV2ForCausalLM(nn.Module):
1577
1825
  torch.float8_e4m3fn,
1578
1826
  torch.float8_e4m3fnuz,
1579
1827
  ):
1580
- if hasattr(self.quant_config, "weight_block_size"):
1828
+ if (
1829
+ hasattr(self.quant_config, "weight_block_size")
1830
+ and self.quant_config.weight_block_size is not None
1831
+ ):
1581
1832
  weight_block_size = self.quant_config.weight_block_size
1582
- if weight_block_size is not None:
1583
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1584
- if _is_hip:
1585
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1586
- weight=w,
1587
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1588
- input_scale=None,
1589
- )
1590
- else:
1591
- weight = w
1592
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1833
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1834
+ if _is_fp8_fnuz:
1835
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1836
+ weight=w,
1837
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1838
+ input_scale=None,
1839
+ )
1840
+ else:
1841
+ weight = w
1842
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1593
1843
 
1594
- if (
1595
- _is_cuda
1596
- and weight_block_size[0] == 128
1597
- and weight_block_size[1] == 128
1598
- and model_dtype == torch.bfloat16
1844
+ if (
1845
+ _is_cuda
1846
+ and weight_block_size[0] == 128
1847
+ and weight_block_size[1] == 128
1848
+ and model_dtype == torch.bfloat16
1849
+ ):
1850
+ if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1851
+ "SGL_USE_DEEPGEMM_BMM", "false"
1599
1852
  ):
1600
- if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1601
- "SGL_USE_DEEPGEMM_BMM", "false"
1602
- ):
1603
- block_scale = weight_scale
1604
- use_deep_gemm_bmm = True
1605
- else:
1606
- w = block_quant_dequant(
1607
- weight,
1608
- weight_scale,
1609
- weight_block_size,
1610
- model_dtype,
1611
- )
1853
+ block_scale = weight_scale
1854
+ use_deep_gemm_bmm = True
1612
1855
  else:
1613
- w, scale = block_quant_to_tensor_quant(
1614
- weight, weight_scale, weight_block_size
1856
+ w = block_quant_dequant(
1857
+ weight,
1858
+ weight_scale,
1859
+ weight_block_size,
1860
+ model_dtype,
1615
1861
  )
1616
- self_attn.w_scale = scale
1862
+ else:
1863
+ w, scale = block_quant_to_tensor_quant(
1864
+ weight, weight_scale, weight_block_size
1865
+ )
1866
+ self_attn.w_scale = scale
1617
1867
  else:
1618
- weight = w
1619
- weight_scale = self_attn.kv_b_proj.weight_scale
1868
+ if _is_fp8_fnuz:
1869
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1870
+ weight=w,
1871
+ weight_scale=self_attn.kv_b_proj.weight_scale,
1872
+ input_scale=None,
1873
+ )
1874
+ else:
1875
+ weight = w
1876
+ weight_scale = self_attn.kv_b_proj.weight_scale
1877
+
1620
1878
  w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1621
1879
  self_attn.w_scale = scale
1622
1880
 
@@ -1641,13 +1899,19 @@ class DeepseekV2ForCausalLM(nn.Module):
1641
1899
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1642
1900
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1643
1901
  if not use_deep_gemm_bmm:
1644
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1645
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1902
+ self_attn.w_kc = bind_or_assign(
1903
+ self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1904
+ )
1905
+ self_attn.w_vc = bind_or_assign(
1906
+ self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
1907
+ )
1646
1908
  if (
1647
1909
  hasattr(self_attn.kv_b_proj, "weight_scale")
1648
1910
  and self_attn.w_scale is None
1649
1911
  ):
1650
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1912
+ self_attn.w_scale = bind_or_assign(
1913
+ self_attn.w_scale, self_attn.kv_b_proj.weight_scale
1914
+ )
1651
1915
  if _is_hip:
1652
1916
  self_attn.w_scale *= 2.0
1653
1917
  else:
@@ -1656,13 +1920,20 @@ class DeepseekV2ForCausalLM(nn.Module):
1656
1920
  ws_kc, ws_vc = block_scale.unflatten(
1657
1921
  0, (-1, (num_tiles_k + num_tiles_n))
1658
1922
  ).split([num_tiles_k, num_tiles_n], dim=1)
1659
- self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1660
- self_attn.w_scale_v = ws_vc.contiguous()
1661
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1662
- self_attn.w_vc = w_vc.contiguous()
1923
+ self_attn.w_scale_k = bind_or_assign(
1924
+ self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
1925
+ )
1926
+ self_attn.w_scale_v = bind_or_assign(
1927
+ self_attn.w_scale_v, ws_vc.contiguous()
1928
+ )
1929
+ self_attn.w_kc = bind_or_assign(
1930
+ self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
1931
+ )
1932
+ self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
1663
1933
  self_attn.use_deep_gemm_bmm = True
1664
1934
 
1665
1935
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1936
+
1666
1937
  if is_nextn:
1667
1938
  if hasattr(self.config, "num_nextn_predict_layers"):
1668
1939
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1681,26 +1952,68 @@ class DeepseekV2ForCausalLM(nn.Module):
1681
1952
  ("gate_up_proj", "gate_proj", 0),
1682
1953
  ("gate_up_proj", "up_proj", 1),
1683
1954
  ]
1684
- if self.n_share_experts_fusion > 0:
1955
+ if self.num_fused_shared_experts > 0:
1956
+ assert self.num_fused_shared_experts == 1
1685
1957
  weights_list = list(weights)
1686
1958
  weights_dict = dict(weights_list)
1687
- if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1688
- suffix_list = [
1689
- "down_proj.weight",
1690
- "down_proj.weight_scale",
1691
- "gate_proj.weight",
1692
- "gate_proj.weight_scale",
1693
- "up_proj.weight",
1694
- "up_proj.weight_scale",
1695
- ]
1959
+ if self.quant_config is not None:
1960
+ if self.quant_config.get_name() == "w8a8_int8":
1961
+ suffix_list = [
1962
+ "down_proj.weight",
1963
+ "down_proj.weight_scale",
1964
+ "gate_proj.weight",
1965
+ "gate_proj.weight_scale",
1966
+ "up_proj.weight",
1967
+ "up_proj.weight_scale",
1968
+ ]
1969
+ elif (
1970
+ self.quant_config.get_name() == "fp8"
1971
+ or self.quant_config.get_name() == "blockwise_int8"
1972
+ ):
1973
+ suffix_list = [
1974
+ "down_proj.weight",
1975
+ "down_proj.weight_scale_inv",
1976
+ "gate_proj.weight",
1977
+ "gate_proj.weight_scale_inv",
1978
+ "up_proj.weight",
1979
+ "up_proj.weight_scale_inv",
1980
+ ]
1981
+ elif self.quant_config.get_name() == "awq":
1982
+ suffix_list = [
1983
+ "down_proj.qweight",
1984
+ "down_proj.qzeros",
1985
+ "down_proj.scales",
1986
+ "gate_proj.qweight",
1987
+ "gate_proj.qzeros",
1988
+ "gate_proj.scales",
1989
+ "up_proj.qweight",
1990
+ "up_proj.qzeros",
1991
+ "up_proj.scales",
1992
+ ]
1993
+ elif self.quant_config.get_name() == "modelopt_fp4":
1994
+ suffix_list = [
1995
+ "down_proj.weight",
1996
+ "down_proj.weight_scale",
1997
+ "down_proj.weight_scale_2",
1998
+ "down_proj.input_scale",
1999
+ "gate_proj.weight",
2000
+ "gate_proj.weight_scale",
2001
+ "gate_proj.weight_scale_2",
2002
+ "gate_proj.input_scale",
2003
+ "up_proj.weight",
2004
+ "up_proj.weight_scale",
2005
+ "up_proj.weight_scale_2",
2006
+ "up_proj.input_scale",
2007
+ ]
2008
+ else:
2009
+ raise ValueError(
2010
+ f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
2011
+ )
1696
2012
  else:
1697
2013
  suffix_list = [
1698
2014
  "down_proj.weight",
1699
- "down_proj.weight_scale_inv",
1700
2015
  "gate_proj.weight",
1701
- "gate_proj.weight_scale_inv",
1702
2016
  "up_proj.weight",
1703
- "up_proj.weight_scale_inv",
1704
2017
  ]
1705
2018
  names_to_remove = []
1706
2019
 
@@ -1716,38 +2029,32 @@ class DeepseekV2ForCausalLM(nn.Module):
1716
2029
 
1717
2030
  for moe_layer in tqdm(
1718
2031
  moe_layers,
1719
- desc=f"Cloning {self.n_share_experts_fusion} "
1720
- "replicas of the shared expert into MoE",
2032
+ desc=f"Cloning {self.num_fused_shared_experts} "
2033
+ "shared expert into MoE",
1721
2034
  ):
1722
2035
  for suffix in suffix_list:
1723
2036
  shared_expert_weight_name = (
1724
2037
  f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1725
2038
  )
1726
- for num_repeat in range(self.n_share_experts_fusion):
1727
- weights_list.append(
1728
- (
1729
- f"model.layers.{moe_layer}."
1730
- f"mlp.experts."
1731
- f"{self.config.n_routed_experts + num_repeat}"
1732
- f".{suffix}",
1733
- weights_dict[shared_expert_weight_name],
1734
- )
2039
+ weights_list.append(
2040
+ (
2041
+ f"model.layers.{moe_layer}."
2042
+ f"mlp.experts."
2043
+ f"{self.config.n_routed_experts + 0}"
2044
+ f".{suffix}",
2045
+ weights_dict[shared_expert_weight_name],
1735
2046
  )
2047
+ )
1736
2048
  names_to_remove += [shared_expert_weight_name]
1737
2049
  weights = [w for w in weights_list if w[0] not in names_to_remove]
1738
2050
 
1739
2051
  # Params for weights, fp8 weight scales, fp8 activation scales
1740
2052
  # (param_name, weight_name, expert_id, shard_id)
1741
- MoEImpl = (
1742
- DeepEPMoE
1743
- if global_server_args_dict["enable_deepep_moe"]
1744
- else (EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE)
1745
- )
1746
- expert_params_mapping = MoEImpl.make_expert_params_mapping(
2053
+ expert_params_mapping = get_moe_impl_class().make_expert_params_mapping(
1747
2054
  ckpt_gate_proj_name="gate_proj",
1748
2055
  ckpt_down_proj_name="down_proj",
1749
2056
  ckpt_up_proj_name="up_proj",
1750
- num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
2057
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
1751
2058
  )
1752
2059
 
1753
2060
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
@@ -1766,7 +2073,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1766
2073
  ]
1767
2074
 
1768
2075
  params_dict = dict(self.named_parameters())
2076
+ weight_names = []
1769
2077
  for name, loaded_weight in weights:
2078
+ weight_names.append(name)
2079
+
1770
2080
  if not is_nextn:
1771
2081
  if hasattr(self.config, "num_nextn_predict_layers"):
1772
2082
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1838,7 +2148,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1838
2148
  # Skip loading extra bias for GPTQ models.
1839
2149
  if name.endswith(".bias") and name not in params_dict:
1840
2150
  continue
1841
-
1842
2151
  if fuse_qkv_a_proj and (
1843
2152
  "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1844
2153
  ):
@@ -1859,15 +2168,17 @@ class DeepseekV2ForCausalLM(nn.Module):
1859
2168
  q_a_proj_name in cached_a_proj
1860
2169
  and kv_a_proj_name in cached_a_proj
1861
2170
  ):
1862
-
1863
2171
  q_a_proj_weight = cached_a_proj[q_a_proj_name]
1864
2172
  kv_a_proj_weight = cached_a_proj[kv_a_proj_name]
1865
2173
  fused_weight = torch.cat(
1866
2174
  [q_a_proj_weight, kv_a_proj_weight], dim=0
1867
2175
  )
1868
-
1869
- param_name = name.replace(
1870
- "q_a_proj", "fused_qkv_a_proj_with_mqa"
2176
+ param_name = (
2177
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
2178
+ if "q_a_proj" in name
2179
+ else name.replace(
2180
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
2181
+ )
1871
2182
  )
1872
2183
  param = params_dict[param_name]
1873
2184
 
@@ -1878,13 +2189,23 @@ class DeepseekV2ForCausalLM(nn.Module):
1878
2189
  cached_a_proj.pop(q_a_proj_name)
1879
2190
  cached_a_proj.pop(kv_a_proj_name)
1880
2191
  else:
2192
+ if (
2193
+ "k_scale" in name or "v_scale" in name
2194
+ ) and name not in params_dict:
2195
+ # modelopt attn kv scale is named differently
2196
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
2197
+ name = name.replace("_proj", "attn_mqa")
2198
+ else:
2199
+ logger.warning(
2200
+ f"Unknown scale found in checkpoint: {name}"
2201
+ )
1881
2202
  param = params_dict[name]
1882
2203
  weight_loader = getattr(
1883
2204
  param, "weight_loader", default_weight_loader
1884
2205
  )
1885
2206
  weight_loader(param, loaded_weight)
1886
2207
 
1887
- self.post_load_weights(is_nextn=is_nextn)
2208
+ self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
1888
2209
 
1889
2210
  def get_embed_and_head(self):
1890
2211
  return self.model.embed_tokens.weight, self.lm_head.weight
@@ -1897,6 +2218,14 @@ class DeepseekV2ForCausalLM(nn.Module):
1897
2218
  torch.cuda.empty_cache()
1898
2219
  torch.cuda.synchronize()
1899
2220
 
2221
+ @classmethod
2222
+ def get_model_config_for_expert_location(cls, config):
2223
+ return ModelConfigForExpertLocation(
2224
+ num_layers=config.num_hidden_layers,
2225
+ num_logical_experts=config.n_routed_experts,
2226
+ num_groups=config.n_group,
2227
+ )
2228
+
1900
2229
 
1901
2230
  class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
1902
2231
  pass