sglang 0.4.6.post5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (318) hide show
  1. sglang/bench_offline_throughput.py +10 -4
  2. sglang/bench_one_batch_server.py +67 -11
  3. sglang/bench_serving.py +85 -74
  4. sglang/lang/backend/runtime_endpoint.py +24 -1
  5. sglang/profiler.py +167 -0
  6. sglang/srt/_custom_ops.py +34 -0
  7. sglang/srt/configs/internvl.py +8 -12
  8. sglang/srt/configs/model_config.py +27 -1
  9. sglang/srt/constrained/base_grammar_backend.py +5 -2
  10. sglang/srt/constrained/llguidance_backend.py +9 -8
  11. sglang/srt/constrained/outlines_backend.py +5 -4
  12. sglang/srt/constrained/xgrammar_backend.py +18 -18
  13. sglang/srt/conversation.py +46 -8
  14. sglang/srt/custom_op.py +38 -3
  15. sglang/srt/debug_utils.py +74 -0
  16. sglang/srt/disaggregation/common/__init__.py +1 -0
  17. sglang/srt/disaggregation/common/conn.py +407 -0
  18. sglang/srt/disaggregation/decode.py +67 -3
  19. sglang/srt/disaggregation/fake/conn.py +1 -0
  20. sglang/srt/disaggregation/kv_events.py +60 -5
  21. sglang/srt/disaggregation/launch_lb.py +140 -0
  22. sglang/srt/disaggregation/mini_lb.py +29 -48
  23. sglang/srt/disaggregation/mooncake/conn.py +432 -140
  24. sglang/srt/disaggregation/mooncake/transfer_engine.py +32 -16
  25. sglang/srt/disaggregation/nixl/conn.py +124 -432
  26. sglang/srt/disaggregation/prefill.py +2 -0
  27. sglang/srt/disaggregation/utils.py +38 -1
  28. sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
  29. sglang/srt/distributed/parallel_state.py +52 -5
  30. sglang/srt/entrypoints/EngineBase.py +6 -0
  31. sglang/srt/entrypoints/engine.py +102 -5
  32. sglang/srt/entrypoints/http_server.py +15 -2
  33. sglang/srt/function_call/base_format_detector.py +138 -86
  34. sglang/srt/function_call/deepseekv3_detector.py +54 -6
  35. sglang/srt/function_call/ebnf_composer.py +33 -19
  36. sglang/srt/function_call/function_call_parser.py +27 -0
  37. sglang/srt/function_call/llama32_detector.py +33 -14
  38. sglang/srt/function_call/mistral_detector.py +73 -26
  39. sglang/srt/function_call/pythonic_detector.py +86 -20
  40. sglang/srt/function_call/qwen25_detector.py +64 -10
  41. sglang/srt/function_call/utils.py +17 -0
  42. sglang/srt/hf_transformers_utils.py +4 -0
  43. sglang/srt/layers/attention/aiter_backend.py +488 -123
  44. sglang/srt/layers/attention/base_attn_backend.py +4 -0
  45. sglang/srt/layers/attention/cutlass_mla_backend.py +2 -19
  46. sglang/srt/layers/attention/flashattention_backend.py +103 -18
  47. sglang/srt/layers/attention/flashinfer_backend.py +45 -1
  48. sglang/srt/layers/attention/flashinfer_mla_backend.py +37 -1
  49. sglang/srt/layers/attention/intel_amx_backend.py +128 -0
  50. sglang/srt/layers/attention/tbo_backend.py +232 -0
  51. sglang/srt/layers/attention/torch_native_backend.py +3 -0
  52. sglang/srt/layers/attention/triton_backend.py +244 -5
  53. sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
  54. sglang/srt/layers/communicator.py +260 -194
  55. sglang/srt/layers/dp_attention.py +6 -5
  56. sglang/srt/layers/layernorm.py +30 -19
  57. sglang/srt/layers/moe/cutlass_moe.py +170 -7
  58. sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
  59. sglang/srt/layers/moe/ep_moe/kernels.py +27 -6
  60. sglang/srt/layers/moe/ep_moe/layer.py +94 -40
  61. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +13 -8
  62. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  71. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +220 -25
  72. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -4
  73. sglang/srt/layers/moe/topk.py +44 -18
  74. sglang/srt/layers/multimodal.py +3 -3
  75. sglang/srt/layers/quantization/__init__.py +3 -2
  76. sglang/srt/layers/quantization/blockwise_int8.py +3 -0
  77. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
  78. sglang/srt/layers/quantization/deep_gemm.py +55 -56
  79. sglang/srt/layers/quantization/fp8.py +28 -23
  80. sglang/srt/layers/quantization/fp8_kernel.py +118 -66
  81. sglang/srt/layers/quantization/fp8_utils.py +165 -49
  82. sglang/srt/layers/quantization/modelopt_quant.py +334 -7
  83. sglang/srt/layers/quantization/moe_wna16.py +3 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
  85. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  86. sglang/srt/layers/rotary_embedding.py +6 -12
  87. sglang/srt/layers/sampler.py +80 -79
  88. sglang/srt/layers/utils.py +6 -0
  89. sglang/srt/lora/layers.py +12 -15
  90. sglang/srt/lora/lora.py +49 -5
  91. sglang/srt/lora/lora_manager.py +19 -5
  92. sglang/srt/lora/mem_pool.py +24 -16
  93. sglang/srt/lora/utils.py +17 -13
  94. sglang/srt/managers/data_parallel_controller.py +13 -5
  95. sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
  96. sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
  97. sglang/srt/managers/{deepseek_eplb.py → eplb_algorithms/deepseek_vec.py} +5 -7
  98. sglang/srt/managers/eplb_manager.py +55 -14
  99. sglang/srt/managers/expert_distribution.py +220 -46
  100. sglang/srt/managers/expert_location.py +110 -56
  101. sglang/srt/managers/expert_location_dispatch.py +23 -6
  102. sglang/srt/managers/io_struct.py +15 -4
  103. sglang/srt/managers/mm_utils.py +88 -38
  104. sglang/srt/managers/multimodal_processors/base_processor.py +188 -16
  105. sglang/srt/managers/multimodal_processors/gemma3.py +4 -31
  106. sglang/srt/managers/multimodal_processors/internvl.py +4 -0
  107. sglang/srt/managers/multimodal_processors/kimi_vl.py +15 -34
  108. sglang/srt/managers/multimodal_processors/minicpm.py +2 -1
  109. sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
  110. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -64
  111. sglang/srt/managers/schedule_batch.py +140 -38
  112. sglang/srt/managers/scheduler.py +305 -112
  113. sglang/srt/managers/tokenizer_manager.py +134 -17
  114. sglang/srt/managers/utils.py +0 -4
  115. sglang/srt/metrics/collector.py +9 -0
  116. sglang/srt/model_executor/cuda_graph_runner.py +72 -61
  117. sglang/srt/model_executor/expert_location_updater.py +157 -22
  118. sglang/srt/model_executor/forward_batch_info.py +38 -17
  119. sglang/srt/model_executor/model_runner.py +96 -56
  120. sglang/srt/model_loader/utils.py +67 -1
  121. sglang/srt/models/deepseek_nextn.py +1 -1
  122. sglang/srt/models/deepseek_v2.py +609 -234
  123. sglang/srt/models/gemma3_causal.py +7 -0
  124. sglang/srt/models/gemma3_mm.py +19 -14
  125. sglang/srt/models/idefics2.py +342 -0
  126. sglang/srt/models/kimi_vl.py +4 -4
  127. sglang/srt/models/llama.py +1 -1
  128. sglang/srt/models/minicpmo.py +2 -5
  129. sglang/srt/models/minicpmv.py +3 -295
  130. sglang/srt/models/phi4mm.py +512 -0
  131. sglang/srt/models/qwen2.py +38 -9
  132. sglang/srt/models/qwen2_5_vl.py +3 -9
  133. sglang/srt/models/qwen2_eagle.py +4 -1
  134. sglang/srt/models/qwen2_moe.py +58 -191
  135. sglang/srt/models/qwen2_vl.py +3 -9
  136. sglang/srt/models/qwen3.py +41 -10
  137. sglang/srt/models/qwen3_moe.py +230 -191
  138. sglang/srt/models/registry.py +9 -1
  139. sglang/srt/models/transformers.py +291 -0
  140. sglang/srt/openai_api/adapter.py +86 -24
  141. sglang/srt/openai_api/protocol.py +31 -2
  142. sglang/srt/openai_api/utils.py +172 -0
  143. sglang/srt/operations.py +37 -2
  144. sglang/srt/operations_strategy.py +200 -24
  145. sglang/srt/sampling/sampling_batch_info.py +13 -1
  146. sglang/srt/sampling/sampling_params.py +2 -1
  147. sglang/srt/server_args.py +114 -27
  148. sglang/srt/speculative/build_eagle_tree.py +8 -8
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -11
  150. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +253 -0
  151. sglang/srt/speculative/eagle_utils.py +51 -91
  152. sglang/srt/speculative/eagle_worker.py +101 -21
  153. sglang/srt/two_batch_overlap.py +635 -0
  154. sglang/srt/utils.py +129 -7
  155. sglang/test/runners.py +16 -7
  156. sglang/test/send_one.py +4 -0
  157. sglang/test/test_cutlass_moe.py +3 -3
  158. sglang/test/test_fp4_moe.py +248 -0
  159. sglang/test/test_utils.py +79 -6
  160. sglang/version.py +1 -1
  161. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/METADATA +14 -11
  162. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/RECORD +318 -291
  163. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/WHEEL +1 -1
  164. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  165. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  166. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  167. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  168. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  169. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  170. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  171. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  172. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  173. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  174. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  175. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  176. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  177. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
  178. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  179. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  180. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  181. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  182. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  183. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  184. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  185. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  186. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  187. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  188. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  189. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  190. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  191. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  192. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  193. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  194. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  195. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  196. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  197. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  198. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  199. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  200. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  201. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  202. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  203. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  204. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  205. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  206. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
  207. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  208. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  209. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  210. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  211. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  212. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  213. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  214. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  215. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  216. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
  217. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
  218. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  219. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  220. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  221. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  222. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  223. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  224. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
  225. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  226. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  227. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
  228. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  229. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  230. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  231. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
  232. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  233. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  234. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  235. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  236. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  237. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  238. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  239. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
  240. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
  241. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
  242. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
  243. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  244. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  245. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
  246. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
  247. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
  248. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
  249. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  250. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  251. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  252. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  253. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
  254. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  255. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  256. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  257. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  258. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
  259. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
  260. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
  261. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
  262. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  263. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  264. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  265. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  266. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  267. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  268. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
  269. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
  270. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  271. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  272. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  273. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  274. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  275. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  276. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  277. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
  278. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
  279. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
  280. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
  281. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  282. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  283. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  284. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  285. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
  286. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  287. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  288. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  289. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  290. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  291. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  292. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  293. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
  294. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
  295. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  296. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  297. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
  298. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  299. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
  300. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  301. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  302. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  303. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  304. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
  305. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  306. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
  307. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
  308. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  309. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  310. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  311. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  312. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  313. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  314. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  315. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  316. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  317. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/licenses/LICENSE +0 -0
  318. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/top_level.txt +0 -0
@@ -57,6 +57,7 @@ from sglang.srt.layers.moe.topk import select_experts
57
57
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
58
58
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
59
59
  from sglang.srt.layers.quantization.fp8_kernel import (
60
+ is_fp8_fnuz,
60
61
  per_tensor_quant_mla_fp8,
61
62
  per_token_group_quant_mla_deep_gemm_masked_fp8,
62
63
  )
@@ -83,21 +84,28 @@ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchI
83
84
  from sglang.srt.managers.schedule_batch import global_server_args_dict
84
85
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
85
86
  from sglang.srt.model_loader.weight_utils import default_weight_loader
86
- from sglang.srt.operations import execute_operations
87
- from sglang.srt.operations_strategy import compute_layer_operations
87
+ from sglang.srt.two_batch_overlap import (
88
+ MaybeTboDeepEPDispatcher,
89
+ model_forward_maybe_tbo,
90
+ )
88
91
  from sglang.srt.utils import (
89
92
  BumpAllocator,
90
93
  DeepEPMode,
94
+ LazyValue,
91
95
  add_prefix,
96
+ bind_or_assign,
92
97
  get_bool_env_var,
93
98
  get_int_env_var,
94
99
  is_cuda,
95
100
  is_hip,
101
+ is_non_idle_and_non_empty,
96
102
  log_info_on_rank0,
97
103
  )
98
104
 
99
105
  _is_hip = is_hip()
100
106
  _is_cuda = is_cuda()
107
+ _is_fp8_fnuz = is_fp8_fnuz()
108
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
101
109
 
102
110
  if _is_cuda:
103
111
  from sgl_kernel import awq_dequantize, bmm_fp8, merge_state_v2
@@ -113,6 +121,9 @@ if _is_hip:
113
121
  decode_attention_fwd_grouped_rope,
114
122
  )
115
123
 
124
+ if _use_aiter:
125
+ from aiter.rotary_embedding import get_rope
126
+
116
127
  logger = logging.getLogger(__name__)
117
128
 
118
129
 
@@ -204,14 +215,6 @@ class MoEGate(nn.Module):
204
215
  return logits
205
216
 
206
217
 
207
- def is_non_idle_and_non_empty(forward_mode, hidden_states):
208
- return (
209
- (forward_mode is not None)
210
- and not forward_mode.is_idle()
211
- and hidden_states.shape[0] > 0
212
- )
213
-
214
-
215
218
  class DeepseekV2MoE(nn.Module):
216
219
 
217
220
  def __init__(
@@ -225,7 +228,12 @@ class DeepseekV2MoE(nn.Module):
225
228
  self.tp_size = get_tensor_model_parallel_world_size()
226
229
  self.routed_scaling_factor = config.routed_scaling_factor
227
230
  self.n_shared_experts = config.n_shared_experts
228
- self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
231
+ self.num_fused_shared_experts = (
232
+ 0
233
+ if global_server_args_dict["disable_shared_experts_fusion"]
234
+ else config.n_shared_experts
235
+ )
236
+ self.config = config
229
237
  self.layer_id = layer_id
230
238
 
231
239
  if self.tp_size > config.n_routed_experts:
@@ -244,9 +252,9 @@ class DeepseekV2MoE(nn.Module):
244
252
 
245
253
  self.experts = get_moe_impl_class()(
246
254
  num_experts=config.n_routed_experts
247
- + self.n_share_experts_fusion
255
+ + self.num_fused_shared_experts
248
256
  + global_server_args_dict["ep_num_redundant_experts"],
249
- top_k=config.num_experts_per_tok + min(self.n_share_experts_fusion, 1),
257
+ top_k=config.num_experts_per_tok + self.num_fused_shared_experts,
250
258
  hidden_size=config.hidden_size,
251
259
  intermediate_size=config.moe_intermediate_size,
252
260
  layer_id=self.layer_id,
@@ -254,6 +262,7 @@ class DeepseekV2MoE(nn.Module):
254
262
  quant_config=quant_config,
255
263
  use_grouped_topk=True,
256
264
  num_expert_group=config.n_group,
265
+ num_fused_shared_experts=self.num_fused_shared_experts,
257
266
  topk_group=config.topk_group,
258
267
  correction_bias=self.gate.e_score_correction_bias,
259
268
  routed_scaling_factor=self.routed_scaling_factor,
@@ -265,7 +274,7 @@ class DeepseekV2MoE(nn.Module):
265
274
  ),
266
275
  )
267
276
 
268
- if config.n_shared_experts is not None and self.n_share_experts_fusion == 0:
277
+ if config.n_shared_experts is not None and self.num_fused_shared_experts == 0:
269
278
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
270
279
  # disable tp for shared experts when enable deepep moe
271
280
  self.shared_experts = DeepseekV2MLP(
@@ -300,7 +309,7 @@ class DeepseekV2MoE(nn.Module):
300
309
  else None
301
310
  )
302
311
 
303
- self.deepep_dispatcher = DeepEPDispatcher(
312
+ self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
304
313
  group=parallel_state.get_tp_group().device_group,
305
314
  router_topk=self.top_k,
306
315
  permute_fusion=True,
@@ -309,13 +318,11 @@ class DeepseekV2MoE(nn.Module):
309
318
  hidden_size=config.hidden_size,
310
319
  params_dtype=config.torch_dtype,
311
320
  deepep_mode=DeepEPMode[global_server_args_dict["deepep_mode"]],
312
- async_finish=True, # TODO
321
+ async_finish=True,
313
322
  return_recv_hook=True,
314
323
  )
315
324
 
316
- @property
317
- def _enable_deepep_moe(self):
318
- return global_server_args_dict["enable_deepep_moe"]
325
+ self._enable_deepep_moe = global_server_args_dict["enable_deepep_moe"]
319
326
 
320
327
  def get_moe_weights(self):
321
328
  return [
@@ -324,8 +331,114 @@ class DeepseekV2MoE(nn.Module):
324
331
  if name not in ["correction_bias"]
325
332
  ]
326
333
 
334
+ def forward(
335
+ self, hidden_states: torch.Tensor, forward_batch: Optional[ForwardBatch] = None
336
+ ) -> torch.Tensor:
337
+ if not self._enable_deepep_moe:
338
+ return self.forward_normal(hidden_states)
339
+ else:
340
+ return self.forward_deepep(hidden_states, forward_batch)
341
+
342
+ def forward_normal(self, hidden_states: torch.Tensor) -> torch.Tensor:
343
+ shared_output = self._forward_shared_experts(hidden_states)
344
+ # router_logits: (num_tokens, n_experts)
345
+ router_logits = self.gate(hidden_states)
346
+ final_hidden_states = self.experts(
347
+ hidden_states=hidden_states, router_logits=router_logits
348
+ )
349
+ if not _is_cuda:
350
+ final_hidden_states *= self.routed_scaling_factor
351
+ if shared_output is not None:
352
+ final_hidden_states = final_hidden_states + shared_output
353
+ if self.tp_size > 1:
354
+ final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
355
+ return final_hidden_states
356
+
357
+ def forward_deepep(
358
+ self, hidden_states: torch.Tensor, forward_batch: ForwardBatch
359
+ ) -> torch.Tensor:
360
+ forward_mode = forward_batch.forward_mode
361
+ shared_output = None
362
+ if is_non_idle_and_non_empty(forward_mode, hidden_states):
363
+ # router_logits: (num_tokens, n_experts)
364
+ router_logits = self.gate(hidden_states)
365
+ shared_output = self._forward_shared_experts(hidden_states)
366
+ topk_weights, topk_idx = select_experts(
367
+ hidden_states=hidden_states,
368
+ router_logits=router_logits,
369
+ top_k=self.top_k,
370
+ use_grouped_topk=True,
371
+ renormalize=self.renormalize,
372
+ topk_group=self.topk_group,
373
+ num_expert_group=self.num_expert_group,
374
+ num_fused_shared_experts=self.num_fused_shared_experts,
375
+ correction_bias=self.correction_bias,
376
+ routed_scaling_factor=self.routed_scaling_factor,
377
+ num_token_non_padded=forward_batch.num_token_non_padded,
378
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
379
+ layer_id=self.layer_id,
380
+ ),
381
+ )
382
+ else:
383
+ topk_idx = torch.full(
384
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
385
+ )
386
+ topk_weights = torch.empty(
387
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
388
+ )
389
+ if self.ep_size > 1:
390
+ # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
391
+ (
392
+ hidden_states,
393
+ topk_idx,
394
+ topk_weights,
395
+ reorder_topk_ids,
396
+ num_recv_tokens_per_expert,
397
+ seg_indptr,
398
+ masked_m,
399
+ expected_m,
400
+ ) = self.deepep_dispatcher.dispatch(
401
+ hidden_states=hidden_states,
402
+ topk_idx=topk_idx,
403
+ topk_weights=topk_weights,
404
+ forward_mode=forward_mode,
405
+ )
406
+ final_hidden_states = self.experts(
407
+ hidden_states=hidden_states,
408
+ topk_idx=topk_idx,
409
+ topk_weights=topk_weights,
410
+ reorder_topk_ids=reorder_topk_ids,
411
+ seg_indptr=seg_indptr,
412
+ masked_m=masked_m,
413
+ expected_m=expected_m,
414
+ num_recv_tokens_per_expert=num_recv_tokens_per_expert,
415
+ forward_mode=forward_mode,
416
+ )
417
+ if self.ep_size > 1:
418
+ final_hidden_states = self.deepep_dispatcher.combine(
419
+ hidden_states=final_hidden_states,
420
+ topk_idx=topk_idx,
421
+ topk_weights=topk_weights,
422
+ forward_mode=forward_mode,
423
+ )
424
+
425
+ if shared_output is not None:
426
+ x = shared_output
427
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
428
+ final_hidden_states = x
429
+ else:
430
+ final_hidden_states *= self.routed_scaling_factor
431
+
432
+ return final_hidden_states
433
+
434
+ def _forward_shared_experts(self, hidden_states):
435
+ if self.num_fused_shared_experts == 0:
436
+ return self.shared_experts(hidden_states)
437
+ else:
438
+ return None
439
+
327
440
  def op_gate(self, state):
328
- if (not self._enable_deepep_moe) or is_non_idle_and_non_empty(
441
+ if is_non_idle_and_non_empty(
329
442
  state.forward_batch.forward_mode, state.hidden_states_mlp_input
330
443
  ):
331
444
  # router_logits: (num_tokens, n_experts)
@@ -334,22 +447,22 @@ class DeepseekV2MoE(nn.Module):
334
447
  state.router_logits = None
335
448
 
336
449
  def op_shared_experts(self, state):
337
- if (self.n_share_experts_fusion == 0) and (
338
- (not self._enable_deepep_moe)
339
- or is_non_idle_and_non_empty(
340
- state.forward_batch.forward_mode, state.hidden_states_mlp_input
341
- )
450
+ hidden_states_mlp_input = state.pop("hidden_states_mlp_input")
451
+ if (self.num_fused_shared_experts == 0) and is_non_idle_and_non_empty(
452
+ state.forward_batch.forward_mode, hidden_states_mlp_input
342
453
  ):
343
- state.shared_output = self.shared_experts(state.hidden_states_mlp_input)
454
+ state.shared_output = self.shared_experts(hidden_states_mlp_input)
344
455
  else:
345
456
  state.shared_output = None
346
457
 
347
458
  def op_select_experts(self, state):
348
- router_logits = state.router_logits
459
+ router_logits = state.pop("router_logits")
349
460
  hidden_states = state.hidden_states_mlp_input
350
461
 
351
- if self._enable_deepep_moe:
352
- if router_logits is not None:
462
+ if router_logits is not None:
463
+ with get_global_expert_distribution_recorder().with_current_layer(
464
+ self.layer_id
465
+ ):
353
466
  state.topk_weights_local, state.topk_idx_local = select_experts(
354
467
  hidden_states=hidden_states,
355
468
  router_logits=router_logits,
@@ -358,90 +471,89 @@ class DeepseekV2MoE(nn.Module):
358
471
  renormalize=self.renormalize,
359
472
  topk_group=self.topk_group,
360
473
  num_expert_group=self.num_expert_group,
474
+ num_fused_shared_experts=self.num_fused_shared_experts,
361
475
  correction_bias=self.correction_bias,
362
476
  routed_scaling_factor=self.routed_scaling_factor,
477
+ num_token_non_padded=state.forward_batch.num_token_non_padded,
363
478
  expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
364
479
  layer_id=self.layer_id,
365
480
  ),
366
481
  )
367
- else:
368
- state.topk_idx_local = torch.full(
369
- (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
370
- )
371
- state.topk_weights_local = torch.empty(
372
- (0, self.top_k), dtype=torch.float32, device=hidden_states.device
373
- )
482
+ else:
483
+ state.topk_idx_local = torch.full(
484
+ (0, self.top_k), -1, dtype=torch.int, device=hidden_states.device
485
+ )
486
+ state.topk_weights_local = torch.empty(
487
+ (0, self.top_k), dtype=torch.float32, device=hidden_states.device
488
+ )
374
489
 
375
490
  def op_dispatch_a(self, state):
376
- if self._enable_deepep_moe and (self.ep_size > 1):
491
+ if self.ep_size > 1:
377
492
  # TODO(ch-wan): allow users to set num_max_dispatch_tokens_per_rank value
378
493
  self.deepep_dispatcher.dispatch_a(
379
- hidden_states=state.pop("hidden_states_mlp_input"),
494
+ hidden_states=state.hidden_states_mlp_input,
380
495
  topk_idx=state.pop("topk_idx_local"),
381
496
  topk_weights=state.pop("topk_weights_local"),
382
497
  forward_mode=state.forward_batch.forward_mode,
498
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
383
499
  )
384
500
 
385
501
  def op_dispatch_b(self, state):
386
- if self._enable_deepep_moe and (self.ep_size > 1):
387
- (
388
- state.hidden_states_experts_input,
389
- state.topk_idx_dispatched,
390
- state.topk_weights_dispatched,
391
- state.reorder_topk_ids,
392
- state.num_recv_tokens_per_expert,
393
- state.seg_indptr,
394
- state.masked_m,
395
- state.expected_m,
396
- ) = self.deepep_dispatcher.dispatch_b()
502
+ if self.ep_size > 1:
503
+ with get_global_expert_distribution_recorder().with_current_layer(
504
+ self.layer_id
505
+ ):
506
+ (
507
+ state.hidden_states_experts_input,
508
+ state.topk_idx_dispatched,
509
+ state.topk_weights_dispatched,
510
+ state.reorder_topk_ids,
511
+ state.num_recv_tokens_per_expert,
512
+ state.seg_indptr,
513
+ state.masked_m,
514
+ state.expected_m,
515
+ ) = self.deepep_dispatcher.dispatch_b(
516
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
517
+ )
397
518
 
398
519
  def op_experts(self, state):
399
- if self._enable_deepep_moe:
400
- state.pop("router_logits")
401
- state.hidden_states_experts_output = self.experts(
402
- hidden_states=state.pop("hidden_states_experts_input"),
403
- topk_idx=state.topk_idx_dispatched,
404
- topk_weights=state.topk_weights_dispatched,
405
- reorder_topk_ids=state.pop("reorder_topk_ids"),
406
- seg_indptr=state.pop("seg_indptr"),
407
- masked_m=state.pop("masked_m"),
408
- expected_m=state.pop("expected_m"),
409
- num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
410
- forward_mode=state.forward_batch.forward_mode,
411
- )
412
- else:
413
- state.hidden_states_experts_output = self.experts(
414
- hidden_states=state.pop("hidden_states_mlp_input"),
415
- router_logits=state.pop("router_logits"),
416
- )
520
+ state.hidden_states_experts_output = self.experts(
521
+ hidden_states=state.pop("hidden_states_experts_input"),
522
+ topk_idx=state.topk_idx_dispatched,
523
+ topk_weights=state.topk_weights_dispatched,
524
+ reorder_topk_ids=state.pop("reorder_topk_ids"),
525
+ seg_indptr=state.pop("seg_indptr"),
526
+ masked_m=state.pop("masked_m"),
527
+ expected_m=state.pop("expected_m"),
528
+ num_recv_tokens_per_expert=state.pop("num_recv_tokens_per_expert"),
529
+ forward_mode=state.forward_batch.forward_mode,
530
+ )
417
531
 
418
532
  def op_combine_a(self, state):
419
- if self._enable_deepep_moe and (self.ep_size > 1):
533
+ if self.ep_size > 1:
420
534
  self.deepep_dispatcher.combine_a(
421
- state.pop("hidden_states_experts_output"),
535
+ hidden_states=state.pop("hidden_states_experts_output"),
422
536
  topk_idx=state.pop("topk_idx_dispatched"),
423
537
  topk_weights=state.pop("topk_weights_dispatched"),
424
538
  forward_mode=state.forward_batch.forward_mode,
539
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
425
540
  )
426
541
 
427
542
  def op_combine_b(self, state):
428
- if self._enable_deepep_moe and (self.ep_size > 1):
429
- state.hidden_states_after_combine = self.deepep_dispatcher.combine_b()
543
+ if self.ep_size > 1:
544
+ state.hidden_states_after_combine = self.deepep_dispatcher.combine_b(
545
+ tbo_subbatch_index=state.get("tbo_subbatch_index"),
546
+ )
430
547
 
431
548
  def op_output(self, state):
432
- final_hidden_states = (
433
- state.pop("hidden_states_after_combine")
434
- if self._enable_deepep_moe
435
- else state.pop("hidden_states_experts_output")
436
- )
549
+ final_hidden_states = state.pop("hidden_states_after_combine")
437
550
 
438
- final_hidden_states *= self.routed_scaling_factor
439
-
440
- if (s := state.pop("shared_output")) is not None:
441
- final_hidden_states = final_hidden_states + s
442
-
443
- if (not self._enable_deepep_moe) and (self.tp_size > 1):
444
- final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
551
+ if (shared_output := state.pop("shared_output")) is not None:
552
+ x = shared_output
553
+ x.add_(final_hidden_states, alpha=self.routed_scaling_factor)
554
+ final_hidden_states = x
555
+ else:
556
+ final_hidden_states *= self.routed_scaling_factor
445
557
 
446
558
  state.hidden_states_mlp_output = final_hidden_states
447
559
 
@@ -596,10 +708,11 @@ class DeepseekV2AttentionMLA(nn.Module):
596
708
  )
597
709
 
598
710
  self.alt_stream = alt_stream
711
+ self.attn_mha.kv_b_proj = None
599
712
 
600
713
  self.w_kc = None
601
714
  self.w_vc = None
602
- self.w_scale = None
715
+ self.w_scale = 1.0
603
716
 
604
717
  self.w_scale_k = None
605
718
  self.w_scale_v = None
@@ -665,6 +778,15 @@ class DeepseekV2AttentionMLA(nn.Module):
665
778
  return AttnForwardMethod.MHA_CHUNKED_KV
666
779
  else:
667
780
  return _dispatch_mla_subtype()
781
+ elif self.attention_backend == "aiter":
782
+ if (
783
+ forward_batch.forward_mode.is_extend()
784
+ and not forward_batch.forward_mode.is_target_verify()
785
+ and not forward_batch.forward_mode.is_draft_extend()
786
+ ):
787
+ return AttnForwardMethod.MHA
788
+ else:
789
+ return AttnForwardMethod.MLA
668
790
  else:
669
791
  # Triton: Use normal computation for prefill and use weight absorption for extend/decode
670
792
  if (
@@ -677,44 +799,97 @@ class DeepseekV2AttentionMLA(nn.Module):
677
799
  else:
678
800
  return _dispatch_mla_subtype()
679
801
 
802
+ def op_prepare(self, state):
803
+ state.attn_intermediate_state = self.forward_prepare(
804
+ positions=state.positions,
805
+ hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
806
+ forward_batch=state.forward_batch,
807
+ zero_allocator=state.zero_allocator,
808
+ )
809
+
810
+ def op_core(self, state):
811
+ state.hidden_states_after_attn = self.forward_core(
812
+ state.pop("attn_intermediate_state")
813
+ )
814
+
680
815
  def forward(
681
816
  self,
682
817
  positions: torch.Tensor,
683
818
  hidden_states: torch.Tensor,
684
819
  forward_batch: ForwardBatch,
685
820
  zero_allocator: BumpAllocator,
686
- ) -> torch.Tensor:
821
+ ):
822
+ s = self.forward_prepare(
823
+ positions=positions,
824
+ hidden_states=hidden_states,
825
+ forward_batch=forward_batch,
826
+ zero_allocator=zero_allocator,
827
+ )
828
+ return self.forward_core(s)
829
+
830
+ def forward_prepare(
831
+ self,
832
+ positions: torch.Tensor,
833
+ hidden_states: torch.Tensor,
834
+ forward_batch: ForwardBatch,
835
+ zero_allocator: BumpAllocator,
836
+ ):
837
+ if self.attn_mha.kv_b_proj is None:
838
+ self.attn_mha.kv_b_proj = self.kv_b_proj
839
+
687
840
  if hidden_states.shape[0] == 0:
688
841
  assert (
689
842
  not self.o_proj.reduce_results
690
843
  ), "short-circuiting allreduce will lead to hangs"
691
- return hidden_states
844
+ return hidden_states, None, forward_batch, None
692
845
 
693
846
  attn_forward_method = self.dispatch_attn_forward_method(forward_batch)
694
847
 
695
848
  if attn_forward_method == AttnForwardMethod.MHA:
696
- return self.forward_normal(positions, hidden_states, forward_batch)
849
+ inner_state = self.forward_normal_prepare(
850
+ positions, hidden_states, forward_batch, zero_allocator
851
+ )
697
852
  elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
698
- return self.forward_normal_chunked_kv(
699
- positions, hidden_states, forward_batch
853
+ inner_state = self.forward_normal_chunked_kv_prepare(
854
+ positions, hidden_states, forward_batch, zero_allocator
700
855
  )
701
856
  elif attn_forward_method == AttnForwardMethod.MLA:
702
- return self.forward_absorb(
857
+ inner_state = self.forward_absorb_prepare(
703
858
  positions, hidden_states, forward_batch, zero_allocator
704
859
  )
705
860
  elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
706
- return self.forward_absorb_fused_mla_rope(
707
- positions, hidden_states, forward_batch
861
+ inner_state = self.forward_absorb_fused_mla_rope_prepare(
862
+ positions, hidden_states, forward_batch, zero_allocator
708
863
  )
709
864
  else:
710
865
  raise NotImplementedError
866
+ return None, attn_forward_method, forward_batch, inner_state
867
+
868
+ def forward_core(self, intermediate_state):
869
+ hidden_states, attn_forward_method, forward_batch, inner_state = (
870
+ intermediate_state
871
+ )
872
+ if inner_state is None:
873
+ return hidden_states
874
+
875
+ if attn_forward_method == AttnForwardMethod.MHA:
876
+ return self.forward_normal_core(*inner_state)
877
+ elif attn_forward_method == AttnForwardMethod.MHA_CHUNKED_KV:
878
+ return self.forward_normal_chunked_kv_core(*inner_state)
879
+ elif attn_forward_method == AttnForwardMethod.MLA:
880
+ return self.forward_absorb_core(*inner_state)
881
+ elif attn_forward_method == AttnForwardMethod.MLA_FUSED_ROPE:
882
+ return self.forward_absorb_fused_mla_rope_core(*inner_state)
883
+ else:
884
+ raise NotImplementedError
711
885
 
712
- def forward_normal(
886
+ def forward_normal_prepare(
713
887
  self,
714
888
  positions: torch.Tensor,
715
889
  hidden_states: torch.Tensor,
716
890
  forward_batch: ForwardBatch,
717
- ) -> torch.Tensor:
891
+ zero_allocator: BumpAllocator,
892
+ ):
718
893
  if self.q_lora_rank is not None:
719
894
  q, latent_cache = self.fused_qkv_a_proj_with_mqa(hidden_states)[0].split(
720
895
  [self.q_lora_rank, self.kv_lora_rank + self.qk_rope_head_dim], dim=-1
@@ -749,18 +924,22 @@ class DeepseekV2AttentionMLA(nn.Module):
749
924
  forward_batch.token_to_kv_pool.set_kv_buffer(
750
925
  self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
751
926
  )
927
+
928
+ return q, k, v, forward_batch
929
+
930
+ def forward_normal_core(self, q, k, v, forward_batch):
752
931
  attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
753
932
  attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
754
933
  output, _ = self.o_proj(attn_output)
755
934
  return output
756
935
 
757
- def forward_absorb(
936
+ def forward_absorb_prepare(
758
937
  self,
759
938
  positions: torch.Tensor,
760
939
  hidden_states: torch.Tensor,
761
940
  forward_batch: ForwardBatch,
762
941
  zero_allocator: BumpAllocator,
763
- ) -> torch.Tensor:
942
+ ):
764
943
  from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
765
944
 
766
945
  if self.q_lora_rank is not None:
@@ -809,8 +988,8 @@ class DeepseekV2AttentionMLA(nn.Module):
809
988
  expected_m,
810
989
  )
811
990
  q_nope_out = q_nope_out[:, :expected_m, :]
812
- elif self.w_kc.dtype == torch.float8_e4m3fnuz:
813
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
991
+ elif _is_hip:
992
+ # TODO(haishaw): add bmm_fp8 to ROCm
814
993
  q_nope_out = torch.bmm(
815
994
  q_nope.to(torch.bfloat16).transpose(0, 1),
816
995
  self.w_kc.to(torch.bfloat16) * self.w_scale,
@@ -829,6 +1008,11 @@ class DeepseekV2AttentionMLA(nn.Module):
829
1008
  q_nope_out = q_nope_out.transpose(0, 1)
830
1009
  q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
831
1010
 
1011
+ return q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1012
+
1013
+ def forward_absorb_core(
1014
+ self, q_pe, k_pe, q_nope_out, k_nope, forward_batch, zero_allocator
1015
+ ):
832
1016
  if self.attention_backend == "fa3" or self.attention_backend == "flashinfer":
833
1017
  attn_output = self.attn_mqa(
834
1018
  q_nope_out, k_nope, k_nope, forward_batch, q_rope=q_pe, k_rope=k_pe
@@ -856,8 +1040,8 @@ class DeepseekV2AttentionMLA(nn.Module):
856
1040
  expected_m,
857
1041
  )
858
1042
  attn_bmm_output = attn_bmm_output[:, :expected_m, :]
859
- elif self.w_vc.dtype == torch.float8_e4m3fnuz:
860
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1043
+ elif _is_hip:
1044
+ # TODO(haishaw): add bmm_fp8 to ROCm
861
1045
  attn_bmm_output = torch.bmm(
862
1046
  attn_output.to(torch.bfloat16).transpose(0, 1),
863
1047
  self.w_vc.to(torch.bfloat16) * self.w_scale,
@@ -881,13 +1065,13 @@ class DeepseekV2AttentionMLA(nn.Module):
881
1065
 
882
1066
  return output
883
1067
 
884
- def forward_absorb_fused_mla_rope(
1068
+ def forward_absorb_fused_mla_rope_prepare(
885
1069
  self,
886
1070
  positions: torch.Tensor,
887
1071
  hidden_states: torch.Tensor,
888
1072
  forward_batch: ForwardBatch,
889
1073
  zero_allocator: BumpAllocator,
890
- ) -> torch.Tensor:
1074
+ ):
891
1075
  enable_rope_fusion = (
892
1076
  os.getenv("SGLANG_FUSED_MLA_ENABLE_ROPE_FUSION", "1") == "1"
893
1077
  )
@@ -908,8 +1092,8 @@ class DeepseekV2AttentionMLA(nn.Module):
908
1092
  latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
909
1093
  q_nope, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
910
1094
 
911
- if self.w_kc.dtype == torch.float8_e4m3fnuz:
912
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1095
+ if _is_hip:
1096
+ # TODO(haishaw): add bmm_fp8 to ROCm
913
1097
  q_nope_out = torch.bmm(
914
1098
  q_nope.to(torch.bfloat16).transpose(0, 1),
915
1099
  self.w_kc.to(torch.bfloat16) * self.w_scale,
@@ -976,6 +1160,44 @@ class DeepseekV2AttentionMLA(nn.Module):
976
1160
  )
977
1161
  val_cache_buf = key_cache_buf[..., : self.kv_lora_rank]
978
1162
 
1163
+ return (
1164
+ q_input,
1165
+ key_cache_buf,
1166
+ val_cache_buf,
1167
+ attn_output,
1168
+ kv_indptr,
1169
+ kv_indices,
1170
+ k_pe_output,
1171
+ cos_sin_cache,
1172
+ positions,
1173
+ attn_logits,
1174
+ num_kv_split,
1175
+ sm_scale,
1176
+ enable_rope_fusion,
1177
+ k_input,
1178
+ forward_batch,
1179
+ zero_allocator,
1180
+ )
1181
+
1182
+ def forward_absorb_fused_mla_rope_core(
1183
+ self,
1184
+ q_input,
1185
+ key_cache_buf,
1186
+ val_cache_buf,
1187
+ attn_output,
1188
+ kv_indptr,
1189
+ kv_indices,
1190
+ k_pe_output,
1191
+ cos_sin_cache,
1192
+ positions,
1193
+ attn_logits,
1194
+ num_kv_split,
1195
+ sm_scale,
1196
+ enable_rope_fusion,
1197
+ k_input,
1198
+ forward_batch,
1199
+ zero_allocator,
1200
+ ):
979
1201
  decode_attention_fwd_grouped_rope(
980
1202
  q_input,
981
1203
  key_cache_buf,
@@ -1004,8 +1226,8 @@ class DeepseekV2AttentionMLA(nn.Module):
1004
1226
 
1005
1227
  attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
1006
1228
 
1007
- if self.w_vc.dtype == torch.float8_e4m3fnuz:
1008
- # TODO(kernel): add bmm_fp8 for torch.float8_e4m3fnuz
1229
+ if _is_hip:
1230
+ # TODO(haishaw): add bmm_fp8 to ROCm
1009
1231
  attn_bmm_output = torch.bmm(
1010
1232
  attn_output.to(torch.bfloat16).transpose(0, 1),
1011
1233
  self.w_vc.to(torch.bfloat16) * self.w_scale,
@@ -1082,12 +1304,13 @@ class DeepseekV2AttentionMLA(nn.Module):
1082
1304
 
1083
1305
  return accum_output
1084
1306
 
1085
- def forward_normal_chunked_kv(
1307
+ def forward_normal_chunked_kv_prepare(
1086
1308
  self,
1087
1309
  positions: torch.Tensor,
1088
1310
  hidden_states: torch.Tensor,
1089
1311
  forward_batch: ForwardBatch,
1090
- ) -> torch.Tensor:
1312
+ zero_allocator: BumpAllocator,
1313
+ ):
1091
1314
  # In normal mha, the k and v tensors will become overly large when the prefix length is long.
1092
1315
  # To avoid this, we split the kv cache into chunks and process them one after another.
1093
1316
  # Since mha is compute friendly, the for loop induced here will not introduce significant overhead.
@@ -1130,6 +1353,9 @@ class DeepseekV2AttentionMLA(nn.Module):
1130
1353
  self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
1131
1354
  )
1132
1355
 
1356
+ return q, k, v, forward_batch
1357
+
1358
+ def forward_normal_chunked_kv_core(self, q, k, v, forward_batch):
1133
1359
  # Do mha for extended part without prefix
1134
1360
  forward_batch.set_attn_attend_prefix_cache(False)
1135
1361
  attn_output, lse = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
@@ -1252,17 +1478,29 @@ class DeepseekV2DecoderLayer(nn.Module):
1252
1478
  residual: Optional[torch.Tensor],
1253
1479
  zero_allocator: BumpAllocator,
1254
1480
  ) -> torch.Tensor:
1255
- return execute_operations(
1256
- inputs=dict(
1257
- positions=positions,
1258
- hidden_states=hidden_states,
1259
- forward_batch=forward_batch,
1260
- residual=residual,
1261
- zero_allocator=zero_allocator,
1262
- ),
1263
- operations=compute_layer_operations(self),
1481
+ hidden_states, residual = self.layer_communicator.prepare_attn(
1482
+ hidden_states, residual, forward_batch
1483
+ )
1484
+
1485
+ hidden_states = self.self_attn(
1486
+ positions=positions,
1487
+ hidden_states=hidden_states,
1488
+ forward_batch=forward_batch,
1489
+ zero_allocator=zero_allocator,
1264
1490
  )
1265
1491
 
1492
+ hidden_states, residual = self.layer_communicator.prepare_mlp(
1493
+ hidden_states, residual, forward_batch
1494
+ )
1495
+
1496
+ hidden_states = self.mlp(hidden_states, forward_batch)
1497
+
1498
+ hidden_states, residual = self.layer_communicator.postprocess_layer(
1499
+ hidden_states, residual, forward_batch
1500
+ )
1501
+
1502
+ return hidden_states, residual
1503
+
1266
1504
  def op_comm_prepare_attn(
1267
1505
  self,
1268
1506
  state,
@@ -1271,6 +1509,7 @@ class DeepseekV2DecoderLayer(nn.Module):
1271
1509
  forward_batch: ForwardBatch,
1272
1510
  residual: Optional[torch.Tensor],
1273
1511
  zero_allocator: BumpAllocator,
1512
+ tbo_subbatch_index: Optional[int] = None,
1274
1513
  ):
1275
1514
  state.hidden_states_after_comm_pre_attn, state.residual_after_input_ln = (
1276
1515
  self.layer_communicator.prepare_attn(hidden_states, residual, forward_batch)
@@ -1280,17 +1519,10 @@ class DeepseekV2DecoderLayer(nn.Module):
1280
1519
  forward_batch=forward_batch,
1281
1520
  positions=positions,
1282
1521
  zero_allocator=zero_allocator,
1522
+ tbo_subbatch_index=tbo_subbatch_index,
1283
1523
  )
1284
1524
  )
1285
1525
 
1286
- def op_attn(self, state):
1287
- state.hidden_states_after_attn = self.self_attn(
1288
- positions=state.positions,
1289
- hidden_states=state.pop("hidden_states_after_comm_pre_attn"),
1290
- forward_batch=state.forward_batch,
1291
- zero_allocator=state.zero_allocator,
1292
- )
1293
-
1294
1526
  def op_comm_prepare_mlp(self, state):
1295
1527
  state.hidden_states_mlp_input, state.residual_after_comm_pre_mlp = (
1296
1528
  self.layer_communicator.prepare_mlp(
@@ -1320,8 +1552,24 @@ class DeepseekV2DecoderLayer(nn.Module):
1320
1552
  state.forward_batch,
1321
1553
  )
1322
1554
 
1323
- state.clear(expect_keys={"positions", "forward_batch", "zero_allocator"})
1324
- return hidden_states, residual
1555
+ output = dict(
1556
+ positions=state.positions,
1557
+ hidden_states=hidden_states,
1558
+ residual=residual,
1559
+ forward_batch=state.forward_batch,
1560
+ zero_allocator=state.zero_allocator,
1561
+ tbo_subbatch_index=state.tbo_subbatch_index,
1562
+ )
1563
+
1564
+ state.clear(
1565
+ expect_keys={
1566
+ "positions",
1567
+ "forward_batch",
1568
+ "zero_allocator",
1569
+ "tbo_subbatch_index",
1570
+ }
1571
+ )
1572
+ return output
1325
1573
 
1326
1574
 
1327
1575
  class DeepseekV2Model(nn.Module):
@@ -1336,6 +1584,7 @@ class DeepseekV2Model(nn.Module):
1336
1584
  super().__init__()
1337
1585
  self.padding_id = config.pad_token_id
1338
1586
  self.vocab_size = config.vocab_size
1587
+ self.first_k_dense_replace = config.first_k_dense_replace
1339
1588
 
1340
1589
  self.embed_tokens = VocabParallelEmbedding(
1341
1590
  config.vocab_size,
@@ -1369,13 +1618,12 @@ class DeepseekV2Model(nn.Module):
1369
1618
  forward_batch: ForwardBatch,
1370
1619
  input_embeds: torch.Tensor = None,
1371
1620
  ) -> torch.Tensor:
1621
+ total_num_layers = len(self.layers)
1622
+ device = input_embeds.device if input_embeds is not None else input_ids.device
1372
1623
  zero_allocator = BumpAllocator(
1373
- # TODO for two-batch-overlap, we need a larger buffer size
1374
- buffer_size=len(self.layers) * 2,
1624
+ buffer_size=total_num_layers * 2 * (2 if forward_batch.can_run_tbo else 1),
1375
1625
  dtype=torch.float32,
1376
- device=(
1377
- input_embeds.device if input_embeds is not None else input_ids.device
1378
- ),
1626
+ device=device,
1379
1627
  )
1380
1628
 
1381
1629
  if input_embeds is None:
@@ -1384,12 +1632,33 @@ class DeepseekV2Model(nn.Module):
1384
1632
  hidden_states = input_embeds
1385
1633
 
1386
1634
  residual = None
1387
- for i in range(len(self.layers)):
1635
+
1636
+ normal_num_layers = (
1637
+ self.first_k_dense_replace
1638
+ if forward_batch.can_run_tbo
1639
+ else total_num_layers
1640
+ )
1641
+ for i in range(normal_num_layers):
1388
1642
  with get_global_expert_distribution_recorder().with_current_layer(i):
1389
1643
  layer = self.layers[i]
1390
1644
  hidden_states, residual = layer(
1391
1645
  positions, hidden_states, forward_batch, residual, zero_allocator
1392
1646
  )
1647
+
1648
+ if normal_num_layers != total_num_layers:
1649
+ hidden_states, residual = model_forward_maybe_tbo(
1650
+ layers=self.layers[normal_num_layers:],
1651
+ enable_tbo=True,
1652
+ positions=positions,
1653
+ forward_batch=forward_batch,
1654
+ hidden_states=hidden_states,
1655
+ residual=residual,
1656
+ input_data_scatter_mode=self.layers[
1657
+ normal_num_layers - 1
1658
+ ].layer_scatter_modes.layer_output_mode,
1659
+ zero_allocator=zero_allocator,
1660
+ )
1661
+
1393
1662
  if not forward_batch.forward_mode.is_idle():
1394
1663
  if residual is None:
1395
1664
  hidden_states = self.norm(hidden_states)
@@ -1410,7 +1679,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1410
1679
  self.config = config
1411
1680
  self.tp_size = get_tensor_model_parallel_world_size()
1412
1681
  self.quant_config = quant_config
1413
- self.determine_n_share_experts_fusion()
1682
+ self.determine_num_fused_shared_experts()
1414
1683
  self.model = DeepseekV2Model(
1415
1684
  config, quant_config, prefix=add_prefix("model", prefix)
1416
1685
  )
@@ -1424,40 +1693,67 @@ class DeepseekV2ForCausalLM(nn.Module):
1424
1693
  self.logits_processor = LogitsProcessor(config)
1425
1694
  self.dp_size = get_local_attention_dp_size()
1426
1695
 
1427
- def determine_n_share_experts_fusion(
1696
+ self._routed_experts_weights_of_layer = LazyValue(
1697
+ lambda: {
1698
+ layer_id: layer.mlp.get_moe_weights()
1699
+ for layer_id, layer in enumerate(self.model.layers)
1700
+ if isinstance(layer.mlp, DeepseekV2MoE)
1701
+ }
1702
+ )
1703
+
1704
+ @property
1705
+ def routed_experts_weights_of_layer(self):
1706
+ return self._routed_experts_weights_of_layer.value
1707
+
1708
+ def determine_num_fused_shared_experts(
1428
1709
  self, architecture: str = "DeepseekV3ForCausalLM"
1429
1710
  ):
1430
- self.n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
1431
- if self.n_share_experts_fusion > 0:
1711
+ self.num_fused_shared_experts = (
1712
+ 0
1713
+ if global_server_args_dict["disable_shared_experts_fusion"]
1714
+ else self.config.n_shared_experts
1715
+ )
1716
+ if self.num_fused_shared_experts > 0:
1432
1717
  # Only Deepseek V3/R1 can use shared experts fusion optimization now.
1433
1718
  if (
1434
1719
  not _is_cuda
1435
1720
  or self.config.architectures[0] != architecture
1436
1721
  or self.config.n_routed_experts != 256
1437
1722
  ):
1438
- self.n_share_experts_fusion = 0
1439
- global_server_args_dict["n_share_experts_fusion"] = 0
1723
+ self.num_fused_shared_experts = 0
1724
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1440
1725
  log_info_on_rank0(
1441
1726
  logger,
1442
1727
  "Only Deepseek V3/R1 on NV-platform can use shared experts fusion optimization. Shared experts fusion optimization is disabled.",
1443
1728
  )
1444
- else:
1445
- assert (
1446
- self.n_share_experts_fusion == self.tp_size
1447
- ), f"Shared experts fusion optimization is enabled in DeepSeek V3/R1, set it to {self.tp_size} can get best optimized performance."
1448
- elif self.n_share_experts_fusion == 0:
1729
+ elif (
1730
+ global_server_args_dict["enable_deepep_moe"]
1731
+ or global_server_args_dict["enable_ep_moe"]
1732
+ ):
1733
+ self.num_fused_shared_experts = 0
1734
+ global_server_args_dict["disable_shared_experts_fusion"] = True
1735
+ log_info_on_rank0(
1736
+ logger,
1737
+ "Deepseek V3/R1 can not use shared experts fusion optimization when in deepep_moe or ep_moe mode. Shared experts fusion optimization is disabled.",
1738
+ )
1739
+ elif self.num_fused_shared_experts == 0:
1449
1740
  if (
1450
1741
  _is_cuda
1451
1742
  and torch.cuda.get_device_capability("cuda") >= (9, 0)
1452
1743
  and self.config.architectures[0] == architecture
1453
1744
  and self.config.n_routed_experts == 256
1454
- and (not global_server_args_dict["enable_deepep_moe"])
1745
+ and (
1746
+ not (
1747
+ global_server_args_dict["enable_deepep_moe"]
1748
+ or global_server_args_dict["enable_ep_moe"]
1749
+ )
1750
+ )
1455
1751
  ):
1456
- self.n_share_experts_fusion = self.tp_size
1457
- global_server_args_dict["n_share_experts_fusion"] = self.tp_size
1752
+ self.num_fused_shared_experts = self.config.n_shared_experts
1753
+ global_server_args_dict["disable_shared_experts_fusion"] = False
1458
1754
  log_info_on_rank0(
1459
1755
  logger,
1460
- "Deepseek V3/R1 with fp8 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1756
+ "Deepseek V3/R1 with fp8/fp4 can use shared experts fusion optimization when SM version >=90. Shared experts fusion optimization is enabled.",
1461
1757
  )
1462
1758
 
1463
1759
  def get_input_embeddings(self) -> nn.Embedding:
@@ -1471,21 +1767,29 @@ class DeepseekV2ForCausalLM(nn.Module):
1471
1767
  forward_batch: ForwardBatch,
1472
1768
  input_embeds: torch.Tensor = None,
1473
1769
  ) -> torch.Tensor:
1474
-
1475
1770
  hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
1476
1771
 
1477
1772
  return self.logits_processor(
1478
1773
  input_ids, hidden_states, self.lm_head, forward_batch
1479
1774
  )
1480
1775
 
1481
- def post_load_weights(self, is_nextn=False):
1776
+ def post_load_weights(self, is_nextn=False, weight_names=None):
1482
1777
 
1483
1778
  # Perform post-processing after loading weights
1484
- layer_ids = (
1485
- range(self.config.num_hidden_layers)
1486
- if not is_nextn
1487
- else [self.config.num_hidden_layers]
1488
- )
1779
+ if is_nextn:
1780
+ layer_ids = [self.config.num_hidden_layers]
1781
+ else:
1782
+ if weight_names is None:
1783
+ layer_ids = range(self.config.num_hidden_layers)
1784
+ else:
1785
+ layer_ids = set()
1786
+ for name in weight_names:
1787
+ if "kv_b_proj" in name:
1788
+ layer_id = int(name.split(".")[2])
1789
+ # filter the nextn layer.
1790
+ if layer_id != self.config.num_hidden_layers:
1791
+ layer_ids.add(layer_id)
1792
+
1489
1793
  for layer_id in layer_ids:
1490
1794
  self_attn = (
1491
1795
  self.model.layers[layer_id].self_attn
@@ -1521,46 +1825,56 @@ class DeepseekV2ForCausalLM(nn.Module):
1521
1825
  torch.float8_e4m3fn,
1522
1826
  torch.float8_e4m3fnuz,
1523
1827
  ):
1524
- if hasattr(self.quant_config, "weight_block_size"):
1828
+ if (
1829
+ hasattr(self.quant_config, "weight_block_size")
1830
+ and self.quant_config.weight_block_size is not None
1831
+ ):
1525
1832
  weight_block_size = self.quant_config.weight_block_size
1526
- if weight_block_size is not None:
1527
- assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1528
- if _is_hip:
1529
- weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1530
- weight=w,
1531
- weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1532
- input_scale=None,
1533
- )
1534
- else:
1535
- weight = w
1536
- weight_scale = self_attn.kv_b_proj.weight_scale_inv
1833
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
1834
+ if _is_fp8_fnuz:
1835
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1836
+ weight=w,
1837
+ weight_scale=self_attn.kv_b_proj.weight_scale_inv,
1838
+ input_scale=None,
1839
+ )
1840
+ else:
1841
+ weight = w
1842
+ weight_scale = self_attn.kv_b_proj.weight_scale_inv
1537
1843
 
1538
- if (
1539
- _is_cuda
1540
- and weight_block_size[0] == 128
1541
- and weight_block_size[1] == 128
1542
- and model_dtype == torch.bfloat16
1844
+ if (
1845
+ _is_cuda
1846
+ and weight_block_size[0] == 128
1847
+ and weight_block_size[1] == 128
1848
+ and model_dtype == torch.bfloat16
1849
+ ):
1850
+ if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1851
+ "SGL_USE_DEEPGEMM_BMM", "false"
1543
1852
  ):
1544
- if _ENABLE_JIT_DEEPGEMM and get_bool_env_var(
1545
- "SGL_USE_DEEPGEMM_BMM", "false"
1546
- ):
1547
- block_scale = weight_scale
1548
- use_deep_gemm_bmm = True
1549
- else:
1550
- w = block_quant_dequant(
1551
- weight,
1552
- weight_scale,
1553
- weight_block_size,
1554
- model_dtype,
1555
- )
1853
+ block_scale = weight_scale
1854
+ use_deep_gemm_bmm = True
1556
1855
  else:
1557
- w, scale = block_quant_to_tensor_quant(
1558
- weight, weight_scale, weight_block_size
1856
+ w = block_quant_dequant(
1857
+ weight,
1858
+ weight_scale,
1859
+ weight_block_size,
1860
+ model_dtype,
1559
1861
  )
1560
- self_attn.w_scale = scale
1862
+ else:
1863
+ w, scale = block_quant_to_tensor_quant(
1864
+ weight, weight_scale, weight_block_size
1865
+ )
1866
+ self_attn.w_scale = scale
1561
1867
  else:
1562
- weight = w
1563
- weight_scale = self_attn.kv_b_proj.weight_scale
1868
+ if _is_fp8_fnuz:
1869
+ weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1870
+ weight=w,
1871
+ weight_scale=self_attn.kv_b_proj.weight_scale,
1872
+ input_scale=None,
1873
+ )
1874
+ else:
1875
+ weight = w
1876
+ weight_scale = self_attn.kv_b_proj.weight_scale
1877
+
1564
1878
  w, scale = channel_quant_to_tensor_quant(weight, weight_scale)
1565
1879
  self_attn.w_scale = scale
1566
1880
 
@@ -1585,13 +1899,19 @@ class DeepseekV2ForCausalLM(nn.Module):
1585
1899
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
1586
1900
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
1587
1901
  if not use_deep_gemm_bmm:
1588
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1589
- self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
1902
+ self_attn.w_kc = bind_or_assign(
1903
+ self_attn.w_kc, w_kc.transpose(1, 2).contiguous().transpose(1, 2)
1904
+ )
1905
+ self_attn.w_vc = bind_or_assign(
1906
+ self_attn.w_vc, w_vc.contiguous().transpose(1, 2)
1907
+ )
1590
1908
  if (
1591
1909
  hasattr(self_attn.kv_b_proj, "weight_scale")
1592
1910
  and self_attn.w_scale is None
1593
1911
  ):
1594
- self_attn.w_scale = self_attn.kv_b_proj.weight_scale
1912
+ self_attn.w_scale = bind_or_assign(
1913
+ self_attn.w_scale, self_attn.kv_b_proj.weight_scale
1914
+ )
1595
1915
  if _is_hip:
1596
1916
  self_attn.w_scale *= 2.0
1597
1917
  else:
@@ -1600,21 +1920,20 @@ class DeepseekV2ForCausalLM(nn.Module):
1600
1920
  ws_kc, ws_vc = block_scale.unflatten(
1601
1921
  0, (-1, (num_tiles_k + num_tiles_n))
1602
1922
  ).split([num_tiles_k, num_tiles_n], dim=1)
1603
- self_attn.w_scale_k = ws_kc.transpose(1, 2).contiguous()
1604
- self_attn.w_scale_v = ws_vc.contiguous()
1605
- self_attn.w_kc = w_kc.transpose(1, 2).contiguous()
1606
- self_attn.w_vc = w_vc.contiguous()
1923
+ self_attn.w_scale_k = bind_or_assign(
1924
+ self_attn.w_scale_k, ws_kc.transpose(1, 2).contiguous()
1925
+ )
1926
+ self_attn.w_scale_v = bind_or_assign(
1927
+ self_attn.w_scale_v, ws_vc.contiguous()
1928
+ )
1929
+ self_attn.w_kc = bind_or_assign(
1930
+ self_attn.w_kc, w_kc.transpose(1, 2).contiguous()
1931
+ )
1932
+ self_attn.w_vc = bind_or_assign(self_attn.w_vc, w_vc.contiguous())
1607
1933
  self_attn.use_deep_gemm_bmm = True
1608
1934
 
1609
- # TODO support nextn later
1610
- if not is_nextn:
1611
- self.routed_experts_weights_of_layer = {
1612
- layer_id: layer.mlp.get_moe_weights()
1613
- for layer_id, layer in enumerate(self.model.layers)
1614
- if isinstance(layer.mlp, DeepseekV2MoE)
1615
- }
1616
-
1617
1935
  def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]], is_nextn=False):
1936
+
1618
1937
  if is_nextn:
1619
1938
  if hasattr(self.config, "num_nextn_predict_layers"):
1620
1939
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1633,26 +1952,68 @@ class DeepseekV2ForCausalLM(nn.Module):
1633
1952
  ("gate_up_proj", "gate_proj", 0),
1634
1953
  ("gate_up_proj", "up_proj", 1),
1635
1954
  ]
1636
- if self.n_share_experts_fusion > 0:
1955
+ if self.num_fused_shared_experts > 0:
1956
+ assert self.num_fused_shared_experts == 1
1637
1957
  weights_list = list(weights)
1638
1958
  weights_dict = dict(weights_list)
1639
- if self.quant_config is None or self.quant_config.get_name() == "w8a8_int8":
1640
- suffix_list = [
1641
- "down_proj.weight",
1642
- "down_proj.weight_scale",
1643
- "gate_proj.weight",
1644
- "gate_proj.weight_scale",
1645
- "up_proj.weight",
1646
- "up_proj.weight_scale",
1647
- ]
1959
+ if self.quant_config is not None:
1960
+ if self.quant_config.get_name() == "w8a8_int8":
1961
+ suffix_list = [
1962
+ "down_proj.weight",
1963
+ "down_proj.weight_scale",
1964
+ "gate_proj.weight",
1965
+ "gate_proj.weight_scale",
1966
+ "up_proj.weight",
1967
+ "up_proj.weight_scale",
1968
+ ]
1969
+ elif (
1970
+ self.quant_config.get_name() == "fp8"
1971
+ or self.quant_config.get_name() == "blockwise_int8"
1972
+ ):
1973
+ suffix_list = [
1974
+ "down_proj.weight",
1975
+ "down_proj.weight_scale_inv",
1976
+ "gate_proj.weight",
1977
+ "gate_proj.weight_scale_inv",
1978
+ "up_proj.weight",
1979
+ "up_proj.weight_scale_inv",
1980
+ ]
1981
+ elif self.quant_config.get_name() == "awq":
1982
+ suffix_list = [
1983
+ "down_proj.qweight",
1984
+ "down_proj.qzeros",
1985
+ "down_proj.scales",
1986
+ "gate_proj.qweight",
1987
+ "gate_proj.qzeros",
1988
+ "gate_proj.scales",
1989
+ "up_proj.qweight",
1990
+ "up_proj.qzeros",
1991
+ "up_proj.scales",
1992
+ ]
1993
+ elif self.quant_config.get_name() == "modelopt_fp4":
1994
+ suffix_list = [
1995
+ "down_proj.weight",
1996
+ "down_proj.weight_scale",
1997
+ "down_proj.weight_scale_2",
1998
+ "down_proj.input_scale",
1999
+ "gate_proj.weight",
2000
+ "gate_proj.weight_scale",
2001
+ "gate_proj.weight_scale_2",
2002
+ "gate_proj.input_scale",
2003
+ "up_proj.weight",
2004
+ "up_proj.weight_scale",
2005
+ "up_proj.weight_scale_2",
2006
+ "up_proj.input_scale",
2007
+ ]
2008
+ else:
2009
+ raise ValueError(
2010
+ f"Unsupported shared expert fusion for quantization: {self.quant_config.get_name()}."
2011
+ )
1648
2012
  else:
1649
2013
  suffix_list = [
1650
2014
  "down_proj.weight",
1651
- "down_proj.weight_scale_inv",
1652
2015
  "gate_proj.weight",
1653
- "gate_proj.weight_scale_inv",
1654
2016
  "up_proj.weight",
1655
- "up_proj.weight_scale_inv",
1656
2017
  ]
1657
2018
  names_to_remove = []
1658
2019
 
@@ -1668,23 +2029,22 @@ class DeepseekV2ForCausalLM(nn.Module):
1668
2029
 
1669
2030
  for moe_layer in tqdm(
1670
2031
  moe_layers,
1671
- desc=f"Cloning {self.n_share_experts_fusion} "
1672
- "replicas of the shared expert into MoE",
2032
+ desc=f"Cloning {self.num_fused_shared_experts} "
2033
+ "shared expert into MoE",
1673
2034
  ):
1674
2035
  for suffix in suffix_list:
1675
2036
  shared_expert_weight_name = (
1676
2037
  f"model.layers.{moe_layer}.mlp.shared_experts.{suffix}"
1677
2038
  )
1678
- for num_repeat in range(self.n_share_experts_fusion):
1679
- weights_list.append(
1680
- (
1681
- f"model.layers.{moe_layer}."
1682
- f"mlp.experts."
1683
- f"{self.config.n_routed_experts + num_repeat}"
1684
- f".{suffix}",
1685
- weights_dict[shared_expert_weight_name],
1686
- )
2039
+ weights_list.append(
2040
+ (
2041
+ f"model.layers.{moe_layer}."
2042
+ f"mlp.experts."
2043
+ f"{self.config.n_routed_experts + 0}"
2044
+ f".{suffix}",
2045
+ weights_dict[shared_expert_weight_name],
1687
2046
  )
2047
+ )
1688
2048
  names_to_remove += [shared_expert_weight_name]
1689
2049
  weights = [w for w in weights_list if w[0] not in names_to_remove]
1690
2050
 
@@ -1694,7 +2054,7 @@ class DeepseekV2ForCausalLM(nn.Module):
1694
2054
  ckpt_gate_proj_name="gate_proj",
1695
2055
  ckpt_down_proj_name="down_proj",
1696
2056
  ckpt_up_proj_name="up_proj",
1697
- num_experts=self.config.n_routed_experts + self.n_share_experts_fusion,
2057
+ num_experts=self.config.n_routed_experts + self.num_fused_shared_experts,
1698
2058
  )
1699
2059
 
1700
2060
  # Fuse q_a_proj and kv_a_proj_with_mqa along output dimension when q_lora_rank is not None
@@ -1713,7 +2073,10 @@ class DeepseekV2ForCausalLM(nn.Module):
1713
2073
  ]
1714
2074
 
1715
2075
  params_dict = dict(self.named_parameters())
2076
+ weight_names = []
1716
2077
  for name, loaded_weight in weights:
2078
+ weight_names.append(name)
2079
+
1717
2080
  if not is_nextn:
1718
2081
  if hasattr(self.config, "num_nextn_predict_layers"):
1719
2082
  num_nextn_layers = self.config.num_nextn_predict_layers
@@ -1785,7 +2148,6 @@ class DeepseekV2ForCausalLM(nn.Module):
1785
2148
  # Skip loading extra bias for GPTQ models.
1786
2149
  if name.endswith(".bias") and name not in params_dict:
1787
2150
  continue
1788
-
1789
2151
  if fuse_qkv_a_proj and (
1790
2152
  "q_a_proj" in name or "kv_a_proj_with_mqa" in name
1791
2153
  ):
@@ -1811,9 +2173,12 @@ class DeepseekV2ForCausalLM(nn.Module):
1811
2173
  fused_weight = torch.cat(
1812
2174
  [q_a_proj_weight, kv_a_proj_weight], dim=0
1813
2175
  )
1814
-
1815
- param_name = name.replace(
1816
- "q_a_proj", "fused_qkv_a_proj_with_mqa"
2176
+ param_name = (
2177
+ name.replace("q_a_proj", "fused_qkv_a_proj_with_mqa")
2178
+ if "q_a_proj" in name
2179
+ else name.replace(
2180
+ "kv_a_proj_with_mqa", "fused_qkv_a_proj_with_mqa"
2181
+ )
1817
2182
  )
1818
2183
  param = params_dict[param_name]
1819
2184
 
@@ -1824,13 +2189,23 @@ class DeepseekV2ForCausalLM(nn.Module):
1824
2189
  cached_a_proj.pop(q_a_proj_name)
1825
2190
  cached_a_proj.pop(kv_a_proj_name)
1826
2191
  else:
2192
+ if (
2193
+ "k_scale" in name or "v_scale" in name
2194
+ ) and name not in params_dict:
2195
+ # modelopt attn kv scale is named differently
2196
+ if any(scale in name for scale in ["k_scale", "v_scale"]):
2197
+ name = name.replace("_proj", "attn_mqa")
2198
+ else:
2199
+ logger.warning(
2200
+ f"Unknown scale found in checkpoint: {name}"
2201
+ )
1827
2202
  param = params_dict[name]
1828
2203
  weight_loader = getattr(
1829
2204
  param, "weight_loader", default_weight_loader
1830
2205
  )
1831
2206
  weight_loader(param, loaded_weight)
1832
2207
 
1833
- self.post_load_weights(is_nextn=is_nextn)
2208
+ self.post_load_weights(is_nextn=is_nextn, weight_names=weight_names)
1834
2209
 
1835
2210
  def get_embed_and_head(self):
1836
2211
  return self.model.embed_tokens.weight, self.lm_head.weight