onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,371 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from itertools import chain
11
+ from pathlib import Path
12
+
13
+ import numpy as np
14
+ import onnx
15
+ import torch
16
+ from float16 import convert_float_to_float16
17
+ from onnx import ModelProto, ValueInfoProto
18
+ from onnx_model import OnnxModel
19
+ from transformers import WhisperConfig
20
+ from whisper_decoder import WhisperDecoder
21
+ from whisper_encoder import WhisperEncoder
22
+ from whisper_inputs import (
23
+ convert_inputs_for_ort,
24
+ get_model_dynamic_axes,
25
+ get_sample_encoder_decoder_init_inputs,
26
+ group_past_key_values,
27
+ )
28
+
29
+ from onnxruntime import InferenceSession
30
+
31
+ logger = logging.getLogger(__name__)
32
+
33
+
34
+ class WhisperEncoderDecoderInit(torch.nn.Module):
35
+ """Whisper encoder component + first pass through Whisper decoder component to initialize KV caches"""
36
+
37
+ def __init__(self, config: WhisperConfig, model: torch.nn.Module, model_impl: str, no_beam_search_op: bool = False):
38
+ super().__init__()
39
+ self.config = config
40
+ self.device = model.device
41
+ self.model_impl = model_impl
42
+ self.no_beam_search_op = no_beam_search_op
43
+
44
+ self.encoder = WhisperEncoder(config, model, model_impl)
45
+ self.decoder = WhisperDecoder(config, model, model_impl, no_beam_search_op)
46
+
47
+ self.max_source_positions = self.config.max_source_positions
48
+ self.num_heads = self.config.decoder_attention_heads
49
+ self.head_size = self.config.d_model // self.num_heads
50
+
51
+ def hf_forward_for_beam_search_op(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor):
52
+ encoder_hidden_states = self.encoder(audio_features)
53
+ logits, present_key_values = self.decoder(decoder_input_ids, encoder_hidden_states)
54
+ return logits, encoder_hidden_states, present_key_values
55
+
56
+ def hf_forward_for_no_beam_search_op(self, audio_features: torch.Tensor):
57
+ encoder_hidden_states = self.encoder(audio_features)
58
+
59
+ # Get cross attention KV caches and return them for this model
60
+ # We do this because these MatMuls are only run once before their outputs are being re-used in the decoder
61
+ present_cross_attention_key_value_caches = []
62
+ for layer in self.decoder.decoder.layers:
63
+ cross_attn_key_cache = (
64
+ layer.encoder_attn.k_proj(encoder_hidden_states)
65
+ .view(-1, self.max_source_positions, self.num_heads, self.head_size)
66
+ .transpose(1, 2)
67
+ )
68
+ cross_attn_value_cache = (
69
+ layer.encoder_attn.v_proj(encoder_hidden_states)
70
+ .view(-1, self.max_source_positions, self.num_heads, self.head_size)
71
+ .transpose(1, 2)
72
+ )
73
+ present_cross_attention_key_value_caches.append(cross_attn_key_cache)
74
+ present_cross_attention_key_value_caches.append(cross_attn_value_cache)
75
+
76
+ return encoder_hidden_states, present_cross_attention_key_value_caches
77
+
78
+ def oai_forward_for_beam_search_op(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor):
79
+ encoder_hidden_states = self.encoder(audio_features)
80
+ logits, present_key_values = self.decoder(decoder_input_ids, encoder_hidden_states)
81
+ return logits, encoder_hidden_states, present_key_values
82
+
83
+ def oai_forward_for_no_beam_search_op(self, audio_features: torch.Tensor):
84
+ encoder_hidden_states = self.encoder(audio_features)
85
+
86
+ # Get cross attention KV caches and return them for this model
87
+ # We do this because these MatMuls are only run once before their outputs are being re-used in the decoder
88
+ present_cross_attention_key_value_caches = []
89
+ for block in self.decoder.model.decoder.blocks:
90
+ cross_attn_key_cache = (
91
+ block.cross_attn.key(encoder_hidden_states)
92
+ .view(-1, self.max_source_positions, self.num_heads, self.head_size)
93
+ .transpose(1, 2)
94
+ )
95
+ cross_attn_value_cache = (
96
+ block.cross_attn.value(encoder_hidden_states)
97
+ .view(-1, self.max_source_positions, self.num_heads, self.head_size)
98
+ .transpose(1, 2)
99
+ )
100
+ present_cross_attention_key_value_caches.append(cross_attn_key_cache)
101
+ present_cross_attention_key_value_caches.append(cross_attn_value_cache)
102
+
103
+ return encoder_hidden_states, present_cross_attention_key_value_caches
104
+
105
+ def forward(self, audio_features: torch.Tensor, decoder_input_ids: torch.Tensor | None = None):
106
+ if self.model_impl == "openai":
107
+ if self.no_beam_search_op:
108
+ return self.oai_forward_for_no_beam_search_op(audio_features)
109
+ return self.oai_forward_for_beam_search_op(audio_features, decoder_input_ids)
110
+
111
+ # Hugging Face implementation
112
+ if self.no_beam_search_op:
113
+ return self.hf_forward_for_no_beam_search_op(audio_features)
114
+ return self.hf_forward_for_beam_search_op(audio_features, decoder_input_ids)
115
+
116
+ def input_names(self):
117
+ if self.no_beam_search_op:
118
+ input_names = ["audio_features"]
119
+ else:
120
+ input_names = ["encoder_input_ids", "decoder_input_ids"]
121
+ return input_names
122
+
123
+ def output_names(self):
124
+ if self.no_beam_search_op:
125
+ output_names = [
126
+ "encoder_hidden_states",
127
+ *list(
128
+ chain.from_iterable(
129
+ (f"present_key_cross_{i}", f"present_value_cross_{i}")
130
+ for i in range(self.config.decoder_layers)
131
+ )
132
+ ),
133
+ ]
134
+ else:
135
+ output_names = [
136
+ "logits",
137
+ "encoder_hidden_states",
138
+ *list(
139
+ chain.from_iterable(
140
+ (
141
+ f"present_key_self_{i}",
142
+ f"present_value_self_{i}",
143
+ f"present_key_cross_{i}",
144
+ f"present_value_cross_{i}",
145
+ )
146
+ for i in range(self.config.decoder_layers)
147
+ )
148
+ ),
149
+ ]
150
+ return output_names
151
+
152
+ def dynamic_axes(self, input_names, output_names):
153
+ dynamic_axes = get_model_dynamic_axes(self.config, input_names, output_names)
154
+ return dynamic_axes
155
+
156
+ def inputs(self, use_fp16_inputs: bool, use_int32_inputs: bool, return_dict: bool = False):
157
+ inputs = get_sample_encoder_decoder_init_inputs(
158
+ self.config,
159
+ self.device,
160
+ batch_size=2,
161
+ decoder_sequence_length=6,
162
+ use_fp16=use_fp16_inputs,
163
+ use_int32=use_int32_inputs,
164
+ )
165
+ if return_dict:
166
+ if self.no_beam_search_op:
167
+ del inputs["decoder_input_ids"]
168
+ return inputs
169
+
170
+ if self.no_beam_search_op:
171
+ return (inputs["audio_features"],)
172
+ return (
173
+ inputs["audio_features"],
174
+ inputs["decoder_input_ids"],
175
+ )
176
+
177
+ def fix_key_value_cache_dims(self, output: ValueInfoProto, is_cross: bool = False):
178
+ # Shape should be (batch_size, num_heads, sequence_length, head_size) for self attention KV caches
179
+ # and (batch_size, num_heads, num_frames // 2, head_size) for cross attention KV caches
180
+ num_heads = output.type.tensor_type.shape.dim[1]
181
+ if "_dim_" in num_heads.dim_param:
182
+ num_heads.Clear()
183
+ num_heads.dim_value = self.num_heads
184
+ sequence_length = output.type.tensor_type.shape.dim[2]
185
+ if "_dim_" in sequence_length.dim_param:
186
+ sequence_length.Clear()
187
+ if is_cross:
188
+ sequence_length.dim_value = self.max_source_positions
189
+ else:
190
+ sequence_length.dim_param = "total_sequence_length"
191
+ head_size = output.type.tensor_type.shape.dim[3]
192
+ if "_dim_" in head_size.dim_param:
193
+ head_size.Clear()
194
+ head_size.dim_value = self.head_size
195
+ return output
196
+
197
+ def fix_outputs(self, model: ModelProto):
198
+ # ONNX exporter might mark dimensions like 'Transposepresent_value_self_1_dim_2' in shape inference.
199
+ # We now change the dim_values to the correct one.
200
+ reordered_outputs = []
201
+ self_attn_kv_caches = []
202
+ cross_attn_kv_caches = []
203
+
204
+ for output in model.graph.output:
205
+ if "present" not in output.name:
206
+ reordered_outputs.append(output)
207
+
208
+ elif "self" in output.name:
209
+ # Self attention KV caches
210
+ new_output = self.fix_key_value_cache_dims(output, is_cross=False)
211
+ if self.no_beam_search_op:
212
+ reordered_outputs.append(new_output)
213
+ else:
214
+ self_attn_kv_caches.append(new_output)
215
+ else:
216
+ # Cross attention KV caches
217
+ new_output = self.fix_key_value_cache_dims(output, is_cross=True)
218
+ if self.no_beam_search_op:
219
+ reordered_outputs.append(new_output)
220
+ else:
221
+ cross_attn_kv_caches.append(new_output)
222
+
223
+ if not self.no_beam_search_op:
224
+ reordered_outputs += self_attn_kv_caches + cross_attn_kv_caches
225
+
226
+ while len(model.graph.output) > 0:
227
+ model.graph.output.pop()
228
+ model.graph.output.extend(reordered_outputs)
229
+ return model
230
+
231
+ def fix_layernorm_weights(self, model: ModelProto, use_fp16_inputs: bool):
232
+ if self.model_impl == "openai" and use_fp16_inputs:
233
+ # Cast ONNX model to float16 to ensure LayerNorm weights are converted from
234
+ # float32 to float16 since exported model already has float16 weights everywhere
235
+ # except for LayerNorm ops. This happens because OpenAI always upcasts to float32
236
+ # when computing LayerNorm.
237
+ #
238
+ # Reference:
239
+ # https://github.com/openai/whisper/blob/90db0de1896c23cbfaf0c58bc2d30665f709f170/whisper/model.py#L41
240
+ model = convert_float_to_float16(model)
241
+ return model
242
+
243
+ def export_onnx(
244
+ self,
245
+ onnx_model_path: str,
246
+ provider: str,
247
+ verbose: bool = True,
248
+ use_external_data_format: bool = False,
249
+ use_fp16_inputs: bool = False,
250
+ use_int32_inputs: bool = True,
251
+ ):
252
+ """Export encoder-decoder-init to ONNX
253
+
254
+ Args:
255
+ onnx_model_path (str): path to save ONNX model
256
+ provider (str): provider to use for verifying parity on ONNX model
257
+ verbose (bool, optional): print verbose information. Defaults to True.
258
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
259
+ use_fp16_inputs (bool, optional): use float16 inputs for the audio_features. Defaults to False.
260
+ use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids. Defaults to True.
261
+ """
262
+ # Shape of encoder's tensors:
263
+ # Inputs:
264
+ # audio_features: (batch_size, num_mels, num_frames)
265
+ # Outputs:
266
+ # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
267
+
268
+ # Shape of decoder's tensors:
269
+ # Inputs:
270
+ # decoder_input_ids: (batch_size, sequence_length)
271
+ # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
272
+ # Outputs:
273
+ # logits: (batch_size, sequence_length, vocab_size)
274
+ # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
275
+ # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
276
+
277
+ inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs)
278
+ input_names = self.input_names()
279
+ output_names = self.output_names()
280
+ dynamic_axes = self.dynamic_axes(input_names, output_names)
281
+
282
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
283
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
284
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
285
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
286
+ out_path = temp_onnx_model_path if use_external_data_format else onnx_model_path
287
+
288
+ torch.onnx.export(
289
+ self,
290
+ args=inputs,
291
+ f=out_path,
292
+ export_params=True,
293
+ input_names=input_names,
294
+ output_names=output_names,
295
+ dynamic_axes=dynamic_axes,
296
+ opset_version=17,
297
+ do_constant_folding=True,
298
+ verbose=verbose,
299
+ )
300
+
301
+ model = onnx.load_model(out_path, load_external_data=use_external_data_format)
302
+ model = self.fix_outputs(model)
303
+ model = self.fix_layernorm_weights(model, use_fp16_inputs)
304
+ OnnxModel.save(
305
+ model,
306
+ onnx_model_path,
307
+ save_as_external_data=use_external_data_format,
308
+ all_tensors_to_one_file=True,
309
+ )
310
+
311
+ self.verify_onnx(onnx_model_path, provider, use_fp16_inputs, use_int32_inputs)
312
+
313
+ def verify_onnx(
314
+ self,
315
+ onnx_model_path: str,
316
+ provider: str,
317
+ use_fp16_inputs: bool,
318
+ use_int32_inputs: bool,
319
+ ):
320
+ """Verify ONNX model outputs and PyTorch model outputs match
321
+
322
+ Args:
323
+ onnx_model_path (str): path to save ONNX model
324
+ provider (str): execution provider for ONNX model
325
+ use_fp16_inputs (bool, optional): use float16 inputs for the audio_features
326
+ use_int32_inputs (bool, optional): use int32 inputs for the decoder_input_ids
327
+ """
328
+ # Shape of encoder's tensors:
329
+ # Inputs:
330
+ # audio_features: (batch_size, num_mels, num_frames)
331
+ # Outputs:
332
+ # encoder_hidden_states: (batch_size, num_frames // 2, hidden_size)
333
+
334
+ # Shape of decoder's tensors:
335
+ # Inputs:
336
+ # decoder_input_ids: (batch_size, sequence_length)
337
+ # encoder_hidden_states (comes from encoder's outputs): (batch_size, num_frames // 2, hidden_size)
338
+ # Outputs:
339
+ # logits: (batch_size, sequence_length, vocab_size)
340
+ # present_{key/value}_self_* (present self attention KV caches): (batch_size, num_heads, past_sequence_length + sequence_length, head_size)
341
+ # present_{key/value}_cross_* (present cross attention KV caches): (batch_size, num_heads, num_frames // 2, head_size)
342
+
343
+ inputs = self.inputs(use_fp16_inputs=use_fp16_inputs, use_int32_inputs=use_int32_inputs, return_dict=True)
344
+
345
+ # Run PyTorch model
346
+ pt_outputs = []
347
+ if self.no_beam_search_op:
348
+ out = self.forward(**inputs)
349
+ pt_outputs.append(out[0].detach().cpu().numpy())
350
+ for present_cross_attn_cache in out[1]:
351
+ pt_outputs.append(present_cross_attn_cache.detach().cpu().numpy())
352
+ else:
353
+ out = self.forward(**inputs)
354
+ pt_outputs.append(out[0].detach().cpu().numpy())
355
+ pt_outputs.append(out[1].detach().cpu().numpy())
356
+
357
+ (self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(out[2])
358
+ pt_outputs.extend([self_attn_kv_cache.detach().cpu().numpy() for self_attn_kv_cache in self_attn_kv_caches])
359
+ pt_outputs.extend(
360
+ [cross_attn_kv_cache.detach().cpu().numpy() for cross_attn_kv_cache in cross_attn_kv_caches]
361
+ )
362
+
363
+ # Run ONNX model
364
+ sess = InferenceSession(onnx_model_path, providers=[provider])
365
+ ort_outputs = sess.run(None, convert_inputs_for_ort(inputs, sess))
366
+
367
+ # Calculate output difference
368
+ for i, output_name in enumerate(self.output_names()):
369
+ diff = np.abs(pt_outputs[i] - ort_outputs[i])
370
+ logger.warning(f"Comparing {output_name}...")
371
+ logger.warning(f"Max diff: {np.max(diff)}")