onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,209 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionSkipLayerNormalization(Fusion):
17
+ """
18
+ Fuse Add + LayerNormalization into one node: SkipLayerNormalization
19
+ Note: This fusion does not check the input shape of Add and LayerNormalization.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model: OnnxModel,
25
+ fused_op_type: str = "SkipLayerNormalization",
26
+ search_op_types: str = "LayerNormalization",
27
+ shape_infer: bool = True,
28
+ ):
29
+ super().__init__(model, fused_op_type, search_op_types)
30
+ if shape_infer:
31
+ # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
32
+ self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
33
+ if self.shape_infer_helper is None:
34
+ # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
35
+ logger.warning("symbolic shape inference disabled or failed.")
36
+
37
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
38
+ add = self.model.get_parent(node, 0, output_name_to_node)
39
+
40
+ # In some models there is input_ids->gather->add->LayerNorm and one of input of the
41
+ # add node is initializer with fixed shape which should not be fused into SkipLayerNorm
42
+ if add is None or add.op_type != "Add":
43
+ return
44
+
45
+ # The number of inputs of add should be 2
46
+ if len(add.input) != 2:
47
+ return
48
+
49
+ for add_input in add.input:
50
+ if self.model.get_initializer(add_input) is not None:
51
+ return
52
+
53
+ # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization
54
+ if add in self.nodes_to_remove:
55
+ return
56
+
57
+ # Root Mean Square Layer Normalization
58
+ simplified = node.op_type == "SimplifiedLayerNormalization"
59
+
60
+ if hasattr(self, "shape_infer_helper"):
61
+ if self.shape_infer_helper is not None:
62
+ if (
63
+ self.shape_infer_helper.get_edge_shape(add.input[0])
64
+ and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
65
+ ):
66
+ logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
67
+ return
68
+
69
+ # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
70
+ if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
71
+ logger.debug(
72
+ "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
73
+ add.input[0],
74
+ add.input[1],
75
+ )
76
+ return
77
+ else:
78
+ logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed")
79
+ return
80
+
81
+ gather_path = self.model.match_parent_path(add, ["Gather"], [None])
82
+ if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
83
+ if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
84
+ return
85
+
86
+ # This means that the residual Add before the LayerNormalization produces an output
87
+ # that is consumed by some other nodes or graph output other than the LayerNormalization itself
88
+ # We can still go ahead with the SkipLayerNormalization fusion but we need to
89
+ # preserve the output of Add and that needs to be produced by SkipLayerNormalization.
90
+ add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None
91
+ residual_add_has_multiple_consumers = (
92
+ add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1
93
+ )
94
+
95
+ outputs_to_keep = node.output
96
+
97
+ if residual_add_has_multiple_consumers:
98
+ outputs_to_keep.extend([add.output[0]])
99
+
100
+ outputs = [node.output[0]]
101
+
102
+ # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output
103
+ if residual_add_has_multiple_consumers:
104
+ outputs.extend(["", "", add.output[0]])
105
+
106
+ if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node):
107
+ self.nodes_to_remove.extend([add, node])
108
+
109
+ inputs = (
110
+ [add.input[0], add.input[1], node.input[1], node.input[2]]
111
+ if not simplified
112
+ else [add.input[0], add.input[1], node.input[1]]
113
+ )
114
+ normalize_node = helper.make_node(
115
+ self.fused_op_type,
116
+ inputs=inputs,
117
+ outputs=outputs,
118
+ name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
119
+ )
120
+ normalize_node.domain = "com.microsoft"
121
+
122
+ # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization
123
+ for att in node.attribute:
124
+ if att.name == "epsilon":
125
+ normalize_node.attribute.extend([att])
126
+
127
+ # Set default epsilon if no epsilon exists from layernorm
128
+ if len(normalize_node.attribute) == 0:
129
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
130
+
131
+ self.nodes_to_add.append(normalize_node)
132
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
133
+
134
+
135
+ class FusionBiasSkipLayerNormalization(Fusion):
136
+ def __init__(self, model: OnnxModel):
137
+ super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias")
138
+
139
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
140
+ if len(node.input) != 4:
141
+ return
142
+
143
+ return_indice = []
144
+ nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice)
145
+ if nodes is not None:
146
+ (add, _matmul) = nodes
147
+ else:
148
+ # In case of fp16, we could have a Cast between the MatMul and the bias Add
149
+ return_indice = []
150
+ nodes = self.model.match_parent_path(
151
+ node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice
152
+ )
153
+ if nodes is not None:
154
+ (add, _cast, _matmul) = nodes
155
+ else:
156
+ return
157
+
158
+ assert len(return_indice) == 2 or len(return_indice) == 3
159
+ add_input_index = return_indice[0]
160
+ if add_input_index >= 2:
161
+ return
162
+ sln_input = add.input[return_indice[1]]
163
+ bias_input = add.input[1 - return_indice[1]]
164
+ skip_input = node.input[1 - add_input_index]
165
+
166
+ # bias should be one dimension
167
+ initializer = self.model.get_initializer(bias_input)
168
+ if initializer is None:
169
+ return
170
+ bias_weight = NumpyHelper.to_array(initializer)
171
+ if bias_weight is None:
172
+ logger.debug("Bias weight not found")
173
+ return
174
+ if len(bias_weight.shape) != 1:
175
+ logger.debug("Bias weight is not 1D")
176
+ return
177
+
178
+ subgraph_nodes = [node, add]
179
+ if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node):
180
+ logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe")
181
+ return
182
+
183
+ self.nodes_to_remove.extend(subgraph_nodes)
184
+ inputs = [
185
+ sln_input,
186
+ skip_input,
187
+ node.input[2],
188
+ node.input[3],
189
+ bias_input,
190
+ ]
191
+ new_node = helper.make_node(
192
+ "SkipLayerNormalization",
193
+ inputs=inputs,
194
+ outputs=node.output,
195
+ name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"),
196
+ )
197
+ new_node.domain = "com.microsoft"
198
+
199
+ # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
200
+ for att in node.attribute:
201
+ if att.name == "epsilon":
202
+ new_node.attribute.extend([att])
203
+
204
+ # Set default epsilon if no epsilon exists from skiplayernorm
205
+ if len(new_node.attribute) == 0:
206
+ new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
207
+
208
+ self.nodes_to_add.append(new_node)
209
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
@@ -0,0 +1,168 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Dict, List
8
+
9
+ from fusion_base import Fusion
10
+ from fusion_utils import FusionUtils
11
+ from onnx import NodeProto, TensorProto, helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionTranspose(Fusion):
18
+ def __init__(self, model: OnnxModel):
19
+ super().__init__(model, "Transpose", "Transpose")
20
+
21
+ def fuse(
22
+ self,
23
+ transpose_node: NodeProto,
24
+ input_name_to_nodes: Dict[str, List[NodeProto]],
25
+ output_name_to_node: Dict[str, NodeProto],
26
+ ):
27
+ """
28
+ Note that onnxruntime will do comprehensive transpose optimization after loading model.
29
+ The purpose of this fusion is to make graph clean before running onnxruntime.
30
+
31
+ Case 1:
32
+ (input)-->Transpose(perm=a)-->Transpose(perm=b)-->
33
+ After:
34
+ (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
35
+ |
36
+ +----->Transpose(perm=a*b)-->
37
+
38
+ Case 2 (Cast has only one child):
39
+ (input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
40
+ After:
41
+ (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
42
+ |
43
+ +----->Cast --> Transpose(perm=a*b)-->
44
+ """
45
+ transpose_b = transpose_node
46
+ if transpose_b.input[0] not in output_name_to_node:
47
+ return
48
+
49
+ transpose_a = output_name_to_node[transpose_b.input[0]]
50
+ if transpose_a.op_type != "Cast":
51
+ cast_node = None
52
+ else:
53
+ cast_node = transpose_a
54
+
55
+ cast_children = self.model.get_children(cast_node, input_name_to_nodes)
56
+ if cast_children and len(cast_children) > 1:
57
+ return
58
+
59
+ if cast_node.input[0] not in output_name_to_node:
60
+ return
61
+
62
+ transpose_a = output_name_to_node[cast_node.input[0]]
63
+
64
+ if transpose_a.op_type != "Transpose":
65
+ return
66
+
67
+ permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
68
+ assert isinstance(permutation, list)
69
+
70
+ parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
71
+ assert isinstance(parent_permutation, list)
72
+
73
+ assert len(parent_permutation) == len(permutation)
74
+
75
+ output_permutation = []
76
+ for _j, index in enumerate(permutation):
77
+ output_permutation.append(parent_permutation[index])
78
+
79
+ if cast_node is None:
80
+ if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
81
+ self.nodes_to_remove.append(transpose_a)
82
+ else:
83
+ if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
84
+ self.nodes_to_remove.append(transpose_a)
85
+ transpose_b.ClearField("attribute")
86
+ transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
87
+
88
+
89
+ class FusionInsertTranspose(Fusion):
90
+ def __init__(self, model: OnnxModel):
91
+ super().__init__(model, "", "GroupNorm")
92
+
93
+ def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
94
+ """Append a Transpose node after an input"""
95
+ node_name = self.model.create_node_name("Transpose")
96
+ if output_name is None:
97
+ output_name = node_name + "_out" + "-" + input_name
98
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
99
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
100
+ return transpose_node
101
+
102
+ def fuse(
103
+ self,
104
+ group_norm_node: NodeProto,
105
+ input_name_to_nodes: Dict[str, List[NodeProto]],
106
+ output_name_to_node: Dict[str, NodeProto],
107
+ ):
108
+ """
109
+ This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
110
+ another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
111
+ Before:
112
+ --> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
113
+ After:
114
+ --> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
115
+ """
116
+ gemm_path = self.model.match_parent_path(
117
+ group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
118
+ )
119
+ if gemm_path is None:
120
+ return
121
+ transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
122
+ if self.model.find_graph_output(unsqueeze_3.output[0]):
123
+ return
124
+
125
+ permutation = OnnxModel.get_node_attribute(transpose, "perm")
126
+ assert isinstance(permutation, list)
127
+ if permutation != [0, 2, 3, 1]:
128
+ return
129
+
130
+ if not (
131
+ len(unsqueeze_3.input) == 2
132
+ and self.model.get_constant_value(unsqueeze_3.input[1]) == 3
133
+ and len(unsqueeze_2.input) == 2
134
+ and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
135
+ and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
136
+ and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
137
+ and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
138
+ ):
139
+ return
140
+
141
+ # Here we use hard-coded name so that it could be shared for the whole model.
142
+ axes_1 = "ort_const_unsqueeze_axes_1"
143
+ if self.model.get_initializer(axes_1) is None:
144
+ self.add_initializer(
145
+ name=axes_1,
146
+ data_type=TensorProto.INT64,
147
+ dims=[1],
148
+ vals=[1],
149
+ raw=False,
150
+ )
151
+
152
+ axes_2 = "ort_const_unsqueeze_axes_2"
153
+ if self.model.get_initializer(axes_2) is None:
154
+ self.add_initializer(
155
+ name=axes_2,
156
+ data_type=TensorProto.INT64,
157
+ dims=[1],
158
+ vals=[2],
159
+ raw=False,
160
+ )
161
+
162
+ unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
163
+ unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
164
+ transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
165
+ self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
166
+ new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
167
+ self.model.add_node(new_transpose, self.this_graph_name)
168
+ self.increase_counter("Insert Transpose")
@@ -0,0 +1,307 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Optional, Tuple
7
+
8
+ import numpy
9
+ from numpy import array_equal, ndarray
10
+ from onnx import NodeProto, TensorProto, helper, numpy_helper
11
+ from onnx import onnx_pb as onnx_proto
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionUtils:
18
+ def __init__(self, model: OnnxModel):
19
+ self.model: OnnxModel = model
20
+
21
+ def cast_graph_input_to_int32(self, input_name: str) -> Tuple[bool, str]:
22
+ graph_input = self.model.find_graph_input(input_name)
23
+ if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
24
+ cast_output, cast_node = self.cast_input_to_int32(input_name)
25
+ logger.debug(f"Casted graph input {input_name} to int32")
26
+ return True, cast_output
27
+
28
+ logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
29
+ return False, input_name
30
+
31
+ def cast_input(self, input_name: str, target_type="int32"):
32
+ output_name = input_name + "_" + target_type
33
+
34
+ if target_type == "int32":
35
+ to_type = int(TensorProto.INT32)
36
+ elif target_type == "float32":
37
+ to_type = int(TensorProto.FLOAT)
38
+ elif target_type == "float16":
39
+ to_type = int(TensorProto.FLOAT16)
40
+ else:
41
+ raise ValueError("Invalid target_type: {target_type}")
42
+
43
+ cast_node = self.add_cast_node(input_name, to_type, output_name)
44
+
45
+ return output_name, cast_node
46
+
47
+ def add_cast_node(
48
+ self,
49
+ input_name: str,
50
+ to_type: int,
51
+ output_name: Optional[str] = None,
52
+ output_name_to_node=None,
53
+ graph_name: Optional[str] = None,
54
+ ):
55
+ if output_name is None:
56
+ output_name = input_name + f"_cast_to_{to_type}"
57
+
58
+ # Avoid consequent Cast nodes.
59
+ inputs = [input_name]
60
+ if output_name_to_node is None:
61
+ output_name_to_node = self.model.output_name_to_node()
62
+ if input_name in output_name_to_node:
63
+ parent_node = output_name_to_node[input_name]
64
+ if parent_node and parent_node.op_type == "Cast":
65
+ inputs = [parent_node.input[0]]
66
+
67
+ cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name])
68
+
69
+ cast_node.attribute.extend([helper.make_attribute("to", to_type)])
70
+ self.model.add_node(cast_node, graph_name=graph_name)
71
+
72
+ return cast_node
73
+
74
+ def cast_input_to_int32(self, input_name: str):
75
+ return self.cast_input(input_name, "int32")
76
+
77
+ def remove_cast_int32(self, input_name: str):
78
+ input_name_to_nodes = self.model.input_name_to_nodes()
79
+ nodes = input_name_to_nodes[input_name]
80
+ for node in nodes:
81
+ if node.op_type == "Cast":
82
+ is_int32 = False
83
+ for att in node.attribute:
84
+ if att.name == "to" and att.i == int(TensorProto.INT32):
85
+ is_int32 = True
86
+ break
87
+ if is_int32:
88
+ output_name = node.output[0]
89
+ self.model.remove_node(node)
90
+ self.model.replace_input_of_all_nodes(output_name, input_name)
91
+
92
+ @staticmethod
93
+ def update_node_input(node, i, new_input_name, input_name_to_nodes):
94
+ old_input_reference = 0
95
+ if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]:
96
+ input_name_to_nodes[node.input[i]].remove(node)
97
+ old_input_reference = len(input_name_to_nodes[node.input[i]])
98
+
99
+ node.input[i] = new_input_name
100
+
101
+ if new_input_name in input_name_to_nodes:
102
+ input_name_to_nodes[new_input_name].append(node)
103
+ else:
104
+ input_name_to_nodes[new_input_name] = [node]
105
+
106
+ return old_input_reference
107
+
108
+ @staticmethod
109
+ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0):
110
+ """
111
+ Before:
112
+ (input)-->parent-->node-->(output)
113
+ After:
114
+ (input)-->parent-->
115
+ |
116
+ +----->node-->(output)
117
+
118
+ This function returns a flag whether the parent node can be removed.
119
+ """
120
+
121
+ old_input_name = node.input[node_input_index]
122
+ new_input_name = parent_node.input[parent_input_index]
123
+ old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes)
124
+
125
+ # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
126
+ parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name)
127
+
128
+ return parent_can_be_removed
129
+
130
+ @staticmethod
131
+ def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
132
+ """Verify that a node has expected value for an attribute.
133
+
134
+ Args:
135
+ node (NodeProto): a node to check
136
+ attribute_name (str): name of attribute
137
+ expected_value (Any): expected value of the attribute
138
+ default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
139
+
140
+ Returns:
141
+ bool: whether the check is passed or not
142
+ """
143
+ value = default_value
144
+ for attr in node.attribute:
145
+ if attr.name == attribute_name:
146
+ value = helper.get_attribute_value(attr)
147
+
148
+ if isinstance(expected_value, list):
149
+ return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
150
+ else:
151
+ return value == expected_value
152
+
153
+ @staticmethod
154
+ def transpose_2d_int8_tensor(tensor: onnx_proto.TensorProto):
155
+ """Transpose a 2-D INT8 TensorProto
156
+ Args:
157
+ tensor (TensorProto): tensor to be transposed
158
+ Returns:
159
+ tensor (TensorProto): transposed tensor
160
+ """
161
+ if not isinstance(tensor, onnx_proto.TensorProto):
162
+ raise ValueError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
163
+
164
+ if len(tensor.dims) != 2 or tensor.data_type != onnx_proto.TensorProto.INT8:
165
+ raise ValueError("Only INT8 2-D tensors can be transposed")
166
+
167
+ if tensor.raw_data:
168
+ int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
169
+ int32_transposed_data = numpy.transpose(int32_data, [1, 0])
170
+ tensor.raw_data = int32_transposed_data.tobytes()
171
+
172
+ else:
173
+ raise ValueError("only raw buffer supported")
174
+
175
+ return tensor
176
+
177
+ @staticmethod
178
+ def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
179
+ """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
180
+ It is a good candidate for fusion if:
181
+ (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
182
+ (2) The Q/DQ node should have constant scale
183
+ (3) The Q/DQ node should have a zero point of 0
184
+ Args:
185
+ node (NodeProto): a Q/DQ node to check
186
+ Returns:
187
+ bool: whether the check is passed or not
188
+ """
189
+ if node.op_type not in {"QuantizeLinear", "DequantizeLinear"}:
190
+ logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
191
+
192
+ scale = model.get_constant_value(node.input[1])
193
+
194
+ # Scale is not constant
195
+ if scale is None:
196
+ return False
197
+
198
+ # Not per-tensor quantization
199
+ scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
200
+ if allow_per_tensor_quantization_only and not scale_has_single_element:
201
+ return False
202
+
203
+ # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
204
+ if len(node.input) == 2:
205
+ return True
206
+
207
+ # Zero point should be constant and should have a value of 0
208
+ zero_point = model.get_constant_value(node.input[2])
209
+
210
+ # Zero point and scale should have same number of dims
211
+ if scale.ndim != zero_point.ndim:
212
+ return False
213
+
214
+ # Zero point is not constant or zero point is not zero
215
+ if zero_point is None:
216
+ return False
217
+
218
+ return numpy.all(zero_point == 0)
219
+
220
+ def check_node_input_value(self, node, input_index: int, expected_value):
221
+ """Verify that a node has expected input value
222
+
223
+ Args:
224
+ node (NodeProto): a node to check
225
+ input_index (int): index of its input to be verified
226
+ expected_value (Any): expected value of the input
227
+
228
+ Returns:
229
+ bool: whether the check is passed or not
230
+ """
231
+ assert len(node.input) > input_index
232
+
233
+ value = self.model.get_constant_value(node.input[input_index])
234
+
235
+ if isinstance(expected_value, list):
236
+ return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
237
+ else:
238
+ return value == expected_value
239
+
240
+ def remove_identity_nodes(self):
241
+ """Remove Identity nodes, except those right before graph output."""
242
+ nodes_to_remove = []
243
+ graph_output_names = self.model.get_graphs_output_names()
244
+ for node in self.model.nodes():
245
+ if node.op_type == "Identity":
246
+ if node.output[0] not in graph_output_names:
247
+ self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
248
+ nodes_to_remove.append(node)
249
+
250
+ if nodes_to_remove:
251
+ self.model.remove_nodes(nodes_to_remove)
252
+ logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
253
+
254
+ def remove_cascaded_cast_nodes(self):
255
+ self.model.remove_cascaded_cast_nodes()
256
+
257
+ def remove_useless_cast_nodes(self):
258
+ self.model.remove_useless_cast_nodes()
259
+
260
+ def remove_useless_reshape_nodes(self):
261
+ """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
262
+ shape_infer = self.model.infer_runtime_shape(update=True)
263
+ if shape_infer is None:
264
+ return
265
+
266
+ nodes_to_remove = []
267
+ for node in self.model.nodes():
268
+ if node.op_type == "Reshape":
269
+ input_shape = shape_infer.get_edge_shape(node.input[0])
270
+ output_shape = shape_infer.get_edge_shape(node.output[0])
271
+ if input_shape and output_shape and input_shape == output_shape:
272
+ logger.info(
273
+ f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
274
+ )
275
+ nodes_to_remove.append(node)
276
+
277
+ if nodes_to_remove:
278
+ graph_input_names = set(self.model.get_graphs_input_names())
279
+ graph_output_names = set(self.model.get_graphs_output_names())
280
+ for node in nodes_to_remove:
281
+ if bool(set(node.output) & graph_output_names):
282
+ if (
283
+ not bool(set(node.input) & graph_input_names)
284
+ and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
285
+ ):
286
+ self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
287
+ else:
288
+ continue
289
+ else:
290
+ self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
291
+ self.model.remove_node(node)
292
+
293
+
294
+ class NumpyHelper:
295
+ @staticmethod
296
+ def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
297
+ # When weights are in external data format but not presented, we can still test the optimizer with two changes:
298
+ # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
299
+ if fill_zeros:
300
+ from onnx import mapping
301
+
302
+ return ndarray(
303
+ shape=tensor.dims,
304
+ dtype=mapping.TENSOR_TYPE_TO_NP_TYPE[tensor.data_type],
305
+ )
306
+
307
+ return numpy_helper.to_array(tensor)