sglang 0.4.6.post4__py3-none-any.whl → 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (358) hide show
  1. sglang/bench_offline_throughput.py +16 -10
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +86 -22
  4. sglang/bench_serving.py +197 -110
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/lang/backend/runtime_endpoint.py +24 -1
  7. sglang/profiler.py +167 -0
  8. sglang/srt/_custom_ops.py +34 -0
  9. sglang/srt/configs/internvl.py +8 -12
  10. sglang/srt/configs/model_config.py +66 -29
  11. sglang/srt/constrained/base_grammar_backend.py +5 -2
  12. sglang/srt/constrained/llguidance_backend.py +9 -8
  13. sglang/srt/constrained/outlines_backend.py +5 -4
  14. sglang/srt/constrained/xgrammar_backend.py +18 -18
  15. sglang/srt/conversation.py +47 -9
  16. sglang/srt/custom_op.py +38 -3
  17. sglang/srt/debug_utils.py +74 -0
  18. sglang/srt/disaggregation/common/__init__.py +1 -0
  19. sglang/srt/disaggregation/common/conn.py +407 -0
  20. sglang/srt/disaggregation/decode.py +187 -134
  21. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  22. sglang/srt/disaggregation/fake/conn.py +4 -13
  23. sglang/srt/disaggregation/kv_events.py +412 -0
  24. sglang/srt/disaggregation/launch_lb.py +140 -0
  25. sglang/srt/disaggregation/mini_lb.py +84 -70
  26. sglang/srt/disaggregation/mooncake/conn.py +441 -140
  27. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -14
  28. sglang/srt/disaggregation/nixl/conn.py +124 -442
  29. sglang/srt/disaggregation/prefill.py +128 -44
  30. sglang/srt/disaggregation/utils.py +154 -6
  31. sglang/srt/distributed/device_communicators/pymscclpp.py +315 -0
  32. sglang/srt/distributed/parallel_state.py +52 -5
  33. sglang/srt/distributed/utils.py +3 -3
  34. sglang/srt/entrypoints/EngineBase.py +11 -0
  35. sglang/srt/entrypoints/engine.py +129 -12
  36. sglang/srt/entrypoints/http_server.py +21 -6
  37. sglang/srt/entrypoints/http_server_engine.py +5 -2
  38. sglang/srt/function_call/base_format_detector.py +302 -0
  39. sglang/srt/function_call/core_types.py +34 -0
  40. sglang/srt/function_call/deepseekv3_detector.py +205 -0
  41. sglang/srt/function_call/ebnf_composer.py +248 -0
  42. sglang/srt/function_call/function_call_parser.py +202 -0
  43. sglang/srt/function_call/llama32_detector.py +93 -0
  44. sglang/srt/function_call/mistral_detector.py +131 -0
  45. sglang/srt/function_call/pythonic_detector.py +229 -0
  46. sglang/srt/function_call/qwen25_detector.py +121 -0
  47. sglang/srt/function_call/utils.py +52 -0
  48. sglang/srt/hf_transformers_utils.py +50 -7
  49. sglang/srt/layers/attention/aiter_backend.py +878 -0
  50. sglang/srt/layers/attention/base_attn_backend.py +4 -0
  51. sglang/srt/layers/attention/cutlass_mla_backend.py +2 -19
  52. sglang/srt/layers/attention/flashattention_backend.py +166 -35
  53. sglang/srt/layers/attention/flashinfer_backend.py +45 -1
  54. sglang/srt/layers/attention/flashinfer_mla_backend.py +45 -5
  55. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  56. sglang/srt/layers/attention/intel_amx_backend.py +128 -0
  57. sglang/srt/layers/attention/tbo_backend.py +232 -0
  58. sglang/srt/layers/attention/torch_native_backend.py +3 -0
  59. sglang/srt/layers/attention/triton_backend.py +247 -5
  60. sglang/srt/layers/attention/triton_ops/extend_attention.py +12 -4
  61. sglang/srt/layers/attention/utils.py +2 -2
  62. sglang/srt/layers/attention/vision.py +1 -1
  63. sglang/srt/layers/communicator.py +517 -0
  64. sglang/srt/layers/dp_attention.py +6 -15
  65. sglang/srt/layers/layernorm.py +30 -19
  66. sglang/srt/layers/moe/cutlass_moe.py +370 -0
  67. sglang/srt/layers/moe/cutlass_moe_params.py +169 -0
  68. sglang/srt/layers/moe/ep_moe/kernels.py +60 -17
  69. sglang/srt/layers/moe/ep_moe/layer.py +195 -87
  70. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +88 -8
  71. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  72. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  73. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  76. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  77. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  78. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  80. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +220 -25
  81. sglang/srt/layers/moe/fused_moe_triton/layer.py +48 -4
  82. sglang/srt/layers/moe/topk.py +107 -24
  83. sglang/srt/layers/multimodal.py +70 -0
  84. sglang/srt/layers/quantization/__init__.py +10 -4
  85. sglang/srt/layers/quantization/blockwise_int8.py +3 -0
  86. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +5 -0
  87. sglang/srt/layers/quantization/deep_gemm.py +60 -59
  88. sglang/srt/layers/quantization/fp8.py +113 -18
  89. sglang/srt/layers/quantization/fp8_kernel.py +118 -66
  90. sglang/srt/layers/quantization/fp8_utils.py +165 -43
  91. sglang/srt/layers/quantization/gptq.py +298 -6
  92. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  93. sglang/srt/layers/quantization/modelopt_quant.py +334 -7
  94. sglang/srt/layers/quantization/moe_wna16.py +3 -0
  95. sglang/srt/layers/quantization/qoq.py +244 -0
  96. sglang/srt/layers/quantization/w8a8_fp8.py +3 -0
  97. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  98. sglang/srt/layers/rotary_embedding.py +6 -12
  99. sglang/srt/layers/sampler.py +80 -79
  100. sglang/srt/layers/utils.py +6 -0
  101. sglang/srt/lora/layers.py +12 -15
  102. sglang/srt/lora/lora.py +49 -5
  103. sglang/srt/lora/lora_manager.py +20 -8
  104. sglang/srt/lora/mem_pool.py +24 -16
  105. sglang/srt/lora/utils.py +17 -13
  106. sglang/srt/managers/data_parallel_controller.py +13 -5
  107. sglang/srt/managers/eplb_algorithms/__init__.py +63 -0
  108. sglang/srt/managers/eplb_algorithms/deepseek.py +223 -0
  109. sglang/srt/managers/eplb_algorithms/deepseek_vec.py +276 -0
  110. sglang/srt/managers/eplb_manager.py +96 -0
  111. sglang/srt/managers/expert_distribution.py +878 -56
  112. sglang/srt/managers/expert_location.py +448 -0
  113. sglang/srt/managers/expert_location_dispatch.py +108 -0
  114. sglang/srt/managers/io_struct.py +29 -5
  115. sglang/srt/managers/mm_utils.py +355 -151
  116. sglang/srt/managers/multimodal_processors/base_processor.py +299 -42
  117. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  118. sglang/srt/managers/multimodal_processors/gemma3.py +15 -17
  119. sglang/srt/managers/multimodal_processors/internvl.py +18 -5
  120. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  121. sglang/srt/managers/multimodal_processors/kimi_vl.py +14 -32
  122. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  123. sglang/srt/managers/multimodal_processors/minicpm.py +27 -32
  124. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  125. sglang/srt/managers/multimodal_processors/phi4mm.py +87 -0
  126. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  127. sglang/srt/managers/multimodal_processors/qwen_vl.py +35 -35
  128. sglang/srt/managers/schedule_batch.py +185 -55
  129. sglang/srt/managers/schedule_policy.py +4 -5
  130. sglang/srt/managers/scheduler.py +389 -154
  131. sglang/srt/managers/session_controller.py +1 -1
  132. sglang/srt/managers/tokenizer_manager.py +231 -39
  133. sglang/srt/managers/utils.py +0 -4
  134. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  135. sglang/srt/mem_cache/chunk_cache.py +3 -1
  136. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  137. sglang/srt/mem_cache/memory_pool.py +74 -52
  138. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  139. sglang/srt/mem_cache/radix_cache.py +58 -5
  140. sglang/srt/metrics/collector.py +11 -2
  141. sglang/srt/mm_utils.py +10 -0
  142. sglang/srt/model_executor/cuda_graph_runner.py +87 -65
  143. sglang/srt/model_executor/expert_location_updater.py +557 -0
  144. sglang/srt/model_executor/forward_batch_info.py +39 -14
  145. sglang/srt/model_executor/model_runner.py +231 -101
  146. sglang/srt/model_loader/loader.py +10 -6
  147. sglang/srt/model_loader/utils.py +67 -1
  148. sglang/srt/models/clip.py +5 -1
  149. sglang/srt/models/deepseek_nextn.py +1 -1
  150. sglang/srt/models/deepseek_v2.py +732 -403
  151. sglang/srt/models/exaone.py +8 -3
  152. sglang/srt/models/gemma3_causal.py +7 -0
  153. sglang/srt/models/gemma3_mm.py +75 -33
  154. sglang/srt/models/idefics2.py +342 -0
  155. sglang/srt/models/kimi_vl.py +4 -4
  156. sglang/srt/models/llama.py +1 -1
  157. sglang/srt/models/llama4.py +10 -2
  158. sglang/srt/models/llava.py +26 -18
  159. sglang/srt/models/mimo_mtp.py +220 -0
  160. sglang/srt/models/minicpmo.py +7 -17
  161. sglang/srt/models/minicpmv.py +3 -295
  162. sglang/srt/models/mistral.py +71 -1
  163. sglang/srt/models/mllama.py +3 -3
  164. sglang/srt/models/phi4mm.py +512 -0
  165. sglang/srt/models/qwen2.py +133 -35
  166. sglang/srt/models/qwen2_5_vl.py +5 -3
  167. sglang/srt/models/qwen2_eagle.py +4 -1
  168. sglang/srt/models/qwen2_moe.py +206 -69
  169. sglang/srt/models/qwen2_vl.py +3 -3
  170. sglang/srt/models/qwen3.py +92 -19
  171. sglang/srt/models/qwen3_moe.py +457 -55
  172. sglang/srt/models/registry.py +9 -1
  173. sglang/srt/models/siglip.py +294 -0
  174. sglang/srt/models/transformers.py +291 -0
  175. sglang/srt/openai_api/adapter.py +114 -40
  176. sglang/srt/openai_api/protocol.py +37 -2
  177. sglang/srt/openai_api/utils.py +172 -0
  178. sglang/srt/operations.py +189 -0
  179. sglang/srt/operations_strategy.py +207 -0
  180. sglang/srt/sampling/sampling_batch_info.py +13 -1
  181. sglang/srt/sampling/sampling_params.py +2 -1
  182. sglang/srt/server_args.py +235 -38
  183. sglang/srt/speculative/build_eagle_tree.py +8 -8
  184. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -11
  185. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +253 -0
  186. sglang/srt/speculative/eagle_utils.py +181 -90
  187. sglang/srt/speculative/eagle_worker.py +146 -21
  188. sglang/srt/two_batch_overlap.py +635 -0
  189. sglang/srt/utils.py +197 -19
  190. sglang/test/runners.py +16 -7
  191. sglang/test/send_one.py +4 -0
  192. sglang/test/test_cutlass_moe.py +278 -0
  193. sglang/test/test_fp4_moe.py +248 -0
  194. sglang/test/test_utils.py +81 -42
  195. sglang/utils.py +2 -2
  196. sglang/version.py +1 -1
  197. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/METADATA +31 -19
  198. sglang-0.4.7.dist-info/RECORD +699 -0
  199. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/WHEEL +1 -1
  200. sglang/srt/function_call_parser.py +0 -858
  201. sglang/srt/platforms/interface.py +0 -371
  202. sglang-0.4.6.post4.dist-info/RECORD +0 -646
  203. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  204. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  205. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  206. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  207. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  208. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  209. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  210. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  211. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  212. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  213. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  214. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  215. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  216. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1024,device_name=NVIDIA_H200.json → triton_3_1_0/E=16,N=1024,device_name=NVIDIA_H200.json} +0 -0
  217. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  218. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  219. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  220. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  221. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  222. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  223. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  224. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  225. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  226. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  227. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  228. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  229. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  230. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  231. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  232. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  233. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json} +0 -0
  234. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  235. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json → triton_3_1_0/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json} +0 -0
  236. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  237. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  238. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  239. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  240. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  241. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  242. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  243. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  244. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  245. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json} +0 -0
  246. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  247. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  248. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  249. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  250. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  251. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  252. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json} +0 -0
  253. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  254. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  255. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json} +0 -0
  256. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json → triton_3_1_0/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json} +0 -0
  257. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  258. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  259. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  260. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  261. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  262. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  263. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=1280,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=1280,device_name=NVIDIA_H200.json} +0 -0
  264. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  265. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  266. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=2560,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=2560,device_name=NVIDIA_H200.json} +0 -0
  267. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  268. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  269. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  270. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=320,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=320,device_name=NVIDIA_H200.json} +0 -0
  271. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_1_0/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  272. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  273. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  274. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  275. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  276. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  277. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  278. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=64,N=640,device_name=NVIDIA_H200.json → triton_3_1_0/E=64,N=640,device_name=NVIDIA_H200.json} +0 -0
  279. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI300X.json} +0 -0
  280. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Instinct_MI325X.json} +0 -0
  281. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=14336,device_name=AMD_Radeon_Graphics.json} +0 -0
  282. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  283. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  284. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=14336,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=14336,device_name=NVIDIA_H200.json} +0 -0
  285. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI300X.json} +0 -0
  286. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Instinct_MI325X.json} +0 -0
  287. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=1792,device_name=AMD_Radeon_Graphics.json} +0 -0
  288. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  289. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  290. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  291. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  292. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=1792,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=1792,device_name=NVIDIA_H200.json} +0 -0
  293. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  294. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  295. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  296. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  297. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=2048,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=2048,device_name=NVIDIA_H200.json} +0 -0
  298. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI300X.json} +0 -0
  299. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Instinct_MI325X.json} +0 -0
  300. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=3584,device_name=AMD_Radeon_Graphics.json} +0 -0
  301. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json} +0 -0
  302. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  303. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json} +0 -0
  304. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  305. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  306. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  307. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_H200.json} +0 -0
  308. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=3584,device_name=NVIDIA_L40S.json → triton_3_1_0/E=8,N=3584,device_name=NVIDIA_L40S.json} +0 -0
  309. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  310. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  311. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  312. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  313. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  314. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  315. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  316. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=4096,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=4096,device_name=NVIDIA_H200.json} +0 -0
  317. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI300X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI300X.json} +0 -0
  318. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Instinct_MI325X.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Instinct_MI325X.json} +0 -0
  319. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=AMD_Radeon_Graphics.json → triton_3_1_0/E=8,N=7168,device_name=AMD_Radeon_Graphics.json} +0 -0
  320. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json} +0 -0
  321. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  322. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  323. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  324. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=7168,device_name=NVIDIA_H200.json → triton_3_1_0/E=8,N=7168,device_name=NVIDIA_H200.json} +0 -0
  325. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json} +0 -0
  326. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json} +0 -0
  327. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json} +0 -0
  328. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json} +0 -0
  329. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json → triton_3_1_0/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json} +0 -0
  330. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  331. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  332. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H20.json} +0 -0
  333. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=192,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=192,device_name=NVIDIA_H200.json} +0 -0
  334. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  335. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  336. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H20.json} +0 -0
  337. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  338. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=384,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=384,device_name=NVIDIA_H200.json} +0 -0
  339. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  340. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  341. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json} +0 -0
  342. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  343. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H20.json} +0 -0
  344. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  345. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=768,device_name=NVIDIA_H200.json → triton_3_2_0/E=128,N=768,device_name=NVIDIA_H200.json} +0 -0
  346. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=128,N=96,device_name=NVIDIA_H20.json → triton_3_2_0/E=128,N=96,device_name=NVIDIA_H20.json} +0 -0
  347. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  348. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  349. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  350. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  351. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json} +0 -0
  352. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  353. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_2_0/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json} +0 -0
  354. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  355. /sglang/srt/layers/moe/fused_moe_triton/configs/{E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json → triton_3_2_0/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json} +0 -0
  356. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  357. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/licenses/LICENSE +0 -0
  358. {sglang-0.4.6.post4.dist-info → sglang-0.4.7.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,878 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ end to end attention solution with aiter kernels
5
+ """
6
+
7
+ import math
8
+ import os
9
+ from dataclasses import dataclass
10
+ from enum import Enum, auto
11
+ from functools import partial
12
+ from typing import TYPE_CHECKING, List, Optional, Union
13
+
14
+ import torch
15
+ import triton
16
+ import triton.language as tl
17
+
18
+ from sglang.global_config import global_config
19
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
20
+ from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
21
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.radix_attention import RadixAttention
26
+ from sglang.srt.model_executor.model_runner import ModelRunner
27
+ from sglang.srt.speculative.spec_info import SpecInfo
28
+
29
+ try:
30
+ from aiter import (
31
+ flash_attn_varlen_func,
32
+ mha_batch_prefill_func,
33
+ paged_attention_ragged,
34
+ )
35
+ from aiter.mla import mla_decode_fwd
36
+ except ImportError:
37
+ print(
38
+ "aiter is AMD specific kernel library. Please make sure aiter is installed on your AMD device."
39
+ )
40
+
41
+ from sglang.srt.configs.model_config import AttentionArch
42
+
43
+
44
+ class WrapperDispatch(Enum):
45
+ SLIDING_WINDOW = auto()
46
+ CROSS_ATTENTION = auto()
47
+
48
+
49
+ @dataclass
50
+ class ForwardMetadata:
51
+ kv_indptr: torch.Tensor
52
+ kv_indices: torch.Tensor
53
+ qo_indptr: torch.Tensor
54
+ kv_last_page_len: torch.Tensor
55
+ max_extend_len: int
56
+ max_prefix_extend_len: int
57
+ max_q_len: int
58
+ max_kv_len: int
59
+
60
+
61
+ global_workspace_buffer = None
62
+
63
+ _AITER_PARTITION_SIZE_ROCM = 256
64
+
65
+
66
+ class AiterAttnBackend(AttentionBackend):
67
+ def __init__(
68
+ self,
69
+ model_runner: ModelRunner,
70
+ skip_prefill: bool = False,
71
+ kv_indptr_buf: Optional[torch.Tensor] = None,
72
+ ):
73
+ super().__init__()
74
+
75
+ self.device = model_runner.device
76
+ self.is_multimodal = model_runner.model_config.is_multimodal
77
+ self.num_draft_tokens = model_runner.server_args.speculative_num_draft_tokens
78
+ self.num_head = (
79
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
80
+ )
81
+ self.head_dim = model_runner.model_config.head_dim
82
+ self.v_head_dim = model_runner.token_to_kv_pool.get_value_buffer(0).shape[-1]
83
+ self.num_kv_head = model_runner.model_config.get_num_kv_heads(
84
+ get_attention_tp_size()
85
+ )
86
+ self.kv_cache_dtype = model_runner.kv_cache_dtype
87
+
88
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
89
+
90
+ self.use_mla = model_runner.model_config.attention_arch == AttentionArch.MLA
91
+
92
+ # Parse constants
93
+ self.max_context_len = model_runner.model_config.context_len
94
+ self.skip_prefill = skip_prefill
95
+
96
+ max_bs = model_runner.req_to_token_pool.size
97
+
98
+ if kv_indptr_buf is None:
99
+ self.kv_indptr = torch.zeros(
100
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
101
+ )
102
+ else:
103
+ self.kv_indptr = kv_indptr_buf
104
+
105
+ self.kv_last_page_len = torch.ones(
106
+ (max_bs,), dtype=torch.int32, device=model_runner.device
107
+ )
108
+ self.qo_indptr = torch.zeros(
109
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
110
+ )
111
+
112
+ # Create prefill indices updater
113
+ if not skip_prefill:
114
+ self.indices_updater_prefill = AiterIndicesUpdaterPrefill(
115
+ model_runner, self
116
+ )
117
+ if self.use_mla:
118
+ self.mla_indices_updater_prefill = AiterMlaIndicesUpdaterPrefill(
119
+ model_runner, self
120
+ )
121
+
122
+ # aiter kernel related initialization
123
+ self.max_num_partitions = (
124
+ self.max_context_len + _AITER_PARTITION_SIZE_ROCM - 1
125
+ ) // _AITER_PARTITION_SIZE_ROCM
126
+
127
+ nbyes_per_qo_elem = torch.finfo(torch.float32).bits // 8
128
+
129
+ if not self.use_mla:
130
+ self.workspace_buffer = torch.empty(
131
+ (max_bs * self.num_head * self.max_num_partitions * self.head_dim)
132
+ * nbyes_per_qo_elem
133
+ + 2 * (max_bs * self.num_head * self.max_num_partitions) * 4,
134
+ dtype=torch.uint8,
135
+ device=self.device,
136
+ )
137
+
138
+ self.scale = float(1.0 / (self.head_dim**0.5))
139
+ self.k_scale = self.v_scale = torch.tensor([1.0], dtype=torch.float32).to(
140
+ self.device
141
+ )
142
+
143
+ self.logits_soft_cap = 0.0
144
+
145
+ self.forward_metadata: ForwardMetadata = None
146
+
147
+ if self.use_mla:
148
+ self.qo_indptr_ = torch.zeros(
149
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
150
+ )
151
+
152
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
153
+ """Init auxiliary variables for triton attention backend."""
154
+
155
+ bs = forward_batch.batch_size
156
+ kv_indptr = self.kv_indptr
157
+ spec_info = forward_batch.spec_info
158
+ qo_indptr = None
159
+ kv_last_page_len = None
160
+ max_extend_len = None
161
+
162
+ if forward_batch.forward_mode.is_decode_or_idle():
163
+ if spec_info is None:
164
+ kv_indptr[1 : bs + 1] = torch.cumsum(forward_batch.seq_lens, dim=0)
165
+ kv_indptr = kv_indptr[: bs + 1]
166
+ kv_indices = torch.zeros(
167
+ forward_batch.seq_lens_sum, dtype=torch.int32, device=self.device
168
+ )
169
+ create_flashinfer_kv_indices_triton[(bs,)](
170
+ self.req_to_token,
171
+ forward_batch.req_pool_indices,
172
+ forward_batch.seq_lens,
173
+ kv_indptr,
174
+ None,
175
+ kv_indices,
176
+ self.req_to_token.stride(0),
177
+ )
178
+ else:
179
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
180
+ bs = kv_indptr.shape[0] - 1
181
+
182
+ if self.use_mla:
183
+ qo_indptr = self.qo_indptr_[: bs + 1]
184
+ qo_indptr[1 : bs + 1] = torch.cumsum(self.kv_last_page_len[:bs], dim=0)
185
+ kv_last_page_len = self.kv_last_page_len[:bs]
186
+ max_extend_len = 1
187
+
188
+ self.forward_metadata = ForwardMetadata(
189
+ kv_indptr,
190
+ kv_indices,
191
+ qo_indptr,
192
+ kv_last_page_len,
193
+ max_extend_len,
194
+ None,
195
+ None,
196
+ None,
197
+ )
198
+
199
+ elif forward_batch.forward_mode.is_draft_extend():
200
+ if self.use_mla:
201
+ prefix_lens = forward_batch.extend_prefix_lens
202
+ self.mla_indices_updater_prefill.update(
203
+ forward_batch.req_pool_indices,
204
+ prefix_lens,
205
+ prefix_lens.sum().item(),
206
+ forward_batch.extend_seq_lens,
207
+ encoder_lens=forward_batch.encoder_lens,
208
+ spec_info=None,
209
+ )
210
+ self.forward_metadata = ForwardMetadata(
211
+ self.mla_indices_updater_prefill.kv_indptr,
212
+ self.mla_indices_updater_prefill.kv_indices,
213
+ self.mla_indices_updater_prefill.qo_indptr,
214
+ self.mla_indices_updater_prefill.kv_last_page_len,
215
+ self.mla_indices_updater_prefill.max_extend_len,
216
+ self.mla_indices_updater_prefill.max_prefix_extend_len,
217
+ None,
218
+ None,
219
+ )
220
+ else:
221
+ self.indices_updater_prefill.update(
222
+ forward_batch.req_pool_indices,
223
+ forward_batch.seq_lens,
224
+ forward_batch.seq_lens_sum,
225
+ prefix_lens=None,
226
+ encoder_lens=forward_batch.encoder_lens,
227
+ spec_info=forward_batch.spec_info,
228
+ )
229
+ self.forward_metadata = ForwardMetadata(
230
+ self.indices_updater_prefill.kv_indptr,
231
+ self.indices_updater_prefill.kv_indices,
232
+ None,
233
+ None,
234
+ None,
235
+ None,
236
+ self.indices_updater_prefill.max_q_len,
237
+ self.indices_updater_prefill.max_kv_len,
238
+ )
239
+ elif forward_batch.forward_mode.is_target_verify():
240
+ if self.use_mla:
241
+ prefix_lens = forward_batch.extend_prefix_lens
242
+ self.mla_indices_updater_prefill.update(
243
+ forward_batch.req_pool_indices,
244
+ prefix_lens,
245
+ prefix_lens.sum().item(),
246
+ forward_batch.extend_seq_lens,
247
+ encoder_lens=forward_batch.encoder_lens,
248
+ spec_info=None,
249
+ )
250
+ self.forward_metadata = ForwardMetadata(
251
+ self.mla_indices_updater_prefill.kv_indptr,
252
+ self.mla_indices_updater_prefill.kv_indices,
253
+ self.mla_indices_updater_prefill.qo_indptr,
254
+ self.mla_indices_updater_prefill.kv_last_page_len,
255
+ self.mla_indices_updater_prefill.max_extend_len,
256
+ self.mla_indices_updater_prefill.max_prefix_extend_len,
257
+ None,
258
+ None,
259
+ )
260
+ else:
261
+ self.indices_updater_prefill.update(
262
+ forward_batch.req_pool_indices,
263
+ forward_batch.seq_lens,
264
+ forward_batch.seq_lens_sum,
265
+ prefix_lens=None,
266
+ encoder_lens=forward_batch.encoder_lens,
267
+ spec_info=forward_batch.spec_info,
268
+ )
269
+ self.forward_metadata = ForwardMetadata(
270
+ self.indices_updater_prefill.kv_indptr,
271
+ self.indices_updater_prefill.kv_indices,
272
+ None,
273
+ None,
274
+ None,
275
+ None,
276
+ self.indices_updater_prefill.max_q_len,
277
+ self.indices_updater_prefill.max_kv_len,
278
+ )
279
+ else:
280
+ prefix_lens = forward_batch.extend_prefix_lens
281
+
282
+ if self.is_multimodal:
283
+ extend_no_prefix = False
284
+ else:
285
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens_cpu)
286
+
287
+ if self.use_mla:
288
+ self.mla_indices_updater_prefill.update(
289
+ forward_batch.req_pool_indices,
290
+ prefix_lens,
291
+ prefix_lens.sum().item(),
292
+ forward_batch.extend_seq_lens,
293
+ encoder_lens=forward_batch.encoder_lens,
294
+ spec_info=None,
295
+ )
296
+ self.forward_metadata = ForwardMetadata(
297
+ self.mla_indices_updater_prefill.kv_indptr,
298
+ self.mla_indices_updater_prefill.kv_indices,
299
+ self.mla_indices_updater_prefill.qo_indptr,
300
+ self.mla_indices_updater_prefill.kv_last_page_len,
301
+ self.mla_indices_updater_prefill.max_extend_len,
302
+ self.mla_indices_updater_prefill.max_prefix_extend_len,
303
+ None,
304
+ None,
305
+ )
306
+ else:
307
+ self.indices_updater_prefill.update(
308
+ forward_batch.req_pool_indices,
309
+ forward_batch.seq_lens,
310
+ forward_batch.seq_lens_sum,
311
+ prefix_lens,
312
+ encoder_lens=forward_batch.encoder_lens,
313
+ spec_info=None,
314
+ )
315
+ self.forward_metadata = ForwardMetadata(
316
+ self.indices_updater_prefill.kv_indptr,
317
+ self.indices_updater_prefill.kv_indices,
318
+ None,
319
+ None,
320
+ None,
321
+ None,
322
+ self.indices_updater_prefill.max_q_len,
323
+ self.indices_updater_prefill.max_kv_len,
324
+ )
325
+
326
+ def init_cuda_graph_state(
327
+ self, max_bs: int, kv_indices_buf: Optional[torch.Tensor] = None
328
+ ):
329
+ self.cuda_graph_kv_last_page_len = torch.ones(max_bs, dtype=torch.int)
330
+ if kv_indices_buf is None:
331
+ self.cuda_graph_kv_indices = torch.zeros(
332
+ (max_bs * self.max_context_len),
333
+ dtype=torch.int32,
334
+ device=self.device,
335
+ )
336
+ else:
337
+ self.cuda_graph_kv_indices = kv_indices_buf
338
+
339
+ if not self.skip_prefill:
340
+ self.cuda_graph_custom_mask = torch.zeros(
341
+ (max_bs * self.max_context_len),
342
+ dtype=torch.uint8,
343
+ device=self.device,
344
+ )
345
+
346
+ def init_forward_metadata_capture_cuda_graph(
347
+ self,
348
+ bs: int,
349
+ num_tokens: int,
350
+ req_pool_indices: torch.Tensor,
351
+ seq_lens: torch.Tensor,
352
+ encoder_lens: Optional[torch.Tensor],
353
+ forward_mode: ForwardMode,
354
+ spec_info: Optional[SpecInfo],
355
+ ):
356
+ if forward_mode.is_decode_or_idle():
357
+ qo_indptr = None
358
+ kv_last_page_len = None
359
+ max_extend_len = None
360
+
361
+ if spec_info is None:
362
+ kv_indptr = self.kv_indptr
363
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
364
+ kv_indptr = kv_indptr[: bs + 1]
365
+ kv_indices = self.cuda_graph_kv_indices
366
+ create_flashinfer_kv_indices_triton[(bs,)](
367
+ self.req_to_token,
368
+ req_pool_indices,
369
+ seq_lens,
370
+ kv_indptr,
371
+ None,
372
+ kv_indices,
373
+ self.req_to_token.stride(0),
374
+ )
375
+ else:
376
+ kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
377
+
378
+ if self.use_mla:
379
+ qo_indptr = self.qo_indptr_[: bs + 1]
380
+ qo_indptr[1 : bs + 1] = torch.cumsum(
381
+ self.cuda_graph_kv_last_page_len[:bs], dim=0
382
+ )
383
+ max_extend_len = 1
384
+ kv_last_page_len = self.cuda_graph_kv_last_page_len[:bs]
385
+
386
+ self.forward_metadata = ForwardMetadata(
387
+ kv_indptr,
388
+ kv_indices,
389
+ qo_indptr,
390
+ kv_last_page_len,
391
+ max_extend_len,
392
+ None,
393
+ None,
394
+ None,
395
+ )
396
+
397
+ elif forward_mode.is_target_verify():
398
+ if self.use_mla:
399
+ qo_indptr = self.qo_indptr[: bs + 1]
400
+ qo_indptr[: bs + 1] = torch.arange(
401
+ 0,
402
+ (1 + bs) * self.num_draft_tokens,
403
+ step=self.num_draft_tokens,
404
+ dtype=torch.int32,
405
+ device=self.device,
406
+ )
407
+ kv_indptr = self.kv_indptr[: bs + 1]
408
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens, dim=0)
409
+ kv_indices = self.cuda_graph_kv_indices
410
+ create_flashinfer_kv_indices_triton[(bs,)](
411
+ self.req_to_token,
412
+ req_pool_indices,
413
+ seq_lens,
414
+ kv_indptr,
415
+ None,
416
+ kv_indices,
417
+ self.req_to_token.stride(0),
418
+ )
419
+
420
+ max_extend_len = self.num_draft_tokens
421
+ kv_last_page_len = None
422
+
423
+ self.forward_metadata = ForwardMetadata(
424
+ kv_indptr,
425
+ kv_indices,
426
+ qo_indptr,
427
+ kv_last_page_len,
428
+ max_extend_len,
429
+ None,
430
+ None,
431
+ None,
432
+ )
433
+ else:
434
+ seq_lens_sum = seq_lens.sum().item()
435
+ self.indices_updater_prefill.update(
436
+ req_pool_indices,
437
+ seq_lens,
438
+ seq_lens_sum,
439
+ prefix_lens=None,
440
+ encoder_lens=encoder_lens,
441
+ spec_info=spec_info,
442
+ )
443
+ self.forward_metadata = ForwardMetadata(
444
+ self.indices_updater_prefill.kv_indptr,
445
+ self.indices_updater_prefill.kv_indices,
446
+ None,
447
+ None,
448
+ None,
449
+ None,
450
+ self.indices_updater_prefill.max_q_len,
451
+ self.indices_updater_prefill.max_kv_len,
452
+ )
453
+
454
+ else:
455
+ raise ValueError(f"Invalid mode: {forward_mode=}")
456
+
457
+ def init_forward_metadata_replay_cuda_graph(
458
+ self,
459
+ bs: int,
460
+ req_pool_indices: torch.Tensor,
461
+ seq_lens: torch.Tensor,
462
+ seq_lens_sum: int,
463
+ encoder_lens: Optional[torch.Tensor],
464
+ forward_mode: ForwardMode,
465
+ spec_info: Optional[SpecInfo],
466
+ seq_lens_cpu: Optional[torch.Tensor],
467
+ ):
468
+ if forward_mode.is_decode_or_idle():
469
+ kv_indptr = self.kv_indptr
470
+ kv_indices = self.cuda_graph_kv_indices
471
+ if spec_info is None:
472
+ kv_indptr[1 : bs + 1] = torch.cumsum(seq_lens[:bs], dim=0)
473
+ kv_indptr = kv_indptr[: bs + 1]
474
+ create_flashinfer_kv_indices_triton[(bs,)](
475
+ self.req_to_token,
476
+ req_pool_indices[:bs],
477
+ seq_lens[:bs],
478
+ kv_indptr,
479
+ None,
480
+ kv_indices,
481
+ self.req_to_token.stride(0),
482
+ )
483
+ else:
484
+ kv_indptr[: spec_info.kv_indptr.shape[0]] = spec_info.kv_indptr
485
+ kv_indices[: spec_info.kv_indices.shape[0]] = spec_info.kv_indices
486
+
487
+ elif forward_mode.is_target_verify():
488
+ self.indices_updater_prefill.update(
489
+ req_pool_indices[:bs],
490
+ seq_lens[:bs],
491
+ seq_lens_sum,
492
+ prefix_lens=None,
493
+ encoder_lens=encoder_lens[:bs] if encoder_lens is not None else None,
494
+ spec_info=spec_info,
495
+ )
496
+ else:
497
+ raise ValueError("Invalid forward mode")
498
+
499
+ def get_cuda_graph_seq_len_fill_value(self):
500
+ return 1
501
+
502
+ def forward_extend(
503
+ self,
504
+ q: torch.Tensor,
505
+ k: torch.Tensor,
506
+ v: torch.Tensor,
507
+ layer: RadixAttention,
508
+ forward_batch: ForwardBatch,
509
+ save_kv_cache=True,
510
+ ):
511
+ cache_loc = (
512
+ forward_batch.out_cache_loc
513
+ if not layer.is_cross_attention
514
+ else forward_batch.encoder_out_cache_loc
515
+ )
516
+
517
+ self.logits_soft_cap = layer.logit_cap
518
+
519
+ if k is not None:
520
+ assert v is not None
521
+ if save_kv_cache:
522
+ if self.use_mla:
523
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
524
+ else:
525
+ forward_batch.token_to_kv_pool.set_kv_buffer(
526
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
527
+ )
528
+
529
+ if self.use_mla:
530
+ max_extend_len = self.forward_metadata.max_extend_len
531
+ max_prefix_extend_len = self.forward_metadata.max_prefix_extend_len
532
+ kv_indptr = self.forward_metadata.kv_indptr
533
+ kv_indices = self.forward_metadata.kv_indices
534
+ kv_last_page_lens = self.forward_metadata.kv_last_page_len
535
+ qo_indptr = self.forward_metadata.qo_indptr
536
+ K_Buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
537
+ V_Buffer = forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id)
538
+ kv_lora_rank = V_Buffer.shape[-1]
539
+ qk_rope_head_dim = K_Buffer.shape[-1] - kv_lora_rank
540
+ qk_nope_head_dim = k.shape[-1] - qk_rope_head_dim
541
+ assert len(q.shape) == 3
542
+ assert len(k.shape) == 3
543
+ assert len(v.shape) == 3
544
+
545
+ if kv_indices.shape[0] == 0:
546
+ o = flash_attn_varlen_func(
547
+ q,
548
+ k,
549
+ v,
550
+ qo_indptr,
551
+ qo_indptr,
552
+ max_extend_len,
553
+ max_extend_len,
554
+ softmax_scale=layer.scaling,
555
+ causal=True,
556
+ )
557
+ return o
558
+ elif layer.qk_head_dim != (kv_lora_rank + qk_rope_head_dim):
559
+ K_Buffer = torch.index_select(K_Buffer, 0, kv_indices)
560
+ kvc, k_pe = torch.split(
561
+ K_Buffer, [kv_lora_rank, qk_rope_head_dim], dim=-1
562
+ )
563
+ kvprefix = layer.kv_b_proj(kvc.contiguous())[0]
564
+
565
+ kvprefix = kvprefix.view(
566
+ -1, layer.tp_k_head_num, qk_nope_head_dim + layer.v_head_dim
567
+ )
568
+ k_prefix, v_prefix = torch.split(
569
+ kvprefix, [qk_nope_head_dim, layer.v_head_dim], dim=-1
570
+ )
571
+ k_prefix = torch.cat(
572
+ [
573
+ k_prefix,
574
+ torch.broadcast_to(
575
+ k_pe,
576
+ (k_pe.shape[0], layer.tp_k_head_num, k_pe.shape[2]),
577
+ ),
578
+ ],
579
+ dim=-1,
580
+ )
581
+ assert (
582
+ forward_batch.extend_prefix_lens.shape
583
+ == forward_batch.extend_seq_lens.shape
584
+ )
585
+ k_prefix = torch.split(k_prefix, forward_batch.extend_prefix_lens_cpu)
586
+ k_extend = torch.split(k, forward_batch.extend_seq_lens_cpu)
587
+ assert len(k_prefix) == len(forward_batch.extend_prefix_lens_cpu)
588
+ k = torch.cat([x for el in zip(k_prefix, k_extend) for x in el])
589
+ v_prefix = torch.split(v_prefix, forward_batch.extend_prefix_lens_cpu)
590
+ v_extend = torch.split(v, forward_batch.extend_seq_lens_cpu)
591
+ v = torch.cat([x for el in zip(v_prefix, v_extend) for x in el])
592
+
593
+ o = flash_attn_varlen_func(
594
+ q,
595
+ k,
596
+ v,
597
+ qo_indptr,
598
+ kv_indptr,
599
+ max_extend_len,
600
+ max_prefix_extend_len,
601
+ softmax_scale=layer.scaling,
602
+ causal=True,
603
+ )
604
+ return o
605
+ else:
606
+ k_cache, v_cache = forward_batch.token_to_kv_pool.get_kv_buffer(
607
+ layer.layer_id
608
+ )
609
+
610
+ bs0 = forward_batch.batch_size + 1
611
+
612
+ o = mha_batch_prefill_func(
613
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
614
+ k_cache,
615
+ v_cache,
616
+ self.qo_indptr[:bs0],
617
+ self.forward_metadata.kv_indptr[:bs0],
618
+ self.forward_metadata.kv_indices,
619
+ self.forward_metadata.max_q_len,
620
+ self.forward_metadata.max_kv_len,
621
+ causal=True,
622
+ logits_soft_cap=self.logits_soft_cap,
623
+ alibi_slopes=None,
624
+ return_lse=False,
625
+ return_attn_probs=False,
626
+ )
627
+
628
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
629
+
630
+ def forward_decode(
631
+ self,
632
+ q: torch.Tensor,
633
+ k: torch.Tensor,
634
+ v: torch.Tensor,
635
+ layer: RadixAttention,
636
+ forward_batch: ForwardBatch,
637
+ save_kv_cache=True,
638
+ ):
639
+
640
+ q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
641
+
642
+ if layer.qk_head_dim != layer.v_head_dim:
643
+ o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
644
+ else:
645
+ o = torch.empty_like(q)
646
+
647
+ if save_kv_cache:
648
+ forward_batch.token_to_kv_pool.set_kv_buffer(
649
+ layer, forward_batch.out_cache_loc, k, v
650
+ )
651
+
652
+ if self.use_mla:
653
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
654
+ mla_decode_fwd(
655
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
656
+ k_buffer.view(-1, 1, 1, layer.qk_head_dim),
657
+ o.view(-1, layer.tp_q_head_num, layer.v_head_dim),
658
+ self.forward_metadata.qo_indptr,
659
+ self.forward_metadata.kv_indptr,
660
+ self.forward_metadata.kv_indices,
661
+ self.forward_metadata.kv_last_page_len,
662
+ self.forward_metadata.max_extend_len,
663
+ layer.scaling,
664
+ layer.logit_cap,
665
+ )
666
+ k_buffer = k_buffer.view(-1, 1, layer.qk_head_dim)
667
+ else:
668
+ self.logits_soft_cap = layer.logit_cap
669
+ paged_attention_ragged(
670
+ o.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
671
+ self.workspace_buffer,
672
+ q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
673
+ forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id).view(
674
+ -1, 1, layer.tp_k_head_num, layer.qk_head_dim
675
+ ),
676
+ forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id).view(
677
+ -1, 1, layer.tp_v_head_num, layer.v_head_dim
678
+ ),
679
+ self.scale,
680
+ self.forward_metadata.kv_indptr,
681
+ self.forward_metadata.kv_indices,
682
+ self.kv_last_page_len,
683
+ 1,
684
+ self.max_num_partitions,
685
+ None,
686
+ "auto",
687
+ "NHD",
688
+ self.logits_soft_cap,
689
+ self.k_scale,
690
+ self.v_scale,
691
+ None,
692
+ _AITER_PARTITION_SIZE_ROCM,
693
+ )
694
+
695
+ return o
696
+
697
+
698
+ class AiterIndicesUpdaterPrefill:
699
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
700
+ # Parse Constants
701
+ self.num_qo_heads = (
702
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
703
+ )
704
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
705
+ get_attention_tp_size()
706
+ )
707
+ self.head_dim = model_runner.model_config.head_dim
708
+ self.data_type = model_runner.kv_cache_dtype
709
+ self.q_data_type = model_runner.dtype
710
+ self.sliding_window_size = model_runner.sliding_window_size
711
+ self.attn_backend = attn_backend
712
+
713
+ # Buffers and wrappers
714
+ self.kv_indptr = attn_backend.kv_indptr
715
+ self.kv_last_page_len = attn_backend.kv_last_page_len
716
+ self.qo_indptr = attn_backend.qo_indptr
717
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
718
+ self.update = self.update_single_wrapper
719
+
720
+ self.kv_indices = None
721
+ self.max_q_len = 0
722
+ self.max_kv_len = 0
723
+
724
+ def update(
725
+ self,
726
+ req_pool_indices: torch.Tensor,
727
+ seq_lens: torch.Tensor,
728
+ seq_lens_sum: int,
729
+ prefix_lens: torch.Tensor,
730
+ encoder_lens: Optional[torch.Tensor],
731
+ spec_info: Optional[SpecInfo],
732
+ ):
733
+ # Keep the signature for type checking. It will be assigned during runtime.
734
+ raise NotImplementedError()
735
+
736
+ def update_single_wrapper(
737
+ self,
738
+ req_pool_indices: torch.Tensor,
739
+ seq_lens: torch.Tensor,
740
+ seq_lens_sum: int,
741
+ prefix_lens: torch.Tensor,
742
+ encoder_lens: Optional[torch.Tensor],
743
+ spec_info: Optional[SpecInfo],
744
+ ):
745
+
746
+ kv_start_idx = None
747
+ kv_indptr = self.kv_indptr
748
+ qo_indptr = self.qo_indptr
749
+ paged_kernel_lens = seq_lens
750
+ paged_kernel_lens_sum = seq_lens_sum
751
+
752
+ bs = len(req_pool_indices)
753
+ if spec_info is None:
754
+ # Normal extend
755
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
756
+ kv_indptr = kv_indptr[: bs + 1]
757
+ kv_indices = torch.empty(
758
+ paged_kernel_lens_sum + 256,
759
+ dtype=torch.int32,
760
+ device=req_pool_indices.device,
761
+ )
762
+ create_flashinfer_kv_indices_triton[(bs,)](
763
+ self.req_to_token,
764
+ req_pool_indices,
765
+ paged_kernel_lens,
766
+ kv_indptr,
767
+ kv_start_idx,
768
+ kv_indices,
769
+ self.req_to_token.shape[1],
770
+ )
771
+
772
+ self.max_kv_len = torch.max(paged_kernel_lens).item()
773
+
774
+ extend_lens = seq_lens - prefix_lens
775
+ self.max_q_len = torch.max(extend_lens).item()
776
+
777
+ qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
778
+ qo_indptr = qo_indptr[: bs + 1]
779
+ custom_mask = None
780
+ else:
781
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
782
+ spec_info.generate_attn_arg_prefill(
783
+ req_pool_indices,
784
+ paged_kernel_lens,
785
+ paged_kernel_lens_sum,
786
+ self.req_to_token,
787
+ )
788
+ )
789
+
790
+ self.kv_indices = kv_indices
791
+
792
+
793
+ class AiterMlaIndicesUpdaterPrefill:
794
+ def __init__(self, model_runner: ModelRunner, attn_backend: AttentionBackend):
795
+ # Parse Constants
796
+ self.attn_backend = attn_backend
797
+
798
+ # Buffers and wrappers
799
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
800
+ self.update = self.update_single_wrapper
801
+
802
+ self.kv_indptr = None
803
+ self.kv_indices = None
804
+ self.qo_indptr = None
805
+ self.kv_last_page_len = None
806
+ self.max_extend_len = 0
807
+ self.max_prefix_extend_len = 0
808
+
809
+ def update(
810
+ self,
811
+ req_pool_indices: torch.Tensor,
812
+ prefix_lens: torch.Tensor,
813
+ prefix_lens_sum: int,
814
+ extend_lens: torch.Tensor,
815
+ encoder_lens: Optional[torch.Tensor],
816
+ spec_info: Optional[SpecInfo],
817
+ ):
818
+ # Keep the signature for type checking. It will be assigned during runtime.
819
+ raise NotImplementedError()
820
+
821
+ def update_single_wrapper(
822
+ self,
823
+ req_pool_indices: torch.Tensor,
824
+ prefix_lens: torch.Tensor,
825
+ prefix_lens_sum: int,
826
+ extend_lens: torch.Tensor,
827
+ encoder_lens: Optional[torch.Tensor],
828
+ spec_info: Optional[SpecInfo],
829
+ ):
830
+
831
+ paged_kernel_lens = prefix_lens
832
+ paged_kernel_lens_sum = prefix_lens_sum
833
+
834
+ bs = len(req_pool_indices)
835
+
836
+ kv_indptr = self.attn_backend.kv_indptr
837
+
838
+ if spec_info is None:
839
+ # Normal extend
840
+ kv_indptr[1 : bs + 1] = torch.cumsum(paged_kernel_lens, dim=0)
841
+ kv_indptr = kv_indptr[: bs + 1]
842
+ kv_indices = torch.empty(
843
+ paged_kernel_lens_sum,
844
+ dtype=torch.int32,
845
+ device=req_pool_indices.device,
846
+ )
847
+ create_flashinfer_kv_indices_triton[(bs,)](
848
+ self.req_to_token,
849
+ req_pool_indices,
850
+ paged_kernel_lens,
851
+ kv_indptr,
852
+ None,
853
+ kv_indices,
854
+ self.req_to_token.stride(0),
855
+ )
856
+
857
+ qo_indptr = self.attn_backend.qo_indptr
858
+ qo_indptr[1 : bs + 1] = torch.cumsum(extend_lens, dim=0)
859
+ qo_indptr = qo_indptr[: bs + 1]
860
+
861
+ max_extend_len = torch.max(extend_lens).item()
862
+ max_prefix_extend_len = torch.max(extend_lens + paged_kernel_lens).item()
863
+ kv_indptr += qo_indptr
864
+ else:
865
+ kv_indices, kv_indptr, qo_indptr, custom_mask = (
866
+ spec_info.generate_attn_arg_prefill(
867
+ req_pool_indices,
868
+ paged_kernel_lens,
869
+ paged_kernel_lens_sum,
870
+ self.req_to_token,
871
+ )
872
+ )
873
+
874
+ self.kv_indptr = kv_indptr
875
+ self.kv_indices = kv_indices
876
+ self.qo_indptr = qo_indptr
877
+ self.max_extend_len = max_extend_len
878
+ self.max_prefix_extend_len = max_prefix_extend_len