vllm-cpu-amxbf16 0.11.2.post2__cp310-cp310-manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1536) hide show
  1. vllm/_C.abi3.so +0 -0
  2. vllm/__init__.py +225 -0
  3. vllm/_aiter_ops.py +983 -0
  4. vllm/_bc_linter.py +54 -0
  5. vllm/_custom_ops.py +2863 -0
  6. vllm/_ipex_ops.py +457 -0
  7. vllm/_version.py +34 -0
  8. vllm/assets/__init__.py +0 -0
  9. vllm/assets/audio.py +43 -0
  10. vllm/assets/base.py +40 -0
  11. vllm/assets/image.py +59 -0
  12. vllm/assets/video.py +149 -0
  13. vllm/attention/__init__.py +18 -0
  14. vllm/attention/backends/__init__.py +0 -0
  15. vllm/attention/backends/abstract.py +391 -0
  16. vllm/attention/backends/registry.py +195 -0
  17. vllm/attention/backends/utils.py +33 -0
  18. vllm/attention/layer.py +1052 -0
  19. vllm/attention/layers/__init__.py +0 -0
  20. vllm/attention/layers/chunked_local_attention.py +121 -0
  21. vllm/attention/layers/cross_attention.py +178 -0
  22. vllm/attention/layers/encoder_only_attention.py +103 -0
  23. vllm/attention/ops/__init__.py +0 -0
  24. vllm/attention/ops/chunked_prefill_paged_decode.py +401 -0
  25. vllm/attention/ops/common.py +414 -0
  26. vllm/attention/ops/flashmla.py +251 -0
  27. vllm/attention/ops/merge_attn_states.py +47 -0
  28. vllm/attention/ops/paged_attn.py +262 -0
  29. vllm/attention/ops/pallas_kv_cache_update.py +130 -0
  30. vllm/attention/ops/prefix_prefill.py +814 -0
  31. vllm/attention/ops/rocm_aiter_paged_attn.py +123 -0
  32. vllm/attention/ops/triton_decode_attention.py +712 -0
  33. vllm/attention/ops/triton_merge_attn_states.py +105 -0
  34. vllm/attention/ops/triton_reshape_and_cache_flash.py +184 -0
  35. vllm/attention/ops/triton_unified_attention.py +941 -0
  36. vllm/attention/ops/vit_attn_wrappers.py +178 -0
  37. vllm/attention/selector.py +231 -0
  38. vllm/attention/utils/__init__.py +0 -0
  39. vllm/attention/utils/fa_utils.py +109 -0
  40. vllm/attention/utils/kv_sharing_utils.py +33 -0
  41. vllm/attention/utils/kv_transfer_utils.py +60 -0
  42. vllm/beam_search.py +88 -0
  43. vllm/benchmarks/__init__.py +0 -0
  44. vllm/benchmarks/datasets.py +3222 -0
  45. vllm/benchmarks/latency.py +172 -0
  46. vllm/benchmarks/lib/__init__.py +3 -0
  47. vllm/benchmarks/lib/endpoint_request_func.py +777 -0
  48. vllm/benchmarks/lib/ready_checker.py +72 -0
  49. vllm/benchmarks/lib/utils.py +79 -0
  50. vllm/benchmarks/serve.py +1531 -0
  51. vllm/benchmarks/sweep/__init__.py +0 -0
  52. vllm/benchmarks/sweep/cli.py +38 -0
  53. vllm/benchmarks/sweep/param_sweep.py +91 -0
  54. vllm/benchmarks/sweep/plot.py +580 -0
  55. vllm/benchmarks/sweep/serve.py +416 -0
  56. vllm/benchmarks/sweep/serve_sla.py +492 -0
  57. vllm/benchmarks/sweep/server.py +114 -0
  58. vllm/benchmarks/sweep/sla_sweep.py +132 -0
  59. vllm/benchmarks/sweep/utils.py +4 -0
  60. vllm/benchmarks/throughput.py +799 -0
  61. vllm/collect_env.py +857 -0
  62. vllm/compilation/__init__.py +0 -0
  63. vllm/compilation/activation_quant_fusion.py +209 -0
  64. vllm/compilation/backends.py +759 -0
  65. vllm/compilation/base_static_graph.py +57 -0
  66. vllm/compilation/caching.py +178 -0
  67. vllm/compilation/collective_fusion.py +1234 -0
  68. vllm/compilation/compiler_interface.py +639 -0
  69. vllm/compilation/counter.py +48 -0
  70. vllm/compilation/cuda_graph.py +208 -0
  71. vllm/compilation/decorators.py +571 -0
  72. vllm/compilation/fix_functionalization.py +253 -0
  73. vllm/compilation/fusion.py +374 -0
  74. vllm/compilation/fusion_attn.py +359 -0
  75. vllm/compilation/fx_utils.py +91 -0
  76. vllm/compilation/inductor_pass.py +133 -0
  77. vllm/compilation/matcher_utils.py +317 -0
  78. vllm/compilation/monitor.py +62 -0
  79. vllm/compilation/noop_elimination.py +134 -0
  80. vllm/compilation/partition_rules.py +72 -0
  81. vllm/compilation/pass_manager.py +135 -0
  82. vllm/compilation/piecewise_backend.py +121 -0
  83. vllm/compilation/post_cleanup.py +21 -0
  84. vllm/compilation/qk_norm_rope_fusion.py +238 -0
  85. vllm/compilation/sequence_parallelism.py +363 -0
  86. vllm/compilation/torch25_custom_graph_pass.py +44 -0
  87. vllm/compilation/vllm_inductor_pass.py +173 -0
  88. vllm/compilation/wrapper.py +238 -0
  89. vllm/config/__init__.py +102 -0
  90. vllm/config/cache.py +207 -0
  91. vllm/config/compilation.py +975 -0
  92. vllm/config/device.py +75 -0
  93. vllm/config/ec_transfer.py +110 -0
  94. vllm/config/kv_events.py +56 -0
  95. vllm/config/kv_transfer.py +114 -0
  96. vllm/config/load.py +124 -0
  97. vllm/config/lora.py +112 -0
  98. vllm/config/model.py +2162 -0
  99. vllm/config/multimodal.py +248 -0
  100. vllm/config/observability.py +123 -0
  101. vllm/config/parallel.py +655 -0
  102. vllm/config/pooler.py +122 -0
  103. vllm/config/scheduler.py +298 -0
  104. vllm/config/speculative.py +654 -0
  105. vllm/config/speech_to_text.py +38 -0
  106. vllm/config/structured_outputs.py +92 -0
  107. vllm/config/utils.py +178 -0
  108. vllm/config/vllm.py +1166 -0
  109. vllm/connections.py +189 -0
  110. vllm/device_allocator/__init__.py +0 -0
  111. vllm/device_allocator/cumem.py +327 -0
  112. vllm/distributed/__init__.py +6 -0
  113. vllm/distributed/communication_op.py +43 -0
  114. vllm/distributed/device_communicators/__init__.py +0 -0
  115. vllm/distributed/device_communicators/all2all.py +490 -0
  116. vllm/distributed/device_communicators/all_reduce_utils.py +344 -0
  117. vllm/distributed/device_communicators/base_device_communicator.py +297 -0
  118. vllm/distributed/device_communicators/cpu_communicator.py +209 -0
  119. vllm/distributed/device_communicators/cuda_communicator.py +340 -0
  120. vllm/distributed/device_communicators/cuda_wrapper.py +216 -0
  121. vllm/distributed/device_communicators/custom_all_reduce.py +326 -0
  122. vllm/distributed/device_communicators/mnnvl_compat.py +27 -0
  123. vllm/distributed/device_communicators/pynccl.py +386 -0
  124. vllm/distributed/device_communicators/pynccl_allocator.py +191 -0
  125. vllm/distributed/device_communicators/pynccl_wrapper.py +564 -0
  126. vllm/distributed/device_communicators/quick_all_reduce.py +290 -0
  127. vllm/distributed/device_communicators/ray_communicator.py +259 -0
  128. vllm/distributed/device_communicators/shm_broadcast.py +733 -0
  129. vllm/distributed/device_communicators/shm_object_storage.py +660 -0
  130. vllm/distributed/device_communicators/symm_mem.py +156 -0
  131. vllm/distributed/device_communicators/tpu_communicator.py +107 -0
  132. vllm/distributed/device_communicators/xpu_communicator.py +95 -0
  133. vllm/distributed/ec_transfer/__init__.py +14 -0
  134. vllm/distributed/ec_transfer/ec_connector/__init__.py +0 -0
  135. vllm/distributed/ec_transfer/ec_connector/base.py +247 -0
  136. vllm/distributed/ec_transfer/ec_connector/factory.py +88 -0
  137. vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py +201 -0
  138. vllm/distributed/ec_transfer/ec_transfer_state.py +42 -0
  139. vllm/distributed/eplb/__init__.py +8 -0
  140. vllm/distributed/eplb/eplb_state.py +837 -0
  141. vllm/distributed/eplb/rebalance_algo.py +260 -0
  142. vllm/distributed/eplb/rebalance_execute.py +431 -0
  143. vllm/distributed/kv_events.py +371 -0
  144. vllm/distributed/kv_transfer/README.md +29 -0
  145. vllm/distributed/kv_transfer/__init__.py +20 -0
  146. vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg +0 -0
  147. vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
  148. vllm/distributed/kv_transfer/kv_connector/base.py +10 -0
  149. vllm/distributed/kv_transfer/kv_connector/factory.py +192 -0
  150. vllm/distributed/kv_transfer/kv_connector/utils.py +268 -0
  151. vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +19 -0
  152. vllm/distributed/kv_transfer/kv_connector/v1/base.py +546 -0
  153. vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +419 -0
  154. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +216 -0
  155. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py +18 -0
  156. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py +379 -0
  157. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +221 -0
  158. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +1411 -0
  159. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +867 -0
  160. vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +189 -0
  161. vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +454 -0
  162. vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +2440 -0
  163. vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +504 -0
  164. vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py +0 -0
  165. vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +531 -0
  166. vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +632 -0
  167. vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +273 -0
  168. vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +450 -0
  169. vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
  170. vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +179 -0
  171. vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +164 -0
  172. vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +242 -0
  173. vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
  174. vllm/distributed/kv_transfer/kv_pipe/base.py +66 -0
  175. vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +295 -0
  176. vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +285 -0
  177. vllm/distributed/kv_transfer/kv_transfer_state.py +78 -0
  178. vllm/distributed/parallel_state.py +1759 -0
  179. vllm/distributed/tpu_distributed_utils.py +188 -0
  180. vllm/distributed/utils.py +543 -0
  181. vllm/engine/__init__.py +0 -0
  182. vllm/engine/arg_utils.py +2144 -0
  183. vllm/engine/async_llm_engine.py +6 -0
  184. vllm/engine/llm_engine.py +6 -0
  185. vllm/engine/protocol.py +170 -0
  186. vllm/entrypoints/__init__.py +0 -0
  187. vllm/entrypoints/anthropic/__init__.py +0 -0
  188. vllm/entrypoints/anthropic/protocol.py +162 -0
  189. vllm/entrypoints/anthropic/serving_messages.py +460 -0
  190. vllm/entrypoints/api_server.py +184 -0
  191. vllm/entrypoints/chat_utils.py +1690 -0
  192. vllm/entrypoints/cli/__init__.py +13 -0
  193. vllm/entrypoints/cli/benchmark/__init__.py +0 -0
  194. vllm/entrypoints/cli/benchmark/base.py +25 -0
  195. vllm/entrypoints/cli/benchmark/latency.py +21 -0
  196. vllm/entrypoints/cli/benchmark/main.py +56 -0
  197. vllm/entrypoints/cli/benchmark/serve.py +21 -0
  198. vllm/entrypoints/cli/benchmark/sweep.py +21 -0
  199. vllm/entrypoints/cli/benchmark/throughput.py +21 -0
  200. vllm/entrypoints/cli/collect_env.py +38 -0
  201. vllm/entrypoints/cli/main.py +79 -0
  202. vllm/entrypoints/cli/openai.py +256 -0
  203. vllm/entrypoints/cli/run_batch.py +68 -0
  204. vllm/entrypoints/cli/serve.py +249 -0
  205. vllm/entrypoints/cli/types.py +29 -0
  206. vllm/entrypoints/constants.py +10 -0
  207. vllm/entrypoints/context.py +572 -0
  208. vllm/entrypoints/dynamic_lora.py +57 -0
  209. vllm/entrypoints/harmony_utils.py +535 -0
  210. vllm/entrypoints/launcher.py +175 -0
  211. vllm/entrypoints/llm.py +1768 -0
  212. vllm/entrypoints/logger.py +84 -0
  213. vllm/entrypoints/openai/__init__.py +0 -0
  214. vllm/entrypoints/openai/api_server.py +2096 -0
  215. vllm/entrypoints/openai/cli_args.py +302 -0
  216. vllm/entrypoints/openai/orca_metrics.py +120 -0
  217. vllm/entrypoints/openai/protocol.py +3299 -0
  218. vllm/entrypoints/openai/run_batch.py +547 -0
  219. vllm/entrypoints/openai/serving_chat.py +1772 -0
  220. vllm/entrypoints/openai/serving_classification.py +235 -0
  221. vllm/entrypoints/openai/serving_completion.py +715 -0
  222. vllm/entrypoints/openai/serving_embedding.py +695 -0
  223. vllm/entrypoints/openai/serving_engine.py +1433 -0
  224. vllm/entrypoints/openai/serving_models.py +304 -0
  225. vllm/entrypoints/openai/serving_pooling.py +346 -0
  226. vllm/entrypoints/openai/serving_responses.py +2021 -0
  227. vllm/entrypoints/openai/serving_score.py +503 -0
  228. vllm/entrypoints/openai/serving_tokenization.py +203 -0
  229. vllm/entrypoints/openai/serving_tokens.py +269 -0
  230. vllm/entrypoints/openai/serving_transcription.py +148 -0
  231. vllm/entrypoints/openai/speech_to_text.py +405 -0
  232. vllm/entrypoints/openai/tool_parsers/__init__.py +142 -0
  233. vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +273 -0
  234. vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +390 -0
  235. vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +390 -0
  236. vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py +210 -0
  237. vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +200 -0
  238. vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +273 -0
  239. vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +253 -0
  240. vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +494 -0
  241. vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +420 -0
  242. vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +227 -0
  243. vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +323 -0
  244. vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +590 -0
  245. vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +341 -0
  246. vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +290 -0
  247. vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py +37 -0
  248. vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py +643 -0
  249. vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +849 -0
  250. vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +390 -0
  251. vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py +366 -0
  252. vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +97 -0
  253. vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +120 -0
  254. vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +332 -0
  255. vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +781 -0
  256. vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +1316 -0
  257. vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +744 -0
  258. vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +303 -0
  259. vllm/entrypoints/openai/tool_parsers/utils.py +229 -0
  260. vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +556 -0
  261. vllm/entrypoints/renderer.py +409 -0
  262. vllm/entrypoints/responses_utils.py +77 -0
  263. vllm/entrypoints/sagemaker/__init__.py +4 -0
  264. vllm/entrypoints/sagemaker/routes.py +72 -0
  265. vllm/entrypoints/score_utils.py +242 -0
  266. vllm/entrypoints/ssl.py +78 -0
  267. vllm/entrypoints/tool.py +143 -0
  268. vllm/entrypoints/tool_server.py +209 -0
  269. vllm/entrypoints/utils.py +319 -0
  270. vllm/env_override.py +378 -0
  271. vllm/envs.py +1659 -0
  272. vllm/forward_context.py +356 -0
  273. vllm/inputs/__init__.py +44 -0
  274. vllm/inputs/data.py +359 -0
  275. vllm/inputs/parse.py +137 -0
  276. vllm/inputs/preprocess.py +727 -0
  277. vllm/logger.py +267 -0
  278. vllm/logging_utils/__init__.py +10 -0
  279. vllm/logging_utils/dump_input.py +83 -0
  280. vllm/logging_utils/formatter.py +77 -0
  281. vllm/logging_utils/log_time.py +34 -0
  282. vllm/logits_process.py +121 -0
  283. vllm/logprobs.py +208 -0
  284. vllm/lora/__init__.py +0 -0
  285. vllm/lora/layers/__init__.py +41 -0
  286. vllm/lora/layers/base.py +67 -0
  287. vllm/lora/layers/base_linear.py +164 -0
  288. vllm/lora/layers/column_parallel_linear.py +578 -0
  289. vllm/lora/layers/fused_moe.py +472 -0
  290. vllm/lora/layers/logits_processor.py +252 -0
  291. vllm/lora/layers/replicated_linear.py +70 -0
  292. vllm/lora/layers/row_parallel_linear.py +181 -0
  293. vllm/lora/layers/utils.py +65 -0
  294. vllm/lora/layers/vocal_parallel_embedding.py +166 -0
  295. vllm/lora/lora_weights.py +198 -0
  296. vllm/lora/models.py +890 -0
  297. vllm/lora/ops/__init__.py +0 -0
  298. vllm/lora/ops/ipex_ops/__init__.py +6 -0
  299. vllm/lora/ops/ipex_ops/lora_ops.py +57 -0
  300. vllm/lora/ops/torch_ops/__init__.py +20 -0
  301. vllm/lora/ops/torch_ops/lora_ops.py +128 -0
  302. vllm/lora/ops/triton_ops/README_TUNING.md +60 -0
  303. vllm/lora/ops/triton_ops/__init__.py +21 -0
  304. vllm/lora/ops/triton_ops/fused_moe_lora_op.py +641 -0
  305. vllm/lora/ops/triton_ops/kernel_utils.py +340 -0
  306. vllm/lora/ops/triton_ops/lora_expand_op.py +310 -0
  307. vllm/lora/ops/triton_ops/lora_kernel_metadata.py +154 -0
  308. vllm/lora/ops/triton_ops/lora_shrink_op.py +287 -0
  309. vllm/lora/ops/triton_ops/utils.py +295 -0
  310. vllm/lora/ops/xla_ops/__init__.py +6 -0
  311. vllm/lora/ops/xla_ops/lora_ops.py +141 -0
  312. vllm/lora/peft_helper.py +128 -0
  313. vllm/lora/punica_wrapper/__init__.py +10 -0
  314. vllm/lora/punica_wrapper/punica_base.py +492 -0
  315. vllm/lora/punica_wrapper/punica_cpu.py +351 -0
  316. vllm/lora/punica_wrapper/punica_gpu.py +411 -0
  317. vllm/lora/punica_wrapper/punica_selector.py +21 -0
  318. vllm/lora/punica_wrapper/punica_tpu.py +359 -0
  319. vllm/lora/punica_wrapper/punica_xpu.py +279 -0
  320. vllm/lora/punica_wrapper/utils.py +150 -0
  321. vllm/lora/request.py +100 -0
  322. vllm/lora/resolver.py +88 -0
  323. vllm/lora/utils.py +293 -0
  324. vllm/lora/worker_manager.py +279 -0
  325. vllm/model_executor/__init__.py +11 -0
  326. vllm/model_executor/custom_op.py +194 -0
  327. vllm/model_executor/layers/__init__.py +0 -0
  328. vllm/model_executor/layers/activation.py +569 -0
  329. vllm/model_executor/layers/attention_layer_base.py +35 -0
  330. vllm/model_executor/layers/batch_invariant.py +854 -0
  331. vllm/model_executor/layers/conv.py +236 -0
  332. vllm/model_executor/layers/fla/__init__.py +8 -0
  333. vllm/model_executor/layers/fla/ops/__init__.py +17 -0
  334. vllm/model_executor/layers/fla/ops/chunk.py +240 -0
  335. vllm/model_executor/layers/fla/ops/chunk_delta_h.py +344 -0
  336. vllm/model_executor/layers/fla/ops/chunk_o.py +183 -0
  337. vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +154 -0
  338. vllm/model_executor/layers/fla/ops/cumsum.py +280 -0
  339. vllm/model_executor/layers/fla/ops/fused_recurrent.py +390 -0
  340. vllm/model_executor/layers/fla/ops/index.py +41 -0
  341. vllm/model_executor/layers/fla/ops/kda.py +1351 -0
  342. vllm/model_executor/layers/fla/ops/l2norm.py +146 -0
  343. vllm/model_executor/layers/fla/ops/layernorm_guard.py +396 -0
  344. vllm/model_executor/layers/fla/ops/op.py +60 -0
  345. vllm/model_executor/layers/fla/ops/solve_tril.py +556 -0
  346. vllm/model_executor/layers/fla/ops/utils.py +194 -0
  347. vllm/model_executor/layers/fla/ops/wy_fast.py +158 -0
  348. vllm/model_executor/layers/fused_moe/__init__.py +106 -0
  349. vllm/model_executor/layers/fused_moe/all2all_utils.py +160 -0
  350. vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +406 -0
  351. vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +180 -0
  352. vllm/model_executor/layers/fused_moe/config.py +916 -0
  353. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  354. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  355. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  356. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  357. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  358. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  359. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  360. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  361. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json +146 -0
  362. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  363. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  364. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  365. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  366. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  367. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  368. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  369. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  370. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  371. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json +123 -0
  372. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  373. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200.json +146 -0
  374. vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +147 -0
  375. vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_L40S.json +147 -0
  376. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  377. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  378. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20-3e.json +146 -0
  379. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  380. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  381. vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +122 -0
  382. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  383. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  384. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  385. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  386. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  387. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e.json +146 -0
  388. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  389. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  390. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  391. vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  392. vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  393. vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +114 -0
  394. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  395. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI308X.json +213 -0
  396. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  397. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  398. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  399. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  400. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  401. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  402. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  403. vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json +82 -0
  404. vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +82 -0
  405. vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json +147 -0
  406. vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json +147 -0
  407. vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  408. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  409. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +147 -0
  410. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  411. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json +146 -0
  412. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  413. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  414. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  415. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  416. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  417. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  418. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  419. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  420. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  421. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  422. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  423. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  424. vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  425. vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200.json +146 -0
  426. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  427. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  428. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  429. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  430. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  431. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json +146 -0
  432. vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  433. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  434. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  435. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  436. vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  437. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  438. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  439. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  440. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  441. vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  442. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI300X.json +201 -0
  443. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json +164 -0
  444. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  445. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H20-3e.json +146 -0
  446. vllm/model_executor/layers/fused_moe/configs/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  447. vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  448. vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json +164 -0
  449. vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI355_OAM,dtype=fp8_w8a8.json +164 -0
  450. vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  451. vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  452. vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  453. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  454. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  455. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  456. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  457. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
  458. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  459. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  460. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  461. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  462. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  463. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  464. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  465. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  466. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  467. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  468. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  469. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  470. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  471. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  472. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  473. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  474. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  475. vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  476. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  477. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  478. vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  479. vllm/model_executor/layers/fused_moe/configs/E=32,N=1408,device_name=NVIDIA_B200.json +147 -0
  480. vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  481. vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  482. vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  483. vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  484. vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  485. vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  486. vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  487. vllm/model_executor/layers/fused_moe/configs/E=40,N=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +147 -0
  488. vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  489. vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  490. vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  491. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_A100-SXM4-80GB.json +147 -0
  492. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json +146 -0
  493. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json +147 -0
  494. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  495. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  496. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  497. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  498. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json +146 -0
  499. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  500. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  501. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  502. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  503. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json +146 -0
  504. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json +146 -0
  505. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  506. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json +146 -0
  507. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json +146 -0
  508. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +147 -0
  509. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json +146 -0
  510. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json +146 -0
  511. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  512. vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
  513. vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
  514. vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
  515. vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
  516. vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json +200 -0
  517. vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json +200 -0
  518. vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  519. vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json +200 -0
  520. vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  521. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  522. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  523. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  524. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  525. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  526. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  527. vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_B200.json +147 -0
  528. vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  529. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  530. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  531. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  532. vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  533. vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json +146 -0
  534. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  535. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  536. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  537. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  538. vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  539. vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json +146 -0
  540. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  541. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  542. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  543. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  544. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  545. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  546. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  547. vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H100_PCIe,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  548. vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  549. vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json +146 -0
  550. vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json +146 -0
  551. vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json +82 -0
  552. vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +82 -0
  553. vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json +200 -0
  554. vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json +200 -0
  555. vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  556. vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json +200 -0
  557. vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  558. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  559. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  560. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  561. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  562. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  563. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  564. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  565. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  566. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json +200 -0
  567. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  568. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json +200 -0
  569. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  570. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  571. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  572. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  573. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  574. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  575. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  576. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  577. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  578. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  579. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json +200 -0
  580. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  581. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json +200 -0
  582. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  583. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  584. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  585. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +154 -0
  586. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  587. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  588. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  589. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  590. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  591. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  592. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  593. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  594. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  595. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  596. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  597. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  598. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  599. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  600. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  601. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json +200 -0
  602. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  603. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json +200 -0
  604. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  605. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  606. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  607. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  608. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  609. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  610. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  611. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  612. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  613. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  614. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  615. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  616. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  617. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  618. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  619. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json +200 -0
  620. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  621. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json +200 -0
  622. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  623. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  624. vllm/model_executor/layers/fused_moe/configs/README +12 -0
  625. vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +354 -0
  626. vllm/model_executor/layers/fused_moe/cutlass_moe.py +1052 -0
  627. vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +387 -0
  628. vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +416 -0
  629. vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +420 -0
  630. vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +367 -0
  631. vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +307 -0
  632. vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +362 -0
  633. vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +192 -0
  634. vllm/model_executor/layers/fused_moe/fused_batched_moe.py +1012 -0
  635. vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +792 -0
  636. vllm/model_executor/layers/fused_moe/fused_moe.py +2175 -0
  637. vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +112 -0
  638. vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +164 -0
  639. vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +316 -0
  640. vllm/model_executor/layers/fused_moe/layer.py +1944 -0
  641. vllm/model_executor/layers/fused_moe/modular_kernel.py +1222 -0
  642. vllm/model_executor/layers/fused_moe/moe_align_block_size.py +174 -0
  643. vllm/model_executor/layers/fused_moe/moe_pallas.py +83 -0
  644. vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +229 -0
  645. vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +60 -0
  646. vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +362 -0
  647. vllm/model_executor/layers/fused_moe/prepare_finalize.py +77 -0
  648. vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +265 -0
  649. vllm/model_executor/layers/fused_moe/routing_simulator.py +310 -0
  650. vllm/model_executor/layers/fused_moe/shared_fused_moe.py +97 -0
  651. vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +171 -0
  652. vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +163 -0
  653. vllm/model_executor/layers/fused_moe/trtllm_moe.py +143 -0
  654. vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +578 -0
  655. vllm/model_executor/layers/fused_moe/utils.py +332 -0
  656. vllm/model_executor/layers/kda.py +448 -0
  657. vllm/model_executor/layers/layernorm.py +442 -0
  658. vllm/model_executor/layers/lightning_attn.py +729 -0
  659. vllm/model_executor/layers/linear.py +1424 -0
  660. vllm/model_executor/layers/logits_processor.py +106 -0
  661. vllm/model_executor/layers/mamba/__init__.py +0 -0
  662. vllm/model_executor/layers/mamba/abstract.py +71 -0
  663. vllm/model_executor/layers/mamba/linear_attn.py +402 -0
  664. vllm/model_executor/layers/mamba/mamba_mixer.py +535 -0
  665. vllm/model_executor/layers/mamba/mamba_mixer2.py +928 -0
  666. vllm/model_executor/layers/mamba/mamba_utils.py +225 -0
  667. vllm/model_executor/layers/mamba/ops/__init__.py +0 -0
  668. vllm/model_executor/layers/mamba/ops/causal_conv1d.py +1240 -0
  669. vllm/model_executor/layers/mamba/ops/layernorm_gated.py +172 -0
  670. vllm/model_executor/layers/mamba/ops/mamba_ssm.py +478 -0
  671. vllm/model_executor/layers/mamba/ops/ssd_bmm.py +211 -0
  672. vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +456 -0
  673. vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +700 -0
  674. vllm/model_executor/layers/mamba/ops/ssd_combined.py +230 -0
  675. vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +157 -0
  676. vllm/model_executor/layers/mamba/short_conv.py +264 -0
  677. vllm/model_executor/layers/mla.py +168 -0
  678. vllm/model_executor/layers/pooler.py +817 -0
  679. vllm/model_executor/layers/quantization/__init__.py +174 -0
  680. vllm/model_executor/layers/quantization/auto_round.py +454 -0
  681. vllm/model_executor/layers/quantization/awq.py +277 -0
  682. vllm/model_executor/layers/quantization/awq_marlin.py +659 -0
  683. vllm/model_executor/layers/quantization/awq_triton.py +337 -0
  684. vllm/model_executor/layers/quantization/base_config.py +170 -0
  685. vllm/model_executor/layers/quantization/bitblas.py +502 -0
  686. vllm/model_executor/layers/quantization/bitsandbytes.py +658 -0
  687. vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +3 -0
  688. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +914 -0
  689. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2284 -0
  690. vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +35 -0
  691. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +392 -0
  692. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +55 -0
  693. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +176 -0
  694. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +124 -0
  695. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +218 -0
  696. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +183 -0
  697. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py +153 -0
  698. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +138 -0
  699. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +200 -0
  700. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +125 -0
  701. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +219 -0
  702. vllm/model_executor/layers/quantization/compressed_tensors/transform/__init__.py +0 -0
  703. vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py +260 -0
  704. vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py +173 -0
  705. vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/__init__.py +0 -0
  706. vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py +64 -0
  707. vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py +13 -0
  708. vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +224 -0
  709. vllm/model_executor/layers/quantization/compressed_tensors/utils.py +216 -0
  710. vllm/model_executor/layers/quantization/deepspeedfp.py +218 -0
  711. vllm/model_executor/layers/quantization/experts_int8.py +240 -0
  712. vllm/model_executor/layers/quantization/fbgemm_fp8.py +195 -0
  713. vllm/model_executor/layers/quantization/fp8.py +1333 -0
  714. vllm/model_executor/layers/quantization/fp_quant.py +420 -0
  715. vllm/model_executor/layers/quantization/gguf.py +643 -0
  716. vllm/model_executor/layers/quantization/gptq.py +393 -0
  717. vllm/model_executor/layers/quantization/gptq_bitblas.py +482 -0
  718. vllm/model_executor/layers/quantization/gptq_marlin.py +789 -0
  719. vllm/model_executor/layers/quantization/gptq_marlin_24.py +320 -0
  720. vllm/model_executor/layers/quantization/hqq_marlin.py +371 -0
  721. vllm/model_executor/layers/quantization/inc.py +65 -0
  722. vllm/model_executor/layers/quantization/input_quant_fp8.py +171 -0
  723. vllm/model_executor/layers/quantization/ipex_quant.py +467 -0
  724. vllm/model_executor/layers/quantization/kernels/__init__.py +0 -0
  725. vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +94 -0
  726. vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +105 -0
  727. vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +115 -0
  728. vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +323 -0
  729. vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py +98 -0
  730. vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +119 -0
  731. vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py +111 -0
  732. vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +161 -0
  733. vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +159 -0
  734. vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +166 -0
  735. vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +73 -0
  736. vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +97 -0
  737. vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +120 -0
  738. vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +219 -0
  739. vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +140 -0
  740. vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +42 -0
  741. vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +105 -0
  742. vllm/model_executor/layers/quantization/kv_cache.py +146 -0
  743. vllm/model_executor/layers/quantization/modelopt.py +1788 -0
  744. vllm/model_executor/layers/quantization/moe_wna16.py +541 -0
  745. vllm/model_executor/layers/quantization/mxfp4.py +1162 -0
  746. vllm/model_executor/layers/quantization/petit.py +320 -0
  747. vllm/model_executor/layers/quantization/ptpc_fp8.py +137 -0
  748. vllm/model_executor/layers/quantization/quark/__init__.py +0 -0
  749. vllm/model_executor/layers/quantization/quark/quark.py +528 -0
  750. vllm/model_executor/layers/quantization/quark/quark_moe.py +683 -0
  751. vllm/model_executor/layers/quantization/quark/schemes/__init__.py +9 -0
  752. vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +306 -0
  753. vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  754. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +179 -0
  755. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +139 -0
  756. vllm/model_executor/layers/quantization/quark/utils.py +105 -0
  757. vllm/model_executor/layers/quantization/qutlass_utils.py +185 -0
  758. vllm/model_executor/layers/quantization/rtn.py +652 -0
  759. vllm/model_executor/layers/quantization/schema.py +90 -0
  760. vllm/model_executor/layers/quantization/torchao.py +380 -0
  761. vllm/model_executor/layers/quantization/tpu_int8.py +139 -0
  762. vllm/model_executor/layers/quantization/utils/__init__.py +6 -0
  763. vllm/model_executor/layers/quantization/utils/allspark_utils.py +67 -0
  764. vllm/model_executor/layers/quantization/utils/bitblas_utils.py +229 -0
  765. vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  766. vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  767. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  768. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  769. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  770. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  771. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  772. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  773. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  774. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  775. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  776. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  777. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  778. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  779. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  780. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  781. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  782. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  783. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  784. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  785. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  786. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  787. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  788. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  789. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  790. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  791. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  792. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  793. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  794. vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  795. vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  796. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  797. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  798. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  799. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  800. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  801. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  802. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  803. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  804. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  805. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  806. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  807. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  808. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  809. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  810. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  811. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  812. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  813. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  814. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  815. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  816. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  817. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  818. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  819. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  820. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  821. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  822. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  823. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  824. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  825. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  826. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  827. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  828. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  829. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  830. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  831. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  832. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  833. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  834. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  835. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  836. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  837. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  838. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  839. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  840. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  841. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  842. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  843. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  844. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  845. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  846. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  847. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  848. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  849. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  850. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  851. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  852. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  853. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  854. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  855. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  856. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  857. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  858. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  859. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  860. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  861. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  862. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  863. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  864. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  865. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  866. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  867. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  868. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  869. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  870. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  871. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  872. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  873. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  874. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  875. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  876. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  877. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  878. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  879. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  880. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  881. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  882. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  883. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  884. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  885. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  886. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  887. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  888. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  889. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  890. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  891. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  892. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  893. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  894. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +18 -0
  895. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  896. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  897. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  898. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  899. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  900. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  901. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  902. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  903. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  904. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  905. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  906. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  907. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  908. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  909. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  910. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  911. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  912. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  913. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  914. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  915. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  916. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  917. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  918. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  919. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  920. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  921. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  922. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  923. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  924. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  925. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  926. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  927. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  928. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  929. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  930. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  931. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  932. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  933. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  934. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  935. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  936. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  937. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  938. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  939. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  940. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  941. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  942. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  943. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  944. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  945. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  946. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  947. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  948. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  949. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  950. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  951. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  952. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  953. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  954. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  955. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  956. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  957. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  958. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  959. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  960. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  961. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  962. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  963. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  964. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  965. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  966. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  967. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  968. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  969. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  970. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  971. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  972. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  973. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  974. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  975. vllm/model_executor/layers/quantization/utils/configs/README.md +3 -0
  976. vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +89 -0
  977. vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +298 -0
  978. vllm/model_executor/layers/quantization/utils/fp8_utils.py +1203 -0
  979. vllm/model_executor/layers/quantization/utils/gptq_utils.py +158 -0
  980. vllm/model_executor/layers/quantization/utils/int8_utils.py +489 -0
  981. vllm/model_executor/layers/quantization/utils/layer_utils.py +41 -0
  982. vllm/model_executor/layers/quantization/utils/machete_utils.py +56 -0
  983. vllm/model_executor/layers/quantization/utils/marlin_utils.py +575 -0
  984. vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +397 -0
  985. vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +351 -0
  986. vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +161 -0
  987. vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +467 -0
  988. vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +181 -0
  989. vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +142 -0
  990. vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +24 -0
  991. vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +142 -0
  992. vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +63 -0
  993. vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py +51 -0
  994. vllm/model_executor/layers/quantization/utils/petit_utils.py +124 -0
  995. vllm/model_executor/layers/quantization/utils/quant_utils.py +687 -0
  996. vllm/model_executor/layers/quantization/utils/w8a8_utils.py +516 -0
  997. vllm/model_executor/layers/resampler.py +283 -0
  998. vllm/model_executor/layers/rotary_embedding/__init__.py +278 -0
  999. vllm/model_executor/layers/rotary_embedding/base.py +235 -0
  1000. vllm/model_executor/layers/rotary_embedding/common.py +188 -0
  1001. vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +165 -0
  1002. vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +215 -0
  1003. vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py +43 -0
  1004. vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py +68 -0
  1005. vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +75 -0
  1006. vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py +115 -0
  1007. vllm/model_executor/layers/rotary_embedding/llama3_rope.py +54 -0
  1008. vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +80 -0
  1009. vllm/model_executor/layers/rotary_embedding/mrope.py +397 -0
  1010. vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py +47 -0
  1011. vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +159 -0
  1012. vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +81 -0
  1013. vllm/model_executor/layers/utils.py +251 -0
  1014. vllm/model_executor/layers/vocab_parallel_embedding.py +558 -0
  1015. vllm/model_executor/model_loader/__init__.py +148 -0
  1016. vllm/model_executor/model_loader/base_loader.py +57 -0
  1017. vllm/model_executor/model_loader/bitsandbytes_loader.py +822 -0
  1018. vllm/model_executor/model_loader/default_loader.py +327 -0
  1019. vllm/model_executor/model_loader/dummy_loader.py +28 -0
  1020. vllm/model_executor/model_loader/gguf_loader.py +176 -0
  1021. vllm/model_executor/model_loader/online_quantization.py +224 -0
  1022. vllm/model_executor/model_loader/runai_streamer_loader.py +116 -0
  1023. vllm/model_executor/model_loader/sharded_state_loader.py +206 -0
  1024. vllm/model_executor/model_loader/tensorizer.py +790 -0
  1025. vllm/model_executor/model_loader/tensorizer_loader.py +151 -0
  1026. vllm/model_executor/model_loader/tpu.py +118 -0
  1027. vllm/model_executor/model_loader/utils.py +288 -0
  1028. vllm/model_executor/model_loader/weight_utils.py +1084 -0
  1029. vllm/model_executor/models/__init__.py +44 -0
  1030. vllm/model_executor/models/adapters.py +543 -0
  1031. vllm/model_executor/models/afmoe.py +711 -0
  1032. vllm/model_executor/models/aimv2.py +247 -0
  1033. vllm/model_executor/models/apertus.py +587 -0
  1034. vllm/model_executor/models/arcee.py +439 -0
  1035. vllm/model_executor/models/arctic.py +635 -0
  1036. vllm/model_executor/models/aria.py +655 -0
  1037. vllm/model_executor/models/aya_vision.py +450 -0
  1038. vllm/model_executor/models/baichuan.py +496 -0
  1039. vllm/model_executor/models/bailing_moe.py +646 -0
  1040. vllm/model_executor/models/bamba.py +522 -0
  1041. vllm/model_executor/models/bee.py +157 -0
  1042. vllm/model_executor/models/bert.py +925 -0
  1043. vllm/model_executor/models/bert_with_rope.py +732 -0
  1044. vllm/model_executor/models/blip.py +349 -0
  1045. vllm/model_executor/models/blip2.py +695 -0
  1046. vllm/model_executor/models/bloom.py +390 -0
  1047. vllm/model_executor/models/chameleon.py +1120 -0
  1048. vllm/model_executor/models/chatglm.py +498 -0
  1049. vllm/model_executor/models/clip.py +965 -0
  1050. vllm/model_executor/models/cohere2_vision.py +472 -0
  1051. vllm/model_executor/models/commandr.py +473 -0
  1052. vllm/model_executor/models/config.py +503 -0
  1053. vllm/model_executor/models/dbrx.py +482 -0
  1054. vllm/model_executor/models/deepencoder.py +673 -0
  1055. vllm/model_executor/models/deepseek_eagle.py +260 -0
  1056. vllm/model_executor/models/deepseek_mtp.py +360 -0
  1057. vllm/model_executor/models/deepseek_ocr.py +593 -0
  1058. vllm/model_executor/models/deepseek_v2.py +1649 -0
  1059. vllm/model_executor/models/deepseek_vl2.py +655 -0
  1060. vllm/model_executor/models/dots1.py +574 -0
  1061. vllm/model_executor/models/dots_ocr.py +900 -0
  1062. vllm/model_executor/models/ernie45.py +53 -0
  1063. vllm/model_executor/models/ernie45_moe.py +759 -0
  1064. vllm/model_executor/models/ernie45_vl.py +1742 -0
  1065. vllm/model_executor/models/ernie45_vl_moe.py +803 -0
  1066. vllm/model_executor/models/ernie_mtp.py +279 -0
  1067. vllm/model_executor/models/exaone.py +545 -0
  1068. vllm/model_executor/models/exaone4.py +531 -0
  1069. vllm/model_executor/models/fairseq2_llama.py +154 -0
  1070. vllm/model_executor/models/falcon.py +545 -0
  1071. vllm/model_executor/models/falcon_h1.py +685 -0
  1072. vllm/model_executor/models/flex_olmo.py +155 -0
  1073. vllm/model_executor/models/fuyu.py +373 -0
  1074. vllm/model_executor/models/gemma.py +426 -0
  1075. vllm/model_executor/models/gemma2.py +439 -0
  1076. vllm/model_executor/models/gemma3.py +571 -0
  1077. vllm/model_executor/models/gemma3_mm.py +741 -0
  1078. vllm/model_executor/models/gemma3n.py +1165 -0
  1079. vllm/model_executor/models/gemma3n_mm.py +811 -0
  1080. vllm/model_executor/models/glm.py +23 -0
  1081. vllm/model_executor/models/glm4.py +305 -0
  1082. vllm/model_executor/models/glm4_1v.py +1821 -0
  1083. vllm/model_executor/models/glm4_moe.py +747 -0
  1084. vllm/model_executor/models/glm4_moe_mtp.py +359 -0
  1085. vllm/model_executor/models/glm4v.py +784 -0
  1086. vllm/model_executor/models/gpt2.py +397 -0
  1087. vllm/model_executor/models/gpt_bigcode.py +339 -0
  1088. vllm/model_executor/models/gpt_j.py +346 -0
  1089. vllm/model_executor/models/gpt_neox.py +344 -0
  1090. vllm/model_executor/models/gpt_oss.py +738 -0
  1091. vllm/model_executor/models/granite.py +516 -0
  1092. vllm/model_executor/models/granite_speech.py +913 -0
  1093. vllm/model_executor/models/granitemoe.py +569 -0
  1094. vllm/model_executor/models/granitemoehybrid.py +709 -0
  1095. vllm/model_executor/models/granitemoeshared.py +333 -0
  1096. vllm/model_executor/models/gritlm.py +245 -0
  1097. vllm/model_executor/models/grok1.py +558 -0
  1098. vllm/model_executor/models/h2ovl.py +554 -0
  1099. vllm/model_executor/models/hunyuan_v1.py +1053 -0
  1100. vllm/model_executor/models/hyperclovax_vision.py +1166 -0
  1101. vllm/model_executor/models/idefics2_vision_model.py +426 -0
  1102. vllm/model_executor/models/idefics3.py +717 -0
  1103. vllm/model_executor/models/interfaces.py +1092 -0
  1104. vllm/model_executor/models/interfaces_base.py +214 -0
  1105. vllm/model_executor/models/intern_vit.py +453 -0
  1106. vllm/model_executor/models/internlm2.py +460 -0
  1107. vllm/model_executor/models/internlm2_ve.py +142 -0
  1108. vllm/model_executor/models/interns1.py +830 -0
  1109. vllm/model_executor/models/interns1_vit.py +432 -0
  1110. vllm/model_executor/models/internvl.py +1452 -0
  1111. vllm/model_executor/models/jais.py +397 -0
  1112. vllm/model_executor/models/jamba.py +610 -0
  1113. vllm/model_executor/models/jina_vl.py +147 -0
  1114. vllm/model_executor/models/keye.py +1761 -0
  1115. vllm/model_executor/models/keye_vl1_5.py +726 -0
  1116. vllm/model_executor/models/kimi_linear.py +663 -0
  1117. vllm/model_executor/models/kimi_vl.py +578 -0
  1118. vllm/model_executor/models/lfm2.py +532 -0
  1119. vllm/model_executor/models/lfm2_moe.py +762 -0
  1120. vllm/model_executor/models/lightonocr.py +195 -0
  1121. vllm/model_executor/models/llama.py +732 -0
  1122. vllm/model_executor/models/llama4.py +859 -0
  1123. vllm/model_executor/models/llama4_eagle.py +223 -0
  1124. vllm/model_executor/models/llama_eagle.py +218 -0
  1125. vllm/model_executor/models/llama_eagle3.py +367 -0
  1126. vllm/model_executor/models/llava.py +842 -0
  1127. vllm/model_executor/models/llava_next.py +583 -0
  1128. vllm/model_executor/models/llava_next_video.py +467 -0
  1129. vllm/model_executor/models/llava_onevision.py +923 -0
  1130. vllm/model_executor/models/longcat_flash.py +749 -0
  1131. vllm/model_executor/models/longcat_flash_mtp.py +349 -0
  1132. vllm/model_executor/models/mamba.py +276 -0
  1133. vllm/model_executor/models/mamba2.py +289 -0
  1134. vllm/model_executor/models/medusa.py +179 -0
  1135. vllm/model_executor/models/midashenglm.py +827 -0
  1136. vllm/model_executor/models/mimo.py +188 -0
  1137. vllm/model_executor/models/mimo_mtp.py +294 -0
  1138. vllm/model_executor/models/minicpm.py +664 -0
  1139. vllm/model_executor/models/minicpm3.py +242 -0
  1140. vllm/model_executor/models/minicpm_eagle.py +389 -0
  1141. vllm/model_executor/models/minicpmo.py +768 -0
  1142. vllm/model_executor/models/minicpmv.py +1745 -0
  1143. vllm/model_executor/models/minimax_m2.py +552 -0
  1144. vllm/model_executor/models/minimax_text_01.py +1012 -0
  1145. vllm/model_executor/models/minimax_vl_01.py +396 -0
  1146. vllm/model_executor/models/mistral3.py +637 -0
  1147. vllm/model_executor/models/mixtral.py +621 -0
  1148. vllm/model_executor/models/mllama4.py +1147 -0
  1149. vllm/model_executor/models/mlp_speculator.py +235 -0
  1150. vllm/model_executor/models/modernbert.py +450 -0
  1151. vllm/model_executor/models/module_mapping.py +74 -0
  1152. vllm/model_executor/models/molmo.py +1555 -0
  1153. vllm/model_executor/models/moonvit.py +677 -0
  1154. vllm/model_executor/models/mpt.py +335 -0
  1155. vllm/model_executor/models/nano_nemotron_vl.py +1740 -0
  1156. vllm/model_executor/models/nemotron.py +518 -0
  1157. vllm/model_executor/models/nemotron_h.py +852 -0
  1158. vllm/model_executor/models/nemotron_nas.py +491 -0
  1159. vllm/model_executor/models/nemotron_vl.py +653 -0
  1160. vllm/model_executor/models/nvlm_d.py +216 -0
  1161. vllm/model_executor/models/olmo.py +414 -0
  1162. vllm/model_executor/models/olmo2.py +454 -0
  1163. vllm/model_executor/models/olmoe.py +498 -0
  1164. vllm/model_executor/models/openpangu.py +1062 -0
  1165. vllm/model_executor/models/openpangu_mtp.py +265 -0
  1166. vllm/model_executor/models/opt.py +426 -0
  1167. vllm/model_executor/models/orion.py +372 -0
  1168. vllm/model_executor/models/ouro.py +516 -0
  1169. vllm/model_executor/models/ovis.py +559 -0
  1170. vllm/model_executor/models/ovis2_5.py +673 -0
  1171. vllm/model_executor/models/paddleocr_vl.py +1407 -0
  1172. vllm/model_executor/models/paligemma.py +412 -0
  1173. vllm/model_executor/models/persimmon.py +377 -0
  1174. vllm/model_executor/models/phi.py +374 -0
  1175. vllm/model_executor/models/phi3.py +18 -0
  1176. vllm/model_executor/models/phi3v.py +737 -0
  1177. vllm/model_executor/models/phi4_multimodal.py +1447 -0
  1178. vllm/model_executor/models/phi4mm.py +1253 -0
  1179. vllm/model_executor/models/phi4mm_audio.py +1296 -0
  1180. vllm/model_executor/models/phi4mm_utils.py +1907 -0
  1181. vllm/model_executor/models/phimoe.py +675 -0
  1182. vllm/model_executor/models/pixtral.py +1352 -0
  1183. vllm/model_executor/models/plamo2.py +981 -0
  1184. vllm/model_executor/models/qwen.py +368 -0
  1185. vllm/model_executor/models/qwen2.py +541 -0
  1186. vllm/model_executor/models/qwen2_5_omni_thinker.py +1246 -0
  1187. vllm/model_executor/models/qwen2_5_vl.py +1613 -0
  1188. vllm/model_executor/models/qwen2_audio.py +473 -0
  1189. vllm/model_executor/models/qwen2_moe.py +596 -0
  1190. vllm/model_executor/models/qwen2_rm.py +123 -0
  1191. vllm/model_executor/models/qwen2_vl.py +1670 -0
  1192. vllm/model_executor/models/qwen3.py +336 -0
  1193. vllm/model_executor/models/qwen3_moe.py +744 -0
  1194. vllm/model_executor/models/qwen3_next.py +1395 -0
  1195. vllm/model_executor/models/qwen3_next_mtp.py +296 -0
  1196. vllm/model_executor/models/qwen3_omni_moe_thinker.py +1721 -0
  1197. vllm/model_executor/models/qwen3_vl.py +1673 -0
  1198. vllm/model_executor/models/qwen3_vl_moe.py +415 -0
  1199. vllm/model_executor/models/qwen_vl.py +802 -0
  1200. vllm/model_executor/models/radio.py +555 -0
  1201. vllm/model_executor/models/registry.py +1155 -0
  1202. vllm/model_executor/models/roberta.py +259 -0
  1203. vllm/model_executor/models/rvl.py +107 -0
  1204. vllm/model_executor/models/seed_oss.py +497 -0
  1205. vllm/model_executor/models/siglip.py +1174 -0
  1206. vllm/model_executor/models/siglip2navit.py +724 -0
  1207. vllm/model_executor/models/skyworkr1v.py +953 -0
  1208. vllm/model_executor/models/smolvlm.py +38 -0
  1209. vllm/model_executor/models/solar.py +502 -0
  1210. vllm/model_executor/models/stablelm.py +359 -0
  1211. vllm/model_executor/models/starcoder2.py +367 -0
  1212. vllm/model_executor/models/step3_text.py +559 -0
  1213. vllm/model_executor/models/step3_vl.py +1148 -0
  1214. vllm/model_executor/models/swin.py +514 -0
  1215. vllm/model_executor/models/tarsier.py +619 -0
  1216. vllm/model_executor/models/telechat2.py +153 -0
  1217. vllm/model_executor/models/teleflm.py +78 -0
  1218. vllm/model_executor/models/terratorch.py +319 -0
  1219. vllm/model_executor/models/transformers/__init__.py +127 -0
  1220. vllm/model_executor/models/transformers/base.py +464 -0
  1221. vllm/model_executor/models/transformers/causal.py +65 -0
  1222. vllm/model_executor/models/transformers/legacy.py +90 -0
  1223. vllm/model_executor/models/transformers/moe.py +318 -0
  1224. vllm/model_executor/models/transformers/multimodal.py +411 -0
  1225. vllm/model_executor/models/transformers/pooling.py +119 -0
  1226. vllm/model_executor/models/transformers/utils.py +207 -0
  1227. vllm/model_executor/models/ultravox.py +681 -0
  1228. vllm/model_executor/models/utils.py +877 -0
  1229. vllm/model_executor/models/vision.py +552 -0
  1230. vllm/model_executor/models/voxtral.py +845 -0
  1231. vllm/model_executor/models/whisper.py +959 -0
  1232. vllm/model_executor/models/zamba2.py +986 -0
  1233. vllm/model_executor/parameter.py +642 -0
  1234. vllm/model_executor/utils.py +94 -0
  1235. vllm/model_executor/warmup/__init__.py +0 -0
  1236. vllm/model_executor/warmup/deep_gemm_warmup.py +314 -0
  1237. vllm/model_executor/warmup/kernel_warmup.py +98 -0
  1238. vllm/multimodal/__init__.py +40 -0
  1239. vllm/multimodal/audio.py +118 -0
  1240. vllm/multimodal/base.py +26 -0
  1241. vllm/multimodal/cache.py +755 -0
  1242. vllm/multimodal/evs.py +294 -0
  1243. vllm/multimodal/hasher.py +106 -0
  1244. vllm/multimodal/image.py +130 -0
  1245. vllm/multimodal/inputs.py +1036 -0
  1246. vllm/multimodal/parse.py +544 -0
  1247. vllm/multimodal/processing.py +2186 -0
  1248. vllm/multimodal/profiling.py +369 -0
  1249. vllm/multimodal/registry.py +360 -0
  1250. vllm/multimodal/utils.py +512 -0
  1251. vllm/multimodal/video.py +306 -0
  1252. vllm/outputs.py +345 -0
  1253. vllm/platforms/__init__.py +277 -0
  1254. vllm/platforms/cpu.py +414 -0
  1255. vllm/platforms/cuda.py +657 -0
  1256. vllm/platforms/interface.py +639 -0
  1257. vllm/platforms/rocm.py +466 -0
  1258. vllm/platforms/tpu.py +276 -0
  1259. vllm/platforms/xpu.py +274 -0
  1260. vllm/plugins/__init__.py +78 -0
  1261. vllm/plugins/io_processors/__init__.py +68 -0
  1262. vllm/plugins/io_processors/interface.py +77 -0
  1263. vllm/plugins/lora_resolvers/__init__.py +0 -0
  1264. vllm/plugins/lora_resolvers/filesystem_resolver.py +52 -0
  1265. vllm/pooling_params.py +228 -0
  1266. vllm/profiler/__init__.py +0 -0
  1267. vllm/profiler/gpu_profiler.py +37 -0
  1268. vllm/profiler/layerwise_profile.py +392 -0
  1269. vllm/profiler/utils.py +151 -0
  1270. vllm/py.typed +2 -0
  1271. vllm/ray/__init__.py +0 -0
  1272. vllm/ray/lazy_utils.py +26 -0
  1273. vllm/ray/ray_env.py +79 -0
  1274. vllm/reasoning/__init__.py +92 -0
  1275. vllm/reasoning/abs_reasoning_parsers.py +290 -0
  1276. vllm/reasoning/basic_parsers.py +162 -0
  1277. vllm/reasoning/deepseek_r1_reasoning_parser.py +67 -0
  1278. vllm/reasoning/deepseek_v3_reasoning_parser.py +62 -0
  1279. vllm/reasoning/ernie45_reasoning_parser.py +165 -0
  1280. vllm/reasoning/glm4_moe_reasoning_parser.py +171 -0
  1281. vllm/reasoning/gptoss_reasoning_parser.py +173 -0
  1282. vllm/reasoning/granite_reasoning_parser.py +363 -0
  1283. vllm/reasoning/hunyuan_a13b_reasoning_parser.py +237 -0
  1284. vllm/reasoning/identity_reasoning_parser.py +58 -0
  1285. vllm/reasoning/minimax_m2_reasoning_parser.py +67 -0
  1286. vllm/reasoning/mistral_reasoning_parser.py +55 -0
  1287. vllm/reasoning/olmo3_reasoning_parser.py +302 -0
  1288. vllm/reasoning/qwen3_reasoning_parser.py +67 -0
  1289. vllm/reasoning/seedoss_reasoning_parser.py +27 -0
  1290. vllm/reasoning/step3_reasoning_parser.py +107 -0
  1291. vllm/sampling_params.py +669 -0
  1292. vllm/scalar_type.py +355 -0
  1293. vllm/scripts.py +17 -0
  1294. vllm/sequence.py +98 -0
  1295. vllm/tasks.py +13 -0
  1296. vllm/third_party/__init__.py +0 -0
  1297. vllm/third_party/pynvml.py +6140 -0
  1298. vllm/tracing.py +135 -0
  1299. vllm/transformers_utils/__init__.py +26 -0
  1300. vllm/transformers_utils/chat_templates/__init__.py +5 -0
  1301. vllm/transformers_utils/chat_templates/registry.py +73 -0
  1302. vllm/transformers_utils/chat_templates/template_basic.jinja +3 -0
  1303. vllm/transformers_utils/chat_templates/template_blip2.jinja +11 -0
  1304. vllm/transformers_utils/chat_templates/template_chatml.jinja +10 -0
  1305. vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja +14 -0
  1306. vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja +23 -0
  1307. vllm/transformers_utils/chat_templates/template_fuyu.jinja +3 -0
  1308. vllm/transformers_utils/chat_templates/template_minicpmv45.jinja +93 -0
  1309. vllm/transformers_utils/config.py +1203 -0
  1310. vllm/transformers_utils/config_parser_base.py +20 -0
  1311. vllm/transformers_utils/configs/__init__.py +70 -0
  1312. vllm/transformers_utils/configs/afmoe.py +84 -0
  1313. vllm/transformers_utils/configs/arctic.py +206 -0
  1314. vllm/transformers_utils/configs/chatglm.py +75 -0
  1315. vllm/transformers_utils/configs/deepseek_vl2.py +126 -0
  1316. vllm/transformers_utils/configs/dotsocr.py +71 -0
  1317. vllm/transformers_utils/configs/eagle.py +84 -0
  1318. vllm/transformers_utils/configs/falcon.py +89 -0
  1319. vllm/transformers_utils/configs/flex_olmo.py +77 -0
  1320. vllm/transformers_utils/configs/jais.py +243 -0
  1321. vllm/transformers_utils/configs/kimi_linear.py +144 -0
  1322. vllm/transformers_utils/configs/kimi_vl.py +38 -0
  1323. vllm/transformers_utils/configs/lfm2_moe.py +159 -0
  1324. vllm/transformers_utils/configs/medusa.py +65 -0
  1325. vllm/transformers_utils/configs/midashenglm.py +103 -0
  1326. vllm/transformers_utils/configs/mistral.py +174 -0
  1327. vllm/transformers_utils/configs/mlp_speculator.py +69 -0
  1328. vllm/transformers_utils/configs/moonvit.py +33 -0
  1329. vllm/transformers_utils/configs/nemotron.py +212 -0
  1330. vllm/transformers_utils/configs/nemotron_h.py +282 -0
  1331. vllm/transformers_utils/configs/olmo3.py +79 -0
  1332. vllm/transformers_utils/configs/ovis.py +182 -0
  1333. vllm/transformers_utils/configs/qwen3_next.py +274 -0
  1334. vllm/transformers_utils/configs/radio.py +89 -0
  1335. vllm/transformers_utils/configs/speculators/__init__.py +2 -0
  1336. vllm/transformers_utils/configs/speculators/algos.py +38 -0
  1337. vllm/transformers_utils/configs/speculators/base.py +114 -0
  1338. vllm/transformers_utils/configs/step3_vl.py +174 -0
  1339. vllm/transformers_utils/configs/ultravox.py +118 -0
  1340. vllm/transformers_utils/detokenizer_utils.py +198 -0
  1341. vllm/transformers_utils/dynamic_module.py +59 -0
  1342. vllm/transformers_utils/processor.py +402 -0
  1343. vllm/transformers_utils/processors/__init__.py +15 -0
  1344. vllm/transformers_utils/processors/deepseek_ocr.py +438 -0
  1345. vllm/transformers_utils/processors/deepseek_vl2.py +406 -0
  1346. vllm/transformers_utils/processors/ovis.py +453 -0
  1347. vllm/transformers_utils/processors/ovis2_5.py +468 -0
  1348. vllm/transformers_utils/runai_utils.py +104 -0
  1349. vllm/transformers_utils/s3_utils.py +95 -0
  1350. vllm/transformers_utils/tokenizer.py +293 -0
  1351. vllm/transformers_utils/tokenizer_base.py +155 -0
  1352. vllm/transformers_utils/tokenizers/__init__.py +16 -0
  1353. vllm/transformers_utils/tokenizers/mistral.py +502 -0
  1354. vllm/transformers_utils/utils.py +130 -0
  1355. vllm/triton_utils/__init__.py +19 -0
  1356. vllm/triton_utils/importing.py +103 -0
  1357. vllm/usage/__init__.py +0 -0
  1358. vllm/usage/usage_lib.py +294 -0
  1359. vllm/utils/__init__.py +82 -0
  1360. vllm/utils/argparse_utils.py +487 -0
  1361. vllm/utils/async_utils.py +303 -0
  1362. vllm/utils/cache.py +214 -0
  1363. vllm/utils/collection_utils.py +139 -0
  1364. vllm/utils/counter.py +45 -0
  1365. vllm/utils/deep_gemm.py +391 -0
  1366. vllm/utils/flashinfer.py +490 -0
  1367. vllm/utils/func_utils.py +236 -0
  1368. vllm/utils/gc_utils.py +147 -0
  1369. vllm/utils/hashing.py +63 -0
  1370. vllm/utils/import_utils.py +411 -0
  1371. vllm/utils/jsontree.py +165 -0
  1372. vllm/utils/math_utils.py +32 -0
  1373. vllm/utils/mem_constants.py +13 -0
  1374. vllm/utils/mem_utils.py +232 -0
  1375. vllm/utils/nccl.py +64 -0
  1376. vllm/utils/network_utils.py +331 -0
  1377. vllm/utils/platform_utils.py +59 -0
  1378. vllm/utils/profiling.py +56 -0
  1379. vllm/utils/registry.py +49 -0
  1380. vllm/utils/serial_utils.py +169 -0
  1381. vllm/utils/system_utils.py +229 -0
  1382. vllm/utils/tensor_schema.py +255 -0
  1383. vllm/utils/torch_utils.py +657 -0
  1384. vllm/v1/__init__.py +0 -0
  1385. vllm/v1/attention/__init__.py +0 -0
  1386. vllm/v1/attention/backends/__init__.py +0 -0
  1387. vllm/v1/attention/backends/cpu_attn.py +496 -0
  1388. vllm/v1/attention/backends/flash_attn.py +1028 -0
  1389. vllm/v1/attention/backends/flashinfer.py +1572 -0
  1390. vllm/v1/attention/backends/flex_attention.py +926 -0
  1391. vllm/v1/attention/backends/gdn_attn.py +387 -0
  1392. vllm/v1/attention/backends/linear_attn.py +74 -0
  1393. vllm/v1/attention/backends/mamba1_attn.py +165 -0
  1394. vllm/v1/attention/backends/mamba2_attn.py +354 -0
  1395. vllm/v1/attention/backends/mamba_attn.py +115 -0
  1396. vllm/v1/attention/backends/mla/__init__.py +0 -0
  1397. vllm/v1/attention/backends/mla/common.py +2031 -0
  1398. vllm/v1/attention/backends/mla/cutlass_mla.py +275 -0
  1399. vllm/v1/attention/backends/mla/flashattn_mla.py +337 -0
  1400. vllm/v1/attention/backends/mla/flashinfer_mla.py +171 -0
  1401. vllm/v1/attention/backends/mla/flashmla.py +314 -0
  1402. vllm/v1/attention/backends/mla/flashmla_sparse.py +548 -0
  1403. vllm/v1/attention/backends/mla/indexer.py +362 -0
  1404. vllm/v1/attention/backends/mla/rocm_aiter_mla.py +294 -0
  1405. vllm/v1/attention/backends/mla/triton_mla.py +171 -0
  1406. vllm/v1/attention/backends/pallas.py +436 -0
  1407. vllm/v1/attention/backends/rocm_aiter_fa.py +816 -0
  1408. vllm/v1/attention/backends/rocm_aiter_unified_attn.py +196 -0
  1409. vllm/v1/attention/backends/rocm_attn.py +362 -0
  1410. vllm/v1/attention/backends/short_conv_attn.py +105 -0
  1411. vllm/v1/attention/backends/tree_attn.py +425 -0
  1412. vllm/v1/attention/backends/triton_attn.py +373 -0
  1413. vllm/v1/attention/backends/utils.py +1116 -0
  1414. vllm/v1/attention/backends/xformers.py +417 -0
  1415. vllm/v1/core/__init__.py +0 -0
  1416. vllm/v1/core/block_pool.py +428 -0
  1417. vllm/v1/core/encoder_cache_manager.py +343 -0
  1418. vllm/v1/core/kv_cache_coordinator.py +480 -0
  1419. vllm/v1/core/kv_cache_manager.py +420 -0
  1420. vllm/v1/core/kv_cache_utils.py +1340 -0
  1421. vllm/v1/core/sched/__init__.py +0 -0
  1422. vllm/v1/core/sched/async_scheduler.py +62 -0
  1423. vllm/v1/core/sched/interface.py +181 -0
  1424. vllm/v1/core/sched/output.py +202 -0
  1425. vllm/v1/core/sched/request_queue.py +221 -0
  1426. vllm/v1/core/sched/scheduler.py +1617 -0
  1427. vllm/v1/core/sched/utils.py +72 -0
  1428. vllm/v1/core/single_type_kv_cache_manager.py +736 -0
  1429. vllm/v1/cudagraph_dispatcher.py +148 -0
  1430. vllm/v1/engine/__init__.py +206 -0
  1431. vllm/v1/engine/async_llm.py +797 -0
  1432. vllm/v1/engine/coordinator.py +377 -0
  1433. vllm/v1/engine/core.py +1420 -0
  1434. vllm/v1/engine/core_client.py +1400 -0
  1435. vllm/v1/engine/detokenizer.py +351 -0
  1436. vllm/v1/engine/exceptions.py +18 -0
  1437. vllm/v1/engine/llm_engine.py +408 -0
  1438. vllm/v1/engine/logprobs.py +182 -0
  1439. vllm/v1/engine/output_processor.py +642 -0
  1440. vllm/v1/engine/parallel_sampling.py +145 -0
  1441. vllm/v1/engine/processor.py +621 -0
  1442. vllm/v1/engine/utils.py +1072 -0
  1443. vllm/v1/executor/__init__.py +6 -0
  1444. vllm/v1/executor/abstract.py +352 -0
  1445. vllm/v1/executor/multiproc_executor.py +877 -0
  1446. vllm/v1/executor/ray_distributed_executor.py +8 -0
  1447. vllm/v1/executor/ray_executor.py +626 -0
  1448. vllm/v1/executor/ray_utils.py +465 -0
  1449. vllm/v1/executor/uniproc_executor.py +183 -0
  1450. vllm/v1/kv_cache_interface.py +403 -0
  1451. vllm/v1/kv_offload/__init__.py +0 -0
  1452. vllm/v1/kv_offload/abstract.py +161 -0
  1453. vllm/v1/kv_offload/arc_manager.py +237 -0
  1454. vllm/v1/kv_offload/backend.py +97 -0
  1455. vllm/v1/kv_offload/backends/__init__.py +0 -0
  1456. vllm/v1/kv_offload/backends/cpu.py +62 -0
  1457. vllm/v1/kv_offload/cpu.py +93 -0
  1458. vllm/v1/kv_offload/factory.py +56 -0
  1459. vllm/v1/kv_offload/lru_manager.py +139 -0
  1460. vllm/v1/kv_offload/mediums.py +39 -0
  1461. vllm/v1/kv_offload/spec.py +62 -0
  1462. vllm/v1/kv_offload/worker/__init__.py +0 -0
  1463. vllm/v1/kv_offload/worker/cpu_gpu.py +185 -0
  1464. vllm/v1/kv_offload/worker/worker.py +144 -0
  1465. vllm/v1/metrics/__init__.py +0 -0
  1466. vllm/v1/metrics/loggers.py +1238 -0
  1467. vllm/v1/metrics/prometheus.py +82 -0
  1468. vllm/v1/metrics/ray_wrappers.py +169 -0
  1469. vllm/v1/metrics/reader.py +257 -0
  1470. vllm/v1/metrics/stats.py +420 -0
  1471. vllm/v1/outputs.py +249 -0
  1472. vllm/v1/pool/__init__.py +0 -0
  1473. vllm/v1/pool/metadata.py +82 -0
  1474. vllm/v1/request.py +259 -0
  1475. vllm/v1/sample/__init__.py +0 -0
  1476. vllm/v1/sample/logits_processor/__init__.py +352 -0
  1477. vllm/v1/sample/logits_processor/builtin.py +274 -0
  1478. vllm/v1/sample/logits_processor/interface.py +106 -0
  1479. vllm/v1/sample/logits_processor/state.py +165 -0
  1480. vllm/v1/sample/metadata.py +44 -0
  1481. vllm/v1/sample/ops/__init__.py +0 -0
  1482. vllm/v1/sample/ops/bad_words.py +52 -0
  1483. vllm/v1/sample/ops/logprobs.py +25 -0
  1484. vllm/v1/sample/ops/penalties.py +57 -0
  1485. vllm/v1/sample/ops/topk_topp_sampler.py +290 -0
  1486. vllm/v1/sample/rejection_sampler.py +793 -0
  1487. vllm/v1/sample/sampler.py +316 -0
  1488. vllm/v1/sample/tpu/__init__.py +0 -0
  1489. vllm/v1/sample/tpu/metadata.py +120 -0
  1490. vllm/v1/sample/tpu/sampler.py +215 -0
  1491. vllm/v1/serial_utils.py +532 -0
  1492. vllm/v1/spec_decode/__init__.py +0 -0
  1493. vllm/v1/spec_decode/eagle.py +1225 -0
  1494. vllm/v1/spec_decode/medusa.py +73 -0
  1495. vllm/v1/spec_decode/metadata.py +66 -0
  1496. vllm/v1/spec_decode/metrics.py +224 -0
  1497. vllm/v1/spec_decode/ngram_proposer.py +291 -0
  1498. vllm/v1/spec_decode/suffix_decoding.py +103 -0
  1499. vllm/v1/spec_decode/utils.py +16 -0
  1500. vllm/v1/structured_output/__init__.py +338 -0
  1501. vllm/v1/structured_output/backend_guidance.py +265 -0
  1502. vllm/v1/structured_output/backend_lm_format_enforcer.py +177 -0
  1503. vllm/v1/structured_output/backend_outlines.py +324 -0
  1504. vllm/v1/structured_output/backend_types.py +136 -0
  1505. vllm/v1/structured_output/backend_xgrammar.py +362 -0
  1506. vllm/v1/structured_output/request.py +94 -0
  1507. vllm/v1/structured_output/utils.py +469 -0
  1508. vllm/v1/utils.py +414 -0
  1509. vllm/v1/worker/__init__.py +0 -0
  1510. vllm/v1/worker/block_table.py +327 -0
  1511. vllm/v1/worker/cpu_model_runner.py +122 -0
  1512. vllm/v1/worker/cpu_worker.py +206 -0
  1513. vllm/v1/worker/dp_utils.py +230 -0
  1514. vllm/v1/worker/ec_connector_model_runner_mixin.py +87 -0
  1515. vllm/v1/worker/gpu_input_batch.py +975 -0
  1516. vllm/v1/worker/gpu_model_runner.py +5102 -0
  1517. vllm/v1/worker/gpu_ubatch_wrapper.py +466 -0
  1518. vllm/v1/worker/gpu_worker.py +894 -0
  1519. vllm/v1/worker/kv_connector_model_runner_mixin.py +144 -0
  1520. vllm/v1/worker/lora_model_runner_mixin.py +213 -0
  1521. vllm/v1/worker/tpu_input_batch.py +593 -0
  1522. vllm/v1/worker/tpu_model_runner.py +2173 -0
  1523. vllm/v1/worker/tpu_worker.py +355 -0
  1524. vllm/v1/worker/ubatch_utils.py +73 -0
  1525. vllm/v1/worker/ubatching.py +231 -0
  1526. vllm/v1/worker/utils.py +366 -0
  1527. vllm/v1/worker/worker_base.py +375 -0
  1528. vllm/v1/worker/xpu_model_runner.py +55 -0
  1529. vllm/v1/worker/xpu_worker.py +189 -0
  1530. vllm/version.py +39 -0
  1531. vllm/vllm_flash_attn/.gitkeep +0 -0
  1532. vllm_cpu_amxbf16-0.11.2.post2.dist-info/METADATA +345 -0
  1533. vllm_cpu_amxbf16-0.11.2.post2.dist-info/RECORD +1536 -0
  1534. vllm_cpu_amxbf16-0.11.2.post2.dist-info/WHEEL +5 -0
  1535. vllm_cpu_amxbf16-0.11.2.post2.dist-info/entry_points.txt +5 -0
  1536. vllm_cpu_amxbf16-0.11.2.post2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2031 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ # MLA Common Components
5
+
6
+ This file implements common components for MLA implementations.
7
+
8
+ First we define:
9
+
10
+ Sq as Q sequence length
11
+ Skv as KV sequence length
12
+
13
+ MLA has two possible ways of computing, a data-movement friendly approach and a
14
+ compute friendly approach, we generally want to use the compute friendly
15
+ approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
16
+ and the data-movement friendly approach for "decode" (i.e. the ratio
17
+ Sq / Skv is "large").
18
+
19
+ NOTE what we deem small and large is currently determined by if its labelled
20
+ prefill or decode by the scheduler, but this is something we should probably
21
+ tune.
22
+
23
+ Main reference: DeepseekV2 paper, and FlashInfer Implementation
24
+ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
25
+
26
+ Deepseek's MLA attention works the following way:
27
+ * Use a single latent vector to represent the per-token entry of the KV cache.
28
+ * For decode (i.e. the memory friendly approach) the attention "simulates" a
29
+ multi-head attention, while the compute is similar to multi-query attention.
30
+
31
+ Below is example of both paths assuming batchsize = 1
32
+
33
+ ## More Extent Definitions:
34
+
35
+ C Context length, `Skv - Sq`
36
+ H hidden size
37
+ N number of attention heads
38
+ Lq latent dimension for Q 1536 in DSV3
39
+ Lkv latent dimension for K/V 512 in DSV3
40
+ P nope dimension, no rope. 128 in DSV3
41
+ R rope dimension, goes through rope. 64 in DSV3
42
+ V V head dim. 128 in DSV3
43
+
44
+ ## Vector/Matrix Definitions
45
+
46
+ h_t hidden states (input to attention) shape [Sq, H]
47
+ q_c latent/compressed Q shape [Sq, Lq]
48
+ q_nope uncompressed Q (no-rope) shape [Sq, N, P]
49
+ q_pe uncompressed Q (rope) shape [Sq, N, R]
50
+ kv_c latent/compressed KV shape [Skv, Lkv]
51
+ k_pe decoupled k position embeddings shape [Skv, R]
52
+ new_kv_c new kv_c from current iter shape [Sq, Lkv]
53
+ new_k_pe new k_pe from current iter shape [Sq, R]
54
+ cache_kv_c cached k_c from previous iters shape [C, Lkv]
55
+ cache_k_pe cached k_pe from previous iters shape [C, R]
56
+ W_DQ project h_t to q_c shape [H, Lq]
57
+ W_UQ project q_c to q_nope shape [Lq, N * P]
58
+ W_QR project q_c to q_pe shape [Lq, N * R]
59
+ W_DKV project h_t to kv_c shape [H, Lkv]
60
+ W_UK project kv_c to k_nope shape [Lkv, N, P]
61
+ W_KR project h_t to k_pe shape [H, R]
62
+ W_UV project kv_c to v shape [Lkv, N, V]
63
+ W_O project v to h_t shape [N * V, H]
64
+
65
+
66
+ ## Compute Friendly Approach (i.e. "_forward_prefill"):
67
+
68
+ q_c = h_t @ W_DQ
69
+ q_nope = (q_c @ W_UQ).view(Sq, N, P)
70
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
71
+ new_kv_c = h_t @ W_DKV
72
+ new_k_pe = RoPE(h_t @ W_KR)
73
+ kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
74
+ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
75
+ k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
76
+ v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
77
+
78
+ // MHA with QK headdim = P + R
79
+ // V headdim = V
80
+ // spda_o shape [Sq, N, V]
81
+ spda_o = scaled_dot_product_attention(
82
+ torch.cat([q_nope, q_pe], dim=-1),
83
+ torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
84
+ v
85
+ )
86
+ return spda_o @ W_O
87
+
88
+ NOTE: in the actual code,
89
+ `kv_b_proj` is [W_UK; W_UV] concatenated per head
90
+ `q_b_proj` is [W_UQ; W_QR] concatenated per head
91
+ `out_proj` is W_O
92
+
93
+
94
+ ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
95
+
96
+ Runtime
97
+ q_c = h_t @ W_DQ
98
+ q_nope = (q_c @ W_UQ).view(-1, N, P)
99
+ ql_nope = einsum("snh,lnh->snl", q, W_UK)
100
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
101
+ new_kv_c = h_t @ W_DKV
102
+ new_k_pe = RoPE(h_t @ W_KR)
103
+ kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
104
+ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
105
+
106
+ // MQA with QK headdim = Lkv + R
107
+ // V headdim = Lkv
108
+ // spda_o shape [Sq, N, Lkv]
109
+ // NOTE: this is less compute-friendly since Lkv > P
110
+ // but is more data-movement friendly since its MQA vs MHA
111
+ spda_o = scaled_dot_product_attention(
112
+ torch.cat([ql_nope, q_pe], dim=-1),
113
+ torch.cat([kv_c, k_pe], dim=-1),
114
+ kv_c
115
+ )
116
+
117
+ o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
118
+ return o.view(-1, N * V) @ self.num_heads @ W_O
119
+
120
+
121
+ ## Chunked Prefill
122
+
123
+ For chunked prefill we want to use the compute friendly algorithm. We are
124
+ assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
125
+ the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
126
+
127
+ However, the compute-friendly approach can potentially run out of memory if Skv
128
+ is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
129
+
130
+ To mitigate this, we chunk the computation of attention with respect to the
131
+ current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
132
+ fixed workspace size.
133
+
134
+ The chunked prefill approach is as follows:
135
+
136
+ MCC Max chunk of context to process per iter, computed dynamically,
137
+ used to bound the memory usage
138
+
139
+ q_c = h_t @ W_DQ
140
+ q_nope = (q_c @ W_UQ).view(Sq, N, P)
141
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
142
+ new_kv_c = h_t @ W_DKV
143
+ new_k_pe = RoPE(h_t @ W_KR)
144
+ new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
145
+ new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
146
+
147
+ // MHA between queries and new KV
148
+ // with QK headdim = P + R
149
+ // V headdim = V
150
+ // curr_o shape [Sq, N, V]
151
+ // curr_lse shape [N, Sq], this is just order FA returns
152
+ curr_o, curr_lse = scaled_dot_product_attention(
153
+ torch.cat([q_nope, q_pe], dim=-1),
154
+ torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
155
+ new_v,
156
+ casual=True,
157
+ return_softmax_lse=True
158
+ )
159
+
160
+ // Compute attention with the already existing context
161
+ for chunk_idx in range(cdiv(C, MCC)):
162
+ chunk_start = chunk_idx * MCC
163
+ chunk_end = min(chunk_start + MCC, C)
164
+ Sc = chunk_end - chunk_start
165
+ cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
166
+ cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
167
+ cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
168
+ cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
169
+
170
+ chunk_o, chunk_lse = scaled_dot_product_attention(
171
+ torch.cat([q_nope, q_pe], dim=-1),
172
+ torch.cat([cache_k_nope_chunk,
173
+ cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
174
+ dim=-1),
175
+ cache_v_chunk,
176
+ casual=False,
177
+ return_softmax_lse=True
178
+ )
179
+
180
+ curr_o, curr_lse = merge_attn_states(
181
+ suffix_output=curr_o,
182
+ suffix_lse=curr_lse,
183
+ prefix_output=chunk_o,
184
+ prefix_lse=chunk_lse,
185
+ )
186
+
187
+ return curr_o @ W_O
188
+ """
189
+
190
+ import functools
191
+ from abc import abstractmethod
192
+ from dataclasses import dataclass, field
193
+ from enum import Enum
194
+ from typing import ClassVar, Generic, TypeVar
195
+
196
+ import torch
197
+ from tqdm import tqdm
198
+
199
+ from vllm import _custom_ops as ops
200
+ from vllm import envs
201
+ from vllm._aiter_ops import rocm_aiter_ops
202
+ from vllm.attention.backends.abstract import (
203
+ AttentionBackend,
204
+ AttentionLayer,
205
+ MLAAttentionImpl,
206
+ )
207
+ from vllm.attention.backends.utils import get_mla_dims
208
+ from vllm.attention.ops.common import cp_lse_ag_out_rs
209
+ from vllm.attention.ops.merge_attn_states import merge_attn_states
210
+ from vllm.attention.utils.fa_utils import get_flash_attn_version
211
+ from vllm.config import VllmConfig, get_current_vllm_config
212
+ from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
213
+ from vllm.logger import init_logger
214
+ from vllm.model_executor.layers.batch_invariant import (
215
+ vllm_is_batch_invariant,
216
+ )
217
+ from vllm.model_executor.layers.linear import (
218
+ ColumnParallelLinear,
219
+ LinearBase,
220
+ UnquantizedLinearMethod,
221
+ )
222
+ from vllm.platforms import current_platform
223
+ from vllm.utils.flashinfer import has_nvidia_artifactory
224
+ from vllm.utils.math_utils import cdiv, round_down
225
+ from vllm.v1.attention.backends.utils import (
226
+ AttentionMetadataBuilder,
227
+ CommonAttentionMetadata,
228
+ get_dcp_local_seq_lens,
229
+ get_per_layer_parameters,
230
+ infer_global_hyperparameters,
231
+ split_decodes_and_prefills,
232
+ )
233
+ from vllm.v1.kv_cache_interface import AttentionSpec
234
+
235
+
236
+ class QueryLenSupport(Enum):
237
+ """Defines the level of query length support for an attention backend's
238
+ decode pipeline.
239
+
240
+ - SINGLE_ONLY: Decode pipeline only supports single-token queries
241
+ (query_len=1)
242
+ - UNIFORM: Decode pipeline supports uniform multi-token queries
243
+ (all requests must have same query_len > 1)
244
+ - VARLEN: Decode pipeline supports variable-length queries
245
+ (mixed query lengths in same batch)
246
+ """
247
+
248
+ SINGLE_ONLY = "single_only"
249
+ UNIFORM = "uniform"
250
+ VARLEN = "varlen"
251
+
252
+
253
+ try:
254
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
255
+
256
+ is_vllm_fa = True
257
+ except ImportError:
258
+ # For rocm use upstream flash attention
259
+ if current_platform.is_rocm():
260
+ from flash_attn import flash_attn_varlen_func
261
+ is_vllm_fa = False
262
+
263
+ try:
264
+ from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
265
+ from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
266
+
267
+ flashinfer_available = True
268
+ except ImportError:
269
+ BatchPrefillWithRaggedKVCacheWrapper = object
270
+
271
+ flashinfer_available = False
272
+
273
+
274
+ def dynamic_per_batched_tensor_quant(
275
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
276
+ ):
277
+ DTYPE_MAX = torch.finfo(dtype).max
278
+ min_val, max_val = x.aminmax()
279
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
280
+ scale = DTYPE_MAX / amax
281
+ x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
282
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
283
+
284
+
285
+ logger = init_logger(__name__)
286
+
287
+ CUDNN_WORKSPACE_SIZE = 12800
288
+
289
+
290
+ class MLACommonBackend(AttentionBackend):
291
+ accept_output_buffer: bool = True
292
+
293
+ @staticmethod
294
+ def get_name() -> str:
295
+ return "TRITON_MLA"
296
+
297
+ @staticmethod
298
+ def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
299
+ return MLACommonMetadataBuilder
300
+
301
+ @staticmethod
302
+ def get_kv_cache_shape(
303
+ num_blocks: int,
304
+ block_size: int,
305
+ num_kv_heads: int, # assumed to be 1 for MLA
306
+ head_size: int,
307
+ cache_dtype_str: str = "auto",
308
+ ) -> tuple[int, ...]:
309
+ return (num_blocks, block_size, head_size)
310
+
311
+ @classmethod
312
+ def get_supported_head_sizes(cls) -> list[int]:
313
+ return [576]
314
+
315
+ @classmethod
316
+ def is_mla(cls) -> bool:
317
+ return True
318
+
319
+
320
+ @dataclass
321
+ class MLACommonPrefillMetadata:
322
+ """Prefill Specific Metadata"""
323
+
324
+ @dataclass
325
+ class ChunkedContextMetadata:
326
+ # New for MLA (compared to FlashAttention)
327
+ # For handling chunked prefill
328
+ cu_seq_lens: torch.Tensor
329
+ starts: torch.Tensor
330
+ seq_tot: list[int]
331
+ max_seq_lens: list[int]
332
+ seq_lens: torch.Tensor
333
+ workspace: torch.Tensor
334
+
335
+ # for mla DCP
336
+ padded_local_chunk_seq_lens: list[list[int]] | None = None
337
+ local_context_lens_allranks: list[list[int]] | None = None
338
+ padded_local_cu_seq_lens: torch.Tensor | None = None
339
+ cu_seq_lens_lst: list[list[int]] | None = None
340
+ chunk_size: int | None = None
341
+
342
+ block_table: torch.Tensor
343
+ query_start_loc: torch.Tensor
344
+ max_query_len: int
345
+ chunked_context: ChunkedContextMetadata | None = None
346
+ query_seq_lens: torch.Tensor | None = None
347
+
348
+
349
+ @dataclass
350
+ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
351
+ prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
352
+ prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field(
353
+ default_factory=list
354
+ )
355
+
356
+
357
+ @dataclass
358
+ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
359
+ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
360
+ seq_lens: torch.Tensor
361
+
362
+ cudnn_workspace: torch.Tensor | None = None
363
+
364
+
365
+ @dataclass
366
+ class MLACommonDecodeMetadata:
367
+ block_table: torch.Tensor
368
+ seq_lens: torch.Tensor
369
+ dcp_tot_seq_lens: torch.Tensor | None
370
+
371
+
372
+ D = TypeVar("D", bound=MLACommonDecodeMetadata)
373
+
374
+
375
+ @dataclass
376
+ class MLACommonMetadata(Generic[D]):
377
+ """Metadata for MLACommon.
378
+
379
+ NOTE: Please read the comment at the top of the file before trying to
380
+ understand this class
381
+ """
382
+
383
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
384
+ # |---------- N-1 iteration --------|
385
+ # |---------------- N iteration ---------------------|
386
+ # |- tokenA -|......................|-- newTokens ---|
387
+ # |---------- context_len ----------|
388
+ # |-------------------- seq_len ---------------------|
389
+ # |-- query_len ---|
390
+
391
+ num_reqs: int
392
+ max_query_len: int
393
+ max_seq_len: int
394
+
395
+ num_actual_tokens: int # Number of tokens excluding padding.
396
+ query_start_loc: torch.Tensor
397
+ slot_mapping: torch.Tensor
398
+
399
+ # New for MLA (compared to FlashAttention)
400
+ # For handling prefill decode split
401
+ num_decodes: int
402
+ num_decode_tokens: int
403
+ num_prefills: int
404
+
405
+ # The dimension of the attention heads
406
+ head_dim: int | None = None
407
+
408
+ decode: D | None = None
409
+ prefill: (
410
+ MLACommonPrefillMetadata
411
+ | FlashInferPrefillMetadata
412
+ | CudnnPrefillMetadata
413
+ | None
414
+ ) = None
415
+
416
+ def __post_init__(self):
417
+ if self.head_dim is not None and not MLACommonBackend.supports_head_size(
418
+ self.head_dim
419
+ ):
420
+ raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
421
+
422
+
423
+ M = TypeVar("M", bound=MLACommonMetadata)
424
+ A = TypeVar("A")
425
+
426
+
427
+ def use_flashinfer_prefill() -> bool:
428
+ # For blackwell default to flashinfer prefill if it's available since
429
+ # it is faster than FA2.
430
+ return (
431
+ not envs.VLLM_DISABLE_FLASHINFER_PREFILL
432
+ and flashinfer_available
433
+ and not envs.VLLM_USE_CUDNN_PREFILL
434
+ and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
435
+ and current_platform.is_device_capability(100)
436
+ )
437
+
438
+
439
+ def use_cudnn_prefill() -> bool:
440
+ return (
441
+ flashinfer_available
442
+ and envs.VLLM_USE_CUDNN_PREFILL
443
+ and current_platform.is_device_capability(100)
444
+ and has_nvidia_artifactory()
445
+ )
446
+
447
+
448
+ def use_trtllm_ragged_deepseek_prefill() -> bool:
449
+ """Check if TRT-LLM ragged DeepSeek prefill should be used."""
450
+ return (
451
+ flashinfer_available
452
+ and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
453
+ and current_platform.is_device_capability(100)
454
+ )
455
+
456
+
457
+ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
458
+ """
459
+ NOTE: Please read the comment at the top of the file before trying to
460
+ understand this class
461
+ """
462
+
463
+ # Defines the level of query length support for this backend.
464
+ # - SINGLE_ONLY: Only single-token queries (no spec decode support)
465
+ # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
466
+ # - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
467
+ # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
468
+ # speculative decoding is enabled.
469
+ query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
470
+
471
+ # The threshold for reordering the batch into decode and prefill requests.
472
+ # If > 1, the batch will be reordered such that requests with
473
+ # query length <= threshold are classified as decode requests.
474
+ # Use `query_len_support` (above) to set this automatically
475
+ # when speculative decoding is enabled.
476
+ reorder_batch_threshold: int = 1
477
+
478
+ @staticmethod
479
+ def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
480
+ scheduler_config = vllm_config.scheduler_config
481
+ cache_config = vllm_config.cache_config
482
+ model_config = vllm_config.model_config
483
+
484
+ chunked_prefill_workspace_size = min(
485
+ # Try for 8 full length request or at least 4 pages per-request
486
+ max(
487
+ 8 * model_config.max_model_len,
488
+ 4 * scheduler_config.max_num_seqs * cache_config.block_size,
489
+ ),
490
+ # For long-context models try not to over-allocate limiting
491
+ # kv-cache space, limiting it to 64k tokens,
492
+ # which would result in the workspace being:
493
+ # 2*(576)*(64*1024) = 144mb
494
+ # (assuming 576 MLA head dim, and fp16)
495
+ # which would result in up-projected context being
496
+ # 2*(192*128)*(64*1024) = 3gb
497
+ # (assuming 192 QK head dim, 128 heads, and fp16)
498
+ 64 * 1024,
499
+ )
500
+
501
+ # Enforce that we enough for at least 1 page per request
502
+ chunked_prefill_workspace_size = max(
503
+ chunked_prefill_workspace_size,
504
+ scheduler_config.max_num_seqs * cache_config.block_size,
505
+ )
506
+
507
+ return chunked_prefill_workspace_size
508
+
509
+ def __init__(
510
+ self,
511
+ kv_cache_spec: AttentionSpec,
512
+ layer_names: list[str],
513
+ vllm_config: VllmConfig,
514
+ device: torch.device,
515
+ metadata_cls: type[M] | None = None,
516
+ supports_dcp_with_varlen: bool = False,
517
+ ):
518
+ self.metadata_cls = (
519
+ metadata_cls if metadata_cls is not None else MLACommonMetadata
520
+ )
521
+ self.kv_cache_spec = kv_cache_spec
522
+ scheduler_config = vllm_config.scheduler_config
523
+ self.model_config = vllm_config.model_config
524
+ parallel_config = vllm_config.parallel_config
525
+ self.compilation_config = vllm_config.compilation_config
526
+ self.vllm_config = vllm_config
527
+ self.device = device
528
+
529
+ self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
530
+ self.mla_dims = get_mla_dims(self.model_config)
531
+ self.aot_schedule = current_platform.is_cuda()
532
+ try:
533
+ self.dcp_world_size = get_dcp_group().world_size
534
+ self.dcp_rank = get_dcp_group().rank_in_group
535
+ except AssertionError:
536
+ # DCP might not be initialized in testing
537
+ self.dcp_world_size = 1
538
+ self.dcp_rank = 0
539
+ self.dcp_local_block_size = parallel_config.dcp_kv_cache_interleave_size
540
+ self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
541
+
542
+ # Don't try to access the runner on AMD
543
+ if self.aot_schedule:
544
+ self.page_size = self.kv_cache_spec.block_size
545
+
546
+ self.chunked_prefill_workspace_size = (
547
+ self.determine_chunked_prefill_workspace_size(vllm_config)
548
+ )
549
+
550
+ if self.dcp_world_size > 1:
551
+ # Note(hc): The local kvcache is incomplete when DCP is triggered,
552
+ # an additional kvcache allgather across the DCP group is therefore
553
+ # required, so the workspace has to be enlarged by 1/DCP relative
554
+ # to the original TP allocation.
555
+ assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
556
+ self.chunked_prefill_workspace = torch.empty(
557
+ (
558
+ self.chunked_prefill_workspace_size
559
+ + self.chunked_prefill_workspace_size // self.dcp_world_size,
560
+ self.model_config.get_head_size(),
561
+ ),
562
+ dtype=self.model_config.dtype,
563
+ device=device,
564
+ )
565
+ else:
566
+ self.chunked_prefill_workspace = torch.empty(
567
+ (
568
+ self.chunked_prefill_workspace_size,
569
+ self.model_config.get_head_size(),
570
+ ),
571
+ dtype=self.model_config.dtype,
572
+ device=device,
573
+ )
574
+
575
+ self._use_cudnn_prefill = use_cudnn_prefill()
576
+ self._use_fi_prefill = use_flashinfer_prefill()
577
+ self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
578
+ self.prefill_metadata_cls = (
579
+ FlashInferPrefillMetadata
580
+ if self._use_fi_prefill
581
+ else CudnnPrefillMetadata
582
+ if self._use_cudnn_prefill
583
+ else MLACommonPrefillMetadata
584
+ )
585
+
586
+ if self._use_fi_prefill:
587
+ self._workspace_buffer = torch.empty(
588
+ envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
589
+ dtype=torch.uint8,
590
+ device=device,
591
+ )
592
+
593
+ self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
594
+ self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
595
+
596
+ self._global_hyperparameters = infer_global_hyperparameters(
597
+ get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
598
+ )
599
+
600
+ if self._use_trtllm_ragged_prefill:
601
+ self._workspace_buffer = torch.empty(
602
+ envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
603
+ dtype=torch.uint8,
604
+ device=device,
605
+ )
606
+
607
+ if self._use_cudnn_prefill:
608
+ self.cudnn_workspace = torch.empty(
609
+ CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
610
+ dtype=torch.int8,
611
+ device=device,
612
+ )
613
+
614
+ supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
615
+ self._init_reorder_batch_threshold(
616
+ self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
617
+ )
618
+
619
+ # Validate consistency between query_len_support and reorder_batch_threshold
620
+ if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
621
+ assert self.reorder_batch_threshold == 1, (
622
+ f"reorder_batch_threshold must be 1 when query_len_support is "
623
+ f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
624
+ )
625
+
626
+ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
627
+ qo_indptr = prefill.query_start_loc
628
+
629
+ has_context = False
630
+ if prefill.chunked_context is not None:
631
+ chunked_context = prefill.chunked_context
632
+ has_context = True
633
+
634
+ if self._fi_prefill_main is None:
635
+ self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
636
+ self._workspace_buffer, "NHD", backend="cutlass"
637
+ )
638
+
639
+ if has_context:
640
+ num_chunks = chunked_context.cu_seq_lens.shape[0]
641
+ # Allocate more prefill chunk wrappers if needed
642
+ if len(self._fi_prefill_chunks) < num_chunks:
643
+ for _ in range(len(self._fi_prefill_chunks), num_chunks):
644
+ self._fi_prefill_chunks.append(
645
+ BatchPrefillWithRaggedKVCacheWrapper(
646
+ self._workspace_buffer, "NHD", backend="cutlass"
647
+ )
648
+ )
649
+ assert num_chunks <= len(self._fi_prefill_chunks)
650
+
651
+ # In MLA, the non-latent num_qo_heads == num_kv_heads
652
+ num_qo_heads = self.num_heads
653
+ num_kv_heads = num_qo_heads
654
+
655
+ # Sanity: Verify that num_kv_heads == 1 since it is latent space
656
+ assert self.kv_cache_spec.num_kv_heads == 1
657
+
658
+ # Get non-latent head_dim_qk and head_dim_vo
659
+ head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim
660
+ head_dim_vo = self.mla_dims.v_head_dim
661
+
662
+ # For main run, qo_indptr == kv_indptr
663
+ kv_indptr = qo_indptr.clone()
664
+
665
+ # Prepare main prefill
666
+ self._fi_prefill_main.plan(
667
+ qo_indptr=qo_indptr,
668
+ kv_indptr=kv_indptr,
669
+ num_qo_heads=num_qo_heads,
670
+ num_kv_heads=num_kv_heads,
671
+ head_dim_qk=head_dim_qk,
672
+ head_dim_vo=head_dim_vo,
673
+ causal=True, # This is main run
674
+ sm_scale=self._global_hyperparameters.sm_scale,
675
+ window_left=self._global_hyperparameters.window_left,
676
+ logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
677
+ q_data_type=self.model_config.dtype,
678
+ )
679
+
680
+ # Prepare context prefills
681
+ if has_context:
682
+ for i in range(num_chunks):
683
+ kv_indptr_chunk = chunked_context.cu_seq_lens[i]
684
+
685
+ self._fi_prefill_chunks[i].plan(
686
+ qo_indptr=qo_indptr,
687
+ kv_indptr=kv_indptr_chunk,
688
+ num_qo_heads=num_qo_heads,
689
+ num_kv_heads=num_kv_heads,
690
+ head_dim_qk=head_dim_qk,
691
+ head_dim_vo=head_dim_vo,
692
+ causal=False, # This is context run
693
+ sm_scale=self._global_hyperparameters.sm_scale,
694
+ window_left=self._global_hyperparameters.window_left,
695
+ logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
696
+ q_data_type=self.model_config.dtype,
697
+ )
698
+
699
+ prefill.prefill_main = self._fi_prefill_main
700
+ prefill.prefill_chunks = self._fi_prefill_chunks
701
+
702
+ def _build_decode(
703
+ self,
704
+ block_table_tensor: torch.Tensor,
705
+ seq_lens_cpu: torch.Tensor,
706
+ seq_lens_device: torch.Tensor,
707
+ query_start_loc_cpu: torch.Tensor,
708
+ query_start_loc_device: torch.Tensor,
709
+ num_decode_tokens: int,
710
+ dcp_tot_seq_lens_device: torch.Tensor | None,
711
+ ) -> MLACommonDecodeMetadata:
712
+ return MLACommonDecodeMetadata(
713
+ block_table=block_table_tensor,
714
+ seq_lens=seq_lens_device,
715
+ dcp_tot_seq_lens=dcp_tot_seq_lens_device,
716
+ )
717
+
718
+ def build_for_cudagraph_capture(
719
+ self, common_attn_metadata: CommonAttentionMetadata
720
+ ) -> M:
721
+ """
722
+ This method builds the metadata for full cudagraph capture.
723
+ Currently, only decode is supported for full cudagraphs with MLA.
724
+ """
725
+ m = common_attn_metadata
726
+ assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
727
+ "MLA only supports decode-only full CUDAGraph capture. "
728
+ "Make sure all cudagraph capture sizes <= max_num_seq."
729
+ )
730
+
731
+ assert m.max_query_len <= self.reorder_batch_threshold # decode only
732
+
733
+ return self.build(0, m)
734
+
735
+ def build(
736
+ self,
737
+ common_prefix_len: int,
738
+ common_attn_metadata: CommonAttentionMetadata,
739
+ fast_build: bool = False,
740
+ ) -> M:
741
+ num_reqs = common_attn_metadata.num_reqs
742
+ num_tokens = common_attn_metadata.num_actual_tokens
743
+ max_query_len = common_attn_metadata.max_query_len
744
+ max_seq_len = common_attn_metadata.max_seq_len
745
+
746
+ # Note(simon): be careful about the CPU <> GPU memory movement in this
747
+ # function. We should avoid GPU -> CPU sync as much as possible because
748
+ # it blocks on all previous kernels.
749
+ device = self.device
750
+ block_table_tensor = common_attn_metadata.block_table_tensor
751
+ slot_mapping = common_attn_metadata.slot_mapping
752
+
753
+ query_start_loc = common_attn_metadata.query_start_loc
754
+ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
755
+ seq_lens = common_attn_metadata.seq_lens
756
+ seq_lens_cpu = common_attn_metadata.seq_lens_cpu
757
+ dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
758
+ dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
759
+
760
+ query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
761
+
762
+ num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
763
+
764
+ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
765
+ split_decodes_and_prefills(
766
+ common_attn_metadata,
767
+ decode_threshold=self.reorder_batch_threshold,
768
+ require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
769
+ )
770
+ )
771
+
772
+ assert num_decodes + num_prefills == num_reqs
773
+ assert num_decode_tokens + num_prefill_tokens == num_tokens
774
+
775
+ prefill_metadata = None
776
+ if num_prefills > 0:
777
+ reqs_start = num_decodes # prefill_start
778
+
779
+ context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
780
+ max_context_len_cpu = context_lens_cpu.max().item()
781
+ num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
782
+ prefill_query_start_loc = (
783
+ query_start_loc[reqs_start:] - query_start_loc[reqs_start]
784
+ )
785
+
786
+ chunked_context_metadata = None
787
+ if max_context_len_cpu > 0:
788
+ # NOTE: it is recommend you read the `Chunked Prefill` section
789
+ # in the comment at the top of the file before trying to
790
+ # understand the following code
791
+
792
+ # currently we allocate an equal amount of workspace for each
793
+ # prefill in the batch, we could probably use a more advanced
794
+ # algorithm here and allocate more workspace to prefills with
795
+ # longer context lengths
796
+ max_context_chunk = (
797
+ self.chunked_prefill_workspace_size // num_prefills_with_context_cpu
798
+ )
799
+
800
+ if self.aot_schedule:
801
+ # align max_context_chunk to page_size by rounding down,
802
+ # currently the `gather_and_maybe_dequant_cache` kernel
803
+ # cannot handle `context_chunk_starts` that are not aligned
804
+ # to page_size
805
+ max_context_chunk = round_down(max_context_chunk, self.page_size)
806
+
807
+ assert max_context_chunk > 0
808
+ num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
809
+
810
+ # if `max_context_chunk = 256`, `num_chunks = 3`, and
811
+ # `num_prefills_with_context = 4`, create a tensor that looks
812
+ # like
813
+ # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
814
+ # Note(simon): this is done in CPU because of downstream's
815
+ # of `to_list`.
816
+ chunk_starts = (
817
+ torch.arange(num_chunks, dtype=torch.int32)
818
+ .unsqueeze(1)
819
+ .expand(-1, num_prefills)
820
+ * max_context_chunk
821
+ )
822
+ chunk_ends = torch.min(
823
+ context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk
824
+ )
825
+ chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
826
+
827
+ cu_seq_lens_cpu = torch.zeros(
828
+ num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
829
+ )
830
+ torch.cumsum(
831
+ chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
832
+ )
833
+
834
+ if self.dcp_world_size > 1:
835
+ local_context_lens_allranks = get_dcp_local_seq_lens(
836
+ context_lens_cpu,
837
+ self.dcp_world_size,
838
+ None,
839
+ self.dcp_local_block_size,
840
+ )
841
+ # Note(qcs): The max local context lengths
842
+ # padded to `dcp_local_block_size`.
843
+ padded_local_context_lens_cpu = (
844
+ cdiv(
845
+ context_lens_cpu,
846
+ self.dcp_virtual_block_size,
847
+ )
848
+ * self.dcp_local_block_size
849
+ )
850
+ # Note(hc): The above max_context_chunk already enforces
851
+ # block_size alignment, DCP just need the block_size can
852
+ # be divisible by dcp_world_size, because DCP use
853
+ # cp_gather_cache which not require `cp_chunk_starts`
854
+ # aligned to page_size.
855
+ assert max_context_chunk % self.dcp_world_size == 0
856
+ padded_local_max_context_chunk_across_ranks = (
857
+ cdiv(
858
+ max_context_chunk,
859
+ self.dcp_virtual_block_size,
860
+ )
861
+ * self.dcp_local_block_size
862
+ )
863
+ local_chunk_starts = (
864
+ torch.arange(num_chunks, dtype=torch.int32)
865
+ .unsqueeze(1)
866
+ .expand(-1, num_prefills)
867
+ * padded_local_max_context_chunk_across_ranks
868
+ )
869
+ local_chunk_ends = torch.min(
870
+ padded_local_context_lens_cpu.unsqueeze(0),
871
+ local_chunk_starts
872
+ + padded_local_max_context_chunk_across_ranks,
873
+ )
874
+ padded_local_chunk_seq_lens = (
875
+ local_chunk_ends - local_chunk_starts
876
+ ).clamp(min=0)
877
+
878
+ padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
879
+ num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
880
+ )
881
+ torch.cumsum(
882
+ padded_local_chunk_seq_lens,
883
+ dim=1,
884
+ out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
885
+ dtype=torch.int32,
886
+ )
887
+
888
+ chunked_context_metadata_cls = (
889
+ CudnnPrefillMetadata.ChunkedContextMetadata
890
+ if self._use_cudnn_prefill
891
+ else MLACommonPrefillMetadata.ChunkedContextMetadata
892
+ )
893
+ if self.dcp_world_size > 1:
894
+ chunked_context_metadata = chunked_context_metadata_cls(
895
+ cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
896
+ starts=local_chunk_starts.to(device, non_blocking=True),
897
+ seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
898
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
899
+ seq_lens=chunk_seq_lens,
900
+ workspace=self.chunked_prefill_workspace,
901
+ padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
902
+ local_context_lens_allranks=local_context_lens_allranks.tolist(),
903
+ padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
904
+ device, non_blocking=True
905
+ ),
906
+ cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
907
+ chunk_size=padded_local_max_context_chunk_across_ranks,
908
+ )
909
+ else:
910
+ chunked_context_metadata = chunked_context_metadata_cls(
911
+ cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
912
+ starts=chunk_starts.to(device, non_blocking=True),
913
+ seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
914
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
915
+ seq_lens=chunk_seq_lens,
916
+ workspace=self.chunked_prefill_workspace,
917
+ )
918
+
919
+ if self._use_cudnn_prefill:
920
+ chunked_context_metadata.seq_lens = chunk_seq_lens
921
+
922
+ assert (
923
+ max(chunked_context_metadata.max_seq_lens)
924
+ <= self.chunked_prefill_workspace_size
925
+ )
926
+
927
+ prefill_metadata = self.prefill_metadata_cls(
928
+ block_table=block_table_tensor[reqs_start:, ...],
929
+ query_start_loc=prefill_query_start_loc,
930
+ max_query_len=max_query_len,
931
+ chunked_context=chunked_context_metadata,
932
+ )
933
+
934
+ if self._use_cudnn_prefill:
935
+ assert isinstance(prefill_metadata, CudnnPrefillMetadata)
936
+ prefill_metadata.query_seq_lens = (
937
+ prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
938
+ )
939
+ prefill_metadata.cudnn_workspace = self.cudnn_workspace
940
+
941
+ if self._use_trtllm_ragged_prefill:
942
+ prefill_metadata.query_seq_lens = (
943
+ prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
944
+ )
945
+
946
+ decode_metadata = None
947
+ if num_decodes > 0:
948
+ dcp_tot_seq_lens_device = None
949
+ if self.dcp_world_size > 1:
950
+ dcp_tot_seq_lens_device = seq_lens[:num_decodes]
951
+ seq_lens_cpu = dcp_local_seq_lens_cpu
952
+ seq_lens = dcp_local_seq_lens
953
+
954
+ decode_metadata = self._build_decode(
955
+ block_table_tensor=block_table_tensor[:num_decodes, ...],
956
+ seq_lens_cpu=seq_lens_cpu[:num_decodes],
957
+ seq_lens_device=seq_lens[:num_decodes],
958
+ query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
959
+ query_start_loc_device=query_start_loc[: num_decodes + 1],
960
+ num_decode_tokens=num_decode_tokens,
961
+ dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
962
+ )
963
+
964
+ attn_metadata = self.metadata_cls(
965
+ num_reqs=common_attn_metadata.num_reqs,
966
+ max_query_len=common_attn_metadata.max_query_len,
967
+ max_seq_len=max_seq_len,
968
+ num_actual_tokens=num_tokens,
969
+ query_start_loc=query_start_loc,
970
+ slot_mapping=slot_mapping,
971
+ head_dim=self.model_config.get_head_size(),
972
+ # MLACommonMetadata Chunk prefill specific
973
+ num_decodes=num_decodes,
974
+ num_decode_tokens=num_decode_tokens,
975
+ num_prefills=num_prefills,
976
+ prefill=prefill_metadata,
977
+ decode=decode_metadata,
978
+ )
979
+
980
+ if self._use_fi_prefill and num_prefills > 0:
981
+ assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
982
+ self._build_fi_prefill_wrappers(attn_metadata.prefill)
983
+
984
+ return attn_metadata
985
+
986
+
987
+ def reorg_kvcache(
988
+ allgatered_kv_c_normed: torch.Tensor,
989
+ allgatered_k_pe: torch.Tensor,
990
+ padded_local_chunk_seq_lens_lst: list[int],
991
+ local_context_lens_allranks: list[list[int]],
992
+ sum_seq_len: int,
993
+ max_seq_len: int,
994
+ chunk_size: int,
995
+ chunk_idx: int,
996
+ toks: int,
997
+ ) -> tuple[torch.Tensor, torch.Tensor]:
998
+ """
999
+ reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
1000
+ e.g.
1001
+ allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
1002
+ T0_4, T0_5, pad, pad, T1_2, pad, ...]
1003
+ -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
1004
+ T1_0, T1_1, T1_2, ...]
1005
+ Args:
1006
+ padded_local_chunk_seq_lens_lst: local chunk context lengths
1007
+ under current CP rank.
1008
+ local_context_lens_allranks: local context lengths on each CP rank.
1009
+ sum_seq_len: the sum of cp_chunk_seq_lens_lst.
1010
+ max_seq_len: the max value of cp_chunk_seq_lens_lst.
1011
+ chunk_size: the local padded max context chunk from
1012
+ chunked_context_metadata building.
1013
+ chunk_idx: chunk idx of chunked_prefill.
1014
+ toks: the number of tokens for local gather cache.
1015
+ """
1016
+ kv_c_segments = []
1017
+ k_pe_segments = []
1018
+ src_token_idx = 0
1019
+ max_seq_len_check = 0
1020
+ for padded_local_chunk_seq_len, local_context_lens in zip(
1021
+ padded_local_chunk_seq_lens_lst, local_context_lens_allranks
1022
+ ):
1023
+ cur_seq_len = 0
1024
+ for rank, local_context_len in enumerate(local_context_lens):
1025
+ # Note(qcs): We split the context into multiple chunks,
1026
+ # depending on the size of the workspace.
1027
+ # local_context in dcp0: |-----------------|
1028
+ # local_context in dcp1: |--------------|
1029
+ # n*padded_local_chunk: |-----|-----|-----|
1030
+ # local_chunk_len in dcp1: |-----|-----|--|
1031
+ # so we need update the last chunk length in dcp1.
1032
+ local_chunk_len = min(
1033
+ max(0, local_context_len - chunk_idx * chunk_size),
1034
+ padded_local_chunk_seq_len,
1035
+ )
1036
+ if local_chunk_len != 0:
1037
+ kv_c_segment = allgatered_kv_c_normed[
1038
+ rank * toks + src_token_idx : rank * toks
1039
+ + src_token_idx
1040
+ + local_chunk_len
1041
+ ]
1042
+ k_pe_segment = allgatered_k_pe[
1043
+ rank * toks + src_token_idx : rank * toks
1044
+ + src_token_idx
1045
+ + local_chunk_len
1046
+ ]
1047
+ kv_c_segments.append(kv_c_segment)
1048
+ k_pe_segments.append(k_pe_segment)
1049
+ cur_seq_len += local_chunk_len
1050
+ max_seq_len_check = max(max_seq_len_check, cur_seq_len)
1051
+ src_token_idx += padded_local_chunk_seq_len
1052
+ reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
1053
+ reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
1054
+ assert reorganized_kv_c_normed.shape[0] == sum_seq_len
1055
+ assert reorganized_k_pe.shape[0] == sum_seq_len
1056
+ assert max_seq_len_check == max_seq_len
1057
+ return reorganized_kv_c_normed, reorganized_k_pe
1058
+
1059
+
1060
+ # TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
1061
+ # and MLACommonImpl -> MLACommonDenseImpl or somthing like that
1062
+ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
1063
+ """
1064
+ NOTE: Please read the comment at the top of the file before trying to
1065
+ understand this class
1066
+ """
1067
+
1068
+ def __init__(
1069
+ self,
1070
+ num_heads: int,
1071
+ head_size: int,
1072
+ scale: float,
1073
+ num_kv_heads: int,
1074
+ alibi_slopes: list[float] | None,
1075
+ sliding_window: int | None,
1076
+ kv_cache_dtype: str,
1077
+ logits_soft_cap: float | None,
1078
+ attn_type: str,
1079
+ kv_sharing_target_layer_name: str | None,
1080
+ # MLA Specific Arguments
1081
+ q_lora_rank: int | None,
1082
+ kv_lora_rank: int,
1083
+ qk_nope_head_dim: int,
1084
+ qk_rope_head_dim: int,
1085
+ qk_head_dim: int,
1086
+ v_head_dim: int,
1087
+ kv_b_proj: ColumnParallelLinear,
1088
+ indexer=None,
1089
+ q_pad_num_heads: int | None = None,
1090
+ ) -> None:
1091
+ if kv_sharing_target_layer_name is not None:
1092
+ raise NotImplementedError("KV sharing is not supported for MLA")
1093
+
1094
+ self.num_heads = num_heads
1095
+ self.head_size = head_size
1096
+ self.scale = float(scale)
1097
+ self.num_kv_heads = num_kv_heads
1098
+ self.kv_cache_dtype = kv_cache_dtype
1099
+
1100
+ self.q_lora_rank = q_lora_rank
1101
+ self.kv_lora_rank = kv_lora_rank
1102
+ self.qk_nope_head_dim = qk_nope_head_dim
1103
+ self.qk_rope_head_dim = qk_rope_head_dim
1104
+ self.qk_head_dim = qk_head_dim
1105
+ self.v_head_dim = v_head_dim
1106
+ self.kv_b_proj = kv_b_proj
1107
+ self.indexer = indexer
1108
+ self.q_pad_num_heads = q_pad_num_heads
1109
+ self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
1110
+
1111
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1112
+ def get_layer_weight(layer):
1113
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
1114
+ for attr in WEIGHT_NAMES:
1115
+ if hasattr(layer, attr):
1116
+ return getattr(layer, attr)
1117
+ raise AttributeError(
1118
+ f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
1119
+ )
1120
+
1121
+ def get_and_maybe_dequant_weights(layer: LinearBase):
1122
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
1123
+ # NOTE: This should only be used offline, since it's O(N^3)
1124
+ eye = torch.eye(
1125
+ layer.input_size_per_partition,
1126
+ dtype=act_dtype,
1127
+ device=get_layer_weight(layer).device,
1128
+ )
1129
+ dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
1130
+ del eye
1131
+ # standardize to (output, input)
1132
+ return dequant_weights.T
1133
+ return layer.weight
1134
+
1135
+ # we currently do not have quantized bmm's which are needed for
1136
+ # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
1137
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
1138
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
1139
+ assert kv_b_proj_weight.shape == (
1140
+ self.kv_lora_rank,
1141
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
1142
+ ), (
1143
+ f"{kv_b_proj_weight.shape=}, "
1144
+ f"{self.kv_lora_rank=}, "
1145
+ f"{self.num_heads=}, "
1146
+ f"{self.qk_nope_head_dim=}, "
1147
+ f"{self.v_head_dim=}"
1148
+ )
1149
+ kv_b_proj_weight = kv_b_proj_weight.view(
1150
+ self.kv_lora_rank,
1151
+ self.num_heads,
1152
+ self.qk_nope_head_dim + self.v_head_dim,
1153
+ )
1154
+
1155
+ W_UK, W_UV = kv_b_proj_weight.split(
1156
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1
1157
+ )
1158
+
1159
+ if self.is_aiter_triton_fp8_bmm_enabled:
1160
+ W_K = W_UK.transpose(0, 1) # 16 512 128
1161
+ W_V = W_UV.permute(1, 2, 0) # 16 128 512
1162
+ self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
1163
+ W_K, dtype=current_platform.fp8_dtype()
1164
+ )
1165
+ self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
1166
+ W_V, dtype=current_platform.fp8_dtype()
1167
+ )
1168
+
1169
+ # The kernel operates on non-padded inputs. Hence, pre-compiling
1170
+ # triton kernel to avoid runtime compilation for unseen batch sizes
1171
+ # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
1172
+ # On DS-R1, this step adds roughly 50s to the model loading time.
1173
+ max_batch_size = 1024 # [ToDo] Find the optimal upper limit
1174
+ pre_compilation_list = list(range(1, max_batch_size + 1))
1175
+ if is_global_first_rank():
1176
+ pre_compilation_list = tqdm(
1177
+ pre_compilation_list,
1178
+ desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
1179
+ total=max_batch_size,
1180
+ )
1181
+
1182
+ for m in pre_compilation_list:
1183
+ x = torch.empty(
1184
+ (self.W_K.shape[0], m, self.W_K.shape[2]),
1185
+ dtype=torch.bfloat16,
1186
+ device=self.W_K.device,
1187
+ )
1188
+ rocm_aiter_ops.triton_fp8_bmm(
1189
+ x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
1190
+ )
1191
+
1192
+ x = torch.empty(
1193
+ (self.W_V.shape[0], m, self.W_V.shape[2]),
1194
+ dtype=torch.bfloat16,
1195
+ device=self.W_V.device,
1196
+ )
1197
+ rocm_aiter_ops.triton_fp8_bmm(
1198
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
1199
+ )
1200
+ else:
1201
+ # Convert from (L, N, V) to (N, L, V)
1202
+ self.W_UV = W_UV.transpose(0, 1)
1203
+ # Convert from (L, N, P) to (N, P, L)
1204
+ self.W_UK_T = W_UK.permute(1, 2, 0)
1205
+
1206
+ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
1207
+ # Convert from (B, N, L) to (N, B, L)
1208
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1209
+ if self.is_aiter_triton_fp8_bmm_enabled:
1210
+ # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
1211
+ x = rocm_aiter_ops.triton_fp8_bmm(
1212
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
1213
+ )
1214
+ # Convert from (B, N, V) to (B, N * V)
1215
+ x = x.reshape(-1, self.num_heads * self.v_head_dim)
1216
+ # Copy result
1217
+ out.copy_(x)
1218
+ else:
1219
+ # Convert from (B, N * V) to (N, B, V)
1220
+ out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
1221
+
1222
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1223
+ torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
1224
+
1225
+ # Convert from (N, B, V) to (B, N * V)
1226
+ out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
1227
+
1228
+ # Adjust output buffer shape back to the original (B, N * V)
1229
+ N, B, V = out.shape
1230
+ out.resize_((B, N * V))
1231
+ out.copy_(out_new) # Copy result
1232
+
1233
+
1234
+ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
1235
+ """
1236
+ NOTE: Please read the comment at the top of the file before trying to
1237
+ understand this class
1238
+ """
1239
+
1240
+ def __init__(self, *args, **kwargs) -> None:
1241
+ super().__init__(*args, **kwargs)
1242
+
1243
+ if use_flashinfer_prefill():
1244
+ logger.debug_once("Using FlashInfer prefill for MLA")
1245
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
1246
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
1247
+ self._pad_v = False
1248
+ elif use_trtllm_ragged_deepseek_prefill():
1249
+ logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
1250
+ self._run_prefill_context_chunk = (
1251
+ self._run_prefill_context_chunk_trtllm_ragged
1252
+ )
1253
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
1254
+ self._pad_v = False
1255
+ elif use_cudnn_prefill():
1256
+ logger.debug_once("Using CUDNN prefill for MLA")
1257
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
1258
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
1259
+ self._pad_v = False
1260
+ else: # Use FlashAttention
1261
+ logger.debug_once("Using FlashAttention prefill for MLA")
1262
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
1263
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
1264
+
1265
+ # Handle the differences between the flash_attn_varlen from
1266
+ # flash_attn and the one from vllm_flash_attn. The former is used on
1267
+ # RoCM and the latter has an additional parameter to control
1268
+ # FA2 vs FA3
1269
+ self.flash_attn_varlen_func = flash_attn_varlen_func
1270
+ self.vllm_flash_attn_version = get_flash_attn_version()
1271
+ if self.vllm_flash_attn_version is not None:
1272
+ self.flash_attn_varlen_func = functools.partial(
1273
+ flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
1274
+ )
1275
+
1276
+ # For MLA the v head dim is smaller than qk head dim so we pad out
1277
+ # v with 0s to match the qk head dim for attention backends that do
1278
+ # not support different headdims
1279
+ # We don't need to pad V if we are on a hopper system with FA3
1280
+ self._pad_v = self.vllm_flash_attn_version is None or not (
1281
+ self.vllm_flash_attn_version == 3
1282
+ and current_platform.get_device_capability()[0] == 9
1283
+ )
1284
+
1285
+ self.dcp_world_size: int | None = None
1286
+
1287
+ self.chunked_prefill_workspace_size = (
1288
+ MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
1289
+ get_current_vllm_config()
1290
+ )
1291
+ )
1292
+ self.dcp_kv_cache_interleave_size: int = (
1293
+ get_current_vllm_config().parallel_config.dcp_kv_cache_interleave_size
1294
+ )
1295
+
1296
+ def _flash_attn_varlen_diff_headdims(
1297
+ self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
1298
+ ):
1299
+ maybe_padded_v = v
1300
+ if self._pad_v:
1301
+ maybe_padded_v = torch.nn.functional.pad(
1302
+ v, [0, q.shape[-1] - v.shape[-1]], value=0
1303
+ )
1304
+
1305
+ if is_vllm_fa:
1306
+ kwargs["return_softmax_lse"] = return_softmax_lse
1307
+ else:
1308
+ # ROCm leverages the upstream flash_attn, which takes a parameter
1309
+ # called "return_attn_probs" instead of return_softmax_lse
1310
+ kwargs["return_attn_probs"] = return_softmax_lse
1311
+ if vllm_is_batch_invariant():
1312
+ kwargs["num_splits"] = 1
1313
+
1314
+ attn_out = self.flash_attn_varlen_func(
1315
+ q=q,
1316
+ k=k,
1317
+ v=maybe_padded_v,
1318
+ softmax_scale=softmax_scale,
1319
+ **kwargs,
1320
+ )
1321
+
1322
+ # Unpack the output if there is multiple results
1323
+ lse = None
1324
+ if isinstance(attn_out, tuple):
1325
+ attn_out, lse = attn_out[0], attn_out[1]
1326
+
1327
+ # Remain consistent with old `flash_attn_varlen_func` where there
1328
+ # is only one output tensor if `return_softmax_lse` is False.
1329
+ if return_softmax_lse:
1330
+ return attn_out, lse
1331
+ return attn_out
1332
+
1333
+ def _run_prefill_new_tokens_fa(
1334
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1335
+ ):
1336
+ return self._flash_attn_varlen_diff_headdims(
1337
+ q=q,
1338
+ k=k,
1339
+ v=v,
1340
+ cu_seqlens_q=prefill.query_start_loc,
1341
+ cu_seqlens_k=prefill.query_start_loc,
1342
+ max_seqlen_q=prefill.max_query_len,
1343
+ max_seqlen_k=prefill.max_query_len,
1344
+ softmax_scale=self.scale,
1345
+ causal=True,
1346
+ return_softmax_lse=return_softmax_lse,
1347
+ )
1348
+
1349
+ def _run_prefill_new_tokens_fi(
1350
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1351
+ ):
1352
+ assert isinstance(prefill, FlashInferPrefillMetadata)
1353
+ assert prefill.prefill_main is not None
1354
+
1355
+ ret = prefill.prefill_main.run(
1356
+ q=q,
1357
+ k=k,
1358
+ v=v,
1359
+ return_lse=return_softmax_lse,
1360
+ )
1361
+
1362
+ if isinstance(ret, tuple):
1363
+ return ret[0], ret[1].transpose(0, 1).contiguous()
1364
+ return ret
1365
+
1366
+ def _run_prefill_new_tokens_cudnn(
1367
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1368
+ ):
1369
+ assert isinstance(prefill, CudnnPrefillMetadata)
1370
+ assert prefill.query_seq_lens is not None
1371
+ output, lse = cudnn_batch_prefill_with_kv_cache(
1372
+ q=q,
1373
+ k_cache=k,
1374
+ v_cache=v,
1375
+ scale=self.scale,
1376
+ workspace_buffer=prefill.cudnn_workspace,
1377
+ max_token_per_sequence=prefill.max_query_len,
1378
+ max_sequence_kv=prefill.max_query_len,
1379
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
1380
+ actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
1381
+ causal=True,
1382
+ # Do not support False for now
1383
+ return_lse=True,
1384
+ # Indicates actual_seq_lens are on GPU or CPU.
1385
+ is_cuda_graph_compatible=True,
1386
+ )
1387
+ if return_softmax_lse:
1388
+ return output, lse
1389
+ return output
1390
+
1391
+ def _run_prefill_context_chunk_fa(
1392
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1393
+ ):
1394
+ assert prefill.chunked_context is not None
1395
+ return self._flash_attn_varlen_diff_headdims(
1396
+ q=q,
1397
+ k=k,
1398
+ v=v,
1399
+ cu_seqlens_q=prefill.query_start_loc,
1400
+ cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
1401
+ max_seqlen_q=prefill.max_query_len,
1402
+ max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
1403
+ softmax_scale=self.scale,
1404
+ causal=False, # Context is unmasked
1405
+ return_softmax_lse=True,
1406
+ )
1407
+
1408
+ def _run_prefill_context_chunk_fi(
1409
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1410
+ ):
1411
+ assert isinstance(prefill, FlashInferPrefillMetadata)
1412
+
1413
+ attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
1414
+ q=q,
1415
+ k=k,
1416
+ v=v,
1417
+ return_lse=True,
1418
+ )
1419
+
1420
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1421
+ return attn_out, lse.transpose(0, 1).contiguous()
1422
+
1423
+ def _run_prefill_context_chunk_cudnn(
1424
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1425
+ ):
1426
+ assert isinstance(prefill, CudnnPrefillMetadata)
1427
+ assert prefill.chunked_context is not None
1428
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
1429
+ assert prefill.query_seq_lens is not None
1430
+ return cudnn_batch_prefill_with_kv_cache(
1431
+ q=q,
1432
+ k_cache=k,
1433
+ v_cache=v,
1434
+ scale=self.scale,
1435
+ workspace_buffer=prefill.cudnn_workspace,
1436
+ max_token_per_sequence=prefill.max_query_len,
1437
+ max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
1438
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
1439
+ actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view(
1440
+ -1, 1, 1, 1
1441
+ ),
1442
+ causal=False,
1443
+ return_lse=True,
1444
+ # Indicates actual_seq_lens are on GPU or CPU.
1445
+ is_cuda_graph_compatible=True,
1446
+ )
1447
+
1448
+ def _run_prefill_new_tokens_trtllm_ragged(
1449
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1450
+ ):
1451
+ """TRT-LLM ragged attention for new tokens (causal)."""
1452
+ from flashinfer.prefill import trtllm_ragged_attention_deepseek
1453
+
1454
+ assert prefill.query_seq_lens is not None
1455
+
1456
+ ret = trtllm_ragged_attention_deepseek(
1457
+ query=q,
1458
+ key=k,
1459
+ value=v,
1460
+ workspace_buffer=self._workspace_buffer,
1461
+ seq_lens=prefill.query_seq_lens,
1462
+ max_q_len=prefill.max_query_len,
1463
+ max_kv_len=prefill.max_query_len,
1464
+ bmm1_scale=self.scale,
1465
+ bmm2_scale=1.0,
1466
+ o_sf_scale=1.0,
1467
+ batch_size=prefill.query_seq_lens.shape[0],
1468
+ window_left=-1,
1469
+ cum_seq_lens_q=prefill.query_start_loc,
1470
+ cum_seq_lens_kv=prefill.query_start_loc,
1471
+ enable_pdl=False,
1472
+ is_causal=True,
1473
+ return_lse=return_softmax_lse,
1474
+ )
1475
+
1476
+ if isinstance(ret, tuple):
1477
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1478
+ return ret[0], ret[1].transpose(0, 1).contiguous()
1479
+ return ret
1480
+
1481
+ def _run_prefill_context_chunk_trtllm_ragged(
1482
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1483
+ ):
1484
+ """TRT-LLM ragged attention for context chunks (non-causal)."""
1485
+ from flashinfer.prefill import trtllm_ragged_attention_deepseek
1486
+
1487
+ assert prefill.chunked_context is not None
1488
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
1489
+
1490
+ out = torch.zeros(
1491
+ q.shape[0],
1492
+ q.shape[1],
1493
+ v.shape[2],
1494
+ device=q.device,
1495
+ dtype=q.dtype,
1496
+ )
1497
+ self._workspace_buffer.fill_(0)
1498
+
1499
+ attn_out, lse = trtllm_ragged_attention_deepseek(
1500
+ query=q,
1501
+ key=k,
1502
+ value=v,
1503
+ workspace_buffer=self._workspace_buffer,
1504
+ seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
1505
+ max_q_len=prefill.max_query_len,
1506
+ max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
1507
+ bmm1_scale=self.scale,
1508
+ bmm2_scale=1.0,
1509
+ o_sf_scale=1.0,
1510
+ batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
1511
+ window_left=-1,
1512
+ cum_seq_lens_q=prefill.query_start_loc,
1513
+ cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
1514
+ enable_pdl=False,
1515
+ is_causal=False,
1516
+ return_lse=True,
1517
+ out=out,
1518
+ )
1519
+
1520
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1521
+ return attn_out, lse.transpose(0, 1).contiguous()
1522
+
1523
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1524
+ def get_layer_weight(layer):
1525
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
1526
+ for attr in WEIGHT_NAMES:
1527
+ if hasattr(layer, attr):
1528
+ return getattr(layer, attr)
1529
+ raise AttributeError(
1530
+ f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
1531
+ )
1532
+
1533
+ def get_and_maybe_dequant_weights(layer: LinearBase):
1534
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
1535
+ # NOTE: This should only be used offline, since it's O(N^3)
1536
+ eye = torch.eye(
1537
+ layer.input_size_per_partition,
1538
+ dtype=act_dtype,
1539
+ device=get_layer_weight(layer).device,
1540
+ )
1541
+ dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
1542
+ del eye
1543
+ # standardize to (output, input)
1544
+ return dequant_weights.T
1545
+ return layer.weight
1546
+
1547
+ # we currently do not have quantized bmm's which are needed for
1548
+ # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
1549
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
1550
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
1551
+ assert kv_b_proj_weight.shape == (
1552
+ self.kv_lora_rank,
1553
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
1554
+ ), (
1555
+ f"{kv_b_proj_weight.shape=}, "
1556
+ f"{self.kv_lora_rank=}, "
1557
+ f"{self.num_heads=}, "
1558
+ f"{self.qk_nope_head_dim=}, "
1559
+ f"{self.v_head_dim=}"
1560
+ )
1561
+ kv_b_proj_weight = kv_b_proj_weight.view(
1562
+ self.kv_lora_rank,
1563
+ self.num_heads,
1564
+ self.qk_nope_head_dim + self.v_head_dim,
1565
+ )
1566
+
1567
+ W_UK, W_UV = kv_b_proj_weight.split(
1568
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1
1569
+ )
1570
+
1571
+ if self.is_aiter_triton_fp8_bmm_enabled:
1572
+ W_K = W_UK.transpose(0, 1) # 16 512 128
1573
+ W_V = W_UV.permute(1, 2, 0) # 16 128 512
1574
+ self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
1575
+ W_K, dtype=current_platform.fp8_dtype()
1576
+ )
1577
+ self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
1578
+ W_V, dtype=current_platform.fp8_dtype()
1579
+ )
1580
+
1581
+ # The kernel operates on non-padded inputs. Hence, pre-compiling
1582
+ # triton kernel to avoid runtime compilation for unseen batch sizes
1583
+ # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
1584
+ # On DS-R1, this step adds roughly 50s to the model loading time.
1585
+ max_batch_size = 1024 # [ToDo] Find the optimal upper limit
1586
+ pre_compilation_list = list(range(1, max_batch_size + 1))
1587
+ if is_global_first_rank():
1588
+ pre_compilation_list = tqdm(
1589
+ pre_compilation_list,
1590
+ desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
1591
+ total=max_batch_size,
1592
+ )
1593
+
1594
+ for m in pre_compilation_list:
1595
+ x = torch.empty(
1596
+ (self.W_K.shape[0], m, self.W_K.shape[2]),
1597
+ dtype=torch.bfloat16,
1598
+ device=self.W_K.device,
1599
+ )
1600
+ rocm_aiter_ops.triton_fp8_bmm(
1601
+ x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
1602
+ )
1603
+
1604
+ x = torch.empty(
1605
+ (self.W_V.shape[0], m, self.W_V.shape[2]),
1606
+ dtype=torch.bfloat16,
1607
+ device=self.W_V.device,
1608
+ )
1609
+ rocm_aiter_ops.triton_fp8_bmm(
1610
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
1611
+ )
1612
+ else:
1613
+ # Convert from (L, N, V) to (N, L, V)
1614
+ self.W_UV = W_UV.transpose(0, 1)
1615
+ # Convert from (L, N, P) to (N, P, L)
1616
+ self.W_UK_T = W_UK.permute(1, 2, 0)
1617
+
1618
+ def _compute_prefill_context(
1619
+ self,
1620
+ q: torch.Tensor,
1621
+ kv_c_and_k_pe_cache: torch.Tensor,
1622
+ attn_metadata: MLACommonMetadata,
1623
+ k_scale: torch.Tensor,
1624
+ ):
1625
+ assert attn_metadata.prefill is not None
1626
+ prefill_metadata = attn_metadata.prefill
1627
+ assert prefill_metadata.chunked_context is not None
1628
+
1629
+ output = None
1630
+ iters = len(prefill_metadata.chunked_context.seq_tot)
1631
+ workspace = prefill_metadata.chunked_context.workspace
1632
+
1633
+ for i in range(iters):
1634
+ toks = prefill_metadata.chunked_context.seq_tot[i]
1635
+
1636
+ ops.gather_and_maybe_dequant_cache(
1637
+ src_cache=kv_c_and_k_pe_cache,
1638
+ dst=workspace,
1639
+ block_table=prefill_metadata.block_table,
1640
+ cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
1641
+ batch_size=attn_metadata.num_prefills,
1642
+ kv_cache_dtype=self.kv_cache_dtype,
1643
+ scale=k_scale,
1644
+ seq_starts=prefill_metadata.chunked_context.starts[i],
1645
+ )
1646
+
1647
+ kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
1648
+ k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
1649
+
1650
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
1651
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
1652
+ )
1653
+ k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1654
+
1655
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1656
+
1657
+ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
1658
+ prefill=prefill_metadata,
1659
+ chunk_idx=i,
1660
+ q=q,
1661
+ k=k,
1662
+ v=v,
1663
+ )
1664
+
1665
+ if output is None:
1666
+ output = attn_output
1667
+ output_lse = attn_softmax_lse
1668
+ else:
1669
+ output_tmp = torch.empty_like(output)
1670
+ output_lse_tmp = torch.empty_like(output_lse)
1671
+ merge_attn_states(
1672
+ output=output_tmp,
1673
+ output_lse=output_lse_tmp,
1674
+ prefix_output=output,
1675
+ prefix_lse=output_lse,
1676
+ suffix_output=attn_output,
1677
+ suffix_lse=attn_softmax_lse,
1678
+ )
1679
+ output = output_tmp
1680
+ output_lse = output_lse_tmp
1681
+
1682
+ return output, output_lse
1683
+
1684
+ def _context_parallel_compute_prefill_context(
1685
+ self,
1686
+ q: torch.Tensor,
1687
+ kv_c_and_k_pe_cache: torch.Tensor,
1688
+ attn_metadata: MLACommonMetadata,
1689
+ k_scale: torch.Tensor,
1690
+ dcp_world_size: int,
1691
+ ):
1692
+ assert k_scale is None, "DCP not support scaled kvcache now."
1693
+ assert attn_metadata.prefill is not None
1694
+ prefill_metadata = attn_metadata.prefill
1695
+ assert prefill_metadata.chunked_context is not None
1696
+ assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None
1697
+ assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
1698
+ assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
1699
+ assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
1700
+ assert prefill_metadata.chunked_context.chunk_size is not None
1701
+
1702
+ output = None
1703
+ iters = len(prefill_metadata.chunked_context.seq_tot)
1704
+ workspace = prefill_metadata.chunked_context.workspace
1705
+
1706
+ for i in range(iters):
1707
+ toks = prefill_metadata.chunked_context.seq_tot[i]
1708
+ ops.cp_gather_cache(
1709
+ src_cache=kv_c_and_k_pe_cache,
1710
+ dst=workspace,
1711
+ block_table=prefill_metadata.block_table,
1712
+ cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
1713
+ i
1714
+ ],
1715
+ batch_size=attn_metadata.num_prefills,
1716
+ seq_starts=prefill_metadata.chunked_context.starts[i],
1717
+ )
1718
+ # workspace
1719
+ # |------- N tokens --------|--------- N*dcp_size tokens ----------|
1720
+ # |<- use for loca_gather ->|<--------- use for allgather -------->|
1721
+ allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
1722
+ assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0]
1723
+ assert toks <= allgather_offset
1724
+ local_gathered_kvcache = workspace[:toks]
1725
+ cur_allgather_workspace = workspace[
1726
+ allgather_offset : allgather_offset * (1 + dcp_world_size)
1727
+ ]
1728
+ assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
1729
+ cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size]
1730
+ cur_allgather_kvcache.copy_(
1731
+ get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
1732
+ )
1733
+ assert (
1734
+ cur_allgather_kvcache.shape[-1]
1735
+ == self.kv_lora_rank + self.qk_rope_head_dim
1736
+ )
1737
+ allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze(
1738
+ 1
1739
+ ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1740
+
1741
+ kv_c_normed, k_pe = reorg_kvcache(
1742
+ allgatered_kv_c_normed,
1743
+ allgatered_k_pe,
1744
+ padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[
1745
+ i
1746
+ ],
1747
+ local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
1748
+ sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
1749
+ max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
1750
+ chunk_size=prefill_metadata.chunked_context.chunk_size,
1751
+ chunk_idx=i,
1752
+ toks=toks,
1753
+ )
1754
+
1755
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
1756
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
1757
+ )
1758
+ k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1759
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1760
+
1761
+ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
1762
+ prefill=prefill_metadata,
1763
+ chunk_idx=i,
1764
+ q=q,
1765
+ k=k,
1766
+ v=v,
1767
+ )
1768
+
1769
+ if output is None:
1770
+ output = attn_output
1771
+ output_lse = attn_softmax_lse
1772
+ else:
1773
+ output_tmp = torch.empty_like(output)
1774
+ output_lse_tmp = torch.empty_like(output_lse)
1775
+ merge_attn_states(
1776
+ output=output_tmp,
1777
+ output_lse=output_lse_tmp,
1778
+ prefix_output=output,
1779
+ prefix_lse=output_lse,
1780
+ suffix_output=attn_output,
1781
+ suffix_lse=attn_softmax_lse,
1782
+ )
1783
+ output = output_tmp
1784
+ output_lse = output_lse_tmp
1785
+
1786
+ return output, output_lse
1787
+
1788
+ def _forward_prefill(
1789
+ self,
1790
+ q: torch.Tensor,
1791
+ kv_c_normed: torch.Tensor,
1792
+ k_pe: torch.Tensor,
1793
+ kv_c_and_k_pe_cache: torch.Tensor,
1794
+ attn_metadata: MLACommonMetadata,
1795
+ k_scale: torch.Tensor,
1796
+ ) -> torch.Tensor:
1797
+ # TODO (zyongye): Prefill function here
1798
+ assert attn_metadata.prefill is not None
1799
+ assert self.dcp_world_size is not None
1800
+
1801
+ has_context = attn_metadata.prefill.chunked_context is not None
1802
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
1803
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
1804
+ )
1805
+ k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1806
+
1807
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1808
+
1809
+ output = self._run_prefill_new_tokens(
1810
+ prefill=attn_metadata.prefill,
1811
+ q=q,
1812
+ k=k,
1813
+ v=v,
1814
+ return_softmax_lse=has_context,
1815
+ )
1816
+
1817
+ if has_context:
1818
+ suffix_output, suffix_lse = output
1819
+ if self.dcp_world_size > 1:
1820
+ context_output, context_lse = (
1821
+ self._context_parallel_compute_prefill_context(
1822
+ q,
1823
+ kv_c_and_k_pe_cache,
1824
+ attn_metadata,
1825
+ k_scale=None,
1826
+ dcp_world_size=self.dcp_world_size,
1827
+ )
1828
+ )
1829
+ else:
1830
+ context_output, context_lse = self._compute_prefill_context(
1831
+ q, kv_c_and_k_pe_cache, attn_metadata, k_scale
1832
+ )
1833
+
1834
+ output = torch.empty_like(suffix_output)
1835
+ merge_attn_states(
1836
+ output=output,
1837
+ prefix_output=context_output,
1838
+ prefix_lse=context_lse,
1839
+ suffix_output=suffix_output,
1840
+ suffix_lse=suffix_lse,
1841
+ )
1842
+
1843
+ # unpad if necessary
1844
+ if self._pad_v:
1845
+ output = output[..., : v.shape[-1]]
1846
+
1847
+ return output.flatten(start_dim=-2)
1848
+
1849
+ @abstractmethod
1850
+ def _forward_decode(
1851
+ self,
1852
+ q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
1853
+ kv_c_and_k_pe_cache: torch.Tensor,
1854
+ attn_metadata: M,
1855
+ layer: AttentionLayer,
1856
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
1857
+ raise NotImplementedError
1858
+
1859
+ def forward(
1860
+ self,
1861
+ layer: AttentionLayer,
1862
+ q: torch.Tensor,
1863
+ k_c_normed: torch.Tensor, # key in unified attn
1864
+ k_pe: torch.Tensor, # value in unified attn
1865
+ kv_cache: torch.Tensor,
1866
+ attn_metadata: M,
1867
+ output: torch.Tensor | None = None,
1868
+ output_scale: torch.Tensor | None = None,
1869
+ output_block_scale: torch.Tensor | None = None,
1870
+ ) -> torch.Tensor:
1871
+ assert output is not None, "Output tensor must be provided."
1872
+
1873
+ if output_scale is not None or output_block_scale is not None:
1874
+ raise NotImplementedError(
1875
+ "fused output quantization is not yet supported for MLACommonImpl"
1876
+ )
1877
+
1878
+ if attn_metadata is None:
1879
+ # During the profile run try to simulate to worse case output size
1880
+ # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
1881
+ # since this can be large
1882
+ _ = torch.empty(
1883
+ (
1884
+ self.chunked_prefill_workspace_size,
1885
+ self.num_heads,
1886
+ self.qk_nope_head_dim + self.v_head_dim,
1887
+ ),
1888
+ device=k_c_normed.device,
1889
+ dtype=k_c_normed.dtype,
1890
+ )
1891
+
1892
+ # The zero fill is required when used with DP + EP
1893
+ # to ensure all ranks within a DP group compute the
1894
+ # same expert outputs.
1895
+ return output.fill_(0)
1896
+
1897
+ if self.dcp_world_size is None:
1898
+ self.dcp_world_size = get_dcp_group().world_size
1899
+
1900
+ fp8_attention = self.kv_cache_dtype.startswith("fp8")
1901
+
1902
+ num_actual_toks = attn_metadata.num_actual_tokens
1903
+
1904
+ # Inputs and outputs may be padded for CUDA graphs
1905
+ output_padded = output
1906
+ output = output[:num_actual_toks, ...]
1907
+ q = q[:num_actual_toks, ...]
1908
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
1909
+ k_pe = k_pe[:num_actual_toks, ...]
1910
+
1911
+ assert (
1912
+ attn_metadata.num_decodes is not None
1913
+ and attn_metadata.num_prefills is not None
1914
+ and attn_metadata.num_decode_tokens is not None
1915
+ )
1916
+
1917
+ has_decode = attn_metadata.num_decodes > 0
1918
+ has_prefill = attn_metadata.num_prefills > 0
1919
+ num_decode_tokens = attn_metadata.num_decode_tokens
1920
+
1921
+ decode_q = q[:num_decode_tokens]
1922
+
1923
+ prefill_q = q[num_decode_tokens:]
1924
+ prefill_k_pe = k_pe[num_decode_tokens:]
1925
+ prefill_k_c_normed = k_c_normed[num_decode_tokens:]
1926
+
1927
+ # write the latent and rope to kv cache
1928
+ if kv_cache.numel() > 0:
1929
+ ops.concat_and_cache_mla(
1930
+ k_c_normed,
1931
+ k_pe.squeeze(1),
1932
+ kv_cache,
1933
+ attn_metadata.slot_mapping.flatten(),
1934
+ kv_cache_dtype=self.kv_cache_dtype,
1935
+ scale=layer._k_scale,
1936
+ )
1937
+
1938
+ if fp8_attention:
1939
+ kv_cache = kv_cache.view(current_platform.fp8_dtype())
1940
+
1941
+ if has_prefill:
1942
+ output[num_decode_tokens:] = self._forward_prefill(
1943
+ prefill_q,
1944
+ prefill_k_c_normed,
1945
+ prefill_k_pe,
1946
+ kv_cache,
1947
+ attn_metadata,
1948
+ layer._k_scale,
1949
+ )
1950
+
1951
+ if has_decode:
1952
+ assert attn_metadata.decode is not None
1953
+
1954
+ decode_q_nope, decode_q_pe = decode_q.split(
1955
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1956
+ )
1957
+
1958
+ # Convert from (B, N, P) to (N, B, P)
1959
+ decode_q_nope = decode_q_nope.transpose(0, 1)
1960
+
1961
+ if self.q_pad_num_heads is not None:
1962
+ B, N, L = decode_q_pe.shape
1963
+ decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
1964
+ decode_pe_padded.resize_((B, N, L))
1965
+ decode_pe_padded.copy_(decode_q_pe)
1966
+ decode_q_pe = decode_pe_padded
1967
+
1968
+ if self.is_aiter_triton_fp8_bmm_enabled:
1969
+ # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
1970
+ decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
1971
+ decode_q_nope,
1972
+ self.W_K,
1973
+ self.W_K_scale,
1974
+ group_size=128,
1975
+ transpose_bm=True,
1976
+ )
1977
+ else:
1978
+ # Pads the head_dim if necessary (for the underlying kernel)
1979
+ N, B, P = decode_q_nope.shape
1980
+ _, _, L = self.W_UK_T.shape
1981
+
1982
+ if self.q_pad_num_heads is not None:
1983
+ decode_ql_nope = decode_q_nope.new_empty(
1984
+ (self.q_pad_num_heads, B, L)
1985
+ )
1986
+ decode_ql_nope.resize_((N, B, L))
1987
+ else:
1988
+ decode_ql_nope = decode_q_nope.new_empty((N, B, L))
1989
+
1990
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1991
+ torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
1992
+
1993
+ # Convert from (N, B, L) to (B, N, L)
1994
+ decode_ql_nope = decode_ql_nope.transpose(0, 1)
1995
+
1996
+ if fp8_attention:
1997
+ ql_nope_shape = decode_ql_nope.shape
1998
+ decode_ql_nope, _ = ops.scaled_fp8_quant(
1999
+ decode_ql_nope.reshape(
2000
+ [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
2001
+ ),
2002
+ layer._q_scale,
2003
+ )
2004
+ decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
2005
+ q_pe_shape = decode_q_pe.shape
2006
+ decode_q_pe, _ = ops.scaled_fp8_quant(
2007
+ decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
2008
+ layer._q_scale,
2009
+ )
2010
+ decode_q_pe = decode_q_pe.reshape(q_pe_shape)
2011
+
2012
+ decode_q = (decode_ql_nope, decode_q_pe)
2013
+ if self.dcp_world_size > 1:
2014
+ assert not fp8_attention, "DCP not support fp8 kvcache now."
2015
+ # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
2016
+ decode_q = torch.cat(decode_q, dim=-1)
2017
+ # decode_q do allgather in head dim.
2018
+ decode_q = get_dcp_group().all_gather(decode_q, dim=1)
2019
+
2020
+ # call decode attn
2021
+ attn_out, lse = self._forward_decode(
2022
+ decode_q, kv_cache, attn_metadata, layer
2023
+ )
2024
+
2025
+ # correct dcp attn_out with lse.
2026
+ if self.dcp_world_size > 1:
2027
+ attn_out = cp_lse_ag_out_rs(attn_out, lse, get_dcp_group())
2028
+
2029
+ # v_up projection
2030
+ self._v_up_proj(attn_out, out=output[:num_decode_tokens])
2031
+ return output_padded