vllm-cpu 0.12.0__cp313-cp313-manylinux_2_17_aarch64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1600) hide show
  1. vllm/_C.abi3.so +0 -0
  2. vllm/__init__.py +107 -0
  3. vllm/_aiter_ops.py +1018 -0
  4. vllm/_bc_linter.py +54 -0
  5. vllm/_custom_ops.py +2925 -0
  6. vllm/_ipex_ops.py +457 -0
  7. vllm/_version.py +34 -0
  8. vllm/assets/__init__.py +0 -0
  9. vllm/assets/audio.py +43 -0
  10. vllm/assets/base.py +40 -0
  11. vllm/assets/image.py +59 -0
  12. vllm/assets/video.py +149 -0
  13. vllm/attention/__init__.py +0 -0
  14. vllm/attention/backends/__init__.py +0 -0
  15. vllm/attention/backends/abstract.py +434 -0
  16. vllm/attention/backends/registry.py +286 -0
  17. vllm/attention/backends/utils.py +33 -0
  18. vllm/attention/layer.py +975 -0
  19. vllm/attention/layers/__init__.py +0 -0
  20. vllm/attention/layers/chunked_local_attention.py +120 -0
  21. vllm/attention/layers/cross_attention.py +178 -0
  22. vllm/attention/layers/encoder_only_attention.py +103 -0
  23. vllm/attention/ops/__init__.py +0 -0
  24. vllm/attention/ops/chunked_prefill_paged_decode.py +401 -0
  25. vllm/attention/ops/common.py +469 -0
  26. vllm/attention/ops/flashmla.py +251 -0
  27. vllm/attention/ops/merge_attn_states.py +47 -0
  28. vllm/attention/ops/paged_attn.py +51 -0
  29. vllm/attention/ops/pallas_kv_cache_update.py +130 -0
  30. vllm/attention/ops/prefix_prefill.py +814 -0
  31. vllm/attention/ops/rocm_aiter_mla_sparse.py +210 -0
  32. vllm/attention/ops/triton_decode_attention.py +712 -0
  33. vllm/attention/ops/triton_merge_attn_states.py +116 -0
  34. vllm/attention/ops/triton_reshape_and_cache_flash.py +184 -0
  35. vllm/attention/ops/triton_unified_attention.py +941 -0
  36. vllm/attention/ops/vit_attn_wrappers.py +136 -0
  37. vllm/attention/selector.py +268 -0
  38. vllm/attention/utils/__init__.py +0 -0
  39. vllm/attention/utils/fa_utils.py +117 -0
  40. vllm/attention/utils/kv_sharing_utils.py +33 -0
  41. vllm/attention/utils/kv_transfer_utils.py +60 -0
  42. vllm/beam_search.py +88 -0
  43. vllm/benchmarks/__init__.py +0 -0
  44. vllm/benchmarks/datasets.py +3222 -0
  45. vllm/benchmarks/latency.py +172 -0
  46. vllm/benchmarks/lib/__init__.py +3 -0
  47. vllm/benchmarks/lib/endpoint_request_func.py +777 -0
  48. vllm/benchmarks/lib/ready_checker.py +72 -0
  49. vllm/benchmarks/lib/utils.py +79 -0
  50. vllm/benchmarks/serve.py +1531 -0
  51. vllm/benchmarks/sweep/__init__.py +0 -0
  52. vllm/benchmarks/sweep/cli.py +41 -0
  53. vllm/benchmarks/sweep/param_sweep.py +91 -0
  54. vllm/benchmarks/sweep/plot.py +580 -0
  55. vllm/benchmarks/sweep/plot_pareto.py +393 -0
  56. vllm/benchmarks/sweep/serve.py +448 -0
  57. vllm/benchmarks/sweep/serve_sla.py +492 -0
  58. vllm/benchmarks/sweep/server.py +114 -0
  59. vllm/benchmarks/sweep/sla_sweep.py +132 -0
  60. vllm/benchmarks/sweep/utils.py +4 -0
  61. vllm/benchmarks/throughput.py +799 -0
  62. vllm/collect_env.py +857 -0
  63. vllm/compilation/__init__.py +0 -0
  64. vllm/compilation/activation_quant_fusion.py +209 -0
  65. vllm/compilation/backends.py +827 -0
  66. vllm/compilation/base_static_graph.py +57 -0
  67. vllm/compilation/caching.py +180 -0
  68. vllm/compilation/collective_fusion.py +1234 -0
  69. vllm/compilation/compiler_interface.py +639 -0
  70. vllm/compilation/counter.py +48 -0
  71. vllm/compilation/cuda_graph.py +208 -0
  72. vllm/compilation/decorators.py +614 -0
  73. vllm/compilation/fix_functionalization.py +253 -0
  74. vllm/compilation/fusion.py +374 -0
  75. vllm/compilation/fusion_attn.py +359 -0
  76. vllm/compilation/fx_utils.py +91 -0
  77. vllm/compilation/inductor_pass.py +133 -0
  78. vllm/compilation/matcher_utils.py +315 -0
  79. vllm/compilation/monitor.py +62 -0
  80. vllm/compilation/noop_elimination.py +134 -0
  81. vllm/compilation/partition_rules.py +72 -0
  82. vllm/compilation/pass_manager.py +136 -0
  83. vllm/compilation/piecewise_backend.py +121 -0
  84. vllm/compilation/post_cleanup.py +21 -0
  85. vllm/compilation/qk_norm_rope_fusion.py +238 -0
  86. vllm/compilation/sequence_parallelism.py +363 -0
  87. vllm/compilation/torch25_custom_graph_pass.py +44 -0
  88. vllm/compilation/vllm_inductor_pass.py +173 -0
  89. vllm/compilation/wrapper.py +260 -0
  90. vllm/config/__init__.py +102 -0
  91. vllm/config/cache.py +220 -0
  92. vllm/config/compilation.py +1154 -0
  93. vllm/config/device.py +75 -0
  94. vllm/config/ec_transfer.py +110 -0
  95. vllm/config/kv_events.py +56 -0
  96. vllm/config/kv_transfer.py +114 -0
  97. vllm/config/load.py +124 -0
  98. vllm/config/lora.py +96 -0
  99. vllm/config/model.py +2274 -0
  100. vllm/config/multimodal.py +247 -0
  101. vllm/config/observability.py +131 -0
  102. vllm/config/parallel.py +653 -0
  103. vllm/config/pooler.py +124 -0
  104. vllm/config/scheduler.py +297 -0
  105. vllm/config/speculative.py +643 -0
  106. vllm/config/speech_to_text.py +38 -0
  107. vllm/config/structured_outputs.py +94 -0
  108. vllm/config/utils.py +324 -0
  109. vllm/config/vllm.py +1353 -0
  110. vllm/connections.py +189 -0
  111. vllm/device_allocator/__init__.py +0 -0
  112. vllm/device_allocator/cumem.py +327 -0
  113. vllm/distributed/__init__.py +6 -0
  114. vllm/distributed/communication_op.py +43 -0
  115. vllm/distributed/device_communicators/__init__.py +0 -0
  116. vllm/distributed/device_communicators/all2all.py +490 -0
  117. vllm/distributed/device_communicators/all_reduce_utils.py +344 -0
  118. vllm/distributed/device_communicators/base_device_communicator.py +297 -0
  119. vllm/distributed/device_communicators/cpu_communicator.py +209 -0
  120. vllm/distributed/device_communicators/cuda_communicator.py +340 -0
  121. vllm/distributed/device_communicators/cuda_wrapper.py +216 -0
  122. vllm/distributed/device_communicators/custom_all_reduce.py +326 -0
  123. vllm/distributed/device_communicators/mnnvl_compat.py +27 -0
  124. vllm/distributed/device_communicators/pynccl.py +386 -0
  125. vllm/distributed/device_communicators/pynccl_allocator.py +191 -0
  126. vllm/distributed/device_communicators/pynccl_wrapper.py +564 -0
  127. vllm/distributed/device_communicators/quick_all_reduce.py +290 -0
  128. vllm/distributed/device_communicators/ray_communicator.py +259 -0
  129. vllm/distributed/device_communicators/shm_broadcast.py +733 -0
  130. vllm/distributed/device_communicators/shm_object_storage.py +697 -0
  131. vllm/distributed/device_communicators/symm_mem.py +156 -0
  132. vllm/distributed/device_communicators/tpu_communicator.py +99 -0
  133. vllm/distributed/device_communicators/xpu_communicator.py +95 -0
  134. vllm/distributed/ec_transfer/__init__.py +14 -0
  135. vllm/distributed/ec_transfer/ec_connector/__init__.py +0 -0
  136. vllm/distributed/ec_transfer/ec_connector/base.py +247 -0
  137. vllm/distributed/ec_transfer/ec_connector/factory.py +85 -0
  138. vllm/distributed/ec_transfer/ec_connector/shared_storage_connector.py +201 -0
  139. vllm/distributed/ec_transfer/ec_transfer_state.py +42 -0
  140. vllm/distributed/eplb/__init__.py +8 -0
  141. vllm/distributed/eplb/async_worker.py +115 -0
  142. vllm/distributed/eplb/eplb_state.py +1154 -0
  143. vllm/distributed/eplb/rebalance_algo.py +260 -0
  144. vllm/distributed/eplb/rebalance_execute.py +532 -0
  145. vllm/distributed/kv_events.py +371 -0
  146. vllm/distributed/kv_transfer/README.md +29 -0
  147. vllm/distributed/kv_transfer/__init__.py +20 -0
  148. vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg +0 -0
  149. vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
  150. vllm/distributed/kv_transfer/kv_connector/base.py +10 -0
  151. vllm/distributed/kv_transfer/kv_connector/factory.py +192 -0
  152. vllm/distributed/kv_transfer/kv_connector/utils.py +268 -0
  153. vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +19 -0
  154. vllm/distributed/kv_transfer/kv_connector/v1/base.py +575 -0
  155. vllm/distributed/kv_transfer/kv_connector/v1/decode_bench_connector.py +419 -0
  156. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +216 -0
  157. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/__init__.py +18 -0
  158. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/multi_process_adapter.py +378 -0
  159. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/utils.py +221 -0
  160. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_integration/vllm_v1_adapter.py +1411 -0
  161. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_mp_connector.py +895 -0
  162. vllm/distributed/kv_transfer/kv_connector/v1/metrics.py +189 -0
  163. vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +454 -0
  164. vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +2480 -0
  165. vllm/distributed/kv_transfer/kv_connector/v1/offloading_connector.py +538 -0
  166. vllm/distributed/kv_transfer/kv_connector/v1/p2p/__init__.py +0 -0
  167. vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_connector.py +531 -0
  168. vllm/distributed/kv_transfer/kv_connector/v1/p2p/p2p_nccl_engine.py +632 -0
  169. vllm/distributed/kv_transfer/kv_connector/v1/p2p/tensor_memory_pool.py +273 -0
  170. vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +450 -0
  171. vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
  172. vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +179 -0
  173. vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +164 -0
  174. vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +242 -0
  175. vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
  176. vllm/distributed/kv_transfer/kv_pipe/base.py +66 -0
  177. vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +295 -0
  178. vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +285 -0
  179. vllm/distributed/kv_transfer/kv_transfer_state.py +78 -0
  180. vllm/distributed/parallel_state.py +1790 -0
  181. vllm/distributed/tpu_distributed_utils.py +188 -0
  182. vllm/distributed/utils.py +545 -0
  183. vllm/engine/__init__.py +0 -0
  184. vllm/engine/arg_utils.py +2106 -0
  185. vllm/engine/async_llm_engine.py +6 -0
  186. vllm/engine/llm_engine.py +6 -0
  187. vllm/engine/protocol.py +188 -0
  188. vllm/entrypoints/__init__.py +0 -0
  189. vllm/entrypoints/anthropic/__init__.py +0 -0
  190. vllm/entrypoints/anthropic/protocol.py +162 -0
  191. vllm/entrypoints/anthropic/serving_messages.py +460 -0
  192. vllm/entrypoints/api_server.py +184 -0
  193. vllm/entrypoints/chat_utils.py +1837 -0
  194. vllm/entrypoints/cli/__init__.py +13 -0
  195. vllm/entrypoints/cli/benchmark/__init__.py +0 -0
  196. vllm/entrypoints/cli/benchmark/base.py +25 -0
  197. vllm/entrypoints/cli/benchmark/latency.py +21 -0
  198. vllm/entrypoints/cli/benchmark/main.py +56 -0
  199. vllm/entrypoints/cli/benchmark/serve.py +21 -0
  200. vllm/entrypoints/cli/benchmark/sweep.py +21 -0
  201. vllm/entrypoints/cli/benchmark/throughput.py +21 -0
  202. vllm/entrypoints/cli/collect_env.py +38 -0
  203. vllm/entrypoints/cli/main.py +79 -0
  204. vllm/entrypoints/cli/openai.py +256 -0
  205. vllm/entrypoints/cli/run_batch.py +68 -0
  206. vllm/entrypoints/cli/serve.py +249 -0
  207. vllm/entrypoints/cli/types.py +29 -0
  208. vllm/entrypoints/constants.py +10 -0
  209. vllm/entrypoints/context.py +572 -0
  210. vllm/entrypoints/dynamic_lora.py +57 -0
  211. vllm/entrypoints/harmony_utils.py +535 -0
  212. vllm/entrypoints/launcher.py +175 -0
  213. vllm/entrypoints/llm.py +1762 -0
  214. vllm/entrypoints/logger.py +84 -0
  215. vllm/entrypoints/openai/__init__.py +0 -0
  216. vllm/entrypoints/openai/api_server.py +1891 -0
  217. vllm/entrypoints/openai/cli_args.py +302 -0
  218. vllm/entrypoints/openai/orca_metrics.py +120 -0
  219. vllm/entrypoints/openai/protocol.py +2465 -0
  220. vllm/entrypoints/openai/run_batch.py +631 -0
  221. vllm/entrypoints/openai/serving_chat.py +1782 -0
  222. vllm/entrypoints/openai/serving_completion.py +716 -0
  223. vllm/entrypoints/openai/serving_engine.py +1478 -0
  224. vllm/entrypoints/openai/serving_models.py +304 -0
  225. vllm/entrypoints/openai/serving_responses.py +2032 -0
  226. vllm/entrypoints/openai/serving_tokenization.py +203 -0
  227. vllm/entrypoints/openai/serving_tokens.py +281 -0
  228. vllm/entrypoints/openai/serving_transcription.py +168 -0
  229. vllm/entrypoints/openai/speech_to_text.py +559 -0
  230. vllm/entrypoints/openai/tool_parsers/__init__.py +142 -0
  231. vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +273 -0
  232. vllm/entrypoints/openai/tool_parsers/deepseekv31_tool_parser.py +390 -0
  233. vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +390 -0
  234. vllm/entrypoints/openai/tool_parsers/ernie45_tool_parser.py +210 -0
  235. vllm/entrypoints/openai/tool_parsers/glm4_moe_tool_parser.py +200 -0
  236. vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +273 -0
  237. vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +253 -0
  238. vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +494 -0
  239. vllm/entrypoints/openai/tool_parsers/hunyuan_a13b_tool_parser.py +420 -0
  240. vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +227 -0
  241. vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +322 -0
  242. vllm/entrypoints/openai/tool_parsers/kimi_k2_tool_parser.py +590 -0
  243. vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +341 -0
  244. vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +324 -0
  245. vllm/entrypoints/openai/tool_parsers/longcat_tool_parser.py +37 -0
  246. vllm/entrypoints/openai/tool_parsers/minimax_m2_tool_parser.py +643 -0
  247. vllm/entrypoints/openai/tool_parsers/minimax_tool_parser.py +849 -0
  248. vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +390 -0
  249. vllm/entrypoints/openai/tool_parsers/olmo3_tool_parser.py +366 -0
  250. vllm/entrypoints/openai/tool_parsers/openai_tool_parser.py +97 -0
  251. vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +120 -0
  252. vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +332 -0
  253. vllm/entrypoints/openai/tool_parsers/qwen3coder_tool_parser.py +781 -0
  254. vllm/entrypoints/openai/tool_parsers/qwen3xml_tool_parser.py +1316 -0
  255. vllm/entrypoints/openai/tool_parsers/seed_oss_tool_parser.py +744 -0
  256. vllm/entrypoints/openai/tool_parsers/step3_tool_parser.py +303 -0
  257. vllm/entrypoints/openai/tool_parsers/utils.py +229 -0
  258. vllm/entrypoints/openai/tool_parsers/xlam_tool_parser.py +556 -0
  259. vllm/entrypoints/openai/utils.py +49 -0
  260. vllm/entrypoints/pooling/__init__.py +16 -0
  261. vllm/entrypoints/pooling/classify/__init__.py +0 -0
  262. vllm/entrypoints/pooling/classify/api_router.py +50 -0
  263. vllm/entrypoints/pooling/classify/protocol.py +181 -0
  264. vllm/entrypoints/pooling/classify/serving.py +237 -0
  265. vllm/entrypoints/pooling/embed/__init__.py +0 -0
  266. vllm/entrypoints/pooling/embed/api_router.py +67 -0
  267. vllm/entrypoints/pooling/embed/protocol.py +208 -0
  268. vllm/entrypoints/pooling/embed/serving.py +697 -0
  269. vllm/entrypoints/pooling/pooling/__init__.py +0 -0
  270. vllm/entrypoints/pooling/pooling/api_router.py +63 -0
  271. vllm/entrypoints/pooling/pooling/protocol.py +148 -0
  272. vllm/entrypoints/pooling/pooling/serving.py +348 -0
  273. vllm/entrypoints/pooling/score/__init__.py +0 -0
  274. vllm/entrypoints/pooling/score/api_router.py +149 -0
  275. vllm/entrypoints/pooling/score/protocol.py +145 -0
  276. vllm/entrypoints/pooling/score/serving.py +505 -0
  277. vllm/entrypoints/renderer.py +409 -0
  278. vllm/entrypoints/responses_utils.py +148 -0
  279. vllm/entrypoints/sagemaker/__init__.py +4 -0
  280. vllm/entrypoints/sagemaker/routes.py +118 -0
  281. vllm/entrypoints/score_utils.py +240 -0
  282. vllm/entrypoints/ssl.py +78 -0
  283. vllm/entrypoints/tool.py +143 -0
  284. vllm/entrypoints/tool_server.py +234 -0
  285. vllm/entrypoints/utils.py +319 -0
  286. vllm/env_override.py +378 -0
  287. vllm/envs.py +1710 -0
  288. vllm/forward_context.py +358 -0
  289. vllm/inputs/__init__.py +44 -0
  290. vllm/inputs/data.py +359 -0
  291. vllm/inputs/parse.py +137 -0
  292. vllm/inputs/preprocess.py +716 -0
  293. vllm/logger.py +298 -0
  294. vllm/logging_utils/__init__.py +13 -0
  295. vllm/logging_utils/dump_input.py +83 -0
  296. vllm/logging_utils/formatter.py +127 -0
  297. vllm/logging_utils/lazy.py +20 -0
  298. vllm/logging_utils/log_time.py +34 -0
  299. vllm/logits_process.py +121 -0
  300. vllm/logprobs.py +206 -0
  301. vllm/lora/__init__.py +0 -0
  302. vllm/lora/layers/__init__.py +42 -0
  303. vllm/lora/layers/base.py +66 -0
  304. vllm/lora/layers/base_linear.py +165 -0
  305. vllm/lora/layers/column_parallel_linear.py +577 -0
  306. vllm/lora/layers/fused_moe.py +747 -0
  307. vllm/lora/layers/logits_processor.py +203 -0
  308. vllm/lora/layers/replicated_linear.py +70 -0
  309. vllm/lora/layers/row_parallel_linear.py +176 -0
  310. vllm/lora/layers/utils.py +74 -0
  311. vllm/lora/layers/vocal_parallel_embedding.py +140 -0
  312. vllm/lora/lora_weights.py +227 -0
  313. vllm/lora/models.py +903 -0
  314. vllm/lora/ops/__init__.py +0 -0
  315. vllm/lora/ops/ipex_ops/__init__.py +6 -0
  316. vllm/lora/ops/ipex_ops/lora_ops.py +57 -0
  317. vllm/lora/ops/torch_ops/__init__.py +20 -0
  318. vllm/lora/ops/torch_ops/lora_ops.py +128 -0
  319. vllm/lora/ops/triton_ops/README_TUNING.md +60 -0
  320. vllm/lora/ops/triton_ops/__init__.py +21 -0
  321. vllm/lora/ops/triton_ops/fused_moe_lora_op.py +661 -0
  322. vllm/lora/ops/triton_ops/kernel_utils.py +340 -0
  323. vllm/lora/ops/triton_ops/lora_expand_op.py +310 -0
  324. vllm/lora/ops/triton_ops/lora_kernel_metadata.py +154 -0
  325. vllm/lora/ops/triton_ops/lora_shrink_op.py +287 -0
  326. vllm/lora/ops/triton_ops/utils.py +295 -0
  327. vllm/lora/ops/xla_ops/__init__.py +6 -0
  328. vllm/lora/ops/xla_ops/lora_ops.py +141 -0
  329. vllm/lora/peft_helper.py +128 -0
  330. vllm/lora/punica_wrapper/__init__.py +10 -0
  331. vllm/lora/punica_wrapper/punica_base.py +493 -0
  332. vllm/lora/punica_wrapper/punica_cpu.py +351 -0
  333. vllm/lora/punica_wrapper/punica_gpu.py +412 -0
  334. vllm/lora/punica_wrapper/punica_selector.py +21 -0
  335. vllm/lora/punica_wrapper/punica_tpu.py +358 -0
  336. vllm/lora/punica_wrapper/punica_xpu.py +276 -0
  337. vllm/lora/punica_wrapper/utils.py +150 -0
  338. vllm/lora/request.py +100 -0
  339. vllm/lora/resolver.py +88 -0
  340. vllm/lora/utils.py +306 -0
  341. vllm/lora/worker_manager.py +268 -0
  342. vllm/model_executor/__init__.py +11 -0
  343. vllm/model_executor/custom_op.py +194 -0
  344. vllm/model_executor/layers/__init__.py +0 -0
  345. vllm/model_executor/layers/activation.py +595 -0
  346. vllm/model_executor/layers/attention_layer_base.py +32 -0
  347. vllm/model_executor/layers/batch_invariant.py +1058 -0
  348. vllm/model_executor/layers/conv.py +256 -0
  349. vllm/model_executor/layers/fla/__init__.py +8 -0
  350. vllm/model_executor/layers/fla/ops/__init__.py +17 -0
  351. vllm/model_executor/layers/fla/ops/chunk.py +240 -0
  352. vllm/model_executor/layers/fla/ops/chunk_delta_h.py +344 -0
  353. vllm/model_executor/layers/fla/ops/chunk_o.py +183 -0
  354. vllm/model_executor/layers/fla/ops/chunk_scaled_dot_kkt.py +154 -0
  355. vllm/model_executor/layers/fla/ops/cumsum.py +280 -0
  356. vllm/model_executor/layers/fla/ops/fused_recurrent.py +390 -0
  357. vllm/model_executor/layers/fla/ops/index.py +41 -0
  358. vllm/model_executor/layers/fla/ops/kda.py +1351 -0
  359. vllm/model_executor/layers/fla/ops/l2norm.py +146 -0
  360. vllm/model_executor/layers/fla/ops/layernorm_guard.py +396 -0
  361. vllm/model_executor/layers/fla/ops/op.py +60 -0
  362. vllm/model_executor/layers/fla/ops/solve_tril.py +556 -0
  363. vllm/model_executor/layers/fla/ops/utils.py +194 -0
  364. vllm/model_executor/layers/fla/ops/wy_fast.py +158 -0
  365. vllm/model_executor/layers/fused_moe/__init__.py +110 -0
  366. vllm/model_executor/layers/fused_moe/all2all_utils.py +171 -0
  367. vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +406 -0
  368. vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +180 -0
  369. vllm/model_executor/layers/fused_moe/config.py +938 -0
  370. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  371. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  372. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  373. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  374. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  375. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  376. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  377. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  378. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json +146 -0
  379. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  380. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  381. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  382. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  383. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  384. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  385. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  386. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  387. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  388. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H100,dtype=fp8_w8a8.json +123 -0
  389. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  390. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=NVIDIA_H200.json +146 -0
  391. vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_H100_80GB_HBM3.json +147 -0
  392. vllm/model_executor/layers/fused_moe/configs/E=128,N=1856,device_name=NVIDIA_L40S.json +147 -0
  393. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  394. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  395. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20-3e.json +146 -0
  396. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  397. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  398. vllm/model_executor/layers/fused_moe/configs/E=128,N=352,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +122 -0
  399. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  400. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  401. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  402. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  403. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  404. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e.json +146 -0
  405. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  406. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  407. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  408. vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  409. vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  410. vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  411. vllm/model_executor/layers/fused_moe/configs/E=128,N=704,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +114 -0
  412. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  413. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI308X.json +213 -0
  414. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  415. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  416. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  417. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  418. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  419. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  420. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  421. vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json +82 -0
  422. vllm/model_executor/layers/fused_moe/configs/E=128,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +82 -0
  423. vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_H100_80GB_HBM3.json +147 -0
  424. vllm/model_executor/layers/fused_moe/configs/E=128,N=928,device_name=NVIDIA_L40S.json +147 -0
  425. vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  426. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  427. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +147 -0
  428. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  429. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json +146 -0
  430. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  431. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  432. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  433. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  434. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  435. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  436. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  437. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  438. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  439. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  440. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  441. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  442. vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  443. vllm/model_executor/layers/fused_moe/configs/E=16,N=2048,device_name=NVIDIA_H200.json +146 -0
  444. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  445. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  446. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  447. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  448. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  449. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H200,dtype=int8_w8a16.json +146 -0
  450. vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  451. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  452. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  453. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  454. vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  455. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  456. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  457. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=float8.json +146 -0
  458. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  459. vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  460. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI300X.json +201 -0
  461. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json +164 -0
  462. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  463. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H20-3e.json +146 -0
  464. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +147 -0
  465. vllm/model_executor/layers/fused_moe/configs/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  466. vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  467. vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI350_OAM,dtype=fp8_w8a8.json +164 -0
  468. vllm/model_executor/layers/fused_moe/configs/E=160,N=384,device_name=AMD_Instinct_MI355_OAM,dtype=fp8_w8a8.json +164 -0
  469. vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  470. vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  471. vllm/model_executor/layers/fused_moe/configs/E=160,N=640,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  472. vllm/model_executor/layers/fused_moe/configs/E=20,N=1536,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Server_Edition,dtype=fp8_w8a8.json +147 -0
  473. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  474. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  475. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  476. vllm/model_executor/layers/fused_moe/configs/E=20,N=2560,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  477. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
  478. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  479. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  480. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  481. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  482. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  483. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  484. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  485. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  486. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  487. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  488. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  489. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  490. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  491. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  492. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  493. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  494. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  495. vllm/model_executor/layers/fused_moe/configs/E=256,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  496. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  497. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  498. vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  499. vllm/model_executor/layers/fused_moe/configs/E=32,N=1408,device_name=NVIDIA_B200.json +147 -0
  500. vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  501. vllm/model_executor/layers/fused_moe/configs/E=32,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  502. vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  503. vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  504. vllm/model_executor/layers/fused_moe/configs/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  505. vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  506. vllm/model_executor/layers/fused_moe/configs/E=384,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  507. vllm/model_executor/layers/fused_moe/configs/E=40,N=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +147 -0
  508. vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  509. vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_GB200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  510. vllm/model_executor/layers/fused_moe/configs/E=40,N=2560,device_name=NVIDIA_H100,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  511. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_A100-SXM4-80GB.json +147 -0
  512. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_B200.json +146 -0
  513. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json +147 -0
  514. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  515. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  516. vllm/model_executor/layers/fused_moe/configs/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  517. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  518. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json +146 -0
  519. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  520. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  521. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  522. vllm/model_executor/layers/fused_moe/configs/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  523. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_B200.json +146 -0
  524. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_GB200,dtype=fp8_w8a8.json +146 -0
  525. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  526. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H20-3e.json +146 -0
  527. vllm/model_executor/layers/fused_moe/configs/E=512,N=512,device_name=NVIDIA_H200.json +146 -0
  528. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +147 -0
  529. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_B200.json +146 -0
  530. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H20-3e.json +146 -0
  531. vllm/model_executor/layers/fused_moe/configs/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  532. vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
  533. vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
  534. vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
  535. vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
  536. vllm/model_executor/layers/fused_moe/configs/E=62,N=128,device_name=AMD_Instinct_MI300X.json +200 -0
  537. vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=AMD_Instinct_MI300X.json +200 -0
  538. vllm/model_executor/layers/fused_moe/configs/E=62,N=256,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  539. vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=AMD_Instinct_MI300X.json +200 -0
  540. vllm/model_executor/layers/fused_moe/configs/E=62,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  541. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  542. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  543. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  544. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  545. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  546. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  547. vllm/model_executor/layers/fused_moe/configs/E=64,N=1408,device_name=NVIDIA_B200.json +147 -0
  548. vllm/model_executor/layers/fused_moe/configs/E=64,N=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  549. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  550. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  551. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  552. vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  553. vllm/model_executor/layers/fused_moe/configs/E=64,N=3072,device_name=NVIDIA_H20.json +146 -0
  554. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  555. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  556. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  557. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  558. vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  559. vllm/model_executor/layers/fused_moe/configs/E=64,N=384,device_name=NVIDIA_H20.json +146 -0
  560. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  561. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  562. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  563. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  564. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  565. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  566. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  567. vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H100_PCIe,dtype=fp8_w8a8,block_shape=[128,128].json +147 -0
  568. vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8.json +146 -0
  569. vllm/model_executor/layers/fused_moe/configs/E=64,N=768,device_name=NVIDIA_H20.json +146 -0
  570. vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json +146 -0
  571. vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=bf16.json +82 -0
  572. vllm/model_executor/layers/fused_moe/configs/E=64,N=8960,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +82 -0
  573. vllm/model_executor/layers/fused_moe/configs/E=72,N=192,device_name=AMD_Instinct_MI300X.json +200 -0
  574. vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=AMD_Instinct_MI300X.json +200 -0
  575. vllm/model_executor/layers/fused_moe/configs/E=72,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  576. vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=AMD_Instinct_MI300X.json +200 -0
  577. vllm/model_executor/layers/fused_moe/configs/E=72,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  578. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  579. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  580. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  581. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  582. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  583. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  584. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  585. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  586. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json +200 -0
  587. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  588. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json +200 -0
  589. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  590. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  591. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  592. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  593. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  594. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  595. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  596. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  597. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  598. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  599. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json +200 -0
  600. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  601. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json +200 -0
  602. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  603. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  604. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  605. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +154 -0
  606. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  607. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  608. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  609. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  610. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  611. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  612. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  613. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  614. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  615. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  616. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  617. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  618. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  619. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  620. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  621. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json +200 -0
  622. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  623. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json +200 -0
  624. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  625. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  626. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  627. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  628. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  629. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  630. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  631. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  632. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  633. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  634. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  635. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  636. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  637. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  638. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  639. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json +200 -0
  640. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  641. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json +200 -0
  642. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  643. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  644. vllm/model_executor/layers/fused_moe/configs/README +12 -0
  645. vllm/model_executor/layers/fused_moe/cpu_fused_moe.py +292 -0
  646. vllm/model_executor/layers/fused_moe/cutlass_moe.py +1052 -0
  647. vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +387 -0
  648. vllm/model_executor/layers/fused_moe/deep_gemm_utils.py +416 -0
  649. vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +420 -0
  650. vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +434 -0
  651. vllm/model_executor/layers/fused_moe/flashinfer_cutedsl_moe.py +376 -0
  652. vllm/model_executor/layers/fused_moe/flashinfer_cutlass_moe.py +307 -0
  653. vllm/model_executor/layers/fused_moe/flashinfer_cutlass_prepare_finalize.py +362 -0
  654. vllm/model_executor/layers/fused_moe/flashinfer_trtllm_moe.py +192 -0
  655. vllm/model_executor/layers/fused_moe/fused_batched_moe.py +1012 -0
  656. vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +821 -0
  657. vllm/model_executor/layers/fused_moe/fused_moe.py +2172 -0
  658. vllm/model_executor/layers/fused_moe/fused_moe_method_base.py +121 -0
  659. vllm/model_executor/layers/fused_moe/fused_moe_modular_method.py +136 -0
  660. vllm/model_executor/layers/fused_moe/gpt_oss_triton_kernels_moe.py +524 -0
  661. vllm/model_executor/layers/fused_moe/layer.py +2152 -0
  662. vllm/model_executor/layers/fused_moe/modular_kernel.py +1332 -0
  663. vllm/model_executor/layers/fused_moe/moe_align_block_size.py +174 -0
  664. vllm/model_executor/layers/fused_moe/moe_pallas.py +83 -0
  665. vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +229 -0
  666. vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +60 -0
  667. vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +362 -0
  668. vllm/model_executor/layers/fused_moe/prepare_finalize.py +78 -0
  669. vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +265 -0
  670. vllm/model_executor/layers/fused_moe/routing_simulator.py +310 -0
  671. vllm/model_executor/layers/fused_moe/shared_fused_moe.py +96 -0
  672. vllm/model_executor/layers/fused_moe/topk_weight_and_reduce.py +171 -0
  673. vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +163 -0
  674. vllm/model_executor/layers/fused_moe/trtllm_moe.py +143 -0
  675. vllm/model_executor/layers/fused_moe/unquantized_fused_moe_method.py +559 -0
  676. vllm/model_executor/layers/fused_moe/utils.py +332 -0
  677. vllm/model_executor/layers/kda.py +442 -0
  678. vllm/model_executor/layers/layernorm.py +442 -0
  679. vllm/model_executor/layers/lightning_attn.py +735 -0
  680. vllm/model_executor/layers/linear.py +1424 -0
  681. vllm/model_executor/layers/logits_processor.py +106 -0
  682. vllm/model_executor/layers/mamba/__init__.py +0 -0
  683. vllm/model_executor/layers/mamba/abstract.py +68 -0
  684. vllm/model_executor/layers/mamba/linear_attn.py +388 -0
  685. vllm/model_executor/layers/mamba/mamba_mixer.py +527 -0
  686. vllm/model_executor/layers/mamba/mamba_mixer2.py +930 -0
  687. vllm/model_executor/layers/mamba/mamba_utils.py +225 -0
  688. vllm/model_executor/layers/mamba/ops/__init__.py +0 -0
  689. vllm/model_executor/layers/mamba/ops/causal_conv1d.py +1240 -0
  690. vllm/model_executor/layers/mamba/ops/layernorm_gated.py +172 -0
  691. vllm/model_executor/layers/mamba/ops/mamba_ssm.py +478 -0
  692. vllm/model_executor/layers/mamba/ops/ssd_bmm.py +211 -0
  693. vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +456 -0
  694. vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +700 -0
  695. vllm/model_executor/layers/mamba/ops/ssd_combined.py +230 -0
  696. vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +157 -0
  697. vllm/model_executor/layers/mamba/short_conv.py +255 -0
  698. vllm/model_executor/layers/mla.py +176 -0
  699. vllm/model_executor/layers/pooler.py +817 -0
  700. vllm/model_executor/layers/quantization/__init__.py +179 -0
  701. vllm/model_executor/layers/quantization/auto_round.py +454 -0
  702. vllm/model_executor/layers/quantization/awq.py +277 -0
  703. vllm/model_executor/layers/quantization/awq_marlin.py +718 -0
  704. vllm/model_executor/layers/quantization/awq_triton.py +337 -0
  705. vllm/model_executor/layers/quantization/base_config.py +170 -0
  706. vllm/model_executor/layers/quantization/bitblas.py +502 -0
  707. vllm/model_executor/layers/quantization/bitsandbytes.py +644 -0
  708. vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +3 -0
  709. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +963 -0
  710. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2387 -0
  711. vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +35 -0
  712. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +392 -0
  713. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +55 -0
  714. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +176 -0
  715. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +124 -0
  716. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +218 -0
  717. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_fp8.py +183 -0
  718. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a8_int.py +153 -0
  719. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +138 -0
  720. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +200 -0
  721. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +125 -0
  722. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +230 -0
  723. vllm/model_executor/layers/quantization/compressed_tensors/transform/__init__.py +0 -0
  724. vllm/model_executor/layers/quantization/compressed_tensors/transform/linear.py +260 -0
  725. vllm/model_executor/layers/quantization/compressed_tensors/transform/module.py +173 -0
  726. vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/__init__.py +0 -0
  727. vllm/model_executor/layers/quantization/compressed_tensors/transform/schemes/linear_qutlass_nvfp4.py +64 -0
  728. vllm/model_executor/layers/quantization/compressed_tensors/transform/utils.py +13 -0
  729. vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +224 -0
  730. vllm/model_executor/layers/quantization/compressed_tensors/utils.py +216 -0
  731. vllm/model_executor/layers/quantization/cpu_wna16.py +625 -0
  732. vllm/model_executor/layers/quantization/deepspeedfp.py +218 -0
  733. vllm/model_executor/layers/quantization/experts_int8.py +225 -0
  734. vllm/model_executor/layers/quantization/fbgemm_fp8.py +195 -0
  735. vllm/model_executor/layers/quantization/fp8.py +1348 -0
  736. vllm/model_executor/layers/quantization/fp_quant.py +420 -0
  737. vllm/model_executor/layers/quantization/gguf.py +687 -0
  738. vllm/model_executor/layers/quantization/gptq.py +393 -0
  739. vllm/model_executor/layers/quantization/gptq_bitblas.py +482 -0
  740. vllm/model_executor/layers/quantization/gptq_marlin.py +842 -0
  741. vllm/model_executor/layers/quantization/gptq_marlin_24.py +320 -0
  742. vllm/model_executor/layers/quantization/hqq_marlin.py +372 -0
  743. vllm/model_executor/layers/quantization/inc.py +65 -0
  744. vllm/model_executor/layers/quantization/input_quant_fp8.py +171 -0
  745. vllm/model_executor/layers/quantization/ipex_quant.py +470 -0
  746. vllm/model_executor/layers/quantization/kernels/__init__.py +0 -0
  747. vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +94 -0
  748. vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +105 -0
  749. vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +115 -0
  750. vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +323 -0
  751. vllm/model_executor/layers/quantization/kernels/mixed_precision/conch.py +98 -0
  752. vllm/model_executor/layers/quantization/kernels/mixed_precision/cutlass.py +119 -0
  753. vllm/model_executor/layers/quantization/kernels/mixed_precision/dynamic_4bit.py +111 -0
  754. vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +161 -0
  755. vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +159 -0
  756. vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +200 -0
  757. vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +73 -0
  758. vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +97 -0
  759. vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +120 -0
  760. vllm/model_executor/layers/quantization/kernels/scaled_mm/cpu.py +219 -0
  761. vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +140 -0
  762. vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +42 -0
  763. vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +105 -0
  764. vllm/model_executor/layers/quantization/kv_cache.py +146 -0
  765. vllm/model_executor/layers/quantization/modelopt.py +1637 -0
  766. vllm/model_executor/layers/quantization/moe_wna16.py +528 -0
  767. vllm/model_executor/layers/quantization/mxfp4.py +1175 -0
  768. vllm/model_executor/layers/quantization/petit.py +319 -0
  769. vllm/model_executor/layers/quantization/ptpc_fp8.py +136 -0
  770. vllm/model_executor/layers/quantization/quark/__init__.py +0 -0
  771. vllm/model_executor/layers/quantization/quark/quark.py +527 -0
  772. vllm/model_executor/layers/quantization/quark/quark_moe.py +653 -0
  773. vllm/model_executor/layers/quantization/quark/schemes/__init__.py +9 -0
  774. vllm/model_executor/layers/quantization/quark/schemes/quark_ocp_mx.py +343 -0
  775. vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  776. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +179 -0
  777. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +139 -0
  778. vllm/model_executor/layers/quantization/quark/utils.py +105 -0
  779. vllm/model_executor/layers/quantization/qutlass_utils.py +185 -0
  780. vllm/model_executor/layers/quantization/rtn.py +639 -0
  781. vllm/model_executor/layers/quantization/schema.py +90 -0
  782. vllm/model_executor/layers/quantization/torchao.py +380 -0
  783. vllm/model_executor/layers/quantization/tpu_int8.py +139 -0
  784. vllm/model_executor/layers/quantization/utils/__init__.py +6 -0
  785. vllm/model_executor/layers/quantization/utils/allspark_utils.py +67 -0
  786. vllm/model_executor/layers/quantization/utils/bitblas_utils.py +229 -0
  787. vllm/model_executor/layers/quantization/utils/configs/N=10240,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  788. vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  789. vllm/model_executor/layers/quantization/utils/configs/N=12288,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  790. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  791. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  792. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  793. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  794. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  795. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  796. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  797. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  798. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  799. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  800. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  801. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  802. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  803. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  804. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  805. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  806. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  807. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  808. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  809. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  810. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  811. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  812. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  813. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  814. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  815. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  816. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  817. vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  818. vllm/model_executor/layers/quantization/utils/configs/N=2112,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  819. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  820. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  821. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  822. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  823. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  824. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  825. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  826. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  827. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  828. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  829. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  830. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  831. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  832. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  833. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  834. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  835. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  836. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  837. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  838. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  839. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  840. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  841. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  842. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  843. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  844. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  845. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  846. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  847. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  848. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  849. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  850. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  851. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  852. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  853. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  854. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  855. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  856. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  857. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  858. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  859. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  860. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  861. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  862. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  863. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  864. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  865. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  866. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  867. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  868. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  869. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  870. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  871. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  872. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  873. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  874. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  875. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  876. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  877. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  878. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  879. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  880. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  881. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  882. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  883. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  884. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  885. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  886. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  887. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  888. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  889. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  890. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  891. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  892. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  893. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  894. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  895. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  896. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  897. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  898. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  899. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  900. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  901. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  902. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  903. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  904. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  905. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  906. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  907. vllm/model_executor/layers/quantization/utils/configs/N=5120,K=25600,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  908. vllm/model_executor/layers/quantization/utils/configs/N=5120,K=8192,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  909. vllm/model_executor/layers/quantization/utils/configs/N=51200,K=5120,device_name=NVIDIA_L40S,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  910. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  911. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  912. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  913. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  914. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  915. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  916. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  917. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  918. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  919. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  920. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +18 -0
  921. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  922. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  923. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  924. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  925. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  926. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  927. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  928. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  929. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  930. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  931. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  932. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  933. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  934. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  935. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  936. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  937. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  938. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  939. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  940. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  941. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  942. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  943. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  944. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  945. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  946. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  947. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  948. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  949. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  950. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  951. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  952. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  953. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  954. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  955. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  956. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  957. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  958. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  959. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  960. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  961. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  962. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  963. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  964. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  965. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  966. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  967. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  968. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  969. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  970. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  971. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  972. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  973. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  974. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  975. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  976. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  977. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  978. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  979. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  980. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  981. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  982. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  983. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  984. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  985. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  986. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  987. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  988. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  989. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  990. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  991. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  992. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  993. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  994. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  995. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  996. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  997. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  998. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  999. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  1000. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  1001. vllm/model_executor/layers/quantization/utils/configs/README.md +3 -0
  1002. vllm/model_executor/layers/quantization/utils/flashinfer_fp4_moe.py +333 -0
  1003. vllm/model_executor/layers/quantization/utils/flashinfer_utils.py +311 -0
  1004. vllm/model_executor/layers/quantization/utils/fp8_utils.py +1203 -0
  1005. vllm/model_executor/layers/quantization/utils/gptq_utils.py +158 -0
  1006. vllm/model_executor/layers/quantization/utils/int8_utils.py +489 -0
  1007. vllm/model_executor/layers/quantization/utils/layer_utils.py +41 -0
  1008. vllm/model_executor/layers/quantization/utils/machete_utils.py +56 -0
  1009. vllm/model_executor/layers/quantization/utils/marlin_utils.py +674 -0
  1010. vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +452 -0
  1011. vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +378 -0
  1012. vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +219 -0
  1013. vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +467 -0
  1014. vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +183 -0
  1015. vllm/model_executor/layers/quantization/utils/mxfp6_utils.py +142 -0
  1016. vllm/model_executor/layers/quantization/utils/mxfp8_utils.py +24 -0
  1017. vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +142 -0
  1018. vllm/model_executor/layers/quantization/utils/nvfp4_moe_support.py +67 -0
  1019. vllm/model_executor/layers/quantization/utils/ocp_mx_utils.py +51 -0
  1020. vllm/model_executor/layers/quantization/utils/petit_utils.py +124 -0
  1021. vllm/model_executor/layers/quantization/utils/quant_utils.py +687 -0
  1022. vllm/model_executor/layers/quantization/utils/w8a8_utils.py +516 -0
  1023. vllm/model_executor/layers/resampler.py +283 -0
  1024. vllm/model_executor/layers/rotary_embedding/__init__.py +292 -0
  1025. vllm/model_executor/layers/rotary_embedding/base.py +240 -0
  1026. vllm/model_executor/layers/rotary_embedding/common.py +188 -0
  1027. vllm/model_executor/layers/rotary_embedding/deepseek_scaling_rope.py +165 -0
  1028. vllm/model_executor/layers/rotary_embedding/dual_chunk_rope.py +215 -0
  1029. vllm/model_executor/layers/rotary_embedding/dynamic_ntk_alpha_rope.py +43 -0
  1030. vllm/model_executor/layers/rotary_embedding/dynamic_ntk_scaling_rope.py +68 -0
  1031. vllm/model_executor/layers/rotary_embedding/ernie45_vl_rope.py +75 -0
  1032. vllm/model_executor/layers/rotary_embedding/linear_scaling_rope.py +115 -0
  1033. vllm/model_executor/layers/rotary_embedding/llama3_rope.py +54 -0
  1034. vllm/model_executor/layers/rotary_embedding/llama4_vision_rope.py +80 -0
  1035. vllm/model_executor/layers/rotary_embedding/mrope.py +397 -0
  1036. vllm/model_executor/layers/rotary_embedding/ntk_scaling_rope.py +47 -0
  1037. vllm/model_executor/layers/rotary_embedding/phi3_long_rope_scaled_rope.py +159 -0
  1038. vllm/model_executor/layers/rotary_embedding/xdrope.py +102 -0
  1039. vllm/model_executor/layers/rotary_embedding/yarn_scaling_rope.py +84 -0
  1040. vllm/model_executor/layers/utils.py +251 -0
  1041. vllm/model_executor/layers/vocab_parallel_embedding.py +558 -0
  1042. vllm/model_executor/model_loader/__init__.py +150 -0
  1043. vllm/model_executor/model_loader/base_loader.py +57 -0
  1044. vllm/model_executor/model_loader/bitsandbytes_loader.py +822 -0
  1045. vllm/model_executor/model_loader/default_loader.py +321 -0
  1046. vllm/model_executor/model_loader/dummy_loader.py +28 -0
  1047. vllm/model_executor/model_loader/gguf_loader.py +349 -0
  1048. vllm/model_executor/model_loader/online_quantization.py +275 -0
  1049. vllm/model_executor/model_loader/runai_streamer_loader.py +116 -0
  1050. vllm/model_executor/model_loader/sharded_state_loader.py +214 -0
  1051. vllm/model_executor/model_loader/tensorizer.py +790 -0
  1052. vllm/model_executor/model_loader/tensorizer_loader.py +151 -0
  1053. vllm/model_executor/model_loader/tpu.py +118 -0
  1054. vllm/model_executor/model_loader/utils.py +296 -0
  1055. vllm/model_executor/model_loader/weight_utils.py +1147 -0
  1056. vllm/model_executor/models/__init__.py +44 -0
  1057. vllm/model_executor/models/adapters.py +543 -0
  1058. vllm/model_executor/models/afmoe.py +697 -0
  1059. vllm/model_executor/models/aimv2.py +248 -0
  1060. vllm/model_executor/models/apertus.py +569 -0
  1061. vllm/model_executor/models/arcee.py +428 -0
  1062. vllm/model_executor/models/arctic.py +634 -0
  1063. vllm/model_executor/models/aria.py +655 -0
  1064. vllm/model_executor/models/aya_vision.py +450 -0
  1065. vllm/model_executor/models/baichuan.py +494 -0
  1066. vllm/model_executor/models/bailing_moe.py +645 -0
  1067. vllm/model_executor/models/bamba.py +516 -0
  1068. vllm/model_executor/models/bee.py +157 -0
  1069. vllm/model_executor/models/bert.py +925 -0
  1070. vllm/model_executor/models/bert_with_rope.py +732 -0
  1071. vllm/model_executor/models/blip.py +350 -0
  1072. vllm/model_executor/models/blip2.py +695 -0
  1073. vllm/model_executor/models/bloom.py +390 -0
  1074. vllm/model_executor/models/chameleon.py +1098 -0
  1075. vllm/model_executor/models/chatglm.py +499 -0
  1076. vllm/model_executor/models/clip.py +1005 -0
  1077. vllm/model_executor/models/cohere2_vision.py +472 -0
  1078. vllm/model_executor/models/commandr.py +470 -0
  1079. vllm/model_executor/models/config.py +510 -0
  1080. vllm/model_executor/models/dbrx.py +485 -0
  1081. vllm/model_executor/models/deepencoder.py +676 -0
  1082. vllm/model_executor/models/deepseek_eagle.py +252 -0
  1083. vllm/model_executor/models/deepseek_mtp.py +446 -0
  1084. vllm/model_executor/models/deepseek_ocr.py +593 -0
  1085. vllm/model_executor/models/deepseek_v2.py +1715 -0
  1086. vllm/model_executor/models/deepseek_vl2.py +644 -0
  1087. vllm/model_executor/models/dots1.py +566 -0
  1088. vllm/model_executor/models/dots_ocr.py +874 -0
  1089. vllm/model_executor/models/ernie45.py +53 -0
  1090. vllm/model_executor/models/ernie45_moe.py +755 -0
  1091. vllm/model_executor/models/ernie45_vl.py +1710 -0
  1092. vllm/model_executor/models/ernie45_vl_moe.py +800 -0
  1093. vllm/model_executor/models/ernie_mtp.py +279 -0
  1094. vllm/model_executor/models/exaone.py +525 -0
  1095. vllm/model_executor/models/exaone4.py +517 -0
  1096. vllm/model_executor/models/fairseq2_llama.py +154 -0
  1097. vllm/model_executor/models/falcon.py +544 -0
  1098. vllm/model_executor/models/falcon_h1.py +680 -0
  1099. vllm/model_executor/models/flex_olmo.py +155 -0
  1100. vllm/model_executor/models/fuyu.py +373 -0
  1101. vllm/model_executor/models/gemma.py +426 -0
  1102. vllm/model_executor/models/gemma2.py +436 -0
  1103. vllm/model_executor/models/gemma3.py +577 -0
  1104. vllm/model_executor/models/gemma3_mm.py +665 -0
  1105. vllm/model_executor/models/gemma3n.py +1167 -0
  1106. vllm/model_executor/models/gemma3n_mm.py +811 -0
  1107. vllm/model_executor/models/glm.py +23 -0
  1108. vllm/model_executor/models/glm4.py +298 -0
  1109. vllm/model_executor/models/glm4_1v.py +1854 -0
  1110. vllm/model_executor/models/glm4_moe.py +738 -0
  1111. vllm/model_executor/models/glm4_moe_mtp.py +359 -0
  1112. vllm/model_executor/models/glm4v.py +785 -0
  1113. vllm/model_executor/models/gpt2.py +397 -0
  1114. vllm/model_executor/models/gpt_bigcode.py +339 -0
  1115. vllm/model_executor/models/gpt_j.py +345 -0
  1116. vllm/model_executor/models/gpt_neox.py +343 -0
  1117. vllm/model_executor/models/gpt_oss.py +745 -0
  1118. vllm/model_executor/models/granite.py +476 -0
  1119. vllm/model_executor/models/granite_speech.py +913 -0
  1120. vllm/model_executor/models/granitemoe.py +561 -0
  1121. vllm/model_executor/models/granitemoehybrid.py +704 -0
  1122. vllm/model_executor/models/granitemoeshared.py +328 -0
  1123. vllm/model_executor/models/gritlm.py +245 -0
  1124. vllm/model_executor/models/grok1.py +555 -0
  1125. vllm/model_executor/models/h2ovl.py +554 -0
  1126. vllm/model_executor/models/hunyuan_v1.py +1042 -0
  1127. vllm/model_executor/models/hunyuan_vision.py +1028 -0
  1128. vllm/model_executor/models/hyperclovax_vision.py +1166 -0
  1129. vllm/model_executor/models/idefics2_vision_model.py +427 -0
  1130. vllm/model_executor/models/idefics3.py +718 -0
  1131. vllm/model_executor/models/interfaces.py +1148 -0
  1132. vllm/model_executor/models/interfaces_base.py +243 -0
  1133. vllm/model_executor/models/intern_vit.py +454 -0
  1134. vllm/model_executor/models/internlm2.py +454 -0
  1135. vllm/model_executor/models/internlm2_ve.py +139 -0
  1136. vllm/model_executor/models/interns1.py +830 -0
  1137. vllm/model_executor/models/interns1_vit.py +433 -0
  1138. vllm/model_executor/models/internvl.py +1452 -0
  1139. vllm/model_executor/models/jais.py +397 -0
  1140. vllm/model_executor/models/jamba.py +609 -0
  1141. vllm/model_executor/models/jina_vl.py +147 -0
  1142. vllm/model_executor/models/keye.py +1765 -0
  1143. vllm/model_executor/models/keye_vl1_5.py +726 -0
  1144. vllm/model_executor/models/kimi_linear.py +658 -0
  1145. vllm/model_executor/models/kimi_vl.py +578 -0
  1146. vllm/model_executor/models/lfm2.py +516 -0
  1147. vllm/model_executor/models/lfm2_moe.py +746 -0
  1148. vllm/model_executor/models/lightonocr.py +195 -0
  1149. vllm/model_executor/models/llama.py +704 -0
  1150. vllm/model_executor/models/llama4.py +857 -0
  1151. vllm/model_executor/models/llama4_eagle.py +216 -0
  1152. vllm/model_executor/models/llama_eagle.py +213 -0
  1153. vllm/model_executor/models/llama_eagle3.py +375 -0
  1154. vllm/model_executor/models/llava.py +842 -0
  1155. vllm/model_executor/models/llava_next.py +583 -0
  1156. vllm/model_executor/models/llava_next_video.py +467 -0
  1157. vllm/model_executor/models/llava_onevision.py +923 -0
  1158. vllm/model_executor/models/longcat_flash.py +743 -0
  1159. vllm/model_executor/models/longcat_flash_mtp.py +349 -0
  1160. vllm/model_executor/models/mamba.py +276 -0
  1161. vllm/model_executor/models/mamba2.py +288 -0
  1162. vllm/model_executor/models/medusa.py +179 -0
  1163. vllm/model_executor/models/midashenglm.py +828 -0
  1164. vllm/model_executor/models/mimo.py +188 -0
  1165. vllm/model_executor/models/mimo_mtp.py +294 -0
  1166. vllm/model_executor/models/minicpm.py +657 -0
  1167. vllm/model_executor/models/minicpm3.py +234 -0
  1168. vllm/model_executor/models/minicpm_eagle.py +385 -0
  1169. vllm/model_executor/models/minicpmo.py +768 -0
  1170. vllm/model_executor/models/minicpmv.py +1744 -0
  1171. vllm/model_executor/models/minimax_m2.py +546 -0
  1172. vllm/model_executor/models/minimax_text_01.py +1010 -0
  1173. vllm/model_executor/models/minimax_vl_01.py +396 -0
  1174. vllm/model_executor/models/mistral3.py +637 -0
  1175. vllm/model_executor/models/mistral_large_3.py +63 -0
  1176. vllm/model_executor/models/mistral_large_3_eagle.py +165 -0
  1177. vllm/model_executor/models/mixtral.py +599 -0
  1178. vllm/model_executor/models/mllama4.py +1151 -0
  1179. vllm/model_executor/models/mlp_speculator.py +235 -0
  1180. vllm/model_executor/models/modernbert.py +452 -0
  1181. vllm/model_executor/models/module_mapping.py +74 -0
  1182. vllm/model_executor/models/molmo.py +1553 -0
  1183. vllm/model_executor/models/moonvit.py +686 -0
  1184. vllm/model_executor/models/mpt.py +335 -0
  1185. vllm/model_executor/models/nano_nemotron_vl.py +1732 -0
  1186. vllm/model_executor/models/nemotron.py +502 -0
  1187. vllm/model_executor/models/nemotron_h.py +850 -0
  1188. vllm/model_executor/models/nemotron_nas.py +473 -0
  1189. vllm/model_executor/models/nemotron_vl.py +653 -0
  1190. vllm/model_executor/models/nvlm_d.py +216 -0
  1191. vllm/model_executor/models/olmo.py +413 -0
  1192. vllm/model_executor/models/olmo2.py +455 -0
  1193. vllm/model_executor/models/olmoe.py +494 -0
  1194. vllm/model_executor/models/opencua.py +271 -0
  1195. vllm/model_executor/models/openpangu.py +1051 -0
  1196. vllm/model_executor/models/openpangu_mtp.py +265 -0
  1197. vllm/model_executor/models/opt.py +426 -0
  1198. vllm/model_executor/models/orion.py +366 -0
  1199. vllm/model_executor/models/ouro.py +508 -0
  1200. vllm/model_executor/models/ovis.py +559 -0
  1201. vllm/model_executor/models/ovis2_5.py +673 -0
  1202. vllm/model_executor/models/paddleocr_vl.py +1380 -0
  1203. vllm/model_executor/models/paligemma.py +412 -0
  1204. vllm/model_executor/models/persimmon.py +376 -0
  1205. vllm/model_executor/models/phi.py +370 -0
  1206. vllm/model_executor/models/phi3.py +18 -0
  1207. vllm/model_executor/models/phi3v.py +737 -0
  1208. vllm/model_executor/models/phi4_multimodal.py +1447 -0
  1209. vllm/model_executor/models/phi4mm.py +1253 -0
  1210. vllm/model_executor/models/phi4mm_audio.py +1296 -0
  1211. vllm/model_executor/models/phi4mm_utils.py +1907 -0
  1212. vllm/model_executor/models/phimoe.py +670 -0
  1213. vllm/model_executor/models/pixtral.py +1380 -0
  1214. vllm/model_executor/models/plamo2.py +966 -0
  1215. vllm/model_executor/models/plamo3.py +441 -0
  1216. vllm/model_executor/models/qwen.py +363 -0
  1217. vllm/model_executor/models/qwen2.py +569 -0
  1218. vllm/model_executor/models/qwen2_5_omni_thinker.py +1220 -0
  1219. vllm/model_executor/models/qwen2_5_vl.py +1594 -0
  1220. vllm/model_executor/models/qwen2_audio.py +473 -0
  1221. vllm/model_executor/models/qwen2_moe.py +590 -0
  1222. vllm/model_executor/models/qwen2_rm.py +123 -0
  1223. vllm/model_executor/models/qwen2_vl.py +1593 -0
  1224. vllm/model_executor/models/qwen3.py +332 -0
  1225. vllm/model_executor/models/qwen3_moe.py +738 -0
  1226. vllm/model_executor/models/qwen3_next.py +1390 -0
  1227. vllm/model_executor/models/qwen3_next_mtp.py +296 -0
  1228. vllm/model_executor/models/qwen3_omni_moe_thinker.py +1765 -0
  1229. vllm/model_executor/models/qwen3_vl.py +1686 -0
  1230. vllm/model_executor/models/qwen3_vl_moe.py +470 -0
  1231. vllm/model_executor/models/qwen_vl.py +803 -0
  1232. vllm/model_executor/models/radio.py +555 -0
  1233. vllm/model_executor/models/registry.py +1183 -0
  1234. vllm/model_executor/models/roberta.py +259 -0
  1235. vllm/model_executor/models/rvl.py +107 -0
  1236. vllm/model_executor/models/seed_oss.py +493 -0
  1237. vllm/model_executor/models/siglip.py +1245 -0
  1238. vllm/model_executor/models/siglip2navit.py +723 -0
  1239. vllm/model_executor/models/skyworkr1v.py +953 -0
  1240. vllm/model_executor/models/smolvlm.py +38 -0
  1241. vllm/model_executor/models/solar.py +485 -0
  1242. vllm/model_executor/models/stablelm.py +359 -0
  1243. vllm/model_executor/models/starcoder2.py +366 -0
  1244. vllm/model_executor/models/step3_text.py +555 -0
  1245. vllm/model_executor/models/step3_vl.py +1149 -0
  1246. vllm/model_executor/models/swin.py +514 -0
  1247. vllm/model_executor/models/tarsier.py +619 -0
  1248. vllm/model_executor/models/telechat2.py +153 -0
  1249. vllm/model_executor/models/teleflm.py +78 -0
  1250. vllm/model_executor/models/terratorch.py +319 -0
  1251. vllm/model_executor/models/transformers/__init__.py +127 -0
  1252. vllm/model_executor/models/transformers/base.py +464 -0
  1253. vllm/model_executor/models/transformers/causal.py +65 -0
  1254. vllm/model_executor/models/transformers/legacy.py +90 -0
  1255. vllm/model_executor/models/transformers/moe.py +325 -0
  1256. vllm/model_executor/models/transformers/multimodal.py +411 -0
  1257. vllm/model_executor/models/transformers/pooling.py +119 -0
  1258. vllm/model_executor/models/transformers/utils.py +213 -0
  1259. vllm/model_executor/models/ultravox.py +686 -0
  1260. vllm/model_executor/models/utils.py +832 -0
  1261. vllm/model_executor/models/vision.py +552 -0
  1262. vllm/model_executor/models/voxtral.py +842 -0
  1263. vllm/model_executor/models/whisper.py +963 -0
  1264. vllm/model_executor/models/zamba2.py +980 -0
  1265. vllm/model_executor/parameter.py +642 -0
  1266. vllm/model_executor/utils.py +94 -0
  1267. vllm/model_executor/warmup/__init__.py +0 -0
  1268. vllm/model_executor/warmup/deep_gemm_warmup.py +314 -0
  1269. vllm/model_executor/warmup/kernel_warmup.py +98 -0
  1270. vllm/multimodal/__init__.py +40 -0
  1271. vllm/multimodal/audio.py +142 -0
  1272. vllm/multimodal/base.py +26 -0
  1273. vllm/multimodal/cache.py +830 -0
  1274. vllm/multimodal/evs.py +294 -0
  1275. vllm/multimodal/hasher.py +106 -0
  1276. vllm/multimodal/image.py +130 -0
  1277. vllm/multimodal/inputs.py +1036 -0
  1278. vllm/multimodal/parse.py +544 -0
  1279. vllm/multimodal/processing.py +2240 -0
  1280. vllm/multimodal/profiling.py +369 -0
  1281. vllm/multimodal/registry.py +357 -0
  1282. vllm/multimodal/utils.py +523 -0
  1283. vllm/multimodal/video.py +333 -0
  1284. vllm/outputs.py +345 -0
  1285. vllm/platforms/__init__.py +277 -0
  1286. vllm/platforms/cpu.py +410 -0
  1287. vllm/platforms/cuda.py +642 -0
  1288. vllm/platforms/interface.py +656 -0
  1289. vllm/platforms/rocm.py +513 -0
  1290. vllm/platforms/tpu.py +275 -0
  1291. vllm/platforms/xpu.py +261 -0
  1292. vllm/plugins/__init__.py +81 -0
  1293. vllm/plugins/io_processors/__init__.py +68 -0
  1294. vllm/plugins/io_processors/interface.py +77 -0
  1295. vllm/plugins/lora_resolvers/__init__.py +0 -0
  1296. vllm/plugins/lora_resolvers/filesystem_resolver.py +52 -0
  1297. vllm/pooling_params.py +230 -0
  1298. vllm/profiler/__init__.py +0 -0
  1299. vllm/profiler/gpu_profiler.py +216 -0
  1300. vllm/profiler/layerwise_profile.py +392 -0
  1301. vllm/profiler/utils.py +151 -0
  1302. vllm/py.typed +2 -0
  1303. vllm/ray/__init__.py +0 -0
  1304. vllm/ray/lazy_utils.py +30 -0
  1305. vllm/ray/ray_env.py +79 -0
  1306. vllm/reasoning/__init__.py +92 -0
  1307. vllm/reasoning/abs_reasoning_parsers.py +290 -0
  1308. vllm/reasoning/basic_parsers.py +162 -0
  1309. vllm/reasoning/deepseek_r1_reasoning_parser.py +67 -0
  1310. vllm/reasoning/deepseek_v3_reasoning_parser.py +62 -0
  1311. vllm/reasoning/ernie45_reasoning_parser.py +165 -0
  1312. vllm/reasoning/glm4_moe_reasoning_parser.py +171 -0
  1313. vllm/reasoning/gptoss_reasoning_parser.py +173 -0
  1314. vllm/reasoning/granite_reasoning_parser.py +363 -0
  1315. vllm/reasoning/hunyuan_a13b_reasoning_parser.py +237 -0
  1316. vllm/reasoning/identity_reasoning_parser.py +58 -0
  1317. vllm/reasoning/minimax_m2_reasoning_parser.py +67 -0
  1318. vllm/reasoning/mistral_reasoning_parser.py +55 -0
  1319. vllm/reasoning/olmo3_reasoning_parser.py +302 -0
  1320. vllm/reasoning/qwen3_reasoning_parser.py +67 -0
  1321. vllm/reasoning/seedoss_reasoning_parser.py +27 -0
  1322. vllm/reasoning/step3_reasoning_parser.py +107 -0
  1323. vllm/sampling_params.py +597 -0
  1324. vllm/scalar_type.py +355 -0
  1325. vllm/scripts.py +17 -0
  1326. vllm/sequence.py +98 -0
  1327. vllm/tasks.py +13 -0
  1328. vllm/third_party/__init__.py +0 -0
  1329. vllm/third_party/pynvml.py +6140 -0
  1330. vllm/tokenizers/__init__.py +24 -0
  1331. vllm/tokenizers/detokenizer_utils.py +198 -0
  1332. vllm/tokenizers/hf.py +124 -0
  1333. vllm/tokenizers/mistral.py +554 -0
  1334. vllm/tokenizers/protocol.py +111 -0
  1335. vllm/tokenizers/registry.py +233 -0
  1336. vllm/tracing.py +135 -0
  1337. vllm/transformers_utils/__init__.py +26 -0
  1338. vllm/transformers_utils/chat_templates/__init__.py +5 -0
  1339. vllm/transformers_utils/chat_templates/registry.py +73 -0
  1340. vllm/transformers_utils/chat_templates/template_basic.jinja +3 -0
  1341. vllm/transformers_utils/chat_templates/template_blip2.jinja +11 -0
  1342. vllm/transformers_utils/chat_templates/template_chatml.jinja +10 -0
  1343. vllm/transformers_utils/chat_templates/template_deepseek_ocr.jinja +14 -0
  1344. vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja +23 -0
  1345. vllm/transformers_utils/chat_templates/template_fuyu.jinja +3 -0
  1346. vllm/transformers_utils/chat_templates/template_minicpmv45.jinja +93 -0
  1347. vllm/transformers_utils/config.py +1081 -0
  1348. vllm/transformers_utils/config_parser_base.py +20 -0
  1349. vllm/transformers_utils/configs/__init__.py +84 -0
  1350. vllm/transformers_utils/configs/afmoe.py +87 -0
  1351. vllm/transformers_utils/configs/arctic.py +216 -0
  1352. vllm/transformers_utils/configs/chatglm.py +75 -0
  1353. vllm/transformers_utils/configs/deepseek_vl2.py +126 -0
  1354. vllm/transformers_utils/configs/dotsocr.py +71 -0
  1355. vllm/transformers_utils/configs/eagle.py +90 -0
  1356. vllm/transformers_utils/configs/falcon.py +89 -0
  1357. vllm/transformers_utils/configs/flex_olmo.py +82 -0
  1358. vllm/transformers_utils/configs/hunyuan_vl.py +322 -0
  1359. vllm/transformers_utils/configs/jais.py +243 -0
  1360. vllm/transformers_utils/configs/kimi_linear.py +148 -0
  1361. vllm/transformers_utils/configs/kimi_vl.py +38 -0
  1362. vllm/transformers_utils/configs/lfm2_moe.py +163 -0
  1363. vllm/transformers_utils/configs/medusa.py +65 -0
  1364. vllm/transformers_utils/configs/midashenglm.py +103 -0
  1365. vllm/transformers_utils/configs/mistral.py +235 -0
  1366. vllm/transformers_utils/configs/mlp_speculator.py +69 -0
  1367. vllm/transformers_utils/configs/moonvit.py +33 -0
  1368. vllm/transformers_utils/configs/nemotron.py +214 -0
  1369. vllm/transformers_utils/configs/nemotron_h.py +282 -0
  1370. vllm/transformers_utils/configs/olmo3.py +83 -0
  1371. vllm/transformers_utils/configs/ovis.py +182 -0
  1372. vllm/transformers_utils/configs/qwen3_next.py +275 -0
  1373. vllm/transformers_utils/configs/radio.py +89 -0
  1374. vllm/transformers_utils/configs/speculators/__init__.py +2 -0
  1375. vllm/transformers_utils/configs/speculators/algos.py +38 -0
  1376. vllm/transformers_utils/configs/speculators/base.py +114 -0
  1377. vllm/transformers_utils/configs/step3_vl.py +178 -0
  1378. vllm/transformers_utils/configs/ultravox.py +118 -0
  1379. vllm/transformers_utils/dynamic_module.py +59 -0
  1380. vllm/transformers_utils/gguf_utils.py +209 -0
  1381. vllm/transformers_utils/processor.py +423 -0
  1382. vllm/transformers_utils/processors/__init__.py +23 -0
  1383. vllm/transformers_utils/processors/deepseek_ocr.py +438 -0
  1384. vllm/transformers_utils/processors/deepseek_vl2.py +406 -0
  1385. vllm/transformers_utils/processors/hunyuan_vl.py +233 -0
  1386. vllm/transformers_utils/processors/hunyuan_vl_image.py +477 -0
  1387. vllm/transformers_utils/processors/ovis.py +453 -0
  1388. vllm/transformers_utils/processors/ovis2_5.py +468 -0
  1389. vllm/transformers_utils/repo_utils.py +287 -0
  1390. vllm/transformers_utils/runai_utils.py +104 -0
  1391. vllm/transformers_utils/s3_utils.py +95 -0
  1392. vllm/transformers_utils/tokenizer.py +127 -0
  1393. vllm/transformers_utils/tokenizer_base.py +33 -0
  1394. vllm/transformers_utils/utils.py +184 -0
  1395. vllm/triton_utils/__init__.py +20 -0
  1396. vllm/triton_utils/importing.py +103 -0
  1397. vllm/usage/__init__.py +0 -0
  1398. vllm/usage/usage_lib.py +294 -0
  1399. vllm/utils/__init__.py +66 -0
  1400. vllm/utils/argparse_utils.py +504 -0
  1401. vllm/utils/async_utils.py +310 -0
  1402. vllm/utils/cache.py +214 -0
  1403. vllm/utils/collection_utils.py +112 -0
  1404. vllm/utils/counter.py +45 -0
  1405. vllm/utils/deep_gemm.py +399 -0
  1406. vllm/utils/flashinfer.py +532 -0
  1407. vllm/utils/func_utils.py +236 -0
  1408. vllm/utils/gc_utils.py +151 -0
  1409. vllm/utils/hashing.py +81 -0
  1410. vllm/utils/import_utils.py +449 -0
  1411. vllm/utils/jsontree.py +158 -0
  1412. vllm/utils/math_utils.py +32 -0
  1413. vllm/utils/mem_constants.py +13 -0
  1414. vllm/utils/mem_utils.py +232 -0
  1415. vllm/utils/nccl.py +64 -0
  1416. vllm/utils/network_utils.py +331 -0
  1417. vllm/utils/platform_utils.py +59 -0
  1418. vllm/utils/profiling.py +56 -0
  1419. vllm/utils/registry.py +51 -0
  1420. vllm/utils/serial_utils.py +169 -0
  1421. vllm/utils/system_utils.py +265 -0
  1422. vllm/utils/tensor_schema.py +255 -0
  1423. vllm/utils/torch_utils.py +647 -0
  1424. vllm/v1/__init__.py +0 -0
  1425. vllm/v1/attention/__init__.py +0 -0
  1426. vllm/v1/attention/backends/__init__.py +0 -0
  1427. vllm/v1/attention/backends/cpu_attn.py +497 -0
  1428. vllm/v1/attention/backends/flash_attn.py +1050 -0
  1429. vllm/v1/attention/backends/flashinfer.py +1572 -0
  1430. vllm/v1/attention/backends/flex_attention.py +945 -0
  1431. vllm/v1/attention/backends/gdn_attn.py +387 -0
  1432. vllm/v1/attention/backends/linear_attn.py +77 -0
  1433. vllm/v1/attention/backends/mamba1_attn.py +165 -0
  1434. vllm/v1/attention/backends/mamba2_attn.py +354 -0
  1435. vllm/v1/attention/backends/mamba_attn.py +117 -0
  1436. vllm/v1/attention/backends/mla/__init__.py +0 -0
  1437. vllm/v1/attention/backends/mla/aiter_triton_mla.py +74 -0
  1438. vllm/v1/attention/backends/mla/common.py +2069 -0
  1439. vllm/v1/attention/backends/mla/cutlass_mla.py +278 -0
  1440. vllm/v1/attention/backends/mla/flashattn_mla.py +340 -0
  1441. vllm/v1/attention/backends/mla/flashinfer_mla.py +174 -0
  1442. vllm/v1/attention/backends/mla/flashmla.py +317 -0
  1443. vllm/v1/attention/backends/mla/flashmla_sparse.py +551 -0
  1444. vllm/v1/attention/backends/mla/indexer.py +369 -0
  1445. vllm/v1/attention/backends/mla/rocm_aiter_mla.py +275 -0
  1446. vllm/v1/attention/backends/mla/rocm_aiter_mla_sparse.py +325 -0
  1447. vllm/v1/attention/backends/mla/triton_mla.py +171 -0
  1448. vllm/v1/attention/backends/pallas.py +436 -0
  1449. vllm/v1/attention/backends/rocm_aiter_fa.py +1000 -0
  1450. vllm/v1/attention/backends/rocm_aiter_unified_attn.py +206 -0
  1451. vllm/v1/attention/backends/rocm_attn.py +359 -0
  1452. vllm/v1/attention/backends/short_conv_attn.py +105 -0
  1453. vllm/v1/attention/backends/tree_attn.py +428 -0
  1454. vllm/v1/attention/backends/triton_attn.py +377 -0
  1455. vllm/v1/attention/backends/utils.py +1149 -0
  1456. vllm/v1/core/__init__.py +0 -0
  1457. vllm/v1/core/block_pool.py +466 -0
  1458. vllm/v1/core/encoder_cache_manager.py +343 -0
  1459. vllm/v1/core/kv_cache_coordinator.py +570 -0
  1460. vllm/v1/core/kv_cache_manager.py +408 -0
  1461. vllm/v1/core/kv_cache_metrics.py +96 -0
  1462. vllm/v1/core/kv_cache_utils.py +1471 -0
  1463. vllm/v1/core/sched/__init__.py +0 -0
  1464. vllm/v1/core/sched/async_scheduler.py +68 -0
  1465. vllm/v1/core/sched/interface.py +187 -0
  1466. vllm/v1/core/sched/output.py +230 -0
  1467. vllm/v1/core/sched/request_queue.py +217 -0
  1468. vllm/v1/core/sched/scheduler.py +1726 -0
  1469. vllm/v1/core/sched/utils.py +72 -0
  1470. vllm/v1/core/single_type_kv_cache_manager.py +801 -0
  1471. vllm/v1/cudagraph_dispatcher.py +183 -0
  1472. vllm/v1/engine/__init__.py +214 -0
  1473. vllm/v1/engine/async_llm.py +874 -0
  1474. vllm/v1/engine/coordinator.py +377 -0
  1475. vllm/v1/engine/core.py +1421 -0
  1476. vllm/v1/engine/core_client.py +1406 -0
  1477. vllm/v1/engine/detokenizer.py +351 -0
  1478. vllm/v1/engine/exceptions.py +18 -0
  1479. vllm/v1/engine/input_processor.py +636 -0
  1480. vllm/v1/engine/llm_engine.py +416 -0
  1481. vllm/v1/engine/logprobs.py +189 -0
  1482. vllm/v1/engine/output_processor.py +658 -0
  1483. vllm/v1/engine/parallel_sampling.py +145 -0
  1484. vllm/v1/engine/processor.py +20 -0
  1485. vllm/v1/engine/utils.py +1068 -0
  1486. vllm/v1/executor/__init__.py +6 -0
  1487. vllm/v1/executor/abstract.py +352 -0
  1488. vllm/v1/executor/multiproc_executor.py +888 -0
  1489. vllm/v1/executor/ray_distributed_executor.py +8 -0
  1490. vllm/v1/executor/ray_executor.py +626 -0
  1491. vllm/v1/executor/ray_utils.py +465 -0
  1492. vllm/v1/executor/uniproc_executor.py +183 -0
  1493. vllm/v1/kv_cache_interface.py +404 -0
  1494. vllm/v1/kv_offload/__init__.py +0 -0
  1495. vllm/v1/kv_offload/abstract.py +161 -0
  1496. vllm/v1/kv_offload/arc_manager.py +237 -0
  1497. vllm/v1/kv_offload/backend.py +97 -0
  1498. vllm/v1/kv_offload/backends/__init__.py +0 -0
  1499. vllm/v1/kv_offload/backends/cpu.py +62 -0
  1500. vllm/v1/kv_offload/cpu.py +86 -0
  1501. vllm/v1/kv_offload/factory.py +56 -0
  1502. vllm/v1/kv_offload/lru_manager.py +139 -0
  1503. vllm/v1/kv_offload/mediums.py +39 -0
  1504. vllm/v1/kv_offload/spec.py +66 -0
  1505. vllm/v1/kv_offload/worker/__init__.py +0 -0
  1506. vllm/v1/kv_offload/worker/cpu_gpu.py +191 -0
  1507. vllm/v1/kv_offload/worker/worker.py +144 -0
  1508. vllm/v1/metrics/__init__.py +0 -0
  1509. vllm/v1/metrics/loggers.py +1268 -0
  1510. vllm/v1/metrics/prometheus.py +82 -0
  1511. vllm/v1/metrics/ray_wrappers.py +194 -0
  1512. vllm/v1/metrics/reader.py +257 -0
  1513. vllm/v1/metrics/stats.py +431 -0
  1514. vllm/v1/outputs.py +237 -0
  1515. vllm/v1/pool/__init__.py +0 -0
  1516. vllm/v1/pool/metadata.py +82 -0
  1517. vllm/v1/request.py +280 -0
  1518. vllm/v1/sample/__init__.py +0 -0
  1519. vllm/v1/sample/logits_processor/__init__.py +352 -0
  1520. vllm/v1/sample/logits_processor/builtin.py +278 -0
  1521. vllm/v1/sample/logits_processor/interface.py +106 -0
  1522. vllm/v1/sample/logits_processor/state.py +165 -0
  1523. vllm/v1/sample/metadata.py +44 -0
  1524. vllm/v1/sample/ops/__init__.py +0 -0
  1525. vllm/v1/sample/ops/bad_words.py +52 -0
  1526. vllm/v1/sample/ops/logprobs.py +25 -0
  1527. vllm/v1/sample/ops/penalties.py +57 -0
  1528. vllm/v1/sample/ops/topk_topp_sampler.py +384 -0
  1529. vllm/v1/sample/rejection_sampler.py +805 -0
  1530. vllm/v1/sample/sampler.py +319 -0
  1531. vllm/v1/sample/tpu/__init__.py +0 -0
  1532. vllm/v1/sample/tpu/metadata.py +120 -0
  1533. vllm/v1/sample/tpu/sampler.py +215 -0
  1534. vllm/v1/serial_utils.py +532 -0
  1535. vllm/v1/spec_decode/__init__.py +0 -0
  1536. vllm/v1/spec_decode/eagle.py +1325 -0
  1537. vllm/v1/spec_decode/medusa.py +73 -0
  1538. vllm/v1/spec_decode/metadata.py +66 -0
  1539. vllm/v1/spec_decode/metrics.py +225 -0
  1540. vllm/v1/spec_decode/ngram_proposer.py +291 -0
  1541. vllm/v1/spec_decode/suffix_decoding.py +101 -0
  1542. vllm/v1/spec_decode/utils.py +121 -0
  1543. vllm/v1/structured_output/__init__.py +338 -0
  1544. vllm/v1/structured_output/backend_guidance.py +265 -0
  1545. vllm/v1/structured_output/backend_lm_format_enforcer.py +177 -0
  1546. vllm/v1/structured_output/backend_outlines.py +324 -0
  1547. vllm/v1/structured_output/backend_types.py +136 -0
  1548. vllm/v1/structured_output/backend_xgrammar.py +362 -0
  1549. vllm/v1/structured_output/request.py +94 -0
  1550. vllm/v1/structured_output/utils.py +469 -0
  1551. vllm/v1/utils.py +414 -0
  1552. vllm/v1/worker/__init__.py +0 -0
  1553. vllm/v1/worker/block_table.py +343 -0
  1554. vllm/v1/worker/cpu_model_runner.py +122 -0
  1555. vllm/v1/worker/cpu_worker.py +210 -0
  1556. vllm/v1/worker/dp_utils.py +250 -0
  1557. vllm/v1/worker/ec_connector_model_runner_mixin.py +87 -0
  1558. vllm/v1/worker/gpu/README.md +4 -0
  1559. vllm/v1/worker/gpu/__init__.py +0 -0
  1560. vllm/v1/worker/gpu/async_utils.py +97 -0
  1561. vllm/v1/worker/gpu/attn_utils.py +189 -0
  1562. vllm/v1/worker/gpu/block_table.py +314 -0
  1563. vllm/v1/worker/gpu/cudagraph_utils.py +259 -0
  1564. vllm/v1/worker/gpu/dp_utils.py +31 -0
  1565. vllm/v1/worker/gpu/input_batch.py +430 -0
  1566. vllm/v1/worker/gpu/model_runner.py +1007 -0
  1567. vllm/v1/worker/gpu/sample/__init__.py +0 -0
  1568. vllm/v1/worker/gpu/sample/gumbel.py +101 -0
  1569. vllm/v1/worker/gpu/sample/logprob.py +167 -0
  1570. vllm/v1/worker/gpu/sample/metadata.py +179 -0
  1571. vllm/v1/worker/gpu/sample/penalties.py +154 -0
  1572. vllm/v1/worker/gpu/sample/sampler.py +75 -0
  1573. vllm/v1/worker/gpu/spec_decode/__init__.py +18 -0
  1574. vllm/v1/worker/gpu/spec_decode/eagle.py +565 -0
  1575. vllm/v1/worker/gpu/spec_decode/eagle_cudagraph.py +115 -0
  1576. vllm/v1/worker/gpu/spec_decode/rejection_sample.py +83 -0
  1577. vllm/v1/worker/gpu/states.py +309 -0
  1578. vllm/v1/worker/gpu/structured_outputs.py +76 -0
  1579. vllm/v1/worker/gpu_input_batch.py +971 -0
  1580. vllm/v1/worker/gpu_model_runner.py +5360 -0
  1581. vllm/v1/worker/gpu_ubatch_wrapper.py +472 -0
  1582. vllm/v1/worker/gpu_worker.py +922 -0
  1583. vllm/v1/worker/kv_connector_model_runner_mixin.py +309 -0
  1584. vllm/v1/worker/lora_model_runner_mixin.py +212 -0
  1585. vllm/v1/worker/tpu_input_batch.py +583 -0
  1586. vllm/v1/worker/tpu_model_runner.py +2196 -0
  1587. vllm/v1/worker/tpu_worker.py +351 -0
  1588. vllm/v1/worker/ubatch_utils.py +73 -0
  1589. vllm/v1/worker/ubatching.py +231 -0
  1590. vllm/v1/worker/utils.py +365 -0
  1591. vllm/v1/worker/worker_base.py +377 -0
  1592. vllm/v1/worker/xpu_model_runner.py +48 -0
  1593. vllm/v1/worker/xpu_worker.py +198 -0
  1594. vllm/version.py +39 -0
  1595. vllm/vllm_flash_attn/.gitkeep +0 -0
  1596. vllm_cpu-0.12.0.dist-info/METADATA +300 -0
  1597. vllm_cpu-0.12.0.dist-info/RECORD +1600 -0
  1598. vllm_cpu-0.12.0.dist-info/WHEEL +5 -0
  1599. vllm_cpu-0.12.0.dist-info/entry_points.txt +5 -0
  1600. vllm_cpu-0.12.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,2069 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ """
4
+ # MLA Common Components
5
+
6
+ This file implements common components for MLA implementations.
7
+
8
+ First we define:
9
+
10
+ Sq as Q sequence length
11
+ Skv as KV sequence length
12
+
13
+ MLA has two possible ways of computing, a data-movement friendly approach and a
14
+ compute friendly approach, we generally want to use the compute friendly
15
+ approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
16
+ and the data-movement friendly approach for "decode" (i.e. the ratio
17
+ Sq / Skv is "large").
18
+
19
+ NOTE what we deem small and large is currently determined by if its labelled
20
+ prefill or decode by the scheduler, but this is something we should probably
21
+ tune.
22
+
23
+ Main reference: DeepseekV2 paper, and FlashInfer Implementation
24
+ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
25
+
26
+ Deepseek's MLA attention works the following way:
27
+ * Use a single latent vector to represent the per-token entry of the KV cache.
28
+ * For decode (i.e. the memory friendly approach) the attention "simulates" a
29
+ multi-head attention, while the compute is similar to multi-query attention.
30
+
31
+ Below is example of both paths assuming batchsize = 1
32
+
33
+ ## More Extent Definitions:
34
+
35
+ C Context length, `Skv - Sq`
36
+ H hidden size
37
+ N number of attention heads
38
+ Lq latent dimension for Q 1536 in DSV3
39
+ Lkv latent dimension for K/V 512 in DSV3
40
+ P nope dimension, no rope. 128 in DSV3
41
+ R rope dimension, goes through rope. 64 in DSV3
42
+ V V head dim. 128 in DSV3
43
+
44
+ ## Vector/Matrix Definitions
45
+
46
+ h_t hidden states (input to attention) shape [Sq, H]
47
+ q_c latent/compressed Q shape [Sq, Lq]
48
+ q_nope uncompressed Q (no-rope) shape [Sq, N, P]
49
+ q_pe uncompressed Q (rope) shape [Sq, N, R]
50
+ kv_c latent/compressed KV shape [Skv, Lkv]
51
+ k_pe decoupled k position embeddings shape [Skv, R]
52
+ new_kv_c new kv_c from current iter shape [Sq, Lkv]
53
+ new_k_pe new k_pe from current iter shape [Sq, R]
54
+ cache_kv_c cached k_c from previous iters shape [C, Lkv]
55
+ cache_k_pe cached k_pe from previous iters shape [C, R]
56
+ W_DQ project h_t to q_c shape [H, Lq]
57
+ W_UQ project q_c to q_nope shape [Lq, N * P]
58
+ W_QR project q_c to q_pe shape [Lq, N * R]
59
+ W_DKV project h_t to kv_c shape [H, Lkv]
60
+ W_UK project kv_c to k_nope shape [Lkv, N, P]
61
+ W_KR project h_t to k_pe shape [H, R]
62
+ W_UV project kv_c to v shape [Lkv, N, V]
63
+ W_O project v to h_t shape [N * V, H]
64
+
65
+
66
+ ## Compute Friendly Approach (i.e. "_forward_prefill"):
67
+
68
+ q_c = h_t @ W_DQ
69
+ q_nope = (q_c @ W_UQ).view(Sq, N, P)
70
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
71
+ new_kv_c = h_t @ W_DKV
72
+ new_k_pe = RoPE(h_t @ W_KR)
73
+ kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
74
+ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
75
+ k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
76
+ v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
77
+
78
+ // MHA with QK headdim = P + R
79
+ // V headdim = V
80
+ // spda_o shape [Sq, N, V]
81
+ spda_o = scaled_dot_product_attention(
82
+ torch.cat([q_nope, q_pe], dim=-1),
83
+ torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
84
+ v
85
+ )
86
+ return spda_o @ W_O
87
+
88
+ NOTE: in the actual code,
89
+ `kv_b_proj` is [W_UK; W_UV] concatenated per head
90
+ `q_b_proj` is [W_UQ; W_QR] concatenated per head
91
+ `out_proj` is W_O
92
+
93
+
94
+ ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
95
+
96
+ Runtime
97
+ q_c = h_t @ W_DQ
98
+ q_nope = (q_c @ W_UQ).view(-1, N, P)
99
+ ql_nope = einsum("snh,lnh->snl", q, W_UK)
100
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
101
+ new_kv_c = h_t @ W_DKV
102
+ new_k_pe = RoPE(h_t @ W_KR)
103
+ kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
104
+ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
105
+
106
+ // MQA with QK headdim = Lkv + R
107
+ // V headdim = Lkv
108
+ // spda_o shape [Sq, N, Lkv]
109
+ // NOTE: this is less compute-friendly since Lkv > P
110
+ // but is more data-movement friendly since its MQA vs MHA
111
+ spda_o = scaled_dot_product_attention(
112
+ torch.cat([ql_nope, q_pe], dim=-1),
113
+ torch.cat([kv_c, k_pe], dim=-1),
114
+ kv_c
115
+ )
116
+
117
+ o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
118
+ return o.view(-1, N * V) @ self.num_heads @ W_O
119
+
120
+
121
+ ## Chunked Prefill
122
+
123
+ For chunked prefill we want to use the compute friendly algorithm. We are
124
+ assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
125
+ the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
126
+
127
+ However, the compute-friendly approach can potentially run out of memory if Skv
128
+ is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
129
+
130
+ To mitigate this, we chunk the computation of attention with respect to the
131
+ current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
132
+ fixed workspace size.
133
+
134
+ The chunked prefill approach is as follows:
135
+
136
+ MCC Max chunk of context to process per iter, computed dynamically,
137
+ used to bound the memory usage
138
+
139
+ q_c = h_t @ W_DQ
140
+ q_nope = (q_c @ W_UQ).view(Sq, N, P)
141
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
142
+ new_kv_c = h_t @ W_DKV
143
+ new_k_pe = RoPE(h_t @ W_KR)
144
+ new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
145
+ new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
146
+
147
+ // MHA between queries and new KV
148
+ // with QK headdim = P + R
149
+ // V headdim = V
150
+ // curr_o shape [Sq, N, V]
151
+ // curr_lse shape [N, Sq], this is just order FA returns
152
+ curr_o, curr_lse = scaled_dot_product_attention(
153
+ torch.cat([q_nope, q_pe], dim=-1),
154
+ torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
155
+ new_v,
156
+ casual=True,
157
+ return_softmax_lse=True
158
+ )
159
+
160
+ // Compute attention with the already existing context
161
+ for chunk_idx in range(cdiv(C, MCC)):
162
+ chunk_start = chunk_idx * MCC
163
+ chunk_end = min(chunk_start + MCC, C)
164
+ Sc = chunk_end - chunk_start
165
+ cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
166
+ cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
167
+ cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
168
+ cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
169
+
170
+ chunk_o, chunk_lse = scaled_dot_product_attention(
171
+ torch.cat([q_nope, q_pe], dim=-1),
172
+ torch.cat([cache_k_nope_chunk,
173
+ cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
174
+ dim=-1),
175
+ cache_v_chunk,
176
+ casual=False,
177
+ return_softmax_lse=True
178
+ )
179
+
180
+ curr_o, curr_lse = merge_attn_states(
181
+ suffix_output=curr_o,
182
+ suffix_lse=curr_lse,
183
+ prefix_output=chunk_o,
184
+ prefix_lse=chunk_lse,
185
+ )
186
+
187
+ return curr_o @ W_O
188
+ """
189
+
190
+ import functools
191
+ from abc import abstractmethod
192
+ from dataclasses import dataclass, field
193
+ from enum import Enum
194
+ from typing import ClassVar, Generic, TypeVar
195
+
196
+ import torch
197
+ from tqdm import tqdm
198
+
199
+ from vllm import _custom_ops as ops
200
+ from vllm import envs
201
+ from vllm._aiter_ops import rocm_aiter_ops
202
+ from vllm.attention.backends.abstract import (
203
+ AttentionBackend,
204
+ AttentionLayer,
205
+ MLAAttentionImpl,
206
+ )
207
+ from vllm.attention.backends.utils import get_mla_dims
208
+ from vllm.attention.ops.common import cp_lse_ag_out_rs
209
+ from vllm.attention.ops.merge_attn_states import merge_attn_states
210
+ from vllm.attention.utils.fa_utils import get_flash_attn_version
211
+ from vllm.config import VllmConfig, get_current_vllm_config
212
+ from vllm.distributed.parallel_state import get_dcp_group, is_global_first_rank
213
+ from vllm.logger import init_logger
214
+ from vllm.model_executor.layers.batch_invariant import (
215
+ vllm_is_batch_invariant,
216
+ )
217
+ from vllm.model_executor.layers.linear import (
218
+ ColumnParallelLinear,
219
+ LinearBase,
220
+ UnquantizedLinearMethod,
221
+ )
222
+ from vllm.platforms import current_platform
223
+ from vllm.utils.flashinfer import has_nvidia_artifactory
224
+ from vllm.utils.math_utils import cdiv, round_down
225
+ from vllm.v1.attention.backends.utils import (
226
+ AttentionMetadataBuilder,
227
+ CommonAttentionMetadata,
228
+ get_dcp_local_seq_lens,
229
+ get_per_layer_parameters,
230
+ infer_global_hyperparameters,
231
+ split_decodes_and_prefills,
232
+ )
233
+ from vllm.v1.kv_cache_interface import AttentionSpec
234
+
235
+
236
+ class QueryLenSupport(Enum):
237
+ """Defines the level of query length support for an attention backend's
238
+ decode pipeline.
239
+
240
+ - SINGLE_ONLY: Decode pipeline only supports single-token queries
241
+ (query_len=1)
242
+ - UNIFORM: Decode pipeline supports uniform multi-token queries
243
+ (all requests must have same query_len > 1)
244
+ - VARLEN: Decode pipeline supports variable-length queries
245
+ (mixed query lengths in same batch)
246
+ """
247
+
248
+ SINGLE_ONLY = "single_only"
249
+ UNIFORM = "uniform"
250
+ VARLEN = "varlen"
251
+
252
+
253
+ try:
254
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
255
+
256
+ is_vllm_fa = True
257
+ except ImportError:
258
+ # For rocm use upstream flash attention
259
+ if current_platform.is_rocm():
260
+ from flash_attn import flash_attn_varlen_func
261
+ is_vllm_fa = False
262
+
263
+ try:
264
+ from flashinfer import BatchPrefillWithRaggedKVCacheWrapper
265
+ from flashinfer.prefill import cudnn_batch_prefill_with_kv_cache # noqa: F401
266
+
267
+ flashinfer_available = True
268
+ except ImportError:
269
+ BatchPrefillWithRaggedKVCacheWrapper = object
270
+
271
+ flashinfer_available = False
272
+
273
+
274
+ def dynamic_per_batched_tensor_quant(
275
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
276
+ ):
277
+ DTYPE_MAX = torch.finfo(dtype).max
278
+ min_val, max_val = x.aminmax()
279
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-10)
280
+ scale = DTYPE_MAX / amax
281
+ x_scl_sat = (x * scale).clamp(min=-DTYPE_MAX, max=DTYPE_MAX)
282
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
283
+
284
+
285
+ logger = init_logger(__name__)
286
+
287
+ CUDNN_WORKSPACE_SIZE = 12800
288
+
289
+
290
+ class MLACommonBackend(AttentionBackend):
291
+ accept_output_buffer: bool = True
292
+
293
+ @staticmethod
294
+ def get_name() -> str:
295
+ return "TRITON_MLA"
296
+
297
+ @staticmethod
298
+ def get_builder_cls() -> type["MLACommonMetadataBuilder"]:
299
+ return MLACommonMetadataBuilder
300
+
301
+ @staticmethod
302
+ def get_kv_cache_shape(
303
+ num_blocks: int,
304
+ block_size: int,
305
+ num_kv_heads: int, # assumed to be 1 for MLA
306
+ head_size: int,
307
+ cache_dtype_str: str = "auto",
308
+ ) -> tuple[int, ...]:
309
+ return (num_blocks, block_size, head_size)
310
+
311
+ @staticmethod
312
+ def get_kv_cache_stride_order(
313
+ include_num_layers_dimension: bool = False,
314
+ ) -> tuple[int, ...]:
315
+ # `stride_order` indicates the permutation that gets
316
+ # us from `get_kv_cache_shape` to the actual memory layout we want.
317
+ # (num_blocks, num_layers, block_size, head_size)
318
+ return (1, 0, 2, 3) if include_num_layers_dimension else (0, 1, 2)
319
+
320
+ @classmethod
321
+ def get_supported_head_sizes(cls) -> list[int]:
322
+ return [576]
323
+
324
+ @classmethod
325
+ def is_mla(cls) -> bool:
326
+ return True
327
+
328
+
329
+ @dataclass
330
+ class MLACommonPrefillMetadata:
331
+ """Prefill Specific Metadata"""
332
+
333
+ @dataclass
334
+ class ChunkedContextMetadata:
335
+ # New for MLA (compared to FlashAttention)
336
+ # For handling chunked prefill
337
+ cu_seq_lens: torch.Tensor
338
+ starts: torch.Tensor
339
+ seq_tot: list[int]
340
+ max_seq_lens: list[int]
341
+ seq_lens: torch.Tensor
342
+ workspace: torch.Tensor
343
+ token_to_seq: torch.Tensor
344
+ chunk_total_token: list[int]
345
+
346
+ # for mla DCP
347
+ padded_local_chunk_seq_lens: list[list[int]] | None = None
348
+ local_context_lens_allranks: list[list[int]] | None = None
349
+ padded_local_cu_seq_lens: torch.Tensor | None = None
350
+ cu_seq_lens_lst: list[list[int]] | None = None
351
+ chunk_size: int | None = None
352
+
353
+ block_table: torch.Tensor
354
+ query_start_loc: torch.Tensor
355
+ max_query_len: int
356
+ chunked_context: ChunkedContextMetadata | None = None
357
+ query_seq_lens: torch.Tensor | None = None
358
+
359
+
360
+ @dataclass
361
+ class FlashInferPrefillMetadata(MLACommonPrefillMetadata):
362
+ prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
363
+ prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = field(
364
+ default_factory=list
365
+ )
366
+
367
+
368
+ @dataclass
369
+ class CudnnPrefillMetadata(MLACommonPrefillMetadata):
370
+ class ChunkedContextMetadata(MLACommonPrefillMetadata.ChunkedContextMetadata):
371
+ seq_lens: torch.Tensor
372
+
373
+ cudnn_workspace: torch.Tensor | None = None
374
+
375
+
376
+ @dataclass
377
+ class MLACommonDecodeMetadata:
378
+ block_table: torch.Tensor
379
+ seq_lens: torch.Tensor
380
+ dcp_tot_seq_lens: torch.Tensor | None
381
+
382
+
383
+ D = TypeVar("D", bound=MLACommonDecodeMetadata)
384
+
385
+
386
+ @dataclass
387
+ class MLACommonMetadata(Generic[D]):
388
+ """Metadata for MLACommon.
389
+
390
+ NOTE: Please read the comment at the top of the file before trying to
391
+ understand this class
392
+ """
393
+
394
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
395
+ # |---------- N-1 iteration --------|
396
+ # |---------------- N iteration ---------------------|
397
+ # |- tokenA -|......................|-- newTokens ---|
398
+ # |---------- context_len ----------|
399
+ # |-------------------- seq_len ---------------------|
400
+ # |-- query_len ---|
401
+
402
+ num_reqs: int
403
+ max_query_len: int
404
+ max_seq_len: int
405
+
406
+ num_actual_tokens: int # Number of tokens excluding padding.
407
+ query_start_loc: torch.Tensor
408
+ slot_mapping: torch.Tensor
409
+
410
+ # New for MLA (compared to FlashAttention)
411
+ # For handling prefill decode split
412
+ num_decodes: int
413
+ num_decode_tokens: int
414
+ num_prefills: int
415
+
416
+ # The dimension of the attention heads
417
+ head_dim: int | None = None
418
+
419
+ decode: D | None = None
420
+ prefill: (
421
+ MLACommonPrefillMetadata
422
+ | FlashInferPrefillMetadata
423
+ | CudnnPrefillMetadata
424
+ | None
425
+ ) = None
426
+
427
+ def __post_init__(self):
428
+ if self.head_dim is not None and not MLACommonBackend.supports_head_size(
429
+ self.head_dim
430
+ ):
431
+ raise ValueError(f"Head dimension {self.head_dim} is not supported by MLA.")
432
+
433
+
434
+ M = TypeVar("M", bound=MLACommonMetadata)
435
+ A = TypeVar("A")
436
+
437
+
438
+ def use_flashinfer_prefill() -> bool:
439
+ # For blackwell default to flashinfer prefill if it's available since
440
+ # it is faster than FA2.
441
+ return (
442
+ not envs.VLLM_DISABLE_FLASHINFER_PREFILL
443
+ and flashinfer_available
444
+ and not envs.VLLM_USE_CUDNN_PREFILL
445
+ and not envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
446
+ and current_platform.is_device_capability(100)
447
+ )
448
+
449
+
450
+ def use_cudnn_prefill() -> bool:
451
+ return (
452
+ flashinfer_available
453
+ and envs.VLLM_USE_CUDNN_PREFILL
454
+ and current_platform.is_device_capability(100)
455
+ and has_nvidia_artifactory()
456
+ )
457
+
458
+
459
+ def use_trtllm_ragged_deepseek_prefill() -> bool:
460
+ """Check if TRT-LLM ragged DeepSeek prefill should be used."""
461
+ return (
462
+ flashinfer_available
463
+ and envs.VLLM_USE_TRTLLM_RAGGED_DEEPSEEK_PREFILL
464
+ and current_platform.is_device_capability(100)
465
+ )
466
+
467
+
468
+ class MLACommonMetadataBuilder(AttentionMetadataBuilder[M]):
469
+ """
470
+ NOTE: Please read the comment at the top of the file before trying to
471
+ understand this class
472
+ """
473
+
474
+ # Defines the level of query length support for this backend.
475
+ # - SINGLE_ONLY: Only single-token queries (no spec decode support)
476
+ # - UNIFORM: Supports uniform multi-token queries (spec decode with uniform lengths)
477
+ # - VARLEN: Supports variable-length queries (spec decode with mixed lengths)
478
+ # If set to UNIFORM or VARLEN, this will increase `reorder_batch_threshold` when
479
+ # speculative decoding is enabled.
480
+ query_len_support: ClassVar[QueryLenSupport] = QueryLenSupport.SINGLE_ONLY
481
+
482
+ # The threshold for reordering the batch into decode and prefill requests.
483
+ # If > 1, the batch will be reordered such that requests with
484
+ # query length <= threshold are classified as decode requests.
485
+ # Use `query_len_support` (above) to set this automatically
486
+ # when speculative decoding is enabled.
487
+ reorder_batch_threshold: int = 1
488
+
489
+ @staticmethod
490
+ def determine_chunked_prefill_workspace_size(vllm_config: VllmConfig) -> int:
491
+ scheduler_config = vllm_config.scheduler_config
492
+ cache_config = vllm_config.cache_config
493
+ model_config = vllm_config.model_config
494
+
495
+ chunked_prefill_workspace_size = min(
496
+ # Try for 8 full length request or at least 4 pages per-request
497
+ max(
498
+ 8 * model_config.max_model_len,
499
+ 4 * scheduler_config.max_num_seqs * cache_config.block_size,
500
+ ),
501
+ # For long-context models try not to over-allocate limiting
502
+ # kv-cache space, limiting it to 64k tokens,
503
+ # which would result in the workspace being:
504
+ # 2*(576)*(64*1024) = 144mb
505
+ # (assuming 576 MLA head dim, and fp16)
506
+ # which would result in up-projected context being
507
+ # 2*(192*128)*(64*1024) = 3gb
508
+ # (assuming 192 QK head dim, 128 heads, and fp16)
509
+ 64 * 1024,
510
+ )
511
+
512
+ # Enforce that we enough for at least 1 page per request
513
+ chunked_prefill_workspace_size = max(
514
+ chunked_prefill_workspace_size,
515
+ scheduler_config.max_num_seqs * cache_config.block_size,
516
+ )
517
+
518
+ return chunked_prefill_workspace_size
519
+
520
+ def __init__(
521
+ self,
522
+ kv_cache_spec: AttentionSpec,
523
+ layer_names: list[str],
524
+ vllm_config: VllmConfig,
525
+ device: torch.device,
526
+ metadata_cls: type[M] | None = None,
527
+ supports_dcp_with_varlen: bool = False,
528
+ ):
529
+ self.metadata_cls = (
530
+ metadata_cls if metadata_cls is not None else MLACommonMetadata
531
+ )
532
+ self.kv_cache_spec = kv_cache_spec
533
+ scheduler_config = vllm_config.scheduler_config
534
+ self.model_config = vllm_config.model_config
535
+ parallel_config = vllm_config.parallel_config
536
+ self.compilation_config = vllm_config.compilation_config
537
+ self.vllm_config = vllm_config
538
+ self.device = device
539
+
540
+ self.num_heads = self.model_config.get_num_attention_heads(parallel_config)
541
+ self.mla_dims = get_mla_dims(self.model_config)
542
+ self.aot_schedule = current_platform.is_cuda()
543
+ try:
544
+ self.dcp_world_size = get_dcp_group().world_size
545
+ self.dcp_rank = get_dcp_group().rank_in_group
546
+ except AssertionError:
547
+ # DCP might not be initialized in testing
548
+ self.dcp_world_size = 1
549
+ self.dcp_rank = 0
550
+ self.dcp_local_block_size = parallel_config.cp_kv_cache_interleave_size
551
+ self.dcp_virtual_block_size = self.dcp_local_block_size * self.dcp_world_size
552
+
553
+ # Don't try to access the runner on AMD
554
+ if self.aot_schedule:
555
+ self.page_size = self.kv_cache_spec.block_size
556
+
557
+ self.chunked_prefill_workspace_size = (
558
+ self.determine_chunked_prefill_workspace_size(vllm_config)
559
+ )
560
+
561
+ if self.dcp_world_size > 1:
562
+ # Note(hc): The local kvcache is incomplete when DCP is triggered,
563
+ # an additional kvcache allgather across the DCP group is therefore
564
+ # required, so the workspace has to be enlarged by 1/DCP relative
565
+ # to the original TP allocation.
566
+ assert self.chunked_prefill_workspace_size % self.dcp_world_size == 0
567
+ self.chunked_prefill_workspace = torch.empty(
568
+ (
569
+ self.chunked_prefill_workspace_size
570
+ + self.chunked_prefill_workspace_size // self.dcp_world_size,
571
+ self.model_config.get_head_size(),
572
+ ),
573
+ dtype=self.model_config.dtype,
574
+ device=device,
575
+ )
576
+ else:
577
+ self.chunked_prefill_workspace = torch.empty(
578
+ (
579
+ self.chunked_prefill_workspace_size,
580
+ self.model_config.get_head_size(),
581
+ ),
582
+ dtype=self.model_config.dtype,
583
+ device=device,
584
+ )
585
+
586
+ self._use_cudnn_prefill = use_cudnn_prefill()
587
+ self._use_fi_prefill = use_flashinfer_prefill()
588
+ self._use_trtllm_ragged_prefill = use_trtllm_ragged_deepseek_prefill()
589
+ self.prefill_metadata_cls = (
590
+ FlashInferPrefillMetadata
591
+ if self._use_fi_prefill
592
+ else CudnnPrefillMetadata
593
+ if self._use_cudnn_prefill
594
+ else MLACommonPrefillMetadata
595
+ )
596
+
597
+ if self._use_fi_prefill:
598
+ self._workspace_buffer = torch.empty(
599
+ envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
600
+ dtype=torch.uint8,
601
+ device=device,
602
+ )
603
+
604
+ self._fi_prefill_main: BatchPrefillWithRaggedKVCacheWrapper | None = None
605
+ self._fi_prefill_chunks: list[BatchPrefillWithRaggedKVCacheWrapper] = []
606
+
607
+ self._global_hyperparameters = infer_global_hyperparameters(
608
+ get_per_layer_parameters(vllm_config, layer_names, MLACommonImpl)
609
+ )
610
+
611
+ if self._use_trtllm_ragged_prefill:
612
+ self._workspace_buffer = torch.empty(
613
+ envs.VLLM_FLASHINFER_WORKSPACE_BUFFER_SIZE,
614
+ dtype=torch.uint8,
615
+ device=device,
616
+ )
617
+
618
+ if self._use_cudnn_prefill:
619
+ self.cudnn_workspace = torch.empty(
620
+ CUDNN_WORKSPACE_SIZE * scheduler_config.max_num_seqs,
621
+ dtype=torch.int8,
622
+ device=device,
623
+ )
624
+
625
+ supports_spec_decode = self.query_len_support != QueryLenSupport.SINGLE_ONLY
626
+ self._init_reorder_batch_threshold(
627
+ self.reorder_batch_threshold, supports_spec_decode, supports_dcp_with_varlen
628
+ )
629
+
630
+ # Validate consistency between query_len_support and reorder_batch_threshold
631
+ if self.query_len_support == QueryLenSupport.SINGLE_ONLY:
632
+ assert self.reorder_batch_threshold == 1, (
633
+ f"reorder_batch_threshold must be 1 when query_len_support is "
634
+ f"SINGLE_ONLY, got {self.reorder_batch_threshold}"
635
+ )
636
+
637
+ def _build_fi_prefill_wrappers(self, prefill: FlashInferPrefillMetadata):
638
+ qo_indptr = prefill.query_start_loc
639
+
640
+ has_context = False
641
+ if prefill.chunked_context is not None:
642
+ chunked_context = prefill.chunked_context
643
+ has_context = True
644
+
645
+ if self._fi_prefill_main is None:
646
+ self._fi_prefill_main = BatchPrefillWithRaggedKVCacheWrapper(
647
+ self._workspace_buffer, "NHD", backend="cutlass"
648
+ )
649
+
650
+ if has_context:
651
+ num_chunks = chunked_context.cu_seq_lens.shape[0]
652
+ # Allocate more prefill chunk wrappers if needed
653
+ if len(self._fi_prefill_chunks) < num_chunks:
654
+ for _ in range(len(self._fi_prefill_chunks), num_chunks):
655
+ self._fi_prefill_chunks.append(
656
+ BatchPrefillWithRaggedKVCacheWrapper(
657
+ self._workspace_buffer, "NHD", backend="cutlass"
658
+ )
659
+ )
660
+ assert num_chunks <= len(self._fi_prefill_chunks)
661
+
662
+ # In MLA, the non-latent num_qo_heads == num_kv_heads
663
+ num_qo_heads = self.num_heads
664
+ num_kv_heads = num_qo_heads
665
+
666
+ # Sanity: Verify that num_kv_heads == 1 since it is latent space
667
+ assert self.kv_cache_spec.num_kv_heads == 1
668
+
669
+ # Get non-latent head_dim_qk and head_dim_vo
670
+ head_dim_qk = self.mla_dims.qk_nope_head_dim + self.mla_dims.qk_rope_head_dim
671
+ head_dim_vo = self.mla_dims.v_head_dim
672
+
673
+ # For main run, qo_indptr == kv_indptr
674
+ kv_indptr = qo_indptr.clone()
675
+
676
+ # Prepare main prefill
677
+ self._fi_prefill_main.plan(
678
+ qo_indptr=qo_indptr,
679
+ kv_indptr=kv_indptr,
680
+ num_qo_heads=num_qo_heads,
681
+ num_kv_heads=num_kv_heads,
682
+ head_dim_qk=head_dim_qk,
683
+ head_dim_vo=head_dim_vo,
684
+ causal=True, # This is main run
685
+ sm_scale=self._global_hyperparameters.sm_scale,
686
+ window_left=self._global_hyperparameters.window_left,
687
+ logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
688
+ q_data_type=self.model_config.dtype,
689
+ )
690
+
691
+ # Prepare context prefills
692
+ if has_context:
693
+ for i in range(num_chunks):
694
+ kv_indptr_chunk = chunked_context.cu_seq_lens[i]
695
+
696
+ self._fi_prefill_chunks[i].plan(
697
+ qo_indptr=qo_indptr,
698
+ kv_indptr=kv_indptr_chunk,
699
+ num_qo_heads=num_qo_heads,
700
+ num_kv_heads=num_kv_heads,
701
+ head_dim_qk=head_dim_qk,
702
+ head_dim_vo=head_dim_vo,
703
+ causal=False, # This is context run
704
+ sm_scale=self._global_hyperparameters.sm_scale,
705
+ window_left=self._global_hyperparameters.window_left,
706
+ logits_soft_cap=self._global_hyperparameters.logits_soft_cap,
707
+ q_data_type=self.model_config.dtype,
708
+ )
709
+
710
+ prefill.prefill_main = self._fi_prefill_main
711
+ prefill.prefill_chunks = self._fi_prefill_chunks
712
+
713
+ def _build_decode(
714
+ self,
715
+ block_table_tensor: torch.Tensor,
716
+ seq_lens_cpu: torch.Tensor,
717
+ seq_lens_device: torch.Tensor,
718
+ query_start_loc_cpu: torch.Tensor,
719
+ query_start_loc_device: torch.Tensor,
720
+ num_decode_tokens: int,
721
+ dcp_tot_seq_lens_device: torch.Tensor | None,
722
+ ) -> MLACommonDecodeMetadata:
723
+ return MLACommonDecodeMetadata(
724
+ block_table=block_table_tensor,
725
+ seq_lens=seq_lens_device,
726
+ dcp_tot_seq_lens=dcp_tot_seq_lens_device,
727
+ )
728
+
729
+ def build_for_cudagraph_capture(
730
+ self, common_attn_metadata: CommonAttentionMetadata
731
+ ) -> M:
732
+ """
733
+ This method builds the metadata for full cudagraph capture.
734
+ Currently, only decode is supported for full cudagraphs with MLA.
735
+ """
736
+ m = common_attn_metadata
737
+ assert m.num_reqs <= (m.num_actual_tokens * self.reorder_batch_threshold), (
738
+ "MLA only supports decode-only full CUDAGraph capture. "
739
+ "Make sure all cudagraph capture sizes <= max_num_seq."
740
+ )
741
+
742
+ assert m.max_query_len <= self.reorder_batch_threshold # decode only
743
+
744
+ return self.build(0, m)
745
+
746
+ def build(
747
+ self,
748
+ common_prefix_len: int,
749
+ common_attn_metadata: CommonAttentionMetadata,
750
+ fast_build: bool = False,
751
+ ) -> M:
752
+ num_reqs = common_attn_metadata.num_reqs
753
+ num_tokens = common_attn_metadata.num_actual_tokens
754
+ max_query_len = common_attn_metadata.max_query_len
755
+ max_seq_len = common_attn_metadata.max_seq_len
756
+
757
+ # Note(simon): be careful about the CPU <> GPU memory movement in this
758
+ # function. We should avoid GPU -> CPU sync as much as possible because
759
+ # it blocks on all previous kernels.
760
+ device = self.device
761
+ block_table_tensor = common_attn_metadata.block_table_tensor
762
+ slot_mapping = common_attn_metadata.slot_mapping
763
+
764
+ query_start_loc = common_attn_metadata.query_start_loc
765
+ query_start_loc_cpu = common_attn_metadata.query_start_loc_cpu
766
+ seq_lens = common_attn_metadata.seq_lens
767
+ seq_lens_cpu = common_attn_metadata.seq_lens_cpu
768
+ dcp_local_seq_lens = common_attn_metadata.dcp_local_seq_lens
769
+ dcp_local_seq_lens_cpu = common_attn_metadata.dcp_local_seq_lens_cpu
770
+
771
+ query_seq_lens_cpu = query_start_loc_cpu[1:] - query_start_loc_cpu[:-1]
772
+
773
+ num_computed_tokens_cpu = common_attn_metadata.seq_lens_cpu - query_seq_lens_cpu
774
+
775
+ num_decodes, num_prefills, num_decode_tokens, num_prefill_tokens = (
776
+ split_decodes_and_prefills(
777
+ common_attn_metadata,
778
+ decode_threshold=self.reorder_batch_threshold,
779
+ require_uniform=(self.query_len_support != QueryLenSupport.VARLEN),
780
+ )
781
+ )
782
+
783
+ assert num_decodes + num_prefills == num_reqs
784
+ assert num_decode_tokens + num_prefill_tokens == num_tokens
785
+
786
+ prefill_metadata = None
787
+ if num_prefills > 0:
788
+ reqs_start = num_decodes # prefill_start
789
+
790
+ context_lens_cpu = num_computed_tokens_cpu[reqs_start:num_reqs]
791
+ max_context_len_cpu = context_lens_cpu.max().item()
792
+ num_prefills_with_context_cpu = (context_lens_cpu > 0).sum().item()
793
+ prefill_query_start_loc = (
794
+ query_start_loc[reqs_start:] - query_start_loc[reqs_start]
795
+ )
796
+
797
+ chunked_context_metadata = None
798
+ if max_context_len_cpu > 0:
799
+ # NOTE: it is recommend you read the `Chunked Prefill` section
800
+ # in the comment at the top of the file before trying to
801
+ # understand the following code
802
+
803
+ # currently we allocate an equal amount of workspace for each
804
+ # prefill in the batch, we could probably use a more advanced
805
+ # algorithm here and allocate more workspace to prefills with
806
+ # longer context lengths
807
+ max_context_chunk = (
808
+ self.chunked_prefill_workspace_size // num_prefills_with_context_cpu
809
+ )
810
+
811
+ if self.aot_schedule:
812
+ # align max_context_chunk to page_size by rounding down,
813
+ # currently the `gather_and_maybe_dequant_cache` kernel
814
+ # cannot handle `context_chunk_starts` that are not aligned
815
+ # to page_size
816
+ max_context_chunk = round_down(max_context_chunk, self.page_size)
817
+
818
+ assert max_context_chunk > 0
819
+ num_chunks = cdiv(max_context_len_cpu, max_context_chunk)
820
+
821
+ # if `max_context_chunk = 256`, `num_chunks = 3`, and
822
+ # `num_prefills_with_context = 4`, create a tensor that looks
823
+ # like
824
+ # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
825
+ # Note(simon): this is done in CPU because of downstream's
826
+ # of `to_list`.
827
+ chunk_starts = (
828
+ torch.arange(num_chunks, dtype=torch.int32)
829
+ .unsqueeze(1)
830
+ .expand(-1, num_prefills)
831
+ * max_context_chunk
832
+ )
833
+ chunk_ends = torch.min(
834
+ context_lens_cpu.unsqueeze(0), chunk_starts + max_context_chunk
835
+ )
836
+ chunk_seq_lens = (chunk_ends - chunk_starts).clamp(min=0)
837
+
838
+ cu_seq_lens_cpu = torch.zeros(
839
+ num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
840
+ )
841
+ torch.cumsum(
842
+ chunk_seq_lens, dim=1, out=cu_seq_lens_cpu[:, 1:], dtype=torch.int32
843
+ )
844
+ chunk_total_token = cu_seq_lens_cpu[:, -1]
845
+
846
+ max_token_num_over_chunk = chunk_total_token.max().item()
847
+ token_to_seq_tensor_cpu = torch.zeros(
848
+ [num_chunks, max_token_num_over_chunk], dtype=torch.int32
849
+ )
850
+ range_idx = torch.arange(num_prefills, dtype=torch.int32)
851
+ for i in range(num_chunks):
852
+ chunk_token_to_seq_tensor = torch.repeat_interleave(
853
+ range_idx, chunk_seq_lens[i]
854
+ )
855
+ chunk_len = chunk_token_to_seq_tensor.shape[0]
856
+ token_to_seq_tensor_cpu[i, :chunk_len] = chunk_token_to_seq_tensor
857
+
858
+ if self.dcp_world_size > 1:
859
+ local_context_lens_allranks = get_dcp_local_seq_lens(
860
+ context_lens_cpu,
861
+ self.dcp_world_size,
862
+ None,
863
+ self.dcp_local_block_size,
864
+ )
865
+ # Note(qcs): The max local context lengths
866
+ # padded to `dcp_local_block_size`.
867
+ padded_local_context_lens_cpu = (
868
+ cdiv(
869
+ context_lens_cpu,
870
+ self.dcp_virtual_block_size,
871
+ )
872
+ * self.dcp_local_block_size
873
+ )
874
+ # Note(hc): The above max_context_chunk already enforces
875
+ # block_size alignment, DCP just need the block_size can
876
+ # be divisible by dcp_world_size, because DCP use
877
+ # cp_gather_cache which not require `cp_chunk_starts`
878
+ # aligned to page_size.
879
+ assert max_context_chunk % self.dcp_world_size == 0
880
+ padded_local_max_context_chunk_across_ranks = (
881
+ cdiv(
882
+ max_context_chunk,
883
+ self.dcp_virtual_block_size,
884
+ )
885
+ * self.dcp_local_block_size
886
+ )
887
+ local_chunk_starts = (
888
+ torch.arange(num_chunks, dtype=torch.int32)
889
+ .unsqueeze(1)
890
+ .expand(-1, num_prefills)
891
+ * padded_local_max_context_chunk_across_ranks
892
+ )
893
+ local_chunk_ends = torch.min(
894
+ padded_local_context_lens_cpu.unsqueeze(0),
895
+ local_chunk_starts
896
+ + padded_local_max_context_chunk_across_ranks,
897
+ )
898
+ padded_local_chunk_seq_lens = (
899
+ local_chunk_ends - local_chunk_starts
900
+ ).clamp(min=0)
901
+
902
+ padded_local_cu_chunk_seq_lens_cpu = torch.zeros(
903
+ num_chunks, num_prefills + 1, dtype=torch.int32, pin_memory=True
904
+ )
905
+ torch.cumsum(
906
+ padded_local_chunk_seq_lens,
907
+ dim=1,
908
+ out=padded_local_cu_chunk_seq_lens_cpu[:, 1:],
909
+ dtype=torch.int32,
910
+ )
911
+
912
+ chunked_context_metadata_cls = (
913
+ CudnnPrefillMetadata.ChunkedContextMetadata
914
+ if self._use_cudnn_prefill
915
+ else MLACommonPrefillMetadata.ChunkedContextMetadata
916
+ )
917
+ if self.dcp_world_size > 1:
918
+ chunked_context_metadata = chunked_context_metadata_cls(
919
+ cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
920
+ starts=local_chunk_starts.to(device, non_blocking=True),
921
+ seq_tot=padded_local_chunk_seq_lens.sum(dim=1).tolist(),
922
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
923
+ seq_lens=chunk_seq_lens,
924
+ token_to_seq=token_to_seq_tensor_cpu.to(
925
+ device, non_blocking=True
926
+ ),
927
+ chunk_total_token=chunk_total_token.tolist(),
928
+ workspace=self.chunked_prefill_workspace,
929
+ padded_local_chunk_seq_lens=padded_local_chunk_seq_lens.tolist(),
930
+ local_context_lens_allranks=local_context_lens_allranks.tolist(),
931
+ padded_local_cu_seq_lens=padded_local_cu_chunk_seq_lens_cpu.to(
932
+ device, non_blocking=True
933
+ ),
934
+ cu_seq_lens_lst=cu_seq_lens_cpu.tolist(),
935
+ chunk_size=padded_local_max_context_chunk_across_ranks,
936
+ )
937
+ else:
938
+ chunked_context_metadata = chunked_context_metadata_cls(
939
+ cu_seq_lens=cu_seq_lens_cpu.to(device, non_blocking=True),
940
+ starts=chunk_starts.to(device, non_blocking=True),
941
+ seq_tot=chunk_seq_lens.sum(dim=1).tolist(),
942
+ max_seq_lens=chunk_seq_lens.max(dim=1).values.tolist(),
943
+ seq_lens=chunk_seq_lens,
944
+ token_to_seq=token_to_seq_tensor_cpu.to(
945
+ device, non_blocking=True
946
+ ),
947
+ chunk_total_token=chunk_total_token,
948
+ workspace=self.chunked_prefill_workspace,
949
+ )
950
+
951
+ if self._use_cudnn_prefill:
952
+ chunked_context_metadata.seq_lens = chunk_seq_lens
953
+
954
+ assert (
955
+ max(chunked_context_metadata.max_seq_lens)
956
+ <= self.chunked_prefill_workspace_size
957
+ )
958
+
959
+ prefill_metadata = self.prefill_metadata_cls(
960
+ block_table=block_table_tensor[reqs_start:, ...],
961
+ query_start_loc=prefill_query_start_loc,
962
+ max_query_len=max_query_len,
963
+ chunked_context=chunked_context_metadata,
964
+ )
965
+
966
+ if self._use_cudnn_prefill:
967
+ assert isinstance(prefill_metadata, CudnnPrefillMetadata)
968
+ prefill_metadata.query_seq_lens = (
969
+ prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
970
+ )
971
+ prefill_metadata.cudnn_workspace = self.cudnn_workspace
972
+
973
+ if self._use_trtllm_ragged_prefill:
974
+ prefill_metadata.query_seq_lens = (
975
+ prefill_query_start_loc[1:] - prefill_query_start_loc[:-1]
976
+ )
977
+
978
+ decode_metadata = None
979
+ if num_decodes > 0:
980
+ dcp_tot_seq_lens_device = None
981
+ if self.dcp_world_size > 1:
982
+ dcp_tot_seq_lens_device = seq_lens[:num_decodes]
983
+ seq_lens_cpu = dcp_local_seq_lens_cpu
984
+ seq_lens = dcp_local_seq_lens
985
+
986
+ decode_metadata = self._build_decode(
987
+ block_table_tensor=block_table_tensor[:num_decodes, ...],
988
+ seq_lens_cpu=seq_lens_cpu[:num_decodes],
989
+ seq_lens_device=seq_lens[:num_decodes],
990
+ query_start_loc_cpu=query_start_loc_cpu[: num_decodes + 1],
991
+ query_start_loc_device=query_start_loc[: num_decodes + 1],
992
+ num_decode_tokens=num_decode_tokens,
993
+ dcp_tot_seq_lens_device=dcp_tot_seq_lens_device,
994
+ )
995
+
996
+ attn_metadata = self.metadata_cls(
997
+ num_reqs=common_attn_metadata.num_reqs,
998
+ max_query_len=common_attn_metadata.max_query_len,
999
+ max_seq_len=max_seq_len,
1000
+ num_actual_tokens=num_tokens,
1001
+ query_start_loc=query_start_loc,
1002
+ slot_mapping=slot_mapping,
1003
+ head_dim=self.model_config.get_head_size(),
1004
+ # MLACommonMetadata Chunk prefill specific
1005
+ num_decodes=num_decodes,
1006
+ num_decode_tokens=num_decode_tokens,
1007
+ num_prefills=num_prefills,
1008
+ prefill=prefill_metadata,
1009
+ decode=decode_metadata,
1010
+ )
1011
+
1012
+ if self._use_fi_prefill and num_prefills > 0:
1013
+ assert isinstance(attn_metadata.prefill, FlashInferPrefillMetadata)
1014
+ self._build_fi_prefill_wrappers(attn_metadata.prefill)
1015
+
1016
+ return attn_metadata
1017
+
1018
+
1019
+ def reorg_kvcache(
1020
+ allgatered_kv_c_normed: torch.Tensor,
1021
+ allgatered_k_pe: torch.Tensor,
1022
+ padded_local_chunk_seq_lens_lst: list[int],
1023
+ local_context_lens_allranks: list[list[int]],
1024
+ sum_seq_len: int,
1025
+ max_seq_len: int,
1026
+ chunk_size: int,
1027
+ chunk_idx: int,
1028
+ toks: int,
1029
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1030
+ """
1031
+ reorg and unpad kvcache after cp local gather to tp layout for attn kernel.
1032
+ e.g.
1033
+ allgatered_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T1_0, T1_1, ...,
1034
+ T0_4, T0_5, pad, pad, T1_2, pad, ...]
1035
+ -> reorganized_kv_c_normed = [T0_0, T0_1, T0_2, T0_3, T0_4, T0_5,
1036
+ T1_0, T1_1, T1_2, ...]
1037
+ Args:
1038
+ padded_local_chunk_seq_lens_lst: local chunk context lengths
1039
+ under current CP rank.
1040
+ local_context_lens_allranks: local context lengths on each CP rank.
1041
+ sum_seq_len: the sum of cp_chunk_seq_lens_lst.
1042
+ max_seq_len: the max value of cp_chunk_seq_lens_lst.
1043
+ chunk_size: the local padded max context chunk from
1044
+ chunked_context_metadata building.
1045
+ chunk_idx: chunk idx of chunked_prefill.
1046
+ toks: the number of tokens for local gather cache.
1047
+ """
1048
+ kv_c_segments = []
1049
+ k_pe_segments = []
1050
+ src_token_idx = 0
1051
+ max_seq_len_check = 0
1052
+ for padded_local_chunk_seq_len, local_context_lens in zip(
1053
+ padded_local_chunk_seq_lens_lst, local_context_lens_allranks
1054
+ ):
1055
+ cur_seq_len = 0
1056
+ for rank, local_context_len in enumerate(local_context_lens):
1057
+ # Note(qcs): We split the context into multiple chunks,
1058
+ # depending on the size of the workspace.
1059
+ # local_context in dcp0: |-----------------|
1060
+ # local_context in dcp1: |--------------|
1061
+ # n*padded_local_chunk: |-----|-----|-----|
1062
+ # local_chunk_len in dcp1: |-----|-----|--|
1063
+ # so we need update the last chunk length in dcp1.
1064
+ local_chunk_len = min(
1065
+ max(0, local_context_len - chunk_idx * chunk_size),
1066
+ padded_local_chunk_seq_len,
1067
+ )
1068
+ if local_chunk_len != 0:
1069
+ kv_c_segment = allgatered_kv_c_normed[
1070
+ rank * toks + src_token_idx : rank * toks
1071
+ + src_token_idx
1072
+ + local_chunk_len
1073
+ ]
1074
+ k_pe_segment = allgatered_k_pe[
1075
+ rank * toks + src_token_idx : rank * toks
1076
+ + src_token_idx
1077
+ + local_chunk_len
1078
+ ]
1079
+ kv_c_segments.append(kv_c_segment)
1080
+ k_pe_segments.append(k_pe_segment)
1081
+ cur_seq_len += local_chunk_len
1082
+ max_seq_len_check = max(max_seq_len_check, cur_seq_len)
1083
+ src_token_idx += padded_local_chunk_seq_len
1084
+ reorganized_kv_c_normed = torch.cat(kv_c_segments, dim=0)
1085
+ reorganized_k_pe = torch.cat(k_pe_segments, dim=0)
1086
+ assert reorganized_kv_c_normed.shape[0] == sum_seq_len
1087
+ assert reorganized_k_pe.shape[0] == sum_seq_len
1088
+ assert max_seq_len_check == max_seq_len
1089
+ return reorganized_kv_c_normed, reorganized_k_pe
1090
+
1091
+
1092
+ # TODO(Lucas): rename MLACommonBaseImpl -> MLACommonImpl,
1093
+ # and MLACommonImpl -> MLACommonDenseImpl or somthing like that
1094
+ class MLACommonBaseImpl(MLAAttentionImpl[A], Generic[A]):
1095
+ """
1096
+ NOTE: Please read the comment at the top of the file before trying to
1097
+ understand this class
1098
+ """
1099
+
1100
+ def __init__(
1101
+ self,
1102
+ num_heads: int,
1103
+ head_size: int,
1104
+ scale: float,
1105
+ num_kv_heads: int,
1106
+ alibi_slopes: list[float] | None,
1107
+ sliding_window: int | None,
1108
+ kv_cache_dtype: str,
1109
+ logits_soft_cap: float | None,
1110
+ attn_type: str,
1111
+ kv_sharing_target_layer_name: str | None,
1112
+ # MLA Specific Arguments
1113
+ q_lora_rank: int | None,
1114
+ kv_lora_rank: int,
1115
+ qk_nope_head_dim: int,
1116
+ qk_rope_head_dim: int,
1117
+ qk_head_dim: int,
1118
+ v_head_dim: int,
1119
+ kv_b_proj: ColumnParallelLinear,
1120
+ indexer=None,
1121
+ q_pad_num_heads: int | None = None,
1122
+ ) -> None:
1123
+ if kv_sharing_target_layer_name is not None:
1124
+ raise NotImplementedError("KV sharing is not supported for MLA")
1125
+
1126
+ self.num_heads = num_heads
1127
+ self.head_size = head_size
1128
+ self.scale = float(scale)
1129
+ self.num_kv_heads = num_kv_heads
1130
+ self.kv_cache_dtype = kv_cache_dtype
1131
+
1132
+ self.q_lora_rank = q_lora_rank
1133
+ self.kv_lora_rank = kv_lora_rank
1134
+ self.qk_nope_head_dim = qk_nope_head_dim
1135
+ self.qk_rope_head_dim = qk_rope_head_dim
1136
+ self.qk_head_dim = qk_head_dim
1137
+ self.v_head_dim = v_head_dim
1138
+ self.kv_b_proj = kv_b_proj
1139
+ self.indexer = indexer
1140
+ self.q_pad_num_heads = q_pad_num_heads
1141
+ self.is_aiter_triton_fp8_bmm_enabled = rocm_aiter_ops.is_fp8bmm_enabled()
1142
+
1143
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1144
+ def get_layer_weight(layer):
1145
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
1146
+ for attr in WEIGHT_NAMES:
1147
+ if hasattr(layer, attr):
1148
+ return getattr(layer, attr)
1149
+ raise AttributeError(
1150
+ f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
1151
+ )
1152
+
1153
+ def get_and_maybe_dequant_weights(layer: LinearBase):
1154
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
1155
+ # NOTE: This should only be used offline, since it's O(N^3)
1156
+ eye = torch.eye(
1157
+ layer.input_size_per_partition,
1158
+ dtype=act_dtype,
1159
+ device=get_layer_weight(layer).device,
1160
+ )
1161
+ dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
1162
+ del eye
1163
+ # standardize to (output, input)
1164
+ return dequant_weights.T
1165
+ return layer.weight
1166
+
1167
+ # we currently do not have quantized bmm's which are needed for
1168
+ # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
1169
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
1170
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
1171
+ assert kv_b_proj_weight.shape == (
1172
+ self.kv_lora_rank,
1173
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
1174
+ ), (
1175
+ f"{kv_b_proj_weight.shape=}, "
1176
+ f"{self.kv_lora_rank=}, "
1177
+ f"{self.num_heads=}, "
1178
+ f"{self.qk_nope_head_dim=}, "
1179
+ f"{self.v_head_dim=}"
1180
+ )
1181
+ kv_b_proj_weight = kv_b_proj_weight.view(
1182
+ self.kv_lora_rank,
1183
+ self.num_heads,
1184
+ self.qk_nope_head_dim + self.v_head_dim,
1185
+ )
1186
+
1187
+ W_UK, W_UV = kv_b_proj_weight.split(
1188
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1
1189
+ )
1190
+
1191
+ if self.is_aiter_triton_fp8_bmm_enabled:
1192
+ W_K = W_UK.transpose(0, 1) # 16 512 128
1193
+ W_V = W_UV.permute(1, 2, 0) # 16 128 512
1194
+ self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
1195
+ W_K, dtype=current_platform.fp8_dtype()
1196
+ )
1197
+ self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
1198
+ W_V, dtype=current_platform.fp8_dtype()
1199
+ )
1200
+
1201
+ # The kernel operates on non-padded inputs. Hence, pre-compiling
1202
+ # triton kernel to avoid runtime compilation for unseen batch sizes
1203
+ # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
1204
+ # On DS-R1, this step adds roughly 50s to the model loading time.
1205
+ max_batch_size = 1024 # [ToDo] Find the optimal upper limit
1206
+ pre_compilation_list = list(range(1, max_batch_size + 1))
1207
+ if is_global_first_rank():
1208
+ pre_compilation_list = tqdm(
1209
+ pre_compilation_list,
1210
+ desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
1211
+ total=max_batch_size,
1212
+ )
1213
+
1214
+ for m in pre_compilation_list:
1215
+ x = torch.empty(
1216
+ (self.W_K.shape[0], m, self.W_K.shape[2]),
1217
+ dtype=torch.bfloat16,
1218
+ device=self.W_K.device,
1219
+ )
1220
+ rocm_aiter_ops.triton_fp8_bmm(
1221
+ x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
1222
+ )
1223
+
1224
+ x = torch.empty(
1225
+ (self.W_V.shape[0], m, self.W_V.shape[2]),
1226
+ dtype=torch.bfloat16,
1227
+ device=self.W_V.device,
1228
+ )
1229
+ rocm_aiter_ops.triton_fp8_bmm(
1230
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
1231
+ )
1232
+ else:
1233
+ # Convert from (L, N, V) to (N, L, V)
1234
+ self.W_UV = W_UV.transpose(0, 1)
1235
+ # Convert from (L, N, P) to (N, P, L)
1236
+ self.W_UK_T = W_UK.permute(1, 2, 0)
1237
+
1238
+ def _v_up_proj(self, x: torch.Tensor, out: torch.Tensor):
1239
+ # Convert from (B, N, L) to (N, B, L)
1240
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1241
+
1242
+ if self.is_aiter_triton_fp8_bmm_enabled:
1243
+ out = out.view(-1, self.num_heads, self.v_head_dim)
1244
+ # Multiply + Transpose (N, B, L) x (N, L, V)->(N, B, V)->(B, N, V)
1245
+ x = rocm_aiter_ops.triton_fp8_bmm(
1246
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True, YQ=out
1247
+ )
1248
+ else:
1249
+ # Convert from (B, N * V) to (N, B, V)
1250
+ out = out.view(-1, self.num_heads, self.v_head_dim).transpose(0, 1)
1251
+
1252
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1253
+ torch.bmm(x, self.W_UV, out=out) # Reuse "out" to make it "hot"
1254
+
1255
+ # Convert from (N, B, V) to (B, N * V)
1256
+ out_new = out.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
1257
+
1258
+ # Adjust output buffer shape back to the original (B, N * V)
1259
+ N, B, V = out.shape
1260
+ out.resize_((B, N * V))
1261
+ out.copy_(out_new) # Copy result
1262
+
1263
+
1264
+ class MLACommonImpl(MLACommonBaseImpl[M], Generic[M]):
1265
+ """
1266
+ NOTE: Please read the comment at the top of the file before trying to
1267
+ understand this class
1268
+ """
1269
+
1270
+ def __init__(self, *args, **kwargs) -> None:
1271
+ super().__init__(*args, **kwargs)
1272
+
1273
+ if use_flashinfer_prefill():
1274
+ logger.debug_once("Using FlashInfer prefill for MLA")
1275
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fi
1276
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fi
1277
+ self._pad_v = False
1278
+ elif use_trtllm_ragged_deepseek_prefill():
1279
+ logger.debug_once("Using TRT-LLM ragged DeepSeek prefill for MLA")
1280
+ self._run_prefill_context_chunk = (
1281
+ self._run_prefill_context_chunk_trtllm_ragged
1282
+ )
1283
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_trtllm_ragged
1284
+ self._pad_v = False
1285
+ elif use_cudnn_prefill():
1286
+ logger.debug_once("Using CUDNN prefill for MLA")
1287
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_cudnn
1288
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_cudnn
1289
+ self._pad_v = False
1290
+ else: # Use FlashAttention
1291
+ logger.debug_once("Using FlashAttention prefill for MLA")
1292
+ self._run_prefill_context_chunk = self._run_prefill_context_chunk_fa
1293
+ self._run_prefill_new_tokens = self._run_prefill_new_tokens_fa
1294
+
1295
+ # Handle the differences between the flash_attn_varlen from
1296
+ # flash_attn and the one from vllm_flash_attn. The former is used on
1297
+ # RoCM and the latter has an additional parameter to control
1298
+ # FA2 vs FA3
1299
+ self.flash_attn_varlen_func = flash_attn_varlen_func
1300
+ self.vllm_flash_attn_version = get_flash_attn_version()
1301
+ if self.vllm_flash_attn_version is not None:
1302
+ self.flash_attn_varlen_func = functools.partial(
1303
+ flash_attn_varlen_func, fa_version=self.vllm_flash_attn_version
1304
+ )
1305
+
1306
+ # For MLA the v head dim is smaller than qk head dim so we pad out
1307
+ # v with 0s to match the qk head dim for attention backends that do
1308
+ # not support different headdims
1309
+ # We don't need to pad V if we are on a hopper system with FA3
1310
+ self._pad_v = self.vllm_flash_attn_version is None or not (
1311
+ self.vllm_flash_attn_version == 3
1312
+ and current_platform.get_device_capability()[0] == 9
1313
+ )
1314
+
1315
+ self.dcp_world_size: int | None = None
1316
+
1317
+ self.chunked_prefill_workspace_size = (
1318
+ MLACommonMetadataBuilder.determine_chunked_prefill_workspace_size(
1319
+ get_current_vllm_config()
1320
+ )
1321
+ )
1322
+ self.cp_kv_cache_interleave_size: int = (
1323
+ get_current_vllm_config().parallel_config.cp_kv_cache_interleave_size
1324
+ )
1325
+
1326
+ def _flash_attn_varlen_diff_headdims(
1327
+ self, q, k, v, return_softmax_lse=False, softmax_scale=None, **kwargs
1328
+ ):
1329
+ maybe_padded_v = v
1330
+ if self._pad_v:
1331
+ maybe_padded_v = torch.nn.functional.pad(
1332
+ v, [0, q.shape[-1] - v.shape[-1]], value=0
1333
+ )
1334
+
1335
+ if is_vllm_fa:
1336
+ kwargs["return_softmax_lse"] = return_softmax_lse
1337
+ else:
1338
+ # ROCm leverages the upstream flash_attn, which takes a parameter
1339
+ # called "return_attn_probs" instead of return_softmax_lse
1340
+ kwargs["return_attn_probs"] = return_softmax_lse
1341
+ if vllm_is_batch_invariant():
1342
+ kwargs["num_splits"] = 1
1343
+
1344
+ attn_out = self.flash_attn_varlen_func(
1345
+ q=q,
1346
+ k=k,
1347
+ v=maybe_padded_v,
1348
+ softmax_scale=softmax_scale,
1349
+ **kwargs,
1350
+ )
1351
+
1352
+ # Unpack the output if there is multiple results
1353
+ lse = None
1354
+ if isinstance(attn_out, tuple):
1355
+ attn_out, lse = attn_out[0], attn_out[1]
1356
+
1357
+ # Remain consistent with old `flash_attn_varlen_func` where there
1358
+ # is only one output tensor if `return_softmax_lse` is False.
1359
+ if return_softmax_lse:
1360
+ return attn_out, lse
1361
+ return attn_out
1362
+
1363
+ def _run_prefill_new_tokens_fa(
1364
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1365
+ ):
1366
+ return self._flash_attn_varlen_diff_headdims(
1367
+ q=q,
1368
+ k=k,
1369
+ v=v,
1370
+ cu_seqlens_q=prefill.query_start_loc,
1371
+ cu_seqlens_k=prefill.query_start_loc,
1372
+ max_seqlen_q=prefill.max_query_len,
1373
+ max_seqlen_k=prefill.max_query_len,
1374
+ softmax_scale=self.scale,
1375
+ causal=True,
1376
+ return_softmax_lse=return_softmax_lse,
1377
+ )
1378
+
1379
+ def _run_prefill_new_tokens_fi(
1380
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1381
+ ):
1382
+ assert isinstance(prefill, FlashInferPrefillMetadata)
1383
+ assert prefill.prefill_main is not None
1384
+
1385
+ ret = prefill.prefill_main.run(
1386
+ q=q,
1387
+ k=k,
1388
+ v=v,
1389
+ return_lse=return_softmax_lse,
1390
+ )
1391
+
1392
+ if isinstance(ret, tuple):
1393
+ return ret[0], ret[1].transpose(0, 1).contiguous()
1394
+ return ret
1395
+
1396
+ def _run_prefill_new_tokens_cudnn(
1397
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1398
+ ):
1399
+ assert isinstance(prefill, CudnnPrefillMetadata)
1400
+ assert prefill.query_seq_lens is not None
1401
+ output, lse = cudnn_batch_prefill_with_kv_cache(
1402
+ q=q,
1403
+ k_cache=k,
1404
+ v_cache=v,
1405
+ scale=self.scale,
1406
+ workspace_buffer=prefill.cudnn_workspace,
1407
+ max_token_per_sequence=prefill.max_query_len,
1408
+ max_sequence_kv=prefill.max_query_len,
1409
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
1410
+ actual_seq_lens_kv=prefill.query_seq_lens.view(-1, 1, 1, 1),
1411
+ causal=True,
1412
+ # Do not support False for now
1413
+ return_lse=True,
1414
+ # Indicates actual_seq_lens are on GPU or CPU.
1415
+ is_cuda_graph_compatible=True,
1416
+ )
1417
+ if return_softmax_lse:
1418
+ return output, lse
1419
+ return output
1420
+
1421
+ def _run_prefill_context_chunk_fa(
1422
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1423
+ ):
1424
+ assert prefill.chunked_context is not None
1425
+ return self._flash_attn_varlen_diff_headdims(
1426
+ q=q,
1427
+ k=k,
1428
+ v=v,
1429
+ cu_seqlens_q=prefill.query_start_loc,
1430
+ cu_seqlens_k=prefill.chunked_context.cu_seq_lens[chunk_idx],
1431
+ max_seqlen_q=prefill.max_query_len,
1432
+ max_seqlen_k=prefill.chunked_context.max_seq_lens[chunk_idx],
1433
+ softmax_scale=self.scale,
1434
+ causal=False, # Context is unmasked
1435
+ return_softmax_lse=True,
1436
+ )
1437
+
1438
+ def _run_prefill_context_chunk_fi(
1439
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1440
+ ):
1441
+ assert isinstance(prefill, FlashInferPrefillMetadata)
1442
+
1443
+ attn_out, lse = prefill.prefill_chunks[chunk_idx].run(
1444
+ q=q,
1445
+ k=k,
1446
+ v=v,
1447
+ return_lse=True,
1448
+ )
1449
+
1450
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1451
+ return attn_out, lse.transpose(0, 1).contiguous()
1452
+
1453
+ def _run_prefill_context_chunk_cudnn(
1454
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1455
+ ):
1456
+ assert isinstance(prefill, CudnnPrefillMetadata)
1457
+ assert prefill.chunked_context is not None
1458
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
1459
+ assert prefill.query_seq_lens is not None
1460
+ return cudnn_batch_prefill_with_kv_cache(
1461
+ q=q,
1462
+ k_cache=k,
1463
+ v_cache=v,
1464
+ scale=self.scale,
1465
+ workspace_buffer=prefill.cudnn_workspace,
1466
+ max_token_per_sequence=prefill.max_query_len,
1467
+ max_sequence_kv=prefill.chunked_context.max_seq_lens[chunk_idx],
1468
+ actual_seq_lens_q=prefill.query_seq_lens.view(-1, 1, 1, 1),
1469
+ actual_seq_lens_kv=prefill.chunked_context.seq_lens[chunk_idx].view(
1470
+ -1, 1, 1, 1
1471
+ ),
1472
+ causal=False,
1473
+ return_lse=True,
1474
+ # Indicates actual_seq_lens are on GPU or CPU.
1475
+ is_cuda_graph_compatible=True,
1476
+ )
1477
+
1478
+ def _run_prefill_new_tokens_trtllm_ragged(
1479
+ self, prefill: MLACommonPrefillMetadata, q, k, v, return_softmax_lse
1480
+ ):
1481
+ """TRT-LLM ragged attention for new tokens (causal)."""
1482
+ from flashinfer.prefill import trtllm_ragged_attention_deepseek
1483
+
1484
+ assert prefill.query_seq_lens is not None
1485
+
1486
+ ret = trtllm_ragged_attention_deepseek(
1487
+ query=q,
1488
+ key=k,
1489
+ value=v,
1490
+ workspace_buffer=self._workspace_buffer,
1491
+ seq_lens=prefill.query_seq_lens,
1492
+ max_q_len=prefill.max_query_len,
1493
+ max_kv_len=prefill.max_query_len,
1494
+ bmm1_scale=self.scale,
1495
+ bmm2_scale=1.0,
1496
+ o_sf_scale=1.0,
1497
+ batch_size=prefill.query_seq_lens.shape[0],
1498
+ window_left=-1,
1499
+ cum_seq_lens_q=prefill.query_start_loc,
1500
+ cum_seq_lens_kv=prefill.query_start_loc,
1501
+ enable_pdl=False,
1502
+ is_causal=True,
1503
+ return_lse=return_softmax_lse,
1504
+ )
1505
+
1506
+ if isinstance(ret, tuple):
1507
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1508
+ return ret[0], ret[1].transpose(0, 1).contiguous()
1509
+ return ret
1510
+
1511
+ def _run_prefill_context_chunk_trtllm_ragged(
1512
+ self, prefill: MLACommonPrefillMetadata, chunk_idx: int, q, k, v
1513
+ ):
1514
+ """TRT-LLM ragged attention for context chunks (non-causal)."""
1515
+ from flashinfer.prefill import trtllm_ragged_attention_deepseek
1516
+
1517
+ assert prefill.chunked_context is not None
1518
+ assert prefill.chunked_context.seq_lens[chunk_idx] is not None
1519
+
1520
+ out = torch.zeros(
1521
+ q.shape[0],
1522
+ q.shape[1],
1523
+ v.shape[2],
1524
+ device=q.device,
1525
+ dtype=q.dtype,
1526
+ )
1527
+ self._workspace_buffer.fill_(0)
1528
+
1529
+ attn_out, lse = trtllm_ragged_attention_deepseek(
1530
+ query=q,
1531
+ key=k,
1532
+ value=v,
1533
+ workspace_buffer=self._workspace_buffer,
1534
+ seq_lens=prefill.chunked_context.seq_lens[chunk_idx],
1535
+ max_q_len=prefill.max_query_len,
1536
+ max_kv_len=prefill.chunked_context.max_seq_lens[chunk_idx],
1537
+ bmm1_scale=self.scale,
1538
+ bmm2_scale=1.0,
1539
+ o_sf_scale=1.0,
1540
+ batch_size=prefill.chunked_context.seq_lens[chunk_idx].shape[0],
1541
+ window_left=-1,
1542
+ cum_seq_lens_q=prefill.query_start_loc,
1543
+ cum_seq_lens_kv=prefill.chunked_context.cu_seq_lens[chunk_idx],
1544
+ enable_pdl=False,
1545
+ is_causal=False,
1546
+ return_lse=True,
1547
+ out=out,
1548
+ )
1549
+
1550
+ # Convert from (q_len, num_heads) to (num_heads, q_len)
1551
+ return attn_out, lse.transpose(0, 1).contiguous()
1552
+
1553
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1554
+ def get_layer_weight(layer):
1555
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
1556
+ for attr in WEIGHT_NAMES:
1557
+ if hasattr(layer, attr):
1558
+ return getattr(layer, attr)
1559
+ raise AttributeError(
1560
+ f"Layer '{layer}' has no recognized weight attribute: {WEIGHT_NAMES}."
1561
+ )
1562
+
1563
+ def get_and_maybe_dequant_weights(layer: LinearBase):
1564
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
1565
+ # NOTE: This should only be used offline, since it's O(N^3)
1566
+ eye = torch.eye(
1567
+ layer.input_size_per_partition,
1568
+ dtype=act_dtype,
1569
+ device=get_layer_weight(layer).device,
1570
+ )
1571
+ dequant_weights = layer.quant_method.apply(layer, eye, bias=None)
1572
+ del eye
1573
+ # standardize to (output, input)
1574
+ return dequant_weights.T
1575
+ return layer.weight
1576
+
1577
+ # we currently do not have quantized bmm's which are needed for
1578
+ # `W_UV` and `W_UK_T`, we just store fp16/bf16 copies and perform
1579
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
1580
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
1581
+ assert kv_b_proj_weight.shape == (
1582
+ self.kv_lora_rank,
1583
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
1584
+ ), (
1585
+ f"{kv_b_proj_weight.shape=}, "
1586
+ f"{self.kv_lora_rank=}, "
1587
+ f"{self.num_heads=}, "
1588
+ f"{self.qk_nope_head_dim=}, "
1589
+ f"{self.v_head_dim=}"
1590
+ )
1591
+ kv_b_proj_weight = kv_b_proj_weight.view(
1592
+ self.kv_lora_rank,
1593
+ self.num_heads,
1594
+ self.qk_nope_head_dim + self.v_head_dim,
1595
+ )
1596
+
1597
+ W_UK, W_UV = kv_b_proj_weight.split(
1598
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1
1599
+ )
1600
+
1601
+ if self.is_aiter_triton_fp8_bmm_enabled:
1602
+ W_K = W_UK.transpose(0, 1) # 16 512 128
1603
+ W_V = W_UV.permute(1, 2, 0) # 16 128 512
1604
+ self.W_K, self.W_K_scale = dynamic_per_batched_tensor_quant(
1605
+ W_K, dtype=current_platform.fp8_dtype()
1606
+ )
1607
+ self.W_V, self.W_V_scale = dynamic_per_batched_tensor_quant(
1608
+ W_V, dtype=current_platform.fp8_dtype()
1609
+ )
1610
+
1611
+ # The kernel operates on non-padded inputs. Hence, pre-compiling
1612
+ # triton kernel to avoid runtime compilation for unseen batch sizes
1613
+ # Pre-compile for batch sizes 1 to 1024 to cover most use-cases.
1614
+ # On DS-R1, this step adds roughly 50s to the model loading time.
1615
+ max_batch_size = 1024 # [ToDo] Find the optimal upper limit
1616
+ pre_compilation_list = list(range(1, max_batch_size + 1))
1617
+ if is_global_first_rank():
1618
+ pre_compilation_list = tqdm(
1619
+ pre_compilation_list,
1620
+ desc="[Aiter Triton] Pre-compiling fp8 BMM kernel",
1621
+ total=max_batch_size,
1622
+ )
1623
+
1624
+ for m in pre_compilation_list:
1625
+ x = torch.empty(
1626
+ (self.W_K.shape[0], m, self.W_K.shape[2]),
1627
+ dtype=torch.bfloat16,
1628
+ device=self.W_K.device,
1629
+ )
1630
+ rocm_aiter_ops.triton_fp8_bmm(
1631
+ x, self.W_K, self.W_K_scale, group_size=128, transpose_bm=True
1632
+ )
1633
+
1634
+ x = torch.empty(
1635
+ (self.W_V.shape[0], m, self.W_V.shape[2]),
1636
+ dtype=torch.bfloat16,
1637
+ device=self.W_V.device,
1638
+ )
1639
+ rocm_aiter_ops.triton_fp8_bmm(
1640
+ x, self.W_V, self.W_V_scale, group_size=128, transpose_bm=True
1641
+ )
1642
+ else:
1643
+ # Convert from (L, N, V) to (N, L, V)
1644
+ self.W_UV = W_UV.transpose(0, 1)
1645
+ # Convert from (L, N, P) to (N, P, L)
1646
+ self.W_UK_T = W_UK.permute(1, 2, 0)
1647
+
1648
+ def _compute_prefill_context(
1649
+ self,
1650
+ q: torch.Tensor,
1651
+ kv_c_and_k_pe_cache: torch.Tensor,
1652
+ attn_metadata: MLACommonMetadata,
1653
+ k_scale: torch.Tensor,
1654
+ ):
1655
+ assert attn_metadata.prefill is not None
1656
+ prefill_metadata = attn_metadata.prefill
1657
+ assert prefill_metadata.chunked_context is not None
1658
+
1659
+ output = None
1660
+ iters = len(prefill_metadata.chunked_context.seq_tot)
1661
+ workspace = prefill_metadata.chunked_context.workspace
1662
+ for i in range(iters):
1663
+ toks = prefill_metadata.chunked_context.seq_tot[i]
1664
+ ops.gather_and_maybe_dequant_cache(
1665
+ src_cache=kv_c_and_k_pe_cache,
1666
+ dst=workspace,
1667
+ block_table=prefill_metadata.block_table,
1668
+ cu_seq_lens=prefill_metadata.chunked_context.cu_seq_lens[i],
1669
+ token_to_seq=prefill_metadata.chunked_context.token_to_seq[i],
1670
+ num_tokens=prefill_metadata.chunked_context.chunk_total_token[i],
1671
+ kv_cache_dtype=self.kv_cache_dtype,
1672
+ scale=k_scale,
1673
+ seq_starts=prefill_metadata.chunked_context.starts[i],
1674
+ )
1675
+
1676
+ kv_c_normed = workspace[:toks][..., : self.kv_lora_rank]
1677
+ k_pe = workspace[:toks][..., self.kv_lora_rank :].unsqueeze(1)
1678
+
1679
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
1680
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
1681
+ )
1682
+ k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1683
+
1684
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1685
+
1686
+ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
1687
+ prefill=prefill_metadata,
1688
+ chunk_idx=i,
1689
+ q=q,
1690
+ k=k,
1691
+ v=v,
1692
+ )
1693
+
1694
+ if output is None:
1695
+ output = attn_output
1696
+ output_lse = attn_softmax_lse
1697
+ else:
1698
+ output_tmp = torch.empty_like(output)
1699
+ output_lse_tmp = torch.empty_like(output_lse)
1700
+ merge_attn_states(
1701
+ output=output_tmp,
1702
+ output_lse=output_lse_tmp,
1703
+ prefix_output=output,
1704
+ prefix_lse=output_lse,
1705
+ suffix_output=attn_output,
1706
+ suffix_lse=attn_softmax_lse,
1707
+ )
1708
+ output = output_tmp
1709
+ output_lse = output_lse_tmp
1710
+
1711
+ return output, output_lse
1712
+
1713
+ def _context_parallel_compute_prefill_context(
1714
+ self,
1715
+ q: torch.Tensor,
1716
+ kv_c_and_k_pe_cache: torch.Tensor,
1717
+ attn_metadata: MLACommonMetadata,
1718
+ k_scale: torch.Tensor,
1719
+ dcp_world_size: int,
1720
+ ):
1721
+ assert k_scale is None, "DCP not support scaled kvcache now."
1722
+ assert attn_metadata.prefill is not None
1723
+ prefill_metadata = attn_metadata.prefill
1724
+ assert prefill_metadata.chunked_context is not None
1725
+ assert prefill_metadata.chunked_context.padded_local_chunk_seq_lens is not None
1726
+ assert prefill_metadata.chunked_context.local_context_lens_allranks is not None
1727
+ assert prefill_metadata.chunked_context.padded_local_cu_seq_lens is not None
1728
+ assert prefill_metadata.chunked_context.cu_seq_lens_lst is not None
1729
+ assert prefill_metadata.chunked_context.chunk_size is not None
1730
+
1731
+ output = None
1732
+ iters = len(prefill_metadata.chunked_context.seq_tot)
1733
+ workspace = prefill_metadata.chunked_context.workspace
1734
+
1735
+ for i in range(iters):
1736
+ toks = prefill_metadata.chunked_context.seq_tot[i]
1737
+ ops.cp_gather_cache(
1738
+ src_cache=kv_c_and_k_pe_cache,
1739
+ dst=workspace,
1740
+ block_table=prefill_metadata.block_table,
1741
+ cu_seq_lens=prefill_metadata.chunked_context.padded_local_cu_seq_lens[
1742
+ i
1743
+ ],
1744
+ batch_size=attn_metadata.num_prefills,
1745
+ seq_starts=prefill_metadata.chunked_context.starts[i],
1746
+ )
1747
+ # workspace
1748
+ # |------- N tokens --------|--------- N*dcp_size tokens ----------|
1749
+ # |<- use for loca_gather ->|<--------- use for allgather -------->|
1750
+ allgather_offset = workspace.shape[0] // (dcp_world_size + 1)
1751
+ assert allgather_offset * (dcp_world_size + 1) == workspace.shape[0]
1752
+ assert toks <= allgather_offset
1753
+ local_gathered_kvcache = workspace[:toks]
1754
+ cur_allgather_workspace = workspace[
1755
+ allgather_offset : allgather_offset * (1 + dcp_world_size)
1756
+ ]
1757
+ assert toks * dcp_world_size <= cur_allgather_workspace.shape[0]
1758
+ cur_allgather_kvcache = cur_allgather_workspace[: toks * dcp_world_size]
1759
+ cur_allgather_kvcache.copy_(
1760
+ get_dcp_group().all_gather(local_gathered_kvcache, dim=0)
1761
+ )
1762
+ assert (
1763
+ cur_allgather_kvcache.shape[-1]
1764
+ == self.kv_lora_rank + self.qk_rope_head_dim
1765
+ )
1766
+ allgatered_kv_c_normed, allgatered_k_pe = cur_allgather_kvcache.unsqueeze(
1767
+ 1
1768
+ ).split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
1769
+
1770
+ kv_c_normed, k_pe = reorg_kvcache(
1771
+ allgatered_kv_c_normed,
1772
+ allgatered_k_pe,
1773
+ padded_local_chunk_seq_lens_lst=prefill_metadata.chunked_context.padded_local_chunk_seq_lens[
1774
+ i
1775
+ ],
1776
+ local_context_lens_allranks=prefill_metadata.chunked_context.local_context_lens_allranks,
1777
+ sum_seq_len=prefill_metadata.chunked_context.cu_seq_lens_lst[i][-1],
1778
+ max_seq_len=prefill_metadata.chunked_context.max_seq_lens[i],
1779
+ chunk_size=prefill_metadata.chunked_context.chunk_size,
1780
+ chunk_idx=i,
1781
+ toks=toks,
1782
+ )
1783
+
1784
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
1785
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
1786
+ )
1787
+ k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1788
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1789
+
1790
+ attn_output, attn_softmax_lse = self._run_prefill_context_chunk(
1791
+ prefill=prefill_metadata,
1792
+ chunk_idx=i,
1793
+ q=q,
1794
+ k=k,
1795
+ v=v,
1796
+ )
1797
+
1798
+ if output is None:
1799
+ output = attn_output
1800
+ output_lse = attn_softmax_lse
1801
+ else:
1802
+ output_tmp = torch.empty_like(output)
1803
+ output_lse_tmp = torch.empty_like(output_lse)
1804
+ merge_attn_states(
1805
+ output=output_tmp,
1806
+ output_lse=output_lse_tmp,
1807
+ prefix_output=output,
1808
+ prefix_lse=output_lse,
1809
+ suffix_output=attn_output,
1810
+ suffix_lse=attn_softmax_lse,
1811
+ )
1812
+ output = output_tmp
1813
+ output_lse = output_lse_tmp
1814
+
1815
+ return output, output_lse
1816
+
1817
+ def _forward_prefill(
1818
+ self,
1819
+ q: torch.Tensor,
1820
+ kv_c_normed: torch.Tensor,
1821
+ k_pe: torch.Tensor,
1822
+ kv_c_and_k_pe_cache: torch.Tensor,
1823
+ attn_metadata: MLACommonMetadata,
1824
+ k_scale: torch.Tensor,
1825
+ output: torch.Tensor,
1826
+ ) -> None:
1827
+ # TODO (zyongye): Prefill function here
1828
+ assert attn_metadata.prefill is not None
1829
+ assert self.dcp_world_size is not None
1830
+
1831
+ has_context = attn_metadata.prefill.chunked_context is not None
1832
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(
1833
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim
1834
+ )
1835
+ k_nope, v = kv_nope.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1836
+
1837
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1838
+
1839
+ output_prefill = self._run_prefill_new_tokens(
1840
+ prefill=attn_metadata.prefill,
1841
+ q=q,
1842
+ k=k,
1843
+ v=v,
1844
+ return_softmax_lse=has_context,
1845
+ )
1846
+
1847
+ if has_context:
1848
+ suffix_output, suffix_lse = output_prefill
1849
+ if self.dcp_world_size > 1:
1850
+ context_output, context_lse = (
1851
+ self._context_parallel_compute_prefill_context(
1852
+ q,
1853
+ kv_c_and_k_pe_cache,
1854
+ attn_metadata,
1855
+ k_scale=None,
1856
+ dcp_world_size=self.dcp_world_size,
1857
+ )
1858
+ )
1859
+ else:
1860
+ context_output, context_lse = self._compute_prefill_context(
1861
+ q, kv_c_and_k_pe_cache, attn_metadata, k_scale
1862
+ )
1863
+
1864
+ # unpad if necessary
1865
+ if self._pad_v:
1866
+ context_output = context_output[..., : v.shape[-1]]
1867
+ suffix_output = suffix_output[..., : v.shape[-1]]
1868
+
1869
+ output = output.view(-1, self.num_heads, self.v_head_dim)
1870
+ merge_attn_states(
1871
+ output=output,
1872
+ prefix_output=context_output,
1873
+ prefix_lse=context_lse,
1874
+ suffix_output=suffix_output,
1875
+ suffix_lse=suffix_lse,
1876
+ )
1877
+ else:
1878
+ output_prefill = output_prefill[..., : v.shape[-1]].flatten(start_dim=-2)
1879
+ output.copy_(output_prefill)
1880
+
1881
+ @abstractmethod
1882
+ def _forward_decode(
1883
+ self,
1884
+ q: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
1885
+ kv_c_and_k_pe_cache: torch.Tensor,
1886
+ attn_metadata: M,
1887
+ layer: AttentionLayer,
1888
+ ) -> tuple[torch.Tensor, torch.Tensor | None]:
1889
+ raise NotImplementedError
1890
+
1891
+ def forward(
1892
+ self,
1893
+ layer: AttentionLayer,
1894
+ q: torch.Tensor,
1895
+ k_c_normed: torch.Tensor, # key in unified attn
1896
+ k_pe: torch.Tensor, # value in unified attn
1897
+ kv_cache: torch.Tensor,
1898
+ attn_metadata: M,
1899
+ output: torch.Tensor | None = None,
1900
+ output_scale: torch.Tensor | None = None,
1901
+ output_block_scale: torch.Tensor | None = None,
1902
+ ) -> torch.Tensor:
1903
+ assert output is not None, "Output tensor must be provided."
1904
+
1905
+ if output_scale is not None or output_block_scale is not None:
1906
+ raise NotImplementedError(
1907
+ "fused output quantization is not yet supported for MLACommonImpl"
1908
+ )
1909
+
1910
+ if attn_metadata is None:
1911
+ # During the profile run try to simulate to worse case output size
1912
+ # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
1913
+ # since this can be large
1914
+ _ = torch.empty(
1915
+ (
1916
+ self.chunked_prefill_workspace_size,
1917
+ self.num_heads,
1918
+ self.qk_nope_head_dim + self.v_head_dim,
1919
+ ),
1920
+ device=k_c_normed.device,
1921
+ dtype=k_c_normed.dtype,
1922
+ )
1923
+
1924
+ # The zero fill is required when used with DP + EP
1925
+ # to ensure all ranks within a DP group compute the
1926
+ # same expert outputs.
1927
+ return output.fill_(0)
1928
+
1929
+ if self.dcp_world_size is None:
1930
+ self.dcp_world_size = get_dcp_group().world_size
1931
+
1932
+ fp8_attention = self.kv_cache_dtype.startswith("fp8")
1933
+
1934
+ num_actual_toks = attn_metadata.num_actual_tokens
1935
+
1936
+ # Inputs and outputs may be padded for CUDA graphs
1937
+ output_padded = output
1938
+ output = output[:num_actual_toks, ...]
1939
+ q = q[:num_actual_toks, ...]
1940
+ k_c_normed = k_c_normed[:num_actual_toks, ...]
1941
+ k_pe = k_pe[:num_actual_toks, ...]
1942
+
1943
+ assert (
1944
+ attn_metadata.num_decodes is not None
1945
+ and attn_metadata.num_prefills is not None
1946
+ and attn_metadata.num_decode_tokens is not None
1947
+ )
1948
+
1949
+ has_decode = attn_metadata.num_decodes > 0
1950
+ has_prefill = attn_metadata.num_prefills > 0
1951
+ num_decode_tokens = attn_metadata.num_decode_tokens
1952
+
1953
+ decode_q = q[:num_decode_tokens]
1954
+
1955
+ prefill_q = q[num_decode_tokens:]
1956
+ prefill_k_pe = k_pe[num_decode_tokens:]
1957
+ prefill_k_c_normed = k_c_normed[num_decode_tokens:]
1958
+
1959
+ # write the latent and rope to kv cache
1960
+ if kv_cache.numel() > 0:
1961
+ ops.concat_and_cache_mla(
1962
+ k_c_normed,
1963
+ k_pe.squeeze(1),
1964
+ kv_cache,
1965
+ attn_metadata.slot_mapping.flatten(),
1966
+ kv_cache_dtype=self.kv_cache_dtype,
1967
+ scale=layer._k_scale,
1968
+ )
1969
+
1970
+ if fp8_attention:
1971
+ kv_cache = kv_cache.view(current_platform.fp8_dtype())
1972
+
1973
+ if has_prefill:
1974
+ self._forward_prefill(
1975
+ prefill_q,
1976
+ prefill_k_c_normed,
1977
+ prefill_k_pe,
1978
+ kv_cache,
1979
+ attn_metadata,
1980
+ layer._k_scale,
1981
+ output=output[num_decode_tokens:],
1982
+ )
1983
+
1984
+ if has_decode:
1985
+ assert attn_metadata.decode is not None
1986
+
1987
+ decode_q_nope, decode_q_pe = decode_q.split(
1988
+ [self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1
1989
+ )
1990
+
1991
+ # Convert from (B, N, P) to (N, B, P)
1992
+ decode_q_nope = decode_q_nope.transpose(0, 1)
1993
+
1994
+ if self.q_pad_num_heads is not None:
1995
+ B, N, L = decode_q_pe.shape
1996
+ decode_pe_padded = decode_q_pe.new_empty((B, self.q_pad_num_heads, L))
1997
+ decode_pe_padded.resize_((B, N, L))
1998
+ decode_pe_padded.copy_(decode_q_pe)
1999
+ decode_q_pe = decode_pe_padded
2000
+
2001
+ if self.is_aiter_triton_fp8_bmm_enabled:
2002
+ # Multiply+Transpose (N, B, P)x(N, P, L)->(N, B, L)->(B, N, L)
2003
+ decode_ql_nope = rocm_aiter_ops.triton_fp8_bmm(
2004
+ decode_q_nope,
2005
+ self.W_K,
2006
+ self.W_K_scale,
2007
+ group_size=128,
2008
+ transpose_bm=True,
2009
+ )
2010
+ else:
2011
+ # Pads the head_dim if necessary (for the underlying kernel)
2012
+ N, B, P = decode_q_nope.shape
2013
+ _, _, L = self.W_UK_T.shape
2014
+
2015
+ if self.q_pad_num_heads is not None:
2016
+ decode_ql_nope = decode_q_nope.new_empty(
2017
+ (self.q_pad_num_heads, B, L)
2018
+ )
2019
+ decode_ql_nope.resize_((N, B, L))
2020
+ else:
2021
+ decode_ql_nope = decode_q_nope.new_empty((N, B, L))
2022
+
2023
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
2024
+ torch.bmm(decode_q_nope, self.W_UK_T, out=decode_ql_nope)
2025
+
2026
+ # Convert from (N, B, L) to (B, N, L)
2027
+ decode_ql_nope = decode_ql_nope.transpose(0, 1)
2028
+
2029
+ if fp8_attention:
2030
+ ql_nope_shape = decode_ql_nope.shape
2031
+ decode_ql_nope, _ = ops.scaled_fp8_quant(
2032
+ decode_ql_nope.reshape(
2033
+ [ql_nope_shape[0], ql_nope_shape[1] * ql_nope_shape[2]]
2034
+ ),
2035
+ layer._q_scale,
2036
+ )
2037
+ decode_ql_nope = decode_ql_nope.reshape(ql_nope_shape)
2038
+ q_pe_shape = decode_q_pe.shape
2039
+ decode_q_pe, _ = ops.scaled_fp8_quant(
2040
+ decode_q_pe.reshape([q_pe_shape[0], q_pe_shape[1] * q_pe_shape[2]]),
2041
+ layer._q_scale,
2042
+ )
2043
+ decode_q_pe = decode_q_pe.reshape(q_pe_shape)
2044
+
2045
+ decode_q = (decode_ql_nope, decode_q_pe)
2046
+ if self.dcp_world_size > 1:
2047
+ assert not fp8_attention, "DCP not support fp8 kvcache now."
2048
+ # concatenate decode_ql_nope and decode_q_pe -> (B, N, L + P)
2049
+ decode_q = torch.cat(decode_q, dim=-1)
2050
+ # decode_q do allgather in head dim.
2051
+ decode_q = get_dcp_group().all_gather(decode_q, dim=1)
2052
+
2053
+ # call decode attn
2054
+ attn_out, lse = self._forward_decode(
2055
+ decode_q, kv_cache, attn_metadata, layer
2056
+ )
2057
+
2058
+ # correct dcp attn_out with lse.
2059
+ if self.dcp_world_size > 1:
2060
+ attn_out = cp_lse_ag_out_rs(
2061
+ attn_out,
2062
+ lse,
2063
+ get_dcp_group(),
2064
+ is_lse_base_on_e=not getattr(self, "_use_fi_prefill", False),
2065
+ )
2066
+
2067
+ # v_up projection
2068
+ self._v_up_proj(attn_out, out=output[:num_decode_tokens])
2069
+ return output_padded