onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,380 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+
9
+ import numpy as np
10
+ import torch
11
+ from transformers import WhisperConfig
12
+
13
+ from onnxruntime import InferenceSession
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ # Create audio_features for encoder
19
+ # Shape is (batch_size, feature_size, sequence_length) = (batch_size, num_mel_filters, num_frames)
20
+ # where num_mel_filters is a model attribute and num_frames = (chunk_length * sample_rate) // hop_length.
21
+ #
22
+ # Hard-coded audio hyperparameters:
23
+ # SAMPLE_RATE = 16000
24
+ # N_FFT = 400
25
+ # HOP_LENGTH = 160
26
+ # CHUNK_LENGTH = 30 (i.e. 30-second chunk of audio)
27
+ # N_SAMPLES = CHUNK_LENGTH * SAMPLE_RATE = 30 * 16000 = 480000 (i.e. 480,000 samples in a 30-second chunk of audio)
28
+ # N_FRAMES = N_SAMPLES // HOP_LENGTH = 480000 // 160 = 3000 (i.e. 3000 frames in a mel spectrogram input)
29
+ #
30
+ # N_SAMPLES_PER_TOKEN = HOP_LENGTH * 2 = 160 * 2 = 320
31
+ # FRAMES_PER_TOKEN = SAMPLE_RATE // HOP_LENGTH = 16000 // 160 = 100 (i.e. 10 ms per audio frame)
32
+ # TOKENS_PER_SECOND = SAMPLE_RATE // N_SAMPLES_PER_TOKEN = 16000 // 320 = 50 (i.e. 20 ms per audio token)
33
+ def get_sample_audio_features(
34
+ config: WhisperConfig,
35
+ device: torch.device,
36
+ batch_size: int,
37
+ sequence_length: int = 3000,
38
+ use_fp16: bool = False,
39
+ ):
40
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
41
+ audio_features = torch.randn(batch_size, config.num_mel_bins, sequence_length, device=device, dtype=torch_dtype)
42
+ return audio_features
43
+
44
+
45
+ # Create input_ids for decoder
46
+ # Shape is (batch_size, sequence_length) where sequence_length is the initial decoder sequence length
47
+ def get_sample_decoder_input_ids(
48
+ config: WhisperConfig,
49
+ device: torch.device,
50
+ batch_size: int,
51
+ sequence_length: int,
52
+ use_int32: bool = True,
53
+ ):
54
+ torch_dtype = torch.int32 if use_int32 else torch.int64
55
+ decoder_input_ids = torch.randint(
56
+ low=0, high=config.vocab_size, size=(batch_size, sequence_length), device=device, dtype=torch_dtype
57
+ )
58
+ return decoder_input_ids
59
+
60
+
61
+ # Create encoder_hidden_states for decoder-init
62
+ # Shape is (batch_size, num_frames // 2, hidden_size)
63
+ def get_sample_encoder_hidden_states(
64
+ config: WhisperConfig,
65
+ device: torch.device,
66
+ batch_size: int,
67
+ use_fp16: bool = False,
68
+ ):
69
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
70
+ encoder_hidden_states = torch.randn(
71
+ batch_size, config.max_source_positions, config.d_model, device=device, dtype=torch_dtype
72
+ )
73
+ return encoder_hidden_states
74
+
75
+
76
+ # Create past_key_values
77
+ # Self-attention KV caches are of shape (batch_size, num_heads, past_sequence_length, head_size)
78
+ # Cross-attention KV caches are of shape (batch_size, num_heads, num_frames // 2, head_size)
79
+ def get_sample_past_key_values(
80
+ config: WhisperConfig,
81
+ device: torch.device,
82
+ batch_size: int,
83
+ past_seq_len: int,
84
+ use_fp16: bool = False,
85
+ ):
86
+ num_heads = config.decoder_attention_heads
87
+ head_size = config.d_model // num_heads
88
+ max_source_positions = (
89
+ config.max_source_positions
90
+ ) # equal to num_frames // 2 = encoder's sequence_length // 2 = 3000 // 2 = 1500
91
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
92
+ self_attention_kv_caches = [
93
+ (
94
+ torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
95
+ torch.rand(batch_size, num_heads, past_seq_len, head_size, device=device, dtype=torch_dtype),
96
+ )
97
+ for _ in range(config.decoder_layers)
98
+ ]
99
+ cross_attention_kv_caches = [
100
+ (
101
+ torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
102
+ torch.rand(batch_size, num_heads, max_source_positions, head_size, device=device, dtype=torch_dtype),
103
+ )
104
+ for _ in range(config.decoder_layers)
105
+ ]
106
+ return flatten_past_key_values(self_attention_kv_caches, cross_attention_kv_caches)
107
+
108
+
109
+ # Flatten KV caches into pairs-of-4 where each pair is defined as:
110
+ # (self_attn_key_cache, self_attn_value_cache, cross_attn_key_cache, cross_attn_value_cache)
111
+ def flatten_past_key_values(
112
+ self_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
113
+ cross_attn_kv_caches: list[tuple[torch.Tensor, torch.Tensor]],
114
+ ):
115
+ past_key_values = []
116
+ for (self_k_cache, self_v_cache), (cross_k_cache, cross_v_cache) in zip(
117
+ self_attn_kv_caches, cross_attn_kv_caches, strict=False
118
+ ):
119
+ layer_kv_caches = (self_k_cache, self_v_cache, cross_k_cache, cross_v_cache)
120
+ past_key_values.append(layer_kv_caches)
121
+ return past_key_values
122
+
123
+
124
+ # Group KV caches into two 1D lists where one list contains the self attention KV caches and
125
+ # one list contains the cross attention KV caches
126
+ def group_past_key_values(
127
+ kv_caches: list[tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]],
128
+ ):
129
+ self_attn_kv_caches, cross_attn_kv_caches = [], []
130
+ for self_k_cache, self_v_cache, cross_k_cache, cross_v_cache in kv_caches:
131
+ self_attn_kv_caches.append(self_k_cache)
132
+ self_attn_kv_caches.append(self_v_cache)
133
+ cross_attn_kv_caches.append(cross_k_cache)
134
+ cross_attn_kv_caches.append(cross_v_cache)
135
+ return self_attn_kv_caches, cross_attn_kv_caches
136
+
137
+
138
+ # Create alignment heads for timestamps
139
+ # Shape is (num_alignment_heads, 2)
140
+ def get_sample_alignment_heads(
141
+ config: WhisperConfig,
142
+ device: torch.device,
143
+ num_alignment_heads: int = 6,
144
+ use_int32: bool = True,
145
+ ):
146
+ torch_dtype = torch.int32 if use_int32 else torch.int64
147
+ alignment_heads = torch.ones((num_alignment_heads, 2), device=device, dtype=torch_dtype)
148
+ return alignment_heads
149
+
150
+
151
+ # Create length of start-of-transcription sequence for timestamps
152
+ # Shape is (1)
153
+ def get_sample_sot_sequence_length(
154
+ device: torch.device,
155
+ sot_sequence_length: int,
156
+ use_int32: bool = False,
157
+ ):
158
+ torch_dtype = torch.int32 if use_int32 else torch.int64
159
+ sot_length = torch.tensor([sot_sequence_length], device=device, dtype=torch_dtype)
160
+ return sot_length
161
+
162
+
163
+ # Create segment length for timestamps
164
+ # Shape is (1)
165
+ def get_sample_segment_length(
166
+ device: torch.device,
167
+ segment_length: int,
168
+ use_int32: bool = False,
169
+ ):
170
+ torch_dtype = torch.int32 if use_int32 else torch.int64
171
+ segment_size = torch.tensor([segment_length], device=device, dtype=torch_dtype)
172
+ return segment_size
173
+
174
+
175
+ # Create QKs for timestamps
176
+ # Shape is (batch_size, num_heads, sequence_length, num_frames // 2)
177
+ def get_sample_QKs( # noqa: N802
178
+ config: WhisperConfig,
179
+ device: torch.device,
180
+ batch_size: int,
181
+ sequence_length: int,
182
+ use_fp16: bool = False,
183
+ ):
184
+ num_heads = config.decoder_attention_heads
185
+ torch_dtype = torch.float16 if use_fp16 else torch.float32
186
+ QKs = [ # noqa: N806
187
+ torch.rand(
188
+ batch_size, num_heads, sequence_length, config.max_source_positions, device=device, dtype=torch_dtype
189
+ )
190
+ for _ in range(config.decoder_layers)
191
+ ]
192
+ return QKs
193
+
194
+
195
+ # Create inputs for encoder component of Whisper
196
+ def get_sample_encoder_inputs(
197
+ config: WhisperConfig,
198
+ device: torch.device,
199
+ batch_size: int,
200
+ sequence_length: int = 3000,
201
+ use_fp16: bool = False,
202
+ ):
203
+ audio_features = get_sample_audio_features(config, device, batch_size, sequence_length, use_fp16)
204
+ return {"audio_features": audio_features}
205
+
206
+
207
+ # Create inputs for encoder component + first pass through decoder component of Whisper
208
+ def get_sample_encoder_decoder_init_inputs(
209
+ config: WhisperConfig,
210
+ device: torch.device,
211
+ batch_size: int,
212
+ decoder_sequence_length: int,
213
+ encoder_sequence_length: int = 3000,
214
+ use_fp16: bool = False,
215
+ use_int32: bool = True,
216
+ ):
217
+ audio_features = get_sample_audio_features(config, device, batch_size, encoder_sequence_length, use_fp16)
218
+ decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, decoder_sequence_length, use_int32)
219
+ return {"audio_features": audio_features, "decoder_input_ids": decoder_input_ids}
220
+
221
+
222
+ # Create inputs for decoder component of Whisper
223
+ # Inputs for first pass through the decoder (i.e. decoder-init): decoder_input_ids, encoder_hidden_states
224
+ # Inputs for subsequent passes through the decoder (i.e. decoder-with-past): decoder_input_ids, past_key_values
225
+ def get_sample_decoder_inputs(
226
+ config: WhisperConfig,
227
+ device: torch.device,
228
+ batch_size: int,
229
+ past_sequence_length: int,
230
+ sequence_length: int,
231
+ use_fp16: bool = False,
232
+ use_int32: bool = True,
233
+ ):
234
+ decoder_input_ids = get_sample_decoder_input_ids(config, device, batch_size, sequence_length, use_int32)
235
+ encoder_hidden_states = get_sample_encoder_hidden_states(config, device, batch_size, use_fp16)
236
+ past_key_values = get_sample_past_key_values(config, device, batch_size, past_sequence_length, use_fp16)
237
+ return {
238
+ "decoder_input_ids": decoder_input_ids,
239
+ "encoder_hidden_states": encoder_hidden_states,
240
+ "past_key_values": past_key_values,
241
+ }
242
+
243
+
244
+ # Create inputs for timestamps component of Whisper
245
+ def get_sample_jump_times_inputs(
246
+ config: WhisperConfig,
247
+ device: torch.device,
248
+ batch_size: int,
249
+ sequence_length: int,
250
+ num_alignment_heads: int,
251
+ sot_sequence_length: int,
252
+ segment_length: int,
253
+ use_fp16: bool = False,
254
+ use_int32: bool = True,
255
+ ):
256
+ alignment_heads = get_sample_alignment_heads(config, device, num_alignment_heads, use_int32)
257
+ # lengths need to be int64 because subsequent 'Slice' ops only take int64 inputs
258
+ sot_sequence_length = get_sample_sot_sequence_length(device, sot_sequence_length)
259
+ segment_length = get_sample_segment_length(device, segment_length)
260
+ QKs = get_sample_QKs(config, device, batch_size, sequence_length, use_fp16) # noqa: N806
261
+ return {
262
+ "alignment_heads": alignment_heads,
263
+ "sot_sequence_length": sot_sequence_length,
264
+ "segment_length": segment_length,
265
+ "QKs": QKs,
266
+ }
267
+
268
+
269
+ # Convert PyTorch inputs to ONNX Runtime inputs
270
+ def convert_inputs_for_ort(
271
+ inputs: dict,
272
+ model: InferenceSession,
273
+ ):
274
+ self_attn_kv_caches, cross_attn_kv_caches = None, None
275
+ batch_size, num_heads, past_seq_len, head_size = 0, 0, 0, 0
276
+ num_beams, max_seq_len = 1, 448
277
+ if "past_key_values" in inputs:
278
+ (self_attn_kv_caches, cross_attn_kv_caches) = group_past_key_values(inputs["past_key_values"])
279
+ batch_size, num_heads, past_seq_len, head_size = self_attn_kv_caches[0].shape
280
+
281
+ ort_inputs = {}
282
+ model_inputs = list(map(lambda i: i.name, model.get_inputs())) # noqa: C417
283
+ use_buffer_sharing = "cache_indirection" in model_inputs
284
+ for name in model_inputs:
285
+ if name in {"audio_features", "encoder_input_ids"}:
286
+ # Encoder input
287
+ ort_inputs[name] = inputs["audio_features"].detach().cpu().numpy()
288
+ elif name == "encoder_hidden_states":
289
+ # Encoder output
290
+ ort_inputs[name] = inputs["encoder_hidden_states"].detach().cpu().numpy()
291
+ elif name in {"decoder_input_ids", "input_ids"}:
292
+ # Decoder input
293
+ ort_inputs[name] = inputs["decoder_input_ids"].detach().cpu().numpy()
294
+ elif "past_key_self" in name or "past_value_self" in name:
295
+ # Decoder input
296
+ orig_kv_cache = self_attn_kv_caches.pop(0).detach().cpu().numpy()
297
+ if use_buffer_sharing:
298
+ new_kv_cache = np.zeros((batch_size, num_heads, max_seq_len, head_size), dtype=orig_kv_cache.dtype)
299
+ new_kv_cache[:batch_size, :num_heads, :past_seq_len, :head_size] = orig_kv_cache
300
+ ort_inputs[name] = new_kv_cache
301
+ else:
302
+ ort_inputs[name] = orig_kv_cache
303
+ elif "past_key_cross" in name or "past_value_cross" in name:
304
+ # Decoder input
305
+ orig_kv_cache = cross_attn_kv_caches.pop(0).detach().cpu().numpy()
306
+ ort_inputs[name] = orig_kv_cache
307
+ elif name == "past_sequence_length":
308
+ # Decoder input
309
+ ort_inputs[name] = np.array([past_seq_len], dtype=np.int32)
310
+ elif name == "cache_indirection":
311
+ # Decoder input
312
+ ort_inputs[name] = np.zeros((batch_size, num_beams, max_seq_len), dtype=np.int32)
313
+ elif name == "alignment_heads":
314
+ # Jump times input
315
+ ort_inputs[name] = inputs["alignment_heads"].detach().cpu().numpy()
316
+ elif name == "sot_sequence_length":
317
+ # Jump times input
318
+ ort_inputs[name] = inputs["sot_sequence_length"].detach().cpu().numpy()
319
+ elif name == "segment_length":
320
+ # Jump times input
321
+ ort_inputs[name] = inputs["segment_length"].detach().cpu().numpy()
322
+ elif "cross_qk" in name:
323
+ # Jump times input
324
+ ort_inputs[name] = inputs["QKs"].pop(0).detach().cpu().numpy()
325
+ else:
326
+ raise ValueError(f"Unknown name not recognized: {name}")
327
+
328
+ return ort_inputs
329
+
330
+
331
+ # Get dynamic axes for all inputs and outputs to the model
332
+ def get_model_dynamic_axes(
333
+ config: WhisperConfig,
334
+ input_names: list[str],
335
+ output_names: list[str],
336
+ ):
337
+ dynamic_axes = {}
338
+ for name in input_names + output_names:
339
+ if name in {"audio_features", "encoder_input_ids"}:
340
+ # shape is (batch_size, num_mels, num_frames)
341
+ dynamic_axes[name] = {0: "batch_size"}
342
+ elif name in {"input_ids", "decoder_input_ids"}:
343
+ # shape is (batch_size, sequence_length)
344
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
345
+ elif name == "alignment_heads":
346
+ # shape is (num_alignment_heads, 2)
347
+ dynamic_axes[name] = {0: "num_alignment_heads"}
348
+ elif name in {"sot_sequence_length", "segment_length"}:
349
+ # shape is (1)
350
+ pass
351
+ elif name == "logits":
352
+ # shape is (batch_size, sequence_length, vocab_size)
353
+ dynamic_axes[name] = {0: "batch_size", 1: "sequence_length"}
354
+ elif name == "encoder_hidden_states":
355
+ # shape is (batch_size, num_frames // 2, hidden_size)
356
+ dynamic_axes[name] = {0: "batch_size"}
357
+ elif "past_key_self" in name or "past_value_self" in name:
358
+ # shape is (batch_size, num_heads, past_sequence_length, head_size)
359
+ dynamic_axes[name] = {0: "batch_size", 2: "past_sequence_length"}
360
+ elif "present_key_self" in name or "present_value_self" in name:
361
+ # shape is (batch_size, num_heads, past_sequence_length + sequence_length, head_size),
362
+ # which is equal to (batch_size, num_heads, total_sequence_length, head_size)
363
+ dynamic_axes[name] = {0: "batch_size", 2: "total_sequence_length"}
364
+ elif (
365
+ "past_key_cross" in name
366
+ or "past_value_cross" in name
367
+ or "present_key_cross" in name
368
+ or "present_value_cross" in name
369
+ ):
370
+ # shape is (batch_size, num_heads, num_frames // 2, head_size)
371
+ dynamic_axes[name] = {0: "batch_size"}
372
+ elif "cross_qk" in name:
373
+ # shape is (batch_size, num_heads, source_sequence_length, target_sequence_length)
374
+ dynamic_axes[name] = {0: "batch_size", 2: "sequence_length"}
375
+ elif "jump_times" in name:
376
+ # shape is (batch_size, max_length)
377
+ dynamic_axes[name] = {0: "batch_size", 1: "max_length"}
378
+ else:
379
+ raise Exception(f"Unknown input or output name found: {name}")
380
+ return dynamic_axes