onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,437 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+
7
+ import logging
8
+ import os
9
+ import tempfile
10
+ from pathlib import Path
11
+
12
+ import numpy
13
+ import onnx
14
+ import torch
15
+ from io_binding_helper import TypeHelper
16
+ from onnx_model import OnnxModel
17
+ from past_helper import PastKeyValuesHelper
18
+ from t5_encoder import T5EncoderInputs
19
+ from torch_onnx_export_helper import torch_onnx_export
20
+ from transformers import MT5Config, T5Config
21
+
22
+ from onnxruntime import InferenceSession
23
+
24
+ logger = logging.getLogger(__name__)
25
+
26
+
27
+ class T5DecoderInit(torch.nn.Module):
28
+ """A T5 decoder with LM head to create initial past key values.
29
+ This model is only called once during starting decoding.
30
+ """
31
+
32
+ def __init__(
33
+ self,
34
+ decoder: torch.nn.Module,
35
+ lm_head: torch.nn.Module,
36
+ config: T5Config | MT5Config,
37
+ decoder_start_token_id: int | None = None,
38
+ ):
39
+ super().__init__()
40
+ self.decoder = decoder
41
+ self.lm_head = lm_head
42
+ self.config = config
43
+ self.decoder_start_token_id = (
44
+ decoder_start_token_id if decoder_start_token_id is not None else self.config.decoder_start_token_id
45
+ )
46
+ self.tie_word_embeddings = (
47
+ self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
48
+ )
49
+
50
+ def forward(
51
+ self,
52
+ decoder_input_ids: torch.Tensor,
53
+ encoder_attention_mask: torch.Tensor,
54
+ encoder_hidden_states: torch.FloatTensor,
55
+ ):
56
+ if decoder_input_ids is None:
57
+ batch_size = encoder_attention_mask.shape[0]
58
+ decoder_input_ids = (
59
+ torch.ones(
60
+ (batch_size, 1),
61
+ dtype=torch.long,
62
+ device=encoder_attention_mask.device,
63
+ )
64
+ * self.decoder_start_token_id
65
+ )
66
+
67
+ decoder_outputs = self.decoder(
68
+ input_ids=decoder_input_ids,
69
+ encoder_hidden_states=encoder_hidden_states,
70
+ encoder_attention_mask=encoder_attention_mask,
71
+ use_cache=True,
72
+ return_dict=True,
73
+ )
74
+
75
+ sequence_output = decoder_outputs.last_hidden_state
76
+ present_key_values = decoder_outputs.past_key_values
77
+
78
+ if self.tie_word_embeddings:
79
+ sequence_output = sequence_output * (self.config.d_model**-0.5)
80
+
81
+ lm_logits = self.lm_head(sequence_output)
82
+ past_self, past_cross = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
83
+ return lm_logits, past_self, past_cross
84
+
85
+
86
+ class T5Decoder(torch.nn.Module):
87
+ """A T5 decoder with LM head and past key values"""
88
+
89
+ def __init__(self, decoder, lm_head, config):
90
+ super().__init__()
91
+ self.decoder = decoder
92
+ self.lm_head = lm_head
93
+ self.config = config
94
+ self.tie_word_embeddings = (
95
+ self.config.tie_word_embeddings if hasattr(self.config, "tie_word_embeddings") else True
96
+ )
97
+
98
+ def forward(self, decoder_input_ids, encoder_attention_mask, *past):
99
+ num_decoder_layers = self.config.num_decoder_layers
100
+ past_key_values = PastKeyValuesHelper.group_by_layer(past, num_decoder_layers)
101
+
102
+ # This is a hack since only the third dimension of encoder_hidden_states is used here
103
+ dummy_encoder_hidden_states = encoder_attention_mask.unsqueeze(2)
104
+ decoder_outputs = self.decoder(
105
+ input_ids=decoder_input_ids,
106
+ past_key_values=past_key_values,
107
+ encoder_hidden_states=dummy_encoder_hidden_states,
108
+ encoder_attention_mask=encoder_attention_mask,
109
+ use_cache=True,
110
+ return_dict=True,
111
+ )
112
+
113
+ sequence_output = decoder_outputs.last_hidden_state
114
+ present_key_values = decoder_outputs.past_key_values
115
+
116
+ if self.tie_word_embeddings:
117
+ sequence_output = sequence_output * (self.config.d_model**-0.5)
118
+
119
+ lm_logits = self.lm_head(sequence_output)
120
+ present_self, _ = PastKeyValuesHelper.group_by_self_or_cross(present_key_values)
121
+
122
+ # Do not return present_cross since they are identical to corresponding past_cross input
123
+ return lm_logits, present_self
124
+
125
+
126
+ class T5DecoderInputs:
127
+ def __init__(
128
+ self,
129
+ decoder_input_ids,
130
+ encoder_attention_mask,
131
+ past_key_values=None,
132
+ ):
133
+ self.decoder_input_ids: torch.LongTensor = decoder_input_ids
134
+ self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
135
+ self.past_key_values: list[torch.FloatTensor] | list[torch.HalfTensor] | None = past_key_values
136
+
137
+ @staticmethod
138
+ def create_dummy(
139
+ config: T5Config | MT5Config,
140
+ batch_size: int,
141
+ encode_sequence_length: int,
142
+ past_decode_sequence_length: int,
143
+ device: torch.device,
144
+ float16: bool = False,
145
+ use_int32_inputs: bool = False,
146
+ ): # -> T5DecoderInputs:
147
+ """Create dummy inputs for T5Decoder.
148
+
149
+ Args:
150
+ decoder: decoder
151
+ batch_size (int): batch size
152
+ encode_sequence_length (int): sequence length of input_ids for encoder
153
+ past_decode_sequence_length (int): past sequence length of input_ids for decoder
154
+ device (torch.device): device of output tensors
155
+ float16 (bool): whether the model uses float32 or float16 in input
156
+ use_int32_inputs(bool): whether use int32 instead of int64 for some inputs
157
+
158
+ Returns:
159
+ T5DecoderInputs: dummy inputs for decoder
160
+ """
161
+ num_attention_heads: int = config.num_heads
162
+ num_layers: int = config.num_decoder_layers
163
+ vocab_size: int = config.vocab_size
164
+
165
+ # Do not use head_size = hidden_size / num_attention_heads here.
166
+ # For example, mt5-small, d_model=512 and num_heads=6
167
+ head_size: int = config.d_kv
168
+
169
+ sequence_length: int = 1 # fixed for decoding
170
+ decoder_input_ids = torch.randint(
171
+ low=0,
172
+ high=vocab_size - 1,
173
+ size=(batch_size, sequence_length),
174
+ dtype=(torch.int32 if use_int32_inputs else torch.int64),
175
+ device=device,
176
+ )
177
+
178
+ encoder_inputs = T5EncoderInputs.create_dummy(
179
+ batch_size,
180
+ encode_sequence_length,
181
+ vocab_size,
182
+ device,
183
+ use_int32_inputs=use_int32_inputs,
184
+ )
185
+
186
+ float_type = torch.float16 if float16 else torch.float32
187
+
188
+ if past_decode_sequence_length > 0:
189
+ self_attention_past_shape = [
190
+ batch_size,
191
+ num_attention_heads,
192
+ past_decode_sequence_length,
193
+ head_size,
194
+ ]
195
+ cross_attention_past_shape = [
196
+ batch_size,
197
+ num_attention_heads,
198
+ encode_sequence_length,
199
+ head_size,
200
+ ]
201
+
202
+ past = []
203
+ for _ in range(2 * num_layers):
204
+ past.append(torch.rand(self_attention_past_shape, dtype=float_type, device=device))
205
+
206
+ for _ in range(2 * num_layers):
207
+ past.append(torch.rand(cross_attention_past_shape, dtype=float_type, device=device))
208
+ else:
209
+ past = None
210
+
211
+ return T5DecoderInputs(decoder_input_ids, encoder_inputs.attention_mask, past)
212
+
213
+ def to_list(self) -> list:
214
+ input_list = [
215
+ self.decoder_input_ids,
216
+ self.encoder_attention_mask,
217
+ ]
218
+ if self.past_key_values:
219
+ input_list.extend(self.past_key_values)
220
+ return input_list
221
+
222
+ def to_fp32(self):
223
+ past = [p.to(dtype=torch.float32) for p in self.past_key_values] if self.past_key_values else None
224
+ return T5DecoderInputs(
225
+ self.decoder_input_ids.clone(),
226
+ self.encoder_attention_mask.clone(),
227
+ past,
228
+ )
229
+
230
+
231
+ class T5DecoderHelper:
232
+ @staticmethod
233
+ def export_onnx(
234
+ decoder: T5Decoder | T5DecoderInit,
235
+ device: torch.device,
236
+ onnx_model_path: str,
237
+ verbose: bool = True,
238
+ use_external_data_format: bool = False,
239
+ use_int32_inputs: bool = False,
240
+ ):
241
+ """Export decoder to ONNX
242
+
243
+ Args:
244
+ decoder (Union[T5Decoder, T5DecoderNoPastState]): decoder object
245
+ device (torch.device): device of decoder object
246
+ onnx_model_path (str): onnx path
247
+ verbose (bool, optional): print verbose information. Defaults to True.
248
+ use_external_data_format (bool, optional): use external data format or not. Defaults to False.
249
+ use_int32_inputs (bool, optional): use int32 inputs
250
+ """
251
+ assert isinstance(decoder, (T5Decoder, T5DecoderInit))
252
+
253
+ inputs = T5DecoderInputs.create_dummy(
254
+ decoder.config,
255
+ batch_size=2,
256
+ encode_sequence_length=3,
257
+ past_decode_sequence_length=5 if isinstance(decoder, T5Decoder) else 0,
258
+ device=device,
259
+ use_int32_inputs=use_int32_inputs,
260
+ )
261
+ input_list = inputs.to_list()
262
+
263
+ num_decoder_layers = decoder.config.num_decoder_layers
264
+
265
+ past_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=False)
266
+ present_names = PastKeyValuesHelper.get_past_names(num_decoder_layers, present=True)
267
+ present_self_names = present_names[: 2 * num_decoder_layers]
268
+
269
+ input_past_names = past_names if isinstance(decoder, T5Decoder) else []
270
+ output_present_names = present_self_names if isinstance(decoder, T5Decoder) else present_names
271
+ output_names = ["logits", *output_present_names]
272
+
273
+ # Shape of input tensors (sequence_length==1):
274
+ # input_ids: (batch_size, sequence_length)
275
+ # encoder_attention_mask: (batch_size, encode_sequence_length)
276
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
277
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
278
+
279
+ # Shape of output tensors:
280
+ # logits: (batch_size, sequence_length, vocab_size)
281
+ # past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
282
+ # past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
283
+
284
+ input_names = ["input_ids"]
285
+ input_names.append("encoder_attention_mask")
286
+ input_names.extend(input_past_names)
287
+
288
+ dynamic_axes = {
289
+ "input_ids": {
290
+ 0: "batch_size",
291
+ # 1: 'sequence_length'
292
+ },
293
+ "encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
294
+ "encoder_hidden_states": {0: "batch_size", 1: "encode_sequence_length"},
295
+ "logits": {
296
+ 0: "batch_size",
297
+ # 1: 'sequence_length'
298
+ },
299
+ }
300
+
301
+ for name in input_past_names:
302
+ dynamic_axes[name] = {
303
+ 0: "batch_size",
304
+ 2: "past_decode_sequence_length" if "self" in name else "encode_sequence_length",
305
+ }
306
+
307
+ for name in output_present_names:
308
+ if "cross" in name:
309
+ dynamic_axes[name] = {0: "batch_size", 2: "encode_sequence_length"}
310
+ else: # self attention past state
311
+ if isinstance(decoder, T5Decoder):
312
+ dynamic_axes[name] = {
313
+ 0: "batch_size",
314
+ 2: "past_decode_sequence_length + 1",
315
+ }
316
+ else:
317
+ dynamic_axes[name] = {
318
+ 0: "batch_size",
319
+ # 2: 'sequence_length'
320
+ }
321
+
322
+ Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
323
+
324
+ with tempfile.TemporaryDirectory() as tmp_dir_name:
325
+ temp_onnx_model_path = os.path.join(tmp_dir_name, "decoder.onnx")
326
+ Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
327
+ torch_onnx_export(
328
+ decoder,
329
+ args=tuple(input_list),
330
+ f=temp_onnx_model_path if use_external_data_format else onnx_model_path,
331
+ export_params=True,
332
+ input_names=input_names,
333
+ output_names=output_names,
334
+ dynamic_axes=dynamic_axes,
335
+ opset_version=12,
336
+ do_constant_folding=True,
337
+ use_external_data_format=use_external_data_format,
338
+ verbose=verbose,
339
+ )
340
+
341
+ if use_external_data_format:
342
+ model = onnx.load_model(temp_onnx_model_path, load_external_data=True)
343
+ OnnxModel.save(
344
+ model,
345
+ onnx_model_path,
346
+ save_as_external_data=True,
347
+ all_tensors_to_one_file=True,
348
+ )
349
+
350
+ @staticmethod
351
+ def onnxruntime_inference(ort_session, inputs: T5DecoderInputs):
352
+ """Run inference of ONNX model."""
353
+ logger.debug("start onnxruntime_inference")
354
+
355
+ ort_inputs = {
356
+ "input_ids": numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy()),
357
+ "encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
358
+ }
359
+
360
+ if inputs.past_key_values:
361
+ assert len(inputs.past_key_values) % 4 == 0
362
+ num_layers = int(len(inputs.past_key_values) / 4)
363
+ past_names = PastKeyValuesHelper.get_past_names(num_layers)
364
+ for i, past_tensor in enumerate(inputs.past_key_values):
365
+ ort_inputs[past_names[i]] = numpy.ascontiguousarray(past_tensor.cpu().numpy())
366
+
367
+ ort_outputs = ort_session.run(None, ort_inputs)
368
+ return ort_outputs
369
+
370
+ @staticmethod
371
+ def verify_onnx(
372
+ model: T5Decoder | T5DecoderInit,
373
+ ort_session: InferenceSession,
374
+ device: torch.device,
375
+ use_int32_inputs: bool,
376
+ max_cases: int = 4,
377
+ ):
378
+ """Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
379
+ float16: bool = TypeHelper.get_input_type(ort_session, "past_key_self_0") == "tensor(float16)"
380
+
381
+ test_cases = [(4, 11, 3), (1, 2, 5), (3, 1, 1), (8, 5, 2)]
382
+ test_cases_max_diff = []
383
+ for (
384
+ batch_size,
385
+ encode_sequence_length,
386
+ past_decode_sequence_length,
387
+ ) in test_cases[:max_cases]:
388
+ if isinstance(model, T5DecoderInit):
389
+ past_decode_sequence_length = 0 # noqa: PLW2901
390
+
391
+ inputs = T5DecoderInputs.create_dummy(
392
+ model.config,
393
+ batch_size,
394
+ encode_sequence_length,
395
+ past_decode_sequence_length,
396
+ device=device,
397
+ float16=float16,
398
+ use_int32_inputs=use_int32_inputs,
399
+ )
400
+
401
+ # We use fp32 PyTroch model as baseline even when ONNX model is fp16
402
+ input_list = inputs.to_fp32().to_list()
403
+
404
+ # Run inference of PyTorch model
405
+ with torch.no_grad():
406
+ torch_outputs = model(*input_list)
407
+
408
+ ort_outputs = T5DecoderHelper.onnxruntime_inference(ort_session, inputs)
409
+ num_decoder_layers = model.config.num_decoder_layers
410
+
411
+ max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
412
+ max_diff_all = max_diff
413
+ logger.debug(f"logits max_diff={max_diff}")
414
+
415
+ for i in range(2 * num_decoder_layers):
416
+ max_diff = numpy.amax(numpy.abs(torch_outputs[1][i].cpu().numpy() - ort_outputs[1 + i]))
417
+ logger.debug(f"self attention past state {i} max_diff={max_diff}")
418
+ max_diff_all = max(max_diff_all, max_diff)
419
+
420
+ if isinstance(model, T5DecoderInit):
421
+ for i in range(2 * num_decoder_layers):
422
+ max_diff = numpy.amax(
423
+ numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[1 + 2 * num_decoder_layers + i])
424
+ )
425
+ logger.debug(f"cross attention past state {i} max_diff={max_diff}")
426
+ max_diff_all = max(max_diff_all, max_diff)
427
+
428
+ test_cases_max_diff.append(max_diff_all)
429
+ logger.info(
430
+ "batch_size=%s, encode_sequence_length=%s, past_decode_sequence_length=%s, max_diff=%s",
431
+ batch_size,
432
+ encode_sequence_length,
433
+ past_decode_sequence_length,
434
+ max_diff_all,
435
+ )
436
+
437
+ return max_diff_all
@@ -0,0 +1,70 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # -------------------------------------------------------------------------
5
+
6
+ import logging
7
+ import random
8
+
9
+ import torch
10
+ from transformers import MT5Config, T5Config
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class T5Encoder(torch.nn.Module):
16
+ """T5 encoder outputs only the last hidden state"""
17
+
18
+ def __init__(self, encoder, config: T5Config | MT5Config):
19
+ super().__init__()
20
+ self.encoder = encoder
21
+ self.config = config
22
+
23
+ def forward(self, input_ids, attention_mask):
24
+ return self.encoder(input_ids, attention_mask)[0]
25
+
26
+
27
+ class T5EncoderInputs:
28
+ def __init__(self, input_ids, attention_mask):
29
+ self.input_ids: torch.LongTensor = input_ids
30
+ self.attention_mask: torch.LongTensor = attention_mask
31
+
32
+ @staticmethod
33
+ def create_dummy(
34
+ batch_size: int,
35
+ sequence_length: int,
36
+ vocab_size: int,
37
+ device: torch.device,
38
+ use_int32_inputs: bool = False,
39
+ ): # -> T5EncoderInputs
40
+ """Create dummy inputs for T5 encoder.
41
+
42
+ Args:
43
+ batch_size (int): batch size
44
+ sequence_length (int): sequence length
45
+ vocab_size (int): vocabulary size
46
+ device (torch.device): device of output tensors
47
+
48
+ Returns:
49
+ T5EncoderInputs: dummy inputs for encoder
50
+ """
51
+ dtype = torch.int32 if use_int32_inputs else torch.int64
52
+
53
+ input_ids = torch.randint(
54
+ low=0,
55
+ high=vocab_size - 1,
56
+ size=(batch_size, sequence_length),
57
+ dtype=dtype,
58
+ device=device,
59
+ )
60
+
61
+ attention_mask = torch.ones([batch_size, sequence_length], dtype=dtype, device=device)
62
+ if sequence_length >= 2:
63
+ for i in range(batch_size):
64
+ padding_position = random.randint(0, sequence_length - 1)
65
+ attention_mask[i, :padding_position] = 0
66
+ return T5EncoderInputs(input_ids, attention_mask)
67
+
68
+ def to_list(self) -> list:
69
+ input_list = [v for v in [self.input_ids, self.attention_mask] if v is not None]
70
+ return input_list