onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,364 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import json
7
+ from argparse import ArgumentParser
8
+
9
+ import onnx
10
+ from onnx import TensorProto, helper
11
+
12
+
13
+ class QnnTensorStruct:
14
+ def __init__(
15
+ self, name="", onnx_data_type=TensorProto.FLOAT, is_quantized=False, scale=0.0, offset=0, dim=None, id=None
16
+ ):
17
+ self.name = name
18
+ self.onnx_data_type = onnx_data_type
19
+ self.is_quantized = is_quantized
20
+ self.scale = scale
21
+ self.offset = offset
22
+ self.dim = [] if dim is None else dim
23
+ self.id = id
24
+
25
+
26
+ def is_quantized_data_type(qnn_data_type, is_converter_json):
27
+ if is_converter_json:
28
+ # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_FIXED_POINT_16
29
+ return qnn_data_type == 0x0408 or qnn_data_type == 0x0416 or qnn_data_type == 0x0308 or qnn_data_type == 0x0316
30
+ else:
31
+ return (
32
+ qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_8"
33
+ or qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_16"
34
+ or qnn_data_type == "QNN_DATATYPE_FIXED_POINT_8"
35
+ or qnn_data_type == "QNN_DATATYPE_FIXED_POINT_16"
36
+ )
37
+
38
+
39
+ def qnn_data_type_to_onnx_data_type(qnn_data_type, is_converter_json):
40
+ if is_converter_json:
41
+ # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
42
+ if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
43
+ return TensorProto.UINT8
44
+ # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
45
+ elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116:
46
+ return TensorProto.UINT16
47
+ # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
48
+ elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132:
49
+ return TensorProto.UINT32
50
+ # QNN_DATATYPE_UINT_64
51
+ elif qnn_data_type == 0x0164:
52
+ return TensorProto.UINT64
53
+ # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
54
+ elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008:
55
+ return TensorProto.INT8
56
+ # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
57
+ elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016:
58
+ return TensorProto.INT16
59
+ # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
60
+ elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032:
61
+ return TensorProto.INT32
62
+ # QNN_DATATYPE_INT_64
63
+ elif qnn_data_type == 0x0064:
64
+ return TensorProto.INT64
65
+ # QNN_DATATYPE_FLOAT_16
66
+ elif qnn_data_type == 0x0216:
67
+ return TensorProto.FLOAT16
68
+ # QNN_DATATYPE_FLOAT_32
69
+ elif qnn_data_type == 0x0232:
70
+ return TensorProto.FLOAT
71
+ # QNN_DATATYPE_BOOL_8
72
+ elif qnn_data_type == 0x0508:
73
+ return TensorProto.BOOL
74
+ else:
75
+ return TensorProto.UNDEFINED
76
+ else:
77
+ # QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
78
+ if qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_8" or qnn_data_type == "QNN_DATATYPE_UINT_8":
79
+ return TensorProto.UINT8
80
+ # QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
81
+ elif qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_16" or qnn_data_type == "QNN_DATATYPE_UINT_16":
82
+ return TensorProto.UINT16
83
+ # QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
84
+ elif qnn_data_type == "QNN_DATATYPE_UFIXED_POINT_32" or qnn_data_type == "QNN_DATATYPE_UINT_32":
85
+ return TensorProto.UINT32
86
+ # QNN_DATATYPE_UINT_64
87
+ elif qnn_data_type == "QNN_DATATYPE_UINT_64":
88
+ return TensorProto.UINT64
89
+ # QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
90
+ elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_8" or qnn_data_type == "QNN_DATATYPE_INT_8":
91
+ return TensorProto.INT8
92
+ # QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
93
+ elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_16" or qnn_data_type == "QNN_DATATYPE_INT_16":
94
+ return TensorProto.INT16
95
+ # QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
96
+ elif qnn_data_type == "QNN_DATATYPE_FIXED_POINT_32" or qnn_data_type == "QNN_DATATYPE_INT_32":
97
+ return TensorProto.INT32
98
+ # QNN_DATATYPE_INT_64
99
+ elif qnn_data_type == "QNN_DATATYPE_INT_64":
100
+ return TensorProto.INT64
101
+ # QNN_DATATYPE_FLOAT_16
102
+ elif qnn_data_type == "QNN_DATATYPE_FLOAT_16":
103
+ return TensorProto.FLOAT16
104
+ # QNN_DATATYPE_FLOAT_32
105
+ elif qnn_data_type == "QNN_DATATYPE_FLOAT_32":
106
+ return TensorProto.FLOAT
107
+ # QNN_DATATYPE_BOOL_8
108
+ elif qnn_data_type == "QNN_DATATYPE_BOOL_8":
109
+ return TensorProto.BOOL
110
+ else:
111
+ return TensorProto.UNDEFINED
112
+
113
+
114
+ def parse_qnn_converter_json_file(qnn_convert_json, qnn_input_tensor_dic, qnn_output_tensor_dic):
115
+ is_qnn_converter_json = True
116
+ for qnn_tensor_name, qnn_tensor_attribute in qnn_convert_json["graph"]["tensors"].items():
117
+ # type:0 - QNN input tensor, type:1 - QNN output tensor
118
+ assert (
119
+ "type" in qnn_tensor_attribute
120
+ and "data_type" in qnn_tensor_attribute
121
+ and "dims" in qnn_tensor_attribute
122
+ and "id" in qnn_tensor_attribute
123
+ and "quant_params" in qnn_tensor_attribute
124
+ ), "QNN converted json file not valid. Can't find some keys from tensors"
125
+
126
+ # If tensor is not IO, ignore it
127
+ if qnn_tensor_attribute["type"] not in [0, 1]:
128
+ continue
129
+
130
+ # Get all graph inputs & output
131
+ qnn_tensor = QnnTensorStruct(
132
+ name=qnn_tensor_name,
133
+ onnx_data_type=qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json),
134
+ is_quantized=is_quantized_data_type(qnn_tensor_attribute["data_type"], is_qnn_converter_json),
135
+ dim=qnn_tensor_attribute["dims"],
136
+ id=qnn_tensor_attribute["id"],
137
+ )
138
+
139
+ if (
140
+ qnn_tensor_attribute["quant_params"]["definition"] == 1
141
+ and qnn_tensor_attribute["quant_params"]["encoding"] == 0
142
+ ):
143
+ qnn_tensor.scale = qnn_tensor_attribute["quant_params"]["scale_offset"]["scale"]
144
+ qnn_tensor.offset = -qnn_tensor_attribute["quant_params"]["scale_offset"]["offset"]
145
+
146
+ if qnn_tensor_attribute["type"] == 0:
147
+ qnn_input_tensor_dic[qnn_tensor_name] = qnn_tensor
148
+ else:
149
+ qnn_output_tensor_dic[qnn_tensor_name] = qnn_tensor
150
+
151
+ assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, (
152
+ "Converted QNN model not valid. It should have at least 1 input & 1 output."
153
+ )
154
+
155
+
156
+ def generate_wrapper_onnx_file(
157
+ grap_name,
158
+ model_file_name,
159
+ qnn_input_tensor_dic,
160
+ qnn_output_tensor_dic,
161
+ disable_embed_mode,
162
+ qnn_ctx_file,
163
+ quantized_IO,
164
+ qnn_sdk_version="unknown",
165
+ ):
166
+ graph_nodes = []
167
+ ini_list = []
168
+ value_infos = []
169
+
170
+ model_inputs = []
171
+ for qnn_input in sorted(qnn_input_tensor_dic.values(), key=lambda inp: inp.id):
172
+ if qnn_input.is_quantized and not quantized_IO:
173
+ q_scale_input_name = qnn_input.name + "_scale"
174
+ q_offset_input_name = qnn_input.name + "_zp"
175
+ q_scale = helper.make_tensor(q_scale_input_name, TensorProto.FLOAT, [], [qnn_input.scale])
176
+ ini_list.append(q_scale)
177
+ q_offset = helper.make_tensor(q_offset_input_name, qnn_input.onnx_data_type, [], [qnn_input.offset])
178
+ ini_list.append(q_offset)
179
+ input_name = qnn_input.name + "_dq"
180
+
181
+ q_node = helper.make_node(
182
+ "QuantizeLinear",
183
+ name=qnn_input.name,
184
+ inputs=[input_name, q_scale_input_name, q_offset_input_name],
185
+ outputs=[qnn_input.name],
186
+ )
187
+
188
+ graph_nodes.append(q_node)
189
+ model_inputs.append(helper.make_tensor_value_info(input_name, TensorProto.FLOAT, qnn_input.dim))
190
+ value_infos.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
191
+ else:
192
+ model_inputs.append(helper.make_tensor_value_info(qnn_input.name, qnn_input.onnx_data_type, qnn_input.dim))
193
+
194
+ if disable_embed_mode:
195
+ ep_cache_context_content = qnn_ctx_file
196
+ ctx_embed_mode = 0
197
+ else:
198
+ with open(qnn_ctx_file, "rb") as file:
199
+ ep_cache_context_content = file.read()
200
+ ctx_embed_mode = 1
201
+
202
+ qnn_ep_context_node = helper.make_node(
203
+ "EPContext",
204
+ name=grap_name,
205
+ inputs=qnn_input_tensor_dic.keys(),
206
+ outputs=qnn_output_tensor_dic.keys(),
207
+ ep_cache_context=ep_cache_context_content,
208
+ embed_mode=ctx_embed_mode,
209
+ ep_sdk_version=qnn_sdk_version,
210
+ source="Qnn",
211
+ domain="com.microsoft",
212
+ )
213
+ graph_nodes.append(qnn_ep_context_node)
214
+
215
+ model_outputs = []
216
+ for qnn_output in sorted(qnn_output_tensor_dic.values(), key=lambda out: out.id):
217
+ if qnn_output.is_quantized and not quantized_IO:
218
+ dq_scale_input_name = qnn_output.name + "_scale"
219
+ dq_offset_input_name = qnn_output.name + "_zp"
220
+ dq_scale = helper.make_tensor(dq_scale_input_name, TensorProto.FLOAT, [], [qnn_output.scale])
221
+ ini_list.append(dq_scale)
222
+ dq_offset = helper.make_tensor(dq_offset_input_name, qnn_output.onnx_data_type, [], [qnn_output.offset])
223
+ ini_list.append(dq_offset)
224
+ output_name = qnn_output.name + "_dq"
225
+
226
+ dq_node = helper.make_node(
227
+ "DequantizeLinear",
228
+ name=output_name,
229
+ inputs=[qnn_output.name, dq_scale_input_name, dq_offset_input_name],
230
+ outputs=[output_name],
231
+ )
232
+
233
+ graph_nodes.append(dq_node)
234
+ model_outputs.append(helper.make_tensor_value_info(output_name, TensorProto.FLOAT, qnn_output.dim))
235
+ value_infos.append(
236
+ helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim)
237
+ )
238
+ else:
239
+ model_outputs.append(
240
+ helper.make_tensor_value_info(qnn_output.name, qnn_output.onnx_data_type, qnn_output.dim)
241
+ )
242
+
243
+ graph_def = helper.make_graph(graph_nodes, "qnn-onnx-model", model_inputs, model_outputs, ini_list, "", value_infos)
244
+
245
+ model_def = helper.make_model(graph_def, producer_name="MS")
246
+
247
+ onnx.save(model_def, model_file_name)
248
+
249
+
250
+ # parse Qnn graph from the json file that extracted from context binary file
251
+ def parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic):
252
+ is_qnn_converter_json = False
253
+ graph_name = qnn_graph["info"]["graphName"]
254
+ raw_inputs = qnn_graph["info"]["graphInputs"]
255
+ raw_outputs = qnn_graph["info"]["graphOutputs"]
256
+
257
+ for raw_input in raw_inputs:
258
+ tensor_info = raw_input["info"]
259
+ qnn_tensor = QnnTensorStruct()
260
+ qnn_tensor.name = tensor_info["name"]
261
+ qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(tensor_info["dataType"], is_qnn_converter_json)
262
+ qnn_tensor.is_quantized = is_quantized_data_type(tensor_info["dataType"], is_qnn_converter_json)
263
+ qnn_tensor.dim = tensor_info["dimensions"]
264
+ if (
265
+ tensor_info["quantizeParams"]["definition"] == "QNN_DEFINITION_DEFINED"
266
+ and tensor_info["quantizeParams"]["quantizationEncoding"] == "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET"
267
+ ):
268
+ qnn_tensor.scale = tensor_info["quantizeParams"]["scaleOffset"]["scale"]
269
+ qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"]
270
+ qnn_input_tensor_dic[qnn_tensor.name] = qnn_tensor
271
+
272
+ for raw_output in raw_outputs:
273
+ tensor_info = raw_output["info"]
274
+ qnn_tensor = QnnTensorStruct()
275
+ qnn_tensor.name = tensor_info["name"]
276
+ qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(tensor_info["dataType"], is_qnn_converter_json)
277
+ qnn_tensor.is_quantized = is_quantized_data_type(tensor_info["dataType"], is_qnn_converter_json)
278
+ qnn_tensor.dim = tensor_info["dimensions"]
279
+ if (
280
+ tensor_info["quantizeParams"]["definition"] == "QNN_DEFINITION_DEFINED"
281
+ and tensor_info["quantizeParams"]["quantizationEncoding"] == "QNN_QUANTIZATION_ENCODING_SCALE_OFFSET"
282
+ ):
283
+ qnn_tensor.scale = tensor_info["quantizeParams"]["scaleOffset"]["scale"]
284
+ qnn_tensor.offset = 0 - tensor_info["quantizeParams"]["scaleOffset"]["offset"]
285
+ qnn_output_tensor_dic[qnn_tensor.name] = qnn_tensor
286
+
287
+ assert len(qnn_input_tensor_dic) >= 1 and len(qnn_output_tensor_dic) >= 1, (
288
+ "Converted QNN model not valid. It should have at least 1 input & 1 output."
289
+ )
290
+
291
+ return graph_name
292
+
293
+
294
+ # Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file
295
+ # uses channel last data layout and 8 bits or 16 bits for input and output.
296
+ # This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model
297
+ # and inserts Cast, Transpose nodes to Onnx model if required
298
+ def main():
299
+ parser = ArgumentParser("Generate Onnx model which includes the QNN context binary.")
300
+ parser.add_argument("-b", "--qnn_bin", help="Required. Path to Qnn context binary file.", required=True, type=str)
301
+ parser.add_argument(
302
+ "-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str
303
+ )
304
+ parser.add_argument(
305
+ "--disable_embed_mode",
306
+ action="store_true",
307
+ default=False,
308
+ help="Set embed_mode=1 which mean embed Qnn context binary into the onnx model. Otherwise, set context binary file path in the onnx model",
309
+ )
310
+ parser.add_argument(
311
+ "--quantized_IO",
312
+ action="store_true",
313
+ default=False,
314
+ help="QNN converted context binary use quantized data as graph inputs and outputs. Will keep it if quantized_IO=True, otherwise, will insert Q and DQ nodes accordingly to make the graph inputs & outputs as float32 data type.",
315
+ )
316
+ args = parser.parse_args()
317
+
318
+ # Parse Qnn model_net.json file to get the graph input output information
319
+
320
+ with open(args.qnn_json) as qnn_json_file:
321
+ qnn_json_obj = json.load(qnn_json_file)
322
+ if "graph" in qnn_json_obj and "tensors" in qnn_json_obj["graph"]:
323
+ print("This json file is from Qnn converter")
324
+ qnn_input_tensor_dic = {}
325
+ qnn_output_tensor_dic = {}
326
+ parse_qnn_converter_json_file(qnn_json_obj, qnn_input_tensor_dic, qnn_output_tensor_dic)
327
+
328
+ generate_wrapper_onnx_file(
329
+ "QnnContext",
330
+ args.qnn_json.replace(".json", "_qnn_ctx.onnx"),
331
+ qnn_input_tensor_dic,
332
+ qnn_output_tensor_dic,
333
+ args.disable_embed_mode,
334
+ args.qnn_bin,
335
+ args.quantized_IO,
336
+ )
337
+ elif "info" in qnn_json_obj and "graphs" in qnn_json_obj["info"]:
338
+ print("This json file is extracted from QNN context binary file")
339
+ qnn_version = qnn_json_obj["info"]["buildId"]
340
+ for qnn_graph in qnn_json_obj["info"]["graphs"]:
341
+ qnn_input_tensor_dic = {}
342
+ qnn_output_tensor_dic = {}
343
+ graph_name = parse_qnn_graph(qnn_graph, qnn_input_tensor_dic, qnn_output_tensor_dic)
344
+
345
+ ctx_file_name = graph_name + "_qnn_ctx.onnx"
346
+ if not args.quantized_IO:
347
+ ctx_file_name = ctx_file_name.replace(".onnx", "_fp32_io.onnx")
348
+
349
+ generate_wrapper_onnx_file(
350
+ graph_name,
351
+ ctx_file_name,
352
+ qnn_input_tensor_dic,
353
+ qnn_output_tensor_dic,
354
+ args.disable_embed_mode,
355
+ args.qnn_bin,
356
+ args.quantized_IO,
357
+ qnn_version,
358
+ )
359
+ else:
360
+ print("json file unrecoginized.")
361
+
362
+
363
+ if __name__ == "__main__":
364
+ main()
@@ -0,0 +1,165 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ """Provide entry point to preprocess ONNX model especially for QNN."""
7
+
8
+ import argparse
9
+ import pathlib
10
+
11
+ import onnx
12
+
13
+ from onnxruntime.quantization.execution_providers import qnn
14
+
15
+
16
+ def _parse_arguments():
17
+ """Parse cmdline arguments."""
18
+ parser = argparse.ArgumentParser(description="Arguments for QNN model preprocess.")
19
+
20
+ parser.add_argument("--input_model_path", "-i", required=True, help="Path to the input ONNX model.")
21
+ parser.add_argument("--output_model_path", "-o", required=True, help="Path to the output ONNX model.")
22
+
23
+ # Save preprocessed model with external data.
24
+ parser.add_argument(
25
+ "--save_as_external_data",
26
+ action="store_true",
27
+ help="Whether the output model would be saved with external data.",
28
+ )
29
+ parser.add_argument(
30
+ "--all_tensors_to_one_file",
31
+ action="store_true",
32
+ help="Whether to save all external data in one file or save each tensor to a file named with the tensor name.",
33
+ )
34
+ parser.add_argument(
35
+ "--external_data_location",
36
+ help="Filename of the external file where all tensors are saved. The path is relative to the model path.",
37
+ )
38
+ parser.add_argument(
39
+ "--external_data_size_threshold",
40
+ default=1024,
41
+ type=int,
42
+ help="Tensors with data size larger than this threshold are converted to external data.",
43
+ )
44
+ parser.add_argument(
45
+ "--external_data_convert_attribute",
46
+ action="store_true",
47
+ help="Whether to save all tensors, including attribute tensors, to external data.",
48
+ )
49
+
50
+ # Preprocess options.
51
+ parser.add_argument(
52
+ "--fuse_layernorm",
53
+ action="store_true",
54
+ help="Whether to fuse matched sequences into LayerNormalization nodes if possible.",
55
+ )
56
+
57
+ # I/O layouts.
58
+ parser.add_argument(
59
+ "--inputs_to_make_channel_last",
60
+ nargs="+",
61
+ default=None,
62
+ help="List of graph input names to be transposed into channel-last.",
63
+ )
64
+
65
+ parser.add_argument(
66
+ "--outputs_to_make_channel_last",
67
+ nargs="+",
68
+ default=None,
69
+ help="List of graph output names to be transposed into channel-last.",
70
+ )
71
+
72
+ # Fix dynamic input shapes.
73
+ parser.add_argument(
74
+ "--dynamic_input_shapes",
75
+ nargs=2,
76
+ action="append",
77
+ type=str,
78
+ default=None,
79
+ help="Model input name and desired static shape in comma seprated format, for example: 'input' 1,3,256,256",
80
+ )
81
+
82
+ # Exclude initializer from input
83
+ parser.add_argument(
84
+ "--exclude_initializer_from_input",
85
+ action="store_true",
86
+ help="Whether to exclude initializer from input if model.ir_version >= 4",
87
+ )
88
+
89
+ return parser.parse_args()
90
+
91
+
92
+ def qnn_preprocess_model(
93
+ model_input: str | pathlib.Path | onnx.ModelProto,
94
+ model_output: str | pathlib.Path,
95
+ fuse_layernorm: bool = False,
96
+ save_as_external_data: bool = False,
97
+ all_tensors_to_one_file: bool = False,
98
+ external_data_location: str | None = None,
99
+ external_data_size_threshold: int = 1024,
100
+ external_data_convert_attribute: bool = False,
101
+ inputs_to_make_channel_last: list[str] | None = None,
102
+ outputs_to_make_channel_last: list[str] | None = None,
103
+ dynamic_input_shapes: list[tuple[str, str]] | None = None,
104
+ exclude_initializer_from_input: bool = False,
105
+ ) -> bool:
106
+ """Preprocess ONNX model for QNN.
107
+
108
+ Args:
109
+ model_input: A path or ONNX ModelProto specifiying the model to be preprocessed.
110
+ model_output: A path specifying where the preprocessed model to be saved.
111
+ fuse_layernorm: A bool specifying whether to fuse the matched sequence into a single LayerNormalization node.
112
+ Defaults to False.
113
+ save_as_external_data: A bool specifying whether to save model with external data. Defaults to False.
114
+ all_tensors_to_one_file: A bool specifying whether to save all external data in one file or save each tensor to
115
+ a file named with the tensor name. This argument is effective only when `save_as_external_data` is True.
116
+ Defaults to False.
117
+ external_data_location: A str specifying where to save the external data. The path is relative to the model
118
+ path. This argument is effective only when `save_as_external_data` is True. Defaults to the model name.
119
+ external_data_size_threshold: An int specifying the threshold of data size for tensors be saved as external
120
+ data. This argument is effective only when `save_as_external_data` is True. Defaults to 1024.
121
+ external_data_convert_attribute: A bool specifying whether to save all tensors including attributes as external
122
+ data. This argument is effective only when `save_as_external_data` is True. Defaults to False.
123
+ inputs_to_make_channel_last: A list of strs specifying graph input names to be transposed into channel-last.
124
+ Defaults to None.
125
+ outputs_to_make_channel_last: A list of strs specifying graph output names to be transposed into channel-last.
126
+ Defaults to None.
127
+ dynamic_input_shapes: A list of tuples specifying model input name to and its static shape in comma seprated
128
+ format, for example: [('input', '1,3,256,256')]. Defaults to None.
129
+ exclude_initializer_from_input: A bool specifying whether to exclude initializer from input. Defaults to False.
130
+
131
+ Returns:
132
+ A bool indicating whether the model is modified.
133
+ """
134
+ return qnn.qnn_preprocess_model(
135
+ model_input,
136
+ model_output,
137
+ fuse_layernorm=fuse_layernorm,
138
+ save_as_external_data=save_as_external_data,
139
+ all_tensors_to_one_file=all_tensors_to_one_file,
140
+ external_data_location=external_data_location,
141
+ external_data_size_threshold=external_data_size_threshold,
142
+ external_data_convert_attribute=external_data_convert_attribute,
143
+ inputs_to_make_channel_last=inputs_to_make_channel_last,
144
+ outputs_to_make_channel_last=outputs_to_make_channel_last,
145
+ dynamic_input_shapes=dynamic_input_shapes,
146
+ exclude_initializer_from_input=exclude_initializer_from_input,
147
+ )
148
+
149
+
150
+ if __name__ == "__main__":
151
+ args = _parse_arguments()
152
+ qnn_preprocess_model(
153
+ args.input_model_path,
154
+ args.output_model_path,
155
+ fuse_layernorm=args.fuse_layernorm,
156
+ save_as_external_data=args.save_as_external_data,
157
+ all_tensors_to_one_file=args.all_tensors_to_one_file,
158
+ external_data_location=args.external_data_location,
159
+ external_data_size_threshold=args.external_data_size_threshold,
160
+ external_data_convert_attribute=args.external_data_convert_attribute,
161
+ inputs_to_make_channel_last=args.inputs_to_make_channel_last,
162
+ outputs_to_make_channel_last=args.outputs_to_make_channel_last,
163
+ dynamic_input_shapes=args.dynamic_input_shapes,
164
+ exclude_initializer_from_input=args.exclude_initializer_from_input,
165
+ )