vllm-cpu 0.8.5.post2__cp310-cp310-manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vllm-cpu might be problematic. Click here for more details.

Files changed (1103) hide show
  1. vllm/_C.abi3.so +0 -0
  2. vllm/__init__.py +170 -0
  3. vllm/_custom_ops.py +1536 -0
  4. vllm/_ipex_ops.py +241 -0
  5. vllm/_version.py +34 -0
  6. vllm/adapter_commons/__init__.py +0 -0
  7. vllm/adapter_commons/layers.py +16 -0
  8. vllm/adapter_commons/models.py +105 -0
  9. vllm/adapter_commons/request.py +25 -0
  10. vllm/adapter_commons/utils.py +92 -0
  11. vllm/adapter_commons/worker_manager.py +38 -0
  12. vllm/assets/__init__.py +0 -0
  13. vllm/assets/audio.py +38 -0
  14. vllm/assets/base.py +40 -0
  15. vllm/assets/image.py +31 -0
  16. vllm/assets/video.py +103 -0
  17. vllm/attention/__init__.py +19 -0
  18. vllm/attention/backends/__init__.py +0 -0
  19. vllm/attention/backends/abstract.py +306 -0
  20. vllm/attention/backends/blocksparse_attn.py +457 -0
  21. vllm/attention/backends/cpu_mla.py +303 -0
  22. vllm/attention/backends/flash_attn.py +999 -0
  23. vllm/attention/backends/flashinfer.py +1092 -0
  24. vllm/attention/backends/flashmla.py +242 -0
  25. vllm/attention/backends/hpu_attn.py +301 -0
  26. vllm/attention/backends/ipex_attn.py +396 -0
  27. vllm/attention/backends/mla/__init__.py +0 -0
  28. vllm/attention/backends/mla/common.py +1444 -0
  29. vllm/attention/backends/pallas.py +346 -0
  30. vllm/attention/backends/placeholder_attn.py +399 -0
  31. vllm/attention/backends/rocm_aiter_mla.py +412 -0
  32. vllm/attention/backends/rocm_flash_attn.py +969 -0
  33. vllm/attention/backends/torch_sdpa.py +691 -0
  34. vllm/attention/backends/triton_mla.py +113 -0
  35. vllm/attention/backends/utils.py +609 -0
  36. vllm/attention/backends/xformers.py +798 -0
  37. vllm/attention/layer.py +443 -0
  38. vllm/attention/ops/__init__.py +0 -0
  39. vllm/attention/ops/blocksparse_attention/__init__.py +0 -0
  40. vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +432 -0
  41. vllm/attention/ops/blocksparse_attention/interface.py +238 -0
  42. vllm/attention/ops/blocksparse_attention/utils.py +244 -0
  43. vllm/attention/ops/chunked_prefill_paged_decode.py +366 -0
  44. vllm/attention/ops/flashmla.py +115 -0
  45. vllm/attention/ops/hpu_paged_attn.py +105 -0
  46. vllm/attention/ops/ipex_attn.py +193 -0
  47. vllm/attention/ops/merge_attn_states.py +42 -0
  48. vllm/attention/ops/nki_flash_attn.py +905 -0
  49. vllm/attention/ops/paged_attn.py +255 -0
  50. vllm/attention/ops/prefix_prefill.py +902 -0
  51. vllm/attention/ops/rocm_aiter_mla.py +42 -0
  52. vllm/attention/ops/rocm_aiter_paged_attn.py +101 -0
  53. vllm/attention/ops/triton_decode_attention.py +675 -0
  54. vllm/attention/ops/triton_flash_attention.py +1375 -0
  55. vllm/attention/ops/triton_merge_attn_states.py +96 -0
  56. vllm/attention/selector.py +186 -0
  57. vllm/attention/utils/fa_utils.py +54 -0
  58. vllm/beam_search.py +82 -0
  59. vllm/benchmarks/__init__.py +0 -0
  60. vllm/benchmarks/datasets.py +831 -0
  61. vllm/benchmarks/endpoint_request_func.py +160 -0
  62. vllm/benchmarks/latency.py +181 -0
  63. vllm/benchmarks/serve.py +925 -0
  64. vllm/benchmarks/throughput.py +608 -0
  65. vllm/benchmarks/utils.py +69 -0
  66. vllm/collect_env.py +795 -0
  67. vllm/compilation/__init__.py +0 -0
  68. vllm/compilation/backends.py +715 -0
  69. vllm/compilation/compiler_interface.py +437 -0
  70. vllm/compilation/counter.py +33 -0
  71. vllm/compilation/decorators.py +249 -0
  72. vllm/compilation/fix_functionalization.py +182 -0
  73. vllm/compilation/fusion.py +617 -0
  74. vllm/compilation/fx_utils.py +60 -0
  75. vllm/compilation/inductor_pass.py +114 -0
  76. vllm/compilation/monitor.py +38 -0
  77. vllm/compilation/multi_output_match.py +108 -0
  78. vllm/compilation/noop_elimination.py +135 -0
  79. vllm/compilation/pass_manager.py +74 -0
  80. vllm/compilation/sequence_parallelism.py +266 -0
  81. vllm/compilation/torch25_custom_graph_pass.py +41 -0
  82. vllm/compilation/vllm_inductor_pass.py +68 -0
  83. vllm/compilation/wrapper.py +129 -0
  84. vllm/config.py +4179 -0
  85. vllm/connections.py +170 -0
  86. vllm/core/__init__.py +0 -0
  87. vllm/core/block/__init__.py +0 -0
  88. vllm/core/block/block_table.py +398 -0
  89. vllm/core/block/common.py +370 -0
  90. vllm/core/block/cpu_gpu_block_allocator.py +440 -0
  91. vllm/core/block/interfaces.py +318 -0
  92. vllm/core/block/naive_block.py +465 -0
  93. vllm/core/block/prefix_caching_block.py +1134 -0
  94. vllm/core/block/utils.py +27 -0
  95. vllm/core/block_manager.py +520 -0
  96. vllm/core/evictor.py +156 -0
  97. vllm/core/interfaces.py +134 -0
  98. vllm/core/placeholder_block_space_manager.py +99 -0
  99. vllm/core/scheduler.py +2060 -0
  100. vllm/device_allocator/__init__.py +0 -0
  101. vllm/device_allocator/cumem.py +280 -0
  102. vllm/distributed/__init__.py +5 -0
  103. vllm/distributed/communication_op.py +40 -0
  104. vllm/distributed/device_communicators/__init__.py +0 -0
  105. vllm/distributed/device_communicators/base_device_communicator.py +151 -0
  106. vllm/distributed/device_communicators/cpu_communicator.py +139 -0
  107. vllm/distributed/device_communicators/cuda_communicator.py +131 -0
  108. vllm/distributed/device_communicators/cuda_wrapper.py +179 -0
  109. vllm/distributed/device_communicators/custom_all_reduce.py +301 -0
  110. vllm/distributed/device_communicators/custom_all_reduce_utils.py +257 -0
  111. vllm/distributed/device_communicators/hpu_communicator.py +45 -0
  112. vllm/distributed/device_communicators/neuron_communicator.py +19 -0
  113. vllm/distributed/device_communicators/pynccl.py +217 -0
  114. vllm/distributed/device_communicators/pynccl_wrapper.py +340 -0
  115. vllm/distributed/device_communicators/shm_broadcast.py +557 -0
  116. vllm/distributed/device_communicators/tpu_communicator.py +93 -0
  117. vllm/distributed/device_communicators/xpu_communicator.py +54 -0
  118. vllm/distributed/kv_transfer/README.md +29 -0
  119. vllm/distributed/kv_transfer/__init__.py +11 -0
  120. vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg +0 -0
  121. vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
  122. vllm/distributed/kv_transfer/kv_connector/base.py +127 -0
  123. vllm/distributed/kv_transfer/kv_connector/factory.py +107 -0
  124. vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +98 -0
  125. vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +201 -0
  126. vllm/distributed/kv_transfer/kv_connector/simple_connector.py +328 -0
  127. vllm/distributed/kv_transfer/kv_connector/utils.py +90 -0
  128. vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +8 -0
  129. vllm/distributed/kv_transfer/kv_connector/v1/base.py +209 -0
  130. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +131 -0
  131. vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +383 -0
  132. vllm/distributed/kv_transfer/kv_connector_agent.py +76 -0
  133. vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
  134. vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +174 -0
  135. vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +160 -0
  136. vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +236 -0
  137. vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
  138. vllm/distributed/kv_transfer/kv_pipe/base.py +66 -0
  139. vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +279 -0
  140. vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +279 -0
  141. vllm/distributed/kv_transfer/kv_transfer_state.py +70 -0
  142. vllm/distributed/parallel_state.py +1209 -0
  143. vllm/distributed/utils.py +366 -0
  144. vllm/engine/__init__.py +0 -0
  145. vllm/engine/arg_utils.py +1724 -0
  146. vllm/engine/async_llm_engine.py +1261 -0
  147. vllm/engine/async_timeout.py +191 -0
  148. vllm/engine/llm_engine.py +2150 -0
  149. vllm/engine/metrics.py +717 -0
  150. vllm/engine/metrics_types.py +96 -0
  151. vllm/engine/multiprocessing/__init__.py +183 -0
  152. vllm/engine/multiprocessing/client.py +745 -0
  153. vllm/engine/multiprocessing/engine.py +450 -0
  154. vllm/engine/output_processor/__init__.py +0 -0
  155. vllm/engine/output_processor/interfaces.py +74 -0
  156. vllm/engine/output_processor/multi_step.py +210 -0
  157. vllm/engine/output_processor/single_step.py +136 -0
  158. vllm/engine/output_processor/stop_checker.py +130 -0
  159. vllm/engine/output_processor/util.py +27 -0
  160. vllm/engine/protocol.py +302 -0
  161. vllm/entrypoints/__init__.py +0 -0
  162. vllm/entrypoints/api_server.py +177 -0
  163. vllm/entrypoints/chat_utils.py +1259 -0
  164. vllm/entrypoints/cli/__init__.py +0 -0
  165. vllm/entrypoints/cli/benchmark/__init__.py +0 -0
  166. vllm/entrypoints/cli/benchmark/base.py +38 -0
  167. vllm/entrypoints/cli/benchmark/latency.py +29 -0
  168. vllm/entrypoints/cli/benchmark/main.py +53 -0
  169. vllm/entrypoints/cli/benchmark/serve.py +29 -0
  170. vllm/entrypoints/cli/benchmark/throughput.py +29 -0
  171. vllm/entrypoints/cli/collect_env.py +35 -0
  172. vllm/entrypoints/cli/main.py +59 -0
  173. vllm/entrypoints/cli/openai.py +175 -0
  174. vllm/entrypoints/cli/serve.py +59 -0
  175. vllm/entrypoints/cli/types.py +24 -0
  176. vllm/entrypoints/launcher.py +146 -0
  177. vllm/entrypoints/llm.py +1450 -0
  178. vllm/entrypoints/logger.py +44 -0
  179. vllm/entrypoints/openai/__init__.py +0 -0
  180. vllm/entrypoints/openai/api_server.py +1130 -0
  181. vllm/entrypoints/openai/cli_args.py +296 -0
  182. vllm/entrypoints/openai/logits_processors.py +89 -0
  183. vllm/entrypoints/openai/protocol.py +1806 -0
  184. vllm/entrypoints/openai/run_batch.py +439 -0
  185. vllm/entrypoints/openai/serving_chat.py +1210 -0
  186. vllm/entrypoints/openai/serving_completion.py +557 -0
  187. vllm/entrypoints/openai/serving_embedding.py +245 -0
  188. vllm/entrypoints/openai/serving_engine.py +569 -0
  189. vllm/entrypoints/openai/serving_models.py +314 -0
  190. vllm/entrypoints/openai/serving_pooling.py +237 -0
  191. vllm/entrypoints/openai/serving_score.py +439 -0
  192. vllm/entrypoints/openai/serving_tokenization.py +147 -0
  193. vllm/entrypoints/openai/serving_transcription.py +421 -0
  194. vllm/entrypoints/openai/tool_parsers/__init__.py +19 -0
  195. vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +163 -0
  196. vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +254 -0
  197. vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +232 -0
  198. vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +370 -0
  199. vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +211 -0
  200. vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +303 -0
  201. vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +262 -0
  202. vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +342 -0
  203. vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +110 -0
  204. vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +292 -0
  205. vllm/entrypoints/openai/tool_parsers/utils.py +123 -0
  206. vllm/entrypoints/score_utils.py +49 -0
  207. vllm/entrypoints/ssl.py +74 -0
  208. vllm/entrypoints/utils.py +136 -0
  209. vllm/env_override.py +34 -0
  210. vllm/envs.py +800 -0
  211. vllm/executor/__init__.py +0 -0
  212. vllm/executor/executor_base.py +400 -0
  213. vllm/executor/mp_distributed_executor.py +243 -0
  214. vllm/executor/msgspec_utils.py +29 -0
  215. vllm/executor/multiproc_worker_utils.py +312 -0
  216. vllm/executor/ray_distributed_executor.py +700 -0
  217. vllm/executor/ray_utils.py +400 -0
  218. vllm/executor/uniproc_executor.py +141 -0
  219. vllm/forward_context.py +159 -0
  220. vllm/inputs/__init__.py +37 -0
  221. vllm/inputs/data.py +248 -0
  222. vllm/inputs/parse.py +121 -0
  223. vllm/inputs/preprocess.py +745 -0
  224. vllm/inputs/registry.py +212 -0
  225. vllm/jsontree.py +79 -0
  226. vllm/logger.py +210 -0
  227. vllm/logging_utils/__init__.py +7 -0
  228. vllm/logging_utils/formatter.py +17 -0
  229. vllm/logits_process.py +121 -0
  230. vllm/lora/__init__.py +0 -0
  231. vllm/lora/fully_sharded_layers.py +335 -0
  232. vllm/lora/layers.py +1263 -0
  233. vllm/lora/lora.py +198 -0
  234. vllm/lora/models.py +802 -0
  235. vllm/lora/ops/__init__.py +0 -0
  236. vllm/lora/ops/torch_ops/__init__.py +15 -0
  237. vllm/lora/ops/torch_ops/lora_ops.py +115 -0
  238. vllm/lora/ops/triton_ops/__init__.py +11 -0
  239. vllm/lora/ops/triton_ops/kernel_utils.py +243 -0
  240. vllm/lora/ops/triton_ops/lora_expand.py +293 -0
  241. vllm/lora/ops/triton_ops/lora_kernel_metadata.py +147 -0
  242. vllm/lora/ops/triton_ops/lora_shrink.py +247 -0
  243. vllm/lora/ops/triton_ops/utils.py +121 -0
  244. vllm/lora/peft_helper.py +115 -0
  245. vllm/lora/punica_wrapper/__init__.py +9 -0
  246. vllm/lora/punica_wrapper/punica_base.py +483 -0
  247. vllm/lora/punica_wrapper/punica_cpu.py +348 -0
  248. vllm/lora/punica_wrapper/punica_gpu.py +289 -0
  249. vllm/lora/punica_wrapper/punica_hpu.py +144 -0
  250. vllm/lora/punica_wrapper/punica_selector.py +20 -0
  251. vllm/lora/punica_wrapper/utils.py +161 -0
  252. vllm/lora/request.py +97 -0
  253. vllm/lora/resolver.py +83 -0
  254. vllm/lora/utils.py +237 -0
  255. vllm/lora/worker_manager.py +251 -0
  256. vllm/model_executor/__init__.py +15 -0
  257. vllm/model_executor/custom_op.py +153 -0
  258. vllm/model_executor/guided_decoding/__init__.py +180 -0
  259. vllm/model_executor/guided_decoding/guidance_decoding.py +63 -0
  260. vllm/model_executor/guided_decoding/guidance_logits_processors.py +85 -0
  261. vllm/model_executor/guided_decoding/guided_fields.py +42 -0
  262. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +66 -0
  263. vllm/model_executor/guided_decoding/outlines_decoding.py +154 -0
  264. vllm/model_executor/guided_decoding/outlines_logits_processors.py +271 -0
  265. vllm/model_executor/guided_decoding/reasoner/__init__.py +35 -0
  266. vllm/model_executor/guided_decoding/utils.py +241 -0
  267. vllm/model_executor/guided_decoding/xgrammar_decoding.py +425 -0
  268. vllm/model_executor/layers/__init__.py +0 -0
  269. vllm/model_executor/layers/activation.py +368 -0
  270. vllm/model_executor/layers/fused_moe/__init__.py +51 -0
  271. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  272. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  273. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  274. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  275. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  276. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  277. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  278. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  279. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  280. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  281. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  282. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  283. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  284. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  285. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  286. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  287. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  288. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  289. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  290. vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  291. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  292. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  293. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  294. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  295. vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  296. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  297. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json +146 -0
  298. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  299. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  300. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  301. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  302. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  303. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  304. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  305. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  306. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  307. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  308. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  309. vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  310. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  311. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  312. vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  313. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  314. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  315. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  316. vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  317. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  318. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
  319. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  320. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  321. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  322. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  323. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  324. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  325. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  326. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  327. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  328. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  329. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  330. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  331. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  332. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  333. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  334. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  335. vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  336. vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
  337. vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
  338. vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
  339. vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
  340. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  341. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  342. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  343. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  344. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  345. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  346. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  347. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  348. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  349. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  350. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  351. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  352. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  353. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  354. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  355. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  356. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  357. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  358. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  359. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  360. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  361. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  362. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  363. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  364. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  365. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  366. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  367. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  368. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json +200 -0
  369. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  370. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json +200 -0
  371. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  372. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  373. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  374. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  375. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  376. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  377. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  378. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  379. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  380. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  381. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json +200 -0
  382. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  383. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json +200 -0
  384. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  385. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  386. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  387. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  388. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  389. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  390. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  391. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  392. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  393. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  394. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  395. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  396. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  397. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  398. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  399. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  400. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  401. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  402. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json +200 -0
  403. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  404. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json +200 -0
  405. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  406. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  407. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  408. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  409. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  410. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  411. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  412. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  413. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  414. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  415. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  416. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  417. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  418. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  419. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  420. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json +200 -0
  421. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  422. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json +200 -0
  423. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  424. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  425. vllm/model_executor/layers/fused_moe/configs/README +12 -0
  426. vllm/model_executor/layers/fused_moe/cutlass_moe.py +180 -0
  427. vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +294 -0
  428. vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +374 -0
  429. vllm/model_executor/layers/fused_moe/fused_moe.py +1539 -0
  430. vllm/model_executor/layers/fused_moe/layer.py +949 -0
  431. vllm/model_executor/layers/fused_moe/moe_align_block_size.py +243 -0
  432. vllm/model_executor/layers/fused_moe/moe_pallas.py +64 -0
  433. vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +59 -0
  434. vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +416 -0
  435. vllm/model_executor/layers/fused_moe/utils.py +48 -0
  436. vllm/model_executor/layers/layernorm.py +277 -0
  437. vllm/model_executor/layers/lightning_attn.py +651 -0
  438. vllm/model_executor/layers/linear.py +1518 -0
  439. vllm/model_executor/layers/logits_processor.py +196 -0
  440. vllm/model_executor/layers/mamba/__init__.py +0 -0
  441. vllm/model_executor/layers/mamba/mamba2_metadata.py +109 -0
  442. vllm/model_executor/layers/mamba/mamba_mixer.py +244 -0
  443. vllm/model_executor/layers/mamba/mamba_mixer2.py +538 -0
  444. vllm/model_executor/layers/mamba/ops/__init__.py +0 -0
  445. vllm/model_executor/layers/mamba/ops/causal_conv1d.py +104 -0
  446. vllm/model_executor/layers/mamba/ops/mamba_ssm.py +415 -0
  447. vllm/model_executor/layers/mamba/ops/ssd_bmm.py +261 -0
  448. vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +588 -0
  449. vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +750 -0
  450. vllm/model_executor/layers/mamba/ops/ssd_combined.py +231 -0
  451. vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +205 -0
  452. vllm/model_executor/layers/pooler.py +336 -0
  453. vllm/model_executor/layers/quantization/__init__.py +153 -0
  454. vllm/model_executor/layers/quantization/aqlm.py +374 -0
  455. vllm/model_executor/layers/quantization/awq.py +184 -0
  456. vllm/model_executor/layers/quantization/awq_marlin.py +518 -0
  457. vllm/model_executor/layers/quantization/awq_triton.py +319 -0
  458. vllm/model_executor/layers/quantization/base_config.py +145 -0
  459. vllm/model_executor/layers/quantization/bitblas.py +459 -0
  460. vllm/model_executor/layers/quantization/bitsandbytes.py +396 -0
  461. vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +0 -0
  462. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +624 -0
  463. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +1100 -0
  464. vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +20 -0
  465. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +357 -0
  466. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +54 -0
  467. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +159 -0
  468. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +119 -0
  469. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +149 -0
  470. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +110 -0
  471. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +200 -0
  472. vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +205 -0
  473. vllm/model_executor/layers/quantization/compressed_tensors/utils.py +213 -0
  474. vllm/model_executor/layers/quantization/deepspeedfp.py +193 -0
  475. vllm/model_executor/layers/quantization/experts_int8.py +194 -0
  476. vllm/model_executor/layers/quantization/fbgemm_fp8.py +168 -0
  477. vllm/model_executor/layers/quantization/fp8.py +832 -0
  478. vllm/model_executor/layers/quantization/gguf.py +408 -0
  479. vllm/model_executor/layers/quantization/gptq.py +276 -0
  480. vllm/model_executor/layers/quantization/gptq_bitblas.py +438 -0
  481. vllm/model_executor/layers/quantization/gptq_marlin.py +643 -0
  482. vllm/model_executor/layers/quantization/gptq_marlin_24.py +295 -0
  483. vllm/model_executor/layers/quantization/hqq_marlin.py +328 -0
  484. vllm/model_executor/layers/quantization/ipex_quant.py +250 -0
  485. vllm/model_executor/layers/quantization/kernels/__init__.py +0 -0
  486. vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +89 -0
  487. vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +82 -0
  488. vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +115 -0
  489. vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +299 -0
  490. vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +142 -0
  491. vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +119 -0
  492. vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +132 -0
  493. vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +66 -0
  494. vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +86 -0
  495. vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +119 -0
  496. vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +136 -0
  497. vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +40 -0
  498. vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +104 -0
  499. vllm/model_executor/layers/quantization/kv_cache.py +137 -0
  500. vllm/model_executor/layers/quantization/marlin.py +259 -0
  501. vllm/model_executor/layers/quantization/modelopt.py +410 -0
  502. vllm/model_executor/layers/quantization/moe_wna16.py +447 -0
  503. vllm/model_executor/layers/quantization/neuron_quant.py +67 -0
  504. vllm/model_executor/layers/quantization/ptpc_fp8.py +125 -0
  505. vllm/model_executor/layers/quantization/qqq.py +273 -0
  506. vllm/model_executor/layers/quantization/quark/__init__.py +0 -0
  507. vllm/model_executor/layers/quantization/quark/quark.py +385 -0
  508. vllm/model_executor/layers/quantization/quark/quark_moe.py +236 -0
  509. vllm/model_executor/layers/quantization/quark/schemes/__init__.py +7 -0
  510. vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +54 -0
  511. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +142 -0
  512. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +121 -0
  513. vllm/model_executor/layers/quantization/quark/utils.py +102 -0
  514. vllm/model_executor/layers/quantization/schema.py +85 -0
  515. vllm/model_executor/layers/quantization/torchao.py +127 -0
  516. vllm/model_executor/layers/quantization/tpu_int8.py +119 -0
  517. vllm/model_executor/layers/quantization/utils/__init__.py +5 -0
  518. vllm/model_executor/layers/quantization/utils/allspark_utils.py +51 -0
  519. vllm/model_executor/layers/quantization/utils/bitblas_utils.py +198 -0
  520. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  521. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  522. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  523. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  524. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  525. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  526. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  527. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  528. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  529. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  530. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  531. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  532. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  533. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  534. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  535. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  536. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  537. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  538. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  539. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  540. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  541. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  542. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  543. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  544. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  545. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  546. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  547. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  548. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  549. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  550. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  551. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  552. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  553. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  554. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  555. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  556. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  557. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  558. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  559. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  560. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  561. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  562. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  563. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  564. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  565. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  566. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  567. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  568. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  569. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  570. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  571. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  572. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  573. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  574. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  575. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  576. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  577. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  578. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  579. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  580. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  581. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  582. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  583. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  584. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  585. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  586. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  587. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  588. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  589. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  590. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  591. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  592. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  593. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  594. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  595. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  596. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  597. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  598. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  599. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  600. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  601. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  602. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  603. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  604. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  605. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  606. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  607. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  608. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  609. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  610. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  611. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  612. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  613. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  614. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  615. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  616. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  617. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  618. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  619. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  620. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  621. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  622. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  623. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  624. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  625. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  626. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  627. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  628. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  629. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  630. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  631. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  632. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  633. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  634. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  635. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  636. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  637. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  638. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  639. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  640. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  641. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +18 -0
  642. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  643. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  644. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  645. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  646. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  647. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  648. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  649. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  650. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  651. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  652. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  653. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  654. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  655. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  656. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  657. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  658. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  659. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  660. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  661. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  662. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  663. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  664. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  665. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  666. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  667. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  668. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  669. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  670. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  671. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  672. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  673. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  674. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  675. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  676. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  677. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  678. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  679. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  680. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  681. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  682. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  683. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  684. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  685. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  686. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  687. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  688. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  689. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  690. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  691. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  692. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  693. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  694. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  695. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  696. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  697. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  698. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  699. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  700. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  701. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  702. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  703. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  704. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  705. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  706. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  707. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  708. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  709. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  710. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  711. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  712. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  713. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  714. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  715. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  716. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  717. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  718. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  719. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  720. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  721. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  722. vllm/model_executor/layers/quantization/utils/fp8_utils.py +523 -0
  723. vllm/model_executor/layers/quantization/utils/gptq_utils.py +94 -0
  724. vllm/model_executor/layers/quantization/utils/int8_utils.py +459 -0
  725. vllm/model_executor/layers/quantization/utils/layer_utils.py +39 -0
  726. vllm/model_executor/layers/quantization/utils/machete_utils.py +32 -0
  727. vllm/model_executor/layers/quantization/utils/marlin_utils.py +413 -0
  728. vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +110 -0
  729. vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +164 -0
  730. vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +464 -0
  731. vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +127 -0
  732. vllm/model_executor/layers/quantization/utils/quant_utils.py +571 -0
  733. vllm/model_executor/layers/quantization/utils/w8a8_utils.py +404 -0
  734. vllm/model_executor/layers/rejection_sampler.py +400 -0
  735. vllm/model_executor/layers/resampler.py +269 -0
  736. vllm/model_executor/layers/rotary_embedding.py +1598 -0
  737. vllm/model_executor/layers/sampler.py +1221 -0
  738. vllm/model_executor/layers/spec_decode_base_sampler.py +258 -0
  739. vllm/model_executor/layers/typical_acceptance_sampler.py +172 -0
  740. vllm/model_executor/layers/utils.py +99 -0
  741. vllm/model_executor/layers/vocab_parallel_embedding.py +485 -0
  742. vllm/model_executor/model_loader/__init__.py +20 -0
  743. vllm/model_executor/model_loader/loader.py +1542 -0
  744. vllm/model_executor/model_loader/neuron.py +243 -0
  745. vllm/model_executor/model_loader/tensorizer.py +468 -0
  746. vllm/model_executor/model_loader/utils.py +171 -0
  747. vllm/model_executor/model_loader/weight_utils.py +749 -0
  748. vllm/model_executor/models/__init__.py +27 -0
  749. vllm/model_executor/models/adapters.py +247 -0
  750. vllm/model_executor/models/arctic.py +559 -0
  751. vllm/model_executor/models/aria.py +656 -0
  752. vllm/model_executor/models/aya_vision.py +461 -0
  753. vllm/model_executor/models/baichuan.py +469 -0
  754. vllm/model_executor/models/bamba.py +542 -0
  755. vllm/model_executor/models/bart.py +936 -0
  756. vllm/model_executor/models/bert.py +725 -0
  757. vllm/model_executor/models/blip.py +337 -0
  758. vllm/model_executor/models/blip2.py +717 -0
  759. vllm/model_executor/models/bloom.py +358 -0
  760. vllm/model_executor/models/chameleon.py +1135 -0
  761. vllm/model_executor/models/chatglm.py +476 -0
  762. vllm/model_executor/models/clip.py +410 -0
  763. vllm/model_executor/models/commandr.py +466 -0
  764. vllm/model_executor/models/constant_size_cache.py +136 -0
  765. vllm/model_executor/models/dbrx.py +469 -0
  766. vllm/model_executor/models/deepseek.py +484 -0
  767. vllm/model_executor/models/deepseek_mtp.py +266 -0
  768. vllm/model_executor/models/deepseek_v2.py +830 -0
  769. vllm/model_executor/models/deepseek_vl2.py +647 -0
  770. vllm/model_executor/models/eagle.py +247 -0
  771. vllm/model_executor/models/exaone.py +548 -0
  772. vllm/model_executor/models/fairseq2_llama.py +153 -0
  773. vllm/model_executor/models/falcon.py +508 -0
  774. vllm/model_executor/models/florence2.py +1102 -0
  775. vllm/model_executor/models/fuyu.py +388 -0
  776. vllm/model_executor/models/gemma.py +423 -0
  777. vllm/model_executor/models/gemma2.py +423 -0
  778. vllm/model_executor/models/gemma3.py +531 -0
  779. vllm/model_executor/models/gemma3_mm.py +716 -0
  780. vllm/model_executor/models/glm.py +22 -0
  781. vllm/model_executor/models/glm4.py +303 -0
  782. vllm/model_executor/models/glm4v.py +647 -0
  783. vllm/model_executor/models/gpt2.py +313 -0
  784. vllm/model_executor/models/gpt_bigcode.py +336 -0
  785. vllm/model_executor/models/gpt_j.py +337 -0
  786. vllm/model_executor/models/gpt_neox.py +330 -0
  787. vllm/model_executor/models/granite.py +494 -0
  788. vllm/model_executor/models/granite_speech.py +777 -0
  789. vllm/model_executor/models/granitemoe.py +435 -0
  790. vllm/model_executor/models/granitemoeshared.py +339 -0
  791. vllm/model_executor/models/gritlm.py +245 -0
  792. vllm/model_executor/models/grok1.py +560 -0
  793. vllm/model_executor/models/h2ovl.py +542 -0
  794. vllm/model_executor/models/idefics2_vision_model.py +387 -0
  795. vllm/model_executor/models/idefics3.py +767 -0
  796. vllm/model_executor/models/interfaces.py +569 -0
  797. vllm/model_executor/models/interfaces_base.py +163 -0
  798. vllm/model_executor/models/intern_vit.py +476 -0
  799. vllm/model_executor/models/internlm2.py +453 -0
  800. vllm/model_executor/models/internlm2_ve.py +146 -0
  801. vllm/model_executor/models/internvl.py +945 -0
  802. vllm/model_executor/models/jais.py +371 -0
  803. vllm/model_executor/models/jamba.py +590 -0
  804. vllm/model_executor/models/kimi_vl.py +577 -0
  805. vllm/model_executor/models/llama.py +619 -0
  806. vllm/model_executor/models/llama4.py +530 -0
  807. vllm/model_executor/models/llama_eagle.py +152 -0
  808. vllm/model_executor/models/llama_eagle3.py +232 -0
  809. vllm/model_executor/models/llava.py +869 -0
  810. vllm/model_executor/models/llava_next.py +582 -0
  811. vllm/model_executor/models/llava_next_video.py +470 -0
  812. vllm/model_executor/models/llava_onevision.py +954 -0
  813. vllm/model_executor/models/mamba.py +271 -0
  814. vllm/model_executor/models/mamba2.py +302 -0
  815. vllm/model_executor/models/mamba_cache.py +76 -0
  816. vllm/model_executor/models/medusa.py +210 -0
  817. vllm/model_executor/models/minicpm.py +592 -0
  818. vllm/model_executor/models/minicpm3.py +229 -0
  819. vllm/model_executor/models/minicpmo.py +725 -0
  820. vllm/model_executor/models/minicpmv.py +1287 -0
  821. vllm/model_executor/models/minimax_cache.py +35 -0
  822. vllm/model_executor/models/minimax_text_01.py +1261 -0
  823. vllm/model_executor/models/mistral3.py +598 -0
  824. vllm/model_executor/models/mixtral.py +485 -0
  825. vllm/model_executor/models/mixtral_quant.py +447 -0
  826. vllm/model_executor/models/mllama.py +1623 -0
  827. vllm/model_executor/models/mllama4.py +838 -0
  828. vllm/model_executor/models/mlp_speculator.py +205 -0
  829. vllm/model_executor/models/modernbert.py +325 -0
  830. vllm/model_executor/models/module_mapping.py +71 -0
  831. vllm/model_executor/models/molmo.py +1567 -0
  832. vllm/model_executor/models/moonvit.py +628 -0
  833. vllm/model_executor/models/mpt.py +329 -0
  834. vllm/model_executor/models/nemotron.py +506 -0
  835. vllm/model_executor/models/nemotron_nas.py +446 -0
  836. vllm/model_executor/models/nvlm_d.py +212 -0
  837. vllm/model_executor/models/olmo.py +390 -0
  838. vllm/model_executor/models/olmo2.py +412 -0
  839. vllm/model_executor/models/olmoe.py +449 -0
  840. vllm/model_executor/models/opt.py +410 -0
  841. vllm/model_executor/models/orion.py +356 -0
  842. vllm/model_executor/models/paligemma.py +397 -0
  843. vllm/model_executor/models/persimmon.py +342 -0
  844. vllm/model_executor/models/phi.py +354 -0
  845. vllm/model_executor/models/phi3.py +18 -0
  846. vllm/model_executor/models/phi3_small.py +463 -0
  847. vllm/model_executor/models/phi3v.py +722 -0
  848. vllm/model_executor/models/phi4mm.py +1263 -0
  849. vllm/model_executor/models/phi4mm_audio.py +1232 -0
  850. vllm/model_executor/models/phi4mm_utils.py +1883 -0
  851. vllm/model_executor/models/phimoe.py +666 -0
  852. vllm/model_executor/models/pixtral.py +1281 -0
  853. vllm/model_executor/models/plamo2.py +736 -0
  854. vllm/model_executor/models/prithvi_geospatial_mae.py +231 -0
  855. vllm/model_executor/models/qwen.py +360 -0
  856. vllm/model_executor/models/qwen2.py +552 -0
  857. vllm/model_executor/models/qwen2_5_omni_thinker.py +901 -0
  858. vllm/model_executor/models/qwen2_5_vl.py +1136 -0
  859. vllm/model_executor/models/qwen2_audio.py +402 -0
  860. vllm/model_executor/models/qwen2_moe.py +531 -0
  861. vllm/model_executor/models/qwen2_rm.py +130 -0
  862. vllm/model_executor/models/qwen2_vl.py +1409 -0
  863. vllm/model_executor/models/qwen3.py +319 -0
  864. vllm/model_executor/models/qwen3_moe.py +528 -0
  865. vllm/model_executor/models/qwen_vl.py +784 -0
  866. vllm/model_executor/models/registry.py +611 -0
  867. vllm/model_executor/models/roberta.py +332 -0
  868. vllm/model_executor/models/siglip.py +522 -0
  869. vllm/model_executor/models/skyworkr1v.py +949 -0
  870. vllm/model_executor/models/smolvlm.py +51 -0
  871. vllm/model_executor/models/solar.py +504 -0
  872. vllm/model_executor/models/stablelm.py +349 -0
  873. vllm/model_executor/models/starcoder2.py +355 -0
  874. vllm/model_executor/models/telechat2.py +139 -0
  875. vllm/model_executor/models/teleflm.py +78 -0
  876. vllm/model_executor/models/transformers.py +442 -0
  877. vllm/model_executor/models/ultravox.py +655 -0
  878. vllm/model_executor/models/utils.py +714 -0
  879. vllm/model_executor/models/vision.py +149 -0
  880. vllm/model_executor/models/whisper.py +746 -0
  881. vllm/model_executor/models/zamba2.py +1008 -0
  882. vllm/model_executor/parameter.py +458 -0
  883. vllm/model_executor/pooling_metadata.py +71 -0
  884. vllm/model_executor/sampling_metadata.py +596 -0
  885. vllm/model_executor/utils.py +53 -0
  886. vllm/multimodal/__init__.py +31 -0
  887. vllm/multimodal/audio.py +105 -0
  888. vllm/multimodal/base.py +218 -0
  889. vllm/multimodal/hasher.py +103 -0
  890. vllm/multimodal/image.py +77 -0
  891. vllm/multimodal/inputs.py +843 -0
  892. vllm/multimodal/parse.py +454 -0
  893. vllm/multimodal/processing.py +1760 -0
  894. vllm/multimodal/profiling.py +274 -0
  895. vllm/multimodal/registry.py +321 -0
  896. vllm/multimodal/utils.py +386 -0
  897. vllm/multimodal/video.py +166 -0
  898. vllm/outputs.py +521 -0
  899. vllm/platforms/__init__.py +286 -0
  900. vllm/platforms/cpu.py +182 -0
  901. vllm/platforms/cuda.py +463 -0
  902. vllm/platforms/hpu.py +94 -0
  903. vllm/platforms/interface.py +427 -0
  904. vllm/platforms/neuron.py +69 -0
  905. vllm/platforms/rocm.py +346 -0
  906. vllm/platforms/tpu.py +174 -0
  907. vllm/platforms/xpu.py +142 -0
  908. vllm/plugins/__init__.py +82 -0
  909. vllm/pooling_params.py +53 -0
  910. vllm/profiler/__init__.py +7 -0
  911. vllm/profiler/layerwise_profile.py +374 -0
  912. vllm/profiler/utils.py +147 -0
  913. vllm/prompt_adapter/__init__.py +0 -0
  914. vllm/prompt_adapter/layers.py +82 -0
  915. vllm/prompt_adapter/models.py +357 -0
  916. vllm/prompt_adapter/request.py +36 -0
  917. vllm/prompt_adapter/utils.py +97 -0
  918. vllm/prompt_adapter/worker_manager.py +178 -0
  919. vllm/py.typed +2 -0
  920. vllm/reasoning/__init__.py +12 -0
  921. vllm/reasoning/abs_reasoning_parsers.py +189 -0
  922. vllm/reasoning/deepseek_r1_reasoning_parser.py +172 -0
  923. vllm/reasoning/granite_reasoning_parser.py +362 -0
  924. vllm/sampling_params.py +598 -0
  925. vllm/scalar_type.py +335 -0
  926. vllm/scripts.py +14 -0
  927. vllm/sequence.py +1486 -0
  928. vllm/spec_decode/__init__.py +0 -0
  929. vllm/spec_decode/batch_expansion.py +505 -0
  930. vllm/spec_decode/draft_model_runner.py +335 -0
  931. vllm/spec_decode/interfaces.py +98 -0
  932. vllm/spec_decode/medusa_worker.py +137 -0
  933. vllm/spec_decode/metrics.py +212 -0
  934. vllm/spec_decode/mlp_speculator_worker.py +93 -0
  935. vllm/spec_decode/mqa_scorer.py +159 -0
  936. vllm/spec_decode/multi_step_worker.py +416 -0
  937. vllm/spec_decode/ngram_worker.py +195 -0
  938. vllm/spec_decode/proposer_worker_base.py +58 -0
  939. vllm/spec_decode/smaller_tp_proposer_worker.py +194 -0
  940. vllm/spec_decode/spec_decode_worker.py +1324 -0
  941. vllm/spec_decode/target_model_runner.py +44 -0
  942. vllm/spec_decode/top1_proposer.py +274 -0
  943. vllm/spec_decode/util.py +276 -0
  944. vllm/test_utils.py +129 -0
  945. vllm/third_party/__init__.py +0 -0
  946. vllm/third_party/pynvml.py +6139 -0
  947. vllm/tracing.py +130 -0
  948. vllm/transformers_utils/__init__.py +19 -0
  949. vllm/transformers_utils/config.py +813 -0
  950. vllm/transformers_utils/configs/__init__.py +52 -0
  951. vllm/transformers_utils/configs/arctic.py +206 -0
  952. vllm/transformers_utils/configs/chatglm.py +71 -0
  953. vllm/transformers_utils/configs/cohere2.py +194 -0
  954. vllm/transformers_utils/configs/dbrx.py +280 -0
  955. vllm/transformers_utils/configs/deepseek_vl2.py +216 -0
  956. vllm/transformers_utils/configs/eagle.py +65 -0
  957. vllm/transformers_utils/configs/exaone.py +191 -0
  958. vllm/transformers_utils/configs/falcon.py +89 -0
  959. vllm/transformers_utils/configs/h2ovl.py +15 -0
  960. vllm/transformers_utils/configs/internvl.py +53 -0
  961. vllm/transformers_utils/configs/jais.py +237 -0
  962. vllm/transformers_utils/configs/kimi_vl.py +36 -0
  963. vllm/transformers_utils/configs/medusa.py +62 -0
  964. vllm/transformers_utils/configs/mllama.py +30 -0
  965. vllm/transformers_utils/configs/mlp_speculator.py +67 -0
  966. vllm/transformers_utils/configs/moonvit.py +32 -0
  967. vllm/transformers_utils/configs/mpt.py +179 -0
  968. vllm/transformers_utils/configs/nemotron.py +204 -0
  969. vllm/transformers_utils/configs/nvlm_d.py +14 -0
  970. vllm/transformers_utils/configs/skyworkr1v.py +53 -0
  971. vllm/transformers_utils/configs/solar.py +246 -0
  972. vllm/transformers_utils/configs/telechat2.py +63 -0
  973. vllm/transformers_utils/configs/ultravox.py +107 -0
  974. vllm/transformers_utils/detokenizer.py +167 -0
  975. vllm/transformers_utils/detokenizer_utils.py +188 -0
  976. vllm/transformers_utils/processor.py +210 -0
  977. vllm/transformers_utils/processors/__init__.py +6 -0
  978. vllm/transformers_utils/processors/deepseek_vl2.py +363 -0
  979. vllm/transformers_utils/s3_utils.py +161 -0
  980. vllm/transformers_utils/tokenizer.py +291 -0
  981. vllm/transformers_utils/tokenizer_base.py +146 -0
  982. vllm/transformers_utils/tokenizer_group.py +110 -0
  983. vllm/transformers_utils/tokenizers/__init__.py +9 -0
  984. vllm/transformers_utils/tokenizers/mistral.py +483 -0
  985. vllm/transformers_utils/utils.py +98 -0
  986. vllm/triton_utils/__init__.py +5 -0
  987. vllm/triton_utils/importing.py +53 -0
  988. vllm/usage/__init__.py +0 -0
  989. vllm/usage/usage_lib.py +255 -0
  990. vllm/utils.py +2692 -0
  991. vllm/v1/__init__.py +0 -0
  992. vllm/v1/attention/__init__.py +0 -0
  993. vllm/v1/attention/backends/__init__.py +0 -0
  994. vllm/v1/attention/backends/flash_attn.py +783 -0
  995. vllm/v1/attention/backends/flashinfer.py +638 -0
  996. vllm/v1/attention/backends/mla/__init__.py +0 -0
  997. vllm/v1/attention/backends/mla/common.py +974 -0
  998. vllm/v1/attention/backends/mla/flashmla.py +149 -0
  999. vllm/v1/attention/backends/mla/triton_mla.py +118 -0
  1000. vllm/v1/attention/backends/pallas.py +221 -0
  1001. vllm/v1/attention/backends/triton_attn.py +198 -0
  1002. vllm/v1/core/__init__.py +0 -0
  1003. vllm/v1/core/block_pool.py +281 -0
  1004. vllm/v1/core/encoder_cache_manager.py +149 -0
  1005. vllm/v1/core/kv_cache_manager.py +385 -0
  1006. vllm/v1/core/kv_cache_utils.py +744 -0
  1007. vllm/v1/core/sched/__init__.py +0 -0
  1008. vllm/v1/core/sched/interface.py +134 -0
  1009. vllm/v1/core/sched/output.py +126 -0
  1010. vllm/v1/core/sched/scheduler.py +838 -0
  1011. vllm/v1/core/sched/utils.py +22 -0
  1012. vllm/v1/core/specialized_manager.py +161 -0
  1013. vllm/v1/engine/__init__.py +166 -0
  1014. vllm/v1/engine/async_llm.py +532 -0
  1015. vllm/v1/engine/core.py +701 -0
  1016. vllm/v1/engine/core_client.py +942 -0
  1017. vllm/v1/engine/detokenizer.py +260 -0
  1018. vllm/v1/engine/exceptions.py +16 -0
  1019. vllm/v1/engine/llm_engine.py +285 -0
  1020. vllm/v1/engine/logprobs.py +198 -0
  1021. vllm/v1/engine/mm_input_cache.py +82 -0
  1022. vllm/v1/engine/output_processor.py +420 -0
  1023. vllm/v1/engine/parallel_sampling.py +132 -0
  1024. vllm/v1/engine/processor.py +387 -0
  1025. vllm/v1/executor/__init__.py +0 -0
  1026. vllm/v1/executor/abstract.py +112 -0
  1027. vllm/v1/executor/multiproc_executor.py +480 -0
  1028. vllm/v1/executor/ray_distributed_executor.py +61 -0
  1029. vllm/v1/kv_cache_interface.py +166 -0
  1030. vllm/v1/metrics/__init__.py +0 -0
  1031. vllm/v1/metrics/loggers.py +498 -0
  1032. vllm/v1/metrics/stats.py +238 -0
  1033. vllm/v1/outputs.py +111 -0
  1034. vllm/v1/request.py +178 -0
  1035. vllm/v1/sample/__init__.py +0 -0
  1036. vllm/v1/sample/metadata.py +43 -0
  1037. vllm/v1/sample/ops/__init__.py +0 -0
  1038. vllm/v1/sample/ops/bad_words.py +38 -0
  1039. vllm/v1/sample/ops/penalties.py +58 -0
  1040. vllm/v1/sample/ops/topk_topp_sampler.py +315 -0
  1041. vllm/v1/sample/rejection_sampler.py +631 -0
  1042. vllm/v1/sample/sampler.py +270 -0
  1043. vllm/v1/sample/tpu/__init__.py +0 -0
  1044. vllm/v1/sample/tpu/metadata.py +118 -0
  1045. vllm/v1/sample/tpu/sampler.py +154 -0
  1046. vllm/v1/serial_utils.py +274 -0
  1047. vllm/v1/spec_decode/__init__.py +0 -0
  1048. vllm/v1/spec_decode/eagle.py +318 -0
  1049. vllm/v1/spec_decode/metadata.py +61 -0
  1050. vllm/v1/spec_decode/metrics.py +164 -0
  1051. vllm/v1/spec_decode/ngram_proposer.py +131 -0
  1052. vllm/v1/spec_decode/utils.py +18 -0
  1053. vllm/v1/stats/__init__.py +0 -0
  1054. vllm/v1/stats/common.py +453 -0
  1055. vllm/v1/structured_output/__init__.py +113 -0
  1056. vllm/v1/structured_output/backend_guidance.py +215 -0
  1057. vllm/v1/structured_output/backend_types.py +96 -0
  1058. vllm/v1/structured_output/backend_xgrammar.py +299 -0
  1059. vllm/v1/structured_output/request.py +84 -0
  1060. vllm/v1/structured_output/utils.py +174 -0
  1061. vllm/v1/utils.py +249 -0
  1062. vllm/v1/worker/__init__.py +0 -0
  1063. vllm/v1/worker/block_table.py +87 -0
  1064. vllm/v1/worker/gpu_input_batch.py +677 -0
  1065. vllm/v1/worker/gpu_model_runner.py +1776 -0
  1066. vllm/v1/worker/gpu_worker.py +349 -0
  1067. vllm/v1/worker/lora_model_runner_mixin.py +145 -0
  1068. vllm/v1/worker/tpu_model_runner.py +1419 -0
  1069. vllm/v1/worker/tpu_worker.py +260 -0
  1070. vllm/v1/worker/utils.py +74 -0
  1071. vllm/v1/worker/worker_base.py +64 -0
  1072. vllm/version.py +40 -0
  1073. vllm/vllm_flash_attn/.gitkeep +0 -0
  1074. vllm/worker/__init__.py +0 -0
  1075. vllm/worker/cache_engine.py +144 -0
  1076. vllm/worker/cpu_enc_dec_model_runner.py +323 -0
  1077. vllm/worker/cpu_model_runner.py +668 -0
  1078. vllm/worker/cpu_pooling_model_runner.py +122 -0
  1079. vllm/worker/cpu_worker.py +400 -0
  1080. vllm/worker/enc_dec_model_runner.py +542 -0
  1081. vllm/worker/hpu_model_runner.py +2221 -0
  1082. vllm/worker/hpu_worker.py +483 -0
  1083. vllm/worker/model_runner.py +2056 -0
  1084. vllm/worker/model_runner_base.py +281 -0
  1085. vllm/worker/multi_step_hpu_worker.py +122 -0
  1086. vllm/worker/multi_step_model_runner.py +908 -0
  1087. vllm/worker/multi_step_tpu_worker.py +107 -0
  1088. vllm/worker/multi_step_worker.py +196 -0
  1089. vllm/worker/neuron_model_runner.py +336 -0
  1090. vllm/worker/neuron_worker.py +138 -0
  1091. vllm/worker/pooling_model_runner.py +200 -0
  1092. vllm/worker/tpu_model_runner.py +908 -0
  1093. vllm/worker/tpu_worker.py +332 -0
  1094. vllm/worker/utils.py +52 -0
  1095. vllm/worker/worker.py +570 -0
  1096. vllm/worker/worker_base.py +644 -0
  1097. vllm/worker/xpu_model_runner.py +603 -0
  1098. vllm/worker/xpu_worker.py +185 -0
  1099. vllm_cpu-0.8.5.post2.dist-info/METADATA +309 -0
  1100. vllm_cpu-0.8.5.post2.dist-info/RECORD +1103 -0
  1101. vllm_cpu-0.8.5.post2.dist-info/WHEEL +5 -0
  1102. vllm_cpu-0.8.5.post2.dist-info/entry_points.txt +2 -0
  1103. vllm_cpu-0.8.5.post2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1444 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ """
3
+ This file implements common components for MLA implementations.
4
+
5
+ First we define:
6
+
7
+ Sq as Q sequence length
8
+ Skv as KV sequence length
9
+
10
+ MLA has two possible ways of computing, a data-movement friendly approach and a
11
+ compute friendly approach, we generally want to use the compute friendly
12
+ approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
13
+ and the data-movement friendly approach for "decode" (i.e. the ratio
14
+ Sq / Skv is "large").
15
+
16
+ NOTE what we deem small and large is currently determined by if its labelled
17
+ prefill or decode by the scheduler, but this is something we should probably
18
+ tune.
19
+
20
+ Main reference: DeepseekV2 paper, and FlashInfer Implementation
21
+ (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
22
+
23
+ Deepseek's MLA attention works the following way:
24
+ * Use a single latent vector to represent the per-token entry of the KV cache.
25
+ * For decode (i.e. the memory friendly approach) the attention "simulates" a
26
+ multi-head attention, while the compute is similar to multi-query attention.
27
+
28
+ Below is example of both paths assuming batchsize = 1
29
+
30
+ ## More Extent Definitions:
31
+
32
+ C Context length, `Skv - Sq`
33
+ H hidden size
34
+ N number of attention heads
35
+ Lq latent dimension for Q 1536 in DSV3
36
+ Lkv latent dimension for K/V 512 in DSV3
37
+ P nope dimension, no rope. 128 in DSV3
38
+ R rope dimension, goes through rope. 64 in DSV3
39
+ V V head dim. 128 in DSV3
40
+
41
+ ## Vector/Matrix Definitions
42
+
43
+ h_t hidden states (input to attention) shape [Sq, H]
44
+ q_c latent/compressed Q shape [Sq, Lq]
45
+ q_nope uncompressed Q (no-rope) shape [Sq, N, P]
46
+ q_pe uncompressed Q (rope) shape [Sq, N, R]
47
+ kv_c latent/compressed KV shape [Skv, Lkv]
48
+ k_pe decoupled k position embeddings shape [Skv, R]
49
+ new_kv_c new kv_c from current iter shape [Sq, Lkv]
50
+ new_k_pe new k_pe from current iter shape [Sq, R]
51
+ cache_kv_c cached k_c from previous iters shape [C, Lkv]
52
+ cache_k_pe cached k_pe from previous iters shape [C, R]
53
+ W_DQ project h_t to q_c shape [H, Lq]
54
+ W_UQ project q_c to q_nope shape [Lq, N * P]
55
+ W_QR project q_c to q_pe shape [Lq, N * R]
56
+ W_DKV project h_t to kv_c shape [H, Lkv]
57
+ W_UK project kv_c to k_nope shape [Lkv, N, P]
58
+ W_KR project h_t to k_pe shape [H, R]
59
+ W_UV project kv_c to v shape [Lkv, N, V]
60
+ W_O project v to h_t shape [N * V, H]
61
+
62
+
63
+ ## Compute Friendly Approach (i.e. "_forward_prefill"):
64
+
65
+ q_c = h_t @ W_DQ
66
+ q_nope = (q_c @ W_UQ).view(Sq, N, P)
67
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
68
+ new_kv_c = h_t @ W_DKV
69
+ new_k_pe = RoPE(h_t @ W_KR)
70
+ kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
71
+ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
72
+ k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
73
+ v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
74
+
75
+ // MHA with QK headdim = P + R
76
+ // V headdim = V
77
+ // spda_o shape [Sq, N, V]
78
+ spda_o = scaled_dot_product_attention(
79
+ torch.cat([q_nope, q_pe], dim=-1),
80
+ torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
81
+ v
82
+ )
83
+ return spda_o @ W_O
84
+
85
+ NOTE: in the actual code,
86
+ `kv_b_proj` is [W_UK; W_UV] concatenated per head
87
+ `q_b_proj` is [W_UQ; W_QR] concatenated per head
88
+ `out_proj` is W_O
89
+
90
+
91
+ ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
92
+
93
+ Runtime
94
+ q_c = h_t @ W_DQ
95
+ q_nope = (q_c @ W_UQ).view(-1, N, P)
96
+ ql_nope = einsum("snh,lnh->snl", q, W_UK)
97
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
98
+ new_kv_c = h_t @ W_DKV
99
+ new_k_pe = RoPE(h_t @ W_KR)
100
+ kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
101
+ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
102
+
103
+ // MQA with QK headdim = Lkv + R
104
+ // V headdim = Lkv
105
+ // spda_o shape [Sq, N, Lkv]
106
+ // NOTE: this is less compute-friendly since Lkv > P
107
+ // but is more data-movement friendly since its MQA vs MHA
108
+ spda_o = scaled_dot_product_attention(
109
+ torch.cat([ql_nope, q_pe], dim=-1),
110
+ torch.cat([kv_c, k_pe], dim=-1),
111
+ kv_c
112
+ )
113
+
114
+ o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
115
+ return o.view(-1, N * V) @ self.num_heads @ W_O
116
+
117
+
118
+ ## Chunked Prefill
119
+
120
+ For chunked prefill we want to use the compute friendly algorithm. We are
121
+ assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
122
+ the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
123
+
124
+ However, the compute-friendly approach can potentially run out of memory if Skv
125
+ is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
126
+
127
+ To mitigate this, we chunk the computation of attention with respect to the
128
+ current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
129
+ fixed workspace size.
130
+
131
+ The chunked prefill approach is as follows:
132
+
133
+ MCC Max chunk of context to process per iter, computed dynamically,
134
+ used to bound the memory usage
135
+
136
+ q_c = h_t @ W_DQ
137
+ q_nope = (q_c @ W_UQ).view(Sq, N, P)
138
+ q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
139
+ new_kv_c = h_t @ W_DKV
140
+ new_k_pe = RoPE(h_t @ W_KR)
141
+ new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
142
+ new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
143
+
144
+ // MHA between queries and new KV
145
+ // with QK headdim = P + R
146
+ // V headdim = V
147
+ // curr_o shape [Sq, N, V]
148
+ // curr_lse shape [N, Sq], this is just order FA returns
149
+ curr_o, curr_lse = scaled_dot_product_attention(
150
+ torch.cat([q_nope, q_pe], dim=-1),
151
+ torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
152
+ new_v,
153
+ casual=True,
154
+ return_softmax_lse=True
155
+ )
156
+
157
+ // Compute attention with the already existing context
158
+ for chunk_idx in range(cdiv(C, MCC)):
159
+ chunk_start = chunk_idx * MCC
160
+ chunk_end = min(chunk_start + MCC, C)
161
+ Sc = chunk_end - chunk_start
162
+ cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
163
+ cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
164
+ cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
165
+ cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
166
+
167
+ chunk_o, chunk_lse = scaled_dot_product_attention(
168
+ torch.cat([q_nope, q_pe], dim=-1),
169
+ torch.cat([cache_k_nope_chunk,
170
+ cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
171
+ dim=-1),
172
+ cache_v_chunk,
173
+ casual=False,
174
+ return_softmax_lse=True
175
+ )
176
+
177
+ curr_o, curr_lse = merge_attn_states(
178
+ suffix_output=curr_o,
179
+ suffix_lse=curr_lse,
180
+ prefix_output=chunk_o,
181
+ prefix_lse=chunk_lse,
182
+ )
183
+
184
+ return curr_o @ W_O
185
+ """
186
+
187
+ import functools
188
+ from abc import abstractmethod
189
+ from collections import defaultdict
190
+ from contextlib import contextmanager
191
+ from dataclasses import dataclass
192
+ from itertools import accumulate
193
+ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
194
+ Type, TypeVar)
195
+
196
+ import torch
197
+
198
+ from vllm import _custom_ops as ops
199
+ from vllm import envs
200
+ from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
201
+ AttentionMetadata,
202
+ AttentionMetadataBuilder,
203
+ AttentionState, MLAAttentionImpl)
204
+ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
205
+ compute_slot_mapping_start_idx,
206
+ is_block_tables_empty)
207
+ from vllm.attention.ops.merge_attn_states import merge_attn_states
208
+ from vllm.attention.utils.fa_utils import get_flash_attn_version
209
+ from vllm.model_executor.layers.linear import (ColumnParallelLinear,
210
+ LinearBase, RowParallelLinear,
211
+ UnquantizedLinearMethod)
212
+ from vllm.model_executor.layers.rotary_embedding import (
213
+ DeepseekScalingRotaryEmbedding, RotaryEmbedding)
214
+ from vllm.multimodal import MultiModalPlaceholderMap
215
+ from vllm.platforms import current_platform
216
+ from vllm.triton_utils import HAS_TRITON
217
+ from vllm.utils import async_tensor_h2d, cdiv, make_tensor_with_pad, round_down
218
+
219
+ if HAS_TRITON:
220
+ from vllm.attention.ops.triton_flash_attention import triton_attention
221
+ else:
222
+ triton_attention = None
223
+
224
+ try:
225
+ from vllm.vllm_flash_attn import flash_attn_varlen_func
226
+ is_vllm_fa = True
227
+ except ImportError:
228
+ is_vllm_fa = False
229
+ try:
230
+ # For rocm use upstream flash attention
231
+ from flash_attn import flash_attn_varlen_func
232
+ except ImportError:
233
+ flash_attn_varlen_func = None
234
+
235
+ if TYPE_CHECKING:
236
+ from vllm.worker.model_runner import (ModelInputForGPUBuilder,
237
+ ModelInputForGPUWithSamplingMetadata)
238
+
239
+ is_hip = current_platform.is_rocm()
240
+
241
+
242
+ class MLACommonBackend(AttentionBackend):
243
+
244
+ @staticmethod
245
+ def get_name() -> str:
246
+ return "TRITON_MLA"
247
+
248
+ @staticmethod
249
+ def get_metadata_cls() -> Type["AttentionMetadata"]:
250
+ return MLACommonMetadata
251
+
252
+ @staticmethod
253
+ def get_builder_cls() -> Type["MLACommonMetadataBuilder"]:
254
+ return MLACommonMetadataBuilder
255
+
256
+ @staticmethod
257
+ def get_state_cls() -> Type["MLACommonState"]:
258
+ return MLACommonState
259
+
260
+ @staticmethod
261
+ def get_kv_cache_shape(
262
+ num_blocks: int,
263
+ block_size: int,
264
+ num_kv_heads: int, # assumed to be 1 for MLA
265
+ head_size: int,
266
+ ) -> Tuple[int, ...]:
267
+ return (num_blocks, block_size, head_size)
268
+
269
+ @staticmethod
270
+ def swap_blocks(
271
+ src_kv_cache: torch.Tensor,
272
+ dst_kv_cache: torch.Tensor,
273
+ src_to_dst: torch.Tensor,
274
+ ) -> None:
275
+ ops.swap_blocks(src_kv_cache, dst_kv_cache, src_to_dst)
276
+
277
+ @staticmethod
278
+ def copy_blocks(
279
+ kv_caches: List[torch.Tensor],
280
+ src_to_dists: torch.Tensor,
281
+ ) -> None:
282
+ ops.copy_blocks_mla(kv_caches, src_to_dists)
283
+
284
+ @staticmethod
285
+ def get_supported_head_sizes() -> List[int]:
286
+ return [576]
287
+
288
+
289
+ T = TypeVar("T", bound="MLACommonMetadata")
290
+
291
+
292
+ class MLACommonState(AttentionState, Generic[T]):
293
+
294
+ def __init__(self, runner):
295
+ self.runner = runner
296
+ self._is_graph_capturing = False
297
+
298
+ scheduler_config = runner.scheduler_config
299
+ self.model_config = runner.model_config
300
+ cache_config = runner.cache_config
301
+
302
+ self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
303
+ self.enable_prefix_caching = cache_config.enable_prefix_caching
304
+
305
+ if self.chunked_prefill_enabled or self.enable_prefix_caching:
306
+ self.context_chunk_workspace_size = min(
307
+ # Max sure there is enough for 8 full length request or at least
308
+ # 4 pages of cache per request
309
+ max(
310
+ 8 * self.model_config.max_model_len, 4 *
311
+ scheduler_config.max_num_seqs * cache_config.block_size),
312
+ # For long-context models try not to over-allocate limiting
313
+ # kv-cache space, limiting it to 64k tokens,
314
+ # which would result in the workspace being:
315
+ # 2*(576)*(64*1024) = 144mb
316
+ # (assuming 576 MLA head dim, and fp16)
317
+ # which would result in up-projected context being
318
+ # 2*(192*128)*(64*1024) = 3gb
319
+ # (assuming 192 QK head dim, 128 heads, and fp16)
320
+ 128 * 1024)
321
+ assert self.context_chunk_workspace_size >= \
322
+ scheduler_config.max_num_seqs * cache_config.block_size
323
+
324
+ @contextmanager
325
+ def graph_capture(self, max_batch_size: int):
326
+ self._is_graph_capturing = True
327
+
328
+ self._graph_slot_mapping = torch.full((max_batch_size, ),
329
+ PAD_SLOT_ID,
330
+ dtype=torch.long,
331
+ device=self.runner.device)
332
+ self._graph_seq_lens = torch.ones(max_batch_size,
333
+ dtype=torch.int32,
334
+ device=self.runner.device)
335
+ self._graph_block_tables = torch.from_numpy(
336
+ self.runner.graph_block_tables).to(device=self.runner.device)
337
+
338
+ self._positions = torch.zeros((max_batch_size, ),
339
+ dtype=torch.long,
340
+ device=self.runner.device)
341
+
342
+ yield
343
+
344
+ self._is_graph_capturing = False
345
+ del self._graph_slot_mapping
346
+ del self._graph_seq_lens
347
+ del self._graph_block_tables
348
+ del self._positions
349
+
350
+ def graph_clone(self, batch_size: int):
351
+ assert self._is_graph_capturing
352
+ return self.__class__(self.runner)
353
+
354
+ def graph_capture_get_metadata_for_batch(
355
+ self,
356
+ batch_size: int,
357
+ is_encoder_decoder_model: bool = False) -> T:
358
+ assert self._is_graph_capturing
359
+
360
+ attn_metadata = self.runner.attn_backend.make_metadata(
361
+ multi_modal_placeholder_index_maps=None,
362
+ enable_kv_scales_calculation=False,
363
+ use_cuda_graph=True,
364
+ num_prefills=0,
365
+ num_prefill_tokens=0,
366
+ num_decode_tokens=batch_size,
367
+ slot_mapping=self._graph_slot_mapping[:batch_size],
368
+ seq_lens=None,
369
+ seq_lens_tensor=self._graph_seq_lens[:batch_size],
370
+ max_query_len=1,
371
+ max_decode_query_len=1,
372
+ max_prefill_seq_len=0,
373
+ max_decode_seq_len=self.runner.max_seq_len_to_capture,
374
+ query_start_loc=None,
375
+ seq_start_loc=None,
376
+ context_lens_tensor=None,
377
+ block_tables=self._graph_block_tables[:batch_size],
378
+ input_positions=self._positions[:batch_size],
379
+ head_dim=self.runner.model_config.get_head_size())
380
+
381
+ if is_encoder_decoder_model:
382
+ raise NotImplementedError(
383
+ "MLACommonState does not support encoder/decoder yet")
384
+
385
+ return attn_metadata
386
+
387
+ def get_graph_input_buffers(self,
388
+ attn_metadata,
389
+ is_encoder_decoder_model: bool = False):
390
+ input_buffers = {
391
+ "slot_mapping": attn_metadata.slot_mapping,
392
+ "seq_lens_tensor": attn_metadata.decode_metadata.seq_lens_tensor,
393
+ "block_tables": attn_metadata.decode_metadata.block_tables,
394
+ "input_positions": attn_metadata.decode_metadata.input_positions,
395
+ }
396
+ if is_encoder_decoder_model:
397
+ raise NotImplementedError(
398
+ "MLACommonState does not support encoder/decoder yet")
399
+
400
+ return input_buffers
401
+
402
+ def prepare_graph_input_buffers(self,
403
+ input_buffers,
404
+ attn_metadata,
405
+ is_encoder_decoder_model: bool = False):
406
+ input_positions = attn_metadata.input_positions
407
+ num_positions = input_positions.shape[0]
408
+ input_buffers["seq_lens_tensor"].copy_(
409
+ attn_metadata.decode_metadata.seq_lens_tensor, non_blocking=True)
410
+ input_buffers["block_tables"].copy_(
411
+ attn_metadata.decode_metadata.block_tables, non_blocking=True)
412
+ # CUDA graph buffer is padded so only perform a partial copy based on
413
+ # num_positions
414
+ input_buffers["input_positions"][:num_positions].copy_(
415
+ input_positions, non_blocking=True)
416
+ if is_encoder_decoder_model:
417
+ raise NotImplementedError(
418
+ "TritonMLAState does not support encoder/decoder yet")
419
+
420
+ def begin_forward(self, model_input):
421
+ if self.chunked_prefill_enabled or self.enable_prefix_caching:
422
+ if not hasattr(self, "context_chunk_workspace"):
423
+ # not self.runner.device does not return the correct device
424
+ # for this process, (init_device sets the correct device but
425
+ # only on the Worker). The only way Ive figured out to get the
426
+ # correct device is to allocate the workspace on the first call
427
+ # to begin_forward and use the device of the input tokens
428
+ assert model_input.input_tokens is not None
429
+ self.context_chunk_workspace = torch.empty(
430
+ (self.context_chunk_workspace_size,
431
+ self.model_config.get_head_size()),
432
+ dtype=self.model_config.dtype,
433
+ device=model_input.input_tokens.device,
434
+ )
435
+
436
+ model_input.attn_metadata.context_chunk_workspace = \
437
+ self.context_chunk_workspace
438
+
439
+
440
+ @dataclass
441
+ class MLACommonMetadata(AttentionMetadata):
442
+ """Metadata for MLACommon.
443
+
444
+ NOTE: Please read the comment at the top of the file before trying to
445
+ understand this class
446
+
447
+ NOTE: Any python object stored here is not updated when it is
448
+ cuda-graph replayed. If you have values that need to be changed
449
+ dynamically, it should be stored in tensor. The tensor has to be
450
+ updated from `CUDAGraphRunner.forward` API.
451
+ """
452
+ # Whether or not if cuda graph is enabled.
453
+ # Cuda-graph is currently enabled for decoding only.
454
+ # TODO(woosuk): Move `use_cuda_graph` out since it's unrelated to attention.
455
+ use_cuda_graph: bool
456
+
457
+ # New for MLA (compared to FlashAttention)
458
+ # Input positions for rotrary embeddings since for MLA the rotary
459
+ # position embeddings are applied inside the attention backend
460
+ input_positions: torch.Tensor
461
+
462
+ # NOTE(sang): Definition of context_len, query_len, and seq_len.
463
+ # |---------- N-1 iteration --------|
464
+ # |---------------- N iteration ---------------------|
465
+ # |- tokenA -|......................|-- newTokens ---|
466
+ # |---------- context_len ----------|
467
+ # |-------------------- seq_len ---------------------|
468
+ # |-- query_len ---|
469
+
470
+ # (batch_size,). The sequence length per sequence. Sequence length means
471
+ # the computed tokens + new tokens None if it is a decoding.
472
+ seq_lens: Optional[List[int]]
473
+ # seq_lens stored as a tensor.
474
+ seq_lens_tensor: Optional[torch.Tensor]
475
+
476
+ # Maximum sequence length among prefill batch. 0 if there are decoding
477
+ # requests only.
478
+ max_prefill_seq_len: int
479
+ # Maximum sequence length among decode batch. 0 if there are prefill
480
+ # requests only.
481
+ max_decode_seq_len: int
482
+ # (batch_size,) A tensor of context lengths (tokens that are computed
483
+ # so far).
484
+ context_lens_tensor: Optional[torch.Tensor]
485
+
486
+ # (batch_size, max_blocks_per_seq).
487
+ # Block addresses per sequence. (Seq id -> list of physical block)
488
+ # E.g., [0, 1, 2] means tokens are stored in 0th, 1st, and 2nd blocks
489
+ # in the kv cache. Each block can contain up to block_size tokens.
490
+ # 2nd dimensions are padded up to max_blocks_per_seq if it is cuda-graph
491
+ # captured.
492
+ block_tables: Optional[torch.Tensor]
493
+
494
+ # Maximum query length in the batch.
495
+ max_query_len: Optional[int] = None
496
+
497
+ # Max number of query tokens among request in the batch.
498
+ max_decode_query_len: Optional[int] = None
499
+
500
+ # (batch_size + 1,). The cumulative subquery lengths of the sequences in
501
+ # the batch, used to index into subquery. E.g., if the subquery length
502
+ # is [4, 6], it is [0, 4, 10].
503
+ query_start_loc: Optional[torch.Tensor] = None
504
+ # (batch_size + 1,). The cumulative sequence lengths of the sequences in
505
+ # the batch, used to index into sequence. E.g., if the sequence length is
506
+ # [4, 6], it is [0, 4, 10].
507
+ seq_start_loc: Optional[torch.Tensor] = None
508
+
509
+ _cached_prefill_metadata: Optional[Any] = None
510
+ _cached_decode_metadata: Optional[Any] = None
511
+
512
+ num_prefill_tokens: int
513
+
514
+ # The dimension of the attention heads
515
+ head_dim: Optional[int] = None
516
+
517
+ # Used when chunked prefill is enabled to simulate worst case workspace
518
+ # allocations, hopefully to avoid going OOM
519
+ is_profile_run: bool = False
520
+
521
+ # New for MLA (compared to FlashAttention)
522
+ # For chunked prefill
523
+ context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
524
+ context_chunk_starts: Optional[torch.Tensor] = None
525
+ context_chunk_seq_tot: Optional[List[int]] = None
526
+ context_chunk_max_seq_lens: Optional[List[int]] = None
527
+ # Set by MLAAttentionState in `begin_forward` so it doesn't get broadcasted
528
+ context_chunk_workspace: Optional[torch.Tensor] = None
529
+
530
+ def __post_init__(self):
531
+ supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
532
+ if self.head_dim is not None and self.head_dim \
533
+ not in supported_head_sizes:
534
+ raise ValueError(
535
+ f"Only {supported_head_sizes} are supported for head_dim,",
536
+ f" received {self.head_dim}.")
537
+
538
+ @property
539
+ def prefill_metadata(self):
540
+ if self.num_prefills == 0:
541
+ return None
542
+
543
+ if self._cached_prefill_metadata is not None:
544
+ return self._cached_prefill_metadata
545
+
546
+ assert self.seq_lens is not None
547
+ assert self.seq_lens_tensor is not None
548
+
549
+ # Compute some attn_metadata fields which default to None
550
+ query_start_loc = (None if self.query_start_loc is None else
551
+ self.query_start_loc[:self.num_prefills + 1])
552
+ slot_mapping = (None if self.slot_mapping is None else
553
+ self.slot_mapping[:self.num_prefill_tokens])
554
+ seq_lens = (None if self.seq_lens is None else
555
+ self.seq_lens[:self.num_prefills])
556
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
557
+ self.seq_lens_tensor[:self.num_prefills])
558
+ seq_start_loc = (None if self.seq_start_loc is None else
559
+ self.seq_start_loc[:self.num_prefills + 1])
560
+ context_lens_tensor = (None if self.context_lens_tensor is None else
561
+ self.context_lens_tensor[:self.num_prefills])
562
+ block_tables = (None if self.block_tables is None else
563
+ self.block_tables[:self.num_prefills])
564
+ input_positions = (None if self.input_positions is None else
565
+ self.input_positions[:self.num_prefill_tokens])
566
+
567
+ self._cached_prefill_metadata = self.__class__(
568
+ # Required by ModelRunner
569
+ use_cuda_graph=False, # Not Attention Related
570
+ # Required by Attention Metadata
571
+ num_prefills=self.num_prefills,
572
+ num_prefill_tokens=self.num_prefill_tokens,
573
+ num_decode_tokens=0,
574
+ slot_mapping=slot_mapping,
575
+ # Required by Attention Metadata (not used)
576
+ multi_modal_placeholder_index_maps=None,
577
+ enable_kv_scales_calculation=False,
578
+ # MLACommonMetadata
579
+ input_positions=input_positions,
580
+ seq_lens=seq_lens,
581
+ seq_lens_tensor=seq_lens_tensor,
582
+ max_query_len=self.max_query_len,
583
+ max_prefill_seq_len=self.max_prefill_seq_len,
584
+ max_decode_query_len=0,
585
+ max_decode_seq_len=0,
586
+ query_start_loc=query_start_loc,
587
+ seq_start_loc=seq_start_loc,
588
+ context_lens_tensor=context_lens_tensor,
589
+ block_tables=block_tables,
590
+ head_dim=self.head_dim,
591
+ is_profile_run=self.is_profile_run,
592
+ # MLACommonMetadata Chunk prefill specific
593
+ context_chunk_cu_seq_lens=self.context_chunk_cu_seq_lens,
594
+ context_chunk_starts=self.context_chunk_starts,
595
+ context_chunk_seq_tot=self.context_chunk_seq_tot,
596
+ context_chunk_max_seq_lens=self.context_chunk_max_seq_lens,
597
+ )
598
+ return self._cached_prefill_metadata
599
+
600
+ @property
601
+ def decode_metadata(self):
602
+ if self.num_decode_tokens == 0:
603
+ return None
604
+
605
+ if self._cached_decode_metadata is not None:
606
+ return self._cached_decode_metadata
607
+ assert self.seq_lens_tensor is not None
608
+
609
+ # Compute some attn_metadata fields which default to None
610
+ slot_mapping = (None if self.slot_mapping is None else
611
+ self.slot_mapping[self.num_prefill_tokens:])
612
+ seq_lens_tensor = (None if self.seq_lens_tensor is None else
613
+ self.seq_lens_tensor[self.num_prefills:])
614
+ block_tables = (None if self.block_tables is None else
615
+ self.block_tables[self.num_prefills:])
616
+ input_positions = (None if self.input_positions is None else
617
+ self.input_positions[self.num_prefill_tokens:])
618
+
619
+ self._cached_decode_metadata = self.__class__(
620
+ # Required by ModelRunner
621
+ use_cuda_graph=self.use_cuda_graph, # Not Attention Related
622
+ # Required by Attention Metadata
623
+ num_prefills=0,
624
+ num_prefill_tokens=0,
625
+ num_decode_tokens=self.num_decode_tokens,
626
+ slot_mapping=slot_mapping,
627
+ # Required by Attention Metadata (not used)
628
+ multi_modal_placeholder_index_maps=None,
629
+ enable_kv_scales_calculation=False,
630
+ # MLACommonMetadata
631
+ seq_lens=None,
632
+ seq_lens_tensor=seq_lens_tensor,
633
+ max_decode_query_len=self.max_decode_query_len,
634
+ max_query_len=self.max_query_len,
635
+ max_prefill_seq_len=0,
636
+ max_decode_seq_len=self.max_decode_seq_len,
637
+ # Batch may be composed of prefill|decodes, adjust query start
638
+ # indices to refer to the start of decodes. E.g.
639
+ # in tokens:[3 prefills|6 decodes], query_start_loc=[3,9] => [0,6].
640
+ query_start_loc=(self.query_start_loc[self.num_prefills:] -
641
+ self.query_start_loc[self.num_prefills])
642
+ if self.query_start_loc is not None else None,
643
+ seq_start_loc=self.seq_start_loc[self.num_prefills:]
644
+ if self.seq_start_loc is not None else None,
645
+ context_lens_tensor=None,
646
+ block_tables=block_tables,
647
+ input_positions=input_positions,
648
+ head_dim=self.head_dim,
649
+ is_profile_run=self.is_profile_run)
650
+ return self._cached_decode_metadata
651
+
652
+ def advance_step(self,
653
+ model_input: "ModelInputForGPUWithSamplingMetadata",
654
+ sampled_token_ids: Optional[torch.Tensor],
655
+ block_size: int,
656
+ num_seqs: int,
657
+ num_queries: int,
658
+ turn_prefills_into_decodes: bool = False):
659
+ """
660
+ Update metadata in-place to advance one decode step.
661
+ """
662
+ # When using cudagraph, the num_seqs is padded to the next captured
663
+ # batch sized, but num_queries tracks the actual number of requests in
664
+ # the batch. For --enforce-eager mode, num_seqs == num_queries
665
+ if num_seqs != num_queries:
666
+ assert num_seqs > num_queries
667
+
668
+ if turn_prefills_into_decodes:
669
+ # When Multi-Step is enabled with Chunked-Prefill, prefills and
670
+ # decodes are scheduled together. In the first step, all the
671
+ # prefills turn into decodes. This update reflects that
672
+ # conversion.
673
+ assert self.num_decode_tokens + self.num_prefills == num_seqs
674
+ self.num_decode_tokens += self.num_prefills
675
+ self.num_prefills = 0
676
+ self.num_prefill_tokens = 0
677
+ self.max_prefill_seq_len = 0
678
+ self.max_query_len = 1
679
+
680
+ self.slot_mapping = self.slot_mapping[:num_seqs]
681
+ else:
682
+ assert self.seq_lens is not None
683
+ assert self.max_decode_seq_len == max(self.seq_lens)
684
+
685
+ assert self.num_prefills == 0
686
+ assert self.num_prefill_tokens == 0
687
+ assert self.num_decode_tokens == num_seqs
688
+ assert self.slot_mapping.shape == (num_seqs, )
689
+
690
+ assert self.seq_lens is not None
691
+ assert len(self.seq_lens) == num_seqs
692
+ assert self.seq_lens_tensor is not None
693
+ assert self.seq_lens_tensor.shape == (num_seqs, )
694
+ assert self.max_query_len == 1
695
+ assert self.max_prefill_seq_len == 0
696
+
697
+ assert self.query_start_loc is not None
698
+ assert self.query_start_loc.shape == (num_queries + 1, )
699
+ assert self.seq_start_loc is not None
700
+ assert self.seq_start_loc.shape == (num_seqs + 1, )
701
+
702
+ assert self.context_lens_tensor is not None
703
+ assert self.context_lens_tensor.shape == (num_queries, )
704
+
705
+ assert self.block_tables is not None
706
+ assert self.block_tables.shape[0] == num_seqs
707
+
708
+ # Update query lengths. Note that we update only queries and not seqs,
709
+ # since tensors may be padded due to captured cuda graph batch size
710
+ for i in range(num_queries):
711
+ self.seq_lens[i] += 1
712
+ self.max_decode_seq_len = max(self.seq_lens)
713
+
714
+ self._ops_advance_step(num_seqs=num_seqs,
715
+ num_queries=num_queries,
716
+ block_size=block_size,
717
+ input_tokens=model_input.input_tokens,
718
+ sampled_token_ids=sampled_token_ids,
719
+ input_positions=model_input.input_positions)
720
+
721
+ def _ops_advance_step(self, num_seqs: int, num_queries: int,
722
+ block_size: int, input_tokens: torch.Tensor,
723
+ sampled_token_ids: torch.Tensor,
724
+ input_positions: torch.Tensor) -> None:
725
+ # here we use advance_step_flashinfo to update the paged_kv_* tensors
726
+ ops.advance_step_flashattn(num_seqs=num_seqs,
727
+ num_queries=num_queries,
728
+ block_size=block_size,
729
+ input_tokens=input_tokens,
730
+ sampled_token_ids=sampled_token_ids,
731
+ input_positions=input_positions,
732
+ seq_lens=self.seq_lens_tensor,
733
+ slot_mapping=self.slot_mapping,
734
+ block_tables=self.block_tables)
735
+
736
+
737
+ class MLACommonMetadataBuilder(AttentionMetadataBuilder[T], Generic[T]):
738
+ """
739
+ NOTE: Please read the comment at the top of the file before trying to
740
+ understand this class
741
+ """
742
+ BLOCK_TABLE_EXTENDER: list[list[int]] = []
743
+
744
+ def __init__(self, input_builder: "ModelInputForGPUBuilder"):
745
+ self.input_builder = input_builder
746
+ self.runner = input_builder.runner
747
+ self.sliding_window = input_builder.sliding_window
748
+ self.block_size = input_builder.block_size
749
+ self.chunked_prefill_enabled = \
750
+ self.runner.scheduler_config.chunked_prefill_enabled
751
+ self.enable_prefix_caching = \
752
+ self.runner.cache_config.enable_prefix_caching
753
+
754
+ if self.chunked_prefill_enabled or self.enable_prefix_caching:
755
+ attn_state = self.input_builder.runner.attn_state
756
+ self.context_chunk_workspace_size = \
757
+ attn_state.context_chunk_workspace_size
758
+ self.page_size = self.runner.block_size
759
+
760
+ def prepare(self):
761
+ self.slot_mapping: List[int] = []
762
+ self.prefill_seq_lens: List[int] = []
763
+ self.context_lens: List[int] = []
764
+ self.block_tables: List[List[int]] = []
765
+ self.curr_seq_lens: List[int] = []
766
+ self.input_positions: List[int] = []
767
+ self.multimodal_placeholder_maps: Dict[
768
+ str,
769
+ MultiModalPlaceholderMap] = defaultdict(MultiModalPlaceholderMap)
770
+ self.num_prefills = 0
771
+ self.num_prefill_tokens = 0
772
+ self.num_decode_tokens = 0
773
+ self.has_prefix_cache_hit = False
774
+
775
+ def _add_seq_group(
776
+ self, inter_data: "ModelInputForGPUBuilder.InterDataForSeqGroup",
777
+ chunked_prefill_enabled: bool, prefix_cache_hit: bool):
778
+ """Add a sequence group to the metadata. Specifically update/append
779
+ 1. context length.
780
+ 2. block table.
781
+ 3. slot mapping.
782
+ """
783
+ is_prompt = inter_data.is_prompt
784
+ block_tables = inter_data.block_tables
785
+
786
+ for (seq_id, token_len, seq_len, curr_seq_len, query_len, context_len,
787
+ curr_sliding_window_block, input_positions) in zip(
788
+ inter_data.seq_ids, [len(t) for t in inter_data.input_tokens],
789
+ inter_data.orig_seq_lens, inter_data.seq_lens,
790
+ inter_data.query_lens, inter_data.context_lens,
791
+ inter_data.curr_sliding_window_blocks,
792
+ inter_data.input_positions):
793
+ self.input_positions.extend(input_positions)
794
+ self.context_lens.append(context_len)
795
+ if is_prompt:
796
+ self.num_prefills += 1
797
+ self.num_prefill_tokens += token_len
798
+ self.prefill_seq_lens.append(seq_len)
799
+ else:
800
+ self.num_decode_tokens += query_len
801
+ self.curr_seq_lens.append(curr_seq_len)
802
+
803
+ # Compute block table.
804
+ # TODO(sang): Combine chunked prefill and prefix caching by
805
+ # only allowing multiple of block_size chunk size.
806
+ # NOTE: This only works for oooooooxxx style attention.
807
+ block_table = []
808
+ if prefix_cache_hit:
809
+ # NOTE(woosuk): For flash-attn, the block table should
810
+ # include the entries for the incoming prefill tokens.
811
+ block_table = block_tables[seq_id]
812
+ elif ((chunked_prefill_enabled or not is_prompt)
813
+ and block_tables is not None):
814
+ if curr_sliding_window_block == 0:
815
+ block_table = block_tables[seq_id]
816
+ else:
817
+ block_table = block_tables[seq_id][
818
+ -curr_sliding_window_block:]
819
+ self.block_tables.append(block_table)
820
+
821
+ # Compute slot mapping.
822
+ is_profile_run = is_block_tables_empty(block_tables)
823
+ start_idx = compute_slot_mapping_start_idx(is_prompt, query_len,
824
+ context_len,
825
+ self.sliding_window)
826
+ compute_slot_mapping(is_profile_run, self.slot_mapping, seq_id,
827
+ seq_len, context_len, start_idx,
828
+ self.block_size, inter_data.block_tables)
829
+
830
+ def _get_graph_runner_block_tables(
831
+ self, num_seqs: int,
832
+ block_tables: List[List[int]]) -> torch.Tensor:
833
+ # The shape of graph_block_tables is
834
+ # [max batch size, max context len // block size].
835
+ max_batch_size, max_blocks = self.runner.graph_block_tables.shape
836
+ assert max_batch_size >= num_seqs
837
+
838
+ graph_block_tables = self.runner.graph_block_tables[:num_seqs]
839
+ for i, block_table in enumerate(block_tables):
840
+ if block_table:
841
+ num_blocks = len(block_table)
842
+ if num_blocks <= max_blocks:
843
+ graph_block_tables[i, :num_blocks] = block_table
844
+ else:
845
+ # It may be possible to have more blocks allocated due
846
+ # to lookahead slots of multi-step, however, they are
847
+ # not used anyway, so can be safely ignored.
848
+ graph_block_tables[
849
+ i, :max_blocks] = block_table[:max_blocks]
850
+
851
+ return torch.from_numpy(graph_block_tables).to(
852
+ device=self.runner.device, non_blocking=True)
853
+
854
+ def build(self, seq_lens: List[int], query_lens: List[int],
855
+ cuda_graph_pad_size: int, batch_size: int):
856
+ """Build attention metadata with on-device tensors.
857
+
858
+ Args:
859
+ seq_lens: The maybe padded sequence lengths of the input sequences.
860
+ query_lens: The query lengths of the input sequences.
861
+ cuda_graph_pad_size: The padding size for cuda graph.
862
+ -1 if cuda graph is not used.
863
+ batch_size: The maybe padded batch size.
864
+ """
865
+ prefix_cache_hit = any([
866
+ inter_data.prefix_cache_hit
867
+ for inter_data in self.input_builder.inter_data_list
868
+ ])
869
+
870
+ for inter_data in self.input_builder.inter_data_list:
871
+ self._add_seq_group(inter_data,
872
+ self.input_builder.chunked_prefill_enabled,
873
+ prefix_cache_hit)
874
+
875
+ device = self.runner.device
876
+ use_captured_graph = cuda_graph_pad_size != -1
877
+
878
+ max_query_len = max(query_lens)
879
+ decode_query_lens = query_lens[self.num_prefills:]
880
+ if len(decode_query_lens) > 0:
881
+ max_decode_query_len = max(decode_query_lens)
882
+ else:
883
+ max_decode_query_len = 1
884
+ max_prefill_seq_len = max(self.prefill_seq_lens, default=0)
885
+ max_decode_seq_len = max(self.curr_seq_lens, default=0)
886
+ num_decode_tokens = self.num_decode_tokens
887
+ query_start_loc = list(accumulate(query_lens, initial=0))
888
+ seq_start_loc = list(accumulate(seq_lens, initial=0))
889
+
890
+ num_seqs = len(seq_lens)
891
+ if use_captured_graph:
892
+ self.slot_mapping.extend([PAD_SLOT_ID] * cuda_graph_pad_size)
893
+ self.block_tables.extend(self.__class__.BLOCK_TABLE_EXTENDER *
894
+ cuda_graph_pad_size)
895
+ num_decode_tokens = batch_size - self.num_prefill_tokens
896
+
897
+ block_tables = self._get_graph_runner_block_tables(
898
+ num_seqs, self.block_tables)
899
+ else:
900
+ block_tables = make_tensor_with_pad(
901
+ self.block_tables,
902
+ pad=0,
903
+ dtype=torch.int,
904
+ device=device,
905
+ )
906
+ assert max_query_len > 0, ("query_lens: {}".format(query_lens))
907
+
908
+ assert device is not None
909
+ context_lens_tensor = async_tensor_h2d(self.context_lens, torch.int,
910
+ device, self.runner.pin_memory)
911
+ seq_lens_tensor = async_tensor_h2d(seq_lens, torch.int, device,
912
+ self.runner.pin_memory)
913
+ input_positions = async_tensor_h2d(self.input_positions, torch.long,
914
+ device, self.runner.pin_memory)
915
+ slot_mapping_tensor = async_tensor_h2d(self.slot_mapping, torch.long,
916
+ device, self.runner.pin_memory)
917
+ query_start_loc_tensor = async_tensor_h2d(query_start_loc, torch.int32,
918
+ device,
919
+ self.runner.pin_memory)
920
+ seq_start_loc_tensor = async_tensor_h2d(seq_start_loc, torch.int32,
921
+ device, self.runner.pin_memory)
922
+
923
+ context_chunk_cu_seq_lens = None
924
+ context_chunk_starts = None
925
+ context_chunk_seq_tot = None
926
+ context_chunk_max_seq_lens = None
927
+
928
+ if (self.chunked_prefill_enabled or self.enable_prefix_caching) \
929
+ and self.num_prefills > 0 \
930
+ and context_lens_tensor is not None \
931
+ and context_lens_tensor[:self.num_prefills].max() > 0:
932
+
933
+ # NOTE: it is recommend you read the `Chunked Prefill` section in
934
+ # the comment at the top of the file before trying to understand
935
+ # the following code
936
+
937
+ num_prefills_with_context = \
938
+ (context_lens_tensor[:self.num_prefills] > 0).sum().item()
939
+
940
+ # currently we allocate an equal amount of workspace for each
941
+ # prefill in the batch, we could probably use a more advanced
942
+ # algorithm here and allocate more workspace to prefills with
943
+ # longer context lengths
944
+ max_context_chunk = \
945
+ self.context_chunk_workspace_size // num_prefills_with_context
946
+
947
+ # align max_context_chunk to page_size by rounding down,
948
+ # currently the `gather_cache` kernel cannot handle
949
+ # `context_chunk_starts` that are not aligned to page_size
950
+ max_context_chunk = round_down(max_context_chunk, self.page_size)
951
+ assert max_context_chunk > 0
952
+ num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
953
+
954
+ # if `max_context_chunk = 256`, `num_chunks = 3`, and
955
+ # `num_prefills_with_context = 4`, create a tensor that looks like
956
+ # [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
957
+ context_chunk_starts = \
958
+ torch.arange(num_chunks, device=device, dtype=torch.int32)\
959
+ .unsqueeze(1).expand(-1, self.num_prefills)\
960
+ * max_context_chunk
961
+ chunk_ends = torch.min(context_lens_tensor[:self.num_prefills]\
962
+ .unsqueeze(0), context_chunk_starts + max_context_chunk)
963
+ chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
964
+ _context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
965
+ torch.int32)
966
+ zero = torch.zeros(num_chunks, dtype=torch.int32, device=device)\
967
+ .unsqueeze(-1)
968
+ context_chunk_cu_seq_lens = \
969
+ torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
970
+ context_chunk_max_seq_lens = \
971
+ chunk_seq_lens.max(dim=1).values.tolist()
972
+ context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
973
+ assert max(context_chunk_seq_tot) <= \
974
+ self.context_chunk_workspace_size
975
+
976
+ return self.runner.attn_backend.make_metadata(
977
+ # Required by ModelRunner
978
+ use_cuda_graph=use_captured_graph, # Not Attention Related
979
+ # Required by Attention Metadata
980
+ num_prefills=self.num_prefills,
981
+ slot_mapping=slot_mapping_tensor,
982
+ num_prefill_tokens=self.num_prefill_tokens,
983
+ num_decode_tokens=num_decode_tokens,
984
+ # Required by Attention Metadata (not used)
985
+ multi_modal_placeholder_index_maps=None, # Not Attention Related
986
+ enable_kv_scales_calculation=False,
987
+ # MLACommonMetadata
988
+ input_positions=input_positions,
989
+ seq_lens=seq_lens,
990
+ seq_lens_tensor=seq_lens_tensor,
991
+ max_query_len=max_query_len,
992
+ max_decode_query_len=max_decode_query_len,
993
+ max_prefill_seq_len=max_prefill_seq_len,
994
+ max_decode_seq_len=max_decode_seq_len,
995
+ query_start_loc=query_start_loc_tensor,
996
+ seq_start_loc=seq_start_loc_tensor,
997
+ context_lens_tensor=context_lens_tensor,
998
+ block_tables=block_tables,
999
+ head_dim=self.runner.model_config.get_head_size(),
1000
+ is_profile_run=self.runner.in_profile_run,
1001
+ # MLACommonMetadata Chunk prefill specific
1002
+ context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
1003
+ context_chunk_starts=context_chunk_starts,
1004
+ context_chunk_seq_tot=context_chunk_seq_tot,
1005
+ context_chunk_max_seq_lens=context_chunk_max_seq_lens,
1006
+ )
1007
+
1008
+
1009
+ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
1010
+ """
1011
+ NOTE: Please read the comment at the top of the file before trying to
1012
+ understand this class
1013
+ """
1014
+
1015
+ def __init__(
1016
+ self,
1017
+ num_heads: int,
1018
+ head_size: int,
1019
+ scale: float,
1020
+ num_kv_heads: int,
1021
+ alibi_slopes: Optional[List[float]],
1022
+ sliding_window: Optional[int],
1023
+ kv_cache_dtype: str,
1024
+ blocksparse_params: Optional[Dict[str, Any]],
1025
+ logits_soft_cap: Optional[float],
1026
+ attn_type: str,
1027
+ # MLA Specific Arguments
1028
+ q_lora_rank: Optional[int],
1029
+ kv_lora_rank: int,
1030
+ qk_nope_head_dim: int,
1031
+ qk_rope_head_dim: int,
1032
+ qk_head_dim: int,
1033
+ v_head_dim: int,
1034
+ rotary_emb: RotaryEmbedding,
1035
+ # q_proj should be q_b_proj if q_lora_rank is not None, but from an
1036
+ # attention backend perspective we rely on the layer to pass in the
1037
+ # correct matrix
1038
+ q_proj: ColumnParallelLinear,
1039
+ kv_b_proj: ColumnParallelLinear,
1040
+ o_proj: RowParallelLinear,
1041
+ ) -> None:
1042
+ self.num_heads = num_heads
1043
+ self.head_size = head_size
1044
+ self.scale = float(scale)
1045
+ self.num_kv_heads = num_kv_heads
1046
+ self.kv_cache_dtype = kv_cache_dtype
1047
+
1048
+ self.q_lora_rank = q_lora_rank
1049
+ self.kv_lora_rank = kv_lora_rank
1050
+ self.qk_nope_head_dim = qk_nope_head_dim
1051
+ self.qk_rope_head_dim = qk_rope_head_dim
1052
+ self.qk_head_dim = qk_head_dim
1053
+ self.v_head_dim = v_head_dim
1054
+
1055
+ self.rotary_emb = rotary_emb
1056
+ self.use_yarn_rope = isinstance(rotary_emb,
1057
+ DeepseekScalingRotaryEmbedding)
1058
+ self.q_proj = q_proj
1059
+ self.kv_b_proj = kv_b_proj
1060
+ self.o_proj = o_proj
1061
+
1062
+ self.triton_fa_func = triton_attention
1063
+ # Handle the differences between the flash_attn_varlen from flash_attn
1064
+ # and the one from vllm_flash_attn. The former is used on RoCM and the
1065
+ # latter has an additional parameter to control FA2 vs FA3
1066
+ self.flash_attn_varlen_func = flash_attn_varlen_func
1067
+ self.vllm_flash_attn_version = get_flash_attn_version()
1068
+ if self.vllm_flash_attn_version is not None:
1069
+ self.flash_attn_varlen_func = \
1070
+ functools.partial(flash_attn_varlen_func,
1071
+ fa_version=self.vllm_flash_attn_version)
1072
+
1073
+ # For MLA the v head dim is smaller than qk head dim so we pad out
1074
+ # v with 0s to match the qk head dim for attention backends that do
1075
+ # not support different headdims
1076
+ # We don't need to pad V if we are on a hopper system with FA3
1077
+ self._pad_v = self.vllm_flash_attn_version is None or not (
1078
+ self.vllm_flash_attn_version == 3
1079
+ and current_platform.get_device_capability()[0] == 9)
1080
+
1081
+ def _flash_attn_varlen_diff_headdims(self, q, k, v, softmax_scale,
1082
+ return_softmax_lse, **kwargs):
1083
+ maybe_padded_v = v
1084
+ if self._pad_v:
1085
+ maybe_padded_v = torch.nn.functional.pad(
1086
+ v, [0, q.shape[-1] - v.shape[-1]], value=0)
1087
+
1088
+ if is_hip and envs.VLLM_USE_TRITON_FLASH_ATTN \
1089
+ and not return_softmax_lse:
1090
+ attn_out = self.triton_fa_func(
1091
+ q,
1092
+ k,
1093
+ maybe_padded_v,
1094
+ None, # output
1095
+ kwargs["cu_seqlens_q"],
1096
+ kwargs["cu_seqlens_k"],
1097
+ kwargs["max_seqlen_q"],
1098
+ kwargs["max_seqlen_k"],
1099
+ kwargs["causal"],
1100
+ softmax_scale,
1101
+ None, # bias
1102
+ )
1103
+ if is_vllm_fa:
1104
+ attn_out = self.flash_attn_varlen_func(
1105
+ q=q,
1106
+ k=k,
1107
+ v=maybe_padded_v,
1108
+ return_softmax_lse=return_softmax_lse,
1109
+ softmax_scale=softmax_scale,
1110
+ **kwargs,
1111
+ )
1112
+ else:
1113
+ # Use return_attn_probs instead of return_softmax_lse for RoCM
1114
+ attn_out = self.flash_attn_varlen_func(
1115
+ q=q,
1116
+ k=k,
1117
+ v=maybe_padded_v,
1118
+ return_attn_probs=return_softmax_lse,
1119
+ softmax_scale=softmax_scale,
1120
+ **kwargs,
1121
+ )
1122
+
1123
+ # Unpack the output if there is multiple results,
1124
+ # triton always returns (output, softmax_lse),
1125
+ # vllm_flash_attn returns (output, softmax_lse) when
1126
+ # `return_softmax_lse = True`
1127
+ # flash_attn (RoCM) returns (output, softmax_lse, ...) when
1128
+ # `return_attn_probs = True`
1129
+ rest = None
1130
+ if isinstance(attn_out, tuple):
1131
+ attn_out, *rest = attn_out
1132
+
1133
+ # unpad if necessary
1134
+ if self._pad_v:
1135
+ attn_out = attn_out[..., :v.shape[-1]]
1136
+
1137
+ # Remain consistent with old `flash_attn_varlen_func` where there
1138
+ # is only one output tensor if `return_softmax_lse` is False.
1139
+ if return_softmax_lse:
1140
+ assert rest is not None
1141
+ return attn_out, rest[0]
1142
+ return attn_out
1143
+
1144
+ def _v_up_proj_and_o_proj(self, x):
1145
+ # Convert from (B, N, L) to (N, B, L)
1146
+ x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
1147
+ # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
1148
+ x = torch.bmm(x, self.W_UV)
1149
+ # Convert from (N, B, V) to (B, N * V)
1150
+ x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
1151
+ return self.o_proj(x)[0]
1152
+
1153
+ # Return `ql_nope`, `q_pe`
1154
+ def _q_proj_and_k_up_proj(self, x):
1155
+ q_nope, q_pe = self.q_proj(x)[0]\
1156
+ .view(-1, self.num_heads, self.qk_head_dim)\
1157
+ .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
1158
+
1159
+ # Convert from (B, N, P) to (N, B, P)
1160
+ q_nope = q_nope.transpose(0, 1)
1161
+ # Multiply (N, B, P) x (N, P, L) -> (N, B, L)
1162
+ ql_nope = torch.bmm(q_nope, self.W_UK_T)
1163
+ # Convert from (N, B, L) to (B, N, L)
1164
+ return ql_nope.transpose(0, 1), q_pe
1165
+
1166
+ def process_weights_after_loading(self, act_dtype: torch.dtype):
1167
+
1168
+ def get_layer_weight(layer):
1169
+ WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
1170
+ for attr in WEIGHT_NAMES:
1171
+ if hasattr(layer, attr):
1172
+ return getattr(layer, attr)
1173
+ raise AttributeError(
1174
+ f"Layer '{layer}' has no recognized weight attribute:"
1175
+ f" {WEIGHT_NAMES}.")
1176
+
1177
+ def get_and_maybe_dequant_weights(layer: LinearBase):
1178
+ if not isinstance(layer.quant_method, UnquantizedLinearMethod):
1179
+ # NOTE: This should only be used offline, since it's O(N^3)
1180
+ eye = torch.eye(layer.input_size_per_partition,
1181
+ dtype=act_dtype,
1182
+ device=get_layer_weight(layer).device)
1183
+ dequant_weights = layer.quant_method.apply(layer,
1184
+ eye,
1185
+ bias=None)
1186
+ del eye
1187
+ # standardize to (output, input)
1188
+ return dequant_weights.T
1189
+ return layer.weight
1190
+
1191
+ # we currently do not have quantized bmm's which are needed for
1192
+ # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
1193
+ # the bmm's in 16-bit, the extra memory overhead of this is fairly low
1194
+ kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
1195
+ assert kv_b_proj_weight.shape == (
1196
+ self.kv_lora_rank,
1197
+ self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
1198
+ f"{kv_b_proj_weight.shape=}, "
1199
+ f"{self.kv_lora_rank=}, "
1200
+ f"{self.num_heads=}, "
1201
+ f"{self.qk_nope_head_dim=}, "
1202
+ f"{self.v_head_dim=}")
1203
+ kv_b_proj_weight = kv_b_proj_weight.view(
1204
+ self.kv_lora_rank,
1205
+ self.num_heads,
1206
+ self.qk_nope_head_dim + self.v_head_dim,
1207
+ )
1208
+
1209
+ W_UK, W_UV = kv_b_proj_weight.split(
1210
+ [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1211
+
1212
+ # Convert from (L, N, V) to (N, L, V)
1213
+ self.W_UV = W_UV.transpose(0, 1)
1214
+ # Convert from (L, N, P) to (N, P, L)
1215
+ self.W_UK_T = W_UK.permute(1, 2, 0)
1216
+
1217
+ def _compute_prefill_context(
1218
+ self,
1219
+ q: torch.Tensor,
1220
+ kv_c_and_k_pe_cache: torch.Tensor,
1221
+ attn_metadata: MLACommonMetadata,
1222
+ ):
1223
+ prefill_metadata = attn_metadata.prefill_metadata
1224
+ assert prefill_metadata is not None
1225
+ assert prefill_metadata.context_chunk_seq_tot is not None
1226
+ assert prefill_metadata.context_chunk_cu_seq_lens is not None
1227
+ assert prefill_metadata.context_chunk_starts is not None
1228
+ assert prefill_metadata.context_chunk_max_seq_lens is not None
1229
+ assert prefill_metadata.context_lens_tensor is not None
1230
+
1231
+ output = None
1232
+ iters = len(prefill_metadata.context_chunk_seq_tot)
1233
+
1234
+ # Fetch from attn_metadata directly, since it late bound by
1235
+ # MLAAttentionState, grabbing it directly `attn_metadata` can avoid
1236
+ # any weirdness around prefill_metadata caching
1237
+ assert attn_metadata.context_chunk_workspace is not None
1238
+ workspace = attn_metadata.context_chunk_workspace
1239
+
1240
+ for i in range(iters):
1241
+ toks = prefill_metadata.context_chunk_seq_tot[i]
1242
+
1243
+ ops.gather_cache(
1244
+ src_cache=kv_c_and_k_pe_cache,
1245
+ dst=workspace,
1246
+ block_table=prefill_metadata.block_tables,
1247
+ cu_seq_lens=prefill_metadata.context_chunk_cu_seq_lens[i],
1248
+ batch_size=prefill_metadata.num_prefills,
1249
+ seq_starts=prefill_metadata.context_chunk_starts[i],
1250
+ )
1251
+
1252
+ kv_c_normed = workspace[:toks]\
1253
+ [..., :self.kv_lora_rank]
1254
+ k_pe = workspace[:toks]\
1255
+ [..., self.kv_lora_rank:].unsqueeze(1)
1256
+
1257
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
1258
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
1259
+ k_nope, v = kv_nope\
1260
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1261
+
1262
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
1263
+ dim=-1)
1264
+
1265
+ attn_output, attn_softmax_lse = \
1266
+ self._flash_attn_varlen_diff_headdims(
1267
+ q=q,
1268
+ k=k,
1269
+ v=v,
1270
+ cu_seqlens_q=prefill_metadata.query_start_loc,
1271
+ cu_seqlens_k=prefill_metadata.context_chunk_cu_seq_lens[i],
1272
+ max_seqlen_q=prefill_metadata.max_query_len,
1273
+ max_seqlen_k=prefill_metadata.context_chunk_max_seq_lens[i],
1274
+ softmax_scale=self.scale,
1275
+ causal=False, # Context is unmasked
1276
+ return_softmax_lse=True,
1277
+ )
1278
+
1279
+ if output is None:
1280
+ output = attn_output
1281
+ output_lse = attn_softmax_lse
1282
+ else:
1283
+ output_tmp = torch.empty_like(output)
1284
+ output_lse_tmp = torch.empty_like(output_lse)
1285
+ merge_attn_states(
1286
+ output=output_tmp,
1287
+ output_lse=output_lse_tmp,
1288
+ prefix_output=output,
1289
+ prefix_lse=output_lse,
1290
+ suffix_output=attn_output,
1291
+ suffix_lse=attn_softmax_lse,
1292
+ )
1293
+ output = output_tmp
1294
+ output_lse = output_lse_tmp
1295
+
1296
+ return output, output_lse
1297
+
1298
+ def _forward_prefill(
1299
+ self,
1300
+ q: torch.Tensor,
1301
+ kv_c_normed: torch.Tensor,
1302
+ k_pe: torch.Tensor,
1303
+ kv_c_and_k_pe_cache: torch.Tensor,
1304
+ attn_metadata: MLACommonMetadata,
1305
+ ) -> torch.Tensor:
1306
+
1307
+ prefill_metadata = attn_metadata.prefill_metadata
1308
+ assert prefill_metadata is not None
1309
+
1310
+ has_context = prefill_metadata.context_lens_tensor is not None \
1311
+ and prefill_metadata.context_lens_tensor.max() > 0
1312
+
1313
+ kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
1314
+ -1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
1315
+ k_nope, v = kv_nope\
1316
+ .split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
1317
+
1318
+ k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
1319
+
1320
+ output = self._flash_attn_varlen_diff_headdims(
1321
+ q=q,
1322
+ k=k,
1323
+ v=v,
1324
+ cu_seqlens_q=prefill_metadata.query_start_loc,
1325
+ cu_seqlens_k=prefill_metadata.query_start_loc,
1326
+ max_seqlen_q=prefill_metadata.max_prefill_seq_len,
1327
+ max_seqlen_k=prefill_metadata.max_prefill_seq_len,
1328
+ softmax_scale=self.scale,
1329
+ causal=True,
1330
+ return_softmax_lse=has_context,
1331
+ )
1332
+
1333
+ if has_context:
1334
+ # ROCm flash_attn_varlen_func will return 3 objects instead of 2
1335
+ suffix_output, suffix_lse = output
1336
+ context_output, context_lse = self._compute_prefill_context( \
1337
+ q, kv_c_and_k_pe_cache, attn_metadata)
1338
+
1339
+ output = torch.empty_like(suffix_output)
1340
+ merge_attn_states(
1341
+ output=output,
1342
+ prefix_output=context_output,
1343
+ prefix_lse=context_lse,
1344
+ suffix_output=suffix_output,
1345
+ suffix_lse=suffix_lse,
1346
+ )
1347
+
1348
+ return self.o_proj(output.flatten(start_dim=-2))[0]
1349
+
1350
+ @abstractmethod
1351
+ def _forward_decode(
1352
+ self,
1353
+ ql_nope: torch.Tensor,
1354
+ q_pe: torch.Tensor,
1355
+ kv_c_and_k_pe_cache: torch.Tensor,
1356
+ attn_metadata: T,
1357
+ ) -> torch.Tensor:
1358
+ raise NotImplementedError
1359
+
1360
+ def forward(
1361
+ self,
1362
+ layer: AttentionLayer,
1363
+ hidden_states_or_q_c: torch.Tensor, # query in unified attn
1364
+ k_c_normed: torch.Tensor, # key in unified attn
1365
+ k_pe: torch.Tensor, # value in unified attn
1366
+ kv_cache: torch.Tensor,
1367
+ attn_metadata: T,
1368
+ output: Optional[torch.Tensor] = None,
1369
+ ) -> torch.Tensor:
1370
+ if output is not None:
1371
+ raise NotImplementedError(
1372
+ "output is not yet supported for MLAImplBase")
1373
+
1374
+ if attn_metadata.is_profile_run and \
1375
+ attn_metadata.context_chunk_workspace is not None:
1376
+ # During the profile run try to simulate to worse case output size
1377
+ # for `self.kv_b_proj(kv_c_normed)` in `_compute_prefill_context`
1378
+ # since this can be large
1379
+ _ = torch.empty(
1380
+ (attn_metadata.context_chunk_workspace.shape[0],
1381
+ self.num_heads, self.qk_nope_head_dim + self.v_head_dim),
1382
+ device=k_c_normed.device,
1383
+ dtype=k_c_normed.dtype,
1384
+ )
1385
+
1386
+ has_decode = attn_metadata.decode_metadata is not None
1387
+ has_prefill = attn_metadata.prefill_metadata is not None
1388
+
1389
+ # Restore head dim (for rotary embedding)
1390
+ k_pe = k_pe.unsqueeze(1)
1391
+ assert hasattr(attn_metadata, "input_positions")
1392
+
1393
+ num_prefill_tokens: int = attn_metadata.num_prefill_tokens
1394
+
1395
+ decode_hs_or_q_c = hidden_states_or_q_c[num_prefill_tokens:]
1396
+ decode_k_pe = k_pe[num_prefill_tokens:]
1397
+ decode_input_positions = \
1398
+ attn_metadata.input_positions[num_prefill_tokens:]
1399
+
1400
+ prefill_hs_or_q_c = hidden_states_or_q_c[:num_prefill_tokens]
1401
+ prefill_k_pe = k_pe[:num_prefill_tokens]
1402
+ prefill_input_positions = \
1403
+ attn_metadata.input_positions[:num_prefill_tokens]
1404
+ prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
1405
+
1406
+ if has_decode:
1407
+ decode_ql_nope, decode_q_pe = \
1408
+ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
1409
+ decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
1410
+ decode_input_positions, decode_q_pe, decode_k_pe)
1411
+
1412
+ if has_prefill:
1413
+ prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
1414
+ .view(-1, self.num_heads, self.qk_head_dim)
1415
+ prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
1416
+ prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
1417
+ prefill_input_positions, prefill_q_pe, prefill_k_pe)
1418
+
1419
+ # write the latent and rope to kv cache
1420
+ if kv_cache.numel() > 0:
1421
+ ops.concat_and_cache_mla(
1422
+ k_c_normed,
1423
+ k_pe.squeeze(1),
1424
+ kv_cache,
1425
+ attn_metadata.slot_mapping.flatten(),
1426
+ kv_cache_dtype=self.kv_cache_dtype,
1427
+ scale=layer._k_scale,
1428
+ )
1429
+
1430
+ output = torch.empty(attn_metadata.num_prefill_tokens +
1431
+ attn_metadata.num_decode_tokens,
1432
+ self.o_proj.output_size,
1433
+ device=hidden_states_or_q_c.device,
1434
+ dtype=hidden_states_or_q_c.dtype)
1435
+ if has_prefill:
1436
+ output[:num_prefill_tokens] = self._forward_prefill(
1437
+ prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
1438
+ attn_metadata)
1439
+
1440
+ if has_decode:
1441
+ output[num_prefill_tokens:] = self._forward_decode(
1442
+ decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
1443
+
1444
+ return output