onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,464 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from itertools import chain
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import onnx
15
+ import torch
16
+ from float16 import convert_float_to_float16
17
+ from google.protobuf.internal.containers import RepeatedCompositeFieldContainer
18
+ from onnx import ModelProto, ValueInfoProto
19
+ from onnx_model import OnnxModel
20
+ from past_helper import PastKeyValuesHelper
21
+ from transformers import WhisperConfig
22
+ from whisper_inputs import (
23
+ convert_inputs_for_ort,
24
+ get_model_dynamic_axes,
25
+ get_sample_decoder_inputs,
26
+ group_past_key_values,
27
+ )
28
+
29
+ from onnxruntime import InferenceSession
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class WhisperDecoder(torch.nn.Module):
35
+ """A Whisper decoder with optional past key values"""
36
+
37
+ def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
38
+ super().__init__()
39
+ self.config = config
40
+ self.device = model.device
41
+ self.model_impl = model_impl
42
+ self.no_beam_search_op = no_beam_search_op
43
+
44
+ self.decoder = None if model_impl == "openai" else model.model.decoder
45
+ self.proj_out = None if model_impl == "openai" else model.proj_out
46
+ self.model = model if model_impl == "openai" else None
47
+
48
+ self.max_source_positions = self.config.max_source_positions
49
+ self.num_heads = self.config.decoder_attention_heads
50
+ self.head_size = self.config.d_model // self.num_heads
51
+
52
+ def hf_forward(
53
+ self,
54
+ decoder_input_ids: torch.Tensor,
55
+ encoder_hidden_states: torch.Tensor | None = None,
56
+ past_key_values: list[tuple[torch.Tensor]] | None = None,
57
+ ):
58
+ outputs = self.decoder(
59
+ encoder_hidden_states=encoder_hidden_states,
60
+ input_ids=decoder_input_ids,
61
+ past_key_values=past_key_values,
62
+ use_cache=True,
63
+ )
64
+ logits = self.proj_out(outputs.last_hidden_state)
65
+ present_key_values = outputs.past_key_values
66
+
67
+ if past_key_values is None:
68
+ # Return present_self_* and present_cross_* for decoder-init
69
+ return logits, present_key_values
70
+
71
+ # Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
72
+ # (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
73
+ # After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1), ...,
74
+ # (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1), ...
75
+ present_self, present_cross = PastKeyValuesHelper.group_by_self_and_cross(present_key_values)
76
+
77
+ # Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
78
+ return logits, present_self
79
+
80
+ def oai_forward(
81
+ self,
82
+ decoder_input_ids: torch.Tensor,
83
+ encoder_hidden_states: torch.Tensor | None = None,
84
+ past_key_values: list[tuple[torch.Tensor]] | None = None,
85
+ ):
86
+ past_kv_cache = {}
87
+ if past_key_values is not None:
88
+ # Convert past KV caches (BxNxSxH --> BxSxNxH --> BxSxD) for OpenAI's forward pass
89
+ self_attn_kv_caches, cross_attn_kv_caches = group_past_key_values(past_key_values)
90
+ self_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in self_attn_kv_caches]
91
+ self_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in self_attn_kv_caches]
92
+ cross_attn_kv_caches = [past_kv.transpose(1, 2) for past_kv in cross_attn_kv_caches]
93
+ cross_attn_kv_caches = [past_kv.reshape((*past_kv.shape[:2], -1)) for past_kv in cross_attn_kv_caches]
94
+
95
+ for idx, block in enumerate(self.model.decoder.blocks):
96
+ past_kv_cache[block.attn.key] = self_attn_kv_caches[2 * idx]
97
+ past_kv_cache[block.attn.value] = self_attn_kv_caches[2 * idx + 1]
98
+ past_kv_cache[block.cross_attn.key] = cross_attn_kv_caches[2 * idx]
99
+ past_kv_cache[block.cross_attn.value] = cross_attn_kv_caches[2 * idx + 1]
100
+
101
+ # Install OpenAI's hooks on the forward pass of each nn.Linear for key and value
102
+ # since the hooks will capture the output of the key and value MatMuls, which
103
+ # represent the current keys and values.
104
+ #
105
+ # For OpenAI's forward pass, the hook function will also perform the concat
106
+ # operation (past_kv + curr_kv --> pres_kv) if needed. However, the ONNX model
107
+ # will not contain this concat operation because the present KV caches aren't
108
+ # returned by OpenAI's forward pass.
109
+ kv_cache, hooks = self.model.install_kv_cache_hooks()
110
+
111
+ # Run forward pass
112
+ # NOTE: There is a bug with openai-whisper==20240930 with the introduction of SDPA.
113
+ # In the Whisper codebase, the following line
114
+ #
115
+ # is_causal = mask is not None and n_ctx > 1
116
+ #
117
+ # has been added where `mask` is a torch tensor. The right-hand side evaluates to `tensor(True/False)`
118
+ # but `is_causal` only accepts the boolean value. The fix is to apply `.item()` after the right-hand
119
+ # side has been evaluated. In other words, the line should be
120
+ #
121
+ # is_causal = (mask is not None and n_ctx > 1).item()
122
+ #
123
+ # instead.
124
+ logits = self.model.decoder(x=decoder_input_ids, xa=encoder_hidden_states, kv_cache=past_kv_cache)
125
+
126
+ # Re-do concat operation on self attention KV caches for ONNX export (if past self attention KV caches exist)
127
+ if past_key_values is not None:
128
+ for block in self.model.decoder.blocks:
129
+ kv_cache[block.attn.key] = torch.cat(
130
+ [past_kv_cache[block.attn.key], kv_cache[block.attn.key]], dim=1
131
+ ).detach()
132
+ kv_cache[block.attn.value] = torch.cat(
133
+ [past_kv_cache[block.attn.value], kv_cache[block.attn.value]], dim=1
134
+ ).detach()
135
+
136
+ present_self, present_cross = [], []
137
+ for block in self.model.decoder.blocks:
138
+ # Group self and cross values
139
+ present_self.append(kv_cache[block.attn.key])
140
+ present_self.append(kv_cache[block.attn.value])
141
+ if past_key_values is None:
142
+ # Return present_self_* and present_cross_* for decoder-init
143
+ present_cross.append(kv_cache[block.cross_attn.key])
144
+ present_cross.append(kv_cache[block.cross_attn.value])
145
+
146
+ # Convert present KV caches (BxSxD --> BxSxNxH --> BxNxSxH) after OpenAI's forward pass
147
+ present_self = [
148
+ present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
149
+ for present_kv in present_self
150
+ ]
151
+ present_cross = [
152
+ present_kv.reshape((*present_kv.shape[:2], -1, self.head_size)).transpose(1, 2)
153
+ for present_kv in present_cross
154
+ ]
155
+
156
+ # Remove OpenAI's hooks since they can persist after this function completes
157
+ for hook in hooks:
158
+ hook.remove()
159
+
160
+ if past_key_values is None:
161
+ # Return present_self_* and present_cross_* for decoder-init
162
+ present_key_values = PastKeyValuesHelper.group_by_layer(
163
+ present_self + present_cross, len(present_self) // 2
164
+ )
165
+ return logits, present_key_values
166
+
167
+ # Return present_self_* for decoder-with-past since past_cross_* and present_cross_* are identical
168
+ return logits, present_self
169
+
170
+ def forward(
171
+ self,
172
+ decoder_input_ids: torch.Tensor,
173
+ encoder_hidden_states: torch.Tensor | None = None,
174
+ past_key_values: list[tuple[torch.Tensor]] | None = None,
175
+ ):
176
+ if self.model_impl == "openai":
177
+ return self.oai_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
178
+ return self.hf_forward(decoder_input_ids, encoder_hidden_states, past_key_values)
179
+
180
+ def input_names(self):
181
+ if self.first_pass:
182
+ input_names = ["input_ids", "encoder_hidden_states"]
183
+ else:
184
+ input_names = [
185
+ "input_ids",
186
+ "encoder_hidden_states",
187
+ *list(
188
+ chain.from_iterable(
189
+ (f"past_key_self_{i}", f"past_value_self_{i}", f"past_key_cross_{i}", f"past_value_cross_{i}")
190
+ for i in range(self.config.decoder_layers)
191
+ )
192
+ ),
193
+ ]
194
+ return input_names
195
+
196
+ def output_names(self):
197
+ if self.first_pass:
198
+ output_names = [
199
+ "logits",
200
+ *list(
201
+ chain.from_iterable(
202
+ (
203
+ f"present_key_self_{i}",
204
+ f"present_value_self_{i}",
205
+ f"present_key_cross_{i}",
206
+ f"present_value_cross_{i}",
207
+ )
208
+ for i in range(self.config.decoder_layers)
209
+ )
210
+ ),
211
+ ]
212
+ else:
213
+ output_names = [
214
+ "logits",
215
+ *list(
216
+ chain.from_iterable(
217
+ (f"present_key_self_{i}", f"present_value_self_{i}") for i in range(self.config.decoder_layers)
218
+ )
219
+ ),
220
+ ]
221
+ return output_names
222
+
223
+ def dynamic_axes(self, input_names, output_names):
224
+ dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
225
+ if "input_ids" in dynamic_axes and not self.no_beam_search_op:
226
+ # Set dynamic axes for `input_ids` when using beam search op to {0: "batch_size"} only
227
+ del dynamic_axes["input_ids"][1]
228
+ return dynamic_axes
229
+
230
+ def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
231
+ inputs = get_sample_decoder_inputs(
232
+ self.config,
233
+ self.device,
234
+ batch_size=2,
235
+ past_sequence_length=(0 if self.first_pass else 6),
236
+ sequence_length=(6 if self.first_pass else 1),
237
+ use_fp16=use_fp16_inputs,
238
+ use_int32=use_int32_inputs,
239
+ )
240
+ if return_dict:
241
+ if self.first_pass:
242
+ del inputs["past_key_values"]
243
+ return inputs
244
+
245
+ if self.first_pass:
246
+ return (
247
+ inputs["decoder_input_ids"],
248
+ inputs["encoder_hidden_states"],
249
+ )
250
+ return (
251
+ inputs["decoder_input_ids"],
252
+ inputs["encoder_hidden_states"],
253
+ inputs["past_key_values"],
254
+ )
255
+
256
+ def fix_key_value_cache_dims(self, io: ValueInfoProto, is_cross: bool = False, is_output: bool = False):
257
+ # Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
258
+ # and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
259
+ num_heads = io.type.tensor_type.shape.dim[1]
260
+ if "_dim_" in num_heads.dim_param:
261
+ num_heads.Clear()
262
+ num_heads.dim_value = self.num_heads
263
+ sequence_length = io.type.tensor_type.shape.dim[2]
264
+ if "_dim_" in sequence_length.dim_param:
265
+ sequence_length.Clear()
266
+ if is_cross:
267
+ sequence_length.dim_value = self.max_source_positions
268
+ else:
269
+ sequence_length.dim_param = "total_sequence_length" if is_output else "past_sequence_length"
270
+ head_size = io.type.tensor_type.shape.dim[3]
271
+ if "_dim_" in head_size.dim_param:
272
+ head_size.Clear()
273
+ head_size.dim_value = self.head_size
274
+ return io
275
+
276
+ def fix_io(self, io_list: RepeatedCompositeFieldContainer, is_output: bool = False):
277
+ # Fix order of inputs/outputs and each dim_value of input/output
278
+ reordered_io = []
279
+ self_attn_kv_caches = []
280
+ cross_attn_kv_caches = []
281
+
282
+ for io in io_list:
283
+ if "past" not in io.name and "present" not in io.name:
284
+ reordered_io.append(io)
285
+ elif "self" in io.name:
286
+ # Self attention KV caches
287
+ new_io = self.fix_key_value_cache_dims(io, is_cross=False, is_output=is_output)
288
+ if self.no_beam_search_op:
289
+ reordered_io.append(new_io)
290
+ else:
291
+ self_attn_kv_caches.append(new_io)
292
+ else:
293
+ # Cross attention KV caches
294
+ new_io = self.fix_key_value_cache_dims(io, is_cross=True, is_output=is_output)
295
+ if self.no_beam_search_op:
296
+ reordered_io.append(new_io)
297
+ else:
298
+ cross_attn_kv_caches.append(new_io)
299
+
300
+ if not self.no_beam_search_op:
301
+ reordered_io += self_attn_kv_caches + cross_attn_kv_caches
302
+ return reordered_io
303
+
304
+ def fix_inputs_and_outputs(self, model: ModelProto):
305
+ # ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
306
+ # We now change the dim_values to the correct one.
307
+ reordered_inputs = self.fix_io(model.graph.input, is_output=False)
308
+ while len(model.graph.input) > 0:
309
+ model.graph.input.pop()
310
+ model.graph.input.extend(reordered_inputs)
311
+
312
+ reordered_outputs = self.fix_io(model.graph.output, is_output=True)
313
+ while len(model.graph.output) > 0:
314
+ model.graph.output.pop()
315
+ model.graph.output.extend(reordered_outputs)
316
+ return model
317
+
318
+ def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
319
+ if self.model_impl == "openai" and use_fp16_inputs:
320
+ # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
321
+ # float32 to float16 since exported model already has float16 weights everywhere
322
+ # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
323
+ # when computing LayerNorm.
324
+ #
325
+ # Reference:
326
+ # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
327
+ model = convert_float_to_float16(model)
328
+ return model
329
+
330
+ def export_onnx(
331
+ self,
332
+ onnx_model_path: str,
333
+ provider: str,
334
+ verbose: bool = True,
335
+ use_external_data_format: bool = False,
336
+ use_fp16_inputs: bool = False,
337
+ use_int32_inputs: bool = True,
338
+ use_encoder_hidden_states: bool = False,
339
+ use_kv_cache_inputs: bool = True,
340
+ ):
341
+ """Export decoder to ONNX
342
+
343
+ Args:
344
+ onnx_model_path (str): path to save ONNX model
345
+ provider (str): provider to use for verifying parity on ONNX model
346
+ verbose (bool, optional): print verbose information. Defaults to True.
347
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
348
+ use_fp16_inputs (bool, optional): use float16 inputs for the KV caches. Defaults to False.
349
+ use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
350
+ use_encoder_hidden_states (bool, optional): use encoder_hidden_states as model input for decoder-init/decoder-without-past models. Defaults to False.
351
+ use_kv_cache_inputs (bool, optional): use KV caches as model inputs for decoder-with-past models. Defaults to True.
352
+ """
353
+ # Shape of decoder's tensors:
354
+ # Required Inputs:
355
+ # decoder_input_ids: (batch_size, sequence_length)
356
+ # Optional Inputs:
357
+ # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
358
+ # past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
359
+ # past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
360
+ # Outputs:
361
+ # logits: (batch_size, sequence_length, vocab_size)
362
+ # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
363
+ # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
364
+
365
+ # For the first pass through the decoder (i.e. decoder-init/decoder-without-past)
366
+ self.first_pass = use_encoder_hidden_states and not use_kv_cache_inputs
367
+
368
+ # For subsequent passes through the decoder (i.e. decoder-with-past)
369
+ self.later_pass = not use_encoder_hidden_states and use_kv_cache_inputs
370
+
371
+ assert self.first_pass or self.later_pass, (
372
+ "Only one of `use_encoder_hidden_states` and `use_kv_cache_inputs` can be true at once."
373
+ )
374
+
375
+ inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
376
+ input_names = self.input_names()
377
+ output_names = self.output_names()
378
+ dynamic_axes = self.dynamic_axes(input_names, output_names)
379
+
380
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
381
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
382
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
383
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
384
+ out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
385
+
386
+ torch.onnx.export(
387
+ self,
388
+ args=inputs,
389
+ f=out_path,
390
+ export_params=True,
391
+ input_names=input_names,
392
+ output_names=output_names,
393
+ dynamic_axes=dynamic_axes,
394
+ opset_version=17,
395
+ do_constant_folding=True,
396
+ verbose=verbose,
397
+ )
398
+
399
+ model = onnx.load_model(out_path, load_external_data=use_external_data_format)
400
+ model = self.fix_inputs_and_outputs(model)
401
+ model = self.fix_layernorm_weights(model, use_fp16_inputs)
402
+ OnnxModel.save(
403
+ model,
404
+ onnx_model_path,
405
+ save_as_external_data=use_external_data_format,
406
+ all_tensors_to_one_file=True,
407
+ )
408
+
409
+ self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
410
+
411
+ def verify_onnx(
412
+ self,
413
+ onnx_model_path: str,
414
+ provider: str,
415
+ use_fp16_inputs: bool,
416
+ use_int32_inputs: bool,
417
+ ):
418
+ """Verify ONNX model outputs and PyTorch model outputs match
419
+
420
+ Args:
421
+ onnx_model_path (str): path to save ONNX model
422
+ provider (str): execution provider for ONNX model
423
+ use_fp16_inputs (bool, optional): use float16 inputs for the KV caches
424
+ use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
425
+ """
426
+ # Shape of decoder's tensors:
427
+ # Required Inputs:
428
+ # decoder_input_ids: (batch_size, sequence_length)
429
+ # Optional Inputs:
430
+ # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
431
+ # past_{key/value}_self_* (past self attention KV caches): (batch_size, num_heads, past_sequence_length, head_size)
432
+ # past_{key/value}_cross_* (past cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
433
+ # Outputs:
434
+ # logits: (batch_size, sequence_length, vocab_size)
435
+ # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
436
+ # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
437
+
438
+ # Run PyTorch model
439
+ inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
440
+ pt_outputs = []
441
+ if self.first_pass:
442
+ out = self.forward(**inputs)
443
+ pt_outputs.append(out[0].detach().cpu().numpy())
444
+ for present_key_value_layer in out[1]:
445
+ for present_key_value in present_key_value_layer:
446
+ pt_outputs.append(present_key_value.detach().cpu().numpy())
447
+ else:
448
+ out = self.forward(**inputs)
449
+ pt_outputs.append(out[0].detach().cpu().numpy())
450
+ for present_self_key_value in out[1]:
451
+ pt_outputs.append(present_self_key_value.detach().cpu().numpy())
452
+
453
+ # Run ONNX model
454
+ sess = InferenceSession(onnx_model_path, providers=[provider])
455
+ ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
456
+
457
+ # Calculate output difference
458
+ try:
459
+ for i, output_name in enumerate(self.output_names()):
460
+ diff = np.abs(pt_outputs[i] - ort_outputs[i])
461
+ logger.warning(f"Comparing {output_name}...")
462
+ logger.warning(f"Max diff: {np.max(diff)}")
463
+ except: # noqa: E722
464
+ pass
@@ -0,0 +1,164 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ import numpy as np
13
+ import onnx
14
+ import torch
15
+ from float16 import convert_float_to_float16
16
+ from onnx import ModelProto
17
+ from onnx_model import OnnxModel
18
+ from transformers import WhisperConfig
19
+ from whisper_inputs import get_model_dynamic_axes, get_sample_encoder_inputs
20
+
21
+ from onnxruntime import InferenceSession
22
+
23
+ logger = logging.getLogger(__name__)
24
+
25
+
26
+ class WhisperEncoder(torch.nn.Module):
27
+ """Whisper encoder component"""
28
+
29
+ def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str):
30
+ super().__init__()
31
+ self.config = config
32
+ self.device = model.device
33
+ self.model_impl = model_impl
34
+
35
+ self.encoder = model.encoder if model_impl == "openai" else model.model.encoder
36
+
37
+ def forward(self, audio_features: torch.Tensor):
38
+ outputs = self.encoder(audio_features)
39
+ return outputs if self.model_impl == "openai" else outputs.last_hidden_state
40
+
41
+ def input_names(self):
42
+ input_names = ["audio_features"]
43
+ return input_names
44
+
45
+ def output_names(self):
46
+ output_names = ["encoder_hidden_states"]
47
+ return output_names
48
+
49
+ def dynamic_axes(self, input_names, output_names):
50
+ dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
51
+ return dynamic_axes
52
+
53
+ def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
54
+ if self.model_impl == "openai" and use_fp16_inputs:
55
+ # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
56
+ # float32 to float16 since exported model already has float16 weights everywhere
57
+ # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
58
+ # when computing LayerNorm.
59
+ #
60
+ # Reference:
61
+ # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
62
+ model = convert_float_to_float16(model)
63
+ return model
64
+
65
+ def export_onnx(
66
+ self,
67
+ onnx_model_path: str,
68
+ provider: str,
69
+ verbose: bool = True,
70
+ use_external_data_format: bool = False,
71
+ use_fp16_inputs: bool = False,
72
+ ):
73
+ """Export encoder to ONNX
74
+
75
+ Args:
76
+ onnx_model_path (str): path to save ONNX model
77
+ provider (str): provider to use for verifying parity on ONNX model
78
+ verbose (bool, optional): print verbose information. Defaults to True.
79
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
80
+ use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
81
+ """
82
+ # Shape of encoder's tensors:
83
+ # Inputs:
84
+ # audio_features: (batch_size, num_mels, num_frames)
85
+ # Outputs:
86
+ # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
87
+
88
+ inputs = get_sample_encoder_inputs(
89
+ self.config,
90
+ self.device,
91
+ batch_size=2,
92
+ use_fp16=use_fp16_inputs,
93
+ )
94
+
95
+ input_names = self.input_names()
96
+ output_names = self.output_names()
97
+ dynamic_axes = self.dynamic_axes(input_names, output_names)
98
+
99
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
100
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
101
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder.onnx")
102
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
103
+ out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
104
+
105
+ torch.onnx.export(
106
+ self,
107
+ args=(inputs["audio_features"]),
108
+ f=out_path,
109
+ export_params=True,
110
+ input_names=input_names,
111
+ output_names=output_names,
112
+ dynamic_axes=dynamic_axes,
113
+ opset_version=17,
114
+ do_constant_folding=True,
115
+ verbose=verbose,
116
+ )
117
+
118
+ model = onnx.load_model(out_path, load_external_data=use_external_data_format)
119
+ model = self.fix_layernorm_weights(model, use_fp16_inputs)
120
+ OnnxModel.save(
121
+ model,
122
+ onnx_model_path,
123
+ save_as_external_data=use_external_data_format,
124
+ all_tensors_to_one_file=True,
125
+ )
126
+
127
+ self.verify_onnx(onnx_model_path, provider, use_fp16_inputs)
128
+
129
+ def verify_onnx(
130
+ self,
131
+ onnx_model_path: str,
132
+ provider: str,
133
+ use_fp16_inputs: bool,
134
+ ):
135
+ """Verify ONNX model outputs and PyTorch model outputs match
136
+
137
+ Args:
138
+ onnx_model_path (str): path to save ONNX model
139
+ provider (str): execution provider for ONNX model
140
+ use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
141
+ """
142
+ # Shape of encoder's tensors:
143
+ # Inputs:
144
+ # audio_features: (batch_size, num_mels, num_frames)
145
+ # Outputs:
146
+ # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
147
+ inputs = get_sample_encoder_inputs(
148
+ self.config,
149
+ self.device,
150
+ batch_size=2,
151
+ use_fp16=use_fp16_inputs,
152
+ )
153
+
154
+ # Run PyTorch model
155
+ pt_outputs = self.forward(inputs["audio_features"]).detach().cpu().numpy()
156
+
157
+ # Run ONNX model
158
+ sess = InferenceSession(onnx_model_path, providers=[provider])
159
+ ort_outputs = sess.run(None, {"audio_features": inputs["audio_features"].detach().cpu().numpy()})[0]
160
+
161
+ # Calculate output difference
162
+ diff = np.abs(pt_outputs - ort_outputs)
163
+ logger.warning("Comparing encoder_hidden_states...")
164
+ logger.warning(f"Max diff: {np.max(diff)}")