onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1035 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ import json
7
+ import logging
8
+ import os
9
+ from pathlib import Path
10
+
11
+ import numpy as np
12
+ import torch
13
+ from convert_generation import add_cache_indirection_to_mha, add_output_qk_to_mha, fix_past_sequence_length
14
+ from optimizer import optimize_model
15
+ from transformers import AutoTokenizer, WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
16
+ from whisper_decoder import WhisperDecoder
17
+ from whisper_encoder import WhisperEncoder
18
+ from whisper_encoder_decoder_init import WhisperEncoderDecoderInit
19
+ from whisper_jump_times import WhisperJumpTimes
20
+
21
+ from onnxruntime import InferenceSession
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+ PRETRAINED_WHISPER_MODELS = [
26
+ "whisper-tiny",
27
+ "whisper-tiny.en",
28
+ "whisper-base",
29
+ "whisper-base.en",
30
+ "whisper-small",
31
+ "whisper-small.en",
32
+ "whisper-medium",
33
+ "whisper-medium.en",
34
+ "whisper-large",
35
+ "whisper-large-v2",
36
+ "whisper-large-v3",
37
+ "whisper-large-v3-turbo",
38
+ ]
39
+
40
+
41
+ class WhisperHelper:
42
+ @staticmethod
43
+ def get_onnx_path(
44
+ output_dir: str,
45
+ model_name_or_path: str,
46
+ suffix: str = "",
47
+ new_folder: bool = False,
48
+ ) -> str:
49
+ """Build onnx path
50
+
51
+ Args:
52
+ output_dir (str): output directory
53
+ model_name_or_path (str): pretrained model name, or path to the model checkpoint
54
+ suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
55
+ new_folder (bool, optional): create a new directory for the model. Defaults to False.
56
+ Returns:
57
+ str: path of onnx model
58
+ """
59
+ model_name = model_name_or_path
60
+ if os.path.isdir(model_name_or_path):
61
+ model_name = Path(model_name_or_path).parts[-1]
62
+ else:
63
+ model_name = model_name.split("/")[-1]
64
+
65
+ model_name += suffix
66
+
67
+ directory = os.path.join(output_dir, model_name) if new_folder else output_dir
68
+ return os.path.join(directory, model_name + ".onnx")
69
+
70
+ @staticmethod
71
+ def save_processing(
72
+ model_name_or_path: str,
73
+ provider: str,
74
+ separate_encoder_and_decoder_init: bool,
75
+ use_decoder_masked_mha: bool,
76
+ output_qk: bool,
77
+ encoder_path: str,
78
+ decoder_path: str,
79
+ output_dir: str,
80
+ cache_dir: str,
81
+ ) -> None:
82
+ config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
83
+ config.save_pretrained(output_dir)
84
+
85
+ tokenizer = AutoTokenizer.from_pretrained(model_name_or_path, cache_dir=cache_dir)
86
+ tokenizer.save_pretrained(output_dir)
87
+
88
+ processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir)
89
+ processor.save_pretrained(output_dir)
90
+
91
+ # Return early since the next files are for ONNX Runtime GenAI
92
+ if separate_encoder_and_decoder_init:
93
+ return
94
+
95
+ audio_processor_cfg = {
96
+ "feature_extraction": {
97
+ "sequence": [
98
+ {"operation": {"name": "audio_decoder", "type": "AudioDecoder"}},
99
+ {
100
+ "operation": {
101
+ "name": "STFT",
102
+ "type": "STFTNorm",
103
+ "attrs": {
104
+ "n_fft": 400,
105
+ "frame_length": 400,
106
+ "hop_length": 160,
107
+ "_comment": [
108
+ 0.0,
109
+ 0.0000616908073425293,
110
+ 0.0002467334270477295,
111
+ 0.0005550682544708252,
112
+ 0.000986635684967041,
113
+ 0.0015413463115692139,
114
+ 0.0022190213203430176,
115
+ 0.0030195116996765137,
116
+ 0.003942638635635376,
117
+ 0.004988163709640503,
118
+ 0.006155818700790405,
119
+ 0.007445335388183594,
120
+ 0.008856385946273804,
121
+ 0.010388582944869995,
122
+ 0.012041628360748291,
123
+ 0.013815045356750488,
124
+ 0.01570841670036316,
125
+ 0.01772129535675049,
126
+ 0.019853144884109497,
127
+ 0.022103488445281982,
128
+ 0.02447172999382019,
129
+ 0.026957333087921143,
130
+ 0.029559612274169922,
131
+ 0.03227800130844116,
132
+ 0.03511175513267517,
133
+ 0.03806024789810181,
134
+ 0.0411226749420166,
135
+ 0.044298380613327026,
136
+ 0.04758647084236145,
137
+ 0.05098623037338257,
138
+ 0.05449673533439636,
139
+ 0.058117181062698364,
140
+ 0.06184667348861694,
141
+ 0.0656842589378357,
142
+ 0.06962898373603821,
143
+ 0.07367992401123047,
144
+ 0.0778360664844513,
145
+ 0.08209633827209473,
146
+ 0.08645972609519958,
147
+ 0.09092515707015991,
148
+ 0.09549149870872498,
149
+ 0.10015767812728882,
150
+ 0.10492250323295593,
151
+ 0.1097848117351532,
152
+ 0.11474338173866272,
153
+ 0.11979702115058899,
154
+ 0.12494447827339172,
155
+ 0.13018447160720825,
156
+ 0.1355157196521759,
157
+ 0.14093685150146484,
158
+ 0.1464466154575348,
159
+ 0.15204361081123352,
160
+ 0.1577264666557312,
161
+ 0.16349375247955322,
162
+ 0.16934409737586975,
163
+ 0.1752760112285614,
164
+ 0.18128803372383118,
165
+ 0.18737870454788208,
166
+ 0.19354650378227234,
167
+ 0.1997898817062378,
168
+ 0.20610737800598145,
169
+ 0.21249738335609436,
170
+ 0.21895831823349,
171
+ 0.2254886031150818,
172
+ 0.23208662867546082,
173
+ 0.23875075578689575,
174
+ 0.24547931551933289,
175
+ 0.2522706985473633,
176
+ 0.25912320613861084,
177
+ 0.26603513956069946,
178
+ 0.27300477027893066,
179
+ 0.2800304591655731,
180
+ 0.2871103882789612,
181
+ 0.29424285888671875,
182
+ 0.30142611265182495,
183
+ 0.30865830183029175,
184
+ 0.31593772768974304,
185
+ 0.3232625722885132,
186
+ 0.3306310474872589,
187
+ 0.3380413055419922,
188
+ 0.34549152851104736,
189
+ 0.352979838848114,
190
+ 0.3605044484138489,
191
+ 0.3680635094642639,
192
+ 0.37565508484840393,
193
+ 0.38327735662460327,
194
+ 0.3909284174442291,
195
+ 0.39860638976097107,
196
+ 0.4063093662261963,
197
+ 0.41403549909591675,
198
+ 0.42178282141685486,
199
+ 0.4295494258403778,
200
+ 0.43733343482017517,
201
+ 0.44513291120529175,
202
+ 0.45294591784477234,
203
+ 0.46077051758766174,
204
+ 0.46860480308532715,
205
+ 0.4764467775821686,
206
+ 0.4842946231365204,
207
+ 0.492146372795105,
208
+ 0.5,
209
+ 0.5078536868095398,
210
+ 0.515705406665802,
211
+ 0.5235532522201538,
212
+ 0.5313953161239624,
213
+ 0.5392295718193054,
214
+ 0.5470541715621948,
215
+ 0.5548672080039978,
216
+ 0.562666654586792,
217
+ 0.5704506635665894,
218
+ 0.5782172679901123,
219
+ 0.5859646201133728,
220
+ 0.5936906933784485,
221
+ 0.6013936996459961,
222
+ 0.609071671962738,
223
+ 0.6167227625846863,
224
+ 0.6243450045585632,
225
+ 0.6319366097450256,
226
+ 0.6394955515861511,
227
+ 0.6470202207565308,
228
+ 0.6545085310935974,
229
+ 0.6619587540626526,
230
+ 0.6693689823150635,
231
+ 0.6767374277114868,
232
+ 0.6840623021125793,
233
+ 0.691341757774353,
234
+ 0.6985740065574646,
235
+ 0.7057572603225708,
236
+ 0.7128896713256836,
237
+ 0.719969630241394,
238
+ 0.7269952893257141,
239
+ 0.7339649796485901,
240
+ 0.7408769130706787,
241
+ 0.7477294206619263,
242
+ 0.7545207738876343,
243
+ 0.761249303817749,
244
+ 0.7679134607315063,
245
+ 0.774511456489563,
246
+ 0.7810417413711548,
247
+ 0.7875027060508728,
248
+ 0.7938927412033081,
249
+ 0.800210177898407,
250
+ 0.8064535856246948,
251
+ 0.8126214146614075,
252
+ 0.8187121152877808,
253
+ 0.8247240781784058,
254
+ 0.8306560516357422,
255
+ 0.8365063667297363,
256
+ 0.8422735929489136,
257
+ 0.8479564785957336,
258
+ 0.8535534143447876,
259
+ 0.8590631484985352,
260
+ 0.8644843101501465,
261
+ 0.8698155879974365,
262
+ 0.8750555515289307,
263
+ 0.8802030086517334,
264
+ 0.8852566480636597,
265
+ 0.8902152180671692,
266
+ 0.8950775265693665,
267
+ 0.899842381477356,
268
+ 0.9045084714889526,
269
+ 0.9090749025344849,
270
+ 0.9135403037071228,
271
+ 0.9179036617279053,
272
+ 0.9221639633178711,
273
+ 0.9263200759887695,
274
+ 0.9303710460662842,
275
+ 0.9343158006668091,
276
+ 0.9381533861160278,
277
+ 0.941882848739624,
278
+ 0.945503294467926,
279
+ 0.9490138292312622,
280
+ 0.9524135589599609,
281
+ 0.9557017087936401,
282
+ 0.9588773250579834,
283
+ 0.961939811706543,
284
+ 0.9648882746696472,
285
+ 0.9677220582962036,
286
+ 0.9704403877258301,
287
+ 0.9730427265167236,
288
+ 0.9755282998085022,
289
+ 0.9778965711593628,
290
+ 0.9801468849182129,
291
+ 0.9822787046432495,
292
+ 0.9842916131019592,
293
+ 0.9861849546432495,
294
+ 0.9879584312438965,
295
+ 0.9896113872528076,
296
+ 0.9911436438560486,
297
+ 0.9925546646118164,
298
+ 0.9938441514968872,
299
+ 0.9950118064880371,
300
+ 0.996057391166687,
301
+ 0.9969804883003235,
302
+ 0.997780978679657,
303
+ 0.9984586238861084,
304
+ 0.999013364315033,
305
+ 0.9994449615478516,
306
+ 0.9997532367706299,
307
+ 0.9999383091926575,
308
+ 1,
309
+ 0.9999383091926575,
310
+ 0.9997532367706299,
311
+ 0.9994449615478516,
312
+ 0.999013364315033,
313
+ 0.9984586238861084,
314
+ 0.997780978679657,
315
+ 0.9969804286956787,
316
+ 0.9960573315620422,
317
+ 0.9950118064880371,
318
+ 0.9938441514968872,
319
+ 0.9925546646118164,
320
+ 0.9911435842514038,
321
+ 0.9896113872528076,
322
+ 0.9879583716392517,
323
+ 0.9861849546432495,
324
+ 0.9842915534973145,
325
+ 0.9822787046432495,
326
+ 0.9801468253135681,
327
+ 0.9778964519500732,
328
+ 0.9755282402038574,
329
+ 0.9730426073074341,
330
+ 0.9704403877258301,
331
+ 0.9677219390869141,
332
+ 0.9648882150650024,
333
+ 0.9619396924972534,
334
+ 0.9588772654533386,
335
+ 0.9557015895843506,
336
+ 0.9524134397506714,
337
+ 0.9490137100219727,
338
+ 0.9455032348632812,
339
+ 0.9418827295303345,
340
+ 0.9381532669067383,
341
+ 0.9343156814575195,
342
+ 0.9303709268569946,
343
+ 0.9263200759887695,
344
+ 0.9221639633178711,
345
+ 0.9179036617279053,
346
+ 0.913540244102478,
347
+ 0.9090747833251953,
348
+ 0.9045084714889526,
349
+ 0.8998422622680664,
350
+ 0.8950774669647217,
351
+ 0.8902151584625244,
352
+ 0.8852565884590149,
353
+ 0.8802029490470886,
354
+ 0.8750554919242859,
355
+ 0.869815468788147,
356
+ 0.8644842505455017,
357
+ 0.8590630888938904,
358
+ 0.853553295135498,
359
+ 0.8479562997817993,
360
+ 0.842273473739624,
361
+ 0.836506187915802,
362
+ 0.8306558728218079,
363
+ 0.8247239589691162,
364
+ 0.8187118768692017,
365
+ 0.8126212358474731,
366
+ 0.8064534664154053,
367
+ 0.8002099990844727,
368
+ 0.793892502784729,
369
+ 0.7875025272369385,
370
+ 0.7810416221618652,
371
+ 0.7745113372802734,
372
+ 0.767913281917572,
373
+ 0.7612491846084595,
374
+ 0.7545205950737,
375
+ 0.7477291822433472,
376
+ 0.7408767342567444,
377
+ 0.7339648008346558,
378
+ 0.7269951105117798,
379
+ 0.7199694514274597,
380
+ 0.7128894925117493,
381
+ 0.7057570219039917,
382
+ 0.6985738277435303,
383
+ 0.6913415789604187,
384
+ 0.684062123298645,
385
+ 0.6767372488975525,
386
+ 0.6693688035011292,
387
+ 0.6619585752487183,
388
+ 0.6545083522796631,
389
+ 0.6470199823379517,
390
+ 0.6394953727722168,
391
+ 0.6319363117218018,
392
+ 0.6243447661399841,
393
+ 0.6167224645614624,
394
+ 0.6090714335441589,
395
+ 0.601393461227417,
396
+ 0.5936904549598694,
397
+ 0.5859643220901489,
398
+ 0.5782170295715332,
399
+ 0.5704504251480103,
400
+ 0.5626664161682129,
401
+ 0.5548669099807739,
402
+ 0.5470539331436157,
403
+ 0.5392293334007263,
404
+ 0.5313950181007385,
405
+ 0.5235530138015747,
406
+ 0.5157051682472229,
407
+ 0.507853627204895,
408
+ 0.5,
409
+ 0.4921463429927826,
410
+ 0.484294593334198,
411
+ 0.4764467477798462,
412
+ 0.46860471367836,
413
+ 0.4607704281806946,
414
+ 0.4529458284378052,
415
+ 0.4451328217983246,
416
+ 0.437333345413208,
417
+ 0.42954933643341064,
418
+ 0.4217827320098877,
419
+ 0.4140354096889496,
420
+ 0.4063093066215515,
421
+ 0.3986063003540039,
422
+ 0.39092832803726196,
423
+ 0.3832772672176361,
424
+ 0.37565499544143677,
425
+ 0.36806342005729675,
426
+ 0.3605043888092041,
427
+ 0.35297977924346924,
428
+ 0.3454914391040802,
429
+ 0.338041216135025,
430
+ 0.33063095808029175,
431
+ 0.3232625126838684,
432
+ 0.3159376382827759,
433
+ 0.3086581826210022,
434
+ 0.3014259934425354,
435
+ 0.2942427396774292,
436
+ 0.28711026906967163,
437
+ 0.2800303101539612,
438
+ 0.2730046510696411,
439
+ 0.2660350203514099,
440
+ 0.2591230869293213,
441
+ 0.25227057933807373,
442
+ 0.24547919631004333,
443
+ 0.2387506067752838,
444
+ 0.23208650946617126,
445
+ 0.22548848390579224,
446
+ 0.21895819902420044,
447
+ 0.2124972641468048,
448
+ 0.2061072587966919,
449
+ 0.19978976249694824,
450
+ 0.1935463547706604,
451
+ 0.18737855553627014,
452
+ 0.18128788471221924,
453
+ 0.17527586221694946,
454
+ 0.1693439483642578,
455
+ 0.16349363327026367,
456
+ 0.15772631764411926,
457
+ 0.15204349160194397,
458
+ 0.14644649624824524,
459
+ 0.1409367322921753,
460
+ 0.13551557064056396,
461
+ 0.1301843225955963,
462
+ 0.12494435906410217,
463
+ 0.11979690194129944,
464
+ 0.11474326252937317,
465
+ 0.10978469252586365,
466
+ 0.10492238402366638,
467
+ 0.10015755891799927,
468
+ 0.09549137949943542,
469
+ 0.09092503786087036,
470
+ 0.08645960688591003,
471
+ 0.08209621906280518,
472
+ 0.07783591747283936,
473
+ 0.07367980480194092,
474
+ 0.06962886452674866,
475
+ 0.06568413972854614,
476
+ 0.06184655427932739,
477
+ 0.0581170916557312,
478
+ 0.0544966459274292,
479
+ 0.05098611116409302,
480
+ 0.04758638143539429,
481
+ 0.044298261404037476,
482
+ 0.04112258553504944,
483
+ 0.038060128688812256,
484
+ 0.03511166572570801,
485
+ 0.03227788209915161,
486
+ 0.02955952286720276,
487
+ 0.02695724368095398,
488
+ 0.024471670389175415,
489
+ 0.02210339903831482,
490
+ 0.01985308527946472,
491
+ 0.017721205949783325,
492
+ 0.015708357095718384,
493
+ 0.0138150155544281,
494
+ 0.012041598558425903,
495
+ 0.010388582944869995,
496
+ 0.008856356143951416,
497
+ 0.007445335388183594,
498
+ 0.006155818700790405,
499
+ 0.004988163709640503,
500
+ 0.003942638635635376,
501
+ 0.0030195116996765137,
502
+ 0.0022190213203430176,
503
+ 0.0015413165092468262,
504
+ 0.000986635684967041,
505
+ 0.0005550682544708252,
506
+ 0.0002467334270477295,
507
+ 0.0000616908073425293,
508
+ ],
509
+ },
510
+ }
511
+ },
512
+ {
513
+ "operation": {
514
+ "name": "log_mel_spectrogram",
515
+ "type": "LogMelSpectrum",
516
+ "attrs": {"chunk_size": 30, "hop_length": 160, "n_fft": 400, "n_mel": config.num_mel_bins},
517
+ }
518
+ },
519
+ ]
520
+ }
521
+ }
522
+ audio_processor_json = json.dumps(audio_processor_cfg, indent=4)
523
+
524
+ with open(os.path.join(output_dir, "audio_processor_config.json"), "w") as f:
525
+ f.write(audio_processor_json)
526
+
527
+ provider_options = [] if "cpu" in provider else [{f"{provider}": {}}]
528
+ genai_config = {
529
+ "model": {
530
+ "bos_token_id": config.bos_token_id,
531
+ "context_length": config.max_length,
532
+ "decoder": {
533
+ "session_options": {
534
+ "log_id": "onnxruntime-genai",
535
+ "provider_options": provider_options,
536
+ },
537
+ "filename": os.path.basename(decoder_path),
538
+ "head_size": config.d_model // config.decoder_attention_heads,
539
+ "hidden_size": config.d_model,
540
+ "inputs": {
541
+ "input_ids": "input_ids",
542
+ "past_key_names": "past_key_self_%d",
543
+ "past_value_names": "past_value_self_%d",
544
+ "cross_past_key_names": "past_key_cross_%d",
545
+ "cross_past_value_names": "past_value_cross_%d",
546
+ },
547
+ "outputs": {
548
+ "logits": "logits",
549
+ "present_key_names": "present_key_self_%d",
550
+ "present_value_names": "present_value_self_%d",
551
+ },
552
+ "num_attention_heads": config.decoder_attention_heads,
553
+ "num_hidden_layers": config.decoder_layers,
554
+ "num_key_value_heads": config.decoder_attention_heads,
555
+ },
556
+ "encoder": {
557
+ "session_options": {
558
+ "log_id": "onnxruntime-genai",
559
+ "provider_options": provider_options,
560
+ },
561
+ "filename": os.path.basename(encoder_path),
562
+ "head_size": config.d_model // config.encoder_attention_heads,
563
+ "hidden_size": config.d_model,
564
+ "inputs": {"audio_features": "audio_features"},
565
+ "outputs": {
566
+ "encoder_hidden_states": "encoder_hidden_states",
567
+ "cross_present_key_names": "present_key_cross_%d",
568
+ "cross_present_value_names": "present_value_cross_%d",
569
+ },
570
+ "num_attention_heads": config.encoder_attention_heads,
571
+ "num_hidden_layers": config.encoder_layers,
572
+ "num_key_value_heads": config.encoder_attention_heads,
573
+ },
574
+ "eos_token_id": config.eos_token_id,
575
+ "pad_token_id": config.pad_token_id,
576
+ "type": "whisper",
577
+ "vocab_size": config.vocab_size,
578
+ },
579
+ "search": {
580
+ "diversity_penalty": 0.0,
581
+ "do_sample": False,
582
+ "early_stopping": True,
583
+ "length_penalty": 1.0,
584
+ "max_length": config.max_length,
585
+ "min_length": 0,
586
+ "no_repeat_ngram_size": 0,
587
+ "num_beams": 1,
588
+ "num_return_sequences": 1,
589
+ "past_present_share_buffer": use_decoder_masked_mha,
590
+ "repetition_penalty": 1.0,
591
+ "temperature": 1.0,
592
+ "top_k": 1,
593
+ "top_p": 1.0,
594
+ },
595
+ }
596
+
597
+ # Requirements for the DMMHA kernel:
598
+ # - Buffer sharing = true
599
+ # - New input: past_sequence_length
600
+ # - New input: cache_indirection
601
+ # Otherwise, buffer sharing should be false and the new inputs should not be added
602
+ # for beam search to work in ORT GenAI.
603
+ if use_decoder_masked_mha:
604
+ genai_config["model"]["decoder"]["inputs"].update(
605
+ {
606
+ "past_sequence_length": "past_sequence_length",
607
+ "cache_indirection": "cache_indirection",
608
+ }
609
+ )
610
+
611
+ if output_qk:
612
+ genai_config["model"]["decoder"]["outputs"].update(
613
+ {
614
+ "output_cross_qk_names": "output_cross_qk_%d",
615
+ }
616
+ )
617
+
618
+ with open(os.path.join(output_dir, "genai_config.json"), "w") as f:
619
+ json.dump(genai_config, f, indent=4)
620
+
621
+ @staticmethod
622
+ def load_model(
623
+ model_name_or_path: str,
624
+ model_impl: str,
625
+ cache_dir: str,
626
+ device: torch.device,
627
+ dtype: torch.dtype,
628
+ merge_encoder_and_decoder_init: bool = True,
629
+ no_beam_search_op: bool = False,
630
+ output_qk: bool = False,
631
+ ) -> dict[str, torch.nn.Module]:
632
+ """Load model given a pretrained name or path, then build models for ONNX conversion.
633
+
634
+ Args:
635
+ model_name_or_path (str): pretrained model name or path
636
+ model_impl (str): library to load model from
637
+ cache_dir (str): cache directory
638
+ device (torch.device): device to run the model
639
+ dtype (torch.dtype): dtype to run the model
640
+ merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
641
+ no_beam_search_op (bool, optional): Whether to use beam search op or not. Defaults to False.
642
+ output_qk (bool, optional): Whether to output QKs to calculate batched jump times for word-level timestamps. Defaults to False.
643
+ Returns:
644
+ Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
645
+ """
646
+ # Load PyTorch model
647
+ if model_impl == "hf":
648
+ # Load from Hugging Face
649
+ model = WhisperForConditionalGeneration.from_pretrained(
650
+ model_name_or_path, cache_dir=cache_dir, attn_implementation="eager"
651
+ )
652
+ else:
653
+ # Load from OpenAI
654
+ import whisper # noqa: PLC0415
655
+
656
+ if not os.path.exists(model_name_or_path):
657
+ name_or_path = model_name_or_path.split("/")[-1][8:]
658
+ else:
659
+ name_or_path = model_name_or_path
660
+ model = whisper.load_model(name_or_path, device, download_root=cache_dir, in_memory=True)
661
+
662
+ # Set PyTorch model properties
663
+ model.eval().to(device=device)
664
+ if model_impl == "hf":
665
+ model.to(dtype=dtype)
666
+ config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
667
+
668
+ # Load each component of PyTorch model
669
+ decoder = WhisperDecoder(config, model, model_impl, no_beam_search_op).eval()
670
+ components = {"decoder": decoder}
671
+ if merge_encoder_and_decoder_init:
672
+ encoder_decoder_init = WhisperEncoderDecoderInit(config, model, model_impl, no_beam_search_op).eval()
673
+ components.update({"encoder": encoder_decoder_init})
674
+ else:
675
+ encoder = WhisperEncoder(config, model, model_impl).eval()
676
+ components.update({"encoder": encoder, "decoder_init": decoder})
677
+
678
+ if output_qk:
679
+ batched_jump_times = WhisperJumpTimes(config, device, cache_dir).eval()
680
+ components.update({"jump_times": batched_jump_times})
681
+ return components
682
+
683
+ @staticmethod
684
+ def export_onnx(
685
+ model: WhisperEncoder | WhisperEncoderDecoderInit | WhisperDecoder,
686
+ onnx_model_path: str,
687
+ provider: str,
688
+ verbose: bool,
689
+ use_external_data_format: bool,
690
+ use_fp16_inputs: bool,
691
+ use_int32_inputs: bool,
692
+ use_encoder_hidden_states: bool,
693
+ use_kv_cache_inputs: bool,
694
+ ):
695
+ """Export model component to ONNX
696
+
697
+ Args:
698
+ model (class): PyTorch class to export
699
+ onnx_model_path (str): path to save ONNX model
700
+ provider (str): provider to use for verifying parity on ONNX model
701
+ verbose (bool): print verbose information.
702
+ use_external_data_format (bool): use external data format or not.
703
+ use_fp16_inputs (bool): use float16 inputs for the audio_features, encoder_hidden_states, logits, and KV caches.
704
+ use_int32_inputs (bool): use int32 inputs for the decoder_input_ids.
705
+ use_encoder_hidden_states (bool): use encoder_hidden_states as model input for decoder-init/decoder-without-past models.
706
+ use_kv_cache_inputs (bool): use KV caches as model inputs for decoder-with-past models.
707
+ """
708
+ if isinstance(model, WhisperEncoder):
709
+ model.export_onnx(
710
+ onnx_model_path,
711
+ provider,
712
+ verbose,
713
+ use_external_data_format,
714
+ use_fp16_inputs,
715
+ )
716
+ elif isinstance(model, WhisperEncoderDecoderInit):
717
+ model.export_onnx(
718
+ onnx_model_path,
719
+ provider,
720
+ verbose,
721
+ use_external_data_format,
722
+ use_fp16_inputs,
723
+ use_int32_inputs,
724
+ )
725
+ elif isinstance(model, WhisperDecoder):
726
+ model.export_onnx(
727
+ onnx_model_path,
728
+ provider,
729
+ verbose,
730
+ use_external_data_format,
731
+ use_fp16_inputs,
732
+ use_int32_inputs,
733
+ use_encoder_hidden_states,
734
+ use_kv_cache_inputs,
735
+ )
736
+ elif isinstance(model, WhisperJumpTimes):
737
+ model.export_onnx(
738
+ onnx_model_path,
739
+ provider,
740
+ verbose,
741
+ use_external_data_format,
742
+ use_fp16_inputs,
743
+ use_int32_inputs,
744
+ )
745
+ else:
746
+ raise ValueError(f"Unknown instance for model detected: {type(model)}")
747
+
748
+ @staticmethod
749
+ def optimize_onnx(
750
+ onnx_model_path: str,
751
+ optimized_model_path: str,
752
+ is_float16: bool,
753
+ num_attention_heads: int,
754
+ hidden_size: int,
755
+ num_decoder_layers: int,
756
+ use_external_data_format: bool = False,
757
+ use_gpu: bool = False,
758
+ provider: str = "cpu",
759
+ is_decoder: bool = False,
760
+ no_beam_search_op: bool = False,
761
+ use_decoder_masked_mha: bool = False,
762
+ output_qk: bool = False,
763
+ ):
764
+ """Optimize ONNX model with an option to convert it to use mixed precision."""
765
+
766
+ from fusion_options import FusionOptions # noqa: PLC0415
767
+
768
+ optimization_options = FusionOptions("bart")
769
+ optimization_options.use_multi_head_attention = True
770
+ optimization_options.disable_multi_head_attention_bias = False
771
+
772
+ m = optimize_model(
773
+ onnx_model_path,
774
+ model_type="bart",
775
+ num_heads=num_attention_heads,
776
+ hidden_size=hidden_size,
777
+ opt_level=0,
778
+ optimization_options=optimization_options,
779
+ use_gpu=use_gpu,
780
+ only_onnxruntime=False,
781
+ )
782
+
783
+ # Add `past_sequence_length`, `cache_indirection`, and `output_qk` to `MultiHeadAttention` ops
784
+ if is_decoder and no_beam_search_op:
785
+ if use_decoder_masked_mha:
786
+ # FP16 CUDA, FP32 CUDA, and FP32 CPU use the `DecoderMaskedMultiHeadAttention` kernel
787
+ # via `MultiHeadAttention`, which requires the `past_sequence_length` and
788
+ # `cache_indirection` inputs
789
+ m, past_seq_len_name = fix_past_sequence_length(m)
790
+ m = add_cache_indirection_to_mha(m, past_seq_len_name)
791
+
792
+ if output_qk:
793
+ m = add_output_qk_to_mha(m, skip_node_idxs=list(range(0, 2 * num_decoder_layers, 2)))
794
+
795
+ m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
796
+
797
+ @staticmethod
798
+ def pt_transcription_for_verify_onnx(
799
+ processor: WhisperProcessor,
800
+ pt_model: torch.nn.Module,
801
+ device: torch.device,
802
+ batch_size: int = 1,
803
+ prompt_mode: bool = False,
804
+ ):
805
+ # Try to import `datasets` pip package
806
+ try:
807
+ from datasets import load_dataset # noqa: PLC0415
808
+ except Exception as e:
809
+ logger.error(f"An error occurred while importing `datasets`: {e}", exc_info=True) # noqa: G201
810
+ install_cmd = "pip install datasets"
811
+ logger.warning(f"Could not import `datasets`. Attempting to install `datasets` via `{install_cmd}`.")
812
+ os.system(install_cmd)
813
+
814
+ from datasets import load_dataset # noqa: PLC0415
815
+
816
+ ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
817
+ input_features_ = []
818
+ if batch_size == 1:
819
+ input_features = processor([ds[0]["audio"]["array"]], return_tensors="pt").input_features
820
+ else:
821
+ input_features_ = [
822
+ processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
823
+ processor([ds[3]["audio"]["array"]], return_tensors="pt").input_features,
824
+ ]
825
+ assert len(input_features_) == batch_size
826
+ input_features = torch.cat((input_features_[0], input_features_[1]))
827
+
828
+ max_length, min_length, num_beams, num_return_sequences = 30, 0, 1, 1
829
+ length_penalty, repetition_penalty = 1.0, 1.0
830
+ inputs = {
831
+ "input_features": input_features.to(device),
832
+ "max_length": max_length,
833
+ "min_length": min_length,
834
+ "num_beams": num_beams,
835
+ "num_return_sequences": num_return_sequences,
836
+ "length_penalty": length_penalty,
837
+ "repetition_penalty": repetition_penalty,
838
+ "early_stopping": True,
839
+ "use_cache": True,
840
+ }
841
+
842
+ if prompt_mode:
843
+ prompts = ["John has doubts", "Maria has grave doubts"]
844
+ prompt_ids = [processor.get_prompt_ids(p) for p in prompts]
845
+ pt_transcription = []
846
+ pt_outputs = []
847
+ # The looping for model.generate is necessary here due to the limitation as per
848
+ # https://huggingface.co/docs/transformers/model_doc/whisper#transformers.WhisperForConditionalGeneration.generate.prompt_ids
849
+ # prompt_ids input requires a tensor of rank 1
850
+ for i in range(batch_size):
851
+ inputs["prompt_ids"] = torch.from_numpy(prompt_ids[i]).to(device=device)
852
+ inputs["input_features"] = input_features_[i].to(device)
853
+ pt_output = pt_model.generate(**inputs).detach().cpu().numpy()
854
+ pt_outputs.append(pt_output)
855
+ pt_transcription.append(processor.batch_decode(pt_output, skip_special_tokens=True)[0])
856
+ inputs["input_features"] = input_features
857
+ del inputs["prompt_ids"]
858
+ else:
859
+ prompt_ids = []
860
+ pt_outputs = pt_model.generate(**inputs).detach().cpu().numpy()
861
+ pt_transcription = [processor.batch_decode(pt_outputs, skip_special_tokens=True)[0]]
862
+ pt_outputs = list(pt_outputs)
863
+ del inputs["early_stopping"]
864
+ del inputs["use_cache"]
865
+ return inputs, pt_transcription, pt_outputs, prompt_ids
866
+
867
+ @staticmethod
868
+ def select_transcription_options(
869
+ batch_size: int,
870
+ prompt_mode: bool,
871
+ ):
872
+ if batch_size > 1 and prompt_mode:
873
+ expected_transcription_no_comma_prompt1 = " John has doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky I"
874
+ expected_transcription_misspelled_prompt1 = " John has doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
875
+ expected_transcription_no_comma_prompt2 = " Maria has grave doubts whether Sir Frederick Layton's work is really Greek after all and can discover in it but little of Rocky"
876
+ expected_transcription_misspelled_prompt2 = " Maria has grave doubts whether Sir Frederick Latins work is really Greek after all and can discover in it but little of Rocky I"
877
+ expected_transcription_options = {
878
+ expected_transcription_no_comma_prompt1,
879
+ expected_transcription_no_comma_prompt2,
880
+ expected_transcription_misspelled_prompt1,
881
+ expected_transcription_misspelled_prompt2,
882
+ }
883
+ else:
884
+ expected_transcription_no_comma = (
885
+ " Mr. Quilter is the apostle of the middle classes and we are glad to welcome his gospel."
886
+ )
887
+ expected_transcription_with_comma = (
888
+ " Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel."
889
+ )
890
+ expected_transcription_with_quote_and_comma = (
891
+ ' "Mr. Quilter is the apostle of the middle classes, and we are glad to welcome his gospel.'
892
+ )
893
+ expected_transcription_options = {
894
+ expected_transcription_no_comma,
895
+ expected_transcription_with_comma,
896
+ expected_transcription_with_quote_and_comma,
897
+ }
898
+ return expected_transcription_options
899
+
900
+ @staticmethod
901
+ def get_outputs(
902
+ pt_outputs: np.ndarray,
903
+ ort_outputs: np.ndarray,
904
+ i: int,
905
+ ):
906
+ """Get PyTorch and ONNX Runtime output token ids at index i"""
907
+ pt_output, ort_output = pt_outputs[i], ort_outputs[i]
908
+ pt_shape, ort_shape = pt_output.shape, ort_output.shape
909
+
910
+ # Hugging Face impl. + Beam Search op: PyTorch = (26,) and ORT = (30,)
911
+ # OpenAI impl. + Beam Search op: PyTorch = (1, 30) and ORT = (30,)
912
+ if pt_shape != ort_shape:
913
+ if len(pt_shape) > 1:
914
+ pt_output = pt_output[0]
915
+ pt_shape = pt_output.shape
916
+ if len(ort_shape) > 1:
917
+ ort_output = ort_output[0]
918
+ ort_shape = ort_output.shape
919
+ if pt_shape[0] != ort_shape[0]:
920
+ min_len = min(pt_shape[0], ort_shape[0])
921
+ pt_output = pt_output[:min_len]
922
+ ort_output = ort_output[:min_len]
923
+
924
+ assert pt_output.shape == ort_output.shape
925
+ return pt_output, ort_output
926
+
927
+ @staticmethod
928
+ def verify_onnx(
929
+ model_name_or_path: str,
930
+ cache_dir: str,
931
+ ort_session: InferenceSession,
932
+ device: torch.device,
933
+ batch_size: int = 1,
934
+ prompt_mode: bool = False,
935
+ ):
936
+ """Compare the result from PyTorch and ONNX Runtime to verify the ONNX model is good."""
937
+ pt_model = WhisperForConditionalGeneration.from_pretrained(
938
+ model_name_or_path, cache_dir=cache_dir, attn_implementation="eager"
939
+ ).to(device)
940
+ processor = WhisperProcessor.from_pretrained(model_name_or_path, cache_dir=cache_dir)
941
+ config = WhisperConfig.from_pretrained(model_name_or_path, cache_dir=cache_dir)
942
+
943
+ inputs, pt_transcription, pt_outputs, decoder_prompt_ids = WhisperHelper.pt_transcription_for_verify_onnx(
944
+ processor,
945
+ pt_model,
946
+ device,
947
+ batch_size=batch_size,
948
+ prompt_mode=prompt_mode,
949
+ )
950
+
951
+ start_id = [config.decoder_start_token_id] # ex: [50258]
952
+ prompt_ids = processor.get_decoder_prompt_ids(language="english", task="transcribe")
953
+ prompt_ids = [token[1] for token in prompt_ids] # ex: [50259, 50358, 50363]
954
+ forced_decoder_ids = start_id + prompt_ids # ex: [50258, 50259, 50358, 50363]
955
+
956
+ ort_names = [entry.name for entry in ort_session.get_inputs()]
957
+ ort_dtypes = [entry.type for entry in ort_session.get_inputs()]
958
+ ort_to_np = {
959
+ "tensor(float)": np.float32,
960
+ "tensor(float16)": np.float16,
961
+ "tensor(int64)": np.int64,
962
+ "tensor(int32)": np.int32,
963
+ "tensor(int8)": np.int8,
964
+ "tensor(uint8)": np.uint8,
965
+ }
966
+
967
+ use_extra_decoding_ids = "extra_decoding_ids" in ort_names
968
+ for name, dtype in zip(ort_names, ort_dtypes, strict=False):
969
+ if name == "input_features":
970
+ inputs[name] = inputs[name].detach().cpu().numpy()
971
+ elif name == "vocab_mask":
972
+ inputs[name] = np.ones(config.vocab_size, dtype=ort_to_np[dtype])
973
+ elif name == "prefix_vocab_mask":
974
+ inputs[name] = np.ones((batch_size, config.vocab_size), dtype=ort_to_np[dtype])
975
+ elif name == "decoder_input_ids":
976
+ if not prompt_mode:
977
+ raw_input_ids = [start_id] if use_extra_decoding_ids else [forced_decoder_ids]
978
+ inputs[name] = np.array(raw_input_ids, dtype=ort_to_np[dtype])
979
+ else:
980
+ # This logic handles the scenario for when prompts are not of the same size
981
+ # For example if our prompt ids are [p1_id_1, p1_id_2] and [p2_id_1]
982
+ # The final decoder_input_ids will look as such after padding
983
+ # [prev_token, p1_id_1, p1_id_2, start_token, lang_token, transcribe_token]
984
+ # [prev_token, p2_id_1, PAD_TOKEN, start_token, lang_token, transcribe_token]
985
+ ort_prompts = []
986
+ for i in range(batch_size):
987
+ ort_prompts.append(decoder_prompt_ids[i].tolist())
988
+ max_len = max(len(p) for p in ort_prompts)
989
+ padded_prompts = []
990
+ for p in ort_prompts:
991
+ padded_prompt = [*p, *([config.pad_token_id] * (max_len - len(p)))]
992
+ padded_prompts.append(padded_prompt + forced_decoder_ids)
993
+ inputs[name] = np.array(padded_prompts, dtype=ort_to_np[dtype])
994
+ elif name == "logits_processor":
995
+ inputs[name] = np.array([1], dtype=ort_to_np[dtype])
996
+ elif name == "cross_qk_layer_head":
997
+ inputs[name] = np.array([[0, 0]], dtype=ort_to_np[dtype])
998
+ elif name == "extra_decoding_ids":
999
+ inputs[name] = np.repeat(np.array([prompt_ids], dtype=ort_to_np[dtype]), batch_size, 0)
1000
+ elif name == "temperature":
1001
+ inputs[name] = np.array([1.0], dtype=ort_to_np[dtype])
1002
+ else:
1003
+ inputs[name] = np.array([inputs[name]], dtype=ort_to_np[dtype])
1004
+
1005
+ ort_outputs = ort_session.run(None, inputs)[0][:, 0, :]
1006
+ ort_transcription = processor.batch_decode(ort_outputs, skip_special_tokens=True)
1007
+ expected_transcription_options = WhisperHelper.select_transcription_options(batch_size, prompt_mode)
1008
+
1009
+ parity = 1
1010
+ for i in range(batch_size):
1011
+ pt_output, ort_output = WhisperHelper.get_outputs(pt_outputs, ort_outputs, i)
1012
+
1013
+ # Check if token ids match
1014
+ parity *= np.allclose(pt_output, ort_output)
1015
+
1016
+ # Check if transcribed outputs match
1017
+ parity *= (
1018
+ pt_transcription[i] in expected_transcription_options
1019
+ and ort_transcription[i] in expected_transcription_options
1020
+ )
1021
+ max_diff = 0
1022
+
1023
+ if not parity:
1024
+ for i in range(batch_size):
1025
+ pt_output, ort_output = WhisperHelper.get_outputs(pt_outputs, ort_outputs, i)
1026
+ diff = pt_output - ort_output
1027
+
1028
+ max_diff_i = max(diff.min(), diff.max(), key=abs)
1029
+ max_diff = max(max_diff, max_diff_i)
1030
+
1031
+ if max_diff != 0:
1032
+ logger.warning(f"PyTorch outputs: {pt_transcription}")
1033
+ logger.warning(f"ONNX Runtime outputs: {ort_transcription}")
1034
+
1035
+ return 0