onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1569 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import itertools
7
+ import logging
8
+ import os
9
+ import sys
10
+ from collections import deque
11
+ from pathlib import Path
12
+ from typing import Dict, List, Optional, Tuple
13
+
14
+ from float16 import convert_float_to_float16
15
+ from onnx import (
16
+ AttributeProto,
17
+ GraphProto,
18
+ ModelProto,
19
+ NodeProto,
20
+ TensorProto,
21
+ ValueInfoProto,
22
+ helper,
23
+ numpy_helper,
24
+ save_model,
25
+ )
26
+ from onnx.external_data_helper import load_external_data_for_tensor, uses_external_data
27
+ from shape_infer_helper import SymbolicShapeInferenceHelper
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ class OnnxModel:
33
+ def __init__(self, model):
34
+ self.initialize(model)
35
+
36
+ def initialize(self, model):
37
+ self.model: ModelProto = model
38
+ self._node_name_suffix: Dict[str, int] = {} # key is node name prefix, value is the last suffix generated
39
+ self.shape_infer_helper: SymbolicShapeInferenceHelper = None
40
+ self.enable_shape_infer: bool = True
41
+ self.all_graphs: Optional[List[GraphProto]] = None
42
+
43
+ # Cache of shape and data type from onnx graph to speed up optimization.
44
+ # Be careful that fusion shall not reuse node output name for different shape/type (in adding/removing nodes)
45
+ # Note that these do not cache the symbolic shape inference result.
46
+ self._dtype_dict: Optional[Dict[str, int]] = None
47
+ self._shape_dict: Optional[Dict[str, List]] = None
48
+
49
+ def disable_shape_inference(self):
50
+ self.enable_shape_infer = False
51
+
52
+ def infer_runtime_shape(self, dynamic_axis_mapping={}, update=False): # noqa: B006
53
+ if self.enable_shape_infer:
54
+ if self.shape_infer_helper is None or update:
55
+ self.shape_infer_helper = SymbolicShapeInferenceHelper(self.model)
56
+
57
+ try:
58
+ if self.shape_infer_helper.infer(dynamic_axis_mapping):
59
+ return self.shape_infer_helper
60
+ except Exception:
61
+ self.enable_shape_infer = False # disable shape inference to suppress same error message.
62
+ print("failed in shape inference", sys.exc_info()[0])
63
+
64
+ return None
65
+
66
+ def input_name_to_nodes(self, exclude_subgraphs=False):
67
+ input_name_to_nodes = {}
68
+ nodes_to_search = self.nodes() if not exclude_subgraphs else self.model.graph.node
69
+ for node in nodes_to_search:
70
+ for input_name in node.input:
71
+ if input_name: # could be empty when it is optional
72
+ if input_name not in input_name_to_nodes:
73
+ input_name_to_nodes[input_name] = [node]
74
+ else:
75
+ input_name_to_nodes[input_name].append(node)
76
+ return input_name_to_nodes
77
+
78
+ def output_name_to_node(self, exclude_subgraphs=False):
79
+ output_name_to_node = {}
80
+ nodes_to_search = self.nodes() if not exclude_subgraphs else self.model.graph.node
81
+ for node in nodes_to_search:
82
+ for output_name in node.output:
83
+ if output_name: # could be empty when it is optional
84
+ output_name_to_node[output_name] = node
85
+ return output_name_to_node
86
+
87
+ def functions(self):
88
+ all_functions = [list(self.model.functions)]
89
+ return all_functions
90
+
91
+ def nodes(self):
92
+ all_nodes = []
93
+ for graph in self.graphs():
94
+ for node in graph.node:
95
+ all_nodes.append(node) # noqa: PERF402
96
+ return all_nodes
97
+
98
+ def graph(self):
99
+ return self.model.graph
100
+
101
+ def graphs(self):
102
+ if self.all_graphs is not None:
103
+ return self.all_graphs
104
+ self.all_graphs = []
105
+ graph_queue = [self.model.graph]
106
+ while graph_queue:
107
+ graph = graph_queue.pop(0)
108
+ self.all_graphs.append(graph)
109
+ for node in graph.node:
110
+ for attr in node.attribute:
111
+ if attr.type == AttributeProto.AttributeType.GRAPH:
112
+ assert isinstance(attr.g, GraphProto)
113
+ graph_queue.append(attr.g)
114
+ if attr.type == AttributeProto.AttributeType.GRAPHS:
115
+ for g in attr.graphs:
116
+ assert isinstance(g, GraphProto)
117
+ graph_queue.append(g)
118
+ return self.all_graphs
119
+
120
+ def get_graphs_input_names(self):
121
+ input_names = []
122
+ for graph in self.graphs():
123
+ for input in graph.input:
124
+ input_names.append(input.name)
125
+ return input_names
126
+
127
+ def get_graphs_output_names(self):
128
+ output_names = []
129
+ for graph in self.graphs():
130
+ for output in graph.output:
131
+ output_names.append(output.name)
132
+ return output_names
133
+
134
+ def get_graph_by_node(self, node):
135
+ for graph in self.graphs():
136
+ if node in graph.node:
137
+ return graph
138
+ return None
139
+
140
+ def get_graph_by_name(self, graph_name):
141
+ for graph in self.graphs():
142
+ if graph_name == graph.name:
143
+ return graph
144
+ return None
145
+
146
+ def get_topological_insert_id(self, graph, outputs):
147
+ for idx, node in enumerate(graph.node):
148
+ for input in node.input:
149
+ if input in outputs:
150
+ return idx
151
+ return len(graph.node)
152
+
153
+ def remove_node(self, node):
154
+ for graph in self.graphs():
155
+ if node in graph.node:
156
+ graph.node.remove(node)
157
+ return
158
+ logger.warning("Failed to remove node %s", node) # It might be a bug to hit this line.
159
+
160
+ def remove_nodes(self, nodes_to_remove):
161
+ for node in nodes_to_remove:
162
+ self.remove_node(node)
163
+
164
+ def add_node(self, node, graph_name=None):
165
+ if graph_name is None or graph_name == self.model.graph.name:
166
+ self.model.graph.node.extend([node])
167
+ else:
168
+ graph = self.get_graph_by_name(graph_name)
169
+ insert_idx = self.get_topological_insert_id(graph, node.output)
170
+ graph.node.insert(insert_idx, node)
171
+
172
+ def add_nodes(self, nodes_to_add, node_name_to_graph_name=None):
173
+ if node_name_to_graph_name is None:
174
+ self.model.graph.node.extend(nodes_to_add)
175
+ else:
176
+ for node in nodes_to_add:
177
+ graph_name = node_name_to_graph_name[node.name]
178
+ self.add_node(node, graph_name)
179
+
180
+ def add_initializer(self, tensor, graph_name=None):
181
+ if graph_name is None or graph_name == self.model.graph.name:
182
+ self.model.graph.initializer.extend([tensor])
183
+ else:
184
+ graph = self.get_graph_by_name(graph_name)
185
+ graph.initializer.extend([tensor])
186
+
187
+ def add_input(self, input, graph_name=None):
188
+ if graph_name is None or graph_name == self.model.graph.name:
189
+ self.model.graph.input.extend([input])
190
+ else:
191
+ graph = self.get_graph_by_name(graph_name)
192
+ graph.input.extend([input])
193
+
194
+ @staticmethod
195
+ def replace_node_input(node, old_input_name, new_input_name):
196
+ assert isinstance(old_input_name, str) and isinstance(new_input_name, str)
197
+ for j in range(len(node.input)):
198
+ if node.input[j] == old_input_name:
199
+ node.input[j] = new_input_name
200
+
201
+ def replace_input_of_all_nodes(self, old_input_name, new_input_name):
202
+ for node in self.model.graph.node:
203
+ OnnxModel.replace_node_input(node, old_input_name, new_input_name)
204
+
205
+ @staticmethod
206
+ def replace_node_output(node, old_output_name, new_output_name):
207
+ assert isinstance(old_output_name, str) and isinstance(new_output_name, str)
208
+ for j in range(len(node.output)):
209
+ if node.output[j] == old_output_name:
210
+ node.output[j] = new_output_name
211
+
212
+ def replace_output_of_all_nodes(self, old_output_name, new_output_name):
213
+ # This function shall be used carefully. For example:
214
+ # Add --[old_name]--> Cast ---> [new_name]
215
+ # |
216
+ # +----[old_name]--> Transpose -->
217
+ # If we want to remove the Cast node: replace output of Add to new_name is not enough;
218
+ # The input of Transpose shall also be updated to new_name.
219
+ for node in self.model.graph.node:
220
+ OnnxModel.replace_node_output(node, old_output_name, new_output_name)
221
+
222
+ def get_initializer(self, name):
223
+ for graph in self.graphs():
224
+ for tensor in graph.initializer:
225
+ if tensor.name == name:
226
+ return tensor
227
+ return None
228
+
229
+ def get_nodes_by_op_type(self, op_type):
230
+ nodes = []
231
+ for node in self.nodes():
232
+ if node.op_type == op_type:
233
+ nodes.append(node)
234
+ return nodes
235
+
236
+ def get_children(self, node, input_name_to_nodes=None):
237
+ if input_name_to_nodes is None:
238
+ input_name_to_nodes = self.input_name_to_nodes()
239
+
240
+ children = []
241
+ for output in node.output:
242
+ if output in input_name_to_nodes:
243
+ for node in input_name_to_nodes[output]:
244
+ children.append(node) # noqa: PERF402
245
+ return children
246
+
247
+ def get_parents(self, node, output_name_to_node=None):
248
+ if output_name_to_node is None:
249
+ output_name_to_node = self.output_name_to_node()
250
+
251
+ parents = []
252
+ for input in node.input:
253
+ if input in output_name_to_node:
254
+ parents.append(output_name_to_node[input])
255
+ return parents
256
+
257
+ def get_parent(self, node, i, output_name_to_node=None):
258
+ if output_name_to_node is None:
259
+ output_name_to_node = self.output_name_to_node()
260
+
261
+ if len(node.input) <= i:
262
+ return None
263
+
264
+ input = node.input[i]
265
+ if input not in output_name_to_node:
266
+ return None
267
+
268
+ return output_name_to_node[input]
269
+
270
+ def match_first_parent(self, node, parent_op_type, output_name_to_node, exclude=[]): # noqa: B006
271
+ """
272
+ Find parent node based on constraints on op_type.
273
+
274
+ Args:
275
+ node (str): current node name.
276
+ parent_op_type (str): constraint of parent node op_type.
277
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
278
+ exclude (list): list of nodes that are excluded (not allowed to match as parent).
279
+
280
+ Returns:
281
+ parent: The matched parent node. None if not found.
282
+ index: The input index of matched parent node. None if not found.
283
+ """
284
+ for i, input in enumerate(node.input):
285
+ if input in output_name_to_node:
286
+ parent = output_name_to_node[input]
287
+ if parent.op_type == parent_op_type and parent not in exclude:
288
+ return parent, i
289
+ else:
290
+ logger.debug(f"To find first {parent_op_type}, current {parent.op_type}")
291
+ return None, None
292
+
293
+ def match_parent(
294
+ self,
295
+ node,
296
+ parent_op_type,
297
+ input_index=None,
298
+ output_name_to_node=None,
299
+ exclude=[], # noqa: B006
300
+ return_indice=None,
301
+ ):
302
+ """
303
+ Find parent node based on constraints on op_type and index.
304
+ When input_index is None, we will find the first parent node based on constraints,
305
+ and return_indice will be appended the corresponding input index.
306
+
307
+ Args:
308
+ node (str): current node name.
309
+ parent_op_type (str): constraint of parent node op_type.
310
+ input_index (int or None): only check the parent given input index of current node.
311
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
312
+ exclude (list): list of nodes that are excluded (not allowed to match as parent).
313
+ return_indice (list): a list to append the input index when input_index is None.
314
+
315
+ Returns:
316
+ parent: The matched parent node.
317
+ """
318
+ assert node is not None
319
+ assert input_index is None or input_index >= 0
320
+
321
+ if output_name_to_node is None:
322
+ output_name_to_node = self.output_name_to_node()
323
+
324
+ if input_index is None:
325
+ parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
326
+ if return_indice is not None:
327
+ return_indice.append(index)
328
+ return parent
329
+
330
+ if input_index >= len(node.input):
331
+ logger.debug(f"input_index {input_index} >= node inputs {len(node.input)}")
332
+ return None
333
+
334
+ parent = self.get_parent(node, input_index, output_name_to_node)
335
+ if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
336
+ return parent
337
+
338
+ if parent is not None:
339
+ logger.debug(f"Expect {parent_op_type}, Got {parent.op_type}")
340
+
341
+ return None
342
+
343
+ def match_parent_paths(self, node, paths, output_name_to_node):
344
+ for i, path in enumerate(paths):
345
+ assert isinstance(path, (List, Tuple))
346
+ return_indice = []
347
+ matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
348
+ if matched:
349
+ return i, matched, return_indice
350
+ return -1, None, None
351
+
352
+ def match_parent_paths_all(self, node, paths, output_name_to_node):
353
+ match_i, matches, return_indices = [], [], []
354
+ for i, path in enumerate(paths):
355
+ assert isinstance(path, (List, Tuple))
356
+ return_indice = []
357
+ matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
358
+ if matched:
359
+ match_i.append(i)
360
+ matches.append(matched)
361
+ return_indices.append(return_indice)
362
+ return match_i, matches, return_indices
363
+
364
+ def match_parent_path(
365
+ self,
366
+ node,
367
+ parent_op_types,
368
+ parent_input_index=None,
369
+ output_name_to_node=None,
370
+ return_indice=None,
371
+ ):
372
+ """
373
+ Find a sequence of input edges based on constraints on parent op_type and index.
374
+ When input_index is None, we will find the first parent node based on constraints,
375
+ and return_indice will be appended the corresponding input index.
376
+
377
+ Args:
378
+ node (str): current node name.
379
+ parent_op_types (str): constraint of parent node op_type of each input edge.
380
+ parent_input_index (list): constraint of input index of each input edge. None means no constraint.
381
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
382
+ return_indice (list): a list to append the input index
383
+ When there is no constraint on input index of an edge.
384
+
385
+ Returns:
386
+ parents: a list of matched parent node.
387
+ """
388
+ if parent_input_index is not None:
389
+ assert len(parent_input_index) == len(parent_op_types)
390
+
391
+ if output_name_to_node is None:
392
+ output_name_to_node = self.output_name_to_node()
393
+
394
+ current_node = node
395
+ matched_parents = []
396
+ for i, op_type in enumerate(parent_op_types):
397
+ matched_parent = self.match_parent(
398
+ current_node,
399
+ op_type,
400
+ parent_input_index[i] if parent_input_index is not None else None,
401
+ output_name_to_node,
402
+ exclude=[],
403
+ return_indice=return_indice,
404
+ )
405
+ if matched_parent is None:
406
+ if parent_input_index is not None:
407
+ logger.debug(
408
+ f"Failed to match index={i} parent_input_index={parent_input_index[i]} op_type={op_type}",
409
+ stack_info=True,
410
+ )
411
+ else:
412
+ logger.debug(f"Failed to match index={i} op_type={op_type}", stack_info=True)
413
+ return None
414
+
415
+ matched_parents.append(matched_parent)
416
+ current_node = matched_parent
417
+
418
+ return matched_parents
419
+
420
+ def find_first_child_by_type(self, node, child_type, input_name_to_nodes=None, recursive=True):
421
+ children = self.get_children(node, input_name_to_nodes)
422
+ dq = deque(children)
423
+ while len(dq) > 0:
424
+ current_node = dq.pop()
425
+ if current_node.op_type == child_type:
426
+ return current_node
427
+
428
+ if recursive:
429
+ children = self.get_children(current_node, input_name_to_nodes)
430
+ for child in children:
431
+ dq.appendleft(child)
432
+
433
+ return None
434
+
435
+ def match_child_path(
436
+ self,
437
+ node,
438
+ child_op_types,
439
+ child_output_index=None,
440
+ return_indice=None,
441
+ exclude=[], # noqa: B006
442
+ ):
443
+ """
444
+ Find a sequence of input edges based on constraints on parent op_type and index.
445
+ When input_index is None, we will find the first parent node based on constraints,
446
+ and return_indice will be appended the corresponding input index.
447
+
448
+ Args:
449
+ node (str): current node name.
450
+ child_op_types (str): constraint of child node op_type of each input edge.
451
+ child_output_index (list): constraint of input index of each input edge. None means no constraint.
452
+ return_indice (list): a list to append the input index
453
+ When there is no constraint on input index of an edge.
454
+
455
+ Returns:
456
+ children: a list of matched children node.
457
+ """
458
+ if child_output_index is not None:
459
+ assert len(child_output_index) == len(child_op_types)
460
+
461
+ current_node = node
462
+ matched_children = []
463
+ for i, op_type in enumerate(child_op_types):
464
+ matched_child = None
465
+ node_children = self.get_children(current_node)
466
+ for child_i, child in enumerate(node_children):
467
+ if child.op_type == op_type and child not in exclude:
468
+ if child_output_index is not None and child_output_index[i] != child_i:
469
+ logger.debug(
470
+ f"Failed to match index={i} child_output_index={child_output_index[i]} op_type={op_type}",
471
+ stack_info=True,
472
+ )
473
+ return None
474
+ matched_child = child
475
+ if matched_child is None:
476
+ logger.debug(f"Failed to match child op_type={op_type}", stack_info=True)
477
+ return None
478
+
479
+ matched_children.append(matched_child)
480
+ current_node = matched_child
481
+ return matched_children
482
+
483
+ def find_first_parent_by_type(self, node, parent_type, output_name_to_node=None, recursive=True):
484
+ if output_name_to_node is None:
485
+ output_name_to_node = self.output_name_to_node()
486
+
487
+ parents = self.get_parents(node, output_name_to_node)
488
+ dq = deque(parents)
489
+ while len(dq) > 0:
490
+ current_node = dq.pop()
491
+ if current_node.op_type == parent_type:
492
+ return current_node
493
+
494
+ if recursive:
495
+ parents = self.get_parents(current_node, output_name_to_node)
496
+ for parent in parents:
497
+ dq.appendleft(parent)
498
+
499
+ return None
500
+
501
+ def get_constant_value(self, output_name):
502
+ for node in self.get_nodes_by_op_type("Constant"):
503
+ if node.output[0] == output_name:
504
+ for att in node.attribute:
505
+ if att.name == "value":
506
+ return numpy_helper.to_array(att.t)
507
+
508
+ # Fall back to intializer since constant folding might have been applied.
509
+ initializer = self.get_initializer(output_name)
510
+ if initializer is not None:
511
+ return numpy_helper.to_array(initializer)
512
+
513
+ return None
514
+
515
+ def get_constant_input(self, node):
516
+ for i, input in enumerate(node.input):
517
+ value = self.get_constant_value(input)
518
+ if value is not None:
519
+ return i, value
520
+
521
+ return None, None
522
+
523
+ def find_constant_input(self, node, expected_value, delta=0.000001):
524
+ i, value = self.get_constant_input(node)
525
+ if value is not None and value.size == 1 and abs(value - expected_value) < delta:
526
+ return i
527
+
528
+ return -1
529
+
530
+ def is_constant_with_specified_dimension(self, output_name, dimensions, description):
531
+ value = self.get_constant_value(output_name)
532
+ if value is None:
533
+ logger.debug(f"{description} {output_name} is not initializer.")
534
+ return False
535
+
536
+ if len(value.shape) != dimensions:
537
+ logger.debug(f"{description} {output_name} shall have {dimensions} dimensions. Got shape {value.shape}")
538
+ return False
539
+
540
+ return True
541
+
542
+ def has_constant_input(self, node, expected_value, delta=0.000001):
543
+ return self.find_constant_input(node, expected_value, delta) >= 0
544
+
545
+ def get_children_subgraph_nodes(self, root_node, stop_nodes, input_name_to_nodes=None):
546
+ if input_name_to_nodes is None:
547
+ input_name_to_nodes = self.input_name_to_nodes()
548
+
549
+ children = input_name_to_nodes[root_node.output[0]]
550
+
551
+ unique_nodes = []
552
+
553
+ dq = deque(children)
554
+ while len(dq) > 0:
555
+ current_node = dq.pop()
556
+ if current_node in stop_nodes:
557
+ continue
558
+
559
+ if current_node not in unique_nodes:
560
+ unique_nodes.append(current_node)
561
+
562
+ for output in current_node.output:
563
+ if output in input_name_to_nodes:
564
+ children = input_name_to_nodes[output]
565
+ for child in children:
566
+ dq.appendleft(child)
567
+
568
+ return unique_nodes
569
+
570
+ def tensor_shape_to_list(self, tensor_type):
571
+ """Convert tensor shape to list"""
572
+ shape_list = []
573
+ for d in tensor_type.shape.dim:
574
+ if d.HasField("dim_value"):
575
+ shape_list.append(d.dim_value) # known dimension
576
+ elif d.HasField("dim_param"):
577
+ shape_list.append(d.dim_param) # unknown dimension with symbolic name
578
+ else:
579
+ shape_list.append("?") # shall not happen
580
+ return shape_list
581
+
582
+ def get_dtype(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None):
583
+ """Try get data type given a name (could be initializer, input or output of graph or node)."""
584
+
585
+ if self._dtype_dict is None:
586
+ self._dtype_dict = {}
587
+ for value_info in itertools.chain(
588
+ self.model.graph.value_info,
589
+ self.model.graph.input,
590
+ self.model.graph.output,
591
+ ):
592
+ self._dtype_dict[value_info.name] = value_info.type.tensor_type.elem_type
593
+
594
+ for initializer in self.model.graph.initializer:
595
+ if initializer.name not in self._dtype_dict:
596
+ self._dtype_dict[initializer.name] = initializer.data_type
597
+
598
+ if name in self._dtype_dict:
599
+ return self._dtype_dict[name]
600
+
601
+ if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_:
602
+ value_info = symbolic_shape_helper.known_vi_[name]
603
+ return value_info.type.tensor_type.elem_type
604
+
605
+ return None
606
+
607
+ def get_shape(self, name: str, symbolic_shape_helper: Optional[SymbolicShapeInferenceHelper] = None):
608
+ """Try get shape given a name (could be initializer, input or output of graph or node)."""
609
+
610
+ if self._shape_dict is None:
611
+ self._shape_dict = {}
612
+ for value_info in itertools.chain(
613
+ self.model.graph.value_info,
614
+ self.model.graph.input,
615
+ self.model.graph.output,
616
+ ):
617
+ if value_info.type.tensor_type.HasField("shape"):
618
+ shape = []
619
+ for dim in value_info.type.tensor_type.shape.dim:
620
+ if dim.dim_param:
621
+ shape.append(dim.dim_param)
622
+ else:
623
+ shape.append(dim.dim_value)
624
+ self._shape_dict[value_info.name] = shape
625
+
626
+ for initializer in self.model.graph.initializer:
627
+ if initializer.name not in self._shape_dict:
628
+ self._shape_dict[initializer.name] = initializer.dims
629
+
630
+ if name in self._shape_dict:
631
+ return self._shape_dict[name]
632
+
633
+ if symbolic_shape_helper is not None and name in symbolic_shape_helper.known_vi_:
634
+ value_info = symbolic_shape_helper.known_vi_[name]
635
+ return value_info.type.tensor_type.elem_type
636
+
637
+ return None
638
+
639
+ @staticmethod
640
+ def get_node_attribute(node: NodeProto, attribute_name: str):
641
+ for attr in node.attribute:
642
+ if attr.name == attribute_name:
643
+ value = helper.get_attribute_value(attr)
644
+ return value
645
+ return None
646
+
647
+ def remove_cascaded_cast_nodes(self):
648
+ """Remove Cast node that are followed by another Cast node like --> Cast --> Cast -->
649
+ Note that this shall be used carefully since it might introduce semantic change.
650
+ For example, float -> int -> float could get different value than the original float value.
651
+ So, it is recommended to used only in post-processing of mixed precision conversion.
652
+ """
653
+ output_name_to_node = self.output_name_to_node()
654
+ removed_count = 0
655
+ for node in self.nodes():
656
+ if node.op_type == "Cast":
657
+ parent = self.get_parent(node, 0, output_name_to_node=output_name_to_node)
658
+ if parent and parent.op_type == "Cast":
659
+ node.input[0] = parent.input[0]
660
+ removed_count += 1
661
+
662
+ if removed_count > 0:
663
+ logger.info("Removed %d cascaded Cast nodes", removed_count)
664
+ self.prune_graph()
665
+
666
+ def remove_useless_cast_nodes(self):
667
+ """Remove cast nodes that are not needed: input and output has same data type."""
668
+ shape_infer = self.infer_runtime_shape(update=True)
669
+ if self.enable_shape_infer and shape_infer is None:
670
+ logger.warning("shape inference failed which might impact useless cast node detection.")
671
+
672
+ nodes_to_remove = []
673
+ for node in self.nodes():
674
+ if node.op_type == "Cast":
675
+ input_dtype = self.get_dtype(node.input[0], shape_infer)
676
+ output_dtype = self.get_dtype(node.output[0], shape_infer)
677
+ if input_dtype and input_dtype == output_dtype:
678
+ nodes_to_remove.append(node)
679
+
680
+ if nodes_to_remove:
681
+ graph_input_names = set(self.get_graphs_input_names())
682
+ graph_output_names = set(self.get_graphs_output_names())
683
+ for node in nodes_to_remove:
684
+ if bool(set(node.output) & graph_output_names):
685
+ if (not bool(set(node.input) & graph_input_names)) and len(
686
+ self.input_name_to_nodes()[node.input[0]]
687
+ ) == 1:
688
+ self.replace_output_of_all_nodes(node.input[0], node.output[0])
689
+ else:
690
+ continue
691
+ else:
692
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
693
+ self.remove_node(node)
694
+
695
+ logger.info(
696
+ "Removed %d Cast nodes with output type same as input",
697
+ len(nodes_to_remove),
698
+ )
699
+
700
+ def convert_model_float32_to_float16(self, cast_input_output=True):
701
+ logger.warning(
702
+ "The function convert_model_float32_to_float16 is deprecated. Use convert_float_to_float16 instead!"
703
+ )
704
+ self.convert_float_to_float16(use_symbolic_shape_infer=True, keep_io_types=cast_input_output)
705
+
706
+ def convert_float_to_float16(self, use_symbolic_shape_infer=True, **kwargs):
707
+ """Convert a model to half (default) or mixed precision.
708
+ To use mixed precision, user need specify which graph inputs, outputs, operator type
709
+ or list of nodes shall keep in float32.
710
+
711
+ Note that the conversion might not proceed without type information for the whole graph.
712
+
713
+ By default, we use symbolic shape inference to get type information. The benefit of symbolic shape inference
714
+ is that it could handle fused operators in com.microsoft domain. Those operators cannot be handled in onnx shape
715
+ inference so symbolic shape inference is recommended for optimized model.
716
+
717
+ When symbolic shape inference is used (even if it failed), ONNX shape inference will be disabled.
718
+
719
+ Note that onnx shape inference will fail for model larger than 2GB. For large model, you have to enable
720
+ symbolic shape inference. If your model is not optimized, you can also use model path to call
721
+ convert_float_to_float16 in float16.py (see https://github.com/microsoft/onnxruntime/pull/15067) to
722
+ avoid the 2GB limit.
723
+
724
+ Args:
725
+ use_symbolic_shape_infer (bool, optional): use symbolic shape inference instead of onnx shape inference.
726
+ Defaults to True.
727
+ keep_io_types (Union[bool, List[str]], optional): boolean or a list of float32 input/output names.
728
+ If True, model inputs/outputs should be left as float32.
729
+ Defaults to True.
730
+ op_block_list (List[str], optional): List of operator types to leave as float32.
731
+ Defaults to None, which will use `float16.DEFAULT_OP_BLOCK_LIST`.
732
+ node_block_list (List[str], optional): List of node names to leave as float32. Defaults to None.
733
+ force_fp16_initializers(bool): force converting all float initializers to float16.
734
+ Default to false.
735
+ min_positive_val (float, optional): minimal positive value. Defaults to 1e-7.
736
+ max_finite_val (float, optional): maximal finite value. Defaults to 1e4.
737
+ force_fp16_inputs(Dict[str, List[int]]): Force the conversion of the inputs of some operators to float16, even if
738
+ this script's preference it to keep them in float32.
739
+ """
740
+ if "keep_io_types" not in kwargs:
741
+ kwargs["keep_io_types"] = True
742
+
743
+ model = self.model
744
+ if use_symbolic_shape_infer:
745
+ # Use symbolic shape inference since custom operators (like Gelu, SkipLayerNormalization etc)
746
+ # are not recognized by onnx shape inference.
747
+ shape_infer_helper = SymbolicShapeInferenceHelper(model)
748
+ try:
749
+ model_with_shape = shape_infer_helper.infer_shapes(model, auto_merge=True, guess_output_rank=False)
750
+
751
+ # auto_merge might cause issue (see https://github.com/microsoft/onnxruntime/issues/15521)
752
+ # we only merge tensor data type but not shape information back to the original onnx model.
753
+ # Note that float16 conversion need data type but not shape information.
754
+ if model_with_shape is not None:
755
+ name_vi = {}
756
+ for vi in model_with_shape.graph.value_info:
757
+ if (
758
+ hasattr(vi.type, "tensor_type")
759
+ and hasattr(vi.type.tensor_type, "elem_type")
760
+ and vi.type.tensor_type.elem_type != TensorProto.UNDEFINED
761
+ and vi.name
762
+ ):
763
+ vi_copy = ValueInfoProto()
764
+ vi_copy.CopyFrom(vi)
765
+ if hasattr(vi_copy.type.tensor_type, "shape"):
766
+ vi_copy.type.tensor_type.ClearField("shape")
767
+ name_vi[vi.name] = vi_copy
768
+ for vi in model.graph.value_info:
769
+ if vi.name in name_vi:
770
+ del name_vi[vi.name]
771
+ for vi in name_vi.values():
772
+ model.graph.value_info.append(vi)
773
+ except Exception:
774
+ logger.warning(
775
+ "Failed to run symbolic shape inference. Please file an issue in https://github.com/microsoft/onnxruntime."
776
+ )
777
+
778
+ parameters = {"disable_shape_infer": use_symbolic_shape_infer}
779
+ parameters.update(
780
+ {
781
+ key: kwargs[key]
782
+ for key in [
783
+ "keep_io_types",
784
+ "min_positive_val",
785
+ "max_finite_val",
786
+ "op_block_list",
787
+ "node_block_list",
788
+ "force_fp16_initializers",
789
+ "force_fp16_inputs",
790
+ "use_bfloat16_as_blocked_nodes_dtype",
791
+ ]
792
+ if key in kwargs
793
+ }
794
+ )
795
+
796
+ fp16_model = convert_float_to_float16(model, **parameters)
797
+ self.initialize(fp16_model)
798
+
799
+ self.remove_cascaded_cast_nodes()
800
+
801
+ self.remove_useless_cast_nodes()
802
+
803
+ def create_node_name(self, op_type, name_prefix=None):
804
+ """Create a unique node name that starts with a prefix (default is operator type).
805
+ The name will not be duplicated with any name that generated or existed in current graphs.
806
+ Args:
807
+ op_type (str): operator type
808
+ name_prefix (str, optional): prefix of node name. Defaults to None.
809
+
810
+ Returns:
811
+ str: node name
812
+ """
813
+
814
+ if name_prefix:
815
+ prefix = name_prefix if name_prefix.endswith("_") else (name_prefix + "_")
816
+ else:
817
+ prefix = op_type + "_"
818
+
819
+ suffix: int = 0
820
+ if prefix in self._node_name_suffix:
821
+ suffix = self._node_name_suffix[prefix] + 1
822
+ else:
823
+ # Check existed node name only once for a prefix
824
+ # as we assume create_node_name is called for every new node in fusion.
825
+ for node in self.nodes():
826
+ if node.name and node.name.startswith(prefix):
827
+ try:
828
+ index = int(node.name[len(prefix) :])
829
+ suffix = max(index + 1, suffix)
830
+ except ValueError:
831
+ continue
832
+
833
+ # Record the generated suffix so that we can avoid generating duplicated name.
834
+ self._node_name_suffix[prefix] = suffix
835
+
836
+ return prefix + str(suffix)
837
+
838
+ def find_graph_input(self, input_name):
839
+ for input in self.model.graph.input:
840
+ if input.name == input_name:
841
+ return input
842
+ return None
843
+
844
+ def find_graph_output(self, output_name):
845
+ for output in self.model.graph.output:
846
+ if output.name == output_name:
847
+ return output
848
+ return None
849
+
850
+ def get_parent_subgraph_nodes(self, node, stop_nodes, output_name_to_node=None):
851
+ if output_name_to_node is None:
852
+ output_name_to_node = self.output_name_to_node()
853
+
854
+ unique_nodes = []
855
+
856
+ parents = self.get_parents(node, output_name_to_node)
857
+ dq = deque(parents)
858
+ while len(dq) > 0:
859
+ current_node = dq.pop()
860
+ if current_node in stop_nodes:
861
+ continue
862
+
863
+ if current_node not in unique_nodes:
864
+ unique_nodes.append(current_node)
865
+
866
+ for input in current_node.input:
867
+ if input in output_name_to_node:
868
+ dq.appendleft(output_name_to_node[input])
869
+
870
+ return unique_nodes
871
+
872
+ def get_graph_inputs(self, current_node, recursive=False):
873
+ """
874
+ Find graph inputs that linked to current node.
875
+ """
876
+ graph_inputs = []
877
+ for input in current_node.input:
878
+ if self.find_graph_input(input) and input not in graph_inputs:
879
+ graph_inputs.append(input)
880
+
881
+ if recursive:
882
+ parent_nodes = self.get_parent_subgraph_nodes(current_node, [])
883
+ for node in parent_nodes:
884
+ for input in node.input:
885
+ if self.find_graph_input(input) and input not in graph_inputs:
886
+ graph_inputs.append(input)
887
+ return graph_inputs
888
+
889
+ @staticmethod
890
+ def input_index(node_output, child_node):
891
+ for index, input in enumerate(child_node.input):
892
+ if input == node_output:
893
+ return index
894
+ return -1
895
+
896
+ def remove_unused_constant(self):
897
+ input_name_to_nodes = self.input_name_to_nodes()
898
+
899
+ # remove unused constant
900
+ unused_nodes = []
901
+ nodes = self.nodes()
902
+ for node in nodes:
903
+ if node.op_type == "Constant" and node.output[0] not in input_name_to_nodes:
904
+ unused_nodes.append(node)
905
+
906
+ self.remove_nodes(unused_nodes)
907
+
908
+ if len(unused_nodes) > 0:
909
+ logger.debug(f"Removed unused constant nodes: {len(unused_nodes)}")
910
+
911
+ def _get_subgraph_inputs_of_node(self, node):
912
+ """
913
+ Get inputs to all nodes in all subgraphs of a node
914
+ """
915
+ # Note: This function only handles one-level subgraphs of child nodes.
916
+ subgraph_nodes_inputs = set()
917
+ for attr in node.attribute:
918
+ if attr.type == AttributeProto.GRAPH:
919
+ child_nodes = attr.g.node
920
+ for child_node in child_nodes:
921
+ subgraph_nodes_inputs.update(child_node.input)
922
+ return subgraph_nodes_inputs
923
+
924
+ def _get_subgraph_nodes_and_inputs(self, ops_with_graph_attrs):
925
+ """
926
+ Get input names to all nodes in all subgraphs where subgraphs are
927
+ graph attributes of a node in the main graph
928
+ """
929
+ subgraph_nodes = list(filter(lambda node: node.op_type in ops_with_graph_attrs, self.model.graph.node))
930
+ subgraph_nodes_inputs = set()
931
+ for parent_node in subgraph_nodes:
932
+ subgraph_inputs_of_parent_node = self._get_subgraph_inputs_of_node(parent_node)
933
+ subgraph_nodes_inputs.update(subgraph_inputs_of_parent_node)
934
+ return subgraph_nodes, subgraph_nodes_inputs
935
+
936
+ def prune_graph(self, outputs=None, allow_remove_graph_inputs=True):
937
+ """
938
+ Prune graph to keep only required outputs. It removes unnecessary nodes that are not linked
939
+ (directly or indirectly) to any required output.
940
+
941
+ There is also an option to remove graph inputs that are not used to generate any required output.
942
+
943
+ Args:
944
+ outputs (list): a list of graph outputs to retain. If it is None, all graph outputs will be kept.
945
+ allow_remove_graph_inputs (bool): allow remove graph inputs.
946
+ """
947
+
948
+ keep_outputs = [output.name for output in self.model.graph.output] if outputs is None else outputs
949
+
950
+ input_name_to_nodes_for_main_graph = self.input_name_to_nodes(exclude_subgraphs=True)
951
+ output_name_to_node = self.output_name_to_node()
952
+
953
+ def get_first_output(node):
954
+ if node.output[0]:
955
+ return node.output[0]
956
+ return next(iter([o for o in node.output if o]), None)
957
+
958
+ if len(self.graphs()) > 1:
959
+ # Get input names for all nodes in all subgraphs
960
+ subgraph_nodes, subgraph_nodes_inputs = self._get_subgraph_nodes_and_inputs(
961
+ ops_with_graph_attrs={"Loop", "Scan", "If"}
962
+ )
963
+ if len(subgraph_nodes) == 0:
964
+ # TODO: support other ops such as `BeamSearch` that have subgraphs as op attributes
965
+ logger.debug("Skip prune_graph since graph has subgraph")
966
+ return
967
+
968
+ # For graphs with subgraphs, add dangling outputs from parent graph nodes to list of outputs to keep
969
+ for node in self.model.graph.node:
970
+ # TODO: This for-loop logic currently assumes that Loop/Scan/If nodes will not be
971
+ # pruned because their subgraphs are needed for computations. This might not be
972
+ # true in all cases.
973
+ if node in subgraph_nodes:
974
+ continue
975
+
976
+ # Check if node output is an input of a subgraph node and not an input to a node in the main graph
977
+ for output in node.output:
978
+ if output in subgraph_nodes_inputs and output not in input_name_to_nodes_for_main_graph:
979
+ keep_outputs += [output]
980
+
981
+ # Keep track of nodes to keep. The key is first output of node, and the value is the node.
982
+ output_to_node = {}
983
+
984
+ # Start from graph outputs, and find parent nodes recursively, and add nodes to the output_to_node dictionary.
985
+ dq = deque()
986
+ for output in keep_outputs:
987
+ if output in output_name_to_node:
988
+ dq.append(output_name_to_node[output])
989
+ while len(dq) > 0:
990
+ node = dq.pop()
991
+ first_output = get_first_output(node)
992
+ if first_output and (first_output not in output_to_node):
993
+ output_to_node[first_output] = node
994
+ for name in node.input:
995
+ if len(name) > 0 and (name in output_name_to_node) and (name not in output_to_node):
996
+ dq.appendleft(output_name_to_node[name])
997
+
998
+ # Keep only those nodes in the output_to_node dictionary.
999
+ nodes_to_keep = []
1000
+ num_nodes_removed = 0
1001
+ for node in self.model.graph.node:
1002
+ first_output = get_first_output(node)
1003
+ kept_node = output_to_node.get(first_output)
1004
+
1005
+ # Need to double check the node since fused node might reuse output name of some nodes to be removed.
1006
+ # It is slow to compare whole node, so we compare op_type first to avoid comparing node in most cases.
1007
+ if kept_node and kept_node.op_type == node.op_type and kept_node == node:
1008
+ nodes_to_keep.append(node)
1009
+ else:
1010
+ num_nodes_removed += 1
1011
+ self.model.graph.ClearField("node")
1012
+ self.model.graph.node.extend(nodes_to_keep)
1013
+
1014
+ # Remove graph outputs not in list
1015
+ output_to_remove = []
1016
+ if outputs is not None:
1017
+ for output in self.model.graph.output:
1018
+ if output.name not in outputs:
1019
+ output_to_remove.append(output)
1020
+ for output in output_to_remove:
1021
+ self.model.graph.output.remove(output)
1022
+
1023
+ # Remove graph inputs not used by any node.
1024
+ input_to_remove = []
1025
+ if allow_remove_graph_inputs:
1026
+ input_name_to_nodes = self.input_name_to_nodes()
1027
+ input_to_remove = [input for input in self.model.graph.input if input.name not in input_name_to_nodes]
1028
+ for name in input_to_remove:
1029
+ self.model.graph.input.remove(name)
1030
+
1031
+ if input_to_remove or output_to_remove or num_nodes_removed > 0:
1032
+ removed = []
1033
+ if input_to_remove:
1034
+ removed.append(f"{len(input_to_remove)} inputs")
1035
+ if output_to_remove:
1036
+ removed.append(f"{len(output_to_remove)} outputs")
1037
+ if num_nodes_removed > 0:
1038
+ removed.append(f"{num_nodes_removed} nodes")
1039
+ logger.info("Removed %s", ", ".join(removed))
1040
+
1041
+ self.update_graph()
1042
+
1043
+ def update_graph(self, verbose=False, allow_remove_graph_inputs=False):
1044
+ graph = self.model.graph
1045
+
1046
+ remaining_input_names = set()
1047
+ for node in graph.node:
1048
+ if node.op_type in ["Loop", "Scan", "If"]:
1049
+ # Add input names of nodes in subgraphs
1050
+ subgraph_inputs_of_node = self._get_subgraph_inputs_of_node(node)
1051
+ remaining_input_names.update(subgraph_inputs_of_node)
1052
+
1053
+ if node.op_type != "Constant":
1054
+ remaining_input_names.update(node.input)
1055
+ if verbose:
1056
+ logger.debug(f"remaining input names: {remaining_input_names}")
1057
+
1058
+ # remove graph input that is not used
1059
+ inputs_to_remove = []
1060
+ if allow_remove_graph_inputs:
1061
+ for input in graph.input:
1062
+ if input.name not in remaining_input_names:
1063
+ inputs_to_remove.append(input)
1064
+ for input in inputs_to_remove:
1065
+ graph.input.remove(input)
1066
+
1067
+ names_to_remove = [input.name for input in inputs_to_remove]
1068
+ logger.debug(f"remove {len(inputs_to_remove)} unused inputs: {names_to_remove}")
1069
+
1070
+ # remove weights that are not used
1071
+ weights_to_remove = []
1072
+ weights_to_keep = []
1073
+ for initializer in graph.initializer:
1074
+ if initializer.name not in remaining_input_names and not self.find_graph_output(initializer.name):
1075
+ weights_to_remove.append(initializer)
1076
+ else:
1077
+ weights_to_keep.append(initializer.name)
1078
+ for initializer in weights_to_remove:
1079
+ graph.initializer.remove(initializer)
1080
+
1081
+ names_to_remove = [initializer.name for initializer in weights_to_remove]
1082
+ logger.debug(f"remove {len(weights_to_remove)} unused initializers: {names_to_remove}")
1083
+ if verbose:
1084
+ logger.debug(f"remaining initializers:{weights_to_keep}")
1085
+
1086
+ self.remove_unused_constant()
1087
+
1088
+ def is_safe_to_fuse_nodes(self, nodes_to_remove, keep_outputs, input_name_to_nodes, output_name_to_node):
1089
+ for node_to_remove in nodes_to_remove:
1090
+ for output_to_remove in node_to_remove.output:
1091
+ if output_to_remove in keep_outputs:
1092
+ continue
1093
+
1094
+ if output_to_remove in input_name_to_nodes:
1095
+ for impacted_node in input_name_to_nodes[output_to_remove]:
1096
+ if impacted_node not in nodes_to_remove:
1097
+ logger.debug(
1098
+ "it is not safe to remove nodes since output %s is used by %s",
1099
+ output_to_remove,
1100
+ impacted_node,
1101
+ )
1102
+ return False
1103
+ return True
1104
+
1105
+ @staticmethod
1106
+ def graph_topological_sort(graph, is_deterministic=False):
1107
+ deps_set = set() # dependency set of all node
1108
+ sorted_node_set = set() # sorted node set
1109
+ sorted_nodes = [] # initialize sorted_nodes
1110
+
1111
+ initializer_names = [init.name for init in graph.initializer]
1112
+ graph_input_names = [input.name for input in graph.input]
1113
+ input_names = initializer_names + graph_input_names
1114
+
1115
+ if is_deterministic:
1116
+ input_names.sort()
1117
+
1118
+ for input_name in input_names:
1119
+ deps_set.add(input_name)
1120
+
1121
+ sorted_node_set_len = -1
1122
+ graph_nodes = graph.node if not is_deterministic else sorted(graph.node, key=lambda x: x.name)
1123
+
1124
+ last_node_name = None
1125
+ while len(sorted_node_set) != len(graph_nodes):
1126
+ if len(sorted_node_set) == sorted_node_set_len:
1127
+ break
1128
+ sorted_node_set_len = len(sorted_node_set)
1129
+ for node_idx, node in enumerate(graph_nodes):
1130
+ if node_idx in sorted_node_set:
1131
+ continue
1132
+ input_count = sum(1 for _ in node.input if _)
1133
+ if input_count == 0:
1134
+ sorted_nodes.append(node)
1135
+ sorted_node_set.add(node_idx)
1136
+ for output in node.output:
1137
+ if output:
1138
+ deps_set.add(output)
1139
+ continue
1140
+ failed = False
1141
+ for input_name in node.input:
1142
+ if input_name and input_name not in deps_set:
1143
+ failed = True
1144
+ last_node_name = node.name
1145
+ if not failed:
1146
+ sorted_nodes.append(node)
1147
+ sorted_node_set.add(node_idx)
1148
+ for output in node.output:
1149
+ if output:
1150
+ deps_set.add(output)
1151
+ else:
1152
+ continue
1153
+
1154
+ if len(sorted_node_set) != len(graph.node):
1155
+ raise RuntimeError(
1156
+ f"Graph is not a DAG: len(sorted_node_set)={len(sorted_node_set)}, len(graph.node)={len(graph.node)}, failed at node {last_node_name}"
1157
+ )
1158
+
1159
+ graph.ClearField("node")
1160
+ graph.node.extend(sorted_nodes)
1161
+
1162
+ def topological_sort(self, is_deterministic=False):
1163
+ # TODO: support graph_topological_sort() in subgraphs
1164
+ # for graph in self.graphs():
1165
+ # self.graph_topological_sort(graph)
1166
+ OnnxModel.graph_topological_sort(self.model.graph, is_deterministic)
1167
+
1168
+ @staticmethod
1169
+ def save(
1170
+ model,
1171
+ output_path,
1172
+ save_as_external_data=False,
1173
+ all_tensors_to_one_file=True,
1174
+ size_threshold=1024,
1175
+ convert_attribute=False,
1176
+ ):
1177
+ Path(output_path).parent.mkdir(parents=True, exist_ok=True)
1178
+
1179
+ # Add ms domain if needed
1180
+ ms_opset = [opset for opset in model.opset_import if opset.domain == "com.microsoft"]
1181
+ # Check whether there is custom op in top level graph (our fusion is on top level right now).
1182
+ # May need to extend to subgraph if our fusion are extended to subgraphs.
1183
+ ms_node = [node for node in model.graph.node if node.domain == "com.microsoft"]
1184
+ if ms_node and not ms_opset:
1185
+ opset = model.opset_import.add()
1186
+ opset.version = 1
1187
+ opset.domain = "com.microsoft"
1188
+
1189
+ if save_as_external_data:
1190
+ # Save model to external data, which is needed for model size > 2GB
1191
+ output_dir = Path(output_path).parent
1192
+ output_dir.mkdir(parents=True, exist_ok=True)
1193
+ external_data_path = output_path + ".data"
1194
+ location = Path(external_data_path).name if all_tensors_to_one_file else None
1195
+
1196
+ if os.path.exists(output_path):
1197
+ logger.info(f"Delete the existing onnx file: {output_path}")
1198
+ os.remove(output_path)
1199
+
1200
+ if all_tensors_to_one_file:
1201
+ if os.path.exists(external_data_path):
1202
+ # Delete the external data file. Otherwise, data will be appended to existing file.
1203
+ logger.info(f"Delete the existing external data file: {external_data_path}")
1204
+ os.remove(external_data_path)
1205
+ else:
1206
+ if os.listdir(output_dir):
1207
+ raise RuntimeError(f"Output directory ({output_dir}) for external data is not empty.")
1208
+
1209
+ save_model(
1210
+ model,
1211
+ output_path,
1212
+ save_as_external_data=True,
1213
+ all_tensors_to_one_file=all_tensors_to_one_file,
1214
+ location=location,
1215
+ size_threshold=size_threshold,
1216
+ convert_attribute=convert_attribute,
1217
+ )
1218
+ else:
1219
+ save_model(model, output_path)
1220
+
1221
+ def save_model_to_file(self, output_path, use_external_data_format=False, all_tensors_to_one_file=True):
1222
+ logger.info("Sort graphs in topological order")
1223
+ self.topological_sort()
1224
+
1225
+ # Note: After the model is saved to another directory with external data,
1226
+ # You need reload the onnx model if you want to read tensor from self.model object.
1227
+ # It is because the base directory is not updated for self.model object so attempt to read tensor data
1228
+ # might encounter error since external data cannot be located.
1229
+ OnnxModel.save(self.model, output_path, use_external_data_format, all_tensors_to_one_file)
1230
+ logger.info(f"Model saved to {output_path}")
1231
+
1232
+ def get_graph_inputs_excluding_initializers(self):
1233
+ """
1234
+ Returns real graph inputs (excluding initializers from older onnx model).
1235
+ """
1236
+ graph_inputs = []
1237
+ for input in self.model.graph.input:
1238
+ if self.get_initializer(input.name) is None:
1239
+ graph_inputs.append(input)
1240
+ return graph_inputs
1241
+
1242
+ def get_opset_version(self):
1243
+ """Get opset version of onnx domain
1244
+
1245
+ Raises:
1246
+ RuntimeError: ONNX model has no opset for default domain.
1247
+
1248
+ Returns:
1249
+ int: opset version of onnx domain.
1250
+ """
1251
+ for opset in self.model.opset_import:
1252
+ if opset.domain in ["", "ai.onnx"]:
1253
+ return opset.version
1254
+ raise RuntimeError("ONNX model has no opset for default domain")
1255
+
1256
+ def get_operator_statistics(self, include_domain=False):
1257
+ """
1258
+ Returns node count of operators.
1259
+ """
1260
+ op_count = {}
1261
+ for node in self.nodes():
1262
+ op = (node.domain + ":" if include_domain and node.domain else "") + node.op_type
1263
+ op_count[op] = 1 if op not in op_count else (op_count[op] + 1)
1264
+
1265
+ # Sorted by count in the descending order, then by key in alphabetical order.
1266
+ logger.info(f"Operators:{sorted(op_count.items(), key=lambda kv:(-kv[1], kv[0]))}")
1267
+
1268
+ return op_count
1269
+
1270
+ @staticmethod
1271
+ def to_data_hash(tensor: TensorProto, base_dir: str = "") -> int:
1272
+ """Converts a tensor def object to a hash for data comparison purposes.
1273
+ Args:
1274
+ tensor: a TensorProto object.
1275
+ base_dir: if external tensor exists, base_dir can help to find the path to it
1276
+ Returns:
1277
+ hash: a hash of the data.
1278
+ """
1279
+ if tensor.HasField("segment"):
1280
+ raise ValueError("Currently not supporting loading segments.")
1281
+ if tensor.data_type == TensorProto.UNDEFINED:
1282
+ raise TypeError("The element type in the input tensor is not defined.")
1283
+ tensor_dtype = tensor.data_type
1284
+ storage_field = helper.tensor_dtype_to_field(tensor_dtype)
1285
+
1286
+ if tensor.data_type == TensorProto.STRING:
1287
+ utf8_strings = getattr(tensor, storage_field)
1288
+ return hash(tuple(s.decode("utf-8") for s in utf8_strings))
1289
+ # Load raw data from external tensor if it exists
1290
+ if uses_external_data(tensor):
1291
+ load_external_data_for_tensor(tensor, base_dir)
1292
+ if tensor.HasField("raw_data"):
1293
+ return hash(tensor.raw_data)
1294
+ else:
1295
+ np_data = numpy_helper.to_array(tensor)
1296
+ return hash(np_data.tobytes())
1297
+
1298
+ @staticmethod
1299
+ def has_same_value(
1300
+ tensor1: TensorProto,
1301
+ tensor2: TensorProto,
1302
+ signature_cache1: Optional[dict] = None,
1303
+ signature_cache2: Optional[dict] = None,
1304
+ ) -> bool:
1305
+ """Returns True when two tensors have same value.
1306
+ Note that name can be different.
1307
+
1308
+ Args:
1309
+ tensor1 (TensorProto): initializer 1
1310
+ tensor2 (TensorProto): initializer 2
1311
+ signature_cache1 (dict): Optional dictionary to store data signatures of tensor1 in order to speed up comparison.
1312
+ signature_cache2 (dict): Optional dictionary to store data signatures of tensor2 in order to speed up comparison.
1313
+ Returns:
1314
+ bool: True when two initializers has same value.
1315
+ """
1316
+ sig1 = (
1317
+ signature_cache1[tensor1.name]
1318
+ if signature_cache1 and tensor1.name in signature_cache1
1319
+ else OnnxModel.to_data_hash(tensor1)
1320
+ )
1321
+ sig2 = (
1322
+ signature_cache2[tensor2.name]
1323
+ if signature_cache2 and tensor2.name in signature_cache2
1324
+ else OnnxModel.to_data_hash(tensor2)
1325
+ )
1326
+ if signature_cache1 is not None:
1327
+ signature_cache1[tensor1.name] = sig1
1328
+ if signature_cache2 is not None:
1329
+ signature_cache2[tensor2.name] = sig2
1330
+ if sig1 == sig2 and tensor1.data_type == tensor2.data_type and tensor1.dims == tensor2.dims:
1331
+ # Same signature, now do the expensive check to confirm the data is the same
1332
+ return (numpy_helper.to_array(tensor1) == numpy_helper.to_array(tensor2)).all()
1333
+
1334
+ return False
1335
+
1336
+ def remove_duplicated_initializer(self, cache: Optional[dict] = None):
1337
+ """Remove initializers with duplicated values, and only keep the first one.
1338
+ It could help reduce size of models (like ALBert) with shared weights.
1339
+ If require_raw_data passed, method will only compare raw_data initializers to speed runtime
1340
+ Note: this function does not process subgraph.
1341
+ """
1342
+ if len(self.graphs()) > 1:
1343
+ logger.warning("remove_duplicated_initializer does not process subgraphs.")
1344
+
1345
+ initializer_count = len(self.model.graph.initializer)
1346
+
1347
+ same = [-1] * initializer_count
1348
+ for i in range(initializer_count - 1):
1349
+ if same[i] >= 0:
1350
+ continue
1351
+ for j in range(i + 1, initializer_count):
1352
+ if OnnxModel.has_same_value(
1353
+ self.model.graph.initializer[i],
1354
+ self.model.graph.initializer[j],
1355
+ cache,
1356
+ cache,
1357
+ ):
1358
+ same[j] = i
1359
+
1360
+ count = 0
1361
+ for i in range(initializer_count):
1362
+ if same[i] >= 0:
1363
+ count += 1
1364
+ self.replace_input_of_all_nodes(
1365
+ self.model.graph.initializer[i].name,
1366
+ self.model.graph.initializer[same[i]].name,
1367
+ )
1368
+
1369
+ if count > 0:
1370
+ self.update_graph()
1371
+ print(f"Removed {count} initializers with duplicated value")
1372
+
1373
+ def add_prefix_to_names(self, prefix: str):
1374
+ """Add prefix to initializer or intermediate outputs in graph. Main graph inputs and outputs are excluded.
1375
+ It could help avoid conflicting in name of node_args when merging two graphs.
1376
+ Note: this function does not process subgraph.
1377
+ """
1378
+ if len(self.graphs()) > 1:
1379
+ logger.warning("add_prefix_to_names does not process subgraphs.")
1380
+
1381
+ # Exclude the names of inputs and outputs of main graph (but not subgraphs)
1382
+ # and empty names ("") as they have special meaning to denote missing optional inputs
1383
+ excluded = [i.name for i in self.model.graph.input] + [o.name for o in self.model.graph.output] + [""]
1384
+
1385
+ for initializer in self.model.graph.initializer:
1386
+ if initializer.name not in excluded:
1387
+ if prefix + initializer.name not in excluded:
1388
+ initializer.name = prefix + initializer.name
1389
+
1390
+ for node in self.model.graph.node:
1391
+ # update name of node inputs
1392
+ for j in range(len(node.input)):
1393
+ if node.input[j] not in excluded:
1394
+ if prefix + node.input[j] not in excluded:
1395
+ node.input[j] = prefix + node.input[j]
1396
+
1397
+ # update name of node outputs
1398
+ for j in range(len(node.output)):
1399
+ if node.output[j] not in excluded:
1400
+ if prefix + node.output[j] not in excluded:
1401
+ node.output[j] = prefix + node.output[j]
1402
+
1403
+ for value_info in self.model.graph.value_info:
1404
+ if value_info.name not in excluded:
1405
+ value_info.name = prefix + value_info.name
1406
+
1407
+ def clean_shape_infer(self):
1408
+ self.model.graph.ClearField("value_info")
1409
+
1410
+ def use_float16(self):
1411
+ """Check whether the model uses float16"""
1412
+ queue = [] # queue for BFS
1413
+ queue.append(self.model.graph)
1414
+ while queue:
1415
+ sub_graphs = []
1416
+ for graph in queue:
1417
+ if not isinstance(graph, GraphProto):
1418
+ continue
1419
+
1420
+ for v in itertools.chain(graph.input, graph.output, graph.value_info):
1421
+ if v.type.tensor_type.elem_type == TensorProto.FLOAT16:
1422
+ return True
1423
+ if v.type.HasField("sequence_type"):
1424
+ if v.type.sequence_type.elem_type.tensor_type.elem_type == TensorProto.FLOAT16:
1425
+ return True
1426
+
1427
+ for t in graph.initializer:
1428
+ if t.data_type == TensorProto.FLOAT16:
1429
+ return True
1430
+
1431
+ for node in graph.node:
1432
+ if node.op_type == "Cast":
1433
+ for attr in node.attribute:
1434
+ if attr.name == "to" and attr.i == TensorProto.FLOAT16:
1435
+ return True
1436
+
1437
+ for attr in node.attribute:
1438
+ if attr.type == AttributeProto.GRAPH:
1439
+ sub_graphs.append(attr.g)
1440
+
1441
+ for g in attr.graphs:
1442
+ sub_graphs.append(g) # noqa: PERF402
1443
+
1444
+ if isinstance(attr.t, TensorProto) and attr.t.data_type == TensorProto.FLOAT16:
1445
+ return True
1446
+
1447
+ for t in attr.tensors:
1448
+ if isinstance(t, TensorProto) and t.data_type == TensorProto.FLOAT16:
1449
+ return True
1450
+
1451
+ queue = sub_graphs
1452
+
1453
+ return False
1454
+
1455
+ def change_graph_input_type(
1456
+ self,
1457
+ graph_input: ValueInfoProto,
1458
+ new_type: int,
1459
+ ):
1460
+ """Change graph input type, and add Cast node if needed.
1461
+
1462
+ Args:
1463
+ graph_input (ValueInfoProto): input of the graph
1464
+ new_type (int): new data type like TensorProto.INT32.
1465
+
1466
+ Returns:
1467
+ NodeProto: a new Cast node that added. None if Cast node is not added.
1468
+ List[NodeProto]: Cast nodes that have been removed.
1469
+ """
1470
+ assert isinstance(graph_input, ValueInfoProto)
1471
+ assert self.find_graph_input(graph_input.name)
1472
+
1473
+ if graph_input.type.tensor_type.elem_type == int(new_type):
1474
+ return None, []
1475
+
1476
+ graph = self.graph()
1477
+ new_cast_node = None
1478
+ nodes_to_remove = []
1479
+
1480
+ input_name_to_nodes = self.input_name_to_nodes()
1481
+ if graph_input.name in input_name_to_nodes:
1482
+ nodes = input_name_to_nodes[graph_input.name]
1483
+
1484
+ # For children that is not Cast node, insert a Cast node to convert int32 to original data type.
1485
+ nodes_not_cast = [node for node in nodes if node.op_type != "Cast"]
1486
+ if nodes_not_cast:
1487
+ node_name = self.create_node_name("Cast")
1488
+ output_name = node_name + "_" + graph_input.name
1489
+ new_value_info = graph.value_info.add()
1490
+ new_value_info.CopyFrom(graph_input)
1491
+ new_value_info.name = output_name
1492
+ new_cast_node = helper.make_node(
1493
+ "Cast",
1494
+ [graph_input.name],
1495
+ [output_name],
1496
+ to=int(graph_input.type.tensor_type.elem_type),
1497
+ name=node_name,
1498
+ )
1499
+ graph.node.extend([new_cast_node])
1500
+
1501
+ for node in nodes_not_cast:
1502
+ OnnxModel.replace_node_input(node, graph_input.name, output_name)
1503
+
1504
+ # For children that is Cast node, no need to insert Cast.
1505
+ # When the children is Cast to int32, we can remove that Cast node since input type is int32 now.
1506
+ nodes_cast = [node for node in nodes if node.op_type == "Cast"]
1507
+ for node in nodes_cast:
1508
+ if OnnxModel.get_node_attribute(node, "to") == int(new_type):
1509
+ self.replace_input_of_all_nodes(node.output[0], graph_input.name)
1510
+ if not self.find_graph_output(node.output[0]):
1511
+ nodes_to_remove.append(node)
1512
+ if nodes_to_remove:
1513
+ self.remove_nodes(nodes_to_remove)
1514
+
1515
+ graph_input.type.tensor_type.elem_type = int(new_type)
1516
+ return new_cast_node, nodes_to_remove
1517
+
1518
+ def change_graph_output_type(
1519
+ self,
1520
+ graph_output: ValueInfoProto,
1521
+ new_type: int,
1522
+ ):
1523
+ """Change graph input type, and add Cast node if needed.
1524
+
1525
+ Args:
1526
+ graph_input (str | ValueInfoProto): output of the graph
1527
+ new_type (int): new data type.
1528
+
1529
+ Returns:
1530
+ NodeProto: a new Cast node that added. None if Cast node is not added.
1531
+ """
1532
+ assert isinstance(graph_output, ValueInfoProto)
1533
+ assert self.find_graph_output(graph_output.name)
1534
+
1535
+ if graph_output.type.tensor_type.elem_type == int(new_type):
1536
+ return None
1537
+
1538
+ cast_node = None
1539
+ graph = self.graph()
1540
+
1541
+ # Add a cast node
1542
+ node_name = self.create_node_name("Cast")
1543
+ input_name = node_name + "_" + graph_output.name
1544
+ self.replace_input_of_all_nodes(graph_output.name, input_name)
1545
+ new_value_info = graph.value_info.add()
1546
+ new_value_info.CopyFrom(graph_output)
1547
+ new_value_info.name = input_name
1548
+ cast_node = helper.make_node(
1549
+ "Cast",
1550
+ [input_name],
1551
+ [graph_output.name],
1552
+ to=int(new_type),
1553
+ name=node_name,
1554
+ )
1555
+ graph.node.extend([cast_node])
1556
+ graph_output.type.tensor_type.elem_type = int(new_type)
1557
+ return cast_node
1558
+
1559
+ def rename_graph_output(self, old_name: str, new_name: str):
1560
+ if new_name in self.output_name_to_node():
1561
+ raise RuntimeError("{new_name} exists in graph")
1562
+
1563
+ graph = self.graph()
1564
+ for output in graph.output:
1565
+ if output.name == old_name:
1566
+ logger.debug("replace output name from %s to %s", old_name, new_name)
1567
+ self.replace_input_of_all_nodes(old_name, new_name)
1568
+ self.replace_output_of_all_nodes(old_name, new_name)
1569
+ output.name = new_name