onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,520 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import json
9
+ from collections.abc import MutableMapping
10
+ from dataclasses import dataclass
11
+ from typing import Any
12
+
13
+ import onnx
14
+
15
+ from .quant_utils import QuantType
16
+
17
+
18
+ @dataclass
19
+ class QuantTypeInfo: # noqa: PLW1641
20
+ """
21
+ The quantization type information for a tensor override.
22
+ """
23
+
24
+ quant_type: QuantType
25
+ symmetric: bool | None = None # If None, assumes default is used.
26
+ reduce_range: bool | None = None # If None, assumes default is used.
27
+ axis: int | None = None # If None, assumes per-tensor quantization
28
+
29
+ def __eq__(self, other: object):
30
+ if isinstance(other, QuantTypeInfo):
31
+ return (
32
+ self.quant_type == other.quant_type
33
+ and (self.symmetric is None or other.symmetric is None or self.symmetric == other.symmetric)
34
+ and (self.reduce_range is None or other.reduce_range is None or self.reduce_range == other.reduce_range)
35
+ and (self.axis == other.axis)
36
+ )
37
+ return NotImplemented
38
+
39
+ @staticmethod
40
+ def load_from_dict(
41
+ raw_dict: dict[str, Any],
42
+ default_qtype: QuantType | None = None,
43
+ default_symmetric: bool | None = None,
44
+ default_reduce_range: bool | None = None,
45
+ ) -> QuantTypeInfo:
46
+ return QuantTypeInfo(
47
+ raw_dict.get("quant_type", default_qtype),
48
+ raw_dict.get("symmetric", default_symmetric),
49
+ raw_dict.get("reduce_range", default_reduce_range),
50
+ raw_dict.get("axis"),
51
+ )
52
+
53
+ def save_to_dict(self, raw_dict: dict[str, Any]):
54
+ raw_dict["quant_type"] = self.quant_type
55
+ if self.symmetric is not None:
56
+ raw_dict["symmetric"] = self.symmetric
57
+ if self.reduce_range is not None:
58
+ raw_dict["reduce_range"] = self.reduce_range
59
+ if self.axis is not None:
60
+ raw_dict["axis"] = self.axis
61
+
62
+
63
+ class TensorQuantOverridesHelper(MutableMapping):
64
+ """
65
+ Utility wrapper over the tensor quantization overrides passed via extra_options.
66
+ """
67
+
68
+ def __init__(self, raw_overrides: dict[str, list[dict[str, Any]]]):
69
+ self.overrides = raw_overrides
70
+ self.quant_types = None
71
+ self.keys_unsupported_with_scale_zp = {"symmetric", "reduce_range", "rmax", "rmin"}
72
+
73
+ def has_per_tensor_overrides(self, tensor_name: str) -> bool:
74
+ overrides_list = self.overrides.get(tensor_name)
75
+ return overrides_list and "axis" not in overrides_list[0]
76
+
77
+ def has_per_channel_overrides(self, tensor_name: str) -> bool:
78
+ overrides_list = self.overrides.get(tensor_name)
79
+ return overrides_list and "axis" in overrides_list[0]
80
+
81
+ def overrides_scale_zp(self, tensor_name: str) -> bool:
82
+ overrides_list = self.overrides.get(tensor_name)
83
+ return overrides_list and ("scale" in overrides_list[0]) and ("zero_point" in overrides_list[0])
84
+
85
+ def get_per_tensor_overrides(
86
+ self,
87
+ tensor_name: str,
88
+ default_val: dict[str, Any] | None = None,
89
+ ) -> dict[str, Any] | None:
90
+ default_list_val = [default_val] if default_val is not None else None
91
+ overrides_list = self.overrides.get(tensor_name, default_list_val)
92
+ if overrides_list and "axis" in overrides_list[0]:
93
+ raise ValueError(
94
+ f"Expected tensor '{tensor_name}' to use per-tensor quantization overrides, "
95
+ f"but found per-channel overrides."
96
+ )
97
+
98
+ return overrides_list[0] if overrides_list else None
99
+
100
+ def get_per_channel_overrides(
101
+ self,
102
+ tensor_name: str,
103
+ default_val: list[dict[str, Any]] | None = None,
104
+ ) -> list[dict[str, Any]] | None:
105
+ overrides_list = self.overrides.get(tensor_name, default_val)
106
+
107
+ if not overrides_list:
108
+ return None
109
+
110
+ if "axis" not in overrides_list[0]:
111
+ raise ValueError(
112
+ f"Expected tensor '{tensor_name}' to have per-channel quantization overrides (axis value is missing).",
113
+ )
114
+
115
+ return overrides_list
116
+
117
+ def get_quant_types(self) -> set[QuantType]:
118
+ if self.quant_types is not None:
119
+ return self.quant_types
120
+
121
+ self.quant_types = set()
122
+
123
+ if self.overrides:
124
+ for quant_overrides_list in self.overrides.values():
125
+ for quant_overrides in quant_overrides_list:
126
+ if "quant_type" in quant_overrides:
127
+ self.quant_types.add(quant_overrides["quant_type"])
128
+
129
+ if "convert" in quant_overrides and "quant_type" in quant_overrides["convert"]:
130
+ self.quant_types.add(quant_overrides["convert"]["quant_type"])
131
+
132
+ return self.quant_types
133
+
134
+ def _is_valid_per_tensor(
135
+ self,
136
+ initializers,
137
+ default_activation_qtype,
138
+ tensor_name: str,
139
+ quant_overrides: dict[str, Any],
140
+ ) -> tuple[bool, str | None]:
141
+ if not isinstance(quant_overrides, dict):
142
+ return (
143
+ False,
144
+ f"Tensor quantization overrides for '{tensor_name}' are not in a dict",
145
+ )
146
+
147
+ is_initializer = tensor_name in initializers
148
+
149
+ quant_type = quant_overrides.get("quant_type")
150
+ if quant_type:
151
+ self.quant_types.add(quant_type)
152
+
153
+ has_scale = "scale" in quant_overrides
154
+ has_zero_point = "zero_point" in quant_overrides
155
+
156
+ if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
157
+ return (
158
+ False,
159
+ "Must provide both 'scale' and 'zero_point' if one of the overrides is provided",
160
+ )
161
+
162
+ if has_scale:
163
+ keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides))
164
+ if keys:
165
+ return (
166
+ False,
167
+ f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
168
+ )
169
+
170
+ if "reduce_range" in quant_overrides and not is_initializer:
171
+ return (
172
+ False,
173
+ f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}",
174
+ )
175
+
176
+ if "convert" in quant_overrides:
177
+ if is_initializer:
178
+ return False, "Cannot use 'convert' override for initializers"
179
+
180
+ if "quant_type" not in quant_overrides["convert"]:
181
+ return False, f"'convert' options (tensor '{tensor_name}') must specify a 'quant_type'"
182
+
183
+ if "reduce_range" in quant_overrides["convert"]:
184
+ return (
185
+ False,
186
+ f"Option 'reduce_range' is only supported for initializers, not for activation {tensor_name}",
187
+ )
188
+
189
+ convert_quant_type = quant_overrides["convert"]["quant_type"]
190
+ original_quant_type = quant_type if quant_type is not None else default_activation_qtype
191
+ if convert_quant_type == original_quant_type:
192
+ return (
193
+ False,
194
+ f"'convert' quant_type must differ from original quant_type (tensor '{tensor_name}')",
195
+ )
196
+
197
+ convert_has_scale = "scale" in quant_overrides["convert"]
198
+ convert_has_zero_point = "zero_point" in quant_overrides["convert"]
199
+
200
+ if (convert_has_scale and not convert_has_zero_point) or (convert_has_zero_point and not convert_has_scale):
201
+ return (
202
+ False,
203
+ f"Must provide both 'scale' and 'zero_point' if one of the overrides is provided (tensor '{tensor_name}')",
204
+ )
205
+
206
+ if convert_has_scale:
207
+ keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides["convert"]))
208
+ if keys:
209
+ return (
210
+ False,
211
+ f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point' "
212
+ f"(tensor '{tensor_name}')",
213
+ )
214
+
215
+ self.quant_types.add(convert_quant_type)
216
+
217
+ return True, None
218
+
219
+ def _is_valid_per_channel(
220
+ self,
221
+ initializers,
222
+ tensor_name: str,
223
+ quant_overrides_list: list[dict[str, Any]],
224
+ ) -> tuple[bool, str | None]:
225
+ is_initializer = tensor_name in initializers
226
+
227
+ if not is_initializer:
228
+ return (
229
+ False,
230
+ f"Tensor '{tensor_name}' has per-channel overrides, but is not an initializer",
231
+ )
232
+
233
+ axis = quant_overrides_list[0].get("axis")
234
+
235
+ if axis is None:
236
+ return (
237
+ False,
238
+ f"Per-channel overrides for tensor {tensor_name} is missing an 'axis' value in "
239
+ "the first channel dictionary.",
240
+ )
241
+
242
+ weight_shape = list(initializers[tensor_name].dims)
243
+ weight_rank = len(weight_shape)
244
+ norm_axis = axis
245
+ if norm_axis < 0:
246
+ norm_axis += weight_rank
247
+
248
+ if norm_axis < 0 or norm_axis >= len(weight_shape):
249
+ return (
250
+ False,
251
+ f"Axis override value is out-of-bounds for tensor {tensor_name} (rank {len(weight_shape)})",
252
+ )
253
+
254
+ if len(quant_overrides_list) > 1 and len(quant_overrides_list) != weight_shape[norm_axis]:
255
+ return (
256
+ False,
257
+ f"Incorrect number of channel overrides for tensor {tensor_name} (axis {axis}), "
258
+ f"expected {weight_shape[axis]}, but found {len(quant_overrides_list)}.",
259
+ )
260
+
261
+ if "convert" in quant_overrides_list[0]:
262
+ return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}."
263
+
264
+ quant_type = quant_overrides_list[0].get("quant_type")
265
+ if quant_type:
266
+ self.quant_types.add(quant_type)
267
+
268
+ symmetric = quant_overrides_list[0].get("symmetric")
269
+ reduce_range = quant_overrides_list[0].get("reduce_range")
270
+
271
+ has_scale = "scale" in quant_overrides_list[0]
272
+ has_zero_point = "zero_point" in quant_overrides_list[0]
273
+ has_scale_zp = has_scale and has_zero_point
274
+
275
+ if (has_scale and not has_zero_point) or (has_zero_point and not has_scale):
276
+ return (
277
+ False,
278
+ "Must provide both 'scale' and 'zero_point' if one of the overrides is provided",
279
+ )
280
+
281
+ if has_scale_zp:
282
+ keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides_list[0]))
283
+ if keys:
284
+ return (
285
+ False,
286
+ f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
287
+ )
288
+
289
+ has_rmin = "rmin" in quant_overrides_list[0]
290
+ has_rmax = "rmax" in quant_overrides_list[0]
291
+ has_rmin_rmax = has_rmin and has_rmax
292
+ if (has_rmin and not has_rmax) or (not has_rmin and has_rmax):
293
+ return (
294
+ False,
295
+ "Must provide both 'rmin' and 'rmax' if one is provided",
296
+ )
297
+
298
+ for index, quant_overrides in enumerate(quant_overrides_list[1:]):
299
+ if not isinstance(quant_overrides, dict):
300
+ return (
301
+ False,
302
+ f"Tensor quantization overrides at index {index} for '{tensor_name}' are not in a dict",
303
+ )
304
+
305
+ if "convert" in quant_overrides:
306
+ return False, f"Cannot use 'convert' override for initializers, such as {tensor_name}."
307
+
308
+ # For per-channel quantization, all channels must use the same quantization type, axis, symmetric
309
+ # and reduce_range values. And, if specified, they must be present in the first channel dict
310
+ # (i.e., quant_overrides_list[0]).
311
+ if "quant_type" in quant_overrides and quant_type != quant_overrides["quant_type"]:
312
+ return (
313
+ False,
314
+ "Channel quantization types for tensor '{tensor_name}' do not match at index {index}.",
315
+ )
316
+ if "axis" in quant_overrides and axis != quant_overrides["axis"] and norm_axis != quant_overrides["axis"]:
317
+ return (
318
+ False,
319
+ "Channel axis for tensor '{tensor_name}' does not match at index {index}.",
320
+ )
321
+ if "symmetric" in quant_overrides and symmetric != quant_overrides["symmetric"]:
322
+ return (
323
+ False,
324
+ "Channel symmetric value for tensor '{tensor_name}' does not match at index {index}.",
325
+ )
326
+ if "reduce_range" in quant_overrides and reduce_range != quant_overrides["reduce_range"]:
327
+ return (
328
+ False,
329
+ "Channel reduce_range value for tensor '{tensor_name}' does not match at index {index}.",
330
+ )
331
+
332
+ # If override scale/zp, must do so for all channels.
333
+ chan_has_scale_zp = "scale" in quant_overrides and "zero_point" in quant_overrides
334
+
335
+ if has_scale_zp and not chan_has_scale_zp:
336
+ return (
337
+ False,
338
+ "Per-channel overrides that specify scale/zero_point must do so for all channels, "
339
+ f"but tensor '{tensor_name}' is missing them at index {index}.",
340
+ )
341
+
342
+ if chan_has_scale_zp:
343
+ keys = self.keys_unsupported_with_scale_zp.intersection(set(quant_overrides))
344
+ if keys:
345
+ return (
346
+ False,
347
+ f"Tensor override option(s) [{', '.join(keys)}] are invalid with 'scale' and 'zero_point'",
348
+ )
349
+
350
+ # If override rmin/rmax, must do so for all channels.
351
+ chan_has_rmin_rmax = "rmin" in quant_overrides and "rmax" in quant_overrides
352
+ if has_rmin_rmax and not chan_has_rmin_rmax:
353
+ return (
354
+ False,
355
+ "Per-channel overrides that specify rmin/rmax must do so for all channels, "
356
+ f"but tensor '{tensor_name}' is missing them at index {index}.",
357
+ )
358
+
359
+ return True, None
360
+
361
+ def is_valid(
362
+ self,
363
+ initializers: dict[str, onnx.TensorProto],
364
+ activation_names: set[str],
365
+ default_activation_qtype,
366
+ ) -> tuple[bool, str | None]:
367
+ self.quant_types = set()
368
+
369
+ # Validate that compatible/valid overrides are provided.
370
+ if self.overrides:
371
+ for tensor_name, quant_overrides_list in self.overrides.items():
372
+ if tensor_name not in initializers and tensor_name not in activation_names:
373
+ return False, f"Tensor '{tensor_name}' in TensorQuantOverrides is not present in the model"
374
+
375
+ if not isinstance(quant_overrides_list, list):
376
+ return False, f"Tensor quantization overrides for '{tensor_name}' are not in a list"
377
+
378
+ if not quant_overrides_list:
379
+ continue
380
+
381
+ if not isinstance(quant_overrides_list[0], dict):
382
+ return False, f"Tensor quantization overrides at index 0 for '{tensor_name}' are not in a dict"
383
+
384
+ if not quant_overrides_list[0]:
385
+ continue
386
+
387
+ axis = quant_overrides_list[0].get("axis")
388
+ is_per_channel = len(quant_overrides_list) > 1 or axis is not None
389
+
390
+ if is_per_channel:
391
+ return self._is_valid_per_channel(initializers, tensor_name, quant_overrides_list)
392
+
393
+ return self._is_valid_per_tensor(
394
+ initializers, default_activation_qtype, tensor_name, quant_overrides_list[0]
395
+ )
396
+
397
+ return True, None
398
+
399
+ def update_tensor_overrides(
400
+ self,
401
+ tensor_name: str,
402
+ new_vals: dict[str, Any],
403
+ channels: list[int] | None = None,
404
+ overwrite: bool = True,
405
+ ) -> bool:
406
+ if not new_vals:
407
+ return False
408
+
409
+ channels = set(channels) if channels is not None else None
410
+ have_overrides = self.overrides.get(tensor_name)
411
+
412
+ # If `overwrite` is False, check if we would overwrite anything.
413
+ do_update = True
414
+ if not overwrite and have_overrides:
415
+ for channel, overrides in enumerate(self.overrides[tensor_name]):
416
+ if channels is not None and channel not in channels:
417
+ continue
418
+ if set(new_vals).intersection(set(overrides)):
419
+ do_update = False
420
+ break
421
+
422
+ # Do the update if `overwrite` is True or if nothing is overwritten (do not want partial overwrites).
423
+ if do_update:
424
+ if not have_overrides:
425
+ self.overrides[tensor_name] = [{}]
426
+
427
+ for channel, overrides in enumerate(self.overrides[tensor_name]):
428
+ if channels is not None and channel not in channels:
429
+ continue
430
+ overrides.update(new_vals)
431
+
432
+ return do_update
433
+
434
+ def get_node_output_qtype_info(
435
+ self,
436
+ output_name: str,
437
+ default_qtype: QuantType | None,
438
+ default_symmetric: bool | None = None,
439
+ ) -> QuantTypeInfo:
440
+ # Outputs are activations, which do not support 'reduce_range' or 'axis'
441
+ if output_name not in self.overrides:
442
+ return QuantTypeInfo(default_qtype, default_symmetric)
443
+
444
+ tensor_overrides = self.overrides[output_name][0]
445
+
446
+ return QuantTypeInfo(
447
+ tensor_overrides.get("quant_type", default_qtype),
448
+ tensor_overrides.get("symmetric", default_symmetric),
449
+ )
450
+
451
+ def get_node_input_qtype_info(
452
+ self,
453
+ input_name: str,
454
+ node_name: str,
455
+ default_qtype: QuantType | None,
456
+ default_symmetric: bool | None = None,
457
+ default_reduce_range: bool | None = None,
458
+ ) -> QuantTypeInfo:
459
+ if input_name not in self.overrides or not self.overrides[input_name]:
460
+ return QuantTypeInfo(default_qtype, default_symmetric, default_reduce_range)
461
+
462
+ # Get the first overrides dict in the list. This works for both per-tensor and per-channel
463
+ # quantization because all channels must use the same quant type.
464
+ tensor_overrides = self.overrides[input_name][0]
465
+ producer_type = tensor_overrides.get("quant_type", default_qtype)
466
+
467
+ if "convert" not in tensor_overrides:
468
+ return QuantTypeInfo(
469
+ producer_type,
470
+ tensor_overrides.get("symmetric", default_symmetric),
471
+ tensor_overrides.get("reduce_range", default_reduce_range),
472
+ tensor_overrides.get("axis"),
473
+ )
474
+
475
+ # This tensor is converted. Check if the node gets the original qtype or the converted qtype.
476
+ convert_dict = tensor_overrides["convert"]
477
+ qtype_info = QuantTypeInfo(
478
+ producer_type,
479
+ convert_dict.get("symmetric", default_symmetric),
480
+ # Converted tensors are not initializers, so do not have 'axis' or 'reduce_range'.
481
+ )
482
+
483
+ # Check if all nodes receive the converted type (i.e., recv_nodes is None) or this node
484
+ # is in the list of consumers (recv_nodes).
485
+ if ("recv_nodes" not in convert_dict) or (node_name in convert_dict["recv_nodes"]):
486
+ qtype_info.quant_type = convert_dict["quant_type"]
487
+
488
+ return qtype_info
489
+
490
+ def pprint_str(self, indent=None) -> str:
491
+ return json.dumps(self.overrides, default=str, indent=indent)
492
+
493
+ def empty(self) -> bool:
494
+ return not self.overrides
495
+
496
+ def get_dict(self) -> dict[str, list[dict[str, Any]]]:
497
+ return self.overrides
498
+
499
+ # Required implementations of abstract methods in collections.abc.MutableMapping
500
+ # so that this class can be used like a dict.
501
+ def __setitem__(self, key: str, value: list[dict]):
502
+ self.overrides[key] = value
503
+
504
+ def __getitem__(self, key: str) -> list[dict]:
505
+ return self.overrides[key]
506
+
507
+ def __delitem__(self, key: str):
508
+ del self.overrides[key]
509
+
510
+ def __iter__(self):
511
+ return iter(self.overrides)
512
+
513
+ def __len__(self):
514
+ return len(self.overrides)
515
+
516
+ def __str__(self) -> str:
517
+ return str(self.overrides)
518
+
519
+ def __repr__(self) -> str:
520
+ return f"{super().__repr__()}, TensorQuantOverridesHelper({self.overrides})"
@@ -0,0 +1,10 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # appended to the __init__.py in the onnxruntime module's 'tools' folder from /tools/python/util/__init__append.py
6
+ import importlib.util
7
+
8
+ have_torch = importlib.util.find_spec("torch")
9
+ if have_torch:
10
+ from .pytorch_export_helpers import infer_input_info # noqa: F401
@@ -0,0 +1,47 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+
4
+ import argparse
5
+ import logging
6
+ import pathlib
7
+
8
+ # need this before the mobile helper imports for some reason
9
+ logging.basicConfig(format="%(levelname)s: %(message)s")
10
+
11
+ from .mobile_helpers import usability_checker # noqa: E402
12
+
13
+
14
+ def check_usability():
15
+ parser = argparse.ArgumentParser(
16
+ description="""Analyze an ONNX model to determine how well it will work in mobile scenarios.""",
17
+ formatter_class=argparse.ArgumentDefaultsHelpFormatter,
18
+ )
19
+ parser.add_argument("--log_level", choices=["debug", "info"], default="info", help="Logging level")
20
+ parser.add_argument("model_path", help="Path to ONNX model to check", type=pathlib.Path)
21
+
22
+ args = parser.parse_args()
23
+ logger = logging.getLogger("check_usability")
24
+
25
+ if args.log_level == "debug":
26
+ logger.setLevel(logging.DEBUG)
27
+ elif args.log_level == "info":
28
+ logger.setLevel(logging.INFO)
29
+ elif args.log_level == "warning":
30
+ logger.setLevel(logging.WARNING)
31
+ else:
32
+ logger.setLevel(logging.ERROR)
33
+
34
+ try_eps = usability_checker.analyze_model(args.model_path, skip_optimize=False, logger=logger)
35
+
36
+ if try_eps:
37
+ logger.info(
38
+ "As NNAPI or CoreML may provide benefits with this model it is recommended to compare the "
39
+ "performance of the model using the NNAPI EP on Android, and the CoreML EP on iOS, "
40
+ "against the performance using the CPU EP."
41
+ )
42
+ else:
43
+ logger.info("For optimal performance the model should be used with the CPU EP. ")
44
+
45
+
46
+ if __name__ == "__main__":
47
+ check_usability()