sglang 0.4.6.post5__py3-none-any.whl → 0.4.7.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (359) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_offline_throughput.py +10 -4
  4. sglang/bench_one_batch_server.py +67 -11
  5. sglang/bench_serving.py +86 -75
  6. sglang/lang/backend/runtime_endpoint.py +24 -1
  7. sglang/lang/interpreter.py +40 -1
  8. sglang/lang/ir.py +27 -0
  9. sglang/math_utils.py +8 -0
  10. sglang/profiler.py +167 -0
  11. sglang/srt/_custom_ops.py +34 -0
  12. sglang/srt/configs/internvl.py +8 -12
  13. sglang/srt/configs/model_config.py +33 -1
  14. sglang/srt/constrained/base_grammar_backend.py +5 -2
  15. sglang/srt/constrained/llguidance_backend.py +9 -8
  16. sglang/srt/constrained/outlines_backend.py +5 -4
  17. sglang/srt/constrained/xgrammar_backend.py +18 -18
  18. sglang/srt/conversation.py +52 -8
  19. sglang/srt/custom_op.py +38 -3
  20. sglang/srt/debug_utils.py +74 -0
  21. sglang/srt/disaggregation/base/__init__.py +1 -1
  22. sglang/srt/disaggregation/base/conn.py +25 -11
  23. sglang/srt/disaggregation/common/__init__.py +5 -0
  24. sglang/srt/disaggregation/common/conn.py +407 -0
  25. sglang/srt/disaggregation/common/utils.py +42 -0
  26. sglang/srt/disaggregation/decode.py +261 -52
  27. sglang/srt/disaggregation/fake/__init__.py +1 -1
  28. sglang/srt/disaggregation/fake/conn.py +16 -9
  29. sglang/srt/disaggregation/kv_events.py +60 -5
  30. sglang/srt/disaggregation/launch_lb.py +140 -0
  31. sglang/srt/disaggregation/mini_lb.py +29 -48
  32. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  33. sglang/srt/disaggregation/mooncake/conn.py +446 -149
  34. sglang/srt/disaggregation/mooncake/transfer_engine.py +32 -16
  35. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  36. sglang/srt/disaggregation/nixl/conn.py +134 -437
  37. sglang/srt/disaggregation/prefill.py +130 -43
  38. sglang/srt/disaggregation/utils.py +127 -86
  39. sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
  40. sglang/srt/distributed/parallel_state.py +52 -5
  41. sglang/srt/entrypoints/EngineBase.py +6 -0
  42. sglang/srt/entrypoints/engine.py +116 -5
  43. sglang/srt/entrypoints/http_server.py +28 -4
  44. sglang/srt/eplb_simulator/__init__.py +1 -0
  45. sglang/srt/eplb_simulator/reader.py +51 -0
  46. sglang/srt/function_call/base_format_detector.py +138 -86
  47. sglang/srt/function_call/deepseekv3_detector.py +54 -6
  48. sglang/srt/function_call/ebnf_composer.py +33 -19
  49. sglang/srt/function_call/function_call_parser.py +27 -0
  50. sglang/srt/function_call/llama32_detector.py +33 -14
  51. sglang/srt/function_call/mistral_detector.py +73 -26
  52. sglang/srt/function_call/pythonic_detector.py +86 -20
  53. sglang/srt/function_call/qwen25_detector.py +64 -10
  54. sglang/srt/function_call/utils.py +17 -0
  55. sglang/srt/hf_transformers_utils.py +4 -0
  56. sglang/srt/layers/activation.py +19 -0
  57. sglang/srt/layers/attention/aiter_backend.py +503 -125
  58. sglang/srt/layers/attention/base_attn_backend.py +4 -0
  59. sglang/srt/layers/attention/cutlass_mla_backend.py +40 -34
  60. sglang/srt/layers/attention/flashattention_backend.py +137 -63
  61. sglang/srt/layers/attention/flashinfer_backend.py +46 -3
  62. sglang/srt/layers/attention/flashinfer_mla_backend.py +59 -25
  63. sglang/srt/layers/attention/flashmla_backend.py +2 -10
  64. sglang/srt/layers/attention/intel_amx_backend.py +128 -0
  65. sglang/srt/layers/attention/tbo_backend.py +232 -0
  66. sglang/srt/layers/attention/torch_native_backend.py +3 -0
  67. sglang/srt/layers/attention/triton_backend.py +304 -65
  68. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  69. sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
  70. sglang/srt/layers/attention/vision.py +51 -24
  71. sglang/srt/layers/communicator.py +281 -197
  72. sglang/srt/layers/dp_attention.py +6 -5
  73. sglang/srt/layers/layernorm.py +30 -19
  74. sglang/srt/layers/linear.py +0 -4
  75. sglang/srt/layers/logits_processor.py +0 -12
  76. sglang/srt/layers/moe/cutlass_moe.py +170 -7
  77. sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
  78. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  79. sglang/srt/layers/moe/ep_moe/layer.py +136 -72
  80. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +24 -45
  81. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  82. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  84. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  85. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +221 -29
  91. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -4
  92. sglang/srt/layers/moe/topk.py +60 -26
  93. sglang/srt/layers/multimodal.py +3 -3
  94. sglang/srt/layers/pooler.py +56 -0
  95. sglang/srt/layers/quantization/__init__.py +3 -2
  96. sglang/srt/layers/quantization/blockwise_int8.py +3 -0
  97. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
  98. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  99. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +69 -127
  100. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  101. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  102. sglang/srt/layers/quantization/fp8.py +28 -23
  103. sglang/srt/layers/quantization/fp8_kernel.py +156 -75
  104. sglang/srt/layers/quantization/fp8_utils.py +250 -69
  105. sglang/srt/layers/quantization/modelopt_quant.py +334 -7
  106. sglang/srt/layers/quantization/moe_wna16.py +3 -0
  107. sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
  108. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  109. sglang/srt/layers/radix_attention.py +2 -3
  110. sglang/srt/layers/rotary_embedding.py +6 -12
  111. sglang/srt/layers/sampler.py +80 -79
  112. sglang/srt/layers/utils.py +6 -0
  113. sglang/srt/lora/layers.py +12 -15
  114. sglang/srt/lora/lora.py +49 -5
  115. sglang/srt/lora/lora_manager.py +98 -39
  116. sglang/srt/lora/mem_pool.py +28 -21
  117. sglang/srt/lora/utils.py +17 -13
  118. sglang/srt/managers/cache_controller.py +2 -1
  119. sglang/srt/managers/data_parallel_controller.py +13 -5
  120. sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
  121. sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
  122. sglang/srt/managers/{deepseek_eplb.py → eplb_algorithms/deepseek_vec.py} +5 -7
  123. sglang/srt/managers/eplb_manager.py +55 -14
  124. sglang/srt/managers/expert_distribution.py +220 -46
  125. sglang/srt/managers/expert_location.py +110 -56
  126. sglang/srt/managers/expert_location_dispatch.py +23 -6
  127. sglang/srt/managers/io_struct.py +43 -8
  128. sglang/srt/managers/mm_utils.py +88 -38
  129. sglang/srt/managers/multimodal_processors/base_processor.py +190 -18
  130. sglang/srt/managers/multimodal_processors/gemma3.py +4 -31
  131. sglang/srt/managers/multimodal_processors/internvl.py +4 -0
  132. sglang/srt/managers/multimodal_processors/kimi_vl.py +15 -34
  133. sglang/srt/managers/multimodal_processors/minicpm.py +2 -1
  134. sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
  135. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -64
  136. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  137. sglang/srt/managers/schedule_batch.py +173 -38
  138. sglang/srt/managers/scheduler.py +376 -127
  139. sglang/srt/managers/tokenizer_manager.py +163 -19
  140. sglang/srt/managers/utils.py +0 -4
  141. sglang/srt/mem_cache/chunk_cache.py +1 -0
  142. sglang/srt/mem_cache/hiradix_cache.py +4 -2
  143. sglang/srt/mem_cache/memory_pool.py +111 -407
  144. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  145. sglang/srt/mem_cache/radix_cache.py +36 -12
  146. sglang/srt/metrics/collector.py +9 -0
  147. sglang/srt/model_executor/cuda_graph_runner.py +191 -113
  148. sglang/srt/model_executor/expert_location_updater.py +157 -22
  149. sglang/srt/model_executor/forward_batch_info.py +52 -22
  150. sglang/srt/model_executor/model_runner.py +102 -62
  151. sglang/srt/model_loader/loader.py +8 -1
  152. sglang/srt/model_loader/utils.py +67 -1
  153. sglang/srt/models/bert.py +113 -13
  154. sglang/srt/models/deepseek_nextn.py +1 -1
  155. sglang/srt/models/deepseek_v2.py +623 -290
  156. sglang/srt/models/gemma3_causal.py +7 -0
  157. sglang/srt/models/gemma3_mm.py +19 -14
  158. sglang/srt/models/idefics2.py +342 -0
  159. sglang/srt/models/internvl.py +46 -102
  160. sglang/srt/models/kimi_vl.py +4 -4
  161. sglang/srt/models/llama.py +1 -1
  162. sglang/srt/models/minicpmo.py +2 -5
  163. sglang/srt/models/minicpmv.py +3 -295
  164. sglang/srt/models/phi4mm.py +512 -0
  165. sglang/srt/models/qwen2.py +38 -9
  166. sglang/srt/models/qwen2_5_vl.py +3 -9
  167. sglang/srt/models/qwen2_eagle.py +4 -1
  168. sglang/srt/models/qwen2_moe.py +58 -191
  169. sglang/srt/models/qwen2_vl.py +3 -9
  170. sglang/srt/models/qwen3.py +41 -10
  171. sglang/srt/models/qwen3_moe.py +230 -191
  172. sglang/srt/models/registry.py +9 -1
  173. sglang/srt/models/roberta.py +117 -9
  174. sglang/srt/models/transformers.py +291 -0
  175. sglang/srt/models/vila.py +305 -0
  176. sglang/srt/openai_api/adapter.py +248 -28
  177. sglang/srt/openai_api/protocol.py +68 -3
  178. sglang/srt/openai_api/utils.py +172 -0
  179. sglang/srt/operations.py +37 -2
  180. sglang/srt/operations_strategy.py +200 -24
  181. sglang/srt/sampling/sampling_batch_info.py +37 -1
  182. sglang/srt/sampling/sampling_params.py +4 -1
  183. sglang/srt/server_args.py +381 -209
  184. sglang/srt/speculative/build_eagle_tree.py +9 -9
  185. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +12 -14
  186. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +256 -0
  187. sglang/srt/speculative/eagle_utils.py +440 -200
  188. sglang/srt/speculative/eagle_worker.py +234 -63
  189. sglang/srt/two_batch_overlap.py +637 -0
  190. sglang/srt/utils.py +187 -7
  191. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  192. sglang/test/runners.py +54 -10
  193. sglang/test/send_one.py +4 -0
  194. sglang/test/test_block_fp8.py +1 -0
  195. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  196. sglang/test/test_block_fp8_ep.py +1 -0
  197. sglang/test/test_cutlass_moe.py +3 -3
  198. sglang/test/test_fp4_moe.py +248 -0
  199. sglang/test/test_utils.py +82 -7
  200. sglang/utils.py +9 -0
  201. sglang/version.py +1 -1
  202. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/METADATA +17 -14
  203. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/RECORD +359 -321
  204. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/WHEEL +1 -1
  205. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  206. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  207. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  208. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  209. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  210. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  211. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  212. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  213. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  214. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  215. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  216. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  217. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  218. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
  219. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  220. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  221. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  222. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  223. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  224. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  225. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  226. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  227. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  228. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  229. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  230. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  231. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  232. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  233. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  234. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  235. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  236. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  237. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  238. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  239. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  240. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  241. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  242. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  243. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  244. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  245. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  246. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  247. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
  248. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  249. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  250. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  251. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  252. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  253. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  254. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  255. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  256. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  257. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
  258. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
  259. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  260. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  261. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  262. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  263. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  264. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  265. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
  266. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  267. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  268. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
  269. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  270. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  271. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  272. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
  273. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  274. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  275. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  276. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  277. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  278. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  279. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  280. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
  281. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
  282. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
  283. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
  284. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  285. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  286. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
  287. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
  288. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
  289. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
  290. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  291. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  292. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  293. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  294. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
  295. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  296. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  297. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  298. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  299. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
  300. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
  301. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
  302. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
  303. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  304. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  305. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  306. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  307. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  308. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  309. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
  310. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
  311. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  312. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  313. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  314. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  315. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  316. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  317. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  318. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
  319. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
  320. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
  321. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
  322. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  323. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  324. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  325. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  326. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
  327. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  328. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  329. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  330. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  331. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  332. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  333. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  334. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
  335. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
  336. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  337. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  338. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
  339. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  340. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
  341. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  342. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  343. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  344. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  345. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
  346. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  347. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
  348. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
  349. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  350. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  351. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  352. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  353. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  354. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  355. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  356. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  357. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  358. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/licenses/LICENSE +0 -0
  359. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.post1.dist-info}/top_level.txt +0 -0
@@ -51,12 +51,13 @@ from sglang.srt.layers.linear import (
51
51
  RowParallelLinear,
52
52
  )
53
53
  from sglang.srt.layers.logits_processor import LogitsProcessor
54
- from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
54
+ from sglang.srt.layers.moe.ep_moe.layer import DeepEPMoE, get_moe_impl_class
55
55
  from sglang.srt.layers.moe.ep_moe.token_dispatcher import DeepEPDispatcher
56
56
  from sglang.srt.layers.moe.topk import select_experts
57
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
57
58
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
- from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
59
59
  from sglang.srt.layers.quantization.fp8_kernel import (
60
+ is_fp8_fnuz,
60
61
  per_tensor_quant_mla_fp8,
61
62
  per_token_group_quant_mla_deep_gemm_masked_fp8,
62
63
  )
@@ -65,6 +66,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
65
66
  block_quant_to_tensor_quant,
66
67
  channel_quant_to_tensor_quant,
67
68
  normalize_e4m3fn_to_e4m3fnuz,
69
+ requant_weight_ue8m0_inplace,
68
70
  )
69
71
  from sglang.srt.layers.quantization.int8_utils import (
70
72
  block_dequant as int8_block_dequant,
@@ -83,28 +85,31 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
83
85
  from sglang.srt.managers.schedule_batch import global_server_args_dict
84
86
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
85
87
  from sglang.srt.model_loader.weight_utils import default_weight_loader
86
- from sglang.srt.operations import execute_operations
87
- from sglang.srt.operations_strategy import compute_layer_operations
88
+ from sglang.srt.two_batch_overlap import (
89
+ MaybeTboDeepEPDispatcher,
90
+ model_forward_maybe_tbo,
91
+ )
88
92
  from sglang.srt.utils import (
89
93
  BumpAllocator,
90
94
  DeepEPMode,
95
+ LazyValue,
91
96
  add_prefix,
97
+ bind_or_assign,
92
98
  get_bool_env_var,
93
99
  get_int_env_var,
94
100
  is_cuda,
95
101
  is_hip,
102
+ is_non_idle_and_non_empty,
96
103
  log_info_on_rank0,
97
104
  )
98
105
 
99
106
  _is_hip = is_hip()
100
107
  _is_cuda = is_cuda()
108
+ _is_fp8_fnuz = is_fp8_fnuz()
109
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
101
110
 
102
111
  if _is_cuda:
103
112
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
104
-
105
- from sglang.srt.layers.quantization.deep_gemm import (
106
- grouped_gemm_nt_f8f8bf16_masked as deep_gemm_grouped_gemm_nt_f8f8bf16_masked,
107
- )
108
113
  else:
109
114
  from vllm._custom_ops import awq_dequantize
110
115
 
@@ -113,6 +118,9 @@ if _is_hip:
113
118
  decode_attention_fwd_grouped_rope,
114
119
  )
115
120
 
121
+ if _use_aiter:
122
+ from aiter.rotary_embedding import get_rope
123
+
116
124
  logger = logging.getLogger(__name__)
117
125
 
118
126
 
@@ -204,14 +212,6 @@ class MoEGate(nn.Module):
204
212
  return logits
205
213
 
206
214
 
207
- def is_non_idle_and_non_empty(forward_mode, hidden_states):
208
- return (
209
- (forward_mode is not None)
210
- and not forward_mode.is_idle()
211
- and hidden_states.shape[0] > 0
212
- )
213
-
214
-
215
215
  class DeepseekV2MoE(nn.Module):
216
216
 
217
217
  def __init__(
@@ -225,7 +225,12 @@ class DeepseekV2MoE(nn.Module):
225
225
  self.tp_size = get_tensor_model_parallel_world_size()
226
226
  self.routed_scaling_factor = config.routed_scaling_factor
227
227
  self.n_shared_experts = config.n_shared_experts
228
- self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
228
+ self.num_fused_shared_experts = (
229
+ 0
230
+ if global_server_args_dict["disable_shared_experts_fusion"]
231
+ else config.n_shared_experts
232
+ )
233
+ self.config = config
229
234
  self.layer_id = layer_id
230
235
 
231
236
  if self.tp_size > config.n_routed_experts:
@@ -244,9 +249,9 @@ class DeepseekV2MoE(nn.Module):
244
249
 
245
250
  self.experts = get_moe_impl_class()(
246
251
  num_experts=config.n_routed_experts
247
- + self.n_share_experts_fusion
252
+ + self.num_fused_shared_experts
248
253
  + global_server_args_dict["ep_num_redundant_experts"],
249
- top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
254
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
250
255
  hidden_size=config.hidden_size,
251
256
  intermediate_size=config.moe_intermediate_size,
252
257
  layer_id=self.layer_id,
@@ -254,6 +259,7 @@ class DeepseekV2MoE(nn.Module):
254
259
  quant_config=quant_config,
255
260
  use_grouped_topk=True,
256
261
  num_expert_group=config.n_group,
262
+ num_fused_shared_experts=self.num_fused_shared_experts,
257
263
  topk_group=config.topk_group,
258
264
  correction_bias=self.gate.e_score_correction_bias,
259
265
  routed_scaling_factor=self.routed_scaling_factor,
@@ -265,7 +271,7 @@ class DeepseekV2MoE(nn.Module):
265
271
  ),
266
272
  )
267
273
 
268
- if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
274
+ if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
269
275
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
270
276
  # disable tp for shared experts when enable deepep moe
271
277
  self.shared_experts = DeepseekV2MLP(
@@ -300,7 +306,7 @@ class DeepseekV2MoE(nn.Module):
300
306
  else None
301
307
  )
302
308
 
303
- self.deepep_dispatcher = DeepEPDispatcher(
309
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
304
310
  group=parallel_state.get_tp_group().device_group,
305
311
  router_topk=self.top_k,
306
312
  permute_fusion=True,
@@ -309,13 +315,11 @@ class DeepseekV2MoE(nn.Module):
309
315
  hidden_size=config.hidden_size,
310
316
  params_dtype=config.torch_dtype,
311
317
  deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
312
- async_finish=True, # TODO
318
+ async_finish=True,
313
319
  return_recv_hook=True,
314
320
  )
315
321
 
316
- @property
317
- def _enable_deepep_moe(self):
318
- return global_server_args_dict["enable_deepep_moe"]
322
+ self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
319
323
 
320
324
  def get_moe_weights(self):
321
325
  return [
@@ -324,8 +328,114 @@ class DeepseekV2MoE(nn.Module):
324
328
  if name not in ["correction_bias"]
325
329
  ]
326
330
 
331
+ def forward(
332
+ self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
333
+ ) -> torch.Tensor:
334
+ if not self._enable_deepep_moe:
335
+ return self.forward_normal(hidden_states)
336
+ else:
337
+ return self.forward_deepep(hidden_states, forward_batch)
338
+
339
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
340
+ shared_output = self._forward_shared_experts(hidden_states)
341
+ # router_logits: (num_tokens, n_experts)
342
+ router_logits = self.gate(hidden_states)
343
+ final_hidden_states = self.experts(
344
+ hidden_states=hidden_states, router_logits=router_logits
345
+ )
346
+ if not _is_cuda:
347
+ final_hidden_states *= self.routed_scaling_factor
348
+ if shared_output is not None:
349
+ final_hidden_states = final_hidden_states + shared_output
350
+ if self.tp_size > 1:
351
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
352
+ return final_hidden_states
353
+
354
+ def forward_deepep(
355
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
356
+ ) -> torch.Tensor:
357
+ forward_mode = forward_batch.forward_mode
358
+ shared_output = None
359
+ if is_non_idle_and_non_empty(forward_mode, hidden_states):
360
+ # router_logits: (num_tokens, n_experts)
361
+ router_logits = self.gate(hidden_states)
362
+ shared_output = self._forward_shared_experts(hidden_states)
363
+ topk_weights, topk_idx = select_experts(
364
+ hidden_states=hidden_states,
365
+ router_logits=router_logits,
366
+ top_k=self.top_k,
367
+ use_grouped_topk=True,
368
+ renormalize=self.renormalize,
369
+ topk_group=self.topk_group,
370
+ num_expert_group=self.num_expert_group,
371
+ num_fused_shared_experts=self.num_fused_shared_experts,
372
+ correction_bias=self.correction_bias,
373
+ routed_scaling_factor=self.routed_scaling_factor,
374
+ num_token_non_padded=forward_batch.num_token_non_padded,
375
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
376
+ layer_id=self.layer_id,
377
+ ),
378
+ )
379
+ else:
380
+ topk_idx = torch.full(
381
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
382
+ )
383
+ topk_weights = torch.empty(
384
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
385
+ )
386
+ if self.ep_size > 1:
387
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
388
+ (
389
+ hidden_states,
390
+ topk_idx,
391
+ topk_weights,
392
+ reorder_topk_ids,
393
+ num_recv_tokens_per_expert,
394
+ seg_indptr,
395
+ masked_m,
396
+ expected_m,
397
+ ) = self.deepep_dispatcher.dispatch(
398
+ hidden_states=hidden_states,
399
+ topk_idx=topk_idx,
400
+ topk_weights=topk_weights,
401
+ forward_mode=forward_mode,
402
+ )
403
+ final_hidden_states = self.experts(
404
+ hidden_states=hidden_states,
405
+ topk_idx=topk_idx,
406
+ topk_weights=topk_weights,
407
+ reorder_topk_ids=reorder_topk_ids,
408
+ seg_indptr=seg_indptr,
409
+ masked_m=masked_m,
410
+ expected_m=expected_m,
411
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
412
+ forward_mode=forward_mode,
413
+ )
414
+ if self.ep_size > 1:
415
+ final_hidden_states = self.deepep_dispatcher.combine(
416
+ hidden_states=final_hidden_states,
417
+ topk_idx=topk_idx,
418
+ topk_weights=topk_weights,
419
+ forward_mode=forward_mode,
420
+ )
421
+
422
+ if shared_output is not None:
423
+ x = shared_output
424
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
425
+ final_hidden_states = x
426
+ else:
427
+ final_hidden_states *= self.routed_scaling_factor
428
+
429
+ return final_hidden_states
430
+
431
+ def _forward_shared_experts(self, hidden_states):
432
+ if self.num_fused_shared_experts == 0:
433
+ return self.shared_experts(hidden_states)
434
+ else:
435
+ return None
436
+
327
437
  def op_gate(self, state):
328
- if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
438
+ if is_non_idle_and_non_empty(
329
439
  state.forward_batch.forward_mode, state.hidden_states_mlp_input
330
440
  ):
331
441
  # router_logits: (num_tokens, n_experts)
@@ -334,22 +444,22 @@ class DeepseekV2MoE(nn.Module):
334
444
  state.router_logits = None
335
445
 
336
446
  def op_shared_experts(self, state):
337
- if (self.n_share_experts_fusion == 0) and (
338
- (not self._enable_deepep_moe)
339
- or is_non_idle_and_non_empty(
340
- state.forward_batch.forward_mode, state.hidden_states_mlp_input
341
- )
447
+ hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
448
+ if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
449
+ state.forward_batch.forward_mode, hidden_states_mlp_input
342
450
  ):
343
- state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
451
+ state.shared_output = self.shared_experts(hidden_states_mlp_input)
344
452
  else:
345
453
  state.shared_output = None
346
454
 
347
455
  def op_select_experts(self, state):
348
- router_logits = state.router_logits
456
+ router_logits = state.pop("router_logits")
349
457
  hidden_states = state.hidden_states_mlp_input
350
458
 
351
- if self._enable_deepep_moe:
352
- if router_logits is not None:
459
+ if router_logits is not None:
460
+ with get_global_expert_distribution_recorder().with_current_layer(
461
+ self.layer_id
462
+ ):
353
463
  state.topk_weights_local, state.topk_idx_local = select_experts(
354
464
  hidden_states=hidden_states,
355
465
  router_logits=router_logits,
@@ -358,90 +468,89 @@ class DeepseekV2MoE(nn.Module):
358
468
  renormalize=self.renormalize,
359
469
  topk_group=self.topk_group,
360
470
  num_expert_group=self.num_expert_group,
471
+ num_fused_shared_experts=self.num_fused_shared_experts,
361
472
  correction_bias=self.correction_bias,
362
473
  routed_scaling_factor=self.routed_scaling_factor,
474
+ num_token_non_padded=state.forward_batch.num_token_non_padded,
363
475
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
364
476
  layer_id=self.layer_id,
365
477
  ),
366
478
  )
367
- else:
368
- state.topk_idx_local = torch.full(
369
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
370
- )
371
- state.topk_weights_local = torch.empty(
372
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
373
- )
479
+ else:
480
+ state.topk_idx_local = torch.full(
481
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
482
+ )
483
+ state.topk_weights_local = torch.empty(
484
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
485
+ )
374
486
 
375
487
  def op_dispatch_a(self, state):
376
- if self._enable_deepep_moe and (self.ep_size > 1):
488
+ if self.ep_size > 1:
377
489
  # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
378
490
  self.deepep_dispatcher.dispatch_a(
379
- hidden_states=state.pop("hidden_states_mlp_input"),
491
+ hidden_states=state.hidden_states_mlp_input,
380
492
  topk_idx=state.pop("topk_idx_local"),
381
493
  topk_weights=state.pop("topk_weights_local"),
382
494
  forward_mode=state.forward_batch.forward_mode,
495
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
383
496
  )
384
497
 
385
498
  def op_dispatch_b(self, state):
386
- if self._enable_deepep_moe and (self.ep_size > 1):
387
- (
388
- state.hidden_states_experts_input,
389
- state.topk_idx_dispatched,
390
- state.topk_weights_dispatched,
391
- state.reorder_topk_ids,
392
- state.num_recv_tokens_per_expert,
393
- state.seg_indptr,
394
- state.masked_m,
395
- state.expected_m,
396
- ) = self.deepep_dispatcher.dispatch_b()
499
+ if self.ep_size > 1:
500
+ with get_global_expert_distribution_recorder().with_current_layer(
501
+ self.layer_id
502
+ ):
503
+ (
504
+ state.hidden_states_experts_input,
505
+ state.topk_idx_dispatched,
506
+ state.topk_weights_dispatched,
507
+ state.reorder_topk_ids,
508
+ state.num_recv_tokens_per_expert,
509
+ state.seg_indptr,
510
+ state.masked_m,
511
+ state.expected_m,
512
+ ) = self.deepep_dispatcher.dispatch_b(
513
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
514
+ )
397
515
 
398
516
  def op_experts(self, state):
399
- if self._enable_deepep_moe:
400
- state.pop("router_logits")
401
- state.hidden_states_experts_output = self.experts(
402
- hidden_states=state.pop("hidden_states_experts_input"),
403
- topk_idx=state.topk_idx_dispatched,
404
- topk_weights=state.topk_weights_dispatched,
405
- reorder_topk_ids=state.pop("reorder_topk_ids"),
406
- seg_indptr=state.pop("seg_indptr"),
407
- masked_m=state.pop("masked_m"),
408
- expected_m=state.pop("expected_m"),
409
- num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
410
- forward_mode=state.forward_batch.forward_mode,
411
- )
412
- else:
413
- state.hidden_states_experts_output = self.experts(
414
- hidden_states=state.pop("hidden_states_mlp_input"),
415
- router_logits=state.pop("router_logits"),
416
- )
517
+ state.hidden_states_experts_output = self.experts(
518
+ hidden_states=state.pop("hidden_states_experts_input"),
519
+ topk_idx=state.topk_idx_dispatched,
520
+ topk_weights=state.topk_weights_dispatched,
521
+ reorder_topk_ids=state.pop("reorder_topk_ids"),
522
+ seg_indptr=state.pop("seg_indptr"),
523
+ masked_m=state.pop("masked_m"),
524
+ expected_m=state.pop("expected_m"),
525
+ num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
526
+ forward_mode=state.forward_batch.forward_mode,
527
+ )
417
528
 
418
529
  def op_combine_a(self, state):
419
- if self._enable_deepep_moe and (self.ep_size > 1):
530
+ if self.ep_size > 1:
420
531
  self.deepep_dispatcher.combine_a(
421
- state.pop("hidden_states_experts_output"),
532
+ hidden_states=state.pop("hidden_states_experts_output"),
422
533
  topk_idx=state.pop("topk_idx_dispatched"),
423
534
  topk_weights=state.pop("topk_weights_dispatched"),
424
535
  forward_mode=state.forward_batch.forward_mode,
536
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
425
537
  )
426
538
 
427
539
  def op_combine_b(self, state):
428
- if self._enable_deepep_moe and (self.ep_size > 1):
429
- state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
540
+ if self.ep_size > 1:
541
+ state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
542
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
543
+ )
430
544
 
431
545
  def op_output(self, state):
432
- final_hidden_states = (
433
- state.pop("hidden_states_after_combine")
434
- if self._enable_deepep_moe
435
- else state.pop("hidden_states_experts_output")
436
- )
437
-
438
- final_hidden_states *= self.routed_scaling_factor
439
-
440
- if (s := state.pop("shared_output")) is not None:
441
- final_hidden_states = final_hidden_states + s
546
+ final_hidden_states = state.pop("hidden_states_after_combine")
442
547
 
443
- if (not self._enable_deepep_moe) and (self.tp_size > 1):
444
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
548
+ if (shared_output := state.pop("shared_output")) is not None:
549
+ x = shared_output
550
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
551
+ final_hidden_states = x
552
+ else:
553
+ final_hidden_states *= self.routed_scaling_factor
445
554
 
446
555
  state.hidden_states_mlp_output = final_hidden_states
447
556
 
@@ -596,10 +705,11 @@ class DeepseekV2AttentionMLA(nn.Module):
596
705
  )
597
706
 
598
707
  self.alt_stream = alt_stream
708
+ self.attn_mha.kv_b_proj = None
599
709
 
600
710
  self.w_kc = None
601
711
  self.w_vc = None
602
- self.w_scale = None
712
+ self.w_scale = 1.0
603
713
 
604
714
  self.w_scale_k = None
605
715
  self.w_scale_v = None
@@ -665,6 +775,15 @@ class DeepseekV2AttentionMLA(nn.Module):
665
775
  return AttnForwardMethod.MHA_CHUNKED_KV
666
776
  else:
667
777
  return _dispatch_mla_subtype()
778
+ elif self.attention_backend == "aiter":
779
+ if (
780
+ forward_batch.forward_mode.is_extend()
781
+ and not forward_batch.forward_mode.is_target_verify()
782
+ and not forward_batch.forward_mode.is_draft_extend()
783
+ ):
784
+ return AttnForwardMethod.MHA
785
+ else:
786
+ return AttnForwardMethod.MLA
668
787
  else:
669
788
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
670
789
  if (
@@ -677,44 +796,97 @@ class DeepseekV2AttentionMLA(nn.Module):
677
796
  else:
678
797
  return _dispatch_mla_subtype()
679
798
 
799
+ def op_prepare(self, state):
800
+ state.attn_intermediate_state = self.forward_prepare(
801
+ positions=state.positions,
802
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
803
+ forward_batch=state.forward_batch,
804
+ zero_allocator=state.zero_allocator,
805
+ )
806
+
807
+ def op_core(self, state):
808
+ state.hidden_states_after_attn = self.forward_core(
809
+ state.pop("attn_intermediate_state")
810
+ )
811
+
680
812
  def forward(
681
813
  self,
682
814
  positions: torch.Tensor,
683
815
  hidden_states: torch.Tensor,
684
816
  forward_batch: ForwardBatch,
685
817
  zero_allocator: BumpAllocator,
686
- ) -> torch.Tensor:
818
+ ):
819
+ s = self.forward_prepare(
820
+ positions=positions,
821
+ hidden_states=hidden_states,
822
+ forward_batch=forward_batch,
823
+ zero_allocator=zero_allocator,
824
+ )
825
+ return self.forward_core(s)
826
+
827
+ def forward_prepare(
828
+ self,
829
+ positions: torch.Tensor,
830
+ hidden_states: torch.Tensor,
831
+ forward_batch: ForwardBatch,
832
+ zero_allocator: BumpAllocator,
833
+ ):
834
+ if self.attn_mha.kv_b_proj is None:
835
+ self.attn_mha.kv_b_proj = self.kv_b_proj
836
+
687
837
  if hidden_states.shape[0] == 0:
688
838
  assert (
689
839
  not self.o_proj.reduce_results
690
840
  ), "short-circuiting allreduce will lead to hangs"
691
- return hidden_states
841
+ return hidden_states, None, forward_batch, None
692
842
 
693
843
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
694
844
 
695
845
  if attn_forward_method == AttnForwardMethod.MHA:
696
- return self.forward_normal(positions, hidden_states, forward_batch)
846
+ inner_state = self.forward_normal_prepare(
847
+ positions, hidden_states, forward_batch, zero_allocator
848
+ )
697
849
  elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
698
- return self.forward_normal_chunked_kv(
699
- positions, hidden_states, forward_batch
850
+ inner_state = self.forward_normal_chunked_kv_prepare(
851
+ positions, hidden_states, forward_batch, zero_allocator
700
852
  )
701
853
  elif attn_forward_method == AttnForwardMethod.MLA:
702
- return self.forward_absorb(
854
+ inner_state = self.forward_absorb_prepare(
703
855
  positions, hidden_states, forward_batch, zero_allocator
704
856
  )
705
857
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
706
- return self.forward_absorb_fused_mla_rope(
707
- positions, hidden_states, forward_batch
858
+ inner_state = self.forward_absorb_fused_mla_rope_prepare(
859
+ positions, hidden_states, forward_batch, zero_allocator
708
860
  )
709
861
  else:
710
862
  raise NotImplementedError
863
+ return None, attn_forward_method, forward_batch, inner_state
864
+
865
+ def forward_core(self, intermediate_state):
866
+ hidden_states, attn_forward_method, forward_batch, inner_state = (
867
+ intermediate_state
868
+ )
869
+ if inner_state is None:
870
+ return hidden_states
871
+
872
+ if attn_forward_method == AttnForwardMethod.MHA:
873
+ return self.forward_normal_core(*inner_state)
874
+ elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
875
+ return self.forward_normal_chunked_kv_core(*inner_state)
876
+ elif attn_forward_method == AttnForwardMethod.MLA:
877
+ return self.forward_absorb_core(*inner_state)
878
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
879
+ return self.forward_absorb_fused_mla_rope_core(*inner_state)
880
+ else:
881
+ raise NotImplementedError
711
882
 
712
- def forward_normal(
883
+ def forward_normal_prepare(
713
884
  self,
714
885
  positions: torch.Tensor,
715
886
  hidden_states: torch.Tensor,
716
887
  forward_batch: ForwardBatch,
717
- ) -> torch.Tensor:
888
+ zero_allocator: BumpAllocator,
889
+ ):
718
890
  if self.q_lora_rank is not None:
719
891
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
720
892
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -749,18 +921,22 @@ class DeepseekV2AttentionMLA(nn.Module):
749
921
  forward_batch.token_to_kv_pool.set_kv_buffer(
750
922
  self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
751
923
  )
924
+
925
+ return q, k, v, forward_batch
926
+
927
+ def forward_normal_core(self, q, k, v, forward_batch):
752
928
  attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
753
929
  attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
754
930
  output, _ = self.o_proj(attn_output)
755
931
  return output
756
932
 
757
- def forward_absorb(
933
+ def forward_absorb_prepare(
758
934
  self,
759
935
  positions: torch.Tensor,
760
936
  hidden_states: torch.Tensor,
761
937
  forward_batch: ForwardBatch,
762
938
  zero_allocator: BumpAllocator,
763
- ) -> torch.Tensor:
939
+ ):
764
940
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
765
941
 
766
942
  if self.q_lora_rank is not None:
@@ -801,7 +977,7 @@ class DeepseekV2AttentionMLA(nn.Module):
801
977
  q_nope_out = q_nope.new_empty(
802
978
  (self.num_local_heads, aligned_m, self.kv_lora_rank)
803
979
  )
804
- deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
980
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
805
981
  (q_nope_val, q_nope_scale),
806
982
  (self.w_kc, self.w_scale_k),
807
983
  q_nope_out,
@@ -809,8 +985,8 @@ class DeepseekV2AttentionMLA(nn.Module):
809
985
  expected_m,
810
986
  )
811
987
  q_nope_out = q_nope_out[:, :expected_m, :]
812
- elif self.w_kc.dtype == torch.float8_e4m3fnuz:
813
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
988
+ elif _is_hip:
989
+ # TODO(haishaw): add bmm_fp8 to ROCm
814
990
  q_nope_out = torch.bmm(
815
991
  q_nope.to(torch.bfloat16).transpose(0, 1),
816
992
  self.w_kc.to(torch.bfloat16) * self.w_scale,
@@ -829,7 +1005,16 @@ class DeepseekV2AttentionMLA(nn.Module):
829
1005
  q_nope_out = q_nope_out.transpose(0, 1)
830
1006
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
831
1007
 
832
- if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
1008
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1009
+
1010
+ def forward_absorb_core(
1011
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1012
+ ):
1013
+ if (
1014
+ self.attention_backend == "fa3"
1015
+ or self.attention_backend == "flashinfer"
1016
+ or self.attention_backend == "cutlass_mla"
1017
+ ):
833
1018
  attn_output = self.attn_mqa(
834
1019
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
835
1020
  )
@@ -848,7 +1033,7 @@ class DeepseekV2AttentionMLA(nn.Module):
848
1033
  attn_bmm_output = attn_output.new_empty(
849
1034
  (self.num_local_heads, aligned_m, self.v_head_dim)
850
1035
  )
851
- deep_gemm_grouped_gemm_nt_f8f8bf16_masked(
1036
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
852
1037
  (attn_output_val, attn_output_scale),
853
1038
  (self.w_vc, self.w_scale_v),
854
1039
  attn_bmm_output,
@@ -856,8 +1041,8 @@ class DeepseekV2AttentionMLA(nn.Module):
856
1041
  expected_m,
857
1042
  )
858
1043
  attn_bmm_output = attn_bmm_output[:, :expected_m, :]
859
- elif self.w_vc.dtype == torch.float8_e4m3fnuz:
860
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1044
+ elif _is_hip:
1045
+ # TODO(haishaw): add bmm_fp8 to ROCm
861
1046
  attn_bmm_output = torch.bmm(
862
1047
  attn_output.to(torch.bfloat16).transpose(0, 1),
863
1048
  self.w_vc.to(torch.bfloat16) * self.w_scale,
@@ -881,13 +1066,13 @@ class DeepseekV2AttentionMLA(nn.Module):
881
1066
 
882
1067
  return output
883
1068
 
884
- def forward_absorb_fused_mla_rope(
1069
+ def forward_absorb_fused_mla_rope_prepare(
885
1070
  self,
886
1071
  positions: torch.Tensor,
887
1072
  hidden_states: torch.Tensor,
888
1073
  forward_batch: ForwardBatch,
889
1074
  zero_allocator: BumpAllocator,
890
- ) -> torch.Tensor:
1075
+ ):
891
1076
  enable_rope_fusion = (
892
1077
  os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
893
1078
  )
@@ -908,8 +1093,8 @@ class DeepseekV2AttentionMLA(nn.Module):
908
1093
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
909
1094
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
910
1095
 
911
- if self.w_kc.dtype == torch.float8_e4m3fnuz:
912
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1096
+ if _is_hip:
1097
+ # TODO(haishaw): add bmm_fp8 to ROCm
913
1098
  q_nope_out = torch.bmm(
914
1099
  q_nope.to(torch.bfloat16).transpose(0, 1),
915
1100
  self.w_kc.to(torch.bfloat16) * self.w_scale,
@@ -976,6 +1161,44 @@ class DeepseekV2AttentionMLA(nn.Module):
976
1161
  )
977
1162
  val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
978
1163
 
1164
+ return (
1165
+ q_input,
1166
+ key_cache_buf,
1167
+ val_cache_buf,
1168
+ attn_output,
1169
+ kv_indptr,
1170
+ kv_indices,
1171
+ k_pe_output,
1172
+ cos_sin_cache,
1173
+ positions,
1174
+ attn_logits,
1175
+ num_kv_split,
1176
+ sm_scale,
1177
+ enable_rope_fusion,
1178
+ k_input,
1179
+ forward_batch,
1180
+ zero_allocator,
1181
+ )
1182
+
1183
+ def forward_absorb_fused_mla_rope_core(
1184
+ self,
1185
+ q_input,
1186
+ key_cache_buf,
1187
+ val_cache_buf,
1188
+ attn_output,
1189
+ kv_indptr,
1190
+ kv_indices,
1191
+ k_pe_output,
1192
+ cos_sin_cache,
1193
+ positions,
1194
+ attn_logits,
1195
+ num_kv_split,
1196
+ sm_scale,
1197
+ enable_rope_fusion,
1198
+ k_input,
1199
+ forward_batch,
1200
+ zero_allocator,
1201
+ ):
979
1202
  decode_attention_fwd_grouped_rope(
980
1203
  q_input,
981
1204
  key_cache_buf,
@@ -1004,8 +1227,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1004
1227
 
1005
1228
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1006
1229
 
1007
- if self.w_vc.dtype == torch.float8_e4m3fnuz:
1008
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1230
+ if _is_hip:
1231
+ # TODO(haishaw): add bmm_fp8 to ROCm
1009
1232
  attn_bmm_output = torch.bmm(
1010
1233
  attn_output.to(torch.bfloat16).transpose(0, 1),
1011
1234
  self.w_vc.to(torch.bfloat16) * self.w_scale,
@@ -1082,12 +1305,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1082
1305
 
1083
1306
  return accum_output
1084
1307
 
1085
- def forward_normal_chunked_kv(
1308
+ def forward_normal_chunked_kv_prepare(
1086
1309
  self,
1087
1310
  positions: torch.Tensor,
1088
1311
  hidden_states: torch.Tensor,
1089
1312
  forward_batch: ForwardBatch,
1090
- ) -> torch.Tensor:
1313
+ zero_allocator: BumpAllocator,
1314
+ ):
1091
1315
  # In normal mha, the k and v tensors will become overly large when the prefix length is long.
1092
1316
  # To avoid this, we split the kv cache into chunks and process them one after another.
1093
1317
  # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
@@ -1130,6 +1354,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1130
1354
  self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1131
1355
  )
1132
1356
 
1357
+ return q, k, v, forward_batch
1358
+
1359
+ def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1133
1360
  # Do mha for extended part without prefix
1134
1361
  forward_batch.set_attn_attend_prefix_cache(False)
1135
1362
  attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
@@ -1252,17 +1479,29 @@ class DeepseekV2DecoderLayer(nn.Module):
1252
1479
  residual: Optional[torch.Tensor],
1253
1480
  zero_allocator: BumpAllocator,
1254
1481
  ) -> torch.Tensor:
1255
- return execute_operations(
1256
- inputs=dict(
1257
- positions=positions,
1258
- hidden_states=hidden_states,
1259
- forward_batch=forward_batch,
1260
- residual=residual,
1261
- zero_allocator=zero_allocator,
1262
- ),
1263
- operations=compute_layer_operations(self),
1482
+ hidden_states, residual = self.layer_communicator.prepare_attn(
1483
+ hidden_states, residual, forward_batch
1264
1484
  )
1265
1485
 
1486
+ hidden_states = self.self_attn(
1487
+ positions=positions,
1488
+ hidden_states=hidden_states,
1489
+ forward_batch=forward_batch,
1490
+ zero_allocator=zero_allocator,
1491
+ )
1492
+
1493
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
1494
+ hidden_states, residual, forward_batch
1495
+ )
1496
+
1497
+ hidden_states = self.mlp(hidden_states, forward_batch)
1498
+
1499
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1500
+ hidden_states, residual, forward_batch
1501
+ )
1502
+
1503
+ return hidden_states, residual
1504
+
1266
1505
  def op_comm_prepare_attn(
1267
1506
  self,
1268
1507
  state,
@@ -1271,6 +1510,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1271
1510
  forward_batch: ForwardBatch,
1272
1511
  residual: Optional[torch.Tensor],
1273
1512
  zero_allocator: BumpAllocator,
1513
+ tbo_subbatch_index: Optional[int] = None,
1274
1514
  ):
1275
1515
  state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
1276
1516
  self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
@@ -1280,17 +1520,10 @@ class DeepseekV2DecoderLayer(nn.Module):
1280
1520
  forward_batch=forward_batch,
1281
1521
  positions=positions,
1282
1522
  zero_allocator=zero_allocator,
1523
+ tbo_subbatch_index=tbo_subbatch_index,
1283
1524
  )
1284
1525
  )
1285
1526
 
1286
- def op_attn(self, state):
1287
- state.hidden_states_after_attn = self.self_attn(
1288
- positions=state.positions,
1289
- hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
1290
- forward_batch=state.forward_batch,
1291
- zero_allocator=state.zero_allocator,
1292
- )
1293
-
1294
1527
  def op_comm_prepare_mlp(self, state):
1295
1528
  state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
1296
1529
  self.layer_communicator.prepare_mlp(
@@ -1320,8 +1553,24 @@ class DeepseekV2DecoderLayer(nn.Module):
1320
1553
  state.forward_batch,
1321
1554
  )
1322
1555
 
1323
- state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
1324
- return hidden_states, residual
1556
+ output = dict(
1557
+ positions=state.positions,
1558
+ hidden_states=hidden_states,
1559
+ residual=residual,
1560
+ forward_batch=state.forward_batch,
1561
+ zero_allocator=state.zero_allocator,
1562
+ tbo_subbatch_index=state.tbo_subbatch_index,
1563
+ )
1564
+
1565
+ state.clear(
1566
+ expect_keys={
1567
+ "positions",
1568
+ "forward_batch",
1569
+ "zero_allocator",
1570
+ "tbo_subbatch_index",
1571
+ }
1572
+ )
1573
+ return output
1325
1574
 
1326
1575
 
1327
1576
  class DeepseekV2Model(nn.Module):
@@ -1336,6 +1585,7 @@ class DeepseekV2Model(nn.Module):
1336
1585
  super().__init__()
1337
1586
  self.padding_id = config.pad_token_id
1338
1587
  self.vocab_size = config.vocab_size
1588
+ self.first_k_dense_replace = config.first_k_dense_replace
1339
1589
 
1340
1590
  self.embed_tokens = VocabParallelEmbedding(
1341
1591
  config.vocab_size,
@@ -1369,13 +1619,12 @@ class DeepseekV2Model(nn.Module):
1369
1619
  forward_batch: ForwardBatch,
1370
1620
  input_embeds: torch.Tensor = None,
1371
1621
  ) -> torch.Tensor:
1622
+ total_num_layers = len(self.layers)
1623
+ device = input_embeds.device if input_embeds is not None else input_ids.device
1372
1624
  zero_allocator = BumpAllocator(
1373
- # TODO for two-batch-overlap, we need a larger buffer size
1374
- buffer_size=len(self.layers) * 2,
1625
+ buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
1375
1626
  dtype=torch.float32,
1376
- device=(
1377
- input_embeds.device if input_embeds is not None else input_ids.device
1378
- ),
1627
+ device=device,
1379
1628
  )
1380
1629
 
1381
1630
  if input_embeds is None:
@@ -1384,12 +1633,33 @@ class DeepseekV2Model(nn.Module):
1384
1633
  hidden_states = input_embeds
1385
1634
 
1386
1635
  residual = None
1387
- for i in range(len(self.layers)):
1636
+
1637
+ normal_num_layers = (
1638
+ self.first_k_dense_replace
1639
+ if forward_batch.can_run_tbo
1640
+ else total_num_layers
1641
+ )
1642
+ for i in range(normal_num_layers):
1388
1643
  with get_global_expert_distribution_recorder().with_current_layer(i):
1389
1644
  layer = self.layers[i]
1390
1645
  hidden_states, residual = layer(
1391
1646
  positions, hidden_states, forward_batch, residual, zero_allocator
1392
1647
  )
1648
+
1649
+ if normal_num_layers != total_num_layers:
1650
+ hidden_states, residual = model_forward_maybe_tbo(
1651
+ layers=self.layers[normal_num_layers:],
1652
+ enable_tbo=True,
1653
+ positions=positions,
1654
+ forward_batch=forward_batch,
1655
+ hidden_states=hidden_states,
1656
+ residual=residual,
1657
+ input_data_scatter_mode=self.layers[
1658
+ normal_num_layers - 1
1659
+ ].layer_scatter_modes.layer_output_mode,
1660
+ zero_allocator=zero_allocator,
1661
+ )
1662
+
1393
1663
  if not forward_batch.forward_mode.is_idle():
1394
1664
  if residual is None:
1395
1665
  hidden_states = self.norm(hidden_states)
@@ -1410,7 +1680,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1410
1680
  self.config = config
1411
1681
  self.tp_size = get_tensor_model_parallel_world_size()
1412
1682
  self.quant_config = quant_config
1413
- self.determine_n_share_experts_fusion()
1683
+ self.determine_num_fused_shared_experts()
1414
1684
  self.model = DeepseekV2Model(
1415
1685
  config, quant_config, prefix=add_prefix("model", prefix)
1416
1686
  )
@@ -1424,41 +1694,50 @@ class DeepseekV2ForCausalLM(nn.Module):
1424
1694
  self.logits_processor = LogitsProcessor(config)
1425
1695
  self.dp_size = get_local_attention_dp_size()
1426
1696
 
1427
- def determine_n_share_experts_fusion(
1697
+ self._routed_experts_weights_of_layer = LazyValue(
1698
+ lambda: {
1699
+ layer_id: layer.mlp.get_moe_weights()
1700
+ for layer_id, layer in enumerate(self.model.layers)
1701
+ if isinstance(layer.mlp, DeepseekV2MoE)
1702
+ }
1703
+ )
1704
+
1705
+ @property
1706
+ def routed_experts_weights_of_layer(self):
1707
+ return self._routed_experts_weights_of_layer.value
1708
+
1709
+ def determine_num_fused_shared_experts(
1428
1710
  self, architecture: str = "DeepseekV3ForCausalLM"
1429
1711
  ):
1430
- self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1431
- if self.n_share_experts_fusion > 0:
1432
- # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1433
- if (
1434
- not _is_cuda
1435
- or self.config.architectures[0] != architecture
1436
- or self.config.n_routed_experts != 256
1437
- ):
1438
- self.n_share_experts_fusion = 0
1439
- global_server_args_dict["n_share_experts_fusion"] = 0
1440
- log_info_on_rank0(
1441
- logger,
1442
- "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1443
- )
1444
- else:
1445
- assert (
1446
- self.n_share_experts_fusion == self.tp_size
1447
- ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1448
- elif self.n_share_experts_fusion == 0:
1449
- if (
1450
- _is_cuda
1451
- and torch.cuda.get_device_capability("cuda") >= (9, 0)
1452
- and self.config.architectures[0] == architecture
1453
- and self.config.n_routed_experts == 256
1454
- and (not global_server_args_dict["enable_deepep_moe"])
1455
- ):
1456
- self.n_share_experts_fusion = self.tp_size
1457
- global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1458
- log_info_on_rank0(
1459
- logger,
1460
- "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1461
- )
1712
+ self.num_fused_shared_experts = 0
1713
+ if global_server_args_dict["disable_shared_experts_fusion"]:
1714
+ return
1715
+
1716
+ # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1717
+ disable_reason = None
1718
+ if (
1719
+ not _is_cuda
1720
+ or torch.cuda.get_device_capability("cuda") < (9, 0)
1721
+ or self.config.architectures[0] != architecture
1722
+ or self.config.n_routed_experts != 256
1723
+ or self.config.n_shared_experts != 1
1724
+ ):
1725
+ disable_reason = "Only Deepseek V3/R1 on NV-platform with capability >= 90 can use shared experts fusion optimization."
1726
+ elif (
1727
+ global_server_args_dict["enable_deepep_moe"]
1728
+ or global_server_args_dict["enable_ep_moe"]
1729
+ ):
1730
+ disable_reason = "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode."
1731
+
1732
+ if disable_reason is not None:
1733
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1734
+ log_info_on_rank0(
1735
+ logger,
1736
+ f"{disable_reason} Shared experts fusion optimization is disabled.",
1737
+ )
1738
+ return
1739
+
1740
+ self.num_fused_shared_experts = self.config.n_shared_experts
1462
1741
 
1463
1742
  def get_input_embeddings(self) -> nn.Embedding:
1464
1743
  return self.model.embed_tokens
@@ -1471,21 +1750,28 @@ class DeepseekV2ForCausalLM(nn.Module):
1471
1750
  forward_batch: ForwardBatch,
1472
1751
  input_embeds: torch.Tensor = None,
1473
1752
  ) -> torch.Tensor:
1474
-
1475
1753
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
1476
1754
 
1477
1755
  return self.logits_processor(
1478
1756
  input_ids, hidden_states, self.lm_head, forward_batch
1479
1757
  )
1480
1758
 
1481
- def post_load_weights(self, is_nextn=False):
1759
+ def post_load_weights(self, is_nextn=False, weight_names=None):
1482
1760
 
1483
1761
  # Perform post-processing after loading weights
1484
- layer_ids = (
1485
- range(self.config.num_hidden_layers)
1486
- if not is_nextn
1487
- else [self.config.num_hidden_layers]
1488
- )
1762
+ if is_nextn:
1763
+ layer_ids = [self.config.num_hidden_layers]
1764
+ else:
1765
+ if weight_names is None:
1766
+ layer_ids = range(self.config.num_hidden_layers)
1767
+ else:
1768
+ layer_ids = set()
1769
+ for name in weight_names:
1770
+ if "kv_b_proj" in name:
1771
+ layer_id = int(name.split(".")[2])
1772
+ if layer_id < self.config.num_hidden_layers:
1773
+ layer_ids.add(layer_id)
1774
+
1489
1775
  for layer_id in layer_ids:
1490
1776
  self_attn = (
1491
1777
  self.model.layers[layer_id].self_attn
@@ -1521,46 +1807,58 @@ class DeepseekV2ForCausalLM(nn.Module):
1521
1807
  torch.float8_e4m3fn,
1522
1808
  torch.float8_e4m3fnuz,
1523
1809
  ):
1524
- if hasattr(self.quant_config, "weight_block_size"):
1810
+ if (
1811
+ hasattr(self.quant_config, "weight_block_size")
1812
+ and self.quant_config.weight_block_size is not None
1813
+ ):
1525
1814
  weight_block_size = self.quant_config.weight_block_size
1526
- if weight_block_size is not None:
1527
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1528
- if _is_hip:
1529
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1530
- weight=w,
1531
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1532
- input_scale=None,
1533
- )
1534
- else:
1535
- weight = w
1536
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1815
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1816
+ if _is_fp8_fnuz:
1817
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1818
+ weight=w,
1819
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1820
+ input_scale=None,
1821
+ )
1822
+ else:
1823
+ weight = w
1824
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1537
1825
 
1826
+ if (
1827
+ _is_cuda
1828
+ and weight_block_size[0] == 128
1829
+ and weight_block_size[1] == 128
1830
+ and model_dtype == torch.bfloat16
1831
+ ):
1538
1832
  if (
1539
- _is_cuda
1540
- and weight_block_size[0] == 128
1541
- and weight_block_size[1] == 128
1542
- and model_dtype == torch.bfloat16
1833
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1834
+ and not deep_gemm_wrapper.DEEPGEMM_BLACKWELL
1835
+ and get_bool_env_var("SGL_USE_DEEPGEMM_BMM", "false")
1543
1836
  ):
1544
- if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1545
- "SGL_USE_DEEPGEMM_BMM", "false"
1546
- ):
1547
- block_scale = weight_scale
1548
- use_deep_gemm_bmm = True
1549
- else:
1550
- w = block_quant_dequant(
1551
- weight,
1552
- weight_scale,
1553
- weight_block_size,
1554
- model_dtype,
1555
- )
1837
+ block_scale = weight_scale
1838
+ use_deep_gemm_bmm = True
1556
1839
  else:
1557
- w, scale = block_quant_to_tensor_quant(
1558
- weight, weight_scale, weight_block_size
1840
+ w = block_quant_dequant(
1841
+ weight,
1842
+ weight_scale,
1843
+ weight_block_size,
1844
+ model_dtype,
1559
1845
  )
1560
- self_attn.w_scale = scale
1846
+ else:
1847
+ w, scale = block_quant_to_tensor_quant(
1848
+ weight, weight_scale, weight_block_size
1849
+ )
1850
+ self_attn.w_scale = scale
1561
1851
  else:
1562
- weight = w
1563
- weight_scale = self_attn.kv_b_proj.weight_scale
1852
+ if _is_fp8_fnuz:
1853
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1854
+ weight=w,
1855
+ weight_scale=self_attn.kv_b_proj.weight_scale,
1856
+ input_scale=None,
1857
+ )
1858
+ else:
1859
+ weight = w
1860
+ weight_scale = self_attn.kv_b_proj.weight_scale
1861
+
1564
1862
  w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1565
1863
  self_attn.w_scale = scale
1566
1864
 
@@ -1585,13 +1883,19 @@ class DeepseekV2ForCausalLM(nn.Module):
1585
1883
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1586
1884
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1587
1885
  if not use_deep_gemm_bmm:
1588
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1589
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1886
+ self_attn.w_kc = bind_or_assign(
1887
+ self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1888
+ )
1889
+ self_attn.w_vc = bind_or_assign(
1890
+ self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
1891
+ )
1590
1892
  if (
1591
1893
  hasattr(self_attn.kv_b_proj, "weight_scale")
1592
1894
  and self_attn.w_scale is None
1593
1895
  ):
1594
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1896
+ self_attn.w_scale = bind_or_assign(
1897
+ self_attn.w_scale, self_attn.kv_b_proj.weight_scale
1898
+ )
1595
1899
  if _is_hip:
1596
1900
  self_attn.w_scale *= 2.0
1597
1901
  else:
@@ -1600,21 +1904,79 @@ class DeepseekV2ForCausalLM(nn.Module):
1600
1904
  ws_kc, ws_vc = block_scale.unflatten(
1601
1905
  0, (-1, (num_tiles_k + num_tiles_n))
1602
1906
  ).split([num_tiles_k, num_tiles_n], dim=1)
1603
- self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1604
- self_attn.w_scale_v = ws_vc.contiguous()
1605
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1606
- self_attn.w_vc = w_vc.contiguous()
1907
+ self_attn.w_scale_k = bind_or_assign(
1908
+ self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
1909
+ )
1910
+ self_attn.w_scale_v = bind_or_assign(
1911
+ self_attn.w_scale_v, ws_vc.contiguous()
1912
+ )
1913
+ self_attn.w_kc = bind_or_assign(
1914
+ self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
1915
+ )
1916
+ self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
1607
1917
  self_attn.use_deep_gemm_bmm = True
1608
1918
 
1609
- # TODO support nextn later
1610
- if not is_nextn:
1611
- self.routed_experts_weights_of_layer = {
1612
- layer_id: layer.mlp.get_moe_weights()
1613
- for layer_id, layer in enumerate(self.model.layers)
1614
- if isinstance(layer.mlp, DeepseekV2MoE)
1615
- }
1919
+ if (
1920
+ deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
1921
+ and deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
1922
+ ):
1923
+ self._weight_requant_ue8m0()
1924
+
1925
+ def _weight_requant_ue8m0(self):
1926
+ weight_block_size = self.quant_config.weight_block_size
1927
+
1928
+ moe_layers = list(
1929
+ range(
1930
+ self.config.first_k_dense_replace,
1931
+ self.config.num_hidden_layers,
1932
+ self.config.moe_layer_freq,
1933
+ )
1934
+ )
1935
+
1936
+ for layer_id in range(self.config.num_hidden_layers):
1937
+ layer = self.model.layers[layer_id]
1938
+
1939
+ for module in [
1940
+ layer.self_attn.fused_qkv_a_proj_with_mqa,
1941
+ layer.self_attn.q_b_proj,
1942
+ layer.self_attn.kv_b_proj,
1943
+ layer.self_attn.o_proj,
1944
+ ]:
1945
+ requant_weight_ue8m0_inplace(
1946
+ module.weight, module.weight_scale_inv, weight_block_size
1947
+ )
1948
+
1949
+ if layer_id in moe_layers:
1950
+ shared_experts = getattr(layer.mlp, "shared_experts", None)
1951
+ if shared_experts is not None:
1952
+ for module in [
1953
+ shared_experts.gate_up_proj,
1954
+ shared_experts.down_proj,
1955
+ ]:
1956
+ requant_weight_ue8m0_inplace(
1957
+ module.weight, module.weight_scale_inv, weight_block_size
1958
+ )
1959
+
1960
+ experts = layer.mlp.experts
1961
+ if isinstance(experts, DeepEPMoE):
1962
+ for w in [
1963
+ experts.w13_weight_fp8,
1964
+ experts.w2_weight_fp8,
1965
+ ]:
1966
+ requant_weight_ue8m0_inplace(w[0], w[1], weight_block_size)
1967
+ else:
1968
+ mlp = layer.mlp
1969
+ assert isinstance(mlp, DeepseekV2MLP)
1970
+ for module in [
1971
+ mlp.gate_up_proj,
1972
+ mlp.down_proj,
1973
+ ]:
1974
+ requant_weight_ue8m0_inplace(
1975
+ module.weight, module.weight_scale_inv, weight_block_size
1976
+ )
1616
1977
 
1617
1978
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1979
+
1618
1980
  if is_nextn:
1619
1981
  if hasattr(self.config, "num_nextn_predict_layers"):
1620
1982
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1633,60 +1995,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1633
1995
  ("gate_up_proj", "gate_proj", 0),
1634
1996
  ("gate_up_proj", "up_proj", 1),
1635
1997
  ]
1636
- if self.n_share_experts_fusion > 0:
1637
- weights_list = list(weights)
1638
- weights_dict = dict(weights_list)
1639
- if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1640
- suffix_list = [
1641
- "down_proj.weight",
1642
- "down_proj.weight_scale",
1643
- "gate_proj.weight",
1644
- "gate_proj.weight_scale",
1645
- "up_proj.weight",
1646
- "up_proj.weight_scale",
1647
- ]
1648
- else:
1649
- suffix_list = [
1650
- "down_proj.weight",
1651
- "down_proj.weight_scale_inv",
1652
- "gate_proj.weight",
1653
- "gate_proj.weight_scale_inv",
1654
- "up_proj.weight",
1655
- "up_proj.weight_scale_inv",
1656
- ]
1657
- names_to_remove = []
1658
-
1659
- moe_layers = (
1660
- range(
1661
- self.config.first_k_dense_replace,
1662
- self.config.num_hidden_layers,
1663
- self.config.moe_layer_freq,
1664
- )
1665
- if not is_nextn
1666
- else [nextn_layer_id]
1667
- )
1668
-
1669
- for moe_layer in tqdm(
1670
- moe_layers,
1671
- desc=f"Cloning {self.n_share_experts_fusion} "
1672
- "replicas of the shared expert into MoE",
1673
- ):
1674
- for suffix in suffix_list:
1675
- shared_expert_weight_name = (
1676
- f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1677
- )
1678
- for num_repeat in range(self.n_share_experts_fusion):
1679
- weights_list.append(
1680
- (
1681
- f"model.layers.{moe_layer}."
1682
- f"mlp.experts."
1683
- f"{self.config.n_routed_experts + num_repeat}"
1684
- f".{suffix}",
1685
- weights_dict[shared_expert_weight_name],
1686
- )
1687
- )
1688
- names_to_remove += [shared_expert_weight_name]
1689
- weights = [w for w in weights_list if w[0] not in names_to_remove]
1690
1998
 
1691
1999
  # Params for weights, fp8 weight scales, fp8 activation scales
1692
2000
  # (param_name, weight_name, expert_id, shard_id)
@@ -1694,7 +2002,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1694
2002
  ckpt_gate_proj_name="gate_proj",
1695
2003
  ckpt_down_proj_name="down_proj",
1696
2004
  ckpt_up_proj_name="up_proj",
1697
- num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
2005
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
1698
2006
  )
1699
2007
 
1700
2008
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
@@ -1712,8 +2020,21 @@ class DeepseekV2ForCausalLM(nn.Module):
1712
2020
  "hnorm",
1713
2021
  ]
1714
2022
 
2023
+ if self.num_fused_shared_experts > 0:
2024
+ assert self.num_fused_shared_experts == 1
2025
+ logger.info("Shared experts fusion optimization enabled.")
2026
+
1715
2027
  params_dict = dict(self.named_parameters())
2028
+ weight_names = []
1716
2029
  for name, loaded_weight in weights:
2030
+ if self.num_fused_shared_experts > 0 and "mlp.shared_experts" in name:
2031
+ name = name.replace(
2032
+ "mlp.shared_experts",
2033
+ f"mlp.experts.{self.config.n_routed_experts}",
2034
+ )
2035
+
2036
+ weight_names.append(name)
2037
+
1717
2038
  if not is_nextn:
1718
2039
  if hasattr(self.config, "num_nextn_predict_layers"):
1719
2040
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1785,7 +2106,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1785
2106
  # Skip loading extra bias for GPTQ models.
1786
2107
  if name.endswith(".bias") and name not in params_dict:
1787
2108
  continue
1788
-
1789
2109
  if fuse_qkv_a_proj and (
1790
2110
  "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1791
2111
  ):
@@ -1811,9 +2131,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1811
2131
  fused_weight = torch.cat(
1812
2132
  [q_a_proj_weight, kv_a_proj_weight], dim=0
1813
2133
  )
1814
-
1815
- param_name = name.replace(
1816
- "q_a_proj", "fused_qkv_a_proj_with_mqa"
2134
+ param_name = (
2135
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
2136
+ if "q_a_proj" in name
2137
+ else name.replace(
2138
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
2139
+ )
1817
2140
  )
1818
2141
  param = params_dict[param_name]
1819
2142
 
@@ -1824,13 +2147,23 @@ class DeepseekV2ForCausalLM(nn.Module):
1824
2147
  cached_a_proj.pop(q_a_proj_name)
1825
2148
  cached_a_proj.pop(kv_a_proj_name)
1826
2149
  else:
2150
+ if (
2151
+ "k_scale" in name or "v_scale" in name
2152
+ ) and name not in params_dict:
2153
+ # modelopt attn kv scale is named differently
2154
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
2155
+ name = name.replace("_proj", "attn_mqa")
2156
+ else:
2157
+ logger.warning(
2158
+ f"Unknown scale found in checkpoint: {name}"
2159
+ )
1827
2160
  param = params_dict[name]
1828
2161
  weight_loader = getattr(
1829
2162
  param, "weight_loader", default_weight_loader
1830
2163
  )
1831
2164
  weight_loader(param, loaded_weight)
1832
2165
 
1833
- self.post_load_weights(is_nextn=is_nextn)
2166
+ self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
1834
2167
 
1835
2168
  def get_embed_and_head(self):
1836
2169
  return self.model.embed_tokens.weight, self.lm_head.weight