onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,738 @@
1
+ # Copyright (c) Microsoft Corporation. All rights reserved.
2
+ # Licensed under the MIT License.
3
+ from __future__ import annotations
4
+
5
+ import argparse
6
+ import logging
7
+ import os
8
+ import pathlib
9
+ import tempfile
10
+ from collections import deque
11
+ from enum import IntEnum
12
+
13
+ import onnx
14
+
15
+ from ..onnx_model_utils import ModelProtoWithShapeInfo, get_producer_consumer_maps, is_fixed_size_tensor, optimize_model
16
+
17
+
18
+ class _SupportedOpsChecker:
19
+ """
20
+ Class to process the md file with list of supported ops and caveats for an execution provider.
21
+ e.g. /tools/ci_build/github/android/nnapi_supported_ops.md
22
+ /tools/ci_build/github/apple/coreml_supported_mlprogram_ops.md
23
+ /tools/ci_build/github/apple/coreml_supported_neuralnetwork_ops.md
24
+ """
25
+
26
+ def __init__(self, filename):
27
+ self._filename = filename
28
+ self._ops = {} # op to caveats
29
+ self._ops_seen = set()
30
+
31
+ with open(filename) as f:
32
+ for line in f:
33
+ # we're looking for a markdown table with 2 columns. first is op name. second is caveats
34
+ # op name is domain:op
35
+ if line.startswith("|"):
36
+ pieces = line.strip().split("|")
37
+ if len(pieces) == 4: # pre-first '|'. op, caveat, post-last '|'
38
+ domain_op = pieces[1]
39
+ caveat = pieces[2]
40
+ caveat = caveat.replace("<br/>", " ") # remove some HTML tags
41
+ # skip lines that don't have the ':' which separates the domain and op
42
+ # e.g. the table header will fail this check
43
+ if ":" in domain_op:
44
+ self._ops[domain_op] = caveat
45
+
46
+ def is_op_supported(self, node):
47
+ domain = node.domain if node.domain else "ai.onnx"
48
+ domain_op = domain + ":" + node.op_type
49
+
50
+ is_supported = domain_op in self._ops
51
+ if is_supported:
52
+ self._ops_seen.add(domain_op)
53
+
54
+ return is_supported
55
+
56
+ def get_caveats(self):
57
+ caveats = []
58
+ for op in sorted(self._ops_seen):
59
+ caveat = self._ops[op]
60
+ if caveat:
61
+ caveats.append(f"{op}:{caveat}")
62
+
63
+ return caveats
64
+
65
+
66
+ class PartitioningInfo:
67
+ class TryWithEP(IntEnum):
68
+ NO = (0,)
69
+ MAYBE = (1,)
70
+ YES = 2
71
+
72
+ def __init__(
73
+ self,
74
+ num_nodes: int,
75
+ num_supported_nodes: int,
76
+ num_partitions: int,
77
+ supported_ops_checker: _SupportedOpsChecker,
78
+ supported_groups: list[onnx.NodeProto],
79
+ unsupported_ops: set[str],
80
+ nodes_unsupported_due_to_op: int,
81
+ nodes_unsupported_due_to_dynamic_input: int,
82
+ num_unsupported_nodes_due_to_rank: int,
83
+ ops_with_unsupported_rank: set[str],
84
+ ):
85
+ self.num_nodes = num_nodes
86
+ self.num_supported_nodes = num_supported_nodes
87
+ self.num_partitions = num_partitions
88
+ self.supported_ops_checker = supported_ops_checker
89
+ self.supported_groups = supported_groups
90
+ self.unsupported_ops = unsupported_ops
91
+ self.nodes_unsupported_due_to_op = nodes_unsupported_due_to_op
92
+ self.nodes_unsupported_due_to_dynamic_input = nodes_unsupported_due_to_dynamic_input
93
+ self.num_unsupported_nodes_due_to_rank = num_unsupported_nodes_due_to_rank
94
+ self.ops_with_unsupported_rank = ops_with_unsupported_rank
95
+
96
+ self.num_subgraphs = 0
97
+ self.num_nodes_in_subgraphs = 0
98
+
99
+ def merge(self, other: PartitioningInfo):
100
+ """
101
+ Merge the information from another PartitioningInfo instance into this one.
102
+ """
103
+ self.num_nodes += other.num_nodes
104
+ self.num_supported_nodes += other.num_supported_nodes
105
+ self.num_partitions += other.num_partitions
106
+ self.supported_groups.extend(other.supported_groups)
107
+ self.unsupported_ops.update(other.unsupported_ops)
108
+ self.nodes_unsupported_due_to_op += other.nodes_unsupported_due_to_op
109
+ self.nodes_unsupported_due_to_dynamic_input += other.nodes_unsupported_due_to_dynamic_input
110
+ self.num_unsupported_nodes_due_to_rank += other.num_unsupported_nodes_due_to_rank
111
+ self.ops_with_unsupported_rank.update(other.ops_with_unsupported_rank)
112
+
113
+ # hard assumption that we merge into the main graph partitioning info
114
+ self.num_subgraphs += 1
115
+ self.num_nodes_in_subgraphs += other.num_nodes
116
+
117
+ def suitability(self):
118
+ # semi-arbitrary choices that err on the side of MAYBE.
119
+ # having 1 partition is always preferred, but if that is small it may not be useful.
120
+ # having 2 partitions may be okay if they cover most nodes
121
+ # more than 2 partitions and the device copy cost is almost guaranteed to outweigh the benefit of using the NPU
122
+ # NOTE: This assumes the EP is not CPU based and there is device copy overhead to consider
123
+ pct_supported = self.num_supported_nodes / self.num_nodes * 100
124
+ if self.num_partitions == 1:
125
+ if pct_supported > 75:
126
+ return PartitioningInfo.TryWithEP.YES
127
+ elif pct_supported > 50:
128
+ return PartitioningInfo.TryWithEP.MAYBE
129
+ else:
130
+ return PartitioningInfo.TryWithEP.NO
131
+
132
+ if self.num_partitions == 2:
133
+ if pct_supported > 75:
134
+ return PartitioningInfo.TryWithEP.MAYBE
135
+ else:
136
+ return PartitioningInfo.TryWithEP.NO
137
+
138
+ return PartitioningInfo.TryWithEP.NO
139
+
140
+ def print_analysis(self, logger: logging.Logger, ep_name: str):
141
+ """
142
+ Analyze the partitioning information and log the analysis
143
+ :param logger: Logger to use
144
+ :param ep_name: Execution provider name to use in the log messages
145
+ """
146
+
147
+ logger.info(
148
+ f"{self.num_partitions} partitions with a total of {self.num_supported_nodes}/{self.num_nodes} "
149
+ f"nodes can be handled by the {ep_name} EP."
150
+ )
151
+
152
+ if self.supported_groups:
153
+ logger.info(
154
+ f"\tPartition sizes: [{', '.join([str(len(partition)) for partition in self.supported_groups])}]"
155
+ )
156
+
157
+ # dump full groups if debug output is enabled
158
+ for group in self.supported_groups:
159
+ logger.debug(f"Nodes in group: {','.join([f'{node.op_type}:{node.name}' for node in group])}")
160
+
161
+ logger.info(f"Unsupported nodes due to operator={self.nodes_unsupported_due_to_op}")
162
+ if self.unsupported_ops:
163
+ logger.info(f"\tUnsupported ops: {','.join(sorted(self.unsupported_ops))}")
164
+
165
+ caveats = self.supported_ops_checker.get_caveats()
166
+ if caveats:
167
+ indent = " " * 5
168
+ logger.info(
169
+ "\tCaveats that have not been checked and may result in a node not actually being supported: "
170
+ f"{''.join([os.linesep + indent + caveat for caveat in caveats])}"
171
+ )
172
+
173
+ if self.nodes_unsupported_due_to_dynamic_input:
174
+ logger.info(
175
+ "Unsupported nodes due to input having a dynamic shape=%d",
176
+ self.nodes_unsupported_due_to_dynamic_input,
177
+ )
178
+
179
+ if self.num_unsupported_nodes_due_to_rank:
180
+ logger.info(f"Unsupported nodes due to rank of input data={self.num_unsupported_nodes_due_to_rank}")
181
+ logger.info(f"\tOps with unsupported rank: {','.join(sorted(self.ops_with_unsupported_rank))}")
182
+
183
+ if self.num_subgraphs > 0:
184
+ # TODO: CoreML has a flag. NNAPI doesn't. Either should be able to support a subgraph when treated as a
185
+ # separate graph (only extra detail would be making sure implicit inputs are handled).
186
+ # Merging the subgraph into the parent graph would be more complex.
187
+ # e.g. for CoreML we could potentially convert Loop to while_loop and If to cond if the subgraphs in the
188
+ # control flow node are fully supported.
189
+ # NNAPI also has While and If.
190
+
191
+ # It most likely will be necessary to support merging in If nodes with fully supported subgraphs,
192
+ # as the subgraphs in those are often very simple, so the performance cost of going to the CPU EP and back
193
+ # is high.
194
+ logger.info(
195
+ f"{self.num_nodes_in_subgraphs} nodes are in {self.num_subgraphs} subgraphs. "
196
+ "Check EP as to whether subgraphs are supported."
197
+ )
198
+
199
+ pct_nodes_using_ep = self.num_supported_nodes / self.num_nodes * 100
200
+ if self.num_partitions == 0:
201
+ logger.info(f"{ep_name} cannot run any nodes in this model.")
202
+ elif self.num_partitions == 1:
203
+ if pct_nodes_using_ep > 75:
204
+ logger.info(
205
+ f"{ep_name} should work well for this model as there is one partition "
206
+ f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model."
207
+ )
208
+ elif pct_nodes_using_ep > 50:
209
+ logger.info(
210
+ f"{ep_name} may work well for this model, however only {pct_nodes_using_ep:.1f}% of nodes "
211
+ "will use it. Performance testing is required to validate."
212
+ )
213
+ else:
214
+ logger.info(
215
+ f"{ep_name} will probably not work will for this model as only {pct_nodes_using_ep:.2f}% "
216
+ "of nodes will use it."
217
+ )
218
+
219
+ elif self.num_partitions == 2 and pct_nodes_using_ep > 75:
220
+ logger.info(
221
+ f"{ep_name} can be considered for this model as there are two partitions "
222
+ f"covering {pct_nodes_using_ep:.1f}% of the nodes. "
223
+ "Performance testing is required to validate."
224
+ )
225
+ else:
226
+ logger.info(
227
+ f"{ep_name} is not recommended with this model as there are {self.num_partitions} partitions "
228
+ f"covering {pct_nodes_using_ep:.1f}% of the nodes in the model. "
229
+ "This will most likely result in worse performance than just using the CPU EP."
230
+ )
231
+
232
+
233
+ def _check_partitioning_for_graph(
234
+ graph: onnx.GraphProto,
235
+ node_to_producers: dict[onnx.NodeProto, set[onnx.NodeProto]],
236
+ node_to_consumers: dict[onnx.NodeProto, set[onnx.NodeProto]],
237
+ supported_ops_checker: _SupportedOpsChecker,
238
+ outer_scope_initializers: set[str],
239
+ require_fixed_input_sizes: bool,
240
+ value_info: dict[str, onnx.ValueInfoProto],
241
+ max_rank: int = 999, # max rank if EP has a limitation
242
+ ):
243
+ # initializers have fixed sizes.
244
+ initializers = [i.name for i in graph.initializer]
245
+
246
+ def _is_fixed_shape_value(value):
247
+ if value in value_info:
248
+ return is_fixed_size_tensor(value_info[value])
249
+
250
+ if value in initializers or value in outer_scope_initializers:
251
+ return True
252
+
253
+ # if something has an unknown shape (e.g. something downstream of a Reshape with dynamic input for the shape)
254
+ # it won't have an entry in value_info
255
+ return False
256
+
257
+ #
258
+ # Replicate logic from /onnxruntime/core/providers/partitioning_utils.cc:CreateSupportedPartitionNodeGroups
259
+ # to roughly estimate number of partitions for nodes that is_node_supported_fn returns true for.
260
+ #
261
+ # We keep the structure and variable names as close as possible to the C++ implementation to simplify keeping them
262
+ # in sync if future updates are needed.
263
+ #
264
+ # NOTE: CreateSupportedPartitionNodeGroups was recently updated to be QDQ aware so that partitions did not split
265
+ # QDQ node groups. This code does not need to be QDQ aware as splitting a QDQ node group does not affect the total
266
+ # number of partitions or supported nodes.
267
+ #
268
+
269
+ # we don't currently support a callback for additional group closure checks in the python implementation
270
+ on_group_closed_fn = None
271
+
272
+ supported_groups = []
273
+ # number of inputs from unprocessed nodes (in-degree) per node
274
+ in_degree = {}
275
+ # nodes that are ready to process
276
+ nodes_to_process = deque() # deque of Node instances
277
+ # nodes that will be processed when considering the next partition node group
278
+ nodes_to_process_with_next_group = deque()
279
+
280
+ # initialize in-degrees and find root nodes
281
+ for node in graph.node:
282
+ node_input_edge_count = len(node_to_producers[node]) if node in node_to_producers else 0
283
+ in_degree[node] = node_input_edge_count
284
+ if node_input_edge_count == 0:
285
+ # node is only dependent on graph input or initializers
286
+ nodes_to_process.append(node)
287
+
288
+ supported_group = []
289
+ # the partition node group's border is the aggregate of its nodes' output nodes
290
+ supported_group_border = set()
291
+ num_supported_nodes = 0
292
+ num_unsupported_nodes_due_to_op = 0
293
+ num_unsupported_nodes_due_to_dynamic_input = 0
294
+ num_unsupported_nodes_due_to_rank = 0
295
+ unsupported_ops = set()
296
+ ops_with_unsupported_rank = set()
297
+
298
+ def close_group():
299
+ if supported_group:
300
+ keep_partition = not on_group_closed_fn or on_group_closed_fn(supported_group)
301
+
302
+ if keep_partition:
303
+ supported_groups.append(supported_group.copy())
304
+
305
+ supported_group.clear()
306
+ supported_group_border.clear()
307
+
308
+ while nodes_to_process or nodes_to_process_with_next_group:
309
+ if not nodes_to_process:
310
+ close_group()
311
+ nodes_to_process = nodes_to_process_with_next_group
312
+ nodes_to_process_with_next_group = deque()
313
+ continue
314
+
315
+ node = nodes_to_process.popleft()
316
+
317
+ is_op_supported = supported_ops_checker.is_op_supported(node)
318
+ is_input_shape_supported = not require_fixed_input_sizes or all(_is_fixed_shape_value(i) for i in node.input)
319
+
320
+ is_rank_supported = True
321
+ if value_info:
322
+ for node_input in node.input:
323
+ if node_input and node_input in value_info and value_info[node_input].type.HasField("tensor_type"):
324
+ input_rank = len(value_info[node_input].type.tensor_type.shape.dim)
325
+ if input_rank > max_rank:
326
+ is_rank_supported = False
327
+ break
328
+
329
+ # special-case if we can infer the rank from the length of the 'perms' Transpose attribute
330
+ # e.g. this works with SegmentAnything where dynamic Reshape operators result in no shape info.
331
+ if node.op_type == "Transpose" and len(node.attribute[0].ints) > max_rank:
332
+ is_rank_supported = False
333
+
334
+ is_node_supported = is_op_supported and is_input_shape_supported and is_rank_supported
335
+
336
+ if not is_node_supported:
337
+ if node in supported_group_border:
338
+ # an unsupported node on the border will be processed after the current partition node group
339
+ # so skip any additional processing/counting here
340
+ nodes_to_process_with_next_group.append(node)
341
+ continue
342
+
343
+ if not is_op_supported:
344
+ unsupported_ops.add(f"{node.domain if node.domain else 'ai.onnx'}:{node.op_type}")
345
+ num_unsupported_nodes_due_to_op += 1
346
+
347
+ if not is_input_shape_supported:
348
+ num_unsupported_nodes_due_to_dynamic_input += 1
349
+
350
+ if not is_rank_supported:
351
+ num_unsupported_nodes_due_to_rank += 1
352
+ ops_with_unsupported_rank.add(f"{node.domain if node.domain else 'ai.onnx'}:{node.op_type}")
353
+
354
+ if is_node_supported:
355
+ num_supported_nodes += 1
356
+
357
+ # add node to the partition node group
358
+ supported_group.append(node)
359
+
360
+ # remove node from the border and add its outputs to the border
361
+ if node in supported_group_border: # noqa: FURB132
362
+ supported_group_border.remove(node)
363
+
364
+ # for each consumer node add to supported_group_border
365
+ if node in node_to_consumers:
366
+ for consumer in node_to_consumers[node]:
367
+ supported_group_border.add(consumer)
368
+
369
+ # adjust in-degrees of the node outputs and add any new nodes to process
370
+ if node in node_to_consumers:
371
+ for consumer in node_to_consumers[node]:
372
+ consumer_node_in_degree = in_degree[consumer]
373
+ consumer_node_in_degree -= 1
374
+ if consumer_node_in_degree == 0:
375
+ nodes_to_process.append(consumer)
376
+
377
+ in_degree[consumer] = consumer_node_in_degree
378
+
379
+ close_group()
380
+
381
+ num_nodes = len(graph.node)
382
+ num_partitions = len(supported_groups)
383
+
384
+ info = PartitioningInfo(
385
+ num_nodes,
386
+ num_supported_nodes,
387
+ num_partitions,
388
+ supported_ops_checker,
389
+ supported_groups,
390
+ unsupported_ops,
391
+ num_unsupported_nodes_due_to_op,
392
+ num_unsupported_nodes_due_to_dynamic_input,
393
+ num_unsupported_nodes_due_to_rank,
394
+ ops_with_unsupported_rank,
395
+ )
396
+
397
+ return info
398
+
399
+
400
+ def check_partitioning(
401
+ main_graph: onnx.GraphProto,
402
+ supported_ops_checker: _SupportedOpsChecker,
403
+ require_fixed_input_sizes: bool,
404
+ max_rank: int = 999,
405
+ ) -> PartitioningInfo:
406
+ """
407
+ Estimate the partitions the graph will be split into for nodes that is_node_supported_fn returns true for.
408
+
409
+ The check on whether a node is supported is purely based on the operator type. Additional limitations
410
+ (e.g. NNAPI EP only supports 2D Conv) are not checked, so partitions may not be 100% accurate. The limitations
411
+ for operators in the partitions are printed so the user can manually check.
412
+ :param main_graph: Graph to process
413
+ :param supported_ops_checker: Checker with info on supported ops.
414
+ :param require_fixed_input_sizes: If True, require that the inputs to a potentially supported node are fixed size
415
+ tensors for it to be considered as supported. This requires
416
+ onnx.shape_inference.infer_shapes to have been run on the model to populate the
417
+ shape information.
418
+ If False, shapes are ignored during the check.
419
+ :param max_rank: Set if EP has a limitation on the rank of tensors it supports.
420
+ :return PartitioningInfo instance with details
421
+ """
422
+
423
+ if require_fixed_input_sizes and len(main_graph.value_info) == 0 and len(main_graph.node) > 1:
424
+ raise ValueError("Run onnx.shape_inference.infer_shapes on the model to populate the shape information.")
425
+
426
+ # create lookup map from ValueInfo for efficiency
427
+ def _update_value_info(graph: onnx.GraphProto, value_to_shape: dict[str, onnx.ValueInfoProto]):
428
+ for v in graph.input:
429
+ value_to_shape[v.name] = v
430
+ for v in graph.output:
431
+ value_to_shape[v.name] = v
432
+ for v in graph.value_info:
433
+ value_to_shape[v.name] = v
434
+
435
+ # the producer/consumer maps are for the entire model
436
+ node_to_producers, node_to_consumers = get_producer_consumer_maps(main_graph)
437
+
438
+ def _check_graph(
439
+ graph: onnx.GraphProto,
440
+ outer_scope_value_info: dict[str, onnx.ValueInfoProto] | None,
441
+ outer_scope_initializers: set[str] | None = None,
442
+ partitioning_info: PartitioningInfo | None = None,
443
+ ) -> PartitioningInfo:
444
+ if outer_scope_value_info is not None:
445
+ # extend value info if we're using it. we replace any value shadowed with a local one
446
+ value_info = outer_scope_value_info.copy()
447
+ _update_value_info(graph, value_info)
448
+ else:
449
+ value_info = {}
450
+
451
+ if outer_scope_initializers is None:
452
+ outer_scope_initializers = set()
453
+
454
+ info = _check_partitioning_for_graph(
455
+ graph,
456
+ node_to_producers,
457
+ node_to_consumers,
458
+ supported_ops_checker,
459
+ outer_scope_initializers,
460
+ require_fixed_input_sizes,
461
+ value_info,
462
+ max_rank,
463
+ )
464
+
465
+ if partitioning_info:
466
+ # merge in subgraph info
467
+ partitioning_info.merge(info)
468
+ else:
469
+ # main graph info
470
+ partitioning_info = info
471
+
472
+ # setup outer scope initializers. we copy the input set as a model may have multiple subgraphs
473
+ # on multiple levels, so we need to keep the set for each descent separate
474
+ subgraph_outer_scope_initializers = set(outer_scope_initializers)
475
+ for initializer in graph.initializer:
476
+ subgraph_outer_scope_initializers.add(initializer.name)
477
+
478
+ for node in graph.node:
479
+ # recurse into nodes with subgraphs
480
+ for attr in node.attribute:
481
+ if attr.HasField("g"):
482
+ subgraph = attr.g
483
+ partitioning_info = _check_graph(
484
+ subgraph, value_info, subgraph_outer_scope_initializers, partitioning_info
485
+ )
486
+
487
+ return partitioning_info
488
+
489
+ aggregated_partitioning_info = _check_graph(main_graph, {} if require_fixed_input_sizes else None)
490
+
491
+ return aggregated_partitioning_info
492
+
493
+
494
+ def _check_ep_partitioning(
495
+ model: onnx.ModelProto, supported_ops_config: pathlib.Path, require_fixed_input_sizes: bool, max_rank: int = 999
496
+ ):
497
+ supported_ops = _SupportedOpsChecker(supported_ops_config)
498
+ partition_info = check_partitioning(model.graph, supported_ops, require_fixed_input_sizes, max_rank)
499
+ return partition_info
500
+
501
+
502
+ def check_nnapi_partitions(model, require_fixed_input_sizes: bool):
503
+ # if we're running in the ORT python package the file should be local. otherwise assume we're running from the
504
+ # ORT repo
505
+ script_dir = pathlib.Path(__file__).parent
506
+ local_config = script_dir / "nnapi_supported_ops.md"
507
+ if local_config.exists():
508
+ config_path = local_config
509
+ else:
510
+ ort_root = script_dir.parents[3]
511
+ config_path = ort_root / "tools" / "ci_build" / "github" / "android" / "nnapi_supported_ops.md"
512
+
513
+ return _check_ep_partitioning(model, config_path, require_fixed_input_sizes)
514
+
515
+
516
+ def check_coreml_partitions(model: onnx.ModelProto, require_fixed_input_sizes: bool, config_filename: str):
517
+ # if we're running in the ORT python package the file should be local. otherwise assume we're running from the
518
+ # ORT repo
519
+ script_dir = pathlib.Path(__file__).parent
520
+ local_config = script_dir / config_filename
521
+ if local_config.exists():
522
+ config_path = local_config
523
+ else:
524
+ ort_root = script_dir.parents[3]
525
+ config_path = ort_root / "tools" / "ci_build" / "github" / "apple" / config_filename
526
+
527
+ max_rank = 5
528
+ return _check_ep_partitioning(model, config_path, require_fixed_input_sizes, max_rank)
529
+
530
+
531
+ def check_shapes(graph: onnx.GraphProto, logger: logging.Logger | None = None):
532
+ """
533
+ Check the shapes of graph inputs, values and graph outputs to determine if they have static or dynamic sizes.
534
+ NNAPI does not support dynamically sized values. CoreML does, but it will most likely cost performance.
535
+ :param graph: Graph to check. If shape inferencing has been run the checks on values will be meaningful.
536
+ :param logger: Optional logger for diagnostic information.
537
+ :return: Tuple of List of inputs with dynamic shapes, Number of dynamic values found
538
+ """
539
+
540
+ # it's OK if the input is dynamically sized and we do a Resize early to a fixed size.
541
+ # it's not good if lots of ops have dynamic inputs
542
+
543
+ num_fixed_values = 0
544
+ num_dynamic_values = 0
545
+
546
+ dynamic_inputs = []
547
+ for i in graph.input:
548
+ if not is_fixed_size_tensor(i):
549
+ dynamic_inputs.append(i)
550
+ # split/join to remove repeated whitespace and newlines from str(i)
551
+ if logger:
552
+ logger.info(f"Input is not a fixed size tensor: {' '.join(str(i).split())}")
553
+ num_dynamic_values += 1
554
+ else:
555
+ num_fixed_values += 1
556
+
557
+ dynamic_outputs = []
558
+ for o in graph.output:
559
+ if not is_fixed_size_tensor(o):
560
+ dynamic_outputs.append(o)
561
+ if logger:
562
+ logger.info(f"Output is not a fixed size tensor: {' '.join(str(o).split())}")
563
+ num_dynamic_values += 1
564
+ else:
565
+ num_fixed_values += 1
566
+
567
+ # check we have value info.
568
+ # special case some test graphs with a single node which only have graph input and output values, and
569
+ # a model where all inputs are dynamic (results in no value_info)
570
+ if not graph.value_info and not (len(graph.node) == 1 or len(dynamic_inputs) == len(graph.input)):
571
+ logger.warning(
572
+ "Unable to check shapes within model. ONNX shape inferencing should be run on the model prior to checking."
573
+ )
574
+
575
+ for vi in graph.value_info:
576
+ if is_fixed_size_tensor(vi):
577
+ num_fixed_values += 1
578
+ else:
579
+ num_dynamic_values += 1
580
+
581
+ if logger:
582
+ logger.info(
583
+ f"Num values with fixed shape={num_fixed_values}. Num values with dynamic shape={num_dynamic_values}"
584
+ )
585
+
586
+ if dynamic_inputs:
587
+ if dynamic_outputs:
588
+ logger.info(
589
+ "Model has dynamic inputs and outputs. Consider re-exporting model with fixed sizes "
590
+ "if NNAPI or CoreML can be used with this model."
591
+ )
592
+ else:
593
+ logger.info(
594
+ """Model has dynamically sized inputs but fixed sized outputs.
595
+ If the sizes become fixed early in the model (e.g. pre-processing of a dynamic input size
596
+ results in a fixed input size for the majority of the model) performance with NNAPI and CoreML,
597
+ if applicable, should not be significantly impacted."""
598
+ )
599
+
600
+ return dynamic_inputs, num_dynamic_values
601
+
602
+
603
+ def checker(model_path: pathlib.Path, logger: logging.Logger):
604
+ model_with_shape_info_wrapper = ModelProtoWithShapeInfo(model_path)
605
+ model_with_shape_info = model_with_shape_info_wrapper.model_with_shape_info
606
+
607
+ dynamic_inputs, num_dynamic_values = check_shapes(model_with_shape_info.graph)
608
+
609
+ def check_ep(ep_name, checker_func):
610
+ logger.info(f"Checking {ep_name}")
611
+
612
+ # check with shape info first so supported nodes takes into account values with dynamic shapes
613
+ require_fixed_input_sizes = True
614
+ partition_info = checker_func(model_with_shape_info, require_fixed_input_sizes)
615
+ if logger.getEffectiveLevel() <= logging.INFO:
616
+ partition_info.print_analysis(logger, ep_name)
617
+
618
+ suitability = partition_info.suitability()
619
+ logger.info(f"Model should perform well with {ep_name} as is: {suitability.name}")
620
+
621
+ if suitability != PartitioningInfo.TryWithEP.YES and dynamic_inputs:
622
+ logger.info("--------")
623
+ logger.info("Checking if model will perform better if the dynamic shapes are fixed...")
624
+ require_fixed_input_sizes = False
625
+ partition_info_with_fixed_shapes = checker_func(model_with_shape_info, require_fixed_input_sizes)
626
+
627
+ if logger.getEffectiveLevel() <= logging.INFO:
628
+ # analyze and log detailed info
629
+ logger.info("Partition information if the model was updated to make the shapes fixed:")
630
+ partition_info_with_fixed_shapes.print_analysis(logger, ep_name)
631
+
632
+ fixed_shape_suitability = partition_info_with_fixed_shapes.suitability()
633
+ logger.info(
634
+ f"Model should perform well with {ep_name} if modified to have fixed input shapes: "
635
+ f"{fixed_shape_suitability.name}"
636
+ )
637
+
638
+ if fixed_shape_suitability != PartitioningInfo.TryWithEP.NO:
639
+ logger.info("Shapes can be altered using python -m onnxruntime.tools.make_dynamic_shape_fixed")
640
+
641
+ if fixed_shape_suitability.value > suitability.value:
642
+ suitability = fixed_shape_suitability
643
+
644
+ logger.info("================")
645
+ logger.info("")
646
+
647
+ return suitability
648
+
649
+ nnapi_suitability = check_ep("NNAPI", check_nnapi_partitions)
650
+
651
+ # Check for NeuralNetwork CoreML model
652
+ def check_nn_coreml(model: onnx.ModelProto, require_fixed_input_sizes):
653
+ return check_coreml_partitions(model, require_fixed_input_sizes, "coreml_supported_neuralnetwork_ops.md")
654
+
655
+ # Check for MLProgram CoreML model
656
+ def check_mlprogram_coreml(model: onnx.ModelProto, require_fixed_input_sizes):
657
+ return check_coreml_partitions(model, require_fixed_input_sizes, "coreml_supported_mlprogram_ops.md")
658
+
659
+ coreml_nn_suitability = check_ep("CoreML NeuralNetwork", check_nn_coreml)
660
+ coreml_mlprogram_suitability = check_ep("CoreML MLProgram", check_mlprogram_coreml)
661
+
662
+ if (
663
+ nnapi_suitability != PartitioningInfo.TryWithEP.YES
664
+ or coreml_nn_suitability != PartitioningInfo.TryWithEP.YES
665
+ or coreml_mlprogram_suitability != PartitioningInfo.TryWithEP.YES
666
+ ) and logger.getEffectiveLevel() > logging.INFO:
667
+ logger.info("Re-run with log level of INFO for more details on the NNAPI/CoreML issues.")
668
+
669
+ return (
670
+ nnapi_suitability != PartitioningInfo.TryWithEP.NO
671
+ or coreml_nn_suitability != PartitioningInfo.TryWithEP.NO
672
+ or coreml_mlprogram_suitability != PartitioningInfo.TryWithEP.NO
673
+ )
674
+
675
+
676
+ def analyze_model(model_path: pathlib.Path, skip_optimize: bool = False, logger: logging.Logger | None = None):
677
+ """
678
+ Analyze the provided model to determine if it's likely to work well with the NNAPI or CoreML Execution Providers
679
+ :param model_path: Model to analyze.
680
+ :param skip_optimize: Skip optimizing to BASIC level before checking. When exporting to ORT format we will do this
681
+ optimization..
682
+ :param logger: Logger for output
683
+ :return: True if either the NNAPI or CoreML Execution Providers may work well with this model.
684
+ """
685
+ if not logger:
686
+ logger = logging.getLogger("usability_checker")
687
+ logger.setLevel(logging.INFO)
688
+
689
+ logger.info(f"Checking {model_path} for usability with ORT Mobile.")
690
+
691
+ with tempfile.TemporaryDirectory() as tmp:
692
+ if not skip_optimize:
693
+ tmp_path = pathlib.Path(tmp) / model_path.name
694
+ optimize_model(model_path, tmp_path, use_external_initializers=True)
695
+ model_path = tmp_path
696
+
697
+ try_eps = checker(model_path.resolve(strict=True), logger)
698
+
699
+ return try_eps
700
+
701
+
702
+ def parse_args():
703
+ parser = argparse.ArgumentParser(
704
+ os.path.basename(__file__), description="""Analyze an ONNX model for usage with the ORT mobile"""
705
+ )
706
+
707
+ parser.add_argument("--log_level", choices=["debug", "info"], default="info", help="Logging level")
708
+ parser.add_argument(
709
+ "--skip_optimize",
710
+ action="store_true",
711
+ help="Don't optimize the model to BASIC level prior to analyzing. "
712
+ "Optimization will occur when exporting the model to ORT format, so in general "
713
+ "should not be skipped unless you have a specific reason to do so.",
714
+ )
715
+ parser.add_argument("model_path", type=pathlib.Path, help="Provide path to ONNX model")
716
+
717
+ return parser.parse_args()
718
+
719
+
720
+ def run_analyze_model():
721
+ args = parse_args()
722
+ logger = logging.getLogger("default")
723
+
724
+ if args.log_level == "debug":
725
+ logger.setLevel(logging.DEBUG)
726
+ elif args.log_level == "info":
727
+ logger.setLevel(logging.INFO)
728
+ elif args.log_level == "warning":
729
+ logger.setLevel(logging.WARNING)
730
+ else:
731
+ logger.setLevel(logging.ERROR)
732
+
733
+ model_path = args.model_path.resolve()
734
+ analyze_model(model_path, args.skip_optimize, logger)
735
+
736
+
737
+ if __name__ == "__main__":
738
+ run_analyze_model()