vllm-cpu 0.8.5.post2__cp310-cp310-manylinux_2_17_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of vllm-cpu might be problematic. Click here for more details.

Files changed (1103) hide show
  1. vllm/_C.abi3.so +0 -0
  2. vllm/__init__.py +170 -0
  3. vllm/_custom_ops.py +1536 -0
  4. vllm/_ipex_ops.py +241 -0
  5. vllm/_version.py +34 -0
  6. vllm/adapter_commons/__init__.py +0 -0
  7. vllm/adapter_commons/layers.py +16 -0
  8. vllm/adapter_commons/models.py +105 -0
  9. vllm/adapter_commons/request.py +25 -0
  10. vllm/adapter_commons/utils.py +92 -0
  11. vllm/adapter_commons/worker_manager.py +38 -0
  12. vllm/assets/__init__.py +0 -0
  13. vllm/assets/audio.py +38 -0
  14. vllm/assets/base.py +40 -0
  15. vllm/assets/image.py +31 -0
  16. vllm/assets/video.py +103 -0
  17. vllm/attention/__init__.py +19 -0
  18. vllm/attention/backends/__init__.py +0 -0
  19. vllm/attention/backends/abstract.py +306 -0
  20. vllm/attention/backends/blocksparse_attn.py +457 -0
  21. vllm/attention/backends/cpu_mla.py +303 -0
  22. vllm/attention/backends/flash_attn.py +999 -0
  23. vllm/attention/backends/flashinfer.py +1092 -0
  24. vllm/attention/backends/flashmla.py +242 -0
  25. vllm/attention/backends/hpu_attn.py +301 -0
  26. vllm/attention/backends/ipex_attn.py +396 -0
  27. vllm/attention/backends/mla/__init__.py +0 -0
  28. vllm/attention/backends/mla/common.py +1444 -0
  29. vllm/attention/backends/pallas.py +346 -0
  30. vllm/attention/backends/placeholder_attn.py +399 -0
  31. vllm/attention/backends/rocm_aiter_mla.py +412 -0
  32. vllm/attention/backends/rocm_flash_attn.py +969 -0
  33. vllm/attention/backends/torch_sdpa.py +691 -0
  34. vllm/attention/backends/triton_mla.py +113 -0
  35. vllm/attention/backends/utils.py +609 -0
  36. vllm/attention/backends/xformers.py +798 -0
  37. vllm/attention/layer.py +443 -0
  38. vllm/attention/ops/__init__.py +0 -0
  39. vllm/attention/ops/blocksparse_attention/__init__.py +0 -0
  40. vllm/attention/ops/blocksparse_attention/blocksparse_attention_kernel.py +432 -0
  41. vllm/attention/ops/blocksparse_attention/interface.py +238 -0
  42. vllm/attention/ops/blocksparse_attention/utils.py +244 -0
  43. vllm/attention/ops/chunked_prefill_paged_decode.py +366 -0
  44. vllm/attention/ops/flashmla.py +115 -0
  45. vllm/attention/ops/hpu_paged_attn.py +105 -0
  46. vllm/attention/ops/ipex_attn.py +193 -0
  47. vllm/attention/ops/merge_attn_states.py +42 -0
  48. vllm/attention/ops/nki_flash_attn.py +905 -0
  49. vllm/attention/ops/paged_attn.py +255 -0
  50. vllm/attention/ops/prefix_prefill.py +902 -0
  51. vllm/attention/ops/rocm_aiter_mla.py +42 -0
  52. vllm/attention/ops/rocm_aiter_paged_attn.py +101 -0
  53. vllm/attention/ops/triton_decode_attention.py +675 -0
  54. vllm/attention/ops/triton_flash_attention.py +1375 -0
  55. vllm/attention/ops/triton_merge_attn_states.py +96 -0
  56. vllm/attention/selector.py +186 -0
  57. vllm/attention/utils/fa_utils.py +54 -0
  58. vllm/beam_search.py +82 -0
  59. vllm/benchmarks/__init__.py +0 -0
  60. vllm/benchmarks/datasets.py +831 -0
  61. vllm/benchmarks/endpoint_request_func.py +160 -0
  62. vllm/benchmarks/latency.py +181 -0
  63. vllm/benchmarks/serve.py +925 -0
  64. vllm/benchmarks/throughput.py +608 -0
  65. vllm/benchmarks/utils.py +69 -0
  66. vllm/collect_env.py +795 -0
  67. vllm/compilation/__init__.py +0 -0
  68. vllm/compilation/backends.py +715 -0
  69. vllm/compilation/compiler_interface.py +437 -0
  70. vllm/compilation/counter.py +33 -0
  71. vllm/compilation/decorators.py +249 -0
  72. vllm/compilation/fix_functionalization.py +182 -0
  73. vllm/compilation/fusion.py +617 -0
  74. vllm/compilation/fx_utils.py +60 -0
  75. vllm/compilation/inductor_pass.py +114 -0
  76. vllm/compilation/monitor.py +38 -0
  77. vllm/compilation/multi_output_match.py +108 -0
  78. vllm/compilation/noop_elimination.py +135 -0
  79. vllm/compilation/pass_manager.py +74 -0
  80. vllm/compilation/sequence_parallelism.py +266 -0
  81. vllm/compilation/torch25_custom_graph_pass.py +41 -0
  82. vllm/compilation/vllm_inductor_pass.py +68 -0
  83. vllm/compilation/wrapper.py +129 -0
  84. vllm/config.py +4179 -0
  85. vllm/connections.py +170 -0
  86. vllm/core/__init__.py +0 -0
  87. vllm/core/block/__init__.py +0 -0
  88. vllm/core/block/block_table.py +398 -0
  89. vllm/core/block/common.py +370 -0
  90. vllm/core/block/cpu_gpu_block_allocator.py +440 -0
  91. vllm/core/block/interfaces.py +318 -0
  92. vllm/core/block/naive_block.py +465 -0
  93. vllm/core/block/prefix_caching_block.py +1134 -0
  94. vllm/core/block/utils.py +27 -0
  95. vllm/core/block_manager.py +520 -0
  96. vllm/core/evictor.py +156 -0
  97. vllm/core/interfaces.py +134 -0
  98. vllm/core/placeholder_block_space_manager.py +99 -0
  99. vllm/core/scheduler.py +2060 -0
  100. vllm/device_allocator/__init__.py +0 -0
  101. vllm/device_allocator/cumem.py +280 -0
  102. vllm/distributed/__init__.py +5 -0
  103. vllm/distributed/communication_op.py +40 -0
  104. vllm/distributed/device_communicators/__init__.py +0 -0
  105. vllm/distributed/device_communicators/base_device_communicator.py +151 -0
  106. vllm/distributed/device_communicators/cpu_communicator.py +139 -0
  107. vllm/distributed/device_communicators/cuda_communicator.py +131 -0
  108. vllm/distributed/device_communicators/cuda_wrapper.py +179 -0
  109. vllm/distributed/device_communicators/custom_all_reduce.py +301 -0
  110. vllm/distributed/device_communicators/custom_all_reduce_utils.py +257 -0
  111. vllm/distributed/device_communicators/hpu_communicator.py +45 -0
  112. vllm/distributed/device_communicators/neuron_communicator.py +19 -0
  113. vllm/distributed/device_communicators/pynccl.py +217 -0
  114. vllm/distributed/device_communicators/pynccl_wrapper.py +340 -0
  115. vllm/distributed/device_communicators/shm_broadcast.py +557 -0
  116. vllm/distributed/device_communicators/tpu_communicator.py +93 -0
  117. vllm/distributed/device_communicators/xpu_communicator.py +54 -0
  118. vllm/distributed/kv_transfer/README.md +29 -0
  119. vllm/distributed/kv_transfer/__init__.py +11 -0
  120. vllm/distributed/kv_transfer/disagg_prefill_workflow.jpg +0 -0
  121. vllm/distributed/kv_transfer/kv_connector/__init__.py +0 -0
  122. vllm/distributed/kv_transfer/kv_connector/base.py +127 -0
  123. vllm/distributed/kv_transfer/kv_connector/factory.py +107 -0
  124. vllm/distributed/kv_transfer/kv_connector/lmcache_connector.py +98 -0
  125. vllm/distributed/kv_transfer/kv_connector/mooncake_store_connector.py +201 -0
  126. vllm/distributed/kv_transfer/kv_connector/simple_connector.py +328 -0
  127. vllm/distributed/kv_transfer/kv_connector/utils.py +90 -0
  128. vllm/distributed/kv_transfer/kv_connector/v1/__init__.py +8 -0
  129. vllm/distributed/kv_transfer/kv_connector/v1/base.py +209 -0
  130. vllm/distributed/kv_transfer/kv_connector/v1/lmcache_connector.py +131 -0
  131. vllm/distributed/kv_transfer/kv_connector/v1/shared_storage_connector.py +383 -0
  132. vllm/distributed/kv_transfer/kv_connector_agent.py +76 -0
  133. vllm/distributed/kv_transfer/kv_lookup_buffer/__init__.py +0 -0
  134. vllm/distributed/kv_transfer/kv_lookup_buffer/base.py +174 -0
  135. vllm/distributed/kv_transfer/kv_lookup_buffer/mooncake_store.py +160 -0
  136. vllm/distributed/kv_transfer/kv_lookup_buffer/simple_buffer.py +236 -0
  137. vllm/distributed/kv_transfer/kv_pipe/__init__.py +0 -0
  138. vllm/distributed/kv_transfer/kv_pipe/base.py +66 -0
  139. vllm/distributed/kv_transfer/kv_pipe/mooncake_pipe.py +279 -0
  140. vllm/distributed/kv_transfer/kv_pipe/pynccl_pipe.py +279 -0
  141. vllm/distributed/kv_transfer/kv_transfer_state.py +70 -0
  142. vllm/distributed/parallel_state.py +1209 -0
  143. vllm/distributed/utils.py +366 -0
  144. vllm/engine/__init__.py +0 -0
  145. vllm/engine/arg_utils.py +1724 -0
  146. vllm/engine/async_llm_engine.py +1261 -0
  147. vllm/engine/async_timeout.py +191 -0
  148. vllm/engine/llm_engine.py +2150 -0
  149. vllm/engine/metrics.py +717 -0
  150. vllm/engine/metrics_types.py +96 -0
  151. vllm/engine/multiprocessing/__init__.py +183 -0
  152. vllm/engine/multiprocessing/client.py +745 -0
  153. vllm/engine/multiprocessing/engine.py +450 -0
  154. vllm/engine/output_processor/__init__.py +0 -0
  155. vllm/engine/output_processor/interfaces.py +74 -0
  156. vllm/engine/output_processor/multi_step.py +210 -0
  157. vllm/engine/output_processor/single_step.py +136 -0
  158. vllm/engine/output_processor/stop_checker.py +130 -0
  159. vllm/engine/output_processor/util.py +27 -0
  160. vllm/engine/protocol.py +302 -0
  161. vllm/entrypoints/__init__.py +0 -0
  162. vllm/entrypoints/api_server.py +177 -0
  163. vllm/entrypoints/chat_utils.py +1259 -0
  164. vllm/entrypoints/cli/__init__.py +0 -0
  165. vllm/entrypoints/cli/benchmark/__init__.py +0 -0
  166. vllm/entrypoints/cli/benchmark/base.py +38 -0
  167. vllm/entrypoints/cli/benchmark/latency.py +29 -0
  168. vllm/entrypoints/cli/benchmark/main.py +53 -0
  169. vllm/entrypoints/cli/benchmark/serve.py +29 -0
  170. vllm/entrypoints/cli/benchmark/throughput.py +29 -0
  171. vllm/entrypoints/cli/collect_env.py +35 -0
  172. vllm/entrypoints/cli/main.py +59 -0
  173. vllm/entrypoints/cli/openai.py +175 -0
  174. vllm/entrypoints/cli/serve.py +59 -0
  175. vllm/entrypoints/cli/types.py +24 -0
  176. vllm/entrypoints/launcher.py +146 -0
  177. vllm/entrypoints/llm.py +1450 -0
  178. vllm/entrypoints/logger.py +44 -0
  179. vllm/entrypoints/openai/__init__.py +0 -0
  180. vllm/entrypoints/openai/api_server.py +1130 -0
  181. vllm/entrypoints/openai/cli_args.py +296 -0
  182. vllm/entrypoints/openai/logits_processors.py +89 -0
  183. vllm/entrypoints/openai/protocol.py +1806 -0
  184. vllm/entrypoints/openai/run_batch.py +439 -0
  185. vllm/entrypoints/openai/serving_chat.py +1210 -0
  186. vllm/entrypoints/openai/serving_completion.py +557 -0
  187. vllm/entrypoints/openai/serving_embedding.py +245 -0
  188. vllm/entrypoints/openai/serving_engine.py +569 -0
  189. vllm/entrypoints/openai/serving_models.py +314 -0
  190. vllm/entrypoints/openai/serving_pooling.py +237 -0
  191. vllm/entrypoints/openai/serving_score.py +439 -0
  192. vllm/entrypoints/openai/serving_tokenization.py +147 -0
  193. vllm/entrypoints/openai/serving_transcription.py +421 -0
  194. vllm/entrypoints/openai/tool_parsers/__init__.py +19 -0
  195. vllm/entrypoints/openai/tool_parsers/abstract_tool_parser.py +163 -0
  196. vllm/entrypoints/openai/tool_parsers/granite_20b_fc_tool_parser.py +254 -0
  197. vllm/entrypoints/openai/tool_parsers/granite_tool_parser.py +232 -0
  198. vllm/entrypoints/openai/tool_parsers/hermes_tool_parser.py +370 -0
  199. vllm/entrypoints/openai/tool_parsers/internlm2_tool_parser.py +211 -0
  200. vllm/entrypoints/openai/tool_parsers/jamba_tool_parser.py +303 -0
  201. vllm/entrypoints/openai/tool_parsers/llama_tool_parser.py +262 -0
  202. vllm/entrypoints/openai/tool_parsers/mistral_tool_parser.py +342 -0
  203. vllm/entrypoints/openai/tool_parsers/phi4mini_tool_parser.py +110 -0
  204. vllm/entrypoints/openai/tool_parsers/pythonic_tool_parser.py +292 -0
  205. vllm/entrypoints/openai/tool_parsers/utils.py +123 -0
  206. vllm/entrypoints/score_utils.py +49 -0
  207. vllm/entrypoints/ssl.py +74 -0
  208. vllm/entrypoints/utils.py +136 -0
  209. vllm/env_override.py +34 -0
  210. vllm/envs.py +800 -0
  211. vllm/executor/__init__.py +0 -0
  212. vllm/executor/executor_base.py +400 -0
  213. vllm/executor/mp_distributed_executor.py +243 -0
  214. vllm/executor/msgspec_utils.py +29 -0
  215. vllm/executor/multiproc_worker_utils.py +312 -0
  216. vllm/executor/ray_distributed_executor.py +700 -0
  217. vllm/executor/ray_utils.py +400 -0
  218. vllm/executor/uniproc_executor.py +141 -0
  219. vllm/forward_context.py +159 -0
  220. vllm/inputs/__init__.py +37 -0
  221. vllm/inputs/data.py +248 -0
  222. vllm/inputs/parse.py +121 -0
  223. vllm/inputs/preprocess.py +745 -0
  224. vllm/inputs/registry.py +212 -0
  225. vllm/jsontree.py +79 -0
  226. vllm/logger.py +210 -0
  227. vllm/logging_utils/__init__.py +7 -0
  228. vllm/logging_utils/formatter.py +17 -0
  229. vllm/logits_process.py +121 -0
  230. vllm/lora/__init__.py +0 -0
  231. vllm/lora/fully_sharded_layers.py +335 -0
  232. vllm/lora/layers.py +1263 -0
  233. vllm/lora/lora.py +198 -0
  234. vllm/lora/models.py +802 -0
  235. vllm/lora/ops/__init__.py +0 -0
  236. vllm/lora/ops/torch_ops/__init__.py +15 -0
  237. vllm/lora/ops/torch_ops/lora_ops.py +115 -0
  238. vllm/lora/ops/triton_ops/__init__.py +11 -0
  239. vllm/lora/ops/triton_ops/kernel_utils.py +243 -0
  240. vllm/lora/ops/triton_ops/lora_expand.py +293 -0
  241. vllm/lora/ops/triton_ops/lora_kernel_metadata.py +147 -0
  242. vllm/lora/ops/triton_ops/lora_shrink.py +247 -0
  243. vllm/lora/ops/triton_ops/utils.py +121 -0
  244. vllm/lora/peft_helper.py +115 -0
  245. vllm/lora/punica_wrapper/__init__.py +9 -0
  246. vllm/lora/punica_wrapper/punica_base.py +483 -0
  247. vllm/lora/punica_wrapper/punica_cpu.py +348 -0
  248. vllm/lora/punica_wrapper/punica_gpu.py +289 -0
  249. vllm/lora/punica_wrapper/punica_hpu.py +144 -0
  250. vllm/lora/punica_wrapper/punica_selector.py +20 -0
  251. vllm/lora/punica_wrapper/utils.py +161 -0
  252. vllm/lora/request.py +97 -0
  253. vllm/lora/resolver.py +83 -0
  254. vllm/lora/utils.py +237 -0
  255. vllm/lora/worker_manager.py +251 -0
  256. vllm/model_executor/__init__.py +15 -0
  257. vllm/model_executor/custom_op.py +153 -0
  258. vllm/model_executor/guided_decoding/__init__.py +180 -0
  259. vllm/model_executor/guided_decoding/guidance_decoding.py +63 -0
  260. vllm/model_executor/guided_decoding/guidance_logits_processors.py +85 -0
  261. vllm/model_executor/guided_decoding/guided_fields.py +42 -0
  262. vllm/model_executor/guided_decoding/lm_format_enforcer_decoding.py +66 -0
  263. vllm/model_executor/guided_decoding/outlines_decoding.py +154 -0
  264. vllm/model_executor/guided_decoding/outlines_logits_processors.py +271 -0
  265. vllm/model_executor/guided_decoding/reasoner/__init__.py +35 -0
  266. vllm/model_executor/guided_decoding/utils.py +241 -0
  267. vllm/model_executor/guided_decoding/xgrammar_decoding.py +425 -0
  268. vllm/model_executor/layers/__init__.py +0 -0
  269. vllm/model_executor/layers/activation.py +368 -0
  270. vllm/model_executor/layers/fused_moe/__init__.py +51 -0
  271. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  272. vllm/model_executor/layers/fused_moe/configs/E=1,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  273. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  274. vllm/model_executor/layers/fused_moe/configs/E=1,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  275. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  276. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +218 -0
  277. vllm/model_executor/layers/fused_moe/configs/E=1,N=3072,device_name=NVIDIA_H100_80GB_HBM3.json +218 -0
  278. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  279. vllm/model_executor/layers/fused_moe/configs/E=1,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  280. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  281. vllm/model_executor/layers/fused_moe/configs/E=1,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  282. vllm/model_executor/layers/fused_moe/configs/E=128,N=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  283. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  284. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H20.json +146 -0
  285. vllm/model_executor/layers/fused_moe/configs/E=128,N=192,device_name=NVIDIA_H200.json +146 -0
  286. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  287. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H20.json +146 -0
  288. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  289. vllm/model_executor/layers/fused_moe/configs/E=128,N=384,device_name=NVIDIA_H200.json +146 -0
  290. vllm/model_executor/layers/fused_moe/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  291. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  292. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  293. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  294. vllm/model_executor/layers/fused_moe/configs/E=128,N=768,device_name=NVIDIA_H200.json +146 -0
  295. vllm/model_executor/layers/fused_moe/configs/E=128,N=96,device_name=NVIDIA_H20.json +146 -0
  296. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=AMD_Instinct_MI300X.json +200 -0
  297. vllm/model_executor/layers/fused_moe/configs/E=16,N=1024,device_name=NVIDIA_H100.json +146 -0
  298. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  299. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  300. vllm/model_executor/layers/fused_moe/configs/E=16,N=1344,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  301. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  302. vllm/model_executor/layers/fused_moe/configs/E=16,N=14336,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  303. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +218 -0
  304. vllm/model_executor/layers/fused_moe/configs/E=16,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  305. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  306. vllm/model_executor/layers/fused_moe/configs/E=16,N=2688,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  307. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  308. vllm/model_executor/layers/fused_moe/configs/E=16,N=3072,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  309. vllm/model_executor/layers/fused_moe/configs/E=16,N=3200,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  310. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  311. vllm/model_executor/layers/fused_moe/configs/E=16,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +218 -0
  312. vllm/model_executor/layers/fused_moe/configs/E=16,N=6400,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  313. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a16.json +146 -0
  314. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  315. vllm/model_executor/layers/fused_moe/configs/E=16,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=int8_w8a16.json +146 -0
  316. vllm/model_executor/layers/fused_moe/configs/E=16,N=800,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +130 -0
  317. vllm/model_executor/layers/fused_moe/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  318. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325X,block_shape=[128,128].json +200 -0
  319. vllm/model_executor/layers/fused_moe/configs/E=256,N=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  320. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  321. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  322. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  323. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  324. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  325. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  326. vllm/model_executor/layers/fused_moe/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  327. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  328. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  329. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  330. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  331. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  332. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  333. vllm/model_executor/layers/fused_moe/configs/E=256,N=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  334. vllm/model_executor/layers/fused_moe/configs/E=256,N=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +200 -0
  335. vllm/model_executor/layers/fused_moe/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  336. vllm/model_executor/layers/fused_moe/configs/E=60,N=1408,device_name=AMD_Instinct_MI300X.json +200 -0
  337. vllm/model_executor/layers/fused_moe/configs/E=60,N=176,device_name=AMD_Instinct_MI300X.json +200 -0
  338. vllm/model_executor/layers/fused_moe/configs/E=60,N=352,device_name=AMD_Instinct_MI300X.json +200 -0
  339. vllm/model_executor/layers/fused_moe/configs/E=60,N=704,device_name=AMD_Instinct_MI300X.json +200 -0
  340. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  341. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  342. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  343. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  344. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  345. vllm/model_executor/layers/fused_moe/configs/E=64,N=1280,device_name=NVIDIA_H200.json +146 -0
  346. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  347. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  348. vllm/model_executor/layers/fused_moe/configs/E=64,N=2560,device_name=NVIDIA_H200.json +146 -0
  349. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  350. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  351. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  352. vllm/model_executor/layers/fused_moe/configs/E=64,N=320,device_name=NVIDIA_H200.json +146 -0
  353. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  354. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  355. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  356. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  357. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  358. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  359. vllm/model_executor/layers/fused_moe/configs/E=64,N=640,device_name=NVIDIA_H200.json +146 -0
  360. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  361. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +200 -0
  362. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  363. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  364. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +138 -0
  365. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  366. vllm/model_executor/layers/fused_moe/configs/E=8,N=14336,device_name=NVIDIA_H200.json +146 -0
  367. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  368. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI300X.json +200 -0
  369. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  370. vllm/model_executor/layers/fused_moe/configs/E=8,N=16384,device_name=AMD_Instinct_MI325X.json +200 -0
  371. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  372. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +200 -0
  373. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  374. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  375. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  376. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  377. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  378. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  379. vllm/model_executor/layers/fused_moe/configs/E=8,N=1792,device_name=NVIDIA_H200.json +146 -0
  380. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  381. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI300X.json +200 -0
  382. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  383. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=AMD_Instinct_MI325X.json +200 -0
  384. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  385. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  386. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  387. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  388. vllm/model_executor/layers/fused_moe/configs/E=8,N=2048,device_name=NVIDIA_H200.json +146 -0
  389. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  390. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +200 -0
  391. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  392. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  393. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-40GB.json +146 -0
  394. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  395. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_GeForce_RTX_4090,dtype=fp8_w8a8.json +146 -0
  396. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  397. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  398. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  399. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_H200.json +146 -0
  400. vllm/model_executor/layers/fused_moe/configs/E=8,N=3584,device_name=NVIDIA_L40S.json +173 -0
  401. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  402. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X.json +200 -0
  403. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  404. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X.json +200 -0
  405. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  406. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  407. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  408. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  409. vllm/model_executor/layers/fused_moe/configs/E=8,N=4096,device_name=NVIDIA_H200.json +146 -0
  410. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  411. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +200 -0
  412. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  413. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  414. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  415. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  416. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  417. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  418. vllm/model_executor/layers/fused_moe/configs/E=8,N=7168,device_name=NVIDIA_H200.json +146 -0
  419. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +164 -0
  420. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X.json +200 -0
  421. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +164 -0
  422. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X.json +200 -0
  423. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8.json +146 -0
  424. vllm/model_executor/layers/fused_moe/configs/E=8,N=8192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
  425. vllm/model_executor/layers/fused_moe/configs/README +12 -0
  426. vllm/model_executor/layers/fused_moe/cutlass_moe.py +180 -0
  427. vllm/model_executor/layers/fused_moe/deep_gemm_moe.py +294 -0
  428. vllm/model_executor/layers/fused_moe/fused_marlin_moe.py +374 -0
  429. vllm/model_executor/layers/fused_moe/fused_moe.py +1539 -0
  430. vllm/model_executor/layers/fused_moe/layer.py +949 -0
  431. vllm/model_executor/layers/fused_moe/moe_align_block_size.py +243 -0
  432. vllm/model_executor/layers/fused_moe/moe_pallas.py +64 -0
  433. vllm/model_executor/layers/fused_moe/moe_torch_iterative.py +59 -0
  434. vllm/model_executor/layers/fused_moe/rocm_aiter_fused_moe.py +416 -0
  435. vllm/model_executor/layers/fused_moe/utils.py +48 -0
  436. vllm/model_executor/layers/layernorm.py +277 -0
  437. vllm/model_executor/layers/lightning_attn.py +651 -0
  438. vllm/model_executor/layers/linear.py +1518 -0
  439. vllm/model_executor/layers/logits_processor.py +196 -0
  440. vllm/model_executor/layers/mamba/__init__.py +0 -0
  441. vllm/model_executor/layers/mamba/mamba2_metadata.py +109 -0
  442. vllm/model_executor/layers/mamba/mamba_mixer.py +244 -0
  443. vllm/model_executor/layers/mamba/mamba_mixer2.py +538 -0
  444. vllm/model_executor/layers/mamba/ops/__init__.py +0 -0
  445. vllm/model_executor/layers/mamba/ops/causal_conv1d.py +104 -0
  446. vllm/model_executor/layers/mamba/ops/mamba_ssm.py +415 -0
  447. vllm/model_executor/layers/mamba/ops/ssd_bmm.py +261 -0
  448. vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py +588 -0
  449. vllm/model_executor/layers/mamba/ops/ssd_chunk_state.py +750 -0
  450. vllm/model_executor/layers/mamba/ops/ssd_combined.py +231 -0
  451. vllm/model_executor/layers/mamba/ops/ssd_state_passing.py +205 -0
  452. vllm/model_executor/layers/pooler.py +336 -0
  453. vllm/model_executor/layers/quantization/__init__.py +153 -0
  454. vllm/model_executor/layers/quantization/aqlm.py +374 -0
  455. vllm/model_executor/layers/quantization/awq.py +184 -0
  456. vllm/model_executor/layers/quantization/awq_marlin.py +518 -0
  457. vllm/model_executor/layers/quantization/awq_triton.py +319 -0
  458. vllm/model_executor/layers/quantization/base_config.py +145 -0
  459. vllm/model_executor/layers/quantization/bitblas.py +459 -0
  460. vllm/model_executor/layers/quantization/bitsandbytes.py +396 -0
  461. vllm/model_executor/layers/quantization/compressed_tensors/__init__.py +0 -0
  462. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors.py +624 -0
  463. vllm/model_executor/layers/quantization/compressed_tensors/compressed_tensors_moe.py +1100 -0
  464. vllm/model_executor/layers/quantization/compressed_tensors/schemes/__init__.py +20 -0
  465. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_24.py +357 -0
  466. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +54 -0
  467. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w4a16_24.py +159 -0
  468. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +119 -0
  469. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +149 -0
  470. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +110 -0
  471. vllm/model_executor/layers/quantization/compressed_tensors/schemes/compressed_tensors_wNa16.py +200 -0
  472. vllm/model_executor/layers/quantization/compressed_tensors/triton_scaled_mm.py +205 -0
  473. vllm/model_executor/layers/quantization/compressed_tensors/utils.py +213 -0
  474. vllm/model_executor/layers/quantization/deepspeedfp.py +193 -0
  475. vllm/model_executor/layers/quantization/experts_int8.py +194 -0
  476. vllm/model_executor/layers/quantization/fbgemm_fp8.py +168 -0
  477. vllm/model_executor/layers/quantization/fp8.py +832 -0
  478. vllm/model_executor/layers/quantization/gguf.py +408 -0
  479. vllm/model_executor/layers/quantization/gptq.py +276 -0
  480. vllm/model_executor/layers/quantization/gptq_bitblas.py +438 -0
  481. vllm/model_executor/layers/quantization/gptq_marlin.py +643 -0
  482. vllm/model_executor/layers/quantization/gptq_marlin_24.py +295 -0
  483. vllm/model_executor/layers/quantization/hqq_marlin.py +328 -0
  484. vllm/model_executor/layers/quantization/ipex_quant.py +250 -0
  485. vllm/model_executor/layers/quantization/kernels/__init__.py +0 -0
  486. vllm/model_executor/layers/quantization/kernels/mixed_precision/MPLinearKernel.py +89 -0
  487. vllm/model_executor/layers/quantization/kernels/mixed_precision/__init__.py +82 -0
  488. vllm/model_executor/layers/quantization/kernels/mixed_precision/allspark.py +115 -0
  489. vllm/model_executor/layers/quantization/kernels/mixed_precision/bitblas.py +299 -0
  490. vllm/model_executor/layers/quantization/kernels/mixed_precision/exllama.py +142 -0
  491. vllm/model_executor/layers/quantization/kernels/mixed_precision/machete.py +119 -0
  492. vllm/model_executor/layers/quantization/kernels/mixed_precision/marlin.py +132 -0
  493. vllm/model_executor/layers/quantization/kernels/scaled_mm/ScaledMMLinearKernel.py +66 -0
  494. vllm/model_executor/layers/quantization/kernels/scaled_mm/__init__.py +86 -0
  495. vllm/model_executor/layers/quantization/kernels/scaled_mm/aiter.py +119 -0
  496. vllm/model_executor/layers/quantization/kernels/scaled_mm/cutlass.py +136 -0
  497. vllm/model_executor/layers/quantization/kernels/scaled_mm/triton.py +40 -0
  498. vllm/model_executor/layers/quantization/kernels/scaled_mm/xla.py +104 -0
  499. vllm/model_executor/layers/quantization/kv_cache.py +137 -0
  500. vllm/model_executor/layers/quantization/marlin.py +259 -0
  501. vllm/model_executor/layers/quantization/modelopt.py +410 -0
  502. vllm/model_executor/layers/quantization/moe_wna16.py +447 -0
  503. vllm/model_executor/layers/quantization/neuron_quant.py +67 -0
  504. vllm/model_executor/layers/quantization/ptpc_fp8.py +125 -0
  505. vllm/model_executor/layers/quantization/qqq.py +273 -0
  506. vllm/model_executor/layers/quantization/quark/__init__.py +0 -0
  507. vllm/model_executor/layers/quantization/quark/quark.py +385 -0
  508. vllm/model_executor/layers/quantization/quark/quark_moe.py +236 -0
  509. vllm/model_executor/layers/quantization/quark/schemes/__init__.py +7 -0
  510. vllm/model_executor/layers/quantization/quark/schemes/quark_scheme.py +54 -0
  511. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_fp8.py +142 -0
  512. vllm/model_executor/layers/quantization/quark/schemes/quark_w8a8_int8.py +121 -0
  513. vllm/model_executor/layers/quantization/quark/utils.py +102 -0
  514. vllm/model_executor/layers/quantization/schema.py +85 -0
  515. vllm/model_executor/layers/quantization/torchao.py +127 -0
  516. vllm/model_executor/layers/quantization/tpu_int8.py +119 -0
  517. vllm/model_executor/layers/quantization/utils/__init__.py +5 -0
  518. vllm/model_executor/layers/quantization/utils/allspark_utils.py +51 -0
  519. vllm/model_executor/layers/quantization/utils/bitblas_utils.py +198 -0
  520. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  521. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  522. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  523. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  524. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  525. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  526. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  527. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  528. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  529. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  530. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  531. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  532. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  533. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  534. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  535. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  536. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  537. vllm/model_executor/layers/quantization/utils/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  538. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  539. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  540. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  541. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  542. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  543. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  544. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  545. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  546. vllm/model_executor/layers/quantization/utils/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  547. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  548. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  549. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  550. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  551. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  552. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  553. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  554. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  555. vllm/model_executor/layers/quantization/utils/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  556. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  557. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  558. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  559. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  560. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  561. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  562. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  563. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  564. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  565. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  566. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  567. vllm/model_executor/layers/quantization/utils/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  568. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  569. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  570. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  571. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  572. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  573. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  574. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  575. vllm/model_executor/layers/quantization/utils/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  576. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  577. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  578. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  579. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  580. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  581. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  582. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=1536,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  583. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  584. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  585. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  586. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  587. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  588. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  589. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  590. vllm/model_executor/layers/quantization/utils/configs/N=3072,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  591. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  592. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  593. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  594. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  595. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  596. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  597. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  598. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  599. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  600. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  601. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  602. vllm/model_executor/layers/quantization/utils/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  603. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  604. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  605. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  606. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  607. vllm/model_executor/layers/quantization/utils/configs/N=36864,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  608. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  609. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  610. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  611. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  612. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  613. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  614. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  615. vllm/model_executor/layers/quantization/utils/configs/N=4096,K=512,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  616. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  617. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  618. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  619. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  620. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  621. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  622. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  623. vllm/model_executor/layers/quantization/utils/configs/N=4608,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  624. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  625. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  626. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  627. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  628. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  629. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  630. vllm/model_executor/layers/quantization/utils/configs/N=512,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  631. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  632. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  633. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  634. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  635. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  636. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  637. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  638. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  639. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  640. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  641. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +18 -0
  642. vllm/model_executor/layers/quantization/utils/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  643. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  644. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  645. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  646. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  647. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  648. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  649. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  650. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  651. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  652. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  653. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  654. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  655. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  656. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  657. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  658. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  659. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  660. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  661. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  662. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  663. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  664. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  665. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  666. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  667. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  668. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  669. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  670. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  671. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  672. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  673. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  674. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  675. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  676. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  677. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  678. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  679. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  680. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  681. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  682. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  683. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  684. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  685. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  686. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  687. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  688. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  689. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  690. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  691. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  692. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  693. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  694. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  695. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  696. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  697. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  698. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  699. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  700. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2048,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  701. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  702. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  703. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  704. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  705. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  706. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  707. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  708. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=2304,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  709. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  710. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  711. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  712. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  713. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128,128].json +146 -0
  714. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128,128].json +146 -0
  715. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=256,device_name=NVIDIA_L20,dtype=fp8_w8a8,block_shape=[128,128].json +26 -0
  716. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  717. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  718. vllm/model_executor/layers/quantization/utils/configs/N=7168,K=8192,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  719. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  720. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  721. vllm/model_executor/layers/quantization/utils/configs/N=8192,K=1536,device_name=AMD_Instinct_MI325_OAM,dtype=fp8_w8a8,block_shape=[128,128].json +164 -0
  722. vllm/model_executor/layers/quantization/utils/fp8_utils.py +523 -0
  723. vllm/model_executor/layers/quantization/utils/gptq_utils.py +94 -0
  724. vllm/model_executor/layers/quantization/utils/int8_utils.py +459 -0
  725. vllm/model_executor/layers/quantization/utils/layer_utils.py +39 -0
  726. vllm/model_executor/layers/quantization/utils/machete_utils.py +32 -0
  727. vllm/model_executor/layers/quantization/utils/marlin_utils.py +413 -0
  728. vllm/model_executor/layers/quantization/utils/marlin_utils_fp8.py +110 -0
  729. vllm/model_executor/layers/quantization/utils/marlin_utils_test.py +164 -0
  730. vllm/model_executor/layers/quantization/utils/marlin_utils_test_24.py +464 -0
  731. vllm/model_executor/layers/quantization/utils/marlin_utils_test_qqq.py +127 -0
  732. vllm/model_executor/layers/quantization/utils/quant_utils.py +571 -0
  733. vllm/model_executor/layers/quantization/utils/w8a8_utils.py +404 -0
  734. vllm/model_executor/layers/rejection_sampler.py +400 -0
  735. vllm/model_executor/layers/resampler.py +269 -0
  736. vllm/model_executor/layers/rotary_embedding.py +1598 -0
  737. vllm/model_executor/layers/sampler.py +1221 -0
  738. vllm/model_executor/layers/spec_decode_base_sampler.py +258 -0
  739. vllm/model_executor/layers/typical_acceptance_sampler.py +172 -0
  740. vllm/model_executor/layers/utils.py +99 -0
  741. vllm/model_executor/layers/vocab_parallel_embedding.py +485 -0
  742. vllm/model_executor/model_loader/__init__.py +20 -0
  743. vllm/model_executor/model_loader/loader.py +1542 -0
  744. vllm/model_executor/model_loader/neuron.py +243 -0
  745. vllm/model_executor/model_loader/tensorizer.py +468 -0
  746. vllm/model_executor/model_loader/utils.py +171 -0
  747. vllm/model_executor/model_loader/weight_utils.py +749 -0
  748. vllm/model_executor/models/__init__.py +27 -0
  749. vllm/model_executor/models/adapters.py +247 -0
  750. vllm/model_executor/models/arctic.py +559 -0
  751. vllm/model_executor/models/aria.py +656 -0
  752. vllm/model_executor/models/aya_vision.py +461 -0
  753. vllm/model_executor/models/baichuan.py +469 -0
  754. vllm/model_executor/models/bamba.py +542 -0
  755. vllm/model_executor/models/bart.py +936 -0
  756. vllm/model_executor/models/bert.py +725 -0
  757. vllm/model_executor/models/blip.py +337 -0
  758. vllm/model_executor/models/blip2.py +717 -0
  759. vllm/model_executor/models/bloom.py +358 -0
  760. vllm/model_executor/models/chameleon.py +1135 -0
  761. vllm/model_executor/models/chatglm.py +476 -0
  762. vllm/model_executor/models/clip.py +410 -0
  763. vllm/model_executor/models/commandr.py +466 -0
  764. vllm/model_executor/models/constant_size_cache.py +136 -0
  765. vllm/model_executor/models/dbrx.py +469 -0
  766. vllm/model_executor/models/deepseek.py +484 -0
  767. vllm/model_executor/models/deepseek_mtp.py +266 -0
  768. vllm/model_executor/models/deepseek_v2.py +830 -0
  769. vllm/model_executor/models/deepseek_vl2.py +647 -0
  770. vllm/model_executor/models/eagle.py +247 -0
  771. vllm/model_executor/models/exaone.py +548 -0
  772. vllm/model_executor/models/fairseq2_llama.py +153 -0
  773. vllm/model_executor/models/falcon.py +508 -0
  774. vllm/model_executor/models/florence2.py +1102 -0
  775. vllm/model_executor/models/fuyu.py +388 -0
  776. vllm/model_executor/models/gemma.py +423 -0
  777. vllm/model_executor/models/gemma2.py +423 -0
  778. vllm/model_executor/models/gemma3.py +531 -0
  779. vllm/model_executor/models/gemma3_mm.py +716 -0
  780. vllm/model_executor/models/glm.py +22 -0
  781. vllm/model_executor/models/glm4.py +303 -0
  782. vllm/model_executor/models/glm4v.py +647 -0
  783. vllm/model_executor/models/gpt2.py +313 -0
  784. vllm/model_executor/models/gpt_bigcode.py +336 -0
  785. vllm/model_executor/models/gpt_j.py +337 -0
  786. vllm/model_executor/models/gpt_neox.py +330 -0
  787. vllm/model_executor/models/granite.py +494 -0
  788. vllm/model_executor/models/granite_speech.py +777 -0
  789. vllm/model_executor/models/granitemoe.py +435 -0
  790. vllm/model_executor/models/granitemoeshared.py +339 -0
  791. vllm/model_executor/models/gritlm.py +245 -0
  792. vllm/model_executor/models/grok1.py +560 -0
  793. vllm/model_executor/models/h2ovl.py +542 -0
  794. vllm/model_executor/models/idefics2_vision_model.py +387 -0
  795. vllm/model_executor/models/idefics3.py +767 -0
  796. vllm/model_executor/models/interfaces.py +569 -0
  797. vllm/model_executor/models/interfaces_base.py +163 -0
  798. vllm/model_executor/models/intern_vit.py +476 -0
  799. vllm/model_executor/models/internlm2.py +453 -0
  800. vllm/model_executor/models/internlm2_ve.py +146 -0
  801. vllm/model_executor/models/internvl.py +945 -0
  802. vllm/model_executor/models/jais.py +371 -0
  803. vllm/model_executor/models/jamba.py +590 -0
  804. vllm/model_executor/models/kimi_vl.py +577 -0
  805. vllm/model_executor/models/llama.py +619 -0
  806. vllm/model_executor/models/llama4.py +530 -0
  807. vllm/model_executor/models/llama_eagle.py +152 -0
  808. vllm/model_executor/models/llama_eagle3.py +232 -0
  809. vllm/model_executor/models/llava.py +869 -0
  810. vllm/model_executor/models/llava_next.py +582 -0
  811. vllm/model_executor/models/llava_next_video.py +470 -0
  812. vllm/model_executor/models/llava_onevision.py +954 -0
  813. vllm/model_executor/models/mamba.py +271 -0
  814. vllm/model_executor/models/mamba2.py +302 -0
  815. vllm/model_executor/models/mamba_cache.py +76 -0
  816. vllm/model_executor/models/medusa.py +210 -0
  817. vllm/model_executor/models/minicpm.py +592 -0
  818. vllm/model_executor/models/minicpm3.py +229 -0
  819. vllm/model_executor/models/minicpmo.py +725 -0
  820. vllm/model_executor/models/minicpmv.py +1287 -0
  821. vllm/model_executor/models/minimax_cache.py +35 -0
  822. vllm/model_executor/models/minimax_text_01.py +1261 -0
  823. vllm/model_executor/models/mistral3.py +598 -0
  824. vllm/model_executor/models/mixtral.py +485 -0
  825. vllm/model_executor/models/mixtral_quant.py +447 -0
  826. vllm/model_executor/models/mllama.py +1623 -0
  827. vllm/model_executor/models/mllama4.py +838 -0
  828. vllm/model_executor/models/mlp_speculator.py +205 -0
  829. vllm/model_executor/models/modernbert.py +325 -0
  830. vllm/model_executor/models/module_mapping.py +71 -0
  831. vllm/model_executor/models/molmo.py +1567 -0
  832. vllm/model_executor/models/moonvit.py +628 -0
  833. vllm/model_executor/models/mpt.py +329 -0
  834. vllm/model_executor/models/nemotron.py +506 -0
  835. vllm/model_executor/models/nemotron_nas.py +446 -0
  836. vllm/model_executor/models/nvlm_d.py +212 -0
  837. vllm/model_executor/models/olmo.py +390 -0
  838. vllm/model_executor/models/olmo2.py +412 -0
  839. vllm/model_executor/models/olmoe.py +449 -0
  840. vllm/model_executor/models/opt.py +410 -0
  841. vllm/model_executor/models/orion.py +356 -0
  842. vllm/model_executor/models/paligemma.py +397 -0
  843. vllm/model_executor/models/persimmon.py +342 -0
  844. vllm/model_executor/models/phi.py +354 -0
  845. vllm/model_executor/models/phi3.py +18 -0
  846. vllm/model_executor/models/phi3_small.py +463 -0
  847. vllm/model_executor/models/phi3v.py +722 -0
  848. vllm/model_executor/models/phi4mm.py +1263 -0
  849. vllm/model_executor/models/phi4mm_audio.py +1232 -0
  850. vllm/model_executor/models/phi4mm_utils.py +1883 -0
  851. vllm/model_executor/models/phimoe.py +666 -0
  852. vllm/model_executor/models/pixtral.py +1281 -0
  853. vllm/model_executor/models/plamo2.py +736 -0
  854. vllm/model_executor/models/prithvi_geospatial_mae.py +231 -0
  855. vllm/model_executor/models/qwen.py +360 -0
  856. vllm/model_executor/models/qwen2.py +552 -0
  857. vllm/model_executor/models/qwen2_5_omni_thinker.py +901 -0
  858. vllm/model_executor/models/qwen2_5_vl.py +1136 -0
  859. vllm/model_executor/models/qwen2_audio.py +402 -0
  860. vllm/model_executor/models/qwen2_moe.py +531 -0
  861. vllm/model_executor/models/qwen2_rm.py +130 -0
  862. vllm/model_executor/models/qwen2_vl.py +1409 -0
  863. vllm/model_executor/models/qwen3.py +319 -0
  864. vllm/model_executor/models/qwen3_moe.py +528 -0
  865. vllm/model_executor/models/qwen_vl.py +784 -0
  866. vllm/model_executor/models/registry.py +611 -0
  867. vllm/model_executor/models/roberta.py +332 -0
  868. vllm/model_executor/models/siglip.py +522 -0
  869. vllm/model_executor/models/skyworkr1v.py +949 -0
  870. vllm/model_executor/models/smolvlm.py +51 -0
  871. vllm/model_executor/models/solar.py +504 -0
  872. vllm/model_executor/models/stablelm.py +349 -0
  873. vllm/model_executor/models/starcoder2.py +355 -0
  874. vllm/model_executor/models/telechat2.py +139 -0
  875. vllm/model_executor/models/teleflm.py +78 -0
  876. vllm/model_executor/models/transformers.py +442 -0
  877. vllm/model_executor/models/ultravox.py +655 -0
  878. vllm/model_executor/models/utils.py +714 -0
  879. vllm/model_executor/models/vision.py +149 -0
  880. vllm/model_executor/models/whisper.py +746 -0
  881. vllm/model_executor/models/zamba2.py +1008 -0
  882. vllm/model_executor/parameter.py +458 -0
  883. vllm/model_executor/pooling_metadata.py +71 -0
  884. vllm/model_executor/sampling_metadata.py +596 -0
  885. vllm/model_executor/utils.py +53 -0
  886. vllm/multimodal/__init__.py +31 -0
  887. vllm/multimodal/audio.py +105 -0
  888. vllm/multimodal/base.py +218 -0
  889. vllm/multimodal/hasher.py +103 -0
  890. vllm/multimodal/image.py +77 -0
  891. vllm/multimodal/inputs.py +843 -0
  892. vllm/multimodal/parse.py +454 -0
  893. vllm/multimodal/processing.py +1760 -0
  894. vllm/multimodal/profiling.py +274 -0
  895. vllm/multimodal/registry.py +321 -0
  896. vllm/multimodal/utils.py +386 -0
  897. vllm/multimodal/video.py +166 -0
  898. vllm/outputs.py +521 -0
  899. vllm/platforms/__init__.py +286 -0
  900. vllm/platforms/cpu.py +182 -0
  901. vllm/platforms/cuda.py +463 -0
  902. vllm/platforms/hpu.py +94 -0
  903. vllm/platforms/interface.py +427 -0
  904. vllm/platforms/neuron.py +69 -0
  905. vllm/platforms/rocm.py +346 -0
  906. vllm/platforms/tpu.py +174 -0
  907. vllm/platforms/xpu.py +142 -0
  908. vllm/plugins/__init__.py +82 -0
  909. vllm/pooling_params.py +53 -0
  910. vllm/profiler/__init__.py +7 -0
  911. vllm/profiler/layerwise_profile.py +374 -0
  912. vllm/profiler/utils.py +147 -0
  913. vllm/prompt_adapter/__init__.py +0 -0
  914. vllm/prompt_adapter/layers.py +82 -0
  915. vllm/prompt_adapter/models.py +357 -0
  916. vllm/prompt_adapter/request.py +36 -0
  917. vllm/prompt_adapter/utils.py +97 -0
  918. vllm/prompt_adapter/worker_manager.py +178 -0
  919. vllm/py.typed +2 -0
  920. vllm/reasoning/__init__.py +12 -0
  921. vllm/reasoning/abs_reasoning_parsers.py +189 -0
  922. vllm/reasoning/deepseek_r1_reasoning_parser.py +172 -0
  923. vllm/reasoning/granite_reasoning_parser.py +362 -0
  924. vllm/sampling_params.py +598 -0
  925. vllm/scalar_type.py +335 -0
  926. vllm/scripts.py +14 -0
  927. vllm/sequence.py +1486 -0
  928. vllm/spec_decode/__init__.py +0 -0
  929. vllm/spec_decode/batch_expansion.py +505 -0
  930. vllm/spec_decode/draft_model_runner.py +335 -0
  931. vllm/spec_decode/interfaces.py +98 -0
  932. vllm/spec_decode/medusa_worker.py +137 -0
  933. vllm/spec_decode/metrics.py +212 -0
  934. vllm/spec_decode/mlp_speculator_worker.py +93 -0
  935. vllm/spec_decode/mqa_scorer.py +159 -0
  936. vllm/spec_decode/multi_step_worker.py +416 -0
  937. vllm/spec_decode/ngram_worker.py +195 -0
  938. vllm/spec_decode/proposer_worker_base.py +58 -0
  939. vllm/spec_decode/smaller_tp_proposer_worker.py +194 -0
  940. vllm/spec_decode/spec_decode_worker.py +1324 -0
  941. vllm/spec_decode/target_model_runner.py +44 -0
  942. vllm/spec_decode/top1_proposer.py +274 -0
  943. vllm/spec_decode/util.py +276 -0
  944. vllm/test_utils.py +129 -0
  945. vllm/third_party/__init__.py +0 -0
  946. vllm/third_party/pynvml.py +6139 -0
  947. vllm/tracing.py +130 -0
  948. vllm/transformers_utils/__init__.py +19 -0
  949. vllm/transformers_utils/config.py +813 -0
  950. vllm/transformers_utils/configs/__init__.py +52 -0
  951. vllm/transformers_utils/configs/arctic.py +206 -0
  952. vllm/transformers_utils/configs/chatglm.py +71 -0
  953. vllm/transformers_utils/configs/cohere2.py +194 -0
  954. vllm/transformers_utils/configs/dbrx.py +280 -0
  955. vllm/transformers_utils/configs/deepseek_vl2.py +216 -0
  956. vllm/transformers_utils/configs/eagle.py +65 -0
  957. vllm/transformers_utils/configs/exaone.py +191 -0
  958. vllm/transformers_utils/configs/falcon.py +89 -0
  959. vllm/transformers_utils/configs/h2ovl.py +15 -0
  960. vllm/transformers_utils/configs/internvl.py +53 -0
  961. vllm/transformers_utils/configs/jais.py +237 -0
  962. vllm/transformers_utils/configs/kimi_vl.py +36 -0
  963. vllm/transformers_utils/configs/medusa.py +62 -0
  964. vllm/transformers_utils/configs/mllama.py +30 -0
  965. vllm/transformers_utils/configs/mlp_speculator.py +67 -0
  966. vllm/transformers_utils/configs/moonvit.py +32 -0
  967. vllm/transformers_utils/configs/mpt.py +179 -0
  968. vllm/transformers_utils/configs/nemotron.py +204 -0
  969. vllm/transformers_utils/configs/nvlm_d.py +14 -0
  970. vllm/transformers_utils/configs/skyworkr1v.py +53 -0
  971. vllm/transformers_utils/configs/solar.py +246 -0
  972. vllm/transformers_utils/configs/telechat2.py +63 -0
  973. vllm/transformers_utils/configs/ultravox.py +107 -0
  974. vllm/transformers_utils/detokenizer.py +167 -0
  975. vllm/transformers_utils/detokenizer_utils.py +188 -0
  976. vllm/transformers_utils/processor.py +210 -0
  977. vllm/transformers_utils/processors/__init__.py +6 -0
  978. vllm/transformers_utils/processors/deepseek_vl2.py +363 -0
  979. vllm/transformers_utils/s3_utils.py +161 -0
  980. vllm/transformers_utils/tokenizer.py +291 -0
  981. vllm/transformers_utils/tokenizer_base.py +146 -0
  982. vllm/transformers_utils/tokenizer_group.py +110 -0
  983. vllm/transformers_utils/tokenizers/__init__.py +9 -0
  984. vllm/transformers_utils/tokenizers/mistral.py +483 -0
  985. vllm/transformers_utils/utils.py +98 -0
  986. vllm/triton_utils/__init__.py +5 -0
  987. vllm/triton_utils/importing.py +53 -0
  988. vllm/usage/__init__.py +0 -0
  989. vllm/usage/usage_lib.py +255 -0
  990. vllm/utils.py +2692 -0
  991. vllm/v1/__init__.py +0 -0
  992. vllm/v1/attention/__init__.py +0 -0
  993. vllm/v1/attention/backends/__init__.py +0 -0
  994. vllm/v1/attention/backends/flash_attn.py +783 -0
  995. vllm/v1/attention/backends/flashinfer.py +638 -0
  996. vllm/v1/attention/backends/mla/__init__.py +0 -0
  997. vllm/v1/attention/backends/mla/common.py +974 -0
  998. vllm/v1/attention/backends/mla/flashmla.py +149 -0
  999. vllm/v1/attention/backends/mla/triton_mla.py +118 -0
  1000. vllm/v1/attention/backends/pallas.py +221 -0
  1001. vllm/v1/attention/backends/triton_attn.py +198 -0
  1002. vllm/v1/core/__init__.py +0 -0
  1003. vllm/v1/core/block_pool.py +281 -0
  1004. vllm/v1/core/encoder_cache_manager.py +149 -0
  1005. vllm/v1/core/kv_cache_manager.py +385 -0
  1006. vllm/v1/core/kv_cache_utils.py +744 -0
  1007. vllm/v1/core/sched/__init__.py +0 -0
  1008. vllm/v1/core/sched/interface.py +134 -0
  1009. vllm/v1/core/sched/output.py +126 -0
  1010. vllm/v1/core/sched/scheduler.py +838 -0
  1011. vllm/v1/core/sched/utils.py +22 -0
  1012. vllm/v1/core/specialized_manager.py +161 -0
  1013. vllm/v1/engine/__init__.py +166 -0
  1014. vllm/v1/engine/async_llm.py +532 -0
  1015. vllm/v1/engine/core.py +701 -0
  1016. vllm/v1/engine/core_client.py +942 -0
  1017. vllm/v1/engine/detokenizer.py +260 -0
  1018. vllm/v1/engine/exceptions.py +16 -0
  1019. vllm/v1/engine/llm_engine.py +285 -0
  1020. vllm/v1/engine/logprobs.py +198 -0
  1021. vllm/v1/engine/mm_input_cache.py +82 -0
  1022. vllm/v1/engine/output_processor.py +420 -0
  1023. vllm/v1/engine/parallel_sampling.py +132 -0
  1024. vllm/v1/engine/processor.py +387 -0
  1025. vllm/v1/executor/__init__.py +0 -0
  1026. vllm/v1/executor/abstract.py +112 -0
  1027. vllm/v1/executor/multiproc_executor.py +480 -0
  1028. vllm/v1/executor/ray_distributed_executor.py +61 -0
  1029. vllm/v1/kv_cache_interface.py +166 -0
  1030. vllm/v1/metrics/__init__.py +0 -0
  1031. vllm/v1/metrics/loggers.py +498 -0
  1032. vllm/v1/metrics/stats.py +238 -0
  1033. vllm/v1/outputs.py +111 -0
  1034. vllm/v1/request.py +178 -0
  1035. vllm/v1/sample/__init__.py +0 -0
  1036. vllm/v1/sample/metadata.py +43 -0
  1037. vllm/v1/sample/ops/__init__.py +0 -0
  1038. vllm/v1/sample/ops/bad_words.py +38 -0
  1039. vllm/v1/sample/ops/penalties.py +58 -0
  1040. vllm/v1/sample/ops/topk_topp_sampler.py +315 -0
  1041. vllm/v1/sample/rejection_sampler.py +631 -0
  1042. vllm/v1/sample/sampler.py +270 -0
  1043. vllm/v1/sample/tpu/__init__.py +0 -0
  1044. vllm/v1/sample/tpu/metadata.py +118 -0
  1045. vllm/v1/sample/tpu/sampler.py +154 -0
  1046. vllm/v1/serial_utils.py +274 -0
  1047. vllm/v1/spec_decode/__init__.py +0 -0
  1048. vllm/v1/spec_decode/eagle.py +318 -0
  1049. vllm/v1/spec_decode/metadata.py +61 -0
  1050. vllm/v1/spec_decode/metrics.py +164 -0
  1051. vllm/v1/spec_decode/ngram_proposer.py +131 -0
  1052. vllm/v1/spec_decode/utils.py +18 -0
  1053. vllm/v1/stats/__init__.py +0 -0
  1054. vllm/v1/stats/common.py +453 -0
  1055. vllm/v1/structured_output/__init__.py +113 -0
  1056. vllm/v1/structured_output/backend_guidance.py +215 -0
  1057. vllm/v1/structured_output/backend_types.py +96 -0
  1058. vllm/v1/structured_output/backend_xgrammar.py +299 -0
  1059. vllm/v1/structured_output/request.py +84 -0
  1060. vllm/v1/structured_output/utils.py +174 -0
  1061. vllm/v1/utils.py +249 -0
  1062. vllm/v1/worker/__init__.py +0 -0
  1063. vllm/v1/worker/block_table.py +87 -0
  1064. vllm/v1/worker/gpu_input_batch.py +677 -0
  1065. vllm/v1/worker/gpu_model_runner.py +1776 -0
  1066. vllm/v1/worker/gpu_worker.py +349 -0
  1067. vllm/v1/worker/lora_model_runner_mixin.py +145 -0
  1068. vllm/v1/worker/tpu_model_runner.py +1419 -0
  1069. vllm/v1/worker/tpu_worker.py +260 -0
  1070. vllm/v1/worker/utils.py +74 -0
  1071. vllm/v1/worker/worker_base.py +64 -0
  1072. vllm/version.py +40 -0
  1073. vllm/vllm_flash_attn/.gitkeep +0 -0
  1074. vllm/worker/__init__.py +0 -0
  1075. vllm/worker/cache_engine.py +144 -0
  1076. vllm/worker/cpu_enc_dec_model_runner.py +323 -0
  1077. vllm/worker/cpu_model_runner.py +668 -0
  1078. vllm/worker/cpu_pooling_model_runner.py +122 -0
  1079. vllm/worker/cpu_worker.py +400 -0
  1080. vllm/worker/enc_dec_model_runner.py +542 -0
  1081. vllm/worker/hpu_model_runner.py +2221 -0
  1082. vllm/worker/hpu_worker.py +483 -0
  1083. vllm/worker/model_runner.py +2056 -0
  1084. vllm/worker/model_runner_base.py +281 -0
  1085. vllm/worker/multi_step_hpu_worker.py +122 -0
  1086. vllm/worker/multi_step_model_runner.py +908 -0
  1087. vllm/worker/multi_step_tpu_worker.py +107 -0
  1088. vllm/worker/multi_step_worker.py +196 -0
  1089. vllm/worker/neuron_model_runner.py +336 -0
  1090. vllm/worker/neuron_worker.py +138 -0
  1091. vllm/worker/pooling_model_runner.py +200 -0
  1092. vllm/worker/tpu_model_runner.py +908 -0
  1093. vllm/worker/tpu_worker.py +332 -0
  1094. vllm/worker/utils.py +52 -0
  1095. vllm/worker/worker.py +570 -0
  1096. vllm/worker/worker_base.py +644 -0
  1097. vllm/worker/xpu_model_runner.py +603 -0
  1098. vllm/worker/xpu_worker.py +185 -0
  1099. vllm_cpu-0.8.5.post2.dist-info/METADATA +309 -0
  1100. vllm_cpu-0.8.5.post2.dist-info/RECORD +1103 -0
  1101. vllm_cpu-0.8.5.post2.dist-info/WHEEL +5 -0
  1102. vllm_cpu-0.8.5.post2.dist-info/entry_points.txt +2 -0
  1103. vllm_cpu-0.8.5.post2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1883 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ # Copyright (c) Microsoft Corporation.
3
+ # Licensed under the MIT license.
4
+ # Code copied from Microsoft/MoE by Jacob Platin (jacobplatin@microsoft.com)
5
+ # but implemented by the Phi-Speech team
6
+ #!/usr/bin/env python3
7
+ import math
8
+ from typing import Optional, Tuple, Union
9
+
10
+ import torch
11
+ import torch.nn.functional as F
12
+ from torch import Tensor, nn
13
+
14
+
15
+ class Block(nn.Module):
16
+ """Block abstract module"""
17
+
18
+ def __init__(self, input_size, output_size):
19
+ super().__init__()
20
+ self.input_size = input_size
21
+ self.output_size = output_size
22
+
23
+
24
+ def get_activation(name="relu"):
25
+ """Select an activation function by name
26
+
27
+ Args:
28
+ name: str
29
+ activation function name,
30
+ one of ["relu", "gelu", "swish", "sigmoid"],
31
+ default "relu".
32
+ """
33
+ name = name.lower()
34
+ if name == "relu":
35
+ return nn.ReLU(inplace=True)
36
+ if name == "gelu":
37
+ return nn.GELU()
38
+ if name == "swish":
39
+ return Swish()
40
+ if name == "sigmoid":
41
+ return torch.nn.Sigmoid()
42
+ return nn.Identity()
43
+
44
+
45
+ def adaptive_enc_mask(x_len, chunk_start_idx, left_window=0, right_window=0):
46
+ """
47
+ The function is very important for Transformer Transducer Streaming mode
48
+ Args:
49
+ xs_len (int): sequence length
50
+ chunk_start_idx (list): first idx of each chunk, such as [0,18,36,48].
51
+ It also supports adaptive chunk size [0,10,15,45]
52
+ left_window (int): how many left chunks can be seen
53
+ right_window (int): how many right chunks can be seen. It is used for
54
+ chunk overlap model.
55
+ Returns:
56
+ mask (torch.Tensor): a mask tensor for streaming model
57
+ Torch 1.0.1
58
+ tensor([[1., 1., 0., 0.],
59
+ [0., 1., 1., 0.],
60
+ [0., 0., 1., 1.]])
61
+ Torch 1.4.1
62
+ tensor([[True., True., False., False.],
63
+ [False., True., True., False.],
64
+ [False., False., True., True.]])
65
+ """
66
+ chunk_start_idx = torch.Tensor(chunk_start_idx).long(
67
+ ) # first idx of each chunk, such as [0,18,36,48].
68
+ start_pad = torch.nn.functional.pad(
69
+ chunk_start_idx,
70
+ (1, 0)) # append 0 to the beginning, so it becomes [0, 0, 18, 36, 48]
71
+ end_pad = torch.nn.functional.pad(
72
+ chunk_start_idx, (0, 1), value=x_len
73
+ ) # append x_len to the end, so it becomes [0,18,36,48, x_len]
74
+ seq_range = torch.arange(0,
75
+ x_len).unsqueeze(-1) # seq_range size: [x_len, 1]
76
+ idx = ((seq_range < end_pad) &
77
+ (seq_range >= start_pad)).nonzero()[:, 1] # idx size: [x_len]
78
+ # boundary = end_pad[idx] # boundary size: [x_len]
79
+ seq_range_expand = (torch.arange(0, x_len).unsqueeze(0).expand(x_len, -1)
80
+ ) # seq_range_expand size [x_len, x_len]
81
+ idx_left = idx - left_window
82
+ idx_left[idx_left < 0] = 0
83
+ boundary_left = start_pad[idx_left]
84
+ mask_left = seq_range_expand >= boundary_left.unsqueeze(-1)
85
+ idx_right = idx + right_window
86
+ idx_right[idx_right > len(chunk_start_idx)] = len(chunk_start_idx)
87
+ boundary_right = end_pad[idx_right]
88
+ mask_right = seq_range_expand < boundary_right.unsqueeze(-1)
89
+ return mask_left & mask_right
90
+
91
+
92
+ class Swish(nn.Module):
93
+ """Implement Swish activation module.
94
+ From https://arxiv.org/pdf/2005.03191.pdf
95
+
96
+ """
97
+
98
+ def __init__(self) -> None:
99
+ super().__init__()
100
+ self.act_fn = nn.Sigmoid()
101
+
102
+ def forward(self, x: Tensor) -> Tensor:
103
+ """Apply Swish function
104
+
105
+ Args:
106
+ x: torch.Tensor
107
+ Input.
108
+ """
109
+ return x * self.act_fn(x)
110
+
111
+
112
+ class GLU(nn.Module):
113
+ """Implement Gated Linear Unit (GLU) module"""
114
+
115
+ def __init__(self, dim: int = -1, act_name: str = "sigmoid") -> None:
116
+ super().__init__()
117
+ self.dim = dim
118
+ self.act_name = act_name.lower()
119
+
120
+ if self.act_name == "relu":
121
+ self.act_fn = nn.ReLU(inplace=True)
122
+ elif self.act_name == "gelu":
123
+ self.act_fn = nn.GELU()
124
+ elif self.act_name == "swish":
125
+ self.act_fn = Swish()
126
+ elif self.act_name == "sigmoid":
127
+ self.act_fn = nn.Sigmoid()
128
+ else:
129
+ self.act_fn = nn.Identity()
130
+
131
+ def forward(self, x: Tensor) -> Tensor:
132
+ """GLU forward
133
+ Apply Swish function on the first half of input matrices
134
+ with sigmoid of the second half.
135
+
136
+ Args:
137
+ x: torch.Tensor
138
+ Input.
139
+
140
+ """
141
+ half_x, gate = x.chunk(2, dim=self.dim)
142
+ return half_x * self.act_fn(gate)
143
+
144
+
145
+ # TODO: Abdel, this can be improved using GLU module
146
+ class GLUPointWiseConv(nn.Module):
147
+ """GLUPointWiseConv module
148
+ used for conformer architecture,
149
+ for more details see:
150
+ https://arxiv.org/pdf/2005.08100v1.pdf
151
+
152
+ Args:
153
+ input_dim: int
154
+ input channel size.
155
+ output_dim: int
156
+ output channel size.
157
+ kernel_size: int
158
+ kernel size
159
+ glu_type: str, optional
160
+ activation function one of
161
+ ["sigmoid", "relu", "gelu"]
162
+ default "sigmoid".
163
+ bias_in_glu: bool, optional
164
+ use addtive bias in glu
165
+ causal: bool, optional
166
+ if set to True, padding is set to the half of
167
+ kernel size, ie, convolution can't see future frames.
168
+ default False.
169
+
170
+ """
171
+
172
+ def __init__(
173
+ self,
174
+ input_dim,
175
+ output_dim,
176
+ kernel_size,
177
+ glu_type="sigmoid",
178
+ bias_in_glu=True,
179
+ causal=False,
180
+ ):
181
+ super().__init__()
182
+
183
+ self.glu_type = glu_type
184
+ self.output_dim = output_dim
185
+ self.bias_in_glu = bias_in_glu
186
+ if causal:
187
+ self.ext_pw_conv_1d = nn.Conv1d(
188
+ input_dim,
189
+ output_dim * 2,
190
+ kernel_size,
191
+ 1,
192
+ padding=(kernel_size - 1),
193
+ )
194
+ else:
195
+ self.ext_pw_conv_1d = nn.Conv1d(
196
+ input_dim,
197
+ output_dim * 2,
198
+ kernel_size,
199
+ 1,
200
+ padding=(kernel_size - 1) // 2,
201
+ )
202
+
203
+ if glu_type == "sigmoid":
204
+ self.glu_act = nn.Sigmoid()
205
+ elif glu_type == "relu":
206
+ self.glu_act = nn.ReLU()
207
+ elif glu_type == "gelu":
208
+ self.glu_act = nn.GELU()
209
+ elif glu_type == "swish":
210
+ self.glu_act = Swish()
211
+ else:
212
+ raise ValueError(f"Unsupported activation type {self.glu_act}")
213
+
214
+ if bias_in_glu:
215
+ self.b1 = nn.Parameter(torch.zeros(1, output_dim, 1))
216
+ self.b2 = nn.Parameter(torch.zeros(1, output_dim, 1))
217
+
218
+ def forward(self, x):
219
+ """
220
+ Args:
221
+ x: torch.Tensor
222
+ input tensor
223
+ """
224
+ # to be consistent with GLULinear, we assume the input always has the
225
+ # #channel (#dim) in the last dimension of the tensor, so need to
226
+ # switch the dimension first for 1D-Conv case
227
+ x = x.permute([0, 2, 1])
228
+ x = self.ext_pw_conv_1d(x)
229
+ if self.glu_type == "bilinear":
230
+ if self.bias_in_glu:
231
+ x = (x[:, 0:self.output_dim, :] + self.b1) * (
232
+ x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
233
+ else:
234
+ x = (x[:, 0:self.output_dim, :]) * (
235
+ x[:, self.output_dim:self.output_dim * 2, :])
236
+ else:
237
+ if self.bias_in_glu:
238
+ x = (x[:, 0:self.output_dim, :] + self.b1) * self.glu_act(
239
+ x[:, self.output_dim:self.output_dim * 2, :] + self.b2)
240
+ else:
241
+ x = (x[:, 0:self.output_dim, :]) * self.glu_act(
242
+ x[:, self.output_dim:self.output_dim * 2, :])
243
+
244
+ x = x.permute([0, 2, 1])
245
+ return x
246
+
247
+
248
+ class DepthWiseSeperableConv1d(nn.Module):
249
+ """DepthWiseSeperableConv1d module used in Convnet module
250
+ for the conformer, for more details see:
251
+ https://arxiv.org/pdf/2005.08100v1.pdf
252
+
253
+ Args:
254
+ input_dim: int
255
+ input channel size.
256
+ depthwise_seperable_out_channel: int
257
+ if set different to 0, the number of
258
+ depthwise_seperable_out_channel will be used as a channel_out
259
+ of the second conv1d layer.
260
+ otherwise, it equal to 0, the second conv1d layer is skipped.
261
+ kernel_size: int
262
+ kernel_size
263
+ depthwise_multiplier: int
264
+ number of input_dim channels duplication. this value
265
+ will be used to compute the hidden channels of the Conv1D.
266
+ padding: int, optional
267
+ padding for the conv1d,
268
+ default: 0.
269
+
270
+ """
271
+
272
+ def __init__(
273
+ self,
274
+ input_dim,
275
+ depthwise_seperable_out_channel,
276
+ kernel_size,
277
+ depthwise_multiplier,
278
+ padding=0,
279
+ ):
280
+ super().__init__()
281
+
282
+ self.dw_conv = nn.Conv1d(
283
+ input_dim,
284
+ input_dim * depthwise_multiplier,
285
+ kernel_size,
286
+ 1,
287
+ padding=padding,
288
+ groups=input_dim,
289
+ )
290
+
291
+ if depthwise_seperable_out_channel != 0:
292
+ self.pw_conv = nn.Conv1d(
293
+ input_dim * depthwise_multiplier,
294
+ depthwise_seperable_out_channel,
295
+ 1,
296
+ 1,
297
+ 0,
298
+ )
299
+ else:
300
+ self.pw_conv = nn.Identity()
301
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
302
+
303
+ def forward(self, x):
304
+ """
305
+
306
+ Args:
307
+ x: torch.Tensor
308
+ input tensor
309
+ """
310
+ x = self.dw_conv(x)
311
+ if self.depthwise_seperable_out_channel != 0:
312
+ x = self.pw_conv(x)
313
+ return x
314
+
315
+
316
+ class ConvModule(nn.Module):
317
+ """ConvModule Module for the conformer block.
318
+ for more details see:
319
+ https://arxiv.org/pdf/2005.08100v1.pdf
320
+
321
+ Args:
322
+ input_dim: int
323
+ input channel size.
324
+ ext_pw_out_channel: int
325
+ if > 0, ext_pw_out_channel is a dim channel size
326
+ for the last pointwise conv after swish activation.
327
+ depthwise_seperable_out_channel: int
328
+ if set different to 0, the number of
329
+ depthwise_seperable_out_channel
330
+ will be used as a channel_out of the second conv1d layer.
331
+ otherwise, it equal to 0, the second conv1d layer is skipped.
332
+ ext_pw_kernel_size: int
333
+ kernel size of the conv pointwise of the conformer.
334
+ kernel_size: int
335
+ kernel size.
336
+ depthwise_multiplier: int
337
+ number of input_dim channels duplication. this value
338
+ will be used to compute the hidden channels of the Conv1D.
339
+ dropout_rate: float
340
+ dropout rate.
341
+ causal: bool, optional
342
+ if set to True, convolution have no access
343
+ to future frames. default False.
344
+ batch_norm: bool, optional
345
+ if set to True, apply batchnorm before activation.
346
+ default False
347
+ chunk_se: int, optional
348
+ 0 for offline SE.
349
+ 1 for streaming SE, where mean is computed
350
+ by accumulated history until current chunk_se.
351
+ 2 for streaming SE, where mean is computed
352
+ by only the current chunk.
353
+ chunk_size: int, optional
354
+ chunk size for cnn. default 18
355
+ activation: str, optional
356
+ activation function used in ConvModule,
357
+ default: "relu".
358
+ glu_type: str, optional
359
+ activation function used for the glu,
360
+ default: "sigmoid".
361
+ bias_in_glu: bool, optional
362
+ if set to True, use additive bias in the weight module
363
+ before GLU.
364
+ linear_glu_in_convm: bool, optional
365
+ if set to True, use GLULinear module,
366
+ otherwise, used GLUPointWiseConv module.
367
+ default to False.
368
+ export: bool, optional,
369
+ if set to True, padding is equal to 0. This is for inference,
370
+ or onnx export. Typically this is set by the export program or
371
+ the decoder program, and it isn't present in your config file.
372
+ default False
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ input_dim,
378
+ ext_pw_out_channel,
379
+ depthwise_seperable_out_channel,
380
+ ext_pw_kernel_size,
381
+ kernel_size,
382
+ depthwise_multiplier,
383
+ dropout_rate,
384
+ causal=False,
385
+ batch_norm=False,
386
+ chunk_se=0,
387
+ chunk_size=18,
388
+ activation="relu",
389
+ glu_type="sigmoid",
390
+ bias_in_glu=True,
391
+ linear_glu_in_convm=False,
392
+ export=False,
393
+ ):
394
+ super().__init__()
395
+ self.layer_norm = nn.LayerNorm(input_dim)
396
+ self.input_dim = input_dim
397
+ self.ext_pw_out_channel = ext_pw_out_channel
398
+ self.ext_pw_kernel_size = ext_pw_kernel_size
399
+ self.depthwise_seperable_out_channel = depthwise_seperable_out_channel
400
+ self.glu_type = glu_type
401
+ self.bias_in_glu = bias_in_glu
402
+ self.linear_glu_in_convm = linear_glu_in_convm
403
+ self.causal = causal
404
+
405
+ self._add_ext_pw_layer()
406
+
407
+ self.batch_norm = batch_norm
408
+ self.kernel_size = kernel_size
409
+
410
+ if batch_norm:
411
+ self.bn_layer = nn.BatchNorm1d(input_dim)
412
+
413
+ self.act = get_activation(activation)
414
+ self.dropout = nn.Dropout(dropout_rate)
415
+ self.export = export
416
+
417
+ if causal:
418
+ padding = 0 if export else kernel_size - 1
419
+ else:
420
+ padding = (kernel_size - 1) // 2
421
+
422
+ self.dw_sep_conv_1d = DepthWiseSeperableConv1d(
423
+ input_dim,
424
+ depthwise_seperable_out_channel,
425
+ kernel_size,
426
+ depthwise_multiplier,
427
+ padding=padding,
428
+ )
429
+
430
+ if depthwise_seperable_out_channel != 0:
431
+ if input_dim != depthwise_seperable_out_channel:
432
+ self.ln2 = nn.Linear(depthwise_seperable_out_channel,
433
+ input_dim)
434
+ else:
435
+ if depthwise_multiplier != 1:
436
+ self.ln2 = nn.Linear(input_dim * depthwise_multiplier,
437
+ input_dim)
438
+
439
+ def _add_ext_pw_layer(self):
440
+ """
441
+ This function is an extension of __init__ function
442
+ and dedicated to the convolution module creation
443
+ of the conformer.
444
+ """
445
+ self.ln1 = self.glu = self.bn_layer = self.ext_pw_conv_1d = (
446
+ nn.Identity()) # jit hacks.
447
+ self.squeeze_excitation = nn.Identity() # jit.
448
+ self.apply_ln1 = self.fix_len1 = False # jit.
449
+
450
+ if self.ext_pw_out_channel != 0:
451
+ if self.causal:
452
+ self.ext_pw_conv_1d = nn.Conv1d(
453
+ self.input_dim,
454
+ self.ext_pw_out_channel,
455
+ self.ext_pw_kernel_size,
456
+ 1,
457
+ padding=(self.ext_pw_kernel_size - 1),
458
+ )
459
+ if self.ext_pw_kernel_size > 1:
460
+ self.fix_len1 = True
461
+ else:
462
+ self.fix_len1 = False
463
+ else:
464
+ self.ext_pw_conv_1d = nn.Conv1d(
465
+ self.input_dim,
466
+ self.ext_pw_out_channel,
467
+ self.ext_pw_kernel_size,
468
+ 1,
469
+ padding=(self.ext_pw_kernel_size - 1) // 2,
470
+ )
471
+ self.fix_len1 = False
472
+
473
+ if self.linear_glu_in_convm:
474
+ self.glu = GLULinear(
475
+ self.input_dim,
476
+ self.ext_pw_out_channel,
477
+ self.glu_type,
478
+ self.bias_in_glu,
479
+ )
480
+ else:
481
+ self.glu = GLUPointWiseConv(
482
+ self.input_dim,
483
+ self.ext_pw_out_channel,
484
+ self.ext_pw_kernel_size,
485
+ self.glu_type,
486
+ self.bias_in_glu,
487
+ self.causal,
488
+ )
489
+
490
+ if self.input_dim != self.ext_pw_out_channel:
491
+ self.apply_ln1 = True
492
+ self.ln1 = nn.Linear(self.ext_pw_out_channel, self.input_dim)
493
+ else:
494
+ self.apply_ln1 = False
495
+ else:
496
+ self.pw_conv_simplify_w = torch.nn.Parameter(torch.ones(3))
497
+ self.pw_conv_simplify_b = torch.nn.Parameter(torch.zeros(3))
498
+
499
+ def forward(self, x):
500
+ """ConvModule Forward.
501
+
502
+ Args:
503
+ x: torch.Tensor
504
+ input tensor.
505
+ """
506
+ x = self.layer_norm(x)
507
+
508
+ if self.ext_pw_out_channel != 0:
509
+ x = self.glu(x)
510
+ if self.causal and self.ext_pw_kernel_size > 1:
511
+ x = x[:, :-(self.ext_pw_kernel_size - 1), :]
512
+ if self.apply_ln1:
513
+ x = self.ln1(x)
514
+ else:
515
+ x_0 = x * self.pw_conv_simplify_w[0] + self.pw_conv_simplify_b[0]
516
+ x_1 = x * self.pw_conv_simplify_w[1] + self.pw_conv_simplify_b[1]
517
+ x = x_0 + x_1
518
+
519
+ x = x.permute([0, 2, 1])
520
+
521
+ x = self.dw_sep_conv_1d(x)
522
+ if self.causal and self.kernel_size > 1:
523
+ x = x[:, :, :-(self.kernel_size - 1)]
524
+ if hasattr(self, "ln2"):
525
+ x = x.permute([0, 2, 1])
526
+ x = self.ln2(x)
527
+ x = x.permute([0, 2, 1])
528
+ if self.batch_norm:
529
+ x = self.bn_layer(x)
530
+ x = self.act(x)
531
+
532
+ if self.ext_pw_out_channel != 0:
533
+ x = self.ext_pw_conv_1d(x)
534
+ if self.fix_len1:
535
+ x = x[:, :, :-(self.ext_pw_kernel_size - 1)]
536
+
537
+ if self.apply_ln1:
538
+ x = x.permute([0, 2, 1])
539
+ x = self.ln1(x)
540
+ x = x.permute([0, 2, 1])
541
+
542
+ x = x.permute([0, 2, 1])
543
+ else:
544
+ x = x.unsqueeze(1).permute([0, 1, 3, 2])
545
+ x = x * self.pw_conv_simplify_w[2] + self.pw_conv_simplify_b[2]
546
+ x = x.squeeze(1)
547
+
548
+ x = self.dropout(x)
549
+ return x
550
+
551
+
552
+ class GLULinear(nn.Module):
553
+ """Linear + GLU module
554
+
555
+ Args:
556
+ input_dim: int
557
+ input size
558
+ output_dim: int
559
+ output size.
560
+ glu_type:
561
+ activation function name used in glu module.
562
+ default "sigmoid" (swish function).
563
+ bias_in_glu: bool, optional
564
+ If True, the addtive bias is added. Default False.
565
+ """
566
+
567
+ def __init__(
568
+ self,
569
+ input_dim,
570
+ output_dim,
571
+ glu_type="sigmoid",
572
+ bias_in_glu=True,
573
+ ):
574
+ super().__init__()
575
+ self.linear = nn.Linear(input_dim, output_dim * 2, bias_in_glu)
576
+ self.glu_act = GLU(-1, glu_type)
577
+
578
+ def forward(self, x):
579
+ """GLULinear forward
580
+
581
+ Args:
582
+ x: torch.Tensor
583
+ inpute tensor.
584
+ """
585
+ x = self.linear(x)
586
+ return self.glu_act(x)
587
+
588
+
589
+ class FeedForward(nn.Module):
590
+ """FeedForward Module.
591
+ For more details see Conformer paper:
592
+ https://arxiv.org/pdf/2005.08100.pdf
593
+
594
+ Args:
595
+ d_model: int
596
+ input size.
597
+ d_inner: int
598
+ output size.
599
+ dropout_rate: float,
600
+ dropout rate.
601
+ activation: str,
602
+ activation function name,
603
+ one of ["relu", "swish", "sigmoid"],
604
+ sigmoid activation is only used with "glu_in_fnn=True",
605
+ default "sigmoid".
606
+ bias_in_glu: bool, optional
607
+ """
608
+
609
+ def __init__(
610
+ self,
611
+ d_model,
612
+ d_inner,
613
+ dropout_rate,
614
+ activation="sigmoid",
615
+ bias_in_glu=True,
616
+ ):
617
+ super().__init__()
618
+ self.d_model = d_model
619
+ self.d_inner = d_inner
620
+
621
+ self.layer_norm = nn.LayerNorm(d_model)
622
+ module = GLULinear(d_model, d_inner, activation, bias_in_glu)
623
+ self.net = nn.Sequential(
624
+ module,
625
+ nn.Dropout(dropout_rate),
626
+ nn.Linear(d_inner, d_model),
627
+ nn.Dropout(dropout_rate),
628
+ )
629
+
630
+ def forward(self, x):
631
+ """FeedForward forward function.
632
+
633
+ Args:
634
+ x: torch.Tensor
635
+ input tensor.
636
+ """
637
+ out = self.net(self.layer_norm(x))
638
+
639
+ return out
640
+
641
+
642
+ #### positional encoding starts here
643
+ def _pre_hook(
644
+ state_dict,
645
+ prefix,
646
+ local_metadata,
647
+ strict,
648
+ missing_keys,
649
+ unexpected_keys,
650
+ error_msgs,
651
+ ):
652
+ """Perform pre-hook in load_state_dict for backward compatibility.
653
+
654
+ Note:
655
+ We saved self.pe until v.0.5.2 but we have omitted it later.
656
+ Therefore, we remove the item "pe" from `state_dict` for backward
657
+ compatibility.
658
+
659
+ """
660
+ k = prefix + "pe"
661
+ if k in state_dict:
662
+ state_dict.pop(k)
663
+
664
+
665
+ class T5RelativeAttentionLogitBias(nn.Module):
666
+ """
667
+ This module implements the relative position bias described in Section
668
+ 2.1 of the T5 paper: https://arxiv.org/pdf/1910.10683.pdf
669
+
670
+ The Huggingface implementation is used as a reference
671
+ https://github.com/huggingface/transformers/blob/v4.30.0/src/
672
+ transformers/models/t5/modeling_t5.py#L435
673
+
674
+ Modifies attention as Q*K^T + B, where B is a learned scalar bias based
675
+ on relative position of the query and key. It is HxNxN, where H is the
676
+ number of heads, N is the sequence length.
677
+
678
+ I've made these modifications to the original T5 bias:
679
+ - Skipping of the bucketing step. Original T5 bias converted rel
680
+ position distances into logarithmically increasing buckets. This is
681
+ supposed to help with length generalization.
682
+ - I just directly use rel position index as bias values, as we don't
683
+ need length generalization (40s max is good enough for ASR encoder),
684
+ and it keeps ONNX export simple.
685
+ - I've also extended it so that biases can be asymmetric, the default
686
+ implementation treats L->R and R->L the same. Asymmetric was found to
687
+ yield better results in my experiments.
688
+
689
+ Args:
690
+ num_heads: int
691
+ Number of attention heads
692
+ num_buckets: int
693
+ Number of buckets to use for relative attention bias. This is the
694
+ size of the learnable bias parameter. Bucketing is not yet
695
+ supported, so this defaults to -1 which means no bucketing is
696
+ used (max_distance determines size of bias param).
697
+ max_distance: int
698
+ Maximum distance to use for relative attention bias. With
699
+ num_buckets=-1, this directly controls the max size of the bias
700
+ parameter. When num_buckets > 0 is supported, this will control
701
+ the maximum distance for logarithmic bucketing after which all
702
+ positions are in the same bucket.
703
+ symmetric: bool
704
+ Whether to use symmetric or asymmetric biases. symmetric=False uses
705
+ 2x number of bias params to distinguish L->R from R->L. This was
706
+ found to be better for the encoder.
707
+ """
708
+
709
+ def __init__(self,
710
+ num_heads,
711
+ num_buckets=-1,
712
+ max_distance=1000,
713
+ symmetric=False):
714
+ super().__init__()
715
+ self.num_heads = num_heads
716
+ self.num_buckets = num_buckets
717
+ self.max_distance = max_distance
718
+ self.symmetric = symmetric
719
+ self._skip_bucketing = self.num_buckets < 0
720
+ if self._skip_bucketing:
721
+ self.num_buckets = max_distance
722
+ else:
723
+ raise NotImplementedError(
724
+ "T5 attention bias with bucketed positions is not yet tested")
725
+ if not self.symmetric:
726
+ self.num_buckets *= 2
727
+ self.bias_values = nn.Embedding(self.num_buckets, self.num_heads)
728
+
729
+ def forward(self, x):
730
+ # instantiate bias compatible with shape of x
731
+ maxpos = x.size(1)
732
+ context_position = torch.arange(maxpos,
733
+ device=x.device,
734
+ dtype=torch.long)[:, None]
735
+ memory_position = torch.arange(maxpos,
736
+ device=x.device,
737
+ dtype=torch.long)[None, :]
738
+ relative_position = memory_position - context_position
739
+ # clipping to a maximum distance using ops that play well with ONNX
740
+ # export
741
+ relative_position = relative_position.masked_fill(
742
+ relative_position < -self.max_distance, -self.max_distance)
743
+ relative_position = relative_position.masked_fill(
744
+ relative_position > self.max_distance - 1, self.max_distance - 1)
745
+
746
+ # mapping from relative position to index in the bias parameter
747
+ if self._skip_bucketing:
748
+ bias_idx = relative_position
749
+ else:
750
+ bias_idx = self._bucket_relative_position(relative_position)
751
+ if self.symmetric:
752
+ bias_idx = bias_idx.abs()
753
+ else:
754
+ bias_idx += self.num_buckets // 2
755
+
756
+ t5_rel_att_bias = self.bias_values(bias_idx) # [L, L, H]
757
+ t5_rel_att_bias = t5_rel_att_bias.permute(2, 0, 1).unsqueeze(
758
+ 0) # [1, H, L, L]
759
+
760
+ return t5_rel_att_bias
761
+
762
+ def _bucket_relative_position(self, relative_position):
763
+ # this is a placeholder (isn't tested, likely buggy) using HuggingFace
764
+ # implem as a reference this also needs to be extended to support
765
+ # asymmetric +/- ve positions
766
+ relative_buckets = 0
767
+ if not self.causal:
768
+ self.num_buckets //= 2
769
+ relative_buckets += (relative_position > 0).to(
770
+ torch.long) * self.num_buckets
771
+ relative_position = torch.abs(relative_position)
772
+ else:
773
+ relative_position = -torch.min(relative_position,
774
+ torch.zeros_like(relative_position))
775
+ # now relative_position is in the range [0, inf)
776
+
777
+ # half of the buckets are for exact increments in positions
778
+ max_exact = self.num_buckets // 2
779
+ is_small = relative_position < max_exact
780
+
781
+ # The other half of the buckets are for logarithmically bigger bins in
782
+ # positions up to max_distance
783
+ relative_position_if_large = max_exact + (
784
+ torch.log(relative_position.float() / max_exact) /
785
+ math.log(self.max_distance / max_exact) *
786
+ (self.num_buckets - max_exact)).to(torch.long)
787
+ relative_position_if_large = torch.min(
788
+ relative_position_if_large,
789
+ torch.full_like(relative_position_if_large, self.num_buckets - 1),
790
+ )
791
+
792
+ relative_buckets += torch.where(is_small, relative_position,
793
+ relative_position_if_large)
794
+ return relative_buckets
795
+
796
+
797
+ class AbsolutePositionalEncoding(nn.Module):
798
+ """Absolute Positional encoding module.
799
+ This module implement Absolute sinusoidal positional encoding
800
+ from: https://arxiv.org/pdf/1706.03762.pdf
801
+
802
+ Args:
803
+ d_model: int
804
+ Input embedding size.
805
+ dropout_rate: float
806
+ dropout rate
807
+ max_len: int, optional
808
+ Maximum input length sequence, Default 5000
809
+
810
+ """
811
+
812
+ def __init__(self, d_model, dropout_rate, max_len=5000):
813
+ """Construct an PositionalEncoding object."""
814
+ super().__init__()
815
+ self.d_model = d_model
816
+ self.xscale = math.sqrt(self.d_model)
817
+ self.dropout = torch.nn.Dropout(p=dropout_rate)
818
+ self.pe = None
819
+ self.extend_pe(torch.tensor(0.0).expand(1, max_len))
820
+ self._register_load_state_dict_pre_hook(_pre_hook)
821
+
822
+ def extend_pe(self, x):
823
+ """Reset the positional encodings.
824
+
825
+ Args:
826
+ x: torch.Tensor
827
+ """
828
+ if self.pe is not None and self.pe.size(1) >= x.size(1):
829
+ if self.pe.dtype != x.dtype or self.pe.device != x.device:
830
+ self.pe = self.pe.to(dtype=x.dtype, device=x.device)
831
+ return
832
+ pe = torch.zeros(x.size(1), self.d_model)
833
+ position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1)
834
+ div_term = torch.exp(
835
+ torch.arange(0, self.d_model, 2, dtype=torch.float32) *
836
+ -(math.log(10000.0) / self.d_model))
837
+ pe[:, 0::2] = torch.sin(position * div_term)
838
+ pe[:, 1::2] = torch.cos(position * div_term)
839
+ pe = pe.unsqueeze(0)
840
+ self.pe = pe.to(device=x.device, dtype=x.dtype)
841
+
842
+ def forward(self, x: torch.Tensor):
843
+ """Add positional encoding.
844
+
845
+ Args:
846
+ x: torch.Tensor
847
+ Input tensor. shape is (batch, time, ...)
848
+
849
+ Returns:
850
+ torch.Tensor: Encoded tensor. Its shape is (batch, time, ...)
851
+
852
+ """
853
+ self.extend_pe(x)
854
+ x = x * self.xscale + self.pe[:, :x.size(1)]
855
+ return self.dropout(x)
856
+
857
+
858
+ #### forward embedding layers starts here
859
+ class MeanVarianceNormLayer(nn.Module):
860
+ """Mean/variance normalization layer.
861
+
862
+ Will subtract mean and multiply input by inverted standard deviation.
863
+ Typically used as a very first layer in a model.
864
+
865
+ Args:
866
+ input_size: int
867
+ layer input size.
868
+ """
869
+
870
+ def __init__(self, input_size):
871
+ super().__init__()
872
+ self.input_size = input_size
873
+ self.global_mean = nn.Parameter(torch.zeros(input_size))
874
+ self.global_invstd = nn.Parameter(torch.ones(input_size))
875
+
876
+ def forward(self, input_: Tensor) -> Tensor:
877
+ """MeanVarianceNormLayer Forward
878
+
879
+ Args:
880
+ input_: torch.Tensor
881
+ input tensor.
882
+ """
883
+ return (input_ - self.global_mean) * self.global_invstd
884
+
885
+
886
+ class CausalConv1D(nn.Conv1d):
887
+ """
888
+ A causal version of nn.Conv1d where each step would have limited access to
889
+ locations on its right or left
890
+ All arguments are the same as nn.Conv1d except padding.
891
+
892
+ If padding is set None, then paddings are set automatically to make it a
893
+ causal convolution where each location would not see any steps on its right.
894
+
895
+ If padding is set as a list (size of 2), then padding[0] would be used as
896
+ left padding and padding[1] as right padding.
897
+ It would make it possible to control the number of steps to be accessible
898
+ on the right and left.
899
+ This mode is not supported when stride > 1. padding[0]+padding[1] should
900
+ be equal to (kernel_size - 1).
901
+ """
902
+
903
+ def __init__(
904
+ self,
905
+ in_channels: int,
906
+ out_channels: int,
907
+ kernel_size: int,
908
+ stride: int = 1,
909
+ padding: Union[str, int] = 0,
910
+ dilation: int = 1,
911
+ groups: int = 1,
912
+ bias: bool = True,
913
+ padding_mode: str = "zeros",
914
+ device=None,
915
+ dtype=None,
916
+ ) -> None:
917
+ self.cache_drop_size = None
918
+ if padding is None:
919
+ self._left_padding = kernel_size - 1
920
+ self._right_padding = stride - 1
921
+ else:
922
+ if stride != 1 and padding != kernel_size - 1:
923
+ raise ValueError(
924
+ "No striding allowed for non-symmetric convolutions!")
925
+ if isinstance(padding, int):
926
+ self._left_padding = padding
927
+ self._right_padding = padding
928
+ elif (isinstance(padding, list) and len(padding) == 2
929
+ and padding[0] + padding[1] == kernel_size - 1):
930
+ self._left_padding = padding[0]
931
+ self._right_padding = padding[1]
932
+ else:
933
+ raise ValueError(f"Invalid padding param: {padding}!")
934
+
935
+ self._max_cache_len = self._left_padding
936
+
937
+ super().__init__(
938
+ in_channels=in_channels,
939
+ out_channels=out_channels,
940
+ kernel_size=kernel_size,
941
+ stride=stride,
942
+ padding=0,
943
+ dilation=dilation,
944
+ groups=groups,
945
+ bias=bias,
946
+ padding_mode=padding_mode,
947
+ device=device,
948
+ dtype=dtype,
949
+ )
950
+
951
+ def update_cache(self, x, cache=None):
952
+ if cache is None:
953
+ new_x = F.pad(x, pad=(self._left_padding, self._right_padding))
954
+ next_cache = cache
955
+ else:
956
+ new_x = F.pad(x, pad=(0, self._right_padding))
957
+ new_x = torch.cat([cache, new_x], dim=-1)
958
+ if self.cache_drop_size > 0:
959
+ next_cache = new_x[:, :, :-self.cache_drop_size]
960
+ else:
961
+ next_cache = new_x
962
+ next_cache = next_cache[:, :, -cache.size(-1):]
963
+ return new_x, next_cache
964
+
965
+ def forward(self, x, cache=None):
966
+ x, cache = self.update_cache(x, cache=cache)
967
+ x = super().forward(x)
968
+ if cache is None:
969
+ return x
970
+ else:
971
+ return x, cache
972
+
973
+
974
+ class CausalConv2D(nn.Conv2d):
975
+ """
976
+ A causal version of nn.Conv2d where each location in the 2D matrix would
977
+ have no access to locations on its right or down
978
+ All arguments are the same as nn.Conv2d except padding which should be
979
+ set as None
980
+ """
981
+
982
+ def __init__(
983
+ self,
984
+ in_channels: int,
985
+ out_channels: int,
986
+ kernel_size: int,
987
+ stride: int = 1,
988
+ padding: Union[str, int] = 0,
989
+ dilation: int = 1,
990
+ groups: int = 1,
991
+ bias: bool = True,
992
+ padding_mode: str = "zeros",
993
+ device=None,
994
+ dtype=None,
995
+ ) -> None:
996
+ if padding is not None:
997
+ raise ValueError(
998
+ "Argument padding should be set to None for CausalConv2D.")
999
+ self._left_padding = kernel_size - 1
1000
+ self._right_padding = stride - 1
1001
+
1002
+ padding = 0
1003
+ super().__init__(
1004
+ in_channels,
1005
+ out_channels,
1006
+ kernel_size,
1007
+ stride,
1008
+ padding,
1009
+ dilation,
1010
+ groups,
1011
+ bias,
1012
+ padding_mode,
1013
+ device,
1014
+ dtype,
1015
+ )
1016
+
1017
+ def forward(
1018
+ self,
1019
+ x,
1020
+ ):
1021
+ x = F.pad(
1022
+ x,
1023
+ pad=(self._left_padding, self._right_padding, 0, 0),
1024
+ )
1025
+ x = super().forward(x)
1026
+ return x
1027
+
1028
+
1029
+ class NemoConvSubsampling(torch.nn.Module):
1030
+ """Convlutional subsampling module, taken from NeMo ASR
1031
+ (https://github.com/NVIDIA/NeMo/blob/b367413645d5c72db3c2c96e46e95a
1032
+ 34501479cf/nemo/collections/asr/parts/submodules/subsampling.py)
1033
+
1034
+ Striding Subsampling: "Speech-Transformer: A No-Recurrence
1035
+ Sequence-to-Sequence Model for Speech Recognition" by Linhao Dong
1036
+ et al. (https://ieeexplore.ieee.org/document/8462506)
1037
+
1038
+
1039
+ Compared with the EncoderConv2D (`input_layer: custom`), this is a
1040
+ much simplified approach, and uses no LayerNorm and far fewer Conv2Ds.
1041
+ Moreover, depthwise convolutions are used to reduce FLOPs, but the first
1042
+ layer is kept as a regular convolution so as not to degrade accuracy.
1043
+
1044
+ `Striding` and `dw_striding` are the same except that the latter uses
1045
+ depthwise convolutions after the first layer, whereas the former does not.
1046
+
1047
+ Args:
1048
+ subsampling_factor (int): Time reduction factor
1049
+ feat_in (int): size of the input features
1050
+ feat_out (int): size of the output features
1051
+ subsampling (str): The subsampling technique, choose from
1052
+ {"striding", "dw-striding", "striding_conv1d",
1053
+ "dw_striding_conv1d"}
1054
+ conv_channels (int): Number of channels for the convolution layers,
1055
+ default is 256.
1056
+ subsampling_conv_chunking_factor (int): Input chunking factor which
1057
+ can be -1 (no chunking) 1 (auto) or a power of 2. Default is 1
1058
+ activation (Module): activation function, default is nn.ReLU()
1059
+ is_causal (bool): whether to use causal Conv1/2D, where each step will
1060
+ have limited access to locations on its right or left
1061
+ """
1062
+
1063
+ def __init__(
1064
+ self,
1065
+ feat_in,
1066
+ feat_out,
1067
+ subsampling_factor=4,
1068
+ subsampling="dw_striding",
1069
+ conv_channels=256,
1070
+ subsampling_conv_chunking_factor=1,
1071
+ activation=nn.ReLU(), # noqa: B008
1072
+ is_causal=False,
1073
+ ):
1074
+ super().__init__()
1075
+ self._subsampling = subsampling
1076
+ self._conv_channels = conv_channels
1077
+ self._feat_in = feat_in
1078
+ self._feat_out = feat_out
1079
+
1080
+ if subsampling_factor % 2 != 0:
1081
+ raise ValueError("Sampling factor should be a multiply of 2!")
1082
+ self._sampling_num = int(math.log(subsampling_factor, 2))
1083
+ self.subsampling_factor = subsampling_factor
1084
+ self.is_causal = is_causal
1085
+ self.subsampling_causal_cond = subsampling in (
1086
+ "dw_striding",
1087
+ "striding",
1088
+ "striding_conv1d",
1089
+ )
1090
+
1091
+ if (subsampling_conv_chunking_factor != -1
1092
+ and subsampling_conv_chunking_factor != 1
1093
+ and subsampling_conv_chunking_factor % 2 != 0):
1094
+ raise ValueError(
1095
+ "subsampling_conv_chunking_factor should be -1, 1, or a "\
1096
+ "power of 2"
1097
+ )
1098
+ self.subsampling_conv_chunking_factor = \
1099
+ subsampling_conv_chunking_factor
1100
+
1101
+ in_channels = 1
1102
+ layers = []
1103
+
1104
+ if subsampling == "dw_striding":
1105
+ self._stride = 2
1106
+ self._kernel_size = 3
1107
+ self._ceil_mode = False
1108
+
1109
+ if self.is_causal:
1110
+ self._left_padding = self._kernel_size - 1
1111
+ self._right_padding = self._stride - 1
1112
+ self._max_cache_len = subsampling_factor + 1
1113
+ else:
1114
+ self._left_padding = (self._kernel_size - 1) // 2
1115
+ self._right_padding = (self._kernel_size - 1) // 2
1116
+ self._max_cache_len = 0
1117
+
1118
+ # Layer 1
1119
+ if self.is_causal:
1120
+ layers.append(
1121
+ CausalConv2D(
1122
+ in_channels=in_channels,
1123
+ out_channels=conv_channels,
1124
+ kernel_size=self._kernel_size,
1125
+ stride=self._stride,
1126
+ padding=None,
1127
+ ))
1128
+ else:
1129
+ layers.append(
1130
+ torch.nn.Conv2d(
1131
+ in_channels=in_channels,
1132
+ out_channels=conv_channels,
1133
+ kernel_size=self._kernel_size,
1134
+ stride=self._stride,
1135
+ padding=self._left_padding,
1136
+ ))
1137
+ in_channels = conv_channels
1138
+ layers.append(activation)
1139
+
1140
+ for i in range(self._sampling_num - 1):
1141
+ if self.is_causal:
1142
+ layers.append(
1143
+ CausalConv2D(
1144
+ in_channels=in_channels,
1145
+ out_channels=in_channels,
1146
+ kernel_size=self._kernel_size,
1147
+ stride=self._stride,
1148
+ padding=None,
1149
+ groups=in_channels,
1150
+ ))
1151
+ else:
1152
+ layers.append(
1153
+ torch.nn.Conv2d(
1154
+ in_channels=in_channels,
1155
+ out_channels=in_channels,
1156
+ kernel_size=self._kernel_size,
1157
+ stride=self._stride,
1158
+ padding=self._left_padding,
1159
+ groups=in_channels,
1160
+ ))
1161
+
1162
+ layers.append(
1163
+ torch.nn.Conv2d(
1164
+ in_channels=in_channels,
1165
+ out_channels=conv_channels,
1166
+ kernel_size=1,
1167
+ stride=1,
1168
+ padding=0,
1169
+ groups=1,
1170
+ ))
1171
+ layers.append(activation)
1172
+ in_channels = conv_channels
1173
+
1174
+ elif subsampling == "striding":
1175
+ self._stride = 2
1176
+ self._kernel_size = 3
1177
+ self._ceil_mode = False
1178
+
1179
+ if self.is_causal:
1180
+ self._left_padding = self._kernel_size - 1
1181
+ self._right_padding = self._stride - 1
1182
+ self._max_cache_len = subsampling_factor + 1
1183
+ else:
1184
+ self._left_padding = (self._kernel_size - 1) // 2
1185
+ self._right_padding = (self._kernel_size - 1) // 2
1186
+ self._max_cache_len = 0
1187
+
1188
+ for i in range(self._sampling_num):
1189
+ if self.is_causal:
1190
+ layers.append(
1191
+ CausalConv2D(
1192
+ in_channels=in_channels,
1193
+ out_channels=conv_channels,
1194
+ kernel_size=self._kernel_size,
1195
+ stride=self._stride,
1196
+ padding=None,
1197
+ ))
1198
+ else:
1199
+ layers.append(
1200
+ torch.nn.Conv2d(
1201
+ in_channels=in_channels,
1202
+ out_channels=conv_channels,
1203
+ kernel_size=self._kernel_size,
1204
+ stride=self._stride,
1205
+ padding=self._left_padding,
1206
+ ))
1207
+ layers.append(activation)
1208
+ in_channels = conv_channels
1209
+
1210
+ elif subsampling == "striding_conv1d":
1211
+ in_channels = feat_in
1212
+
1213
+ self._stride = 2
1214
+ self._kernel_size = 5
1215
+ self._ceil_mode = False
1216
+
1217
+ if self.is_causal:
1218
+ self._left_padding = self._kernel_size - 1
1219
+ self._right_padding = self._stride - 1
1220
+ self._max_cache_len = subsampling_factor + 1
1221
+ else:
1222
+ self._left_padding = (self._kernel_size - 1) // 2
1223
+ self._right_padding = (self._kernel_size - 1) // 2
1224
+ self._max_cache_len = 0
1225
+
1226
+ for i in range(self._sampling_num):
1227
+ if self.is_causal:
1228
+ layers.append(
1229
+ CausalConv1D(
1230
+ in_channels=in_channels,
1231
+ out_channels=(feat_out if self._sampling_num == i +
1232
+ 1 else conv_channels),
1233
+ kernel_size=self._kernel_size,
1234
+ stride=self._stride,
1235
+ padding=None,
1236
+ ))
1237
+ else:
1238
+ layers.append(
1239
+ torch.nn.Conv1d(
1240
+ in_channels=in_channels,
1241
+ out_channels=(feat_out if self._sampling_num == i +
1242
+ 1 else conv_channels),
1243
+ kernel_size=self._kernel_size,
1244
+ stride=self._stride,
1245
+ padding=self._left_padding,
1246
+ ))
1247
+ layers.append(activation)
1248
+ in_channels = conv_channels
1249
+
1250
+ elif subsampling == "dw_striding_conv1d":
1251
+ in_channels = feat_in
1252
+
1253
+ self._stride = 2
1254
+ self._kernel_size = 5
1255
+ self._ceil_mode = False
1256
+
1257
+ self._left_padding = (self._kernel_size - 1) // 2
1258
+ self._right_padding = (self._kernel_size - 1) // 2
1259
+
1260
+ # Layer 1
1261
+ layers.extend([
1262
+ torch.nn.Conv1d(
1263
+ in_channels=in_channels,
1264
+ out_channels=in_channels,
1265
+ kernel_size=self._kernel_size,
1266
+ stride=self._stride,
1267
+ padding=self._left_padding,
1268
+ groups=in_channels,
1269
+ ),
1270
+ torch.nn.Conv1d(
1271
+ in_channels=in_channels,
1272
+ out_channels=(feat_out if self._sampling_num == 1 else
1273
+ conv_channels),
1274
+ kernel_size=1,
1275
+ stride=1,
1276
+ padding=0,
1277
+ groups=1,
1278
+ ),
1279
+ ])
1280
+ in_channels = conv_channels
1281
+ layers.append(activation)
1282
+
1283
+ for i in range(self._sampling_num - 1):
1284
+ layers.extend([
1285
+ torch.nn.Conv1d(
1286
+ in_channels=in_channels,
1287
+ out_channels=in_channels,
1288
+ kernel_size=self._kernel_size,
1289
+ stride=self._stride,
1290
+ padding=self._left_padding,
1291
+ groups=in_channels,
1292
+ ),
1293
+ torch.nn.Conv1d(
1294
+ in_channels=in_channels,
1295
+ out_channels=(feat_out if self._sampling_num == i +
1296
+ 2 else conv_channels),
1297
+ kernel_size=1,
1298
+ stride=1,
1299
+ padding=0,
1300
+ groups=1,
1301
+ ),
1302
+ ])
1303
+ layers.append(activation)
1304
+ in_channels = conv_channels
1305
+
1306
+ else:
1307
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1308
+
1309
+ if subsampling in ["dw_striding", "striding"]:
1310
+ in_length = torch.tensor(feat_in, dtype=torch.float)
1311
+ out_length = calc_length(
1312
+ lengths=in_length,
1313
+ all_paddings=self._left_padding + self._right_padding,
1314
+ kernel_size=self._kernel_size,
1315
+ stride=self._stride,
1316
+ ceil_mode=self._ceil_mode,
1317
+ repeat_num=self._sampling_num,
1318
+ )
1319
+ self.out = torch.nn.Linear(conv_channels * int(out_length),
1320
+ feat_out)
1321
+ self.conv2d_subsampling = True
1322
+ elif subsampling in ["striding_conv1d", "dw_striding_conv1d"]:
1323
+ self.out = None
1324
+ self.conv2d_subsampling = False
1325
+ else:
1326
+ raise ValueError(f"Not valid sub-sampling: {subsampling}!")
1327
+
1328
+ self.conv = torch.nn.Sequential(*layers)
1329
+
1330
+ def get_sampling_frames(self):
1331
+ return [1, self.subsampling_factor]
1332
+
1333
+ def get_streaming_cache_size(self):
1334
+ return [0, self.subsampling_factor + 1]
1335
+
1336
+ def forward(self, x, mask):
1337
+ """
1338
+ Forward method for NeMo subsampling.
1339
+
1340
+ Args:
1341
+ x[Batch, Time, Filters]: torch.Tensor
1342
+ input tensor
1343
+ x_mask: torch.Tensor
1344
+ input mask
1345
+
1346
+ Returns:
1347
+ x: torch.Tensor
1348
+ Resulting tensor from subsampling (B, T //
1349
+ time_reduction_factor, feat_out)
1350
+ pad_mask: torch.Tensor
1351
+ tensor of padded hidden state sequences (B, 1, T //
1352
+ time_reduction_factor)
1353
+ """
1354
+ x = x.unsqueeze(1) if self.conv2d_subsampling else x.transpose(1, 2)
1355
+
1356
+ # split inputs if chunking_factor is set
1357
+ if (self.subsampling_conv_chunking_factor != -1
1358
+ and self.conv2d_subsampling):
1359
+ if self.subsampling_conv_chunking_factor == 1:
1360
+ # if subsampling_conv_chunking_factor is 1, we split only
1361
+ # if needed.
1362
+ # avoiding a bug / feature limiting indexing of tensors
1363
+ # to 2**31.
1364
+ # see https://github.com/pytorch/pytorch/issues/80020
1365
+ x_ceil = (2**31 / self._conv_channels * self._stride *
1366
+ self._stride)
1367
+ need_to_split = torch.numel(x) > x_ceil
1368
+ else:
1369
+ # if subsampling_conv_chunking_factor > 1 we always split
1370
+ need_to_split = True
1371
+
1372
+ if need_to_split:
1373
+ x, success = self.conv_split_by_batch(x)
1374
+ if not success: # if unable to split by batch, try by channel
1375
+ if self._subsampling == "dw_striding":
1376
+ x = self.conv_split_by_channel(x)
1377
+ else:
1378
+ x = self.conv(x) # try anyway
1379
+ else:
1380
+ x = self.conv(x)
1381
+ else:
1382
+ x = self.conv(x)
1383
+
1384
+ # Flatten Channel and Frequency Axes
1385
+ if self.conv2d_subsampling:
1386
+ b, c, t, f = x.size()
1387
+ x = self.out(x.transpose(1, 2).reshape(b, t, -1))
1388
+ # Transpose to Channel Last mode
1389
+ else:
1390
+ x = x.transpose(1, 2)
1391
+
1392
+ if mask is None:
1393
+ return x, None
1394
+
1395
+ max_audio_length = x.shape[1]
1396
+ feature_lens = mask.sum(1)
1397
+ padding_length = torch.ceil(feature_lens / self.subsampling_factor)
1398
+ if self.is_causal and self.subsampling_causal_cond:
1399
+ feature_lens_remainder = feature_lens % self.subsampling_factor
1400
+ padding_length[feature_lens_remainder != 1] += 1
1401
+ pad_mask = torch.arange(0, max_audio_length, device=x.device).expand(
1402
+ padding_length.size(0), -1) < padding_length.unsqueeze(1)
1403
+ return x, pad_mask.unsqueeze(1)
1404
+
1405
+ def reset_parameters(self):
1406
+ # initialize weights
1407
+ if self._subsampling == "dw_striding":
1408
+ with torch.no_grad():
1409
+ # init conv
1410
+ scale = 1.0 / self._kernel_size
1411
+ dw_max = (self._kernel_size**2)**-0.5
1412
+ pw_max = self._conv_channels**-0.5
1413
+
1414
+ torch.nn.init.uniform_(self.conv[0].weight, -scale, scale)
1415
+ torch.nn.init.uniform_(self.conv[0].bias, -scale, scale)
1416
+
1417
+ for idx in range(2, len(self.conv), 3):
1418
+ torch.nn.init.uniform_(self.conv[idx].weight, -dw_max,
1419
+ dw_max)
1420
+ torch.nn.init.uniform_(self.conv[idx].bias, -dw_max,
1421
+ dw_max)
1422
+ torch.nn.init.uniform_(self.conv[idx + 1].weight, -pw_max,
1423
+ pw_max)
1424
+ torch.nn.init.uniform_(self.conv[idx + 1].bias, -pw_max,
1425
+ pw_max)
1426
+
1427
+ # init fc (80 * 64 = 5120 from https://github.com/kssteven418/
1428
+ # Squeezeformer/blob/13c97d6cf92f2844d2cb3142b4c5bfa9ad1a8951/
1429
+ # src/models/conformer_encoder.py#L487
1430
+ fc_scale = (self._feat_out * self._feat_in /
1431
+ self._sampling_num)**-0.5
1432
+ torch.nn.init.uniform_(self.out.weight, -fc_scale, fc_scale)
1433
+ torch.nn.init.uniform_(self.out.bias, -fc_scale, fc_scale)
1434
+
1435
+ def conv_split_by_batch(self, x):
1436
+ """Tries to split input by batch, run conv and concat results"""
1437
+ b, _, _, _ = x.size()
1438
+ if b == 1: # can't split if batch size is 1
1439
+ return x, False
1440
+
1441
+ if self.subsampling_conv_chunking_factor > 1:
1442
+ cf = self.subsampling_conv_chunking_factor
1443
+ else:
1444
+ # avoiding a bug / feature limiting indexing of tensors to 2**31
1445
+ # see https://github.com/pytorch/pytorch/issues/80020
1446
+ x_ceil = 2**31 / self._conv_channels * self._stride * self._stride
1447
+ p = math.ceil(math.log(torch.numel(x) / x_ceil, 2))
1448
+ cf = 2**p
1449
+
1450
+ new_batch_size = b // cf
1451
+ if new_batch_size == 0: # input is too big
1452
+ return x, False
1453
+
1454
+ return (
1455
+ torch.cat([
1456
+ self.conv(chunk)
1457
+ for chunk in torch.split(x, new_batch_size, 0)
1458
+ ]),
1459
+ True,
1460
+ )
1461
+
1462
+ def conv_split_by_channel(self, x):
1463
+ """For dw convs, tries to split input by time, run conv and concat
1464
+ results"""
1465
+ x = self.conv[0](x) # full conv2D
1466
+ x = self.conv[1](x) # activation
1467
+
1468
+ for i in range(self._sampling_num - 1):
1469
+ _, c, t, _ = x.size()
1470
+
1471
+ if self.subsampling_conv_chunking_factor > 1:
1472
+ cf = self.subsampling_conv_chunking_factor
1473
+ else:
1474
+ # avoiding a bug / feature limiting indexing of tensors
1475
+ # to 2**31
1476
+ # see https://github.com/pytorch/pytorch/issues/80020
1477
+ p = math.ceil(math.log(torch.numel(x) / 2**31, 2))
1478
+ cf = 2**p
1479
+
1480
+ new_c = int(c // cf)
1481
+ if new_c == 0:
1482
+ new_c = 1
1483
+
1484
+ new_t = int(t // cf)
1485
+ if new_t == 0:
1486
+ new_t = 1
1487
+
1488
+ x = self.channel_chunked_conv(self.conv[i * 3 + 2], new_c,
1489
+ x) # conv2D, depthwise
1490
+
1491
+ # splitting pointwise convs by time
1492
+ x = torch.cat(
1493
+ [
1494
+ self.conv[i * 3 + 3](chunk)
1495
+ for chunk in torch.split(x, new_t, 2)
1496
+ ],
1497
+ 2,
1498
+ ) # conv2D, pointwise
1499
+ x = self.conv[i * 3 + 4](x) # activation
1500
+ return x
1501
+
1502
+ def channel_chunked_conv(self, conv, chunk_size, x):
1503
+ """Performs channel chunked convolution"""
1504
+
1505
+ ind = 0
1506
+ out_chunks = []
1507
+ for chunk in torch.split(x, chunk_size, 1):
1508
+ step = chunk.size()[1]
1509
+
1510
+ if self.is_causal:
1511
+ chunk = nn.functional.pad(
1512
+ chunk,
1513
+ pad=(
1514
+ self._kernel_size - 1,
1515
+ self._stride - 1,
1516
+ self._kernel_size - 1,
1517
+ self._stride - 1,
1518
+ ),
1519
+ )
1520
+ ch_out = nn.functional.conv2d(
1521
+ chunk,
1522
+ conv.weight[ind:ind + step, :, :, :],
1523
+ bias=conv.bias[ind:ind + step],
1524
+ stride=self._stride,
1525
+ padding=0,
1526
+ groups=step,
1527
+ )
1528
+ else:
1529
+ ch_out = nn.functional.conv2d(
1530
+ chunk,
1531
+ conv.weight[ind:ind + step, :, :, :],
1532
+ bias=conv.bias[ind:ind + step],
1533
+ stride=self._stride,
1534
+ padding=self._left_padding,
1535
+ groups=step,
1536
+ )
1537
+ out_chunks.append(ch_out)
1538
+ ind += step
1539
+
1540
+ return torch.cat(out_chunks, 1)
1541
+
1542
+ def change_subsampling_conv_chunking_factor(
1543
+ self, subsampling_conv_chunking_factor: int):
1544
+ if (subsampling_conv_chunking_factor != -1
1545
+ and subsampling_conv_chunking_factor != 1
1546
+ and subsampling_conv_chunking_factor % 2 != 0):
1547
+ raise ValueError(
1548
+ "subsampling_conv_chunking_factor should be -1, 1, or a "\
1549
+ "power of 2"
1550
+ )
1551
+ self.subsampling_conv_chunking_factor = subsampling_conv_chunking_factor
1552
+
1553
+
1554
+ def calc_length(lengths,
1555
+ all_paddings,
1556
+ kernel_size,
1557
+ stride,
1558
+ ceil_mode,
1559
+ repeat_num=1):
1560
+ """Calculates the output length of a Tensor passed through a convolution or
1561
+ max pooling layer"""
1562
+ add_pad: float = all_paddings - kernel_size
1563
+ one: float = 1.0
1564
+ for i in range(repeat_num):
1565
+ lengths = (torch.div(lengths.to(dtype=torch.float) + add_pad, stride) +
1566
+ one)
1567
+ lengths = torch.ceil(lengths) if ceil_mode else torch.floor(lengths)
1568
+ return lengths.to(dtype=torch.int)
1569
+
1570
+
1571
+ #### multihead attention starts here
1572
+ class AttModule(nn.Module):
1573
+ """Attention abstraction module"""
1574
+
1575
+ def __init__(self):
1576
+ super().__init__()
1577
+ self.export_mode = False
1578
+
1579
+ def set_export(self, mode=True):
1580
+ """set the export mode"""
1581
+ self.export_mode = mode
1582
+
1583
+ def forward(
1584
+ self,
1585
+ x: Tensor,
1586
+ memory: Optional[Tensor] = None,
1587
+ pos_emb: Optional[Tensor] = None,
1588
+ att_mask: Optional[Tensor] = None,
1589
+ ) -> Tuple[Tensor, Tensor, Optional[Tensor], Optional[Tensor]]:
1590
+ """AttModule forward
1591
+
1592
+ Args:
1593
+ x: torch.Tensor
1594
+ input tensor.
1595
+ memory: torch.Tensor, optional
1596
+ memory tensor.
1597
+ pos_emb: torch.Tensor, optional
1598
+ positional encoder embedding.
1599
+ att_mask: torch.Tensor, optional
1600
+ attention mask tensor.
1601
+ """
1602
+ return x, memory, pos_emb, att_mask
1603
+
1604
+
1605
+ class AttBlock(Block, AttModule):
1606
+ """Attention Block module to support both Attention and Block module."""
1607
+
1608
+ def memory_dims(self, max_len=False):
1609
+ """memory dimensions"""
1610
+ return (1, self.input_size)
1611
+
1612
+
1613
+ def masked_softmax(
1614
+ scores,
1615
+ mask: Optional[Tensor],
1616
+ ):
1617
+ if mask is not None:
1618
+ mask = mask.unsqueeze(1).eq(0) # (batch, 1, time1, time2)
1619
+ scores = scores.masked_fill(mask, -torch.inf)
1620
+ attn = torch.softmax(scores, dim=-1).masked_fill(
1621
+ mask, 0.0) # (batch, head, time1, time2)
1622
+ else:
1623
+ attn = torch.softmax(scores, dim=-1) # (batch, head, time1, time2)
1624
+ return attn
1625
+
1626
+
1627
+ class MultiHeadedAttention(nn.Module):
1628
+ """Multi-Head Attention layer with optional relative position embedding
1629
+ and GLU.
1630
+
1631
+ Args:
1632
+ n_head: int
1633
+ the number of heads.
1634
+ n_feat: int
1635
+ input size features.
1636
+ dropout_rate: float
1637
+ dropout rate.
1638
+ use_LN: bool
1639
+ apply layer norm or not
1640
+ dropout_at_output: bool
1641
+ whether to apply dropout at output
1642
+ attention_inner_dim: int, optional
1643
+ the attention dimension used in the class,
1644
+ it can be different from the input dimension n_feat.
1645
+ default: -1 (equal to n_feat).
1646
+ use_pt_scaled_dot_product_attention: bool, optional
1647
+ if set True, use pytorch scaled dot product attention in training.
1648
+ NOTE: this will NOT be used in ONNX decoding due to a lack of
1649
+ support. In that case, we use the original attention
1650
+ implementation, which shows no regression.
1651
+ default: False.
1652
+ n_value: int, optional
1653
+ if set to values other than -1, use a different dimension for
1654
+ value. With the default value (i.e. -1), it is backward compatible.
1655
+ group_size: int, optional. must divide `n_head`
1656
+ if group_size > 1: GQA
1657
+ if group_size = 1: MHA
1658
+ if group_size = n_head: MQA
1659
+ """
1660
+
1661
+ inv_sqrt_d_k: torch.jit.Final[float]
1662
+ h: torch.jit.Final[int]
1663
+ h_k: torch.jit.Final[int]
1664
+ g: torch.jit.Final[int]
1665
+
1666
+ def __init__(
1667
+ self,
1668
+ n_head,
1669
+ n_feat,
1670
+ dropout_rate,
1671
+ attention_inner_dim=-1,
1672
+ glu_type="swish",
1673
+ bias_in_glu=True,
1674
+ use_pt_scaled_dot_product_attention=False,
1675
+ n_value=-1,
1676
+ group_size: int = 1,
1677
+ ):
1678
+ super().__init__()
1679
+ if n_value == -1:
1680
+ n_value = n_feat
1681
+ if attention_inner_dim == -1:
1682
+ attention_inner_dim = n_feat
1683
+ assert attention_inner_dim % n_head == 0
1684
+
1685
+ # We assume d_v always equals d_k
1686
+ self.d_k = attention_inner_dim // n_head
1687
+ self.inv_sqrt_d_k = 1.0 / math.sqrt(self.d_k)
1688
+ self.h = n_head
1689
+ assert n_head % group_size == 0, "group_size must divide n_head"
1690
+ self.g = group_size
1691
+ self.h_k = n_head // group_size
1692
+
1693
+ self.linear_q = nn.Linear(n_feat, attention_inner_dim)
1694
+ self.linear_k = nn.Linear(n_feat, attention_inner_dim // group_size)
1695
+ self.linear_v = nn.Linear(n_value, attention_inner_dim // group_size)
1696
+ self.linear_out = nn.Linear(attention_inner_dim // group_size, n_value)
1697
+
1698
+ self.attn = torch.jit.Attribute(None, Optional[Tensor])
1699
+ self.dropout = nn.Dropout(p=dropout_rate)
1700
+ self.dropout_rate = dropout_rate
1701
+ self.use_pt_scaled_dot_product_attention = (
1702
+ use_pt_scaled_dot_product_attention)
1703
+
1704
+ if use_pt_scaled_dot_product_attention and group_size > 1:
1705
+ raise ValueError("Cannot use PT Scaled Attention with GQA")
1706
+
1707
+ # Torchscript eager quantization. Note that these functions below are
1708
+ # NOOPs and have very little impact on performance unless quantization
1709
+ # is enabled.
1710
+ self.quant_q = torch.ao.quantization.QuantStub()
1711
+ self.quant_x = torch.ao.quantization.QuantStub()
1712
+ self.dequant = torch.ao.quantization.DeQuantStub()
1713
+ self.ffunc = torch.ao.nn.quantized.FloatFunctional()
1714
+
1715
+ def forward(
1716
+ self,
1717
+ query: Tensor,
1718
+ key: Tensor,
1719
+ value: Tensor,
1720
+ pos_k: Tensor,
1721
+ pos_v: Tensor,
1722
+ mask: Optional[Tensor],
1723
+ relative_attention_bias: Optional[Tensor] = None,
1724
+ ):
1725
+ """Compute 'Scaled Dot Product Attention'.
1726
+
1727
+ Args:
1728
+ query: torch.Tensor
1729
+ query tensor (batch, time1, size)
1730
+ key: torch.Tensor
1731
+ key tensor (batch, time2, size)
1732
+ value: torch.Tensor
1733
+ value tensor (batch, time1, size)
1734
+ pos_k: torch.Tensor
1735
+ key tensor used for relative positional embedding.
1736
+ pos_v: torch.Tensor
1737
+ value tensor used for relative positional embedding.
1738
+ mask: torch.Tensor
1739
+ mask tensor (batch, time1, time2)
1740
+ relative_attention_bias: torch.Tensor
1741
+ bias added to attention logits w.r.t. relative positions
1742
+ (1, n_head, time1, time2)
1743
+ """
1744
+ n_batch = query.size(0)
1745
+
1746
+ q = self.linear_q(query).view(n_batch, -1, self.h,
1747
+ self.d_k) # (b, t, d)
1748
+ k = self.linear_k(key).view(n_batch, -1, self.h_k,
1749
+ self.d_k) # (b, t, d)
1750
+ v = self.linear_v(value).view(n_batch, -1, self.h_k, self.d_k)
1751
+ q = (q.transpose(1, 2) if self.use_pt_scaled_dot_product_attention
1752
+ and not torch.jit.is_scripting() else q.transpose(1, 2) *
1753
+ self.inv_sqrt_d_k)
1754
+ k = k.transpose(1, 2) # (batch, head_k, time2, d_k)
1755
+ v = v.transpose(1, 2) # (batch, head_k, time2, d_k)
1756
+
1757
+ if (self.use_pt_scaled_dot_product_attention
1758
+ and not torch.jit.is_scripting()):
1759
+ attn_mask = None
1760
+ if mask is not None:
1761
+ mask = mask.unsqueeze(1)
1762
+ if relative_attention_bias is not None:
1763
+ attn_mask = mask + relative_attention_bias
1764
+ else:
1765
+ attn_mask = mask
1766
+ if mask.dtype != q.dtype:
1767
+ attn_mask = attn_mask.to(q.dtype)
1768
+
1769
+ with torch.nn.attention.sdpa_kernel([
1770
+ torch.nn.attention.SDPBackend.FLASH_ATTENTION,
1771
+ torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
1772
+ torch.nn.attention.SDPBackend.MATH,
1773
+ torch.nn.attention.SDPBackend.CUDNN_ATTENTION,
1774
+ ]):
1775
+ x = torch.nn.functional.scaled_dot_product_attention(
1776
+ q,
1777
+ k,
1778
+ v,
1779
+ attn_mask=attn_mask,
1780
+ dropout_p=self.dropout_rate,
1781
+ )
1782
+ else:
1783
+ if self.h != self.h_k:
1784
+ q = q.reshape(n_batch, self.g, self.h_k, -1, self.d_k)
1785
+ A = torch.einsum("b g h t d, b h s d -> b h t s", q, k)
1786
+ else:
1787
+ A = torch.matmul(q, k.transpose(-2, -1))
1788
+ if pos_k is not None:
1789
+ if self.h != self.h_k:
1790
+ B = torch.einsum("b g h t d, t s d -> b h t s", q, pos_k)
1791
+ else:
1792
+ reshape_q = (q.contiguous().view(n_batch * self.h, -1,
1793
+ self.d_k).transpose(0, 1)
1794
+ ) # (t1,nh,dk)
1795
+ B = torch.matmul(reshape_q,
1796
+ pos_k.transpose(-2,
1797
+ -1)) # pos_k: (t1,dk,t2)
1798
+ B = B.transpose(0, 1).view(n_batch, self.h, pos_k.size(0),
1799
+ pos_k.size(1))
1800
+ scores = A + B
1801
+ else:
1802
+ scores = A
1803
+
1804
+ if relative_attention_bias is not None:
1805
+ scores = scores + relative_attention_bias
1806
+
1807
+ attn = masked_softmax(scores, mask) # (batch, head, time1, time2)
1808
+
1809
+ self.attn = attn
1810
+
1811
+ p_attn = self.dropout(attn)
1812
+ x = torch.matmul(p_attn.to(v.dtype),
1813
+ v) # (batch, head, time1, d_k)
1814
+ if pos_v is not None:
1815
+ reshape_attn = (p_attn.contiguous().view(
1816
+ n_batch * self.h, pos_v.size(0),
1817
+ pos_v.size(1)).transpose(0, 1)) # (t1, bh, t2)
1818
+
1819
+ attn_v = (torch.matmul(reshape_attn, pos_v).transpose(
1820
+ 0, 1).contiguous().view(n_batch, self.h, pos_v.size(0),
1821
+ self.d_k))
1822
+ x = x + attn_v
1823
+ x = (x.transpose(1, 2).contiguous().view(n_batch, -1,
1824
+ self.h_k * self.d_k)
1825
+ ) # (batch, time1, d_model)
1826
+
1827
+ return self.linear_out(x) # (batch, time1, d_model)
1828
+
1829
+
1830
+ class MultiSequential(torch.nn.Sequential):
1831
+ """Multi-input multi-output torch.nn.Sequential"""
1832
+
1833
+ @torch.jit.ignore
1834
+ def forward(self, *args):
1835
+ """Forward method implementation."""
1836
+ for m in self:
1837
+ args = m(*args)
1838
+ return args
1839
+
1840
+
1841
+ def get_offset(input_layer: str, time_reduction: int):
1842
+ """Get an offset. We will use the offset for determining #frames of a
1843
+ subsampled feature.
1844
+
1845
+ Args:
1846
+ input_layer (str): Type of an input layer
1847
+ time_reduction (int): time reduction factor for downsampling a feature
1848
+ Returns:
1849
+ int: offset
1850
+ """
1851
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 4:
1852
+ return 3
1853
+ if input_layer in ("conv2d", ) and time_reduction == 6:
1854
+ return 1
1855
+ if input_layer in ("conv2d", "nemo_conv") and time_reduction == 8:
1856
+ return 7
1857
+ return 0
1858
+
1859
+
1860
+ def unfold_tensor(xs_pad, max_seq_len):
1861
+ """
1862
+ For a given tensor with shape of (N, T, D), if sequence length T is
1863
+ longer than max_seq_len, this function unfold it to a
1864
+ (NT', max_seq_len, D) where T' is T // max_seq_len.
1865
+ Args:
1866
+ xs_pad: N, T, D
1867
+ """
1868
+ _, _, D = xs_pad.shape
1869
+ xs_pad = xs_pad.transpose(-1, -2) # convert to N, D, T
1870
+ # N x D x 1 x T => N x (D x max_seq_len) x T'
1871
+ xs_pad = F.unfold(
1872
+ xs_pad[..., None, :],
1873
+ kernel_size=(1, max_seq_len),
1874
+ stride=(1, max_seq_len),
1875
+ )
1876
+ new_bsz, _, slen = xs_pad.shape
1877
+ # N x D x max_seq_len x T'
1878
+ xs_pad = xs_pad.view(new_bsz, -1, max_seq_len, slen)
1879
+ # N x T' x max_seq_len x D
1880
+ xs_pad = xs_pad.permute(0, 3, 2, 1).contiguous()
1881
+ # NT' x max_seq_len x D
1882
+ xs_pad = xs_pad.view(-1, max_seq_len, D)
1883
+ return xs_pad