sglang 0.4.6.post5__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (318) hide show
  1. sglang/bench_offline_throughput.py +10 -4
  2. sglang/bench_one_batch_server.py +67 -11
  3. sglang/bench_serving.py +85 -74
  4. sglang/lang/backend/runtime_endpoint.py +24 -1
  5. sglang/profiler.py +167 -0
  6. sglang/srt/_custom_ops.py +34 -0
  7. sglang/srt/configs/internvl.py +8 -12
  8. sglang/srt/configs/model_config.py +27 -1
  9. sglang/srt/constrained/base_grammar_backend.py +5 -2
  10. sglang/srt/constrained/llguidance_backend.py +9 -8
  11. sglang/srt/constrained/outlines_backend.py +5 -4
  12. sglang/srt/constrained/xgrammar_backend.py +18 -18
  13. sglang/srt/conversation.py +46 -8
  14. sglang/srt/custom_op.py +38 -3
  15. sglang/srt/debug_utils.py +74 -0
  16. sglang/srt/disaggregation/common/__init__.py +1 -0
  17. sglang/srt/disaggregation/common/conn.py +407 -0
  18. sglang/srt/disaggregation/decode.py +67 -3
  19. sglang/srt/disaggregation/fake/conn.py +1 -0
  20. sglang/srt/disaggregation/kv_events.py +60 -5
  21. sglang/srt/disaggregation/launch_lb.py +140 -0
  22. sglang/srt/disaggregation/mini_lb.py +29 -48
  23. sglang/srt/disaggregation/mooncake/conn.py +432 -140
  24. sglang/srt/disaggregation/mooncake/transfer_engine.py +32 -16
  25. sglang/srt/disaggregation/nixl/conn.py +124 -432
  26. sglang/srt/disaggregation/prefill.py +2 -0
  27. sglang/srt/disaggregation/utils.py +38 -1
  28. sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
  29. sglang/srt/distributed/parallel_state.py +52 -5
  30. sglang/srt/entrypoints/EngineBase.py +6 -0
  31. sglang/srt/entrypoints/engine.py +102 -5
  32. sglang/srt/entrypoints/http_server.py +15 -2
  33. sglang/srt/function_call/base_format_detector.py +138 -86
  34. sglang/srt/function_call/deepseekv3_detector.py +54 -6
  35. sglang/srt/function_call/ebnf_composer.py +33 -19
  36. sglang/srt/function_call/function_call_parser.py +27 -0
  37. sglang/srt/function_call/llama32_detector.py +33 -14
  38. sglang/srt/function_call/mistral_detector.py +73 -26
  39. sglang/srt/function_call/pythonic_detector.py +86 -20
  40. sglang/srt/function_call/qwen25_detector.py +64 -10
  41. sglang/srt/function_call/utils.py +17 -0
  42. sglang/srt/hf_transformers_utils.py +4 -0
  43. sglang/srt/layers/attention/aiter_backend.py +488 -123
  44. sglang/srt/layers/attention/base_attn_backend.py +4 -0
  45. sglang/srt/layers/attention/cutlass_mla_backend.py +2 -19
  46. sglang/srt/layers/attention/flashattention_backend.py +103 -18
  47. sglang/srt/layers/attention/flashinfer_backend.py +45 -1
  48. sglang/srt/layers/attention/flashinfer_mla_backend.py +37 -1
  49. sglang/srt/layers/attention/intel_amx_backend.py +128 -0
  50. sglang/srt/layers/attention/tbo_backend.py +232 -0
  51. sglang/srt/layers/attention/torch_native_backend.py +3 -0
  52. sglang/srt/layers/attention/triton_backend.py +244 -5
  53. sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
  54. sglang/srt/layers/communicator.py +260 -194
  55. sglang/srt/layers/dp_attention.py +6 -5
  56. sglang/srt/layers/layernorm.py +30 -19
  57. sglang/srt/layers/moe/cutlass_moe.py +170 -7
  58. sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
  59. sglang/srt/layers/moe/ep_moe/kernels.py +27 -6
  60. sglang/srt/layers/moe/ep_moe/layer.py +94 -40
  61. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +13 -8
  62. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  71. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +220 -25
  72. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -4
  73. sglang/srt/layers/moe/topk.py +44 -18
  74. sglang/srt/layers/multimodal.py +3 -3
  75. sglang/srt/layers/quantization/__init__.py +3 -2
  76. sglang/srt/layers/quantization/blockwise_int8.py +3 -0
  77. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
  78. sglang/srt/layers/quantization/deep_gemm.py +55 -56
  79. sglang/srt/layers/quantization/fp8.py +28 -23
  80. sglang/srt/layers/quantization/fp8_kernel.py +118 -66
  81. sglang/srt/layers/quantization/fp8_utils.py +165 -49
  82. sglang/srt/layers/quantization/modelopt_quant.py +334 -7
  83. sglang/srt/layers/quantization/moe_wna16.py +3 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
  85. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  86. sglang/srt/layers/rotary_embedding.py +6 -12
  87. sglang/srt/layers/sampler.py +80 -79
  88. sglang/srt/layers/utils.py +6 -0
  89. sglang/srt/lora/layers.py +12 -15
  90. sglang/srt/lora/lora.py +49 -5
  91. sglang/srt/lora/lora_manager.py +19 -5
  92. sglang/srt/lora/mem_pool.py +24 -16
  93. sglang/srt/lora/utils.py +17 -13
  94. sglang/srt/managers/data_parallel_controller.py +13 -5
  95. sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
  96. sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
  97. sglang/srt/managers/{deepseek_eplb.py → eplb_algorithms/deepseek_vec.py} +5 -7
  98. sglang/srt/managers/eplb_manager.py +55 -14
  99. sglang/srt/managers/expert_distribution.py +220 -46
  100. sglang/srt/managers/expert_location.py +110 -56
  101. sglang/srt/managers/expert_location_dispatch.py +23 -6
  102. sglang/srt/managers/io_struct.py +15 -4
  103. sglang/srt/managers/mm_utils.py +88 -38
  104. sglang/srt/managers/multimodal_processors/base_processor.py +188 -16
  105. sglang/srt/managers/multimodal_processors/gemma3.py +4 -31
  106. sglang/srt/managers/multimodal_processors/internvl.py +4 -0
  107. sglang/srt/managers/multimodal_processors/kimi_vl.py +15 -34
  108. sglang/srt/managers/multimodal_processors/minicpm.py +2 -1
  109. sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
  110. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -64
  111. sglang/srt/managers/schedule_batch.py +140 -38
  112. sglang/srt/managers/scheduler.py +305 -112
  113. sglang/srt/managers/tokenizer_manager.py +134 -17
  114. sglang/srt/managers/utils.py +0 -4
  115. sglang/srt/metrics/collector.py +9 -0
  116. sglang/srt/model_executor/cuda_graph_runner.py +72 -61
  117. sglang/srt/model_executor/expert_location_updater.py +157 -22
  118. sglang/srt/model_executor/forward_batch_info.py +38 -17
  119. sglang/srt/model_executor/model_runner.py +96 -56
  120. sglang/srt/model_loader/utils.py +67 -1
  121. sglang/srt/models/deepseek_nextn.py +1 -1
  122. sglang/srt/models/deepseek_v2.py +609 -234
  123. sglang/srt/models/gemma3_causal.py +7 -0
  124. sglang/srt/models/gemma3_mm.py +19 -14
  125. sglang/srt/models/idefics2.py +342 -0
  126. sglang/srt/models/kimi_vl.py +4 -4
  127. sglang/srt/models/llama.py +1 -1
  128. sglang/srt/models/minicpmo.py +2 -5
  129. sglang/srt/models/minicpmv.py +3 -295
  130. sglang/srt/models/phi4mm.py +512 -0
  131. sglang/srt/models/qwen2.py +38 -9
  132. sglang/srt/models/qwen2_5_vl.py +3 -9
  133. sglang/srt/models/qwen2_eagle.py +4 -1
  134. sglang/srt/models/qwen2_moe.py +58 -191
  135. sglang/srt/models/qwen2_vl.py +3 -9
  136. sglang/srt/models/qwen3.py +41 -10
  137. sglang/srt/models/qwen3_moe.py +230 -191
  138. sglang/srt/models/registry.py +9 -1
  139. sglang/srt/models/transformers.py +291 -0
  140. sglang/srt/openai_api/adapter.py +86 -24
  141. sglang/srt/openai_api/protocol.py +31 -2
  142. sglang/srt/openai_api/utils.py +172 -0
  143. sglang/srt/operations.py +37 -2
  144. sglang/srt/operations_strategy.py +200 -24
  145. sglang/srt/sampling/sampling_batch_info.py +13 -1
  146. sglang/srt/sampling/sampling_params.py +2 -1
  147. sglang/srt/server_args.py +114 -27
  148. sglang/srt/speculative/build_eagle_tree.py +8 -8
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -11
  150. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +253 -0
  151. sglang/srt/speculative/eagle_utils.py +51 -91
  152. sglang/srt/speculative/eagle_worker.py +101 -21
  153. sglang/srt/two_batch_overlap.py +635 -0
  154. sglang/srt/utils.py +129 -7
  155. sglang/test/runners.py +16 -7
  156. sglang/test/send_one.py +4 -0
  157. sglang/test/test_cutlass_moe.py +3 -3
  158. sglang/test/test_fp4_moe.py +248 -0
  159. sglang/test/test_utils.py +79 -6
  160. sglang/version.py +1 -1
  161. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/METADATA +14 -11
  162. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/RECORD +318 -291
  163. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/WHEEL +1 -1
  164. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  165. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  166. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  167. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  168. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  169. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  170. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  171. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  172. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  173. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  174. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  175. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  176. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  177. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
  178. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  179. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  180. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  181. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  182. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  183. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  184. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  185. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  186. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  187. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  188. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  189. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  190. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  191. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  192. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  193. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  194. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  195. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  196. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  197. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  198. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  199. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  200. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  201. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  202. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  203. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  204. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  205. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  206. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
  207. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  208. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  209. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  210. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  211. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  212. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  213. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  214. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  215. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  216. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
  217. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
  218. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  219. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  220. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  221. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  222. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  223. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  224. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
  225. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  226. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  227. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
  228. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  229. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  230. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  231. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
  232. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  233. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  234. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  235. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  236. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  237. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  238. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  239. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
  240. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
  241. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
  242. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
  243. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  244. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  245. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
  246. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
  247. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
  248. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
  249. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  250. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  251. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  252. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  253. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
  254. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  255. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  256. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  257. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  258. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
  259. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
  260. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
  261. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
  262. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  263. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  264. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  265. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  266. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  267. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  268. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
  269. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
  270. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  271. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  272. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  273. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  274. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  275. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  276. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  277. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
  278. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
  279. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
  280. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
  281. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  282. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  283. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  284. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  285. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
  286. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  287. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  288. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  289. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  290. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  291. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  292. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  293. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
  294. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
  295. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  296. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  297. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
  298. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  299. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
  300. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  301. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  302. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  303. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  304. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
  305. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  306. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
  307. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
  308. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  309. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  310. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  311. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  312. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  313. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  314. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  315. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  316. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  317. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/licenses/LICENSE +0 -0
  318. {sglang-0.4.6.post5.dist-info → sglang-0.4.7.dist-info}/top_level.txt +0 -0
@@ -27,12 +27,19 @@ if TYPE_CHECKING:
27
27
  from sglang.srt.speculative.spec_info import SpecInfo
28
28
 
29
29
  try:
30
- from aiter import mha_batch_prefill_func, paged_attention_ragged
30
+ from aiter import (
31
+ flash_attn_varlen_func,
32
+ mha_batch_prefill_func,
33
+ paged_attention_ragged,
34
+ )
35
+ from aiter.mla import mla_decode_fwd
31
36
  except ImportError:
32
37
  print(
33
38
  "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
34
39
  )
35
40
 
41
+ from sglang.srt.configs.model_config import AttentionArch
42
+
36
43
 
37
44
  class WrapperDispatch(Enum):
38
45
  SLIDING_WINDOW = auto()
@@ -43,6 +50,10 @@ class WrapperDispatch(Enum):
43
50
  class ForwardMetadata:
44
51
  kv_indptr: torch.Tensor
45
52
  kv_indices: torch.Tensor
53
+ qo_indptr: torch.Tensor
54
+ kv_last_page_len: torch.Tensor
55
+ max_extend_len: int
56
+ max_prefix_extend_len: int
46
57
  max_q_len: int
47
58
  max_kv_len: int
48
59
 
@@ -63,6 +74,7 @@ class AiterAttnBackend(AttentionBackend):
63
74
 
64
75
  self.device = model_runner.device
65
76
  self.is_multimodal = model_runner.model_config.is_multimodal
77
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
66
78
  self.num_head = (
67
79
  model_runner.model_config.num_attention_heads // get_attention_tp_size()
68
80
  )
@@ -75,6 +87,8 @@ class AiterAttnBackend(AttentionBackend):
75
87
 
76
88
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
77
89
 
90
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
91
+
78
92
  # Parse constants
79
93
  self.max_context_len = model_runner.model_config.context_len
80
94
  self.skip_prefill = skip_prefill
@@ -100,6 +114,10 @@ class AiterAttnBackend(AttentionBackend):
100
114
  self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
101
115
  model_runner, self
102
116
  )
117
+ if self.use_mla:
118
+ self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
119
+ model_runner, self
120
+ )
103
121
 
104
122
  # aiter kernel related initialization
105
123
  self.max_num_partitions = (
@@ -108,33 +126,40 @@ class AiterAttnBackend(AttentionBackend):
108
126
 
109
127
  nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
110
128
 
111
- self.workspace_buffer = torch.empty(
112
- (max_bs * self.num_head * self.max_num_partitions * self.head_dim)
113
- * nbyes_per_qo_elem
114
- + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
115
- dtype=torch.uint8,
116
- device=self.device,
117
- )
129
+ if not self.use_mla:
130
+ self.workspace_buffer = torch.empty(
131
+ (max_bs * self.num_head * self.max_num_partitions * self.head_dim)
132
+ * nbyes_per_qo_elem
133
+ + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
134
+ dtype=torch.uint8,
135
+ device=self.device,
136
+ )
118
137
 
119
138
  self.scale = float(1.0 / (self.head_dim**0.5))
120
139
  self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
121
140
  self.device
122
141
  )
123
- self.kv_last_page_lens = torch.ones((max_bs,), dtype=torch.int32).to(
124
- self.device
125
- )
126
142
 
127
143
  self.logits_soft_cap = 0.0
128
144
 
129
145
  self.forward_metadata: ForwardMetadata = None
130
146
 
147
+ if self.use_mla:
148
+ self.qo_indptr_ = torch.zeros(
149
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
150
+ )
151
+
131
152
  def init_forward_metadata(self, forward_batch: ForwardBatch):
153
+ """Init auxiliary variables for triton attention backend."""
154
+
155
+ bs = forward_batch.batch_size
156
+ kv_indptr = self.kv_indptr
157
+ spec_info = forward_batch.spec_info
158
+ qo_indptr = None
159
+ kv_last_page_len = None
160
+ max_extend_len = None
161
+
132
162
  if forward_batch.forward_mode.is_decode_or_idle():
133
- # update for aiter
134
- # create kv_indices and kv_inptr
135
- bs = forward_batch.batch_size
136
- kv_indptr = self.kv_indptr
137
- spec_info = forward_batch.spec_info
138
163
  if spec_info is None:
139
164
  kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
140
165
  kv_indptr = kv_indptr[: bs + 1]
@@ -154,38 +179,103 @@ class AiterAttnBackend(AttentionBackend):
154
179
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
155
180
  bs = kv_indptr.shape[0] - 1
156
181
 
157
- self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
182
+ if self.use_mla:
183
+ qo_indptr = self.qo_indptr_[: bs + 1]
184
+ qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
185
+ kv_last_page_len = self.kv_last_page_len[:bs]
186
+ max_extend_len = 1
158
187
 
159
- elif forward_batch.forward_mode.is_draft_extend():
160
- self.indices_updater_prefill.update(
161
- forward_batch.req_pool_indices,
162
- forward_batch.seq_lens,
163
- forward_batch.seq_lens_sum,
164
- prefix_lens=None,
165
- encoder_lens=forward_batch.encoder_lens,
166
- spec_info=forward_batch.spec_info,
167
- )
168
188
  self.forward_metadata = ForwardMetadata(
169
- self.indices_updater_prefill.kv_indptr,
170
- self.indices_updater_prefill.kv_indices,
171
- self.indices_updater_prefill.max_q_len,
172
- self.indices_updater_prefill.max_kv_len,
189
+ kv_indptr,
190
+ kv_indices,
191
+ qo_indptr,
192
+ kv_last_page_len,
193
+ max_extend_len,
194
+ None,
195
+ None,
196
+ None,
173
197
  )
198
+
199
+ elif forward_batch.forward_mode.is_draft_extend():
200
+ if self.use_mla:
201
+ prefix_lens = forward_batch.extend_prefix_lens
202
+ self.mla_indices_updater_prefill.update(
203
+ forward_batch.req_pool_indices,
204
+ prefix_lens,
205
+ prefix_lens.sum().item(),
206
+ forward_batch.extend_seq_lens,
207
+ encoder_lens=forward_batch.encoder_lens,
208
+ spec_info=None,
209
+ )
210
+ self.forward_metadata = ForwardMetadata(
211
+ self.mla_indices_updater_prefill.kv_indptr,
212
+ self.mla_indices_updater_prefill.kv_indices,
213
+ self.mla_indices_updater_prefill.qo_indptr,
214
+ self.mla_indices_updater_prefill.kv_last_page_len,
215
+ self.mla_indices_updater_prefill.max_extend_len,
216
+ self.mla_indices_updater_prefill.max_prefix_extend_len,
217
+ None,
218
+ None,
219
+ )
220
+ else:
221
+ self.indices_updater_prefill.update(
222
+ forward_batch.req_pool_indices,
223
+ forward_batch.seq_lens,
224
+ forward_batch.seq_lens_sum,
225
+ prefix_lens=None,
226
+ encoder_lens=forward_batch.encoder_lens,
227
+ spec_info=forward_batch.spec_info,
228
+ )
229
+ self.forward_metadata = ForwardMetadata(
230
+ self.indices_updater_prefill.kv_indptr,
231
+ self.indices_updater_prefill.kv_indices,
232
+ None,
233
+ None,
234
+ None,
235
+ None,
236
+ self.indices_updater_prefill.max_q_len,
237
+ self.indices_updater_prefill.max_kv_len,
238
+ )
174
239
  elif forward_batch.forward_mode.is_target_verify():
175
- self.indices_updater_prefill.update(
176
- forward_batch.req_pool_indices,
177
- forward_batch.seq_lens,
178
- forward_batch.seq_lens_sum,
179
- prefix_lens=None,
180
- encoder_lens=forward_batch.encoder_lens,
181
- spec_info=forward_batch.spec_info,
182
- )
183
- self.forward_metadata = ForwardMetadata(
184
- self.indices_updater_prefill.kv_indptr,
185
- self.indices_updater_prefill.kv_indices,
186
- self.indices_updater_prefill.max_q_len,
187
- self.indices_updater_prefill.max_kv_len,
188
- )
240
+ if self.use_mla:
241
+ prefix_lens = forward_batch.extend_prefix_lens
242
+ self.mla_indices_updater_prefill.update(
243
+ forward_batch.req_pool_indices,
244
+ prefix_lens,
245
+ prefix_lens.sum().item(),
246
+ forward_batch.extend_seq_lens,
247
+ encoder_lens=forward_batch.encoder_lens,
248
+ spec_info=None,
249
+ )
250
+ self.forward_metadata = ForwardMetadata(
251
+ self.mla_indices_updater_prefill.kv_indptr,
252
+ self.mla_indices_updater_prefill.kv_indices,
253
+ self.mla_indices_updater_prefill.qo_indptr,
254
+ self.mla_indices_updater_prefill.kv_last_page_len,
255
+ self.mla_indices_updater_prefill.max_extend_len,
256
+ self.mla_indices_updater_prefill.max_prefix_extend_len,
257
+ None,
258
+ None,
259
+ )
260
+ else:
261
+ self.indices_updater_prefill.update(
262
+ forward_batch.req_pool_indices,
263
+ forward_batch.seq_lens,
264
+ forward_batch.seq_lens_sum,
265
+ prefix_lens=None,
266
+ encoder_lens=forward_batch.encoder_lens,
267
+ spec_info=forward_batch.spec_info,
268
+ )
269
+ self.forward_metadata = ForwardMetadata(
270
+ self.indices_updater_prefill.kv_indptr,
271
+ self.indices_updater_prefill.kv_indices,
272
+ None,
273
+ None,
274
+ None,
275
+ None,
276
+ self.indices_updater_prefill.max_q_len,
277
+ self.indices_updater_prefill.max_kv_len,
278
+ )
189
279
  else:
190
280
  prefix_lens = forward_batch.extend_prefix_lens
191
281
 
@@ -194,24 +284,49 @@ class AiterAttnBackend(AttentionBackend):
194
284
  else:
195
285
  extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
196
286
 
197
- self.indices_updater_prefill.update(
198
- forward_batch.req_pool_indices,
199
- forward_batch.seq_lens,
200
- forward_batch.seq_lens_sum,
201
- prefix_lens,
202
- encoder_lens=forward_batch.encoder_lens,
203
- spec_info=None,
204
- )
205
- self.forward_metadata = ForwardMetadata(
206
- self.indices_updater_prefill.kv_indptr,
207
- self.indices_updater_prefill.kv_indices,
208
- self.indices_updater_prefill.max_q_len,
209
- self.indices_updater_prefill.max_kv_len,
210
- )
287
+ if self.use_mla:
288
+ self.mla_indices_updater_prefill.update(
289
+ forward_batch.req_pool_indices,
290
+ prefix_lens,
291
+ prefix_lens.sum().item(),
292
+ forward_batch.extend_seq_lens,
293
+ encoder_lens=forward_batch.encoder_lens,
294
+ spec_info=None,
295
+ )
296
+ self.forward_metadata = ForwardMetadata(
297
+ self.mla_indices_updater_prefill.kv_indptr,
298
+ self.mla_indices_updater_prefill.kv_indices,
299
+ self.mla_indices_updater_prefill.qo_indptr,
300
+ self.mla_indices_updater_prefill.kv_last_page_len,
301
+ self.mla_indices_updater_prefill.max_extend_len,
302
+ self.mla_indices_updater_prefill.max_prefix_extend_len,
303
+ None,
304
+ None,
305
+ )
306
+ else:
307
+ self.indices_updater_prefill.update(
308
+ forward_batch.req_pool_indices,
309
+ forward_batch.seq_lens,
310
+ forward_batch.seq_lens_sum,
311
+ prefix_lens,
312
+ encoder_lens=forward_batch.encoder_lens,
313
+ spec_info=None,
314
+ )
315
+ self.forward_metadata = ForwardMetadata(
316
+ self.indices_updater_prefill.kv_indptr,
317
+ self.indices_updater_prefill.kv_indices,
318
+ None,
319
+ None,
320
+ None,
321
+ None,
322
+ self.indices_updater_prefill.max_q_len,
323
+ self.indices_updater_prefill.max_kv_len,
324
+ )
211
325
 
212
326
  def init_cuda_graph_state(
213
327
  self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
214
328
  ):
329
+ self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
215
330
  if kv_indices_buf is None:
216
331
  self.cuda_graph_kv_indices = torch.zeros(
217
332
  (max_bs * self.max_context_len),
@@ -239,6 +354,10 @@ class AiterAttnBackend(AttentionBackend):
239
354
  spec_info: Optional[SpecInfo],
240
355
  ):
241
356
  if forward_mode.is_decode_or_idle():
357
+ qo_indptr = None
358
+ kv_last_page_len = None
359
+ max_extend_len = None
360
+
242
361
  if spec_info is None:
243
362
  kv_indptr = self.kv_indptr
244
363
  kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
@@ -255,25 +374,83 @@ class AiterAttnBackend(AttentionBackend):
255
374
  )
256
375
  else:
257
376
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
258
- self.forward_metadata = ForwardMetadata(kv_indptr, kv_indices, None, None)
259
377
 
260
- elif forward_mode.is_target_verify():
261
- seq_lens_sum = seq_lens.sum().item()
262
- self.indices_updater_prefill.update(
263
- req_pool_indices,
264
- seq_lens,
265
- seq_lens_sum,
266
- prefix_lens=None,
267
- encoder_lens=encoder_lens,
268
- spec_info=spec_info,
269
- )
378
+ if self.use_mla:
379
+ qo_indptr = self.qo_indptr_[: bs + 1]
380
+ qo_indptr[1 : bs + 1] = torch.cumsum(
381
+ self.cuda_graph_kv_last_page_len[:bs], dim=0
382
+ )
383
+ max_extend_len = 1
384
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
385
+
270
386
  self.forward_metadata = ForwardMetadata(
271
- self.indices_updater_prefill.kv_indptr,
272
- self.indices_updater_prefill.kv_indices,
273
- self.indices_updater_prefill.max_q_len,
274
- self.indices_updater_prefill.max_kv_len,
387
+ kv_indptr,
388
+ kv_indices,
389
+ qo_indptr,
390
+ kv_last_page_len,
391
+ max_extend_len,
392
+ None,
393
+ None,
394
+ None,
275
395
  )
276
396
 
397
+ elif forward_mode.is_target_verify():
398
+ if self.use_mla:
399
+ qo_indptr = self.qo_indptr[: bs + 1]
400
+ qo_indptr[: bs + 1] = torch.arange(
401
+ 0,
402
+ (1 + bs) * self.num_draft_tokens,
403
+ step=self.num_draft_tokens,
404
+ dtype=torch.int32,
405
+ device=self.device,
406
+ )
407
+ kv_indptr = self.kv_indptr[: bs + 1]
408
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
409
+ kv_indices = self.cuda_graph_kv_indices
410
+ create_flashinfer_kv_indices_triton[(bs,)](
411
+ self.req_to_token,
412
+ req_pool_indices,
413
+ seq_lens,
414
+ kv_indptr,
415
+ None,
416
+ kv_indices,
417
+ self.req_to_token.stride(0),
418
+ )
419
+
420
+ max_extend_len = self.num_draft_tokens
421
+ kv_last_page_len = None
422
+
423
+ self.forward_metadata = ForwardMetadata(
424
+ kv_indptr,
425
+ kv_indices,
426
+ qo_indptr,
427
+ kv_last_page_len,
428
+ max_extend_len,
429
+ None,
430
+ None,
431
+ None,
432
+ )
433
+ else:
434
+ seq_lens_sum = seq_lens.sum().item()
435
+ self.indices_updater_prefill.update(
436
+ req_pool_indices,
437
+ seq_lens,
438
+ seq_lens_sum,
439
+ prefix_lens=None,
440
+ encoder_lens=encoder_lens,
441
+ spec_info=spec_info,
442
+ )
443
+ self.forward_metadata = ForwardMetadata(
444
+ self.indices_updater_prefill.kv_indptr,
445
+ self.indices_updater_prefill.kv_indices,
446
+ None,
447
+ None,
448
+ None,
449
+ None,
450
+ self.indices_updater_prefill.max_q_len,
451
+ self.indices_updater_prefill.max_kv_len,
452
+ )
453
+
277
454
  else:
278
455
  raise ValueError(f"Invalid mode: {forward_mode=}")
279
456
 
@@ -342,31 +519,113 @@ class AiterAttnBackend(AttentionBackend):
342
519
  if k is not None:
343
520
  assert v is not None
344
521
  if save_kv_cache:
345
- forward_batch.token_to_kv_pool.set_kv_buffer(
346
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
347
- )
348
-
349
- k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
350
-
351
- bs0 = forward_batch.batch_size + 1
352
-
353
- o = mha_batch_prefill_func(
354
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
355
- k_cache,
356
- v_cache,
357
- self.qo_indptr[:bs0],
358
- self.forward_metadata.kv_indptr[:bs0],
359
- self.forward_metadata.kv_indices,
360
- self.forward_metadata.max_q_len,
361
- self.forward_metadata.max_kv_len,
362
- causal=True,
363
- logits_soft_cap=self.logits_soft_cap,
364
- alibi_slopes=None,
365
- return_lse=False,
366
- return_attn_probs=False,
367
- )
522
+ if self.use_mla:
523
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
524
+ else:
525
+ forward_batch.token_to_kv_pool.set_kv_buffer(
526
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
527
+ )
528
+
529
+ if self.use_mla:
530
+ max_extend_len = self.forward_metadata.max_extend_len
531
+ max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
532
+ kv_indptr = self.forward_metadata.kv_indptr
533
+ kv_indices = self.forward_metadata.kv_indices
534
+ kv_last_page_lens = self.forward_metadata.kv_last_page_len
535
+ qo_indptr = self.forward_metadata.qo_indptr
536
+ K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
537
+ V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
538
+ kv_lora_rank = V_Buffer.shape[-1]
539
+ qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
540
+ qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
541
+ assert len(q.shape) == 3
542
+ assert len(k.shape) == 3
543
+ assert len(v.shape) == 3
544
+
545
+ if kv_indices.shape[0] == 0:
546
+ o = flash_attn_varlen_func(
547
+ q,
548
+ k,
549
+ v,
550
+ qo_indptr,
551
+ qo_indptr,
552
+ max_extend_len,
553
+ max_extend_len,
554
+ softmax_scale=layer.scaling,
555
+ causal=True,
556
+ )
557
+ return o
558
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
559
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
560
+ kvc, k_pe = torch.split(
561
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
562
+ )
563
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
368
564
 
369
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
565
+ kvprefix = kvprefix.view(
566
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
567
+ )
568
+ k_prefix, v_prefix = torch.split(
569
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
570
+ )
571
+ k_prefix = torch.cat(
572
+ [
573
+ k_prefix,
574
+ torch.broadcast_to(
575
+ k_pe,
576
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
577
+ ),
578
+ ],
579
+ dim=-1,
580
+ )
581
+ assert (
582
+ forward_batch.extend_prefix_lens.shape
583
+ == forward_batch.extend_seq_lens.shape
584
+ )
585
+ k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
586
+ k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
587
+ assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
588
+ k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
589
+ v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
590
+ v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
591
+ v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
592
+
593
+ o = flash_attn_varlen_func(
594
+ q,
595
+ k,
596
+ v,
597
+ qo_indptr,
598
+ kv_indptr,
599
+ max_extend_len,
600
+ max_prefix_extend_len,
601
+ softmax_scale=layer.scaling,
602
+ causal=True,
603
+ )
604
+ return o
605
+ else:
606
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
607
+ layer.layer_id
608
+ )
609
+
610
+ bs0 = forward_batch.batch_size + 1
611
+
612
+ o = mha_batch_prefill_func(
613
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
614
+ k_cache,
615
+ v_cache,
616
+ self.qo_indptr[:bs0],
617
+ self.forward_metadata.kv_indptr[:bs0],
618
+ self.forward_metadata.kv_indices,
619
+ self.forward_metadata.max_q_len,
620
+ self.forward_metadata.max_kv_len,
621
+ causal=True,
622
+ logits_soft_cap=self.logits_soft_cap,
623
+ alibi_slopes=None,
624
+ return_lse=False,
625
+ return_attn_probs=False,
626
+ )
627
+
628
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
370
629
 
371
630
  def forward_decode(
372
631
  self,
@@ -377,6 +636,7 @@ class AiterAttnBackend(AttentionBackend):
377
636
  forward_batch: ForwardBatch,
378
637
  save_kv_cache=True,
379
638
  ):
639
+
380
640
  q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
381
641
 
382
642
  if layer.qk_head_dim != layer.v_head_dim:
@@ -389,32 +649,48 @@ class AiterAttnBackend(AttentionBackend):
389
649
  layer, forward_batch.out_cache_loc, k, v
390
650
  )
391
651
 
392
- self.logits_soft_cap = layer.logit_cap
393
- paged_attention_ragged(
394
- o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
395
- self.workspace_buffer,
396
- q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
397
- forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
398
- -1, 1, layer.tp_k_head_num, layer.qk_head_dim
399
- ),
400
- forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
401
- -1, 1, layer.tp_v_head_num, layer.v_head_dim
402
- ),
403
- self.scale,
404
- self.forward_metadata.kv_indptr,
405
- self.forward_metadata.kv_indices,
406
- self.kv_last_page_lens,
407
- 1,
408
- self.max_num_partitions,
409
- None,
410
- "auto",
411
- "NHD",
412
- self.logits_soft_cap,
413
- self.k_scale,
414
- self.v_scale,
415
- None,
416
- _AITER_PARTITION_SIZE_ROCM,
417
- )
652
+ if self.use_mla:
653
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
654
+ mla_decode_fwd(
655
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
656
+ k_buffer.view(-1, 1, 1, layer.qk_head_dim),
657
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
658
+ self.forward_metadata.qo_indptr,
659
+ self.forward_metadata.kv_indptr,
660
+ self.forward_metadata.kv_indices,
661
+ self.forward_metadata.kv_last_page_len,
662
+ self.forward_metadata.max_extend_len,
663
+ layer.scaling,
664
+ layer.logit_cap,
665
+ )
666
+ k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
667
+ else:
668
+ self.logits_soft_cap = layer.logit_cap
669
+ paged_attention_ragged(
670
+ o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
671
+ self.workspace_buffer,
672
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
673
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
674
+ -1, 1, layer.tp_k_head_num, layer.qk_head_dim
675
+ ),
676
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
677
+ -1, 1, layer.tp_v_head_num, layer.v_head_dim
678
+ ),
679
+ self.scale,
680
+ self.forward_metadata.kv_indptr,
681
+ self.forward_metadata.kv_indices,
682
+ self.kv_last_page_len,
683
+ 1,
684
+ self.max_num_partitions,
685
+ None,
686
+ "auto",
687
+ "NHD",
688
+ self.logits_soft_cap,
689
+ self.k_scale,
690
+ self.v_scale,
691
+ None,
692
+ _AITER_PARTITION_SIZE_ROCM,
693
+ )
418
694
 
419
695
  return o
420
696
 
@@ -506,8 +782,97 @@ class AiterIndicesUpdaterPrefill:
506
782
  spec_info.generate_attn_arg_prefill(
507
783
  req_pool_indices,
508
784
  paged_kernel_lens,
785
+ paged_kernel_lens_sum,
786
+ self.req_to_token,
787
+ )
788
+ )
789
+
790
+ self.kv_indices = kv_indices
791
+
792
+
793
+ class AiterMlaIndicesUpdaterPrefill:
794
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
795
+ # Parse Constants
796
+ self.attn_backend = attn_backend
797
+
798
+ # Buffers and wrappers
799
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
800
+ self.update = self.update_single_wrapper
801
+
802
+ self.kv_indptr = None
803
+ self.kv_indices = None
804
+ self.qo_indptr = None
805
+ self.kv_last_page_len = None
806
+ self.max_extend_len = 0
807
+ self.max_prefix_extend_len = 0
808
+
809
+ def update(
810
+ self,
811
+ req_pool_indices: torch.Tensor,
812
+ prefix_lens: torch.Tensor,
813
+ prefix_lens_sum: int,
814
+ extend_lens: torch.Tensor,
815
+ encoder_lens: Optional[torch.Tensor],
816
+ spec_info: Optional[SpecInfo],
817
+ ):
818
+ # Keep the signature for type checking. It will be assigned during runtime.
819
+ raise NotImplementedError()
820
+
821
+ def update_single_wrapper(
822
+ self,
823
+ req_pool_indices: torch.Tensor,
824
+ prefix_lens: torch.Tensor,
825
+ prefix_lens_sum: int,
826
+ extend_lens: torch.Tensor,
827
+ encoder_lens: Optional[torch.Tensor],
828
+ spec_info: Optional[SpecInfo],
829
+ ):
830
+
831
+ paged_kernel_lens = prefix_lens
832
+ paged_kernel_lens_sum = prefix_lens_sum
833
+
834
+ bs = len(req_pool_indices)
835
+
836
+ kv_indptr = self.attn_backend.kv_indptr
837
+
838
+ if spec_info is None:
839
+ # Normal extend
840
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
841
+ kv_indptr = kv_indptr[: bs + 1]
842
+ kv_indices = torch.empty(
843
+ paged_kernel_lens_sum,
844
+ dtype=torch.int32,
845
+ device=req_pool_indices.device,
846
+ )
847
+ create_flashinfer_kv_indices_triton[(bs,)](
848
+ self.req_to_token,
849
+ req_pool_indices,
850
+ paged_kernel_lens,
851
+ kv_indptr,
852
+ None,
853
+ kv_indices,
854
+ self.req_to_token.stride(0),
855
+ )
856
+
857
+ qo_indptr = self.attn_backend.qo_indptr
858
+ qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
859
+ qo_indptr = qo_indptr[: bs + 1]
860
+
861
+ max_extend_len = torch.max(extend_lens).item()
862
+ max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
863
+ kv_indptr += qo_indptr
864
+ else:
865
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
866
+ spec_info.generate_attn_arg_prefill(
867
+ req_pool_indices,
868
+ paged_kernel_lens,
869
+ paged_kernel_lens_sum,
509
870
  self.req_to_token,
510
871
  )
511
872
  )
512
873
 
874
+ self.kv_indptr = kv_indptr
513
875
  self.kv_indices = kv_indices
876
+ self.qo_indptr = qo_indptr
877
+ self.max_extend_len = max_extend_len
878
+ self.max_prefix_extend_len = max_prefix_extend_len