vllm-cpu-amxbf16 0.9.1__cp312-cp312-manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (1197) hide show
  1. vllm/_C.abi3.so +0 -0
  2. vllm/__init__.py +53 -0
  3. vllm/_custom_ops.py +1828 -0
  4. vllm/_ipex_ops.py +244 -0
  5. vllm/_version.py +34 -0
  6. vllm/adapter_commons/__init__.py +0 -0
  7. vllm/adapter_commons/layers.py +16 -0
  8. vllm/adapter_commons/models.py +106 -0
  9. vllm/adapter_commons/request.py +26 -0
  10. vllm/adapter_commons/utils.py +93 -0
  11. vllm/adapter_commons/worker_manager.py +39 -0
  12. vllm/assets/__init__.py +0 -0
  13. vllm/assets/audio.py +45 -0
  14. vllm/assets/base.py +41 -0
  15. vllm/assets/image.py +34 -0
  16. vllm/assets/video.py +115 -0
  17. vllm/attention/__init__.py +20 -0
  18. vllm/attention/backends/__init__.py +0 -0
  19. vllm/attention/backends/abstract.py +308 -0
  20. vllm/attention/backends/blocksparse_attn.py +461 -0
  21. vllm/attention/backends/cpu_mla.py +307 -0
  22. vllm/attention/backends/dual_chunk_flash_attn.py +1498 -0
  23. vllm/attention/backends/flash_attn.py +1003 -0
  24. vllm/attention/backends/flashinfer.py +1104 -0
  25. vllm/attention/backends/flashmla.py +244 -0
  26. vllm/attention/backends/hpu_attn.py +313 -0
  27. vllm/attention/backends/ipex_attn.py +398 -0
  28. vllm/attention/backends/mla/__init__.py +0 -0
  29. vllm/attention/backends/mla/common.py +1385 -0
  30. vllm/attention/backends/pallas.py +351 -0
  31. vllm/attention/backends/placeholder_attn.py +400 -0
  32. vllm/attention/backends/rocm_aiter_mla.py +435 -0
  33. vllm/attention/backends/rocm_flash_attn.py +975 -0
  34. vllm/attention/backends/torch_sdpa.py +703 -0
  35. vllm/attention/backends/triton_mla.py +115 -0
  36. vllm/attention/backends/utils.py +610 -0
  37. vllm/attention/backends/xformers.py +802 -0
  38. vllm/attention/layer.py +468 -0
  39. vllm/attention/ops/__init__.py +0 -0
  40. vllm/attention/ops/blocksparse_attention/__init__.py +0 -0
  41. vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +433 -0
  42. vllm/attention/ops/blocksparse_attention/interface.py +239 -0
  43. vllm/attention/ops/blocksparse_attention/utils.py +246 -0
  44. vllm/attention/ops/chunked_prefill_paged_decode.py +368 -0
  45. vllm/attention/ops/flashmla.py +116 -0
  46. vllm/attention/ops/hpu_paged_attn.py +88 -0
  47. vllm/attention/ops/ipex_attn.py +195 -0
  48. vllm/attention/ops/merge_attn_states.py +43 -0
  49. vllm/attention/ops/nki_flash_attn.py +906 -0
  50. vllm/attention/ops/paged_attn.py +256 -0
  51. vllm/attention/ops/prefix_prefill.py +902 -0
  52. vllm/attention/ops/rocm_aiter_mla.py +100 -0
  53. vllm/attention/ops/rocm_aiter_paged_attn.py +102 -0
  54. vllm/attention/ops/triton_decode_attention.py +674 -0
  55. vllm/attention/ops/triton_flash_attention.py +979 -0
  56. vllm/attention/ops/triton_merge_attn_states.py +97 -0
  57. vllm/attention/ops/triton_unified_attention.py +334 -0
  58. vllm/attention/selector.py +187 -0
  59. vllm/attention/utils/fa_utils.py +55 -0
  60. vllm/beam_search.py +87 -0
  61. vllm/benchmarks/__init__.py +0 -0
  62. vllm/benchmarks/datasets.py +1185 -0
  63. vllm/benchmarks/endpoint_request_func.py +381 -0
  64. vllm/benchmarks/latency.py +168 -0
  65. vllm/benchmarks/serve.py +1135 -0
  66. vllm/benchmarks/throughput.py +609 -0
  67. vllm/benchmarks/utils.py +70 -0
  68. vllm/collect_env.py +820 -0
  69. vllm/compilation/__init__.py +0 -0
  70. vllm/compilation/activation_quant_fusion.py +89 -0
  71. vllm/compilation/backends.py +563 -0
  72. vllm/compilation/base_piecewise_backend.py +72 -0
  73. vllm/compilation/collective_fusion.py +127 -0
  74. vllm/compilation/compiler_interface.py +544 -0
  75. vllm/compilation/counter.py +38 -0
  76. vllm/compilation/cuda_piecewise_backend.py +214 -0
  77. vllm/compilation/decorators.py +250 -0
  78. vllm/compilation/fix_functionalization.py +191 -0
  79. vllm/compilation/fusion.py +618 -0
  80. vllm/compilation/fx_utils.py +62 -0
  81. vllm/compilation/inductor_pass.py +115 -0
  82. vllm/compilation/monitor.py +39 -0
  83. vllm/compilation/multi_output_match.py +109 -0
  84. vllm/compilation/noop_elimination.py +137 -0
  85. vllm/compilation/pass_manager.py +78 -0
  86. vllm/compilation/sequence_parallelism.py +268 -0
  87. vllm/compilation/torch25_custom_graph_pass.py +42 -0
  88. vllm/compilation/vllm_inductor_pass.py +67 -0
  89. vllm/compilation/wrapper.py +135 -0
  90. vllm/config.py +4746 -0
  91. vllm/connections.py +174 -0
  92. vllm/core/__init__.py +0 -0
  93. vllm/core/block/__init__.py +0 -0
  94. vllm/core/block/block_table.py +399 -0
  95. vllm/core/block/common.py +371 -0
  96. vllm/core/block/cpu_gpu_block_allocator.py +441 -0
  97. vllm/core/block/interfaces.py +319 -0
  98. vllm/core/block/naive_block.py +466 -0
  99. vllm/core/block/prefix_caching_block.py +1135 -0
  100. vllm/core/block/utils.py +28 -0
  101. vllm/core/block_manager.py +521 -0
  102. vllm/core/evictor.py +157 -0
  103. vllm/core/interfaces.py +135 -0
  104. vllm/core/placeholder_block_space_manager.py +100 -0
  105. vllm/core/scheduler.py +2093 -0
  106. vllm/device_allocator/__init__.py +0 -0
  107. vllm/device_allocator/cumem.py +281 -0
  108. vllm/distributed/__init__.py +6 -0
  109. vllm/distributed/communication_op.py +41 -0
  110. vllm/distributed/device_communicators/__init__.py +0 -0
  111. vllm/distributed/device_communicators/all2all.py +264 -0
  112. vllm/distributed/device_communicators/base_device_communicator.py +260 -0
  113. vllm/distributed/device_communicators/cpu_communicator.py +145 -0
  114. vllm/distributed/device_communicators/cuda_communicator.py +176 -0
  115. vllm/distributed/device_communicators/cuda_wrapper.py +180 -0
  116. vllm/distributed/device_communicators/custom_all_reduce.py +304 -0
  117. vllm/distributed/device_communicators/custom_all_reduce_utils.py +259 -0
  118. vllm/distributed/device_communicators/hpu_communicator.py +46 -0
  119. vllm/distributed/device_communicators/neuron_communicator.py +20 -0
  120. vllm/distributed/device_communicators/pynccl.py +218 -0
  121. vllm/distributed/device_communicators/pynccl_wrapper.py +341 -0
  122. vllm/distributed/device_communicators/shm_broadcast.py +585 -0
  123. vllm/distributed/device_communicators/tpu_communicator.py +103 -0
  124. vllm/distributed/device_communicators/xpu_communicator.py +55 -0
  125. vllm/distributed/kv_events.py +356 -0
  126. vllm/distributed/kv_transfer/README.md +29 -0
  127. vllm/distributed/kv_transfer/__init__.py +12 -0
  128. vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg +0 -0
  129. vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
  130. vllm/distributed/kv_transfer/kv_connector/base.py +128 -0
  131. vllm/distributed/kv_transfer/kv_connector/factory.py +128 -0
  132. vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +99 -0
  133. vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +203 -0
  134. vllm/distributed/kv_transfer/kv_connector/simple_connector.py +329 -0
  135. vllm/distributed/kv_transfer/kv_connector/utils.py +108 -0
  136. vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +6 -0
  137. vllm/distributed/kv_transfer/kv_connector/v1/base.py +283 -0
  138. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +134 -0
  139. vllm/distributed/kv_transfer/kv_connector/v1/multi_connector.py +201 -0
  140. vllm/distributed/kv_transfer/kv_connector/v1/nixl_connector.py +1030 -0
  141. vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +384 -0
  142. vllm/distributed/kv_transfer/kv_connector_agent.py +77 -0
  143. vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
  144. vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +175 -0
  145. vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +161 -0
  146. vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +237 -0
  147. vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
  148. vllm/distributed/kv_transfer/kv_pipe/base.py +67 -0
  149. vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +280 -0
  150. vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +280 -0
  151. vllm/distributed/kv_transfer/kv_transfer_state.py +71 -0
  152. vllm/distributed/parallel_state.py +1296 -0
  153. vllm/distributed/tpu_distributed_utils.py +177 -0
  154. vllm/distributed/utils.py +536 -0
  155. vllm/engine/__init__.py +0 -0
  156. vllm/engine/arg_utils.py +1708 -0
  157. vllm/engine/async_llm_engine.py +1200 -0
  158. vllm/engine/async_timeout.py +173 -0
  159. vllm/engine/llm_engine.py +2097 -0
  160. vllm/engine/metrics.py +629 -0
  161. vllm/engine/metrics_types.py +94 -0
  162. vllm/engine/multiprocessing/__init__.py +148 -0
  163. vllm/engine/multiprocessing/client.py +681 -0
  164. vllm/engine/multiprocessing/engine.py +460 -0
  165. vllm/engine/output_processor/__init__.py +0 -0
  166. vllm/engine/output_processor/interfaces.py +75 -0
  167. vllm/engine/output_processor/multi_step.py +216 -0
  168. vllm/engine/output_processor/single_step.py +145 -0
  169. vllm/engine/output_processor/stop_checker.py +131 -0
  170. vllm/engine/output_processor/util.py +28 -0
  171. vllm/engine/protocol.py +317 -0
  172. vllm/entrypoints/__init__.py +0 -0
  173. vllm/entrypoints/api_server.py +178 -0
  174. vllm/entrypoints/chat_utils.py +1299 -0
  175. vllm/entrypoints/cli/__init__.py +0 -0
  176. vllm/entrypoints/cli/benchmark/__init__.py +0 -0
  177. vllm/entrypoints/cli/benchmark/base.py +39 -0
  178. vllm/entrypoints/cli/benchmark/latency.py +30 -0
  179. vllm/entrypoints/cli/benchmark/main.py +54 -0
  180. vllm/entrypoints/cli/benchmark/serve.py +30 -0
  181. vllm/entrypoints/cli/benchmark/throughput.py +30 -0
  182. vllm/entrypoints/cli/collect_env.py +35 -0
  183. vllm/entrypoints/cli/main.py +65 -0
  184. vllm/entrypoints/cli/openai.py +205 -0
  185. vllm/entrypoints/cli/run_batch.py +62 -0
  186. vllm/entrypoints/cli/serve.py +328 -0
  187. vllm/entrypoints/cli/types.py +25 -0
  188. vllm/entrypoints/launcher.py +147 -0
  189. vllm/entrypoints/llm.py +1544 -0
  190. vllm/entrypoints/logger.py +50 -0
  191. vllm/entrypoints/openai/__init__.py +0 -0
  192. vllm/entrypoints/openai/api_server.py +1387 -0
  193. vllm/entrypoints/openai/cli_args.py +315 -0
  194. vllm/entrypoints/openai/logits_processors.py +90 -0
  195. vllm/entrypoints/openai/protocol.py +1913 -0
  196. vllm/entrypoints/openai/run_batch.py +463 -0
  197. vllm/entrypoints/openai/serving_chat.py +1221 -0
  198. vllm/entrypoints/openai/serving_classification.py +160 -0
  199. vllm/entrypoints/openai/serving_completion.py +592 -0
  200. vllm/entrypoints/openai/serving_embedding.py +201 -0
  201. vllm/entrypoints/openai/serving_engine.py +986 -0
  202. vllm/entrypoints/openai/serving_models.py +315 -0
  203. vllm/entrypoints/openai/serving_pooling.py +232 -0
  204. vllm/entrypoints/openai/serving_score.py +433 -0
  205. vllm/entrypoints/openai/serving_tokenization.py +157 -0
  206. vllm/entrypoints/openai/serving_transcription.py +424 -0
  207. vllm/entrypoints/openai/tool_parsers/__init__.py +23 -0
  208. vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +164 -0
  209. vllm/entrypoints/openai/tool_parsers/deepseekv3_tool_parser.py +370 -0
  210. vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +259 -0
  211. vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +237 -0
  212. vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +371 -0
  213. vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +216 -0
  214. vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +308 -0
  215. vllm/entrypoints/openai/tool_parsers/llama4_pythonic_tool_parser.py +316 -0
  216. vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +267 -0
  217. vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +369 -0
  218. vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +112 -0
  219. vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +308 -0
  220. vllm/entrypoints/openai/tool_parsers/utils.py +124 -0
  221. vllm/entrypoints/score_utils.py +50 -0
  222. vllm/entrypoints/ssl.py +75 -0
  223. vllm/entrypoints/utils.py +233 -0
  224. vllm/env_override.py +41 -0
  225. vllm/envs.py +944 -0
  226. vllm/executor/__init__.py +0 -0
  227. vllm/executor/executor_base.py +401 -0
  228. vllm/executor/mp_distributed_executor.py +244 -0
  229. vllm/executor/msgspec_utils.py +30 -0
  230. vllm/executor/multiproc_worker_utils.py +313 -0
  231. vllm/executor/ray_distributed_executor.py +701 -0
  232. vllm/executor/ray_utils.py +399 -0
  233. vllm/executor/uniproc_executor.py +139 -0
  234. vllm/forward_context.py +179 -0
  235. vllm/inputs/__init__.py +41 -0
  236. vllm/inputs/data.py +331 -0
  237. vllm/inputs/parse.py +151 -0
  238. vllm/inputs/preprocess.py +909 -0
  239. vllm/inputs/registry.py +237 -0
  240. vllm/jsontree.py +80 -0
  241. vllm/logger.py +212 -0
  242. vllm/logging_utils/__init__.py +8 -0
  243. vllm/logging_utils/dump_input.py +85 -0
  244. vllm/logging_utils/formatter.py +18 -0
  245. vllm/logits_process.py +119 -0
  246. vllm/lora/__init__.py +0 -0
  247. vllm/lora/fully_sharded_layers.py +355 -0
  248. vllm/lora/layers.py +1285 -0
  249. vllm/lora/lora.py +199 -0
  250. vllm/lora/models.py +818 -0
  251. vllm/lora/ops/__init__.py +0 -0
  252. vllm/lora/ops/torch_ops/__init__.py +16 -0
  253. vllm/lora/ops/torch_ops/lora_ops.py +119 -0
  254. vllm/lora/ops/triton_ops/__init__.py +12 -0
  255. vllm/lora/ops/triton_ops/kernel_utils.py +243 -0
  256. vllm/lora/ops/triton_ops/lora_expand_op.py +290 -0
  257. vllm/lora/ops/triton_ops/lora_kernel_metadata.py +148 -0
  258. vllm/lora/ops/triton_ops/lora_shrink_op.py +244 -0
  259. vllm/lora/ops/triton_ops/utils.py +120 -0
  260. vllm/lora/ops/xla_ops/__init__.py +7 -0
  261. vllm/lora/ops/xla_ops/lora_ops.py +145 -0
  262. vllm/lora/peft_helper.py +136 -0
  263. vllm/lora/punica_wrapper/__init__.py +10 -0
  264. vllm/lora/punica_wrapper/punica_base.py +485 -0
  265. vllm/lora/punica_wrapper/punica_cpu.py +349 -0
  266. vllm/lora/punica_wrapper/punica_gpu.py +290 -0
  267. vllm/lora/punica_wrapper/punica_hpu.py +145 -0
  268. vllm/lora/punica_wrapper/punica_selector.py +20 -0
  269. vllm/lora/punica_wrapper/punica_tpu.py +405 -0
  270. vllm/lora/punica_wrapper/utils.py +164 -0
  271. vllm/lora/request.py +99 -0
  272. vllm/lora/resolver.py +85 -0
  273. vllm/lora/utils.py +240 -0
  274. vllm/lora/worker_manager.py +259 -0
  275. vllm/model_executor/__init__.py +16 -0
  276. vllm/model_executor/custom_op.py +152 -0
  277. vllm/model_executor/guided_decoding/__init__.py +181 -0
  278. vllm/model_executor/guided_decoding/guidance_decoding.py +63 -0
  279. vllm/model_executor/guided_decoding/guidance_logits_processors.py +104 -0
  280. vllm/model_executor/guided_decoding/guided_fields.py +41 -0
  281. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +67 -0
  282. vllm/model_executor/guided_decoding/outlines_decoding.py +155 -0
  283. vllm/model_executor/guided_decoding/outlines_logits_processors.py +284 -0
  284. vllm/model_executor/guided_decoding/utils.py +242 -0
  285. vllm/model_executor/guided_decoding/xgrammar_decoding.py +426 -0
  286. vllm/model_executor/layers/__init__.py +0 -0
  287. vllm/model_executor/layers/activation.py +369 -0
  288. vllm/model_executor/layers/fused_moe/__init__.py +54 -0
  289. vllm/model_executor/layers/fused_moe/batched_deep_gemm_moe.py +125 -0
  290. vllm/model_executor/layers/fused_moe/batched_triton_or_deep_gemm_moe.py +117 -0
  291. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  292. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  293. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  294. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  295. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  296. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  297. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  298. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  299. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  300. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  301. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  302. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  303. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  304. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  305. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  306. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20-3e.json +146 -0
  307. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  308. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  309. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  310. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  311. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20-3e.json +146 -0
  312. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  313. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  314. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  315. vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  316. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  317. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  318. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  319. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  320. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  321. vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  322. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  323. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json +146 -0
  324. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  325. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  326. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  327. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  328. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  329. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  330. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  331. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  332. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  333. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  334. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  335. vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  336. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  337. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  338. vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  339. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  340. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  341. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  342. vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  343. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  344. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
  345. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  346. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  347. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  348. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  349. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  350. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  351. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  352. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  353. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  354. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  355. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  356. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  357. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  358. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  359. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  360. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  361. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  362. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  363. vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  364. vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
  365. vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
  366. vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
  367. vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
  368. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  369. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  370. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  371. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  372. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  373. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  374. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  375. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  376. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  377. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  378. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  379. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  380. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  381. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  382. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  383. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  384. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  385. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  386. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  387. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  388. vllm/model_executor/layers/fused_moe/configs/E=64,N=896,device_name=NVIDIA_H20.json +146 -0
  389. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  390. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  391. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  392. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  393. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  394. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  395. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  396. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  397. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json +200 -0
  398. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  399. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json +200 -0
  400. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  401. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  402. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  403. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  404. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  405. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  406. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  407. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  408. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  409. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  410. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json +200 -0
  411. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  412. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json +200 -0
  413. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  414. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  415. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  416. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  417. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  418. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  419. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  420. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  421. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  422. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  423. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  424. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  425. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  426. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  427. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  428. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  429. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  430. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  431. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json +200 -0
  432. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  433. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json +200 -0
  434. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  435. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  436. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  437. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  438. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  439. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  440. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  441. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  442. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  443. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  444. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  445. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  446. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  447. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  448. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  449. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json +200 -0
  450. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  451. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json +200 -0
  452. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  453. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  454. vllm/model_executor/layers/fused_moe/configs/README +12 -0
  455. vllm/model_executor/layers/fused_moe/cutlass_moe.py +461 -0
  456. vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +240 -0
  457. vllm/model_executor/layers/fused_moe/deepep_ht_prepare_finalize.py +240 -0
  458. vllm/model_executor/layers/fused_moe/deepep_ll_prepare_finalize.py +186 -0
  459. vllm/model_executor/layers/fused_moe/fused_batched_moe.py +775 -0
  460. vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +232 -0
  461. vllm/model_executor/layers/fused_moe/fused_moe.py +1724 -0
  462. vllm/model_executor/layers/fused_moe/layer.py +1535 -0
  463. vllm/model_executor/layers/fused_moe/modular_kernel.py +446 -0
  464. vllm/model_executor/layers/fused_moe/moe_align_block_size.py +243 -0
  465. vllm/model_executor/layers/fused_moe/moe_pallas.py +80 -0
  466. vllm/model_executor/layers/fused_moe/moe_permute_unpermute.py +190 -0
  467. vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +60 -0
  468. vllm/model_executor/layers/fused_moe/pplx_prepare_finalize.py +159 -0
  469. vllm/model_executor/layers/fused_moe/prepare_finalize.py +69 -0
  470. vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +421 -0
  471. vllm/model_executor/layers/fused_moe/triton_deep_gemm_moe.py +117 -0
  472. vllm/model_executor/layers/fused_moe/utils.py +98 -0
  473. vllm/model_executor/layers/layernorm.py +288 -0
  474. vllm/model_executor/layers/lightning_attn.py +652 -0
  475. vllm/model_executor/layers/linear.py +1524 -0
  476. vllm/model_executor/layers/logits_processor.py +197 -0
  477. vllm/model_executor/layers/mamba/__init__.py +0 -0
  478. vllm/model_executor/layers/mamba/mamba2_metadata.py +125 -0
  479. vllm/model_executor/layers/mamba/mamba_mixer.py +245 -0
  480. vllm/model_executor/layers/mamba/mamba_mixer2.py +616 -0
  481. vllm/model_executor/layers/mamba/ops/__init__.py +0 -0
  482. vllm/model_executor/layers/mamba/ops/causal_conv1d.py +105 -0
  483. vllm/model_executor/layers/mamba/ops/mamba_ssm.py +414 -0
  484. vllm/model_executor/layers/mamba/ops/ssd_bmm.py +262 -0
  485. vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +589 -0
  486. vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +751 -0
  487. vllm/model_executor/layers/mamba/ops/ssd_combined.py +232 -0
  488. vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +206 -0
  489. vllm/model_executor/layers/pooler.py +350 -0
  490. vllm/model_executor/layers/quantization/__init__.py +157 -0
  491. vllm/model_executor/layers/quantization/aqlm.py +376 -0
  492. vllm/model_executor/layers/quantization/auto_round.py +310 -0
  493. vllm/model_executor/layers/quantization/awq.py +194 -0
  494. vllm/model_executor/layers/quantization/awq_marlin.py +519 -0
  495. vllm/model_executor/layers/quantization/awq_triton.py +320 -0
  496. vllm/model_executor/layers/quantization/base_config.py +151 -0
  497. vllm/model_executor/layers/quantization/bitblas.py +461 -0
  498. vllm/model_executor/layers/quantization/bitsandbytes.py +396 -0
  499. vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +0 -0
  500. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +668 -0
  501. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +1260 -0
  502. vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +24 -0
  503. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +358 -0
  504. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +55 -0
  505. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +160 -0
  506. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_nvfp4.py +93 -0
  507. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a4_nvfp4.py +178 -0
  508. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +121 -0
  509. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +150 -0
  510. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +111 -0
  511. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +201 -0
  512. vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +206 -0
  513. vllm/model_executor/layers/quantization/compressed_tensors/utils.py +216 -0
  514. vllm/model_executor/layers/quantization/deepspeedfp.py +195 -0
  515. vllm/model_executor/layers/quantization/experts_int8.py +196 -0
  516. vllm/model_executor/layers/quantization/fbgemm_fp8.py +172 -0
  517. vllm/model_executor/layers/quantization/fp8.py +906 -0
  518. vllm/model_executor/layers/quantization/gguf.py +565 -0
  519. vllm/model_executor/layers/quantization/gptq.py +278 -0
  520. vllm/model_executor/layers/quantization/gptq_bitblas.py +445 -0
  521. vllm/model_executor/layers/quantization/gptq_marlin.py +648 -0
  522. vllm/model_executor/layers/quantization/gptq_marlin_24.py +297 -0
  523. vllm/model_executor/layers/quantization/hqq_marlin.py +332 -0
  524. vllm/model_executor/layers/quantization/ipex_quant.py +250 -0
  525. vllm/model_executor/layers/quantization/kernels/__init__.py +0 -0
  526. vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +90 -0
  527. vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +83 -0
  528. vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +116 -0
  529. vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +300 -0
  530. vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +143 -0
  531. vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +120 -0
  532. vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +131 -0
  533. vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +67 -0
  534. vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +87 -0
  535. vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +120 -0
  536. vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +137 -0
  537. vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +41 -0
  538. vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +105 -0
  539. vllm/model_executor/layers/quantization/kv_cache.py +139 -0
  540. vllm/model_executor/layers/quantization/marlin.py +261 -0
  541. vllm/model_executor/layers/quantization/modelopt.py +737 -0
  542. vllm/model_executor/layers/quantization/moe_wna16.py +449 -0
  543. vllm/model_executor/layers/quantization/neuron_quant.py +76 -0
  544. vllm/model_executor/layers/quantization/ptpc_fp8.py +127 -0
  545. vllm/model_executor/layers/quantization/qqq.py +275 -0
  546. vllm/model_executor/layers/quantization/quark/__init__.py +0 -0
  547. vllm/model_executor/layers/quantization/quark/quark.py +441 -0
  548. vllm/model_executor/layers/quantization/quark/quark_moe.py +237 -0
  549. vllm/model_executor/layers/quantization/quark/schemes/__init__.py +9 -0
  550. vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +55 -0
  551. vllm/model_executor/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +126 -0
  552. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +146 -0
  553. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +122 -0
  554. vllm/model_executor/layers/quantization/quark/utils.py +105 -0
  555. vllm/model_executor/layers/quantization/schema.py +86 -0
  556. vllm/model_executor/layers/quantization/torchao.py +161 -0
  557. vllm/model_executor/layers/quantization/tpu_int8.py +121 -0
  558. vllm/model_executor/layers/quantization/utils/__init__.py +6 -0
  559. vllm/model_executor/layers/quantization/utils/allspark_utils.py +52 -0
  560. vllm/model_executor/layers/quantization/utils/bitblas_utils.py +208 -0
  561. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  562. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  563. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  564. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  565. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  566. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  567. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  568. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  569. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  570. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  571. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  572. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  573. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  574. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  575. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  576. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  577. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  578. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  579. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  580. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  581. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  582. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  583. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  584. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  585. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  586. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  587. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  588. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  589. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  590. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  591. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  592. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  593. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  594. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  595. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  596. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  597. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  598. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  599. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  600. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  601. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  602. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  603. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  604. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  605. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  606. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  607. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  608. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  609. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  610. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  611. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  612. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  613. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  614. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  615. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  616. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  617. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  618. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  619. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  620. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  621. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  622. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  623. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  624. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  625. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  626. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  627. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  628. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  629. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  630. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  631. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  632. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  633. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  634. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  635. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  636. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  637. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  638. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  639. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  640. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  641. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  642. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  643. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  644. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  645. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  646. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  647. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  648. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  649. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  650. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  651. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  652. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  653. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  654. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  655. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  656. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  657. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  658. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  659. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  660. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  661. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  662. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  663. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  664. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  665. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  666. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  667. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  668. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  669. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  670. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  671. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  672. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  673. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  674. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  675. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  676. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  677. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  678. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  679. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  680. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  681. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  682. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +18 -0
  683. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  684. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  685. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  686. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  687. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  688. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  689. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  690. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  691. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  692. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  693. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  694. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  695. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  696. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  697. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  698. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  699. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  700. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  701. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  702. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  703. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  704. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  705. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  706. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  707. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  708. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  709. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  710. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  711. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  712. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  713. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  714. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  715. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  716. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  717. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  718. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  719. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  720. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  721. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  722. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  723. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  724. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  725. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  726. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  727. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  728. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  729. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  730. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  731. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  732. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  733. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  734. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  735. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  736. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  737. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  738. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  739. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  740. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  741. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  742. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  743. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  744. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  745. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  746. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  747. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  748. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  749. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  750. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  751. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  752. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  753. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  754. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  755. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  756. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  757. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  758. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  759. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  760. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  761. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  762. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  763. vllm/model_executor/layers/quantization/utils/fp8_utils.py +618 -0
  764. vllm/model_executor/layers/quantization/utils/gptq_utils.py +95 -0
  765. vllm/model_executor/layers/quantization/utils/int8_utils.py +485 -0
  766. vllm/model_executor/layers/quantization/utils/layer_utils.py +40 -0
  767. vllm/model_executor/layers/quantization/utils/machete_utils.py +33 -0
  768. vllm/model_executor/layers/quantization/utils/marlin_utils.py +476 -0
  769. vllm/model_executor/layers/quantization/utils/marlin_utils_fp4.py +283 -0
  770. vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +325 -0
  771. vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +165 -0
  772. vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +464 -0
  773. vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +126 -0
  774. vllm/model_executor/layers/quantization/utils/mxfp4_utils.py +45 -0
  775. vllm/model_executor/layers/quantization/utils/nvfp4_emulation_utils.py +104 -0
  776. vllm/model_executor/layers/quantization/utils/quant_utils.py +573 -0
  777. vllm/model_executor/layers/quantization/utils/w8a8_utils.py +405 -0
  778. vllm/model_executor/layers/rejection_sampler.py +406 -0
  779. vllm/model_executor/layers/resampler.py +270 -0
  780. vllm/model_executor/layers/rotary_embedding.py +1862 -0
  781. vllm/model_executor/layers/sampler.py +1204 -0
  782. vllm/model_executor/layers/spec_decode_base_sampler.py +259 -0
  783. vllm/model_executor/layers/typical_acceptance_sampler.py +166 -0
  784. vllm/model_executor/layers/utils.py +95 -0
  785. vllm/model_executor/layers/vocab_parallel_embedding.py +487 -0
  786. vllm/model_executor/model_loader/__init__.py +76 -0
  787. vllm/model_executor/model_loader/base_loader.py +43 -0
  788. vllm/model_executor/model_loader/bitsandbytes_loader.py +570 -0
  789. vllm/model_executor/model_loader/default_loader.py +282 -0
  790. vllm/model_executor/model_loader/dummy_loader.py +27 -0
  791. vllm/model_executor/model_loader/gguf_loader.py +120 -0
  792. vllm/model_executor/model_loader/neuron.py +476 -0
  793. vllm/model_executor/model_loader/neuronx_distributed.py +685 -0
  794. vllm/model_executor/model_loader/runai_streamer_loader.py +109 -0
  795. vllm/model_executor/model_loader/sharded_state_loader.py +201 -0
  796. vllm/model_executor/model_loader/tensorizer.py +600 -0
  797. vllm/model_executor/model_loader/tensorizer_loader.py +123 -0
  798. vllm/model_executor/model_loader/tpu.py +112 -0
  799. vllm/model_executor/model_loader/utils.py +302 -0
  800. vllm/model_executor/model_loader/weight_utils.py +782 -0
  801. vllm/model_executor/models/__init__.py +28 -0
  802. vllm/model_executor/models/adapters.py +248 -0
  803. vllm/model_executor/models/aimv2.py +246 -0
  804. vllm/model_executor/models/arctic.py +559 -0
  805. vllm/model_executor/models/aria.py +657 -0
  806. vllm/model_executor/models/aya_vision.py +466 -0
  807. vllm/model_executor/models/baichuan.py +474 -0
  808. vllm/model_executor/models/bamba.py +543 -0
  809. vllm/model_executor/models/bart.py +938 -0
  810. vllm/model_executor/models/bert.py +523 -0
  811. vllm/model_executor/models/bert_with_rope.py +769 -0
  812. vllm/model_executor/models/blip.py +339 -0
  813. vllm/model_executor/models/blip2.py +718 -0
  814. vllm/model_executor/models/bloom.py +373 -0
  815. vllm/model_executor/models/chameleon.py +1136 -0
  816. vllm/model_executor/models/chatglm.py +478 -0
  817. vllm/model_executor/models/clip.py +407 -0
  818. vllm/model_executor/models/commandr.py +472 -0
  819. vllm/model_executor/models/constant_size_cache.py +137 -0
  820. vllm/model_executor/models/dbrx.py +472 -0
  821. vllm/model_executor/models/deepseek.py +486 -0
  822. vllm/model_executor/models/deepseek_mtp.py +269 -0
  823. vllm/model_executor/models/deepseek_v2.py +843 -0
  824. vllm/model_executor/models/deepseek_vl2.py +648 -0
  825. vllm/model_executor/models/eagle.py +260 -0
  826. vllm/model_executor/models/exaone.py +551 -0
  827. vllm/model_executor/models/fairseq2_llama.py +154 -0
  828. vllm/model_executor/models/falcon.py +510 -0
  829. vllm/model_executor/models/falcon_h1.py +685 -0
  830. vllm/model_executor/models/florence2.py +1103 -0
  831. vllm/model_executor/models/fuyu.py +389 -0
  832. vllm/model_executor/models/gemma.py +425 -0
  833. vllm/model_executor/models/gemma2.py +425 -0
  834. vllm/model_executor/models/gemma3.py +533 -0
  835. vllm/model_executor/models/gemma3_mm.py +709 -0
  836. vllm/model_executor/models/glm.py +23 -0
  837. vllm/model_executor/models/glm4.py +305 -0
  838. vllm/model_executor/models/glm4v.py +648 -0
  839. vllm/model_executor/models/gpt2.py +328 -0
  840. vllm/model_executor/models/gpt_bigcode.py +335 -0
  841. vllm/model_executor/models/gpt_j.py +339 -0
  842. vllm/model_executor/models/gpt_neox.py +332 -0
  843. vllm/model_executor/models/granite.py +493 -0
  844. vllm/model_executor/models/granite_speech.py +779 -0
  845. vllm/model_executor/models/granitemoe.py +437 -0
  846. vllm/model_executor/models/granitemoehybrid.py +586 -0
  847. vllm/model_executor/models/granitemoeshared.py +341 -0
  848. vllm/model_executor/models/gritlm.py +224 -0
  849. vllm/model_executor/models/grok1.py +546 -0
  850. vllm/model_executor/models/h2ovl.py +546 -0
  851. vllm/model_executor/models/idefics2_vision_model.py +389 -0
  852. vllm/model_executor/models/idefics3.py +776 -0
  853. vllm/model_executor/models/interfaces.py +572 -0
  854. vllm/model_executor/models/interfaces_base.py +164 -0
  855. vllm/model_executor/models/intern_vit.py +480 -0
  856. vllm/model_executor/models/internlm2.py +455 -0
  857. vllm/model_executor/models/internlm2_ve.py +147 -0
  858. vllm/model_executor/models/internvl.py +1418 -0
  859. vllm/model_executor/models/jais.py +373 -0
  860. vllm/model_executor/models/jamba.py +592 -0
  861. vllm/model_executor/models/kimi_vl.py +577 -0
  862. vllm/model_executor/models/llama.py +644 -0
  863. vllm/model_executor/models/llama4.py +532 -0
  864. vllm/model_executor/models/llama_eagle.py +165 -0
  865. vllm/model_executor/models/llama_eagle3.py +263 -0
  866. vllm/model_executor/models/llava.py +866 -0
  867. vllm/model_executor/models/llava_next.py +586 -0
  868. vllm/model_executor/models/llava_next_video.py +471 -0
  869. vllm/model_executor/models/llava_onevision.py +956 -0
  870. vllm/model_executor/models/mamba.py +273 -0
  871. vllm/model_executor/models/mamba2.py +308 -0
  872. vllm/model_executor/models/mamba_cache.py +76 -0
  873. vllm/model_executor/models/medusa.py +219 -0
  874. vllm/model_executor/models/mimo.py +192 -0
  875. vllm/model_executor/models/mimo_mtp.py +285 -0
  876. vllm/model_executor/models/minicpm.py +592 -0
  877. vllm/model_executor/models/minicpm3.py +230 -0
  878. vllm/model_executor/models/minicpm_eagle.py +391 -0
  879. vllm/model_executor/models/minicpmo.py +759 -0
  880. vllm/model_executor/models/minicpmv.py +1287 -0
  881. vllm/model_executor/models/minimax_cache.py +36 -0
  882. vllm/model_executor/models/minimax_text_01.py +1301 -0
  883. vllm/model_executor/models/minimax_vl_01.py +364 -0
  884. vllm/model_executor/models/mistral3.py +604 -0
  885. vllm/model_executor/models/mixtral.py +488 -0
  886. vllm/model_executor/models/mixtral_quant.py +453 -0
  887. vllm/model_executor/models/mllama.py +1624 -0
  888. vllm/model_executor/models/mllama4.py +938 -0
  889. vllm/model_executor/models/mlp_speculator.py +206 -0
  890. vllm/model_executor/models/modernbert.py +331 -0
  891. vllm/model_executor/models/module_mapping.py +72 -0
  892. vllm/model_executor/models/molmo.py +1568 -0
  893. vllm/model_executor/models/moonvit.py +630 -0
  894. vllm/model_executor/models/mpt.py +331 -0
  895. vllm/model_executor/models/nemotron.py +508 -0
  896. vllm/model_executor/models/nemotron_h.py +573 -0
  897. vllm/model_executor/models/nemotron_nas.py +484 -0
  898. vllm/model_executor/models/nvlm_d.py +216 -0
  899. vllm/model_executor/models/olmo.py +389 -0
  900. vllm/model_executor/models/olmo2.py +414 -0
  901. vllm/model_executor/models/olmoe.py +468 -0
  902. vllm/model_executor/models/opt.py +412 -0
  903. vllm/model_executor/models/orion.py +349 -0
  904. vllm/model_executor/models/ovis.py +567 -0
  905. vllm/model_executor/models/paligemma.py +398 -0
  906. vllm/model_executor/models/persimmon.py +344 -0
  907. vllm/model_executor/models/phi.py +356 -0
  908. vllm/model_executor/models/phi3.py +19 -0
  909. vllm/model_executor/models/phi3_small.py +465 -0
  910. vllm/model_executor/models/phi3v.py +723 -0
  911. vllm/model_executor/models/phi4mm.py +1246 -0
  912. vllm/model_executor/models/phi4mm_audio.py +1233 -0
  913. vllm/model_executor/models/phi4mm_utils.py +1884 -0
  914. vllm/model_executor/models/phimoe.py +665 -0
  915. vllm/model_executor/models/pixtral.py +1316 -0
  916. vllm/model_executor/models/plamo2.py +738 -0
  917. vllm/model_executor/models/prithvi_geospatial_mae.py +232 -0
  918. vllm/model_executor/models/qwen.py +362 -0
  919. vllm/model_executor/models/qwen2.py +497 -0
  920. vllm/model_executor/models/qwen2_5_omni_thinker.py +904 -0
  921. vllm/model_executor/models/qwen2_5_vl.py +1166 -0
  922. vllm/model_executor/models/qwen2_audio.py +410 -0
  923. vllm/model_executor/models/qwen2_moe.py +540 -0
  924. vllm/model_executor/models/qwen2_rm.py +132 -0
  925. vllm/model_executor/models/qwen2_vl.py +1405 -0
  926. vllm/model_executor/models/qwen3.py +321 -0
  927. vllm/model_executor/models/qwen3_moe.py +535 -0
  928. vllm/model_executor/models/qwen_vl.py +785 -0
  929. vllm/model_executor/models/registry.py +622 -0
  930. vllm/model_executor/models/roberta.py +276 -0
  931. vllm/model_executor/models/siglip.py +524 -0
  932. vllm/model_executor/models/skyworkr1v.py +951 -0
  933. vllm/model_executor/models/smolvlm.py +52 -0
  934. vllm/model_executor/models/solar.py +506 -0
  935. vllm/model_executor/models/stablelm.py +343 -0
  936. vllm/model_executor/models/starcoder2.py +356 -0
  937. vllm/model_executor/models/tarsier.py +643 -0
  938. vllm/model_executor/models/telechat2.py +140 -0
  939. vllm/model_executor/models/teleflm.py +79 -0
  940. vllm/model_executor/models/transformers.py +508 -0
  941. vllm/model_executor/models/ultravox.py +656 -0
  942. vllm/model_executor/models/utils.py +731 -0
  943. vllm/model_executor/models/vision.py +147 -0
  944. vllm/model_executor/models/whisper.py +747 -0
  945. vllm/model_executor/models/zamba2.py +1009 -0
  946. vllm/model_executor/parameter.py +459 -0
  947. vllm/model_executor/pooling_metadata.py +72 -0
  948. vllm/model_executor/sampling_metadata.py +597 -0
  949. vllm/model_executor/utils.py +77 -0
  950. vllm/multimodal/__init__.py +33 -0
  951. vllm/multimodal/audio.py +106 -0
  952. vllm/multimodal/base.py +219 -0
  953. vllm/multimodal/hasher.py +118 -0
  954. vllm/multimodal/image.py +97 -0
  955. vllm/multimodal/inputs.py +876 -0
  956. vllm/multimodal/parse.py +461 -0
  957. vllm/multimodal/processing.py +1895 -0
  958. vllm/multimodal/profiling.py +258 -0
  959. vllm/multimodal/registry.py +331 -0
  960. vllm/multimodal/utils.py +436 -0
  961. vllm/multimodal/video.py +198 -0
  962. vllm/outputs.py +512 -0
  963. vllm/platforms/__init__.py +291 -0
  964. vllm/platforms/cpu.py +266 -0
  965. vllm/platforms/cuda.py +526 -0
  966. vllm/platforms/hpu.py +106 -0
  967. vllm/platforms/interface.py +538 -0
  968. vllm/platforms/neuron.py +150 -0
  969. vllm/platforms/rocm.py +435 -0
  970. vllm/platforms/tpu.py +216 -0
  971. vllm/platforms/xpu.py +156 -0
  972. vllm/plugins/__init__.py +94 -0
  973. vllm/plugins/lora_resolvers/README.md +15 -0
  974. vllm/plugins/lora_resolvers/__init__.py +0 -0
  975. vllm/plugins/lora_resolvers/filesystem_resolver.py +50 -0
  976. vllm/pooling_params.py +54 -0
  977. vllm/profiler/__init__.py +0 -0
  978. vllm/profiler/layerwise_profile.py +375 -0
  979. vllm/profiler/utils.py +148 -0
  980. vllm/prompt_adapter/__init__.py +0 -0
  981. vllm/prompt_adapter/layers.py +83 -0
  982. vllm/prompt_adapter/models.py +358 -0
  983. vllm/prompt_adapter/request.py +37 -0
  984. vllm/prompt_adapter/utils.py +98 -0
  985. vllm/prompt_adapter/worker_manager.py +179 -0
  986. vllm/py.typed +2 -0
  987. vllm/reasoning/__init__.py +15 -0
  988. vllm/reasoning/abs_reasoning_parsers.py +192 -0
  989. vllm/reasoning/deepseek_r1_reasoning_parser.py +173 -0
  990. vllm/reasoning/granite_reasoning_parser.py +363 -0
  991. vllm/reasoning/qwen3_reasoning_parser.py +151 -0
  992. vllm/sampling_params.py +602 -0
  993. vllm/scalar_type.py +347 -0
  994. vllm/scripts.py +15 -0
  995. vllm/sequence.py +1568 -0
  996. vllm/spec_decode/__init__.py +0 -0
  997. vllm/spec_decode/batch_expansion.py +506 -0
  998. vllm/spec_decode/draft_model_runner.py +349 -0
  999. vllm/spec_decode/interfaces.py +99 -0
  1000. vllm/spec_decode/medusa_worker.py +138 -0
  1001. vllm/spec_decode/metrics.py +213 -0
  1002. vllm/spec_decode/mlp_speculator_worker.py +94 -0
  1003. vllm/spec_decode/mqa_scorer.py +160 -0
  1004. vllm/spec_decode/multi_step_worker.py +423 -0
  1005. vllm/spec_decode/ngram_worker.py +196 -0
  1006. vllm/spec_decode/proposer_worker_base.py +59 -0
  1007. vllm/spec_decode/smaller_tp_proposer_worker.py +196 -0
  1008. vllm/spec_decode/spec_decode_worker.py +1326 -0
  1009. vllm/spec_decode/target_model_runner.py +45 -0
  1010. vllm/spec_decode/top1_proposer.py +275 -0
  1011. vllm/spec_decode/util.py +277 -0
  1012. vllm/test_utils.py +130 -0
  1013. vllm/third_party/__init__.py +0 -0
  1014. vllm/third_party/pynvml.py +6140 -0
  1015. vllm/tracing.py +131 -0
  1016. vllm/transformers_utils/__init__.py +24 -0
  1017. vllm/transformers_utils/chat_templates/__init__.py +5 -0
  1018. vllm/transformers_utils/chat_templates/registry.py +60 -0
  1019. vllm/transformers_utils/chat_templates/template_basic.jinja +3 -0
  1020. vllm/transformers_utils/chat_templates/template_blip2.jinja +11 -0
  1021. vllm/transformers_utils/chat_templates/template_chatml.jinja +10 -0
  1022. vllm/transformers_utils/chat_templates/template_deepseek_vl2.jinja +23 -0
  1023. vllm/transformers_utils/chat_templates/template_fuyu.jinja +3 -0
  1024. vllm/transformers_utils/config.py +887 -0
  1025. vllm/transformers_utils/configs/__init__.py +61 -0
  1026. vllm/transformers_utils/configs/arctic.py +207 -0
  1027. vllm/transformers_utils/configs/chatglm.py +72 -0
  1028. vllm/transformers_utils/configs/cohere2.py +195 -0
  1029. vllm/transformers_utils/configs/dbrx.py +280 -0
  1030. vllm/transformers_utils/configs/deepseek_vl2.py +216 -0
  1031. vllm/transformers_utils/configs/eagle.py +85 -0
  1032. vllm/transformers_utils/configs/exaone.py +190 -0
  1033. vllm/transformers_utils/configs/falcon.py +90 -0
  1034. vllm/transformers_utils/configs/h2ovl.py +16 -0
  1035. vllm/transformers_utils/configs/internvl.py +54 -0
  1036. vllm/transformers_utils/configs/jais.py +238 -0
  1037. vllm/transformers_utils/configs/kimi_vl.py +37 -0
  1038. vllm/transformers_utils/configs/medusa.py +63 -0
  1039. vllm/transformers_utils/configs/minimax_text_01.py +70 -0
  1040. vllm/transformers_utils/configs/minimax_vl_01.py +71 -0
  1041. vllm/transformers_utils/configs/mllama.py +31 -0
  1042. vllm/transformers_utils/configs/mlp_speculator.py +68 -0
  1043. vllm/transformers_utils/configs/moonvit.py +33 -0
  1044. vllm/transformers_utils/configs/mpt.py +180 -0
  1045. vllm/transformers_utils/configs/nemotron.py +205 -0
  1046. vllm/transformers_utils/configs/nemotron_h.py +258 -0
  1047. vllm/transformers_utils/configs/nvlm_d.py +15 -0
  1048. vllm/transformers_utils/configs/ovis.py +184 -0
  1049. vllm/transformers_utils/configs/skyworkr1v.py +54 -0
  1050. vllm/transformers_utils/configs/solar.py +247 -0
  1051. vllm/transformers_utils/configs/telechat2.py +64 -0
  1052. vllm/transformers_utils/configs/ultravox.py +108 -0
  1053. vllm/transformers_utils/detokenizer.py +168 -0
  1054. vllm/transformers_utils/detokenizer_utils.py +189 -0
  1055. vllm/transformers_utils/processor.py +221 -0
  1056. vllm/transformers_utils/processors/__init__.py +8 -0
  1057. vllm/transformers_utils/processors/deepseek_vl2.py +363 -0
  1058. vllm/transformers_utils/processors/ovis.py +420 -0
  1059. vllm/transformers_utils/s3_utils.py +162 -0
  1060. vllm/transformers_utils/tokenizer.py +302 -0
  1061. vllm/transformers_utils/tokenizer_base.py +149 -0
  1062. vllm/transformers_utils/tokenizer_group.py +120 -0
  1063. vllm/transformers_utils/tokenizers/__init__.py +10 -0
  1064. vllm/transformers_utils/tokenizers/mistral.py +493 -0
  1065. vllm/transformers_utils/utils.py +99 -0
  1066. vllm/triton_utils/__init__.py +14 -0
  1067. vllm/triton_utils/importing.py +50 -0
  1068. vllm/usage/__init__.py +0 -0
  1069. vllm/usage/usage_lib.py +256 -0
  1070. vllm/utils.py +2910 -0
  1071. vllm/v1/__init__.py +0 -0
  1072. vllm/v1/attention/__init__.py +0 -0
  1073. vllm/v1/attention/backends/__init__.py +0 -0
  1074. vllm/v1/attention/backends/cpu_attn.py +163 -0
  1075. vllm/v1/attention/backends/flash_attn.py +869 -0
  1076. vllm/v1/attention/backends/flashinfer.py +651 -0
  1077. vllm/v1/attention/backends/flex_attention.py +477 -0
  1078. vllm/v1/attention/backends/mla/__init__.py +0 -0
  1079. vllm/v1/attention/backends/mla/common.py +931 -0
  1080. vllm/v1/attention/backends/mla/cutlass_mla.py +97 -0
  1081. vllm/v1/attention/backends/mla/flashmla.py +152 -0
  1082. vllm/v1/attention/backends/mla/rocm_aiter_mla.py +220 -0
  1083. vllm/v1/attention/backends/mla/triton_mla.py +120 -0
  1084. vllm/v1/attention/backends/pallas.py +240 -0
  1085. vllm/v1/attention/backends/triton_attn.py +285 -0
  1086. vllm/v1/attention/backends/utils.py +52 -0
  1087. vllm/v1/core/__init__.py +0 -0
  1088. vllm/v1/core/block_pool.py +349 -0
  1089. vllm/v1/core/encoder_cache_manager.py +150 -0
  1090. vllm/v1/core/kv_cache_coordinator.py +363 -0
  1091. vllm/v1/core/kv_cache_manager.py +392 -0
  1092. vllm/v1/core/kv_cache_utils.py +996 -0
  1093. vllm/v1/core/sched/__init__.py +0 -0
  1094. vllm/v1/core/sched/interface.py +150 -0
  1095. vllm/v1/core/sched/output.py +154 -0
  1096. vllm/v1/core/sched/scheduler.py +1044 -0
  1097. vllm/v1/core/sched/utils.py +23 -0
  1098. vllm/v1/core/single_type_kv_cache_manager.py +403 -0
  1099. vllm/v1/engine/__init__.py +173 -0
  1100. vllm/v1/engine/async_llm.py +558 -0
  1101. vllm/v1/engine/coordinator.py +253 -0
  1102. vllm/v1/engine/core.py +961 -0
  1103. vllm/v1/engine/core_client.py +1129 -0
  1104. vllm/v1/engine/detokenizer.py +261 -0
  1105. vllm/v1/engine/exceptions.py +17 -0
  1106. vllm/v1/engine/llm_engine.py +317 -0
  1107. vllm/v1/engine/logprobs.py +199 -0
  1108. vllm/v1/engine/mm_input_cache.py +91 -0
  1109. vllm/v1/engine/output_processor.py +428 -0
  1110. vllm/v1/engine/parallel_sampling.py +133 -0
  1111. vllm/v1/engine/processor.py +407 -0
  1112. vllm/v1/executor/__init__.py +0 -0
  1113. vllm/v1/executor/abstract.py +113 -0
  1114. vllm/v1/executor/multiproc_executor.py +537 -0
  1115. vllm/v1/executor/ray_distributed_executor.py +62 -0
  1116. vllm/v1/kv_cache_interface.py +194 -0
  1117. vllm/v1/metrics/__init__.py +0 -0
  1118. vllm/v1/metrics/loggers.py +523 -0
  1119. vllm/v1/metrics/prometheus.py +82 -0
  1120. vllm/v1/metrics/ray_wrappers.py +131 -0
  1121. vllm/v1/metrics/reader.py +246 -0
  1122. vllm/v1/metrics/stats.py +239 -0
  1123. vllm/v1/outputs.py +116 -0
  1124. vllm/v1/request.py +193 -0
  1125. vllm/v1/sample/__init__.py +0 -0
  1126. vllm/v1/sample/metadata.py +44 -0
  1127. vllm/v1/sample/ops/__init__.py +0 -0
  1128. vllm/v1/sample/ops/bad_words.py +39 -0
  1129. vllm/v1/sample/ops/penalties.py +59 -0
  1130. vllm/v1/sample/ops/topk_topp_sampler.py +293 -0
  1131. vllm/v1/sample/rejection_sampler.py +631 -0
  1132. vllm/v1/sample/sampler.py +286 -0
  1133. vllm/v1/sample/tpu/__init__.py +0 -0
  1134. vllm/v1/sample/tpu/metadata.py +124 -0
  1135. vllm/v1/sample/tpu/sampler.py +145 -0
  1136. vllm/v1/serial_utils.py +315 -0
  1137. vllm/v1/spec_decode/__init__.py +0 -0
  1138. vllm/v1/spec_decode/eagle.py +432 -0
  1139. vllm/v1/spec_decode/medusa.py +62 -0
  1140. vllm/v1/spec_decode/metadata.py +62 -0
  1141. vllm/v1/spec_decode/metrics.py +178 -0
  1142. vllm/v1/spec_decode/ngram_proposer.py +132 -0
  1143. vllm/v1/spec_decode/utils.py +46 -0
  1144. vllm/v1/structured_output/__init__.py +222 -0
  1145. vllm/v1/structured_output/backend_guidance.py +245 -0
  1146. vllm/v1/structured_output/backend_types.py +134 -0
  1147. vllm/v1/structured_output/backend_xgrammar.py +318 -0
  1148. vllm/v1/structured_output/request.py +86 -0
  1149. vllm/v1/structured_output/utils.py +175 -0
  1150. vllm/v1/utils.py +743 -0
  1151. vllm/v1/worker/__init__.py +0 -0
  1152. vllm/v1/worker/block_table.py +142 -0
  1153. vllm/v1/worker/cpu_model_runner.py +86 -0
  1154. vllm/v1/worker/cpu_worker.py +152 -0
  1155. vllm/v1/worker/gpu_input_batch.py +681 -0
  1156. vllm/v1/worker/gpu_model_runner.py +2320 -0
  1157. vllm/v1/worker/gpu_worker.py +393 -0
  1158. vllm/v1/worker/lora_model_runner_mixin.py +173 -0
  1159. vllm/v1/worker/tpu_model_runner.py +1673 -0
  1160. vllm/v1/worker/tpu_worker.py +299 -0
  1161. vllm/v1/worker/utils.py +111 -0
  1162. vllm/v1/worker/worker_base.py +65 -0
  1163. vllm/version.py +41 -0
  1164. vllm/vllm_flash_attn/.gitkeep +0 -0
  1165. vllm/worker/__init__.py +0 -0
  1166. vllm/worker/cache_engine.py +145 -0
  1167. vllm/worker/cpu_enc_dec_model_runner.py +326 -0
  1168. vllm/worker/cpu_model_runner.py +671 -0
  1169. vllm/worker/cpu_pooling_model_runner.py +125 -0
  1170. vllm/worker/cpu_worker.py +450 -0
  1171. vllm/worker/enc_dec_model_runner.py +555 -0
  1172. vllm/worker/hpu_model_runner.py +2320 -0
  1173. vllm/worker/hpu_worker.py +484 -0
  1174. vllm/worker/model_runner.py +2178 -0
  1175. vllm/worker/model_runner_base.py +282 -0
  1176. vllm/worker/multi_step_hpu_worker.py +123 -0
  1177. vllm/worker/multi_step_model_runner.py +911 -0
  1178. vllm/worker/multi_step_neuron_model_runner.py +84 -0
  1179. vllm/worker/multi_step_neuronx_distributed_model_runner.py +63 -0
  1180. vllm/worker/multi_step_tpu_worker.py +108 -0
  1181. vllm/worker/multi_step_worker.py +197 -0
  1182. vllm/worker/neuron_model_runner.py +460 -0
  1183. vllm/worker/neuron_worker.py +193 -0
  1184. vllm/worker/neuronx_distributed_model_runner.py +294 -0
  1185. vllm/worker/pooling_model_runner.py +211 -0
  1186. vllm/worker/tpu_model_runner.py +909 -0
  1187. vllm/worker/tpu_worker.py +337 -0
  1188. vllm/worker/utils.py +53 -0
  1189. vllm/worker/worker.py +577 -0
  1190. vllm/worker/worker_base.py +646 -0
  1191. vllm/worker/xpu_model_runner.py +606 -0
  1192. vllm/worker/xpu_worker.py +186 -0
  1193. vllm_cpu_amxbf16-0.9.1.dist-info/METADATA +305 -0
  1194. vllm_cpu_amxbf16-0.9.1.dist-info/RECORD +1197 -0
  1195. vllm_cpu_amxbf16-0.9.1.dist-info/WHEEL +5 -0
  1196. vllm_cpu_amxbf16-0.9.1.dist-info/entry_points.txt +5 -0
  1197. vllm_cpu_amxbf16-0.9.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1884 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
3
+ # Copyright (c) Microsoft Corporation.
4
+ # Licensed under the MIT license.
5
+ # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
6
+ # but implemented by the Phi-Speech team
7
+ #!/usr/bin/env python3
8
+ import math
9
+ from typing import Optional, Union
10
+
11
+ import torch
12
+ import torch.nn.functional as F
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class BlockBase(nn.Module):
17
+ """Block abstract module"""
18
+
19
+ def __init__(self, input_size, output_size):
20
+ super().__init__()
21
+ self.input_size = input_size
22
+ self.output_size = output_size
23
+
24
+
25
+ def get_activation(name="relu"):
26
+ """Select an activation function by name
27
+
28
+ Args:
29
+ name: str
30
+ activation function name,
31
+ one of ["relu", "gelu", "swish", "sigmoid"],
32
+ default "relu".
33
+ """
34
+ name = name.lower()
35
+ if name == "relu":
36
+ return nn.ReLU(inplace=True)
37
+ if name == "gelu":
38
+ return nn.GELU()
39
+ if name == "swish":
40
+ return Swish()
41
+ if name == "sigmoid":
42
+ return torch.nn.Sigmoid()
43
+ return nn.Identity()
44
+
45
+
46
+ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
47
+ """
48
+ The function is very important for Transformer Transducer Streaming mode
49
+ Args:
50
+ xs_len (int): sequence length
51
+ chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
52
+ It also supports adaptive chunk size [0,10,15,45]
53
+ left_window (int): how many left chunks can be seen
54
+ right_window (int): how many right chunks can be seen. It is used for
55
+ chunk overlap model.
56
+ Returns:
57
+ mask (torch.Tensor): a mask tensor for streaming model
58
+ Torch 1.0.1
59
+ tensor([[1., 1., 0., 0.],
60
+ [0., 1., 1., 0.],
61
+ [0., 0., 1., 1.]])
62
+ Torch 1.4.1
63
+ tensor([[True., True., False., False.],
64
+ [False., True., True., False.],
65
+ [False., False., True., True.]])
66
+ """
67
+ chunk_start_idx = torch.Tensor(chunk_start_idx).long(
68
+ ) # first idx of each chunk, such as [0,18,36,48].
69
+ start_pad = torch.nn.functional.pad(
70
+ chunk_start_idx,
71
+ (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
72
+ end_pad = torch.nn.functional.pad(
73
+ chunk_start_idx, (0, 1), value=x_len
74
+ ) # append x_len to the end, so it becomes [0,18,36,48, x_len]
75
+ seq_range = torch.arange(0,
76
+ x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
77
+ idx = ((seq_range < end_pad) &
78
+ (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len]
79
+ # boundary = end_pad[idx] # boundary size: [x_len]
80
+ seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
81
+ ) # seq_range_expand size [x_len, x_len]
82
+ idx_left = idx - left_window
83
+ idx_left[idx_left < 0] = 0
84
+ boundary_left = start_pad[idx_left]
85
+ mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
86
+ idx_right = idx + right_window
87
+ idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
88
+ boundary_right = end_pad[idx_right]
89
+ mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
90
+ return mask_left & mask_right
91
+
92
+
93
+ class Swish(nn.Module):
94
+ """Implement Swish activation module.
95
+ From https://arxiv.org/pdf/2005.03191.pdf
96
+
97
+ """
98
+
99
+ def __init__(self) -> None:
100
+ super().__init__()
101
+ self.act_fn = nn.Sigmoid()
102
+
103
+ def forward(self, x: Tensor) -> Tensor:
104
+ """Apply Swish function
105
+
106
+ Args:
107
+ x: torch.Tensor
108
+ Input.
109
+ """
110
+ return x * self.act_fn(x)
111
+
112
+
113
+ class GLU(nn.Module):
114
+ """Implement Gated Linear Unit (GLU) module"""
115
+
116
+ def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
117
+ super().__init__()
118
+ self.dim = dim
119
+ self.act_name = act_name.lower()
120
+
121
+ if self.act_name == "relu":
122
+ self.act_fn = nn.ReLU(inplace=True)
123
+ elif self.act_name == "gelu":
124
+ self.act_fn = nn.GELU()
125
+ elif self.act_name == "swish":
126
+ self.act_fn = Swish()
127
+ elif self.act_name == "sigmoid":
128
+ self.act_fn = nn.Sigmoid()
129
+ else:
130
+ self.act_fn = nn.Identity()
131
+
132
+ def forward(self, x: Tensor) -> Tensor:
133
+ """GLU forward
134
+ Apply Swish function on the first half of input matrices
135
+ with sigmoid of the second half.
136
+
137
+ Args:
138
+ x: torch.Tensor
139
+ Input.
140
+
141
+ """
142
+ half_x, gate = x.chunk(2, dim=self.dim)
143
+ return half_x * self.act_fn(gate)
144
+
145
+
146
+ # TODO: Abdel, this can be improved using GLU module
147
+ class GLUPointWiseConv(nn.Module):
148
+ """GLUPointWiseConv module
149
+ used for conformer architecture,
150
+ for more details see:
151
+ https://arxiv.org/pdf/2005.08100v1.pdf
152
+
153
+ Args:
154
+ input_dim: int
155
+ input channel size.
156
+ output_dim: int
157
+ output channel size.
158
+ kernel_size: int
159
+ kernel size
160
+ glu_type: str, optional
161
+ activation function one of
162
+ ["sigmoid", "relu", "gelu"]
163
+ default "sigmoid".
164
+ bias_in_glu: bool, optional
165
+ use addtive bias in glu
166
+ causal: bool, optional
167
+ if set to True, padding is set to the half of
168
+ kernel size, ie, convolution can't see future frames.
169
+ default False.
170
+
171
+ """
172
+
173
+ def __init__(
174
+ self,
175
+ input_dim,
176
+ output_dim,
177
+ kernel_size,
178
+ glu_type="sigmoid",
179
+ bias_in_glu=True,
180
+ causal=False,
181
+ ):
182
+ super().__init__()
183
+
184
+ self.glu_type = glu_type
185
+ self.output_dim = output_dim
186
+ self.bias_in_glu = bias_in_glu
187
+ if causal:
188
+ self.ext_pw_conv_1d = nn.Conv1d(
189
+ input_dim,
190
+ output_dim * 2,
191
+ kernel_size,
192
+ 1,
193
+ padding=(kernel_size - 1),
194
+ )
195
+ else:
196
+ self.ext_pw_conv_1d = nn.Conv1d(
197
+ input_dim,
198
+ output_dim * 2,
199
+ kernel_size,
200
+ 1,
201
+ padding=(kernel_size - 1) // 2,
202
+ )
203
+
204
+ if glu_type == "sigmoid":
205
+ self.glu_act = nn.Sigmoid()
206
+ elif glu_type == "relu":
207
+ self.glu_act = nn.ReLU()
208
+ elif glu_type == "gelu":
209
+ self.glu_act = nn.GELU()
210
+ elif glu_type == "swish":
211
+ self.glu_act = Swish()
212
+ else:
213
+ raise ValueError(f"Unsupported activation type {self.glu_act}")
214
+
215
+ if bias_in_glu:
216
+ self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
217
+ self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
218
+
219
+ def forward(self, x):
220
+ """
221
+ Args:
222
+ x: torch.Tensor
223
+ input tensor
224
+ """
225
+ # to be consistent with GLULinear, we assume the input always has the
226
+ # #channel (#dim) in the last dimension of the tensor, so need to
227
+ # switch the dimension first for 1D-Conv case
228
+ x = x.permute([0, 2, 1])
229
+ x = self.ext_pw_conv_1d(x)
230
+ if self.glu_type == "bilinear":
231
+ if self.bias_in_glu:
232
+ x = (x[:, 0:self.output_dim, :] + self.b1) * (
233
+ x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
234
+ else:
235
+ x = (x[:, 0:self.output_dim, :]) * (
236
+ x[:, self.output_dim:self.output_dim * 2, :])
237
+ else:
238
+ if self.bias_in_glu:
239
+ x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act(
240
+ x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
241
+ else:
242
+ x = (x[:, 0:self.output_dim, :]) * self.glu_act(
243
+ x[:, self.output_dim:self.output_dim * 2, :])
244
+
245
+ x = x.permute([0, 2, 1])
246
+ return x
247
+
248
+
249
+ class DepthWiseSeperableConv1d(nn.Module):
250
+ """DepthWiseSeperableConv1d module used in Convnet module
251
+ for the conformer, for more details see:
252
+ https://arxiv.org/pdf/2005.08100v1.pdf
253
+
254
+ Args:
255
+ input_dim: int
256
+ input channel size.
257
+ depthwise_seperable_out_channel: int
258
+ if set different to 0, the number of
259
+ depthwise_seperable_out_channel will be used as a channel_out
260
+ of the second conv1d layer.
261
+ otherwise, it equal to 0, the second conv1d layer is skipped.
262
+ kernel_size: int
263
+ kernel_size
264
+ depthwise_multiplier: int
265
+ number of input_dim channels duplication. this value
266
+ will be used to compute the hidden channels of the Conv1D.
267
+ padding: int, optional
268
+ padding for the conv1d,
269
+ default: 0.
270
+
271
+ """
272
+
273
+ def __init__(
274
+ self,
275
+ input_dim,
276
+ depthwise_seperable_out_channel,
277
+ kernel_size,
278
+ depthwise_multiplier,
279
+ padding=0,
280
+ ):
281
+ super().__init__()
282
+
283
+ self.dw_conv = nn.Conv1d(
284
+ input_dim,
285
+ input_dim * depthwise_multiplier,
286
+ kernel_size,
287
+ 1,
288
+ padding=padding,
289
+ groups=input_dim,
290
+ )
291
+
292
+ if depthwise_seperable_out_channel != 0:
293
+ self.pw_conv = nn.Conv1d(
294
+ input_dim * depthwise_multiplier,
295
+ depthwise_seperable_out_channel,
296
+ 1,
297
+ 1,
298
+ 0,
299
+ )
300
+ else:
301
+ self.pw_conv = nn.Identity()
302
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
303
+
304
+ def forward(self, x):
305
+ """
306
+
307
+ Args:
308
+ x: torch.Tensor
309
+ input tensor
310
+ """
311
+ x = self.dw_conv(x)
312
+ if self.depthwise_seperable_out_channel != 0:
313
+ x = self.pw_conv(x)
314
+ return x
315
+
316
+
317
+ class ConvModule(nn.Module):
318
+ """ConvModule Module for the conformer block.
319
+ for more details see:
320
+ https://arxiv.org/pdf/2005.08100v1.pdf
321
+
322
+ Args:
323
+ input_dim: int
324
+ input channel size.
325
+ ext_pw_out_channel: int
326
+ if > 0, ext_pw_out_channel is a dim channel size
327
+ for the last pointwise conv after swish activation.
328
+ depthwise_seperable_out_channel: int
329
+ if set different to 0, the number of
330
+ depthwise_seperable_out_channel
331
+ will be used as a channel_out of the second conv1d layer.
332
+ otherwise, it equal to 0, the second conv1d layer is skipped.
333
+ ext_pw_kernel_size: int
334
+ kernel size of the conv pointwise of the conformer.
335
+ kernel_size: int
336
+ kernel size.
337
+ depthwise_multiplier: int
338
+ number of input_dim channels duplication. this value
339
+ will be used to compute the hidden channels of the Conv1D.
340
+ dropout_rate: float
341
+ dropout rate.
342
+ causal: bool, optional
343
+ if set to True, convolution have no access
344
+ to future frames. default False.
345
+ batch_norm: bool, optional
346
+ if set to True, apply batchnorm before activation.
347
+ default False
348
+ chunk_se: int, optional
349
+ 0 for offline SE.
350
+ 1 for streaming SE, where mean is computed
351
+ by accumulated history until current chunk_se.
352
+ 2 for streaming SE, where mean is computed
353
+ by only the current chunk.
354
+ chunk_size: int, optional
355
+ chunk size for cnn. default 18
356
+ activation: str, optional
357
+ activation function used in ConvModule,
358
+ default: "relu".
359
+ glu_type: str, optional
360
+ activation function used for the glu,
361
+ default: "sigmoid".
362
+ bias_in_glu: bool, optional
363
+ if set to True, use additive bias in the weight module
364
+ before GLU.
365
+ linear_glu_in_convm: bool, optional
366
+ if set to True, use GLULinear module,
367
+ otherwise, used GLUPointWiseConv module.
368
+ default to False.
369
+ export: bool, optional,
370
+ if set to True, padding is equal to 0. This is for inference,
371
+ or onnx export. Typically this is set by the export program or
372
+ the decoder program, and it isn't present in your config file.
373
+ default False
374
+ """
375
+
376
+ def __init__(
377
+ self,
378
+ input_dim,
379
+ ext_pw_out_channel,
380
+ depthwise_seperable_out_channel,
381
+ ext_pw_kernel_size,
382
+ kernel_size,
383
+ depthwise_multiplier,
384
+ dropout_rate,
385
+ causal=False,
386
+ batch_norm=False,
387
+ chunk_se=0,
388
+ chunk_size=18,
389
+ activation="relu",
390
+ glu_type="sigmoid",
391
+ bias_in_glu=True,
392
+ linear_glu_in_convm=False,
393
+ export=False,
394
+ ):
395
+ super().__init__()
396
+ self.layer_norm = nn.LayerNorm(input_dim)
397
+ self.input_dim = input_dim
398
+ self.ext_pw_out_channel = ext_pw_out_channel
399
+ self.ext_pw_kernel_size = ext_pw_kernel_size
400
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
401
+ self.glu_type = glu_type
402
+ self.bias_in_glu = bias_in_glu
403
+ self.linear_glu_in_convm = linear_glu_in_convm
404
+ self.causal = causal
405
+
406
+ self._add_ext_pw_layer()
407
+
408
+ self.batch_norm = batch_norm
409
+ self.kernel_size = kernel_size
410
+
411
+ if batch_norm:
412
+ self.bn_layer = nn.BatchNorm1d(input_dim)
413
+
414
+ self.act = get_activation(activation)
415
+ self.dropout = nn.Dropout(dropout_rate)
416
+ self.export = export
417
+
418
+ if causal:
419
+ padding = 0 if export else kernel_size - 1
420
+ else:
421
+ padding = (kernel_size - 1) // 2
422
+
423
+ self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
424
+ input_dim,
425
+ depthwise_seperable_out_channel,
426
+ kernel_size,
427
+ depthwise_multiplier,
428
+ padding=padding,
429
+ )
430
+
431
+ if depthwise_seperable_out_channel != 0:
432
+ if input_dim != depthwise_seperable_out_channel:
433
+ self.ln2 = nn.Linear(depthwise_seperable_out_channel,
434
+ input_dim)
435
+ else:
436
+ if depthwise_multiplier != 1:
437
+ self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
438
+ input_dim)
439
+
440
+ def _add_ext_pw_layer(self):
441
+ """
442
+ This function is an extension of __init__ function
443
+ and dedicated to the convolution module creation
444
+ of the conformer.
445
+ """
446
+ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
447
+ nn.Identity()) # jit hacks.
448
+ self.squeeze_excitation = nn.Identity() # jit.
449
+ self.apply_ln1 = self.fix_len1 = False # jit.
450
+
451
+ if self.ext_pw_out_channel != 0:
452
+ if self.causal:
453
+ self.ext_pw_conv_1d = nn.Conv1d(
454
+ self.input_dim,
455
+ self.ext_pw_out_channel,
456
+ self.ext_pw_kernel_size,
457
+ 1,
458
+ padding=(self.ext_pw_kernel_size - 1),
459
+ )
460
+ if self.ext_pw_kernel_size > 1:
461
+ self.fix_len1 = True
462
+ else:
463
+ self.fix_len1 = False
464
+ else:
465
+ self.ext_pw_conv_1d = nn.Conv1d(
466
+ self.input_dim,
467
+ self.ext_pw_out_channel,
468
+ self.ext_pw_kernel_size,
469
+ 1,
470
+ padding=(self.ext_pw_kernel_size - 1) // 2,
471
+ )
472
+ self.fix_len1 = False
473
+
474
+ if self.linear_glu_in_convm:
475
+ self.glu = GLULinear(
476
+ self.input_dim,
477
+ self.ext_pw_out_channel,
478
+ self.glu_type,
479
+ self.bias_in_glu,
480
+ )
481
+ else:
482
+ self.glu = GLUPointWiseConv(
483
+ self.input_dim,
484
+ self.ext_pw_out_channel,
485
+ self.ext_pw_kernel_size,
486
+ self.glu_type,
487
+ self.bias_in_glu,
488
+ self.causal,
489
+ )
490
+
491
+ if self.input_dim != self.ext_pw_out_channel:
492
+ self.apply_ln1 = True
493
+ self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
494
+ else:
495
+ self.apply_ln1 = False
496
+ else:
497
+ self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
498
+ self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
499
+
500
+ def forward(self, x):
501
+ """ConvModule Forward.
502
+
503
+ Args:
504
+ x: torch.Tensor
505
+ input tensor.
506
+ """
507
+ x = self.layer_norm(x)
508
+
509
+ if self.ext_pw_out_channel != 0:
510
+ x = self.glu(x)
511
+ if self.causal and self.ext_pw_kernel_size > 1:
512
+ x = x[:, :-(self.ext_pw_kernel_size - 1), :]
513
+ if self.apply_ln1:
514
+ x = self.ln1(x)
515
+ else:
516
+ x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
517
+ x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
518
+ x = x_0 + x_1
519
+
520
+ x = x.permute([0, 2, 1])
521
+
522
+ x = self.dw_sep_conv_1d(x)
523
+ if self.causal and self.kernel_size > 1:
524
+ x = x[:, :, :-(self.kernel_size - 1)]
525
+ if hasattr(self, "ln2"):
526
+ x = x.permute([0, 2, 1])
527
+ x = self.ln2(x)
528
+ x = x.permute([0, 2, 1])
529
+ if self.batch_norm:
530
+ x = self.bn_layer(x)
531
+ x = self.act(x)
532
+
533
+ if self.ext_pw_out_channel != 0:
534
+ x = self.ext_pw_conv_1d(x)
535
+ if self.fix_len1:
536
+ x = x[:, :, :-(self.ext_pw_kernel_size - 1)]
537
+
538
+ if self.apply_ln1:
539
+ x = x.permute([0, 2, 1])
540
+ x = self.ln1(x)
541
+ x = x.permute([0, 2, 1])
542
+
543
+ x = x.permute([0, 2, 1])
544
+ else:
545
+ x = x.unsqueeze(1).permute([0, 1, 3, 2])
546
+ x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
547
+ x = x.squeeze(1)
548
+
549
+ x = self.dropout(x)
550
+ return x
551
+
552
+
553
+ class GLULinear(nn.Module):
554
+ """Linear + GLU module
555
+
556
+ Args:
557
+ input_dim: int
558
+ input size
559
+ output_dim: int
560
+ output size.
561
+ glu_type:
562
+ activation function name used in glu module.
563
+ default "sigmoid" (swish function).
564
+ bias_in_glu: bool, optional
565
+ If True, the addtive bias is added. Default False.
566
+ """
567
+
568
+ def __init__(
569
+ self,
570
+ input_dim,
571
+ output_dim,
572
+ glu_type="sigmoid",
573
+ bias_in_glu=True,
574
+ ):
575
+ super().__init__()
576
+ self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
577
+ self.glu_act = GLU(-1, glu_type)
578
+
579
+ def forward(self, x):
580
+ """GLULinear forward
581
+
582
+ Args:
583
+ x: torch.Tensor
584
+ inpute tensor.
585
+ """
586
+ x = self.linear(x)
587
+ return self.glu_act(x)
588
+
589
+
590
+ class FeedForward(nn.Module):
591
+ """FeedForward Module.
592
+ For more details see Conformer paper:
593
+ https://arxiv.org/pdf/2005.08100.pdf
594
+
595
+ Args:
596
+ d_model: int
597
+ input size.
598
+ d_inner: int
599
+ output size.
600
+ dropout_rate: float,
601
+ dropout rate.
602
+ activation: str,
603
+ activation function name,
604
+ one of ["relu", "swish", "sigmoid"],
605
+ sigmoid activation is only used with "glu_in_fnn=True",
606
+ default "sigmoid".
607
+ bias_in_glu: bool, optional
608
+ """
609
+
610
+ def __init__(
611
+ self,
612
+ d_model,
613
+ d_inner,
614
+ dropout_rate,
615
+ activation="sigmoid",
616
+ bias_in_glu=True,
617
+ ):
618
+ super().__init__()
619
+ self.d_model = d_model
620
+ self.d_inner = d_inner
621
+
622
+ self.layer_norm = nn.LayerNorm(d_model)
623
+ module = GLULinear(d_model, d_inner, activation, bias_in_glu)
624
+ self.net = nn.Sequential(
625
+ module,
626
+ nn.Dropout(dropout_rate),
627
+ nn.Linear(d_inner, d_model),
628
+ nn.Dropout(dropout_rate),
629
+ )
630
+
631
+ def forward(self, x):
632
+ """FeedForward forward function.
633
+
634
+ Args:
635
+ x: torch.Tensor
636
+ input tensor.
637
+ """
638
+ out = self.net(self.layer_norm(x))
639
+
640
+ return out
641
+
642
+
643
+ #### positional encoding starts here
644
+ def _pre_hook(
645
+ state_dict,
646
+ prefix,
647
+ local_metadata,
648
+ strict,
649
+ missing_keys,
650
+ unexpected_keys,
651
+ error_msgs,
652
+ ):
653
+ """Perform pre-hook in load_state_dict for backward compatibility.
654
+
655
+ Note:
656
+ We saved self.pe until v.0.5.2 but we have omitted it later.
657
+ Therefore, we remove the item "pe" from `state_dict` for backward
658
+ compatibility.
659
+
660
+ """
661
+ k = prefix + "pe"
662
+ if k in state_dict:
663
+ state_dict.pop(k)
664
+
665
+
666
+ class T5RelativeAttentionLogitBias(nn.Module):
667
+ """
668
+ This module implements the relative position bias described in Section
669
+ 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
670
+
671
+ The Huggingface implementation is used as a reference
672
+ https://github.com/huggingface/transformers/blob/v4.30.0/src/
673
+ transformers/models/t5/modeling_t5.py#L435
674
+
675
+ Modifies attention as Q*K^T + B, where B is a learned scalar bias based
676
+ on relative position of the query and key. It is HxNxN, where H is the
677
+ number of heads, N is the sequence length.
678
+
679
+ I've made these modifications to the original T5 bias:
680
+ - Skipping of the bucketing step. Original T5 bias converted rel
681
+ position distances into logarithmically increasing buckets. This is
682
+ supposed to help with length generalization.
683
+ - I just directly use rel position index as bias values, as we don't
684
+ need length generalization (40s max is good enough for ASR encoder),
685
+ and it keeps ONNX export simple.
686
+ - I've also extended it so that biases can be asymmetric, the default
687
+ implementation treats L->R and R->L the same. Asymmetric was found to
688
+ yield better results in my experiments.
689
+
690
+ Args:
691
+ num_heads: int
692
+ Number of attention heads
693
+ num_buckets: int
694
+ Number of buckets to use for relative attention bias. This is the
695
+ size of the learnable bias parameter. Bucketing is not yet
696
+ supported, so this defaults to -1 which means no bucketing is
697
+ used (max_distance determines size of bias param).
698
+ max_distance: int
699
+ Maximum distance to use for relative attention bias. With
700
+ num_buckets=-1, this directly controls the max size of the bias
701
+ parameter. When num_buckets > 0 is supported, this will control
702
+ the maximum distance for logarithmic bucketing after which all
703
+ positions are in the same bucket.
704
+ symmetric: bool
705
+ Whether to use symmetric or asymmetric biases. symmetric=False uses
706
+ 2x number of bias params to distinguish L->R from R->L. This was
707
+ found to be better for the encoder.
708
+ """
709
+
710
+ def __init__(self,
711
+ num_heads,
712
+ num_buckets=-1,
713
+ max_distance=1000,
714
+ symmetric=False):
715
+ super().__init__()
716
+ self.num_heads = num_heads
717
+ self.num_buckets = num_buckets
718
+ self.max_distance = max_distance
719
+ self.symmetric = symmetric
720
+ self._skip_bucketing = self.num_buckets < 0
721
+ if self._skip_bucketing:
722
+ self.num_buckets = max_distance
723
+ else:
724
+ raise NotImplementedError(
725
+ "T5 attention bias with bucketed positions is not yet tested")
726
+ if not self.symmetric:
727
+ self.num_buckets *= 2
728
+ self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
729
+
730
+ def forward(self, x):
731
+ # instantiate bias compatible with shape of x
732
+ maxpos = x.size(1)
733
+ context_position = torch.arange(maxpos,
734
+ device=x.device,
735
+ dtype=torch.long)[:, None]
736
+ memory_position = torch.arange(maxpos,
737
+ device=x.device,
738
+ dtype=torch.long)[None, :]
739
+ relative_position = memory_position - context_position
740
+ # clipping to a maximum distance using ops that play well with ONNX
741
+ # export
742
+ relative_position = relative_position.masked_fill(
743
+ relative_position < -self.max_distance, -self.max_distance)
744
+ relative_position = relative_position.masked_fill(
745
+ relative_position > self.max_distance - 1, self.max_distance - 1)
746
+
747
+ # mapping from relative position to index in the bias parameter
748
+ if self._skip_bucketing:
749
+ bias_idx = relative_position
750
+ else:
751
+ bias_idx = self._bucket_relative_position(relative_position)
752
+ if self.symmetric:
753
+ bias_idx = bias_idx.abs()
754
+ else:
755
+ bias_idx += self.num_buckets // 2
756
+
757
+ t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
758
+ t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(
759
+ 0) # [1, H, L, L]
760
+
761
+ return t5_rel_att_bias
762
+
763
+ def _bucket_relative_position(self, relative_position):
764
+ # this is a placeholder (isn't tested, likely buggy) using HuggingFace
765
+ # implem as a reference this also needs to be extended to support
766
+ # asymmetric +/- ve positions
767
+ relative_buckets = 0
768
+ if not self.causal:
769
+ self.num_buckets //= 2
770
+ relative_buckets += (relative_position > 0).to(
771
+ torch.long) * self.num_buckets
772
+ relative_position = torch.abs(relative_position)
773
+ else:
774
+ relative_position = -torch.min(relative_position,
775
+ torch.zeros_like(relative_position))
776
+ # now relative_position is in the range [0, inf)
777
+
778
+ # half of the buckets are for exact increments in positions
779
+ max_exact = self.num_buckets // 2
780
+ is_small = relative_position < max_exact
781
+
782
+ # The other half of the buckets are for logarithmically bigger bins in
783
+ # positions up to max_distance
784
+ relative_position_if_large = max_exact + (
785
+ torch.log(relative_position.float() / max_exact) /
786
+ math.log(self.max_distance / max_exact) *
787
+ (self.num_buckets - max_exact)).to(torch.long)
788
+ relative_position_if_large = torch.min(
789
+ relative_position_if_large,
790
+ torch.full_like(relative_position_if_large, self.num_buckets - 1),
791
+ )
792
+
793
+ relative_buckets += torch.where(is_small, relative_position,
794
+ relative_position_if_large)
795
+ return relative_buckets
796
+
797
+
798
+ class AbsolutePositionalEncoding(nn.Module):
799
+ """Absolute Positional encoding module.
800
+ This module implement Absolute sinusoidal positional encoding
801
+ from: https://arxiv.org/pdf/1706.03762.pdf
802
+
803
+ Args:
804
+ d_model: int
805
+ Input embedding size.
806
+ dropout_rate: float
807
+ dropout rate
808
+ max_len: int, optional
809
+ Maximum input length sequence, Default 5000
810
+
811
+ """
812
+
813
+ def __init__(self, d_model, dropout_rate, max_len=5000):
814
+ """Construct an PositionalEncoding object."""
815
+ super().__init__()
816
+ self.d_model = d_model
817
+ self.xscale = math.sqrt(self.d_model)
818
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
819
+ self.pe = None
820
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
821
+ self._register_load_state_dict_pre_hook(_pre_hook)
822
+
823
+ def extend_pe(self, x):
824
+ """Reset the positional encodings.
825
+
826
+ Args:
827
+ x: torch.Tensor
828
+ """
829
+ if self.pe is not None and self.pe.size(1) >= x.size(1):
830
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
831
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
832
+ return
833
+ pe = torch.zeros(x.size(1), self.d_model)
834
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
835
+ div_term = torch.exp(
836
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
837
+ -(math.log(10000.0) / self.d_model))
838
+ pe[:, 0::2] = torch.sin(position * div_term)
839
+ pe[:, 1::2] = torch.cos(position * div_term)
840
+ pe = pe.unsqueeze(0)
841
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
842
+
843
+ def forward(self, x: torch.Tensor):
844
+ """Add positional encoding.
845
+
846
+ Args:
847
+ x: torch.Tensor
848
+ Input tensor. shape is (batch, time, ...)
849
+
850
+ Returns:
851
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
852
+
853
+ """
854
+ self.extend_pe(x)
855
+ x = x * self.xscale + self.pe[:, :x.size(1)]
856
+ return self.dropout(x)
857
+
858
+
859
+ #### forward embedding layers starts here
860
+ class MeanVarianceNormLayer(nn.Module):
861
+ """Mean/variance normalization layer.
862
+
863
+ Will subtract mean and multiply input by inverted standard deviation.
864
+ Typically used as a very first layer in a model.
865
+
866
+ Args:
867
+ input_size: int
868
+ layer input size.
869
+ """
870
+
871
+ def __init__(self, input_size):
872
+ super().__init__()
873
+ self.input_size = input_size
874
+ self.global_mean = nn.Parameter(torch.zeros(input_size))
875
+ self.global_invstd = nn.Parameter(torch.ones(input_size))
876
+
877
+ def forward(self, input_: Tensor) -> Tensor:
878
+ """MeanVarianceNormLayer Forward
879
+
880
+ Args:
881
+ input_: torch.Tensor
882
+ input tensor.
883
+ """
884
+ return (input_ - self.global_mean) * self.global_invstd
885
+
886
+
887
+ class CausalConv1D(nn.Conv1d):
888
+ """
889
+ A causal version of nn.Conv1d where each step would have limited access to
890
+ locations on its right or left
891
+ All arguments are the same as nn.Conv1d except padding.
892
+
893
+ If padding is set None, then paddings are set automatically to make it a
894
+ causal convolution where each location would not see any steps on its right.
895
+
896
+ If padding is set as a list (size of 2), then padding[0] would be used as
897
+ left padding and padding[1] as right padding.
898
+ It would make it possible to control the number of steps to be accessible
899
+ on the right and left.
900
+ This mode is not supported when stride > 1. padding[0]+padding[1] should
901
+ be equal to (kernel_size - 1).
902
+ """
903
+
904
+ def __init__(
905
+ self,
906
+ in_channels: int,
907
+ out_channels: int,
908
+ kernel_size: int,
909
+ stride: int = 1,
910
+ padding: Union[str, int] = 0,
911
+ dilation: int = 1,
912
+ groups: int = 1,
913
+ bias: bool = True,
914
+ padding_mode: str = "zeros",
915
+ device=None,
916
+ dtype=None,
917
+ ) -> None:
918
+ self.cache_drop_size = None
919
+ if padding is None:
920
+ self._left_padding = kernel_size - 1
921
+ self._right_padding = stride - 1
922
+ else:
923
+ if stride != 1 and padding != kernel_size - 1:
924
+ raise ValueError(
925
+ "No striding allowed for non-symmetric convolutions!")
926
+ if isinstance(padding, int):
927
+ self._left_padding = padding
928
+ self._right_padding = padding
929
+ elif (isinstance(padding, list) and len(padding) == 2
930
+ and padding[0] + padding[1] == kernel_size - 1):
931
+ self._left_padding = padding[0]
932
+ self._right_padding = padding[1]
933
+ else:
934
+ raise ValueError(f"Invalid padding param: {padding}!")
935
+
936
+ self._max_cache_len = self._left_padding
937
+
938
+ super().__init__(
939
+ in_channels=in_channels,
940
+ out_channels=out_channels,
941
+ kernel_size=kernel_size,
942
+ stride=stride,
943
+ padding=0,
944
+ dilation=dilation,
945
+ groups=groups,
946
+ bias=bias,
947
+ padding_mode=padding_mode,
948
+ device=device,
949
+ dtype=dtype,
950
+ )
951
+
952
+ def update_cache(self, x, cache=None):
953
+ if cache is None:
954
+ new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
955
+ next_cache = cache
956
+ else:
957
+ new_x = F.pad(x, pad=(0, self._right_padding))
958
+ new_x = torch.cat([cache, new_x], dim=-1)
959
+ if self.cache_drop_size > 0:
960
+ next_cache = new_x[:, :, :-self.cache_drop_size]
961
+ else:
962
+ next_cache = new_x
963
+ next_cache = next_cache[:, :, -cache.size(-1):]
964
+ return new_x, next_cache
965
+
966
+ def forward(self, x, cache=None):
967
+ x, cache = self.update_cache(x, cache=cache)
968
+ x = super().forward(x)
969
+ if cache is None:
970
+ return x
971
+ else:
972
+ return x, cache
973
+
974
+
975
+ class CausalConv2D(nn.Conv2d):
976
+ """
977
+ A causal version of nn.Conv2d where each location in the 2D matrix would
978
+ have no access to locations on its right or down
979
+ All arguments are the same as nn.Conv2d except padding which should be
980
+ set as None
981
+ """
982
+
983
+ def __init__(
984
+ self,
985
+ in_channels: int,
986
+ out_channels: int,
987
+ kernel_size: int,
988
+ stride: int = 1,
989
+ padding: Union[str, int] = 0,
990
+ dilation: int = 1,
991
+ groups: int = 1,
992
+ bias: bool = True,
993
+ padding_mode: str = "zeros",
994
+ device=None,
995
+ dtype=None,
996
+ ) -> None:
997
+ if padding is not None:
998
+ raise ValueError(
999
+ "Argument padding should be set to None for CausalConv2D.")
1000
+ self._left_padding = kernel_size - 1
1001
+ self._right_padding = stride - 1
1002
+
1003
+ padding = 0
1004
+ super().__init__(
1005
+ in_channels,
1006
+ out_channels,
1007
+ kernel_size,
1008
+ stride,
1009
+ padding,
1010
+ dilation,
1011
+ groups,
1012
+ bias,
1013
+ padding_mode,
1014
+ device,
1015
+ dtype,
1016
+ )
1017
+
1018
+ def forward(
1019
+ self,
1020
+ x,
1021
+ ):
1022
+ x = F.pad(
1023
+ x,
1024
+ pad=(self._left_padding, self._right_padding, 0, 0),
1025
+ )
1026
+ x = super().forward(x)
1027
+ return x
1028
+
1029
+
1030
+ class NemoConvSubsampling(torch.nn.Module):
1031
+ """Convlutional subsampling module, taken from NeMo ASR
1032
+ (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
1033
+ 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
1034
+
1035
+ Striding Subsampling: "Speech-Transformer: A No-Recurrence
1036
+ Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
1037
+ et al. (https://ieeexplore.ieee.org/document/8462506)
1038
+
1039
+
1040
+ Compared with the EncoderConv2D (`input_layer: custom`), this is a
1041
+ much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
1042
+ Moreover, depthwise convolutions are used to reduce FLOPs, but the first
1043
+ layer is kept as a regular convolution so as not to degrade accuracy.
1044
+
1045
+ `Striding` and `dw_striding` are the same except that the latter uses
1046
+ depthwise convolutions after the first layer, whereas the former does not.
1047
+
1048
+ Args:
1049
+ subsampling_factor (int): Time reduction factor
1050
+ feat_in (int): size of the input features
1051
+ feat_out (int): size of the output features
1052
+ subsampling (str): The subsampling technique, choose from
1053
+ {"striding", "dw-striding", "striding_conv1d",
1054
+ "dw_striding_conv1d"}
1055
+ conv_channels (int): Number of channels for the convolution layers,
1056
+ default is 256.
1057
+ subsampling_conv_chunking_factor (int): Input chunking factor which
1058
+ can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
1059
+ activation (Module): activation function, default is nn.ReLU()
1060
+ is_causal (bool): whether to use causal Conv1/2D, where each step will
1061
+ have limited access to locations on its right or left
1062
+ """
1063
+
1064
+ def __init__(
1065
+ self,
1066
+ feat_in,
1067
+ feat_out,
1068
+ subsampling_factor=4,
1069
+ subsampling="dw_striding",
1070
+ conv_channels=256,
1071
+ subsampling_conv_chunking_factor=1,
1072
+ activation=nn.ReLU(), # noqa: B008
1073
+ is_causal=False,
1074
+ ):
1075
+ super().__init__()
1076
+ self._subsampling = subsampling
1077
+ self._conv_channels = conv_channels
1078
+ self._feat_in = feat_in
1079
+ self._feat_out = feat_out
1080
+
1081
+ if subsampling_factor % 2 != 0:
1082
+ raise ValueError("Sampling factor should be a multiply of 2!")
1083
+ self._sampling_num = int(math.log(subsampling_factor, 2))
1084
+ self.subsampling_factor = subsampling_factor
1085
+ self.is_causal = is_causal
1086
+ self.subsampling_causal_cond = subsampling in (
1087
+ "dw_striding",
1088
+ "striding",
1089
+ "striding_conv1d",
1090
+ )
1091
+
1092
+ if (subsampling_conv_chunking_factor != -1
1093
+ and subsampling_conv_chunking_factor != 1
1094
+ and subsampling_conv_chunking_factor % 2 != 0):
1095
+ raise ValueError(
1096
+ "subsampling_conv_chunking_factor should be -1, 1, or a "\
1097
+ "power of 2"
1098
+ )
1099
+ self.subsampling_conv_chunking_factor = \
1100
+ subsampling_conv_chunking_factor
1101
+
1102
+ in_channels = 1
1103
+ layers = []
1104
+
1105
+ if subsampling == "dw_striding":
1106
+ self._stride = 2
1107
+ self._kernel_size = 3
1108
+ self._ceil_mode = False
1109
+
1110
+ if self.is_causal:
1111
+ self._left_padding = self._kernel_size - 1
1112
+ self._right_padding = self._stride - 1
1113
+ self._max_cache_len = subsampling_factor + 1
1114
+ else:
1115
+ self._left_padding = (self._kernel_size - 1) // 2
1116
+ self._right_padding = (self._kernel_size - 1) // 2
1117
+ self._max_cache_len = 0
1118
+
1119
+ # Layer 1
1120
+ if self.is_causal:
1121
+ layers.append(
1122
+ CausalConv2D(
1123
+ in_channels=in_channels,
1124
+ out_channels=conv_channels,
1125
+ kernel_size=self._kernel_size,
1126
+ stride=self._stride,
1127
+ padding=None,
1128
+ ))
1129
+ else:
1130
+ layers.append(
1131
+ torch.nn.Conv2d(
1132
+ in_channels=in_channels,
1133
+ out_channels=conv_channels,
1134
+ kernel_size=self._kernel_size,
1135
+ stride=self._stride,
1136
+ padding=self._left_padding,
1137
+ ))
1138
+ in_channels = conv_channels
1139
+ layers.append(activation)
1140
+
1141
+ for i in range(self._sampling_num - 1):
1142
+ if self.is_causal:
1143
+ layers.append(
1144
+ CausalConv2D(
1145
+ in_channels=in_channels,
1146
+ out_channels=in_channels,
1147
+ kernel_size=self._kernel_size,
1148
+ stride=self._stride,
1149
+ padding=None,
1150
+ groups=in_channels,
1151
+ ))
1152
+ else:
1153
+ layers.append(
1154
+ torch.nn.Conv2d(
1155
+ in_channels=in_channels,
1156
+ out_channels=in_channels,
1157
+ kernel_size=self._kernel_size,
1158
+ stride=self._stride,
1159
+ padding=self._left_padding,
1160
+ groups=in_channels,
1161
+ ))
1162
+
1163
+ layers.append(
1164
+ torch.nn.Conv2d(
1165
+ in_channels=in_channels,
1166
+ out_channels=conv_channels,
1167
+ kernel_size=1,
1168
+ stride=1,
1169
+ padding=0,
1170
+ groups=1,
1171
+ ))
1172
+ layers.append(activation)
1173
+ in_channels = conv_channels
1174
+
1175
+ elif subsampling == "striding":
1176
+ self._stride = 2
1177
+ self._kernel_size = 3
1178
+ self._ceil_mode = False
1179
+
1180
+ if self.is_causal:
1181
+ self._left_padding = self._kernel_size - 1
1182
+ self._right_padding = self._stride - 1
1183
+ self._max_cache_len = subsampling_factor + 1
1184
+ else:
1185
+ self._left_padding = (self._kernel_size - 1) // 2
1186
+ self._right_padding = (self._kernel_size - 1) // 2
1187
+ self._max_cache_len = 0
1188
+
1189
+ for i in range(self._sampling_num):
1190
+ if self.is_causal:
1191
+ layers.append(
1192
+ CausalConv2D(
1193
+ in_channels=in_channels,
1194
+ out_channels=conv_channels,
1195
+ kernel_size=self._kernel_size,
1196
+ stride=self._stride,
1197
+ padding=None,
1198
+ ))
1199
+ else:
1200
+ layers.append(
1201
+ torch.nn.Conv2d(
1202
+ in_channels=in_channels,
1203
+ out_channels=conv_channels,
1204
+ kernel_size=self._kernel_size,
1205
+ stride=self._stride,
1206
+ padding=self._left_padding,
1207
+ ))
1208
+ layers.append(activation)
1209
+ in_channels = conv_channels
1210
+
1211
+ elif subsampling == "striding_conv1d":
1212
+ in_channels = feat_in
1213
+
1214
+ self._stride = 2
1215
+ self._kernel_size = 5
1216
+ self._ceil_mode = False
1217
+
1218
+ if self.is_causal:
1219
+ self._left_padding = self._kernel_size - 1
1220
+ self._right_padding = self._stride - 1
1221
+ self._max_cache_len = subsampling_factor + 1
1222
+ else:
1223
+ self._left_padding = (self._kernel_size - 1) // 2
1224
+ self._right_padding = (self._kernel_size - 1) // 2
1225
+ self._max_cache_len = 0
1226
+
1227
+ for i in range(self._sampling_num):
1228
+ if self.is_causal:
1229
+ layers.append(
1230
+ CausalConv1D(
1231
+ in_channels=in_channels,
1232
+ out_channels=(feat_out if self._sampling_num == i +
1233
+ 1 else conv_channels),
1234
+ kernel_size=self._kernel_size,
1235
+ stride=self._stride,
1236
+ padding=None,
1237
+ ))
1238
+ else:
1239
+ layers.append(
1240
+ torch.nn.Conv1d(
1241
+ in_channels=in_channels,
1242
+ out_channels=(feat_out if self._sampling_num == i +
1243
+ 1 else conv_channels),
1244
+ kernel_size=self._kernel_size,
1245
+ stride=self._stride,
1246
+ padding=self._left_padding,
1247
+ ))
1248
+ layers.append(activation)
1249
+ in_channels = conv_channels
1250
+
1251
+ elif subsampling == "dw_striding_conv1d":
1252
+ in_channels = feat_in
1253
+
1254
+ self._stride = 2
1255
+ self._kernel_size = 5
1256
+ self._ceil_mode = False
1257
+
1258
+ self._left_padding = (self._kernel_size - 1) // 2
1259
+ self._right_padding = (self._kernel_size - 1) // 2
1260
+
1261
+ # Layer 1
1262
+ layers.extend([
1263
+ torch.nn.Conv1d(
1264
+ in_channels=in_channels,
1265
+ out_channels=in_channels,
1266
+ kernel_size=self._kernel_size,
1267
+ stride=self._stride,
1268
+ padding=self._left_padding,
1269
+ groups=in_channels,
1270
+ ),
1271
+ torch.nn.Conv1d(
1272
+ in_channels=in_channels,
1273
+ out_channels=(feat_out if self._sampling_num == 1 else
1274
+ conv_channels),
1275
+ kernel_size=1,
1276
+ stride=1,
1277
+ padding=0,
1278
+ groups=1,
1279
+ ),
1280
+ ])
1281
+ in_channels = conv_channels
1282
+ layers.append(activation)
1283
+
1284
+ for i in range(self._sampling_num - 1):
1285
+ layers.extend([
1286
+ torch.nn.Conv1d(
1287
+ in_channels=in_channels,
1288
+ out_channels=in_channels,
1289
+ kernel_size=self._kernel_size,
1290
+ stride=self._stride,
1291
+ padding=self._left_padding,
1292
+ groups=in_channels,
1293
+ ),
1294
+ torch.nn.Conv1d(
1295
+ in_channels=in_channels,
1296
+ out_channels=(feat_out if self._sampling_num == i +
1297
+ 2 else conv_channels),
1298
+ kernel_size=1,
1299
+ stride=1,
1300
+ padding=0,
1301
+ groups=1,
1302
+ ),
1303
+ ])
1304
+ layers.append(activation)
1305
+ in_channels = conv_channels
1306
+
1307
+ else:
1308
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1309
+
1310
+ if subsampling in ["dw_striding", "striding"]:
1311
+ in_length = torch.tensor(feat_in, dtype=torch.float)
1312
+ out_length = calc_length(
1313
+ lengths=in_length,
1314
+ all_paddings=self._left_padding + self._right_padding,
1315
+ kernel_size=self._kernel_size,
1316
+ stride=self._stride,
1317
+ ceil_mode=self._ceil_mode,
1318
+ repeat_num=self._sampling_num,
1319
+ )
1320
+ self.out = torch.nn.Linear(conv_channels * int(out_length),
1321
+ feat_out)
1322
+ self.conv2d_subsampling = True
1323
+ elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
1324
+ self.out = None
1325
+ self.conv2d_subsampling = False
1326
+ else:
1327
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1328
+
1329
+ self.conv = torch.nn.Sequential(*layers)
1330
+
1331
+ def get_sampling_frames(self):
1332
+ return [1, self.subsampling_factor]
1333
+
1334
+ def get_streaming_cache_size(self):
1335
+ return [0, self.subsampling_factor + 1]
1336
+
1337
+ def forward(self, x, mask):
1338
+ """
1339
+ Forward method for NeMo subsampling.
1340
+
1341
+ Args:
1342
+ x[Batch, Time, Filters]: torch.Tensor
1343
+ input tensor
1344
+ x_mask: torch.Tensor
1345
+ input mask
1346
+
1347
+ Returns:
1348
+ x: torch.Tensor
1349
+ Resulting tensor from subsampling (B, T //
1350
+ time_reduction_factor, feat_out)
1351
+ pad_mask: torch.Tensor
1352
+ tensor of padded hidden state sequences (B, 1, T //
1353
+ time_reduction_factor)
1354
+ """
1355
+ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
1356
+
1357
+ # split inputs if chunking_factor is set
1358
+ if (self.subsampling_conv_chunking_factor != -1
1359
+ and self.conv2d_subsampling):
1360
+ if self.subsampling_conv_chunking_factor == 1:
1361
+ # if subsampling_conv_chunking_factor is 1, we split only
1362
+ # if needed.
1363
+ # avoiding a bug / feature limiting indexing of tensors
1364
+ # to 2**31.
1365
+ # see https://github.com/pytorch/pytorch/issues/80020
1366
+ x_ceil = (2**31 / self._conv_channels * self._stride *
1367
+ self._stride)
1368
+ need_to_split = torch.numel(x) > x_ceil
1369
+ else:
1370
+ # if subsampling_conv_chunking_factor > 1 we always split
1371
+ need_to_split = True
1372
+
1373
+ if need_to_split:
1374
+ x, success = self.conv_split_by_batch(x)
1375
+ if not success: # if unable to split by batch, try by channel
1376
+ if self._subsampling == "dw_striding":
1377
+ x = self.conv_split_by_channel(x)
1378
+ else:
1379
+ x = self.conv(x) # try anyway
1380
+ else:
1381
+ x = self.conv(x)
1382
+ else:
1383
+ x = self.conv(x)
1384
+
1385
+ # Flatten Channel and Frequency Axes
1386
+ if self.conv2d_subsampling:
1387
+ b, c, t, f = x.size()
1388
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
1389
+ # Transpose to Channel Last mode
1390
+ else:
1391
+ x = x.transpose(1, 2)
1392
+
1393
+ if mask is None:
1394
+ return x, None
1395
+
1396
+ max_audio_length = x.shape[1]
1397
+ feature_lens = mask.sum(1)
1398
+ padding_length = torch.ceil(feature_lens / self.subsampling_factor)
1399
+ if self.is_causal and self.subsampling_causal_cond:
1400
+ feature_lens_remainder = feature_lens % self.subsampling_factor
1401
+ padding_length[feature_lens_remainder != 1] += 1
1402
+ pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
1403
+ padding_length.size(0), -1) < padding_length.unsqueeze(1)
1404
+ return x, pad_mask.unsqueeze(1)
1405
+
1406
+ def reset_parameters(self):
1407
+ # initialize weights
1408
+ if self._subsampling == "dw_striding":
1409
+ with torch.no_grad():
1410
+ # init conv
1411
+ scale = 1.0 / self._kernel_size
1412
+ dw_max = (self._kernel_size**2)**-0.5
1413
+ pw_max = self._conv_channels**-0.5
1414
+
1415
+ torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
1416
+ torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
1417
+
1418
+ for idx in range(2, len(self.conv), 3):
1419
+ torch.nn.init.uniform_(self.conv[idx].weight, -dw_max,
1420
+ dw_max)
1421
+ torch.nn.init.uniform_(self.conv[idx].bias, -dw_max,
1422
+ dw_max)
1423
+ torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max,
1424
+ pw_max)
1425
+ torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max,
1426
+ pw_max)
1427
+
1428
+ # init fc (80 * 64 = 5120 from https://github.com/kssteven418/
1429
+ # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
1430
+ # src/models/conformer_encoder.py#L487
1431
+ fc_scale = (self._feat_out * self._feat_in /
1432
+ self._sampling_num)**-0.5
1433
+ torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
1434
+ torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
1435
+
1436
+ def conv_split_by_batch(self, x):
1437
+ """Tries to split input by batch, run conv and concat results"""
1438
+ b, _, _, _ = x.size()
1439
+ if b == 1: # can't split if batch size is 1
1440
+ return x, False
1441
+
1442
+ if self.subsampling_conv_chunking_factor > 1:
1443
+ cf = self.subsampling_conv_chunking_factor
1444
+ else:
1445
+ # avoiding a bug / feature limiting indexing of tensors to 2**31
1446
+ # see https://github.com/pytorch/pytorch/issues/80020
1447
+ x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
1448
+ p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
1449
+ cf = 2**p
1450
+
1451
+ new_batch_size = b // cf
1452
+ if new_batch_size == 0: # input is too big
1453
+ return x, False
1454
+
1455
+ return (
1456
+ torch.cat([
1457
+ self.conv(chunk)
1458
+ for chunk in torch.split(x, new_batch_size, 0)
1459
+ ]),
1460
+ True,
1461
+ )
1462
+
1463
+ def conv_split_by_channel(self, x):
1464
+ """For dw convs, tries to split input by time, run conv and concat
1465
+ results"""
1466
+ x = self.conv[0](x) # full conv2D
1467
+ x = self.conv[1](x) # activation
1468
+
1469
+ for i in range(self._sampling_num - 1):
1470
+ _, c, t, _ = x.size()
1471
+
1472
+ if self.subsampling_conv_chunking_factor > 1:
1473
+ cf = self.subsampling_conv_chunking_factor
1474
+ else:
1475
+ # avoiding a bug / feature limiting indexing of tensors
1476
+ # to 2**31
1477
+ # see https://github.com/pytorch/pytorch/issues/80020
1478
+ p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
1479
+ cf = 2**p
1480
+
1481
+ new_c = int(c // cf)
1482
+ if new_c == 0:
1483
+ new_c = 1
1484
+
1485
+ new_t = int(t // cf)
1486
+ if new_t == 0:
1487
+ new_t = 1
1488
+
1489
+ x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c,
1490
+ x) # conv2D, depthwise
1491
+
1492
+ # splitting pointwise convs by time
1493
+ x = torch.cat(
1494
+ [
1495
+ self.conv[i * 3 + 3](chunk)
1496
+ for chunk in torch.split(x, new_t, 2)
1497
+ ],
1498
+ 2,
1499
+ ) # conv2D, pointwise
1500
+ x = self.conv[i * 3 + 4](x) # activation
1501
+ return x
1502
+
1503
+ def channel_chunked_conv(self, conv, chunk_size, x):
1504
+ """Performs channel chunked convolution"""
1505
+
1506
+ ind = 0
1507
+ out_chunks = []
1508
+ for chunk in torch.split(x, chunk_size, 1):
1509
+ step = chunk.size()[1]
1510
+
1511
+ if self.is_causal:
1512
+ chunk = nn.functional.pad(
1513
+ chunk,
1514
+ pad=(
1515
+ self._kernel_size - 1,
1516
+ self._stride - 1,
1517
+ self._kernel_size - 1,
1518
+ self._stride - 1,
1519
+ ),
1520
+ )
1521
+ ch_out = nn.functional.conv2d(
1522
+ chunk,
1523
+ conv.weight[ind:ind + step, :, :, :],
1524
+ bias=conv.bias[ind:ind + step],
1525
+ stride=self._stride,
1526
+ padding=0,
1527
+ groups=step,
1528
+ )
1529
+ else:
1530
+ ch_out = nn.functional.conv2d(
1531
+ chunk,
1532
+ conv.weight[ind:ind + step, :, :, :],
1533
+ bias=conv.bias[ind:ind + step],
1534
+ stride=self._stride,
1535
+ padding=self._left_padding,
1536
+ groups=step,
1537
+ )
1538
+ out_chunks.append(ch_out)
1539
+ ind += step
1540
+
1541
+ return torch.cat(out_chunks, 1)
1542
+
1543
+ def change_subsampling_conv_chunking_factor(
1544
+ self, subsampling_conv_chunking_factor: int):
1545
+ if (subsampling_conv_chunking_factor != -1
1546
+ and subsampling_conv_chunking_factor != 1
1547
+ and subsampling_conv_chunking_factor % 2 != 0):
1548
+ raise ValueError(
1549
+ "subsampling_conv_chunking_factor should be -1, 1, or a "\
1550
+ "power of 2"
1551
+ )
1552
+ self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
1553
+
1554
+
1555
+ def calc_length(lengths,
1556
+ all_paddings,
1557
+ kernel_size,
1558
+ stride,
1559
+ ceil_mode,
1560
+ repeat_num=1):
1561
+ """Calculates the output length of a Tensor passed through a convolution or
1562
+ max pooling layer"""
1563
+ add_pad: float = all_paddings - kernel_size
1564
+ one: float = 1.0
1565
+ for i in range(repeat_num):
1566
+ lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) +
1567
+ one)
1568
+ lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
1569
+ return lengths.to(dtype=torch.int)
1570
+
1571
+
1572
+ #### multihead attention starts here
1573
+ class AttModule(nn.Module):
1574
+ """Attention abstraction module"""
1575
+
1576
+ def __init__(self):
1577
+ super().__init__()
1578
+ self.export_mode = False
1579
+
1580
+ def set_export(self, mode=True):
1581
+ """set the export mode"""
1582
+ self.export_mode = mode
1583
+
1584
+ def forward(
1585
+ self,
1586
+ x: Tensor,
1587
+ memory: Optional[Tensor] = None,
1588
+ pos_emb: Optional[Tensor] = None,
1589
+ att_mask: Optional[Tensor] = None,
1590
+ ) -> tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
1591
+ """AttModule forward
1592
+
1593
+ Args:
1594
+ x: torch.Tensor
1595
+ input tensor.
1596
+ memory: torch.Tensor, optional
1597
+ memory tensor.
1598
+ pos_emb: torch.Tensor, optional
1599
+ positional encoder embedding.
1600
+ att_mask: torch.Tensor, optional
1601
+ attention mask tensor.
1602
+ """
1603
+ return x, memory, pos_emb, att_mask
1604
+
1605
+
1606
+ class AttBlock(BlockBase, AttModule):
1607
+ """Attention Block module to support both Attention and Block module."""
1608
+
1609
+ def memory_dims(self, max_len=False):
1610
+ """memory dimensions"""
1611
+ return (1, self.input_size)
1612
+
1613
+
1614
+ def masked_softmax(
1615
+ scores,
1616
+ mask: Optional[Tensor],
1617
+ ):
1618
+ if mask is not None:
1619
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
1620
+ scores = scores.masked_fill(mask, -torch.inf)
1621
+ attn = torch.softmax(scores, dim=-1).masked_fill(
1622
+ mask, 0.0) # (batch, head, time1, time2)
1623
+ else:
1624
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
1625
+ return attn
1626
+
1627
+
1628
+ class MultiHeadedAttention(nn.Module):
1629
+ """Multi-Head Attention layer with optional relative position embedding
1630
+ and GLU.
1631
+
1632
+ Args:
1633
+ n_head: int
1634
+ the number of heads.
1635
+ n_feat: int
1636
+ input size features.
1637
+ dropout_rate: float
1638
+ dropout rate.
1639
+ use_LN: bool
1640
+ apply layer norm or not
1641
+ dropout_at_output: bool
1642
+ whether to apply dropout at output
1643
+ attention_inner_dim: int, optional
1644
+ the attention dimension used in the class,
1645
+ it can be different from the input dimension n_feat.
1646
+ default: -1 (equal to n_feat).
1647
+ use_pt_scaled_dot_product_attention: bool, optional
1648
+ if set True, use pytorch scaled dot product attention in training.
1649
+ NOTE: this will NOT be used in ONNX decoding due to a lack of
1650
+ support. In that case, we use the original attention
1651
+ implementation, which shows no regression.
1652
+ default: False.
1653
+ n_value: int, optional
1654
+ if set to values other than -1, use a different dimension for
1655
+ value. With the default value (i.e. -1), it is backward compatible.
1656
+ group_size: int, optional. must divide `n_head`
1657
+ if group_size > 1: GQA
1658
+ if group_size = 1: MHA
1659
+ if group_size = n_head: MQA
1660
+ """
1661
+
1662
+ inv_sqrt_d_k: torch.jit.Final[float]
1663
+ h: torch.jit.Final[int]
1664
+ h_k: torch.jit.Final[int]
1665
+ g: torch.jit.Final[int]
1666
+
1667
+ def __init__(
1668
+ self,
1669
+ n_head,
1670
+ n_feat,
1671
+ dropout_rate,
1672
+ attention_inner_dim=-1,
1673
+ glu_type="swish",
1674
+ bias_in_glu=True,
1675
+ use_pt_scaled_dot_product_attention=False,
1676
+ n_value=-1,
1677
+ group_size: int = 1,
1678
+ ):
1679
+ super().__init__()
1680
+ if n_value == -1:
1681
+ n_value = n_feat
1682
+ if attention_inner_dim == -1:
1683
+ attention_inner_dim = n_feat
1684
+ assert attention_inner_dim % n_head == 0
1685
+
1686
+ # We assume d_v always equals d_k
1687
+ self.d_k = attention_inner_dim // n_head
1688
+ self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
1689
+ self.h = n_head
1690
+ assert n_head % group_size == 0, "group_size must divide n_head"
1691
+ self.g = group_size
1692
+ self.h_k = n_head // group_size
1693
+
1694
+ self.linear_q = nn.Linear(n_feat, attention_inner_dim)
1695
+ self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
1696
+ self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
1697
+ self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
1698
+
1699
+ self.attn = torch.jit.Attribute(None, Optional[Tensor])
1700
+ self.dropout = nn.Dropout(p=dropout_rate)
1701
+ self.dropout_rate = dropout_rate
1702
+ self.use_pt_scaled_dot_product_attention = (
1703
+ use_pt_scaled_dot_product_attention)
1704
+
1705
+ if use_pt_scaled_dot_product_attention and group_size > 1:
1706
+ raise ValueError("Cannot use PT Scaled Attention with GQA")
1707
+
1708
+ # Torchscript eager quantization. Note that these functions below are
1709
+ # NOOPs and have very little impact on performance unless quantization
1710
+ # is enabled.
1711
+ self.quant_q = torch.ao.quantization.QuantStub()
1712
+ self.quant_x = torch.ao.quantization.QuantStub()
1713
+ self.dequant = torch.ao.quantization.DeQuantStub()
1714
+ self.ffunc = torch.ao.nn.quantized.FloatFunctional()
1715
+
1716
+ def forward(
1717
+ self,
1718
+ query: Tensor,
1719
+ key: Tensor,
1720
+ value: Tensor,
1721
+ pos_k: Tensor,
1722
+ pos_v: Tensor,
1723
+ mask: Optional[Tensor],
1724
+ relative_attention_bias: Optional[Tensor] = None,
1725
+ ):
1726
+ """Compute 'Scaled Dot Product Attention'.
1727
+
1728
+ Args:
1729
+ query: torch.Tensor
1730
+ query tensor (batch, time1, size)
1731
+ key: torch.Tensor
1732
+ key tensor (batch, time2, size)
1733
+ value: torch.Tensor
1734
+ value tensor (batch, time1, size)
1735
+ pos_k: torch.Tensor
1736
+ key tensor used for relative positional embedding.
1737
+ pos_v: torch.Tensor
1738
+ value tensor used for relative positional embedding.
1739
+ mask: torch.Tensor
1740
+ mask tensor (batch, time1, time2)
1741
+ relative_attention_bias: torch.Tensor
1742
+ bias added to attention logits w.r.t. relative positions
1743
+ (1, n_head, time1, time2)
1744
+ """
1745
+ n_batch = query.size(0)
1746
+
1747
+ q = self.linear_q(query).view(n_batch, -1, self.h,
1748
+ self.d_k) # (b, t, d)
1749
+ k = self.linear_k(key).view(n_batch, -1, self.h_k,
1750
+ self.d_k) # (b, t, d)
1751
+ v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
1752
+ q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention
1753
+ and not torch.jit.is_scripting() else q.transpose(1, 2) *
1754
+ self.inv_sqrt_d_k)
1755
+ k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
1756
+ v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
1757
+
1758
+ if (self.use_pt_scaled_dot_product_attention
1759
+ and not torch.jit.is_scripting()):
1760
+ attn_mask = None
1761
+ if mask is not None:
1762
+ mask = mask.unsqueeze(1)
1763
+ if relative_attention_bias is not None:
1764
+ attn_mask = mask + relative_attention_bias
1765
+ else:
1766
+ attn_mask = mask
1767
+ if mask.dtype != q.dtype:
1768
+ attn_mask = attn_mask.to(q.dtype)
1769
+
1770
+ with torch.nn.attention.sdpa_kernel([
1771
+ torch.nn.attention.SDPBackend.FLASH_ATTENTION,
1772
+ torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
1773
+ torch.nn.attention.SDPBackend.MATH,
1774
+ torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
1775
+ ]):
1776
+ x = torch.nn.functional.scaled_dot_product_attention(
1777
+ q,
1778
+ k,
1779
+ v,
1780
+ attn_mask=attn_mask,
1781
+ dropout_p=self.dropout_rate,
1782
+ )
1783
+ else:
1784
+ if self.h != self.h_k:
1785
+ q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
1786
+ A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
1787
+ else:
1788
+ A = torch.matmul(q, k.transpose(-2, -1))
1789
+ if pos_k is not None:
1790
+ if self.h != self.h_k:
1791
+ B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
1792
+ else:
1793
+ reshape_q = (q.contiguous().view(n_batch * self.h, -1,
1794
+ self.d_k).transpose(0, 1)
1795
+ ) # (t1,nh,dk)
1796
+ B = torch.matmul(reshape_q,
1797
+ pos_k.transpose(-2,
1798
+ -1)) # pos_k: (t1,dk,t2)
1799
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0),
1800
+ pos_k.size(1))
1801
+ scores = A + B
1802
+ else:
1803
+ scores = A
1804
+
1805
+ if relative_attention_bias is not None:
1806
+ scores = scores + relative_attention_bias
1807
+
1808
+ attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
1809
+
1810
+ self.attn = attn
1811
+
1812
+ p_attn = self.dropout(attn)
1813
+ x = torch.matmul(p_attn.to(v.dtype),
1814
+ v) # (batch, head, time1, d_k)
1815
+ if pos_v is not None:
1816
+ reshape_attn = (p_attn.contiguous().view(
1817
+ n_batch * self.h, pos_v.size(0),
1818
+ pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2)
1819
+
1820
+ attn_v = (torch.matmul(reshape_attn, pos_v).transpose(
1821
+ 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0),
1822
+ self.d_k))
1823
+ x = x + attn_v
1824
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
1825
+ self.h_k * self.d_k)
1826
+ ) # (batch, time1, d_model)
1827
+
1828
+ return self.linear_out(x) # (batch, time1, d_model)
1829
+
1830
+
1831
+ class MultiSequential(torch.nn.Sequential):
1832
+ """Multi-input multi-output torch.nn.Sequential"""
1833
+
1834
+ @torch.jit.ignore
1835
+ def forward(self, *args):
1836
+ """Forward method implementation."""
1837
+ for m in self:
1838
+ args = m(*args)
1839
+ return args
1840
+
1841
+
1842
+ def get_offset(input_layer: str, time_reduction: int):
1843
+ """Get an offset. We will use the offset for determining #frames of a
1844
+ subsampled feature.
1845
+
1846
+ Args:
1847
+ input_layer (str): Type of an input layer
1848
+ time_reduction (int): time reduction factor for downsampling a feature
1849
+ Returns:
1850
+ int: offset
1851
+ """
1852
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
1853
+ return 3
1854
+ if input_layer in ("conv2d", ) and time_reduction == 6:
1855
+ return 1
1856
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
1857
+ return 7
1858
+ return 0
1859
+
1860
+
1861
+ def unfold_tensor(xs_pad, max_seq_len):
1862
+ """
1863
+ For a given tensor with shape of (N, T, D), if sequence length T is
1864
+ longer than max_seq_len, this function unfold it to a
1865
+ (NT', max_seq_len, D) where T' is T // max_seq_len.
1866
+ Args:
1867
+ xs_pad: N, T, D
1868
+ """
1869
+ _, _, D = xs_pad.shape
1870
+ xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
1871
+ # N x D x 1 x T => N x (D x max_seq_len) x T'
1872
+ xs_pad = F.unfold(
1873
+ xs_pad[..., None, :],
1874
+ kernel_size=(1, max_seq_len),
1875
+ stride=(1, max_seq_len),
1876
+ )
1877
+ new_bsz, _, slen = xs_pad.shape
1878
+ # N x D x max_seq_len x T'
1879
+ xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
1880
+ # N x T' x max_seq_len x D
1881
+ xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
1882
+ # NT' x max_seq_len x D
1883
+ xs_pad = xs_pad.view(-1, max_seq_len, D)
1884
+ return xs_pad