onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,334 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+
10
+ import onnx
11
+ from benchmark_helper import Precision
12
+ from convert_generation import (
13
+ get_shared_initializers,
14
+ update_decoder_subgraph_output_cross_attention,
15
+ update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha,
16
+ )
17
+ from onnx import TensorProto, helper
18
+ from transformers import WhisperConfig, WhisperTokenizer
19
+
20
+ logger = logging.getLogger(__name__)
21
+
22
+
23
+ def verify_inputs(beam_inputs, graph_inputs):
24
+ # Verify that ONNX graph's inputs match beam search op's inputs
25
+ beam_required_inputs = list(filter(lambda beam_input: beam_input, beam_inputs))
26
+ assert len(graph_inputs) == len(beam_required_inputs)
27
+ for graph_input, beam_input in zip(graph_inputs, beam_required_inputs, strict=False):
28
+ # Check if graph_input is in beam_input to handle beam_input names with the "_fp16" suffix
29
+ assert graph_input.name in beam_input
30
+
31
+
32
+ def clean_list(arr, remove_all_strings=True):
33
+ if remove_all_strings:
34
+ # Remove all empty strings in list
35
+ return list(filter(lambda elm: elm != "", arr))
36
+
37
+ # Remove empty strings at end of list
38
+ while len(arr) > 0:
39
+ if arr[-1] == "":
40
+ arr.pop()
41
+ else:
42
+ break
43
+ return arr
44
+
45
+
46
+ def chain_model(args):
47
+ # Load encoder/decoder and insert necessary (but unused) graph inputs expected by WhisperBeamSearch op
48
+ encoder_model = onnx.load_model(args.encoder_path, load_external_data=True)
49
+ encoder_model.graph.name = "encoderdecoderinit subgraph"
50
+
51
+ decoder_model = onnx.load_model(args.decoder_path, load_external_data=True)
52
+ decoder_model.graph.name = "decoder subgraph"
53
+
54
+ config = WhisperConfig.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
55
+ tokenizer = WhisperTokenizer.from_pretrained(args.model_name_or_path, cache_dir=args.cache_dir)
56
+
57
+ use_fp16_inputs = args.precision == Precision.FLOAT16 or (
58
+ args.precision in (Precision.INT8, Precision.INT4) and args.use_gpu
59
+ )
60
+ # Create inputs/outputs for WhisperBeamSearch op
61
+ temperature_name = "temperature_fp16" if use_fp16_inputs else "temperature"
62
+ beam_inputs = [
63
+ "input_features_fp16" if use_fp16_inputs else "input_features",
64
+ "max_length",
65
+ "min_length",
66
+ "num_beams",
67
+ "num_return_sequences",
68
+ "length_penalty_fp16" if use_fp16_inputs else "length_penalty",
69
+ "repetition_penalty_fp16" if use_fp16_inputs else "repetition_penalty",
70
+ "vocab_mask" if args.use_vocab_mask else "",
71
+ "prefix_vocab_mask" if args.use_prefix_vocab_mask else "",
72
+ "", # attention mask
73
+ "decoder_input_ids" if args.use_forced_decoder_ids else "",
74
+ "logits_processor" if args.use_logits_processor else "",
75
+ "cross_qk_layer_head" if args.collect_cross_qk else "",
76
+ "extra_decoding_ids" if args.extra_decoding_ids else "",
77
+ temperature_name if args.use_temperature else "",
78
+ ]
79
+
80
+ sequence_scores_name = "sequence_scores_fp16" if use_fp16_inputs else "sequence_scores"
81
+ scores_name = "scores_fp16" if use_fp16_inputs else "scores"
82
+ beam_outputs = [
83
+ "sequences",
84
+ sequence_scores_name if args.output_sequence_scores else "",
85
+ scores_name if args.output_scores else "",
86
+ "cross_qk" if args.collect_cross_qk else "",
87
+ "no_speech_probs_beam" if args.output_no_speech_probs else "",
88
+ ]
89
+
90
+ graph_nodes = []
91
+ if use_fp16_inputs:
92
+ input_features_cast_node = helper.make_node(
93
+ "Cast",
94
+ inputs=["input_features"],
95
+ outputs=["input_features_fp16"],
96
+ name="CastInputFeaturesToFp16",
97
+ to=TensorProto.FLOAT16,
98
+ )
99
+ len_pen_cast_node = helper.make_node(
100
+ "Cast",
101
+ inputs=["length_penalty"],
102
+ outputs=["length_penalty_fp16"],
103
+ name="CastLengthPenaltyToFp16",
104
+ to=TensorProto.FLOAT16,
105
+ )
106
+ rep_pen_cast_node = helper.make_node(
107
+ "Cast",
108
+ inputs=["repetition_penalty"],
109
+ outputs=["repetition_penalty_fp16"],
110
+ name="CastRepetitionPenaltyToFp16",
111
+ to=TensorProto.FLOAT16,
112
+ )
113
+ graph_nodes.extend([input_features_cast_node, len_pen_cast_node, rep_pen_cast_node])
114
+
115
+ if args.use_temperature:
116
+ temp_cast_node = helper.make_node(
117
+ "Cast",
118
+ inputs=["temperature"],
119
+ outputs=["temperature_fp16"],
120
+ name="temperature_to_fp16",
121
+ to=TensorProto.FLOAT16,
122
+ )
123
+ graph_nodes.append(temp_cast_node)
124
+
125
+ if args.output_sequence_scores:
126
+ output_sequence_scores_cast_node = helper.make_node(
127
+ "Cast",
128
+ inputs=["sequence_scores_fp16"],
129
+ outputs=["sequence_scores"],
130
+ name="CastOutputSequenceScoresToFp32",
131
+ to=TensorProto.FLOAT,
132
+ )
133
+ graph_nodes.append(output_sequence_scores_cast_node)
134
+
135
+ if args.output_scores:
136
+ output_scores_cast_node = helper.make_node(
137
+ "Cast",
138
+ inputs=["scores_fp16"],
139
+ outputs=["scores"],
140
+ name="CastScoresToFp32",
141
+ to=TensorProto.FLOAT,
142
+ )
143
+ graph_nodes.append(output_scores_cast_node)
144
+
145
+ # Create WhisperBeamSearch op
146
+ beam_search_attrs = [
147
+ helper.make_attribute("eos_token_id", config.eos_token_id),
148
+ helper.make_attribute("pad_token_id", config.pad_token_id),
149
+ helper.make_attribute(
150
+ "decoder_start_token_id", config.decoder_start_token_id
151
+ ), # same as tokenizer.convert_tokens_to_ids(['<|startoftranscript|>'])[0]
152
+ helper.make_attribute("translate_token_id", tokenizer.convert_tokens_to_ids(["<|translate|>"])[0]),
153
+ helper.make_attribute("transcribe_token_id", tokenizer.convert_tokens_to_ids(["<|transcribe|>"])[0]),
154
+ helper.make_attribute("start_of_lm_token_id", tokenizer.convert_tokens_to_ids(["<|startoflm|>"])[0]),
155
+ (
156
+ helper.make_attribute("no_speech_token_id", tokenizer.convert_tokens_to_ids(["<|nospeech|>"])[0])
157
+ if args.output_no_speech_probs
158
+ else ""
159
+ ),
160
+ helper.make_attribute("no_timestamps_token_id", tokenizer.convert_tokens_to_ids(["<|notimestamps|>"])[0]),
161
+ helper.make_attribute("beginning_timestamp_token_id", tokenizer.convert_tokens_to_ids(["<|0.00|>"])[0]),
162
+ helper.make_attribute("no_repeat_ngram_size", args.no_repeat_ngram_size),
163
+ helper.make_attribute("early_stopping", True),
164
+ helper.make_attribute("model_type", 2),
165
+ helper.make_attribute("decoder_output_cross_qk", 1) if args.collect_cross_qk else "",
166
+ ]
167
+ node = helper.make_node(
168
+ "WhisperBeamSearch",
169
+ inputs=clean_list(beam_inputs, remove_all_strings=False),
170
+ outputs=clean_list(beam_outputs, remove_all_strings=False),
171
+ name="BeamSearch",
172
+ domain="com.microsoft",
173
+ )
174
+ node.attribute.extend(clean_list(beam_search_attrs, remove_all_strings=True))
175
+
176
+ # Graph inputs
177
+ input_features = helper.make_tensor_value_info(
178
+ "input_features", TensorProto.FLOAT, ["batch_size", "feature_size", "sequence_length"]
179
+ )
180
+ max_length = helper.make_tensor_value_info("max_length", TensorProto.INT32, [1])
181
+ min_length = helper.make_tensor_value_info("min_length", TensorProto.INT32, [1])
182
+ num_beams = helper.make_tensor_value_info("num_beams", TensorProto.INT32, [1])
183
+ num_return_sequences = helper.make_tensor_value_info("num_return_sequences", TensorProto.INT32, [1])
184
+ length_penalty = helper.make_tensor_value_info("length_penalty", TensorProto.FLOAT, [1])
185
+ repetition_penalty = helper.make_tensor_value_info("repetition_penalty", TensorProto.FLOAT, [1])
186
+ vocab_mask = helper.make_tensor_value_info("vocab_mask", TensorProto.INT32, [config.vocab_size])
187
+ prefix_vocab_mask = helper.make_tensor_value_info(
188
+ "prefix_vocab_mask", TensorProto.INT32, ["batch_size", config.vocab_size]
189
+ )
190
+ decoder_input_ids = helper.make_tensor_value_info(
191
+ "decoder_input_ids", TensorProto.INT32, ["batch_size", "initial_sequence_length"]
192
+ )
193
+ logits_processor = helper.make_tensor_value_info("logits_processor", TensorProto.INT32, [1])
194
+ cross_qk_layer_head = helper.make_tensor_value_info("cross_qk_layer_head", TensorProto.INT32, ["num_layer_head", 2])
195
+ extra_decoding_ids = helper.make_tensor_value_info(
196
+ "extra_decoding_ids", TensorProto.INT32, ["batch_size", "extra_decoding_ids_len"]
197
+ )
198
+ temperature = helper.make_tensor_value_info("temperature", TensorProto.FLOAT, [1])
199
+
200
+ graph_inputs = clean_list(
201
+ [
202
+ input_features,
203
+ max_length,
204
+ min_length,
205
+ num_beams,
206
+ num_return_sequences,
207
+ length_penalty,
208
+ repetition_penalty,
209
+ vocab_mask if args.use_vocab_mask else "",
210
+ prefix_vocab_mask if args.use_prefix_vocab_mask else "",
211
+ decoder_input_ids if args.use_forced_decoder_ids else "",
212
+ logits_processor if args.use_logits_processor else "",
213
+ cross_qk_layer_head if args.collect_cross_qk else "",
214
+ extra_decoding_ids if args.extra_decoding_ids else "",
215
+ temperature if args.use_temperature else "",
216
+ ]
217
+ )
218
+
219
+ # Graph outputs
220
+ sequences = helper.make_tensor_value_info(
221
+ "sequences", TensorProto.INT32, ["batch_size", "num_return_sequences", "max_length"]
222
+ )
223
+ sequence_scores = helper.make_tensor_value_info("sequence_scores", TensorProto.FLOAT, ["batch_size"])
224
+ scores = helper.make_tensor_value_info("scores", TensorProto.FLOAT, ["batch_size"])
225
+ cross_qk = helper.make_tensor_value_info(
226
+ "cross_qk",
227
+ TensorProto.FLOAT,
228
+ ["batch_size", "num_return_sequences", "num_layer_head_cross_qk", "max_length", "frames"],
229
+ )
230
+ no_speech_probs = helper.make_tensor_value_info("no_speech_probs", TensorProto.FLOAT, ["batch_size"])
231
+
232
+ graph_outputs = clean_list(
233
+ [
234
+ sequences,
235
+ sequence_scores if args.output_sequence_scores else "",
236
+ scores if args.output_scores else "",
237
+ cross_qk if args.output_cross_qk or (not args.cross_qk_onnx_model and args.collect_cross_qk) else "",
238
+ no_speech_probs if args.output_no_speech_probs else "",
239
+ ]
240
+ )
241
+
242
+ # Replace MultiHeadAttention with DecoderMaskedMultiHeadAttention for CUDA EP inference
243
+ if hasattr(args, "use_gpu") and args.use_gpu:
244
+ if update_decoder_subgraph_share_buffer_and_use_decoder_masked_mha(decoder_model.graph):
245
+ logger.info("Updated whisper decoder subgraph to use DecoderMaskedMultiHeadAttention successfully!")
246
+ else:
247
+ logger.warning("DecoderMaskedMultiHeadAttention could not be applied to whisper decoder subgraph")
248
+ if hasattr(args, "collect_cross_qk") and args.collect_cross_qk:
249
+ update_decoder_subgraph_output_cross_attention(decoder_model.graph)
250
+
251
+ # Initializers/opsets
252
+ # Delete shared data between decoder/encoder and move to larger graph initializers
253
+ initializers = get_shared_initializers(encoder_model, decoder_model)
254
+ node.attribute.extend(
255
+ [
256
+ helper.make_attribute("decoder", decoder_model.graph),
257
+ helper.make_attribute("encoder", encoder_model.graph),
258
+ ]
259
+ )
260
+
261
+ opset_import = [helper.make_opsetid(domain="com.microsoft", version=1), helper.make_opsetid(domain="", version=17)]
262
+
263
+ graph_nodes.append(node)
264
+ if args.output_no_speech_probs:
265
+ prob_cast_node = helper.make_node(
266
+ "Cast",
267
+ inputs=["no_speech_probs_beam"],
268
+ outputs=["no_speech_probs"],
269
+ name="no_speech_probs_cast_to_fp32",
270
+ to=TensorProto.FLOAT,
271
+ )
272
+ graph_nodes.append(prob_cast_node)
273
+
274
+ # Make graph with WhisperBeamSearch op
275
+ beam_graph = helper.make_graph(
276
+ graph_nodes,
277
+ name="WhisperBeamSearch Graph",
278
+ inputs=graph_inputs,
279
+ outputs=graph_outputs,
280
+ initializer=initializers,
281
+ )
282
+ beam_graph_input_names = [gi.name for gi in graph_inputs]
283
+ beam_graph_output_names = [go.name for go in graph_outputs]
284
+
285
+ if args.cross_qk_onnx_model:
286
+ post_qk_model = onnx.load_model(args.cross_qk_onnx_model, load_external_data=True)
287
+ post_qk_graph = post_qk_model.graph
288
+ beam_graph.initializer.extend(post_qk_graph.initializer)
289
+ beam_graph.node.extend(post_qk_graph.node)
290
+ # If tensor from cross_qk_onnx_model has same name as tensor in beamsearch graph, treat them as same tensor.
291
+ # User should notice this rule when provide cross_qk_onnx_model to append to the beamsearch node.
292
+ for pgi in post_qk_graph.input:
293
+ if (
294
+ (pgi.name not in beam_graph_input_names)
295
+ and (pgi.name not in beam_graph_output_names)
296
+ and (pgi.name != "cross_qk")
297
+ ):
298
+ beam_graph.input.extend([pgi])
299
+ beam_graph.output.extend(post_qk_graph.output)
300
+
301
+ # Verify graph's inputs match beam search's inputs
302
+ verify_inputs(beam_inputs, graph_inputs)
303
+
304
+ assert decoder_model.ir_version == encoder_model.ir_version
305
+ logger.info(f"Using IR version {decoder_model.ir_version} for chained model")
306
+
307
+ # Set IR version of chained model to IR version of subgraphs in order to generate a working E2E model
308
+ beam_model = helper.make_model_gen_version(
309
+ beam_graph,
310
+ producer_name="onnxruntime.transformers",
311
+ opset_imports=opset_import,
312
+ ir_version=decoder_model.ir_version,
313
+ )
314
+
315
+ # Save WhisperBeamSearch graph and external data
316
+ if os.path.isfile(args.beam_model_output_dir):
317
+ logger.info(f"Overwriting {args.beam_model_output_dir} and {args.beam_model_output_dir + '.data'}")
318
+ if os.path.exists(args.beam_model_output_dir):
319
+ os.remove(args.beam_model_output_dir)
320
+ if os.path.exists(args.beam_model_output_dir + ".data"):
321
+ os.remove(args.beam_model_output_dir + ".data")
322
+
323
+ onnx.save(
324
+ beam_model,
325
+ args.beam_model_output_dir,
326
+ save_as_external_data=args.use_external_data_format,
327
+ all_tensors_to_one_file=True,
328
+ convert_attribute=True,
329
+ location=f"{os.path.basename(args.beam_model_output_dir)}.data",
330
+ )
331
+ try:
332
+ onnx.checker.check_model(args.beam_model_output_dir, full_check=True)
333
+ except Exception as e:
334
+ logger.error(f"An error occurred while running the ONNX checker: {e}", exc_info=True) # noqa: G201