onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,307 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import logging
9
+ from pathlib import Path
10
+
11
+ import onnx
12
+
13
+ from ...fusions import FusionGelu, FusionLayerNormalization
14
+ from ...onnx_model import ONNXModel
15
+ from .fusion_lpnorm import FusionLpNormalization
16
+
17
+
18
+ def qnn_preprocess_model(
19
+ model_input: str | Path | onnx.ModelProto,
20
+ model_output: str | Path,
21
+ fuse_layernorm: bool = False,
22
+ save_as_external_data: bool = False,
23
+ all_tensors_to_one_file: bool = False,
24
+ external_data_location: str | None = None,
25
+ external_data_size_threshold: int = 1024,
26
+ external_data_convert_attribute: bool = False,
27
+ inputs_to_make_channel_last: list[str] | None = None,
28
+ outputs_to_make_channel_last: list[str] | None = None,
29
+ ) -> bool:
30
+ """
31
+ If necessary, this method creates a new "pre-processed" model in preparation for
32
+ quantization of a model to be used in QNN EP. Returns true if a new model was created.
33
+
34
+ This method perfoms the following operations:
35
+ - Fuse Erf sequence into a single Gelu node.
36
+ - Fuse ReduceL2 sequence into a single LpNormalization node (p == 2).
37
+ - (Optional) Fuse ReduceMean sequence into a single LayerNormalization node.
38
+
39
+ Args:
40
+ model_input: Path to the input model file or ModelProto.
41
+ model_output: Path the output model file, which is only created if this method returns True.
42
+ fuse_layernorm: True if ReduceMean sequences should be fused into LayerNormalization nodes.
43
+ Defaults to False.
44
+ save_as_external_data: True if output model should be saved with external data. Defaults to false.
45
+ all_tensors_to_one_file: Effective only if save_as_external_data is true. Defaults to false.
46
+ If true, save all tensors to one external file specified by external_data_location.
47
+ If false, save each tensor to a file named with the tensor name.
48
+ external_data_location: Effective only if save_as_external_data is true. Defaults to None.
49
+ Specify the external file to which all tensors are saved. Path is relative
50
+ to the model path. If not specified, the model's name is used.
51
+ external_data_size_threshold: Effective only if save_as_external_data is true. Defaults to 1024.
52
+ Tensors with a data size >= external_data_size_threshold are converted to external data.
53
+ To convert every tensor with raw data to external data, set to 0.
54
+ external_data_convert_attribute: Effective only if save_as_external_data is true. Defaults to false.
55
+ If true, convert all tensors to external data.
56
+ If false, convert only non-attribute tensors to external data.
57
+ inputs_to_make_channel_last: List of graph input names to transpose to be "channel-last". For example,
58
+ if "input0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change input0's
59
+ shape to (N, D1, D2, ..., Dn, C) and add a transpose node after it.
60
+
61
+ Original:
62
+ input0 (N, C, D1, D2, ..., Dn) --> <Nodes>
63
+
64
+ Updated:
65
+ input0 (N, D1, D2, ..., Dn, C) --> Transpose --> input0_chanfirst (N, C, D1, D2, ..., Dn) --> <Nodes>
66
+
67
+ This can potentially improve inference latency for QDQ models running on QNN EP because the
68
+ additional transpose node may allow other transpose nodes inserted during ORT layout transformation
69
+ to cancel out.
70
+ outputs_to_make_channel_last: List of graph output names to transpose to be "channel-last". For example,
71
+ if "output0" originally has the shape (N, C, D1, D2, ..., Dn), the resulting model will change output0's
72
+ shape to (N, D1, D2, ..., Dn, C) and add a transpose node before it.
73
+
74
+ Original:
75
+ <Nodes> --> output0 (N, C, D1, D2, ..., Dn)
76
+
77
+ Updated:
78
+ <Nodes> --> output0_chanfirst (N, C, D1, D2, ..., Dn) --> Transpose --> output0 (N, D1, D2, ..., Dn, C)
79
+
80
+ This can potentially improve inference latency for QDQ models running on QNN EP because the
81
+ additional transpose node may allow other transpose nodes inserted during ORT layout transformation
82
+ to cancel out.
83
+ """
84
+ modified = False
85
+ model = model_input if isinstance(model_input, onnx.ModelProto) else onnx.load_model(model_input)
86
+ onnx_model = ONNXModel(model)
87
+
88
+ # Fuse Erf sequence into a single Gelu
89
+ fusion_gelu = FusionGelu(onnx_model)
90
+ if fusion_gelu.apply():
91
+ modified = True
92
+
93
+ # Fuse ReduceL2 sequence into a single LpNormalization node with p == 2.
94
+ fusion_lpnorm = FusionLpNormalization(onnx_model)
95
+ if fusion_lpnorm.apply():
96
+ modified = True
97
+
98
+ # Optionally, fuse ReduceMean sequence into a single LayerNormalization node.
99
+ if fuse_layernorm:
100
+ onnx_opset = next(x for x in model.opset_import if x.domain == "" or x.domain == "ai.onnx")
101
+
102
+ # Need opset >= 17 to use LayerNormalization.
103
+ if onnx_opset.version < 17:
104
+ logging.warning(
105
+ "Unable to fuse ReduceMean sequence into a LayerNormalization node. "
106
+ "ONNX model must use an opset >= 17 in order to use LayerNormalization, "
107
+ f"but found version {onnx_opset.version}. Please use onnx.version_converter to update your model."
108
+ )
109
+ else:
110
+ fusion_layernorm = FusionLayerNormalization(onnx_model)
111
+ if fusion_layernorm.apply():
112
+ modified = True
113
+
114
+ # Optionally, transpose inputs and/or outputs to make them "channel-last".
115
+ if inputs_to_make_channel_last or outputs_to_make_channel_last:
116
+ transpose_node_prefix = "Transpose_channel_"
117
+ transpose_node_suffix: int = onnx_model.get_largest_node_name_suffix(transpose_node_prefix) + 1
118
+ update_io_to_channel_last(
119
+ onnx_model.model,
120
+ inputs_to_make_channel_last,
121
+ outputs_to_make_channel_last,
122
+ transpose_node_name_prefix=transpose_node_prefix,
123
+ transpose_node_name_start_suffix=transpose_node_suffix,
124
+ )
125
+ modified = True
126
+
127
+ # Make sure all nodes have a name.
128
+ unnamed_node_prefix = "qnn_preproc_node_"
129
+ available_suffix = onnx_model.get_largest_node_name_suffix(unnamed_node_prefix) + 1
130
+ for node in onnx_model.model.graph.node:
131
+ if node.op_type != "Constant" and not node.name:
132
+ new_node_name = f"{unnamed_node_prefix}{available_suffix!s}"
133
+ available_suffix += 1
134
+ node.name = new_node_name
135
+ modified = True
136
+ logging.warning(f"Node of type {node.op_type} does not have a name. Renamed to {new_node_name}.")
137
+
138
+ if modified:
139
+ onnx_model.topological_sort()
140
+ onnx.save_model(
141
+ model,
142
+ model_output,
143
+ save_as_external_data=save_as_external_data,
144
+ all_tensors_to_one_file=all_tensors_to_one_file,
145
+ location=external_data_location,
146
+ size_threshold=external_data_size_threshold,
147
+ convert_attribute=external_data_convert_attribute,
148
+ )
149
+
150
+ return modified
151
+
152
+
153
+ class InputOutputNameMap:
154
+ def __init__(
155
+ self,
156
+ orig_tensor_names: set[str],
157
+ orig_graph_inputs: dict[str, onnx.ValueInfoProto],
158
+ orig_graph_outputs: dict[str, onnx.ValueInfoProto],
159
+ ):
160
+ self.orig_tensor_names = orig_tensor_names
161
+ self.orig_graph_inputs = orig_graph_inputs
162
+ self.orig_graph_outputs = orig_graph_outputs
163
+ self.updated_io_names = {}
164
+ self.new_value_infos = []
165
+
166
+ def get_new_name(self, orig_name: str):
167
+ if orig_name in self.updated_io_names:
168
+ return self.updated_io_names[orig_name]
169
+
170
+ # Make a new tensor name that is unique among all tensors in the graph.
171
+ prefix: str = f"{orig_name}_channel_first_"
172
+ suffix: int = -1
173
+ for tensor_name in self.orig_tensor_names:
174
+ if tensor_name.startswith(prefix) and tensor_name[len(prefix) :].isdigit():
175
+ index = int(tensor_name[len(prefix) :])
176
+ suffix = max(suffix, index)
177
+
178
+ suffix += 1 # This is the first available suffix.
179
+ new_name = f"{prefix}{suffix!s}"
180
+
181
+ # Add new value_info objects for these new tensors.
182
+ orig_value_info = self.orig_graph_inputs.get(orig_name) or self.orig_graph_outputs[orig_name]
183
+ value_info_proto = onnx.ValueInfoProto()
184
+ value_info_proto.CopyFrom(orig_value_info)
185
+ value_info_proto.name = new_name
186
+ self.new_value_infos.append(value_info_proto)
187
+
188
+ self.updated_io_names[orig_name] = new_name
189
+ return self.updated_io_names[orig_name]
190
+
191
+
192
+ def update_io_to_channel_last(
193
+ model: onnx.ModelProto,
194
+ inputs_to_update: list[str] | None,
195
+ outputs_to_update: list[str] | None,
196
+ transpose_node_name_prefix: str = "Transpose_channel_",
197
+ transpose_node_name_start_suffix: int = 0,
198
+ ):
199
+ inputs_to_update = set(inputs_to_update or [])
200
+ outputs_to_update = set(outputs_to_update or [])
201
+
202
+ if not inputs_to_update and not outputs_to_update:
203
+ return
204
+
205
+ graph = model.graph
206
+ orig_graph_inputs = {ginput.name: ginput for ginput in graph.input}
207
+ orig_graph_outputs = {goutput.name: goutput for goutput in graph.output}
208
+
209
+ # Check that the user passed in actual input and output names.
210
+ for input_name in inputs_to_update:
211
+ if input_name not in orig_graph_inputs:
212
+ raise ValueError(f"{input_name} is not a graph input")
213
+
214
+ for output_name in outputs_to_update:
215
+ if output_name not in orig_graph_outputs:
216
+ raise ValueError(f"{output_name} is not a graph output")
217
+
218
+ orig_tensor_names = set()
219
+ orig_tensor_names.update(set(orig_graph_inputs))
220
+ orig_tensor_names.update(set(orig_graph_outputs))
221
+ orig_tensor_names.update(input_name for node in graph.node for input_name in node.input if input_name)
222
+
223
+ # Maps original input (or output) name to its updated name used within the graph.
224
+ io_map = InputOutputNameMap(orig_tensor_names, orig_graph_inputs, orig_graph_outputs)
225
+
226
+ # Update each node's inputs/outputs to use the transposed versions.
227
+ for node in graph.node:
228
+ for i in range(len(node.input)):
229
+ if node.input[i] and node.input[i] in inputs_to_update:
230
+ node.input[i] = io_map.get_new_name(node.input[i])
231
+ elif node.input[i] and node.input[i] in outputs_to_update:
232
+ node.input[i] = io_map.get_new_name(node.input[i])
233
+
234
+ for i in range(len(node.output)):
235
+ if node.output[i] in outputs_to_update:
236
+ node.output[i] = io_map.get_new_name(node.output[i])
237
+
238
+ # Update graph inputs to channel-last and a Transpose (to channel-first) after each.
239
+ for g_input_name in inputs_to_update:
240
+ g_input = orig_graph_inputs[g_input_name]
241
+
242
+ if not g_input.type.HasField("tensor_type") or not g_input.type.tensor_type.HasField("shape"):
243
+ raise ValueError(f"Expected input {g_input.name} to have a tensor_type with a shape")
244
+
245
+ input_shape = g_input.type.tensor_type.shape
246
+ input_rank = len(input_shape.dim)
247
+
248
+ if input_rank < 3:
249
+ raise ValueError(f"Expected input {g_input.name} to be of rank >= 3")
250
+
251
+ channel_dim = onnx.TensorShapeProto.Dimension()
252
+ channel_dim.CopyFrom(input_shape.dim[1])
253
+ for i in range(1, input_rank - 1):
254
+ input_shape.dim[i].CopyFrom(input_shape.dim[i + 1])
255
+ input_shape.dim[input_rank - 1].CopyFrom(channel_dim)
256
+
257
+ transpose_perm = list(range(input_rank))
258
+ for i in range(input_rank):
259
+ transpose_perm[i] = i if i < 1 else i - 1
260
+ transpose_perm[1] = input_rank - 1
261
+
262
+ transpose_node = onnx.helper.make_node(
263
+ "Transpose",
264
+ name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}",
265
+ inputs=[g_input.name],
266
+ outputs=[io_map.get_new_name(g_input.name)],
267
+ perm=transpose_perm,
268
+ )
269
+ transpose_node_name_start_suffix += 1
270
+
271
+ graph.node.extend([transpose_node])
272
+
273
+ # Update graph outputs to channel-last and a Transpose (from channel-first) before each.
274
+ for g_output_name in outputs_to_update:
275
+ g_output = orig_graph_outputs[g_output_name]
276
+ if not g_output.type.HasField("tensor_type") or not g_output.type.tensor_type.HasField("shape"):
277
+ raise ValueError(f"Expected output {g_output.name} to have a tensor_type with a shape")
278
+
279
+ output_shape = g_output.type.tensor_type.shape
280
+ output_rank = len(output_shape.dim)
281
+
282
+ if output_rank < 3:
283
+ raise ValueError(f"Expected output {g_output.name} to be of rank >= 3")
284
+
285
+ channel_dim = onnx.TensorShapeProto.Dimension()
286
+ channel_dim.CopyFrom(output_shape.dim[1])
287
+ for i in range(1, output_rank - 1):
288
+ output_shape.dim[i].CopyFrom(output_shape.dim[i + 1])
289
+ output_shape.dim[output_rank - 1].CopyFrom(channel_dim)
290
+
291
+ transpose_perm = list(range(output_rank))
292
+ for i in range(output_rank):
293
+ transpose_perm[i] = i if i == 0 else i + 1
294
+ transpose_perm[output_rank - 1] = 1
295
+
296
+ transpose_node = onnx.helper.make_node(
297
+ "Transpose",
298
+ name=f"{transpose_node_name_prefix}{transpose_node_name_start_suffix!s}",
299
+ inputs=[io_map.get_new_name(g_output.name)],
300
+ outputs=[g_output.name],
301
+ perm=transpose_perm,
302
+ )
303
+ transpose_node_name_start_suffix += 1
304
+
305
+ graph.node.extend([transpose_node])
306
+
307
+ graph.value_info.extend(io_map.new_value_infos)