onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,209 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import NumpyHelper
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionSkipLayerNormalization(Fusion):
17
+ """
18
+ Fuse Add + LayerNormalization into one node: SkipLayerNormalization
19
+ Note: This fusion does not check the input shape of Add and LayerNormalization.
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model: OnnxModel,
25
+ fused_op_type: str = "SkipLayerNormalization",
26
+ search_op_types: str = "LayerNormalization",
27
+ shape_infer: bool = True,
28
+ ):
29
+ super().__init__(model, fused_op_type, search_op_types)
30
+ if shape_infer:
31
+ # Update shape inference is needed since other fusions might add new edge which does not have shape info yet.
32
+ self.shape_infer_helper = self.model.infer_runtime_shape({"batch_size": 4, "seq_len": 7}, update=True)
33
+ if self.shape_infer_helper is None:
34
+ # TODO(tianleiwu): support subgraph in shape inference or add broadcasting in SkipLayerNormalization op.
35
+ logger.warning("symbolic shape inference disabled or failed.")
36
+
37
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
38
+ add = self.model.get_parent(node, 0, output_name_to_node)
39
+
40
+ # In some models there is input_ids->gather->add->LayerNorm and one of input of the
41
+ # add node is initializer with fixed shape which should not be fused into SkipLayerNorm
42
+ if add is None or add.op_type != "Add":
43
+ return
44
+
45
+ # The number of inputs of add should be 2
46
+ if len(add.input) != 2:
47
+ return
48
+
49
+ for add_input in add.input:
50
+ if self.model.get_initializer(add_input) is not None:
51
+ return
52
+
53
+ # To avoid an Add node have two children of LayerNormalization, we shall only fuse one SkipLayerNormalization
54
+ if add in self.nodes_to_remove:
55
+ return
56
+
57
+ # Root Mean Square Layer Normalization
58
+ simplified = node.op_type == "SimplifiedLayerNormalization"
59
+
60
+ if hasattr(self, "shape_infer_helper"):
61
+ if self.shape_infer_helper is not None:
62
+ if (
63
+ self.shape_infer_helper.get_edge_shape(add.input[0])
64
+ and len(self.shape_infer_helper.get_edge_shape(add.input[0])) != 3
65
+ ):
66
+ logger.debug("skip SkipLayerNormalization fusion since shape of input %s is not 3D", add.input[0])
67
+ return
68
+
69
+ # TODO(tianleiwu): support broadcasting Skip shape (1, sequence_length, hidden_size) or (sequence_length, hidden_size)
70
+ if not self.shape_infer_helper.compare_shape(add.input[0], add.input[1]):
71
+ logger.debug(
72
+ "skip SkipLayerNormalization fusion since shape of inputs (%s, %s) are not same",
73
+ add.input[0],
74
+ add.input[1],
75
+ )
76
+ return
77
+ else:
78
+ logger.debug("skip SkipLayerNormalization fusion since symbolic shape inference failed")
79
+ return
80
+
81
+ gather_path = self.model.match_parent_path(add, ["Gather"], [None])
82
+ if gather_path is not None and self.model.find_graph_input(gather_path[0].input[1]) is None:
83
+ if self.model.match_parent_path(gather_path[0], ["ConstantOfShape"], [1]) is None:
84
+ return
85
+
86
+ # This means that the residual Add before the LayerNormalization produces an output
87
+ # that is consumed by some other nodes or graph output other than the LayerNormalization itself
88
+ # We can still go ahead with the SkipLayerNormalization fusion but we need to
89
+ # preserve the output of Add and that needs to be produced by SkipLayerNormalization.
90
+ add_has_graph_output = self.model.find_graph_output(add.output[0]) is not None
91
+ residual_add_has_multiple_consumers = (
92
+ add_has_graph_output or len(self.model.get_children(add, input_name_to_nodes)) > 1
93
+ )
94
+
95
+ outputs_to_keep = node.output
96
+
97
+ if residual_add_has_multiple_consumers:
98
+ outputs_to_keep.extend([add.output[0]])
99
+
100
+ outputs = [node.output[0]]
101
+
102
+ # Skip the other optional outputs of SkipLayerNormalization before adding the Add's output
103
+ if residual_add_has_multiple_consumers:
104
+ outputs.extend(["", "", add.output[0]])
105
+
106
+ if self.model.is_safe_to_fuse_nodes([add, node], outputs_to_keep, input_name_to_nodes, output_name_to_node):
107
+ self.nodes_to_remove.extend([add, node])
108
+
109
+ inputs = (
110
+ [add.input[0], add.input[1], node.input[1], node.input[2]]
111
+ if not simplified
112
+ else [add.input[0], add.input[1], node.input[1]]
113
+ )
114
+ normalize_node = helper.make_node(
115
+ self.fused_op_type,
116
+ inputs=inputs,
117
+ outputs=outputs,
118
+ name=self.model.create_node_name(self.fused_op_type, name_prefix="SkipLayerNorm"),
119
+ )
120
+ normalize_node.domain = "com.microsoft"
121
+
122
+ # Pass attribute "epsilon" from layernorm node to SkipLayerNormalization
123
+ for att in node.attribute:
124
+ if att.name == "epsilon":
125
+ normalize_node.attribute.extend([att])
126
+
127
+ # Set default epsilon if no epsilon exists from layernorm
128
+ if len(normalize_node.attribute) == 0:
129
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
130
+
131
+ self.nodes_to_add.append(normalize_node)
132
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
133
+
134
+
135
+ class FusionBiasSkipLayerNormalization(Fusion):
136
+ def __init__(self, model: OnnxModel):
137
+ super().__init__(model, "SkipLayerNormalization", "SkipLayerNormalization", "add bias")
138
+
139
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
140
+ if len(node.input) != 4:
141
+ return
142
+
143
+ return_indice = []
144
+ nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [None, None], output_name_to_node, return_indice)
145
+ if nodes is not None:
146
+ (add, _matmul) = nodes
147
+ else:
148
+ # In case of fp16, we could have a Cast between the MatMul and the bias Add
149
+ return_indice = []
150
+ nodes = self.model.match_parent_path(
151
+ node, ["Add", "Cast", "MatMul"], [None, None, None], output_name_to_node, return_indice
152
+ )
153
+ if nodes is not None:
154
+ (add, _cast, _matmul) = nodes
155
+ else:
156
+ return
157
+
158
+ assert len(return_indice) == 2 or len(return_indice) == 3
159
+ add_input_index = return_indice[0]
160
+ if add_input_index >= 2:
161
+ return
162
+ sln_input = add.input[return_indice[1]]
163
+ bias_input = add.input[1 - return_indice[1]]
164
+ skip_input = node.input[1 - add_input_index]
165
+
166
+ # bias should be one dimension
167
+ initializer = self.model.get_initializer(bias_input)
168
+ if initializer is None:
169
+ return
170
+ bias_weight = NumpyHelper.to_array(initializer)
171
+ if bias_weight is None:
172
+ logger.debug("Bias weight not found")
173
+ return
174
+ if len(bias_weight.shape) != 1:
175
+ logger.debug("Bias weight is not 1D")
176
+ return
177
+
178
+ subgraph_nodes = [node, add]
179
+ if not self.model.is_safe_to_fuse_nodes(subgraph_nodes, node.output, input_name_to_nodes, output_name_to_node):
180
+ logger.debug("Skip fusing SkipLayerNormalization with Bias since it is not safe")
181
+ return
182
+
183
+ self.nodes_to_remove.extend(subgraph_nodes)
184
+ inputs = [
185
+ sln_input,
186
+ skip_input,
187
+ node.input[2],
188
+ node.input[3],
189
+ bias_input,
190
+ ]
191
+ new_node = helper.make_node(
192
+ "SkipLayerNormalization",
193
+ inputs=inputs,
194
+ outputs=node.output,
195
+ name=self.model.create_node_name("SkipLayerNormalization", "SkipLayerNorm_AddBias_"),
196
+ )
197
+ new_node.domain = "com.microsoft"
198
+
199
+ # Pass attribute "epsilon" from skiplayernorm node to skiplayernorm(add bias)
200
+ for att in node.attribute:
201
+ if att.name == "epsilon":
202
+ new_node.attribute.extend([att])
203
+
204
+ # Set default epsilon if no epsilon exists from skiplayernorm
205
+ if len(new_node.attribute) == 0:
206
+ new_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
207
+
208
+ self.nodes_to_add.append(new_node)
209
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
@@ -0,0 +1,167 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import NodeProto, TensorProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionTranspose(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "Transpose", "Transpose")
19
+
20
+ def fuse(
21
+ self,
22
+ transpose_node: NodeProto,
23
+ input_name_to_nodes: dict[str, list[NodeProto]],
24
+ output_name_to_node: dict[str, NodeProto],
25
+ ):
26
+ """
27
+ Note that onnxruntime will do comprehensive transpose optimization after loading model.
28
+ The purpose of this fusion is to make graph clean before running onnxruntime.
29
+
30
+ Case 1:
31
+ (input)-->Transpose(perm=a)-->Transpose(perm=b)-->
32
+ After:
33
+ (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
34
+ |
35
+ +----->Transpose(perm=a*b)-->
36
+
37
+ Case 2 (Cast has only one child):
38
+ (input)-->Transpose(perm=a)--> Cast -->Transpose(perm=b)-->
39
+ After:
40
+ (input)-->Transpose(perm=a)--> (this path can be removed if the output is not used anymore)
41
+ |
42
+ +----->Cast --> Transpose(perm=a*b)-->
43
+ """
44
+ transpose_b = transpose_node
45
+ if transpose_b.input[0] not in output_name_to_node:
46
+ return
47
+
48
+ transpose_a = output_name_to_node[transpose_b.input[0]]
49
+ if transpose_a.op_type != "Cast":
50
+ cast_node = None
51
+ else:
52
+ cast_node = transpose_a
53
+
54
+ cast_children = self.model.get_children(cast_node, input_name_to_nodes)
55
+ if cast_children and len(cast_children) > 1:
56
+ return
57
+
58
+ if cast_node.input[0] not in output_name_to_node:
59
+ return
60
+
61
+ transpose_a = output_name_to_node[cast_node.input[0]]
62
+
63
+ if transpose_a.op_type != "Transpose":
64
+ return
65
+
66
+ permutation = OnnxModel.get_node_attribute(transpose_b, "perm")
67
+ assert isinstance(permutation, list)
68
+
69
+ parent_permutation = OnnxModel.get_node_attribute(transpose_a, "perm")
70
+ assert isinstance(parent_permutation, list)
71
+
72
+ assert len(parent_permutation) == len(permutation)
73
+
74
+ output_permutation = []
75
+ for _j, index in enumerate(permutation):
76
+ output_permutation.append(parent_permutation[index])
77
+
78
+ if cast_node is None:
79
+ if FusionUtils.skip_parent(self.model, transpose_b, transpose_a, input_name_to_nodes):
80
+ self.nodes_to_remove.append(transpose_a)
81
+ else:
82
+ if FusionUtils.skip_parent(self.model, cast_node, transpose_a, input_name_to_nodes):
83
+ self.nodes_to_remove.append(transpose_a)
84
+ transpose_b.ClearField("attribute")
85
+ transpose_b.attribute.extend([helper.make_attribute("perm", output_permutation)])
86
+
87
+
88
+ class FusionInsertTranspose(Fusion):
89
+ def __init__(self, model: OnnxModel):
90
+ super().__init__(model, "", "GroupNorm")
91
+
92
+ def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
93
+ """Append a Transpose node after an input"""
94
+ node_name = self.model.create_node_name("Transpose")
95
+ if output_name is None:
96
+ output_name = node_name + "_out" + "-" + input_name
97
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
98
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
99
+ return transpose_node
100
+
101
+ def fuse(
102
+ self,
103
+ group_norm_node: NodeProto,
104
+ input_name_to_nodes: dict[str, list[NodeProto]],
105
+ output_name_to_node: dict[str, NodeProto],
106
+ ):
107
+ """
108
+ This optimization will insert an Transpose, and onnxruntime transpose optimizer will remove it together with
109
+ another Transpose so that we can get effect of reducing one Transpose after onnxruntime optimization.
110
+ Before:
111
+ --> Gemm --> Unsqueeze(axes=[2]) --> Unsqueeze(axes=[3]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
112
+ After:
113
+ --> Gemm --> Unsqueeze(axes=[1]) --> Unsqueeze(axes=[2]) -->Transpose([0,3,1,2]) --> Add --> Transpose([0,2,3,1]) --> GroupNorm
114
+ """
115
+ gemm_path = self.model.match_parent_path(
116
+ group_norm_node, ["Transpose", "Add", "Unsqueeze", "Unsqueeze", "Gemm"], [0, 0, None, 0, 0]
117
+ )
118
+ if gemm_path is None:
119
+ return
120
+ transpose, add, unsqueeze_3, unsqueeze_2, gemm = gemm_path
121
+ if self.model.find_graph_output(unsqueeze_3.output[0]):
122
+ return
123
+
124
+ permutation = OnnxModel.get_node_attribute(transpose, "perm")
125
+ assert isinstance(permutation, list)
126
+ if permutation != [0, 2, 3, 1]:
127
+ return
128
+
129
+ if not (
130
+ len(unsqueeze_3.input) == 2
131
+ and self.model.get_constant_value(unsqueeze_3.input[1]) == 3
132
+ and len(unsqueeze_2.input) == 2
133
+ and self.model.get_constant_value(unsqueeze_2.input[1]) == 2
134
+ and len(self.model.get_children(gemm, input_name_to_nodes)) == 1
135
+ and len(self.model.get_children(unsqueeze_3, input_name_to_nodes)) == 1
136
+ and len(self.model.get_children(unsqueeze_2, input_name_to_nodes)) == 1
137
+ ):
138
+ return
139
+
140
+ # Here we use hard-coded name so that it could be shared for the whole model.
141
+ axes_1 = "ort_const_unsqueeze_axes_1"
142
+ if self.model.get_initializer(axes_1) is None:
143
+ self.add_initializer(
144
+ name=axes_1,
145
+ data_type=TensorProto.INT64,
146
+ dims=[1],
147
+ vals=[1],
148
+ raw=False,
149
+ )
150
+
151
+ axes_2 = "ort_const_unsqueeze_axes_2"
152
+ if self.model.get_initializer(axes_2) is None:
153
+ self.add_initializer(
154
+ name=axes_2,
155
+ data_type=TensorProto.INT64,
156
+ dims=[1],
157
+ vals=[2],
158
+ raw=False,
159
+ )
160
+
161
+ unsqueeze_3.input[1] = "ort_const_unsqueeze_axes_2"
162
+ unsqueeze_2.input[1] = "ort_const_unsqueeze_axes_1"
163
+ transpose_output_name = self.model.create_node_name("Transpose") + "_NCHW"
164
+ self.model.replace_input_of_all_nodes(unsqueeze_3.output[0], transpose_output_name)
165
+ new_transpose = self.create_transpose_node(unsqueeze_3.output[0], [0, 3, 1, 2], transpose_output_name)
166
+ self.model.add_node(new_transpose, self.this_graph_name)
167
+ self.increase_counter("Insert Transpose")
@@ -0,0 +1,321 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy
8
+ from numpy import array_equal, ndarray
9
+ from onnx import NodeProto, TensorProto, helper, numpy_helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionUtils:
16
+ def __init__(self, model: OnnxModel):
17
+ self.model: OnnxModel = model
18
+
19
+ def cast_graph_input_to_int32(self, input_name: str) -> tuple[bool, str]:
20
+ graph_input = self.model.find_graph_input(input_name)
21
+ if graph_input is not None and graph_input.type.tensor_type.elem_type != TensorProto.INT32:
22
+ cast_output, cast_node = self.cast_input_to_int32(input_name)
23
+ logger.debug(f"Casted graph input {input_name} to int32")
24
+ return True, cast_output
25
+
26
+ logger.debug(f"Did not cast graph input {input_name} to int32: found {graph_input is not None}")
27
+ return False, input_name
28
+
29
+ def cast_input(self, input_name: str, target_type="int32"):
30
+ output_name = input_name + "_" + target_type
31
+
32
+ if target_type == "int32":
33
+ to_type = int(TensorProto.INT32)
34
+ elif target_type == "float32":
35
+ to_type = int(TensorProto.FLOAT)
36
+ elif target_type == "float16":
37
+ to_type = int(TensorProto.FLOAT16)
38
+ else:
39
+ raise ValueError("Invalid target_type: {target_type}")
40
+
41
+ cast_node = self.add_cast_node(input_name, to_type, output_name)
42
+
43
+ return output_name, cast_node
44
+
45
+ def add_cast_node(
46
+ self,
47
+ input_name: str,
48
+ to_type: int,
49
+ output_name: str | None = None,
50
+ output_name_to_node=None,
51
+ graph_name: str | None = None,
52
+ ):
53
+ if output_name is None:
54
+ output_name = input_name + f"_cast_to_{to_type}"
55
+
56
+ # Avoid consequent Cast nodes.
57
+ inputs = [input_name]
58
+ if output_name_to_node is None:
59
+ output_name_to_node = self.model.output_name_to_node()
60
+ if input_name in output_name_to_node:
61
+ parent_node = output_name_to_node[input_name]
62
+ if parent_node and parent_node.op_type == "Cast":
63
+ inputs = [parent_node.input[0]]
64
+
65
+ cast_node = helper.make_node("Cast", inputs=inputs, outputs=[output_name])
66
+
67
+ cast_node.attribute.extend([helper.make_attribute("to", to_type)])
68
+ self.model.add_node(cast_node, graph_name=graph_name)
69
+
70
+ return cast_node
71
+
72
+ def cast_input_to_int32(self, input_name: str):
73
+ return self.cast_input(input_name, "int32")
74
+
75
+ def remove_cast_int32(self, input_name: str):
76
+ input_name_to_nodes = self.model.input_name_to_nodes()
77
+ nodes = input_name_to_nodes[input_name]
78
+ for node in nodes:
79
+ if node.op_type == "Cast":
80
+ is_int32 = False
81
+ for att in node.attribute:
82
+ if att.name == "to" and att.i == int(TensorProto.INT32):
83
+ is_int32 = True
84
+ break
85
+ if is_int32:
86
+ output_name = node.output[0]
87
+ self.model.remove_node(node)
88
+ self.model.replace_input_of_all_nodes(output_name, input_name)
89
+
90
+ @staticmethod
91
+ def update_node_input(node, i, new_input_name, input_name_to_nodes):
92
+ old_input_reference = 0
93
+ if (node.input[i] in input_name_to_nodes) and node in input_name_to_nodes[node.input[i]]:
94
+ input_name_to_nodes[node.input[i]].remove(node)
95
+ old_input_reference = len(input_name_to_nodes[node.input[i]])
96
+
97
+ node.input[i] = new_input_name
98
+
99
+ if new_input_name in input_name_to_nodes:
100
+ input_name_to_nodes[new_input_name].append(node)
101
+ else:
102
+ input_name_to_nodes[new_input_name] = [node]
103
+
104
+ return old_input_reference
105
+
106
+ @staticmethod
107
+ def skip_parent(model: OnnxModel, node, parent_node, input_name_to_nodes, node_input_index=0, parent_input_index=0):
108
+ """
109
+ Before:
110
+ (input)-->parent-->node-->(output)
111
+ After:
112
+ (input)-->parent-->
113
+ |
114
+ +----->node-->(output)
115
+
116
+ This function returns a flag whether the parent node can be removed.
117
+ """
118
+
119
+ old_input_name = node.input[node_input_index]
120
+ new_input_name = parent_node.input[parent_input_index]
121
+ old_input_reference = FusionUtils.update_node_input(node, node_input_index, new_input_name, input_name_to_nodes)
122
+
123
+ # We can remove the first Transpose if its output is not used (linked to graph output or other nodes) anymore.
124
+ parent_can_be_removed = (old_input_reference == 0) and not model.find_graph_output(old_input_name)
125
+
126
+ return parent_can_be_removed
127
+
128
+ def get_squeeze_or_unsqueeze_axes(self, node: NodeProto) -> ndarray | None:
129
+ assert node.op_type in ["Squeeze", "Unsqueeze"]
130
+
131
+ # For opset >= 13, axes is an input instead of an attribute.
132
+ if len(node.input) > 1:
133
+ return self.model.get_constant_value(node.input[1])
134
+
135
+ axes = None
136
+ for attr in node.attribute:
137
+ if attr.name == "axes":
138
+ axes = helper.get_attribute_value(attr)
139
+ return axes
140
+
141
+ @staticmethod
142
+ def check_node_attribute(node, attribute_name: str, expected_value, default_value=None):
143
+ """Verify that a node has expected value for an attribute.
144
+
145
+ Args:
146
+ node (NodeProto): a node to check
147
+ attribute_name (str): name of attribute
148
+ expected_value (Any): expected value of the attribute
149
+ default_value (Any, optional): default value if the attribute does not exist. Defaults to None.
150
+
151
+ Returns:
152
+ bool: whether the check is passed or not
153
+ """
154
+ value = default_value
155
+ for attr in node.attribute:
156
+ if attr.name == attribute_name:
157
+ value = helper.get_attribute_value(attr)
158
+
159
+ if isinstance(expected_value, list):
160
+ return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
161
+ else:
162
+ return value == expected_value
163
+
164
+ @staticmethod
165
+ def transpose_2d_int8_tensor(tensor: TensorProto):
166
+ """Transpose a 2-D INT8 TensorProto
167
+ Args:
168
+ tensor (TensorProto): tensor to be transposed
169
+ Returns:
170
+ tensor (TensorProto): transposed tensor
171
+ """
172
+ if not isinstance(tensor, TensorProto):
173
+ raise TypeError(f"Expected input type is an ONNX TensorProto but got {type(tensor)}")
174
+
175
+ if len(tensor.dims) != 2 or tensor.data_type != TensorProto.INT8:
176
+ raise ValueError("Only INT8 2-D tensors can be transposed")
177
+
178
+ if tensor.raw_data:
179
+ int32_data = numpy.reshape(numpy.frombuffer(tensor.raw_data, dtype="int8"), tensor.dims)
180
+ int32_transposed_data = numpy.transpose(int32_data, [1, 0])
181
+ tensor.raw_data = int32_transposed_data.tobytes()
182
+
183
+ else:
184
+ raise ValueError("only raw buffer supported")
185
+
186
+ return tensor
187
+
188
+ @staticmethod
189
+ def check_qdq_node_for_fusion(node: NodeProto, model: OnnxModel, allow_per_tensor_quantization_only=True):
190
+ """Verify if a provided QuantizeLinear (Q) / DequantizeLinear (DQ) node is a good candidate for fusion.
191
+ It is a good candidate for fusion if:
192
+ (1) The Q/DQ node is for per-tensor quantization if allow_per_tensor_quantization_only is `True`
193
+ (2) The Q/DQ node should have constant scale
194
+ (3) The Q/DQ node should have a zero point of 0
195
+ Args:
196
+ node (NodeProto): a Q/DQ node to check
197
+ Returns:
198
+ bool: whether the check is passed or not
199
+ """
200
+ if node.op_type not in {"QuantizeLinear", "DequantizeLinear"}:
201
+ logger.debug(f"Provided node is not a Q/DQ node. Op Type: {node.op_type}")
202
+
203
+ scale = model.get_constant_value(node.input[1])
204
+
205
+ # Scale is not constant
206
+ if scale is None:
207
+ return False
208
+
209
+ # Not per-tensor quantization
210
+ scale_has_single_element = scale.ndim == 0 or (scale.ndim == 1 and scale.shape[0] == 1)
211
+ if allow_per_tensor_quantization_only and not scale_has_single_element:
212
+ return False
213
+
214
+ # If the Q/DQ node has no zero point input, it is assumed to be 0 (per ONNX spec)
215
+ if len(node.input) == 2:
216
+ return True
217
+
218
+ # Zero point should be constant and should have a value of 0
219
+ zero_point = model.get_constant_value(node.input[2])
220
+
221
+ # Zero point and scale should have same number of dims
222
+ if scale.ndim != zero_point.ndim:
223
+ return False
224
+
225
+ # Zero point is not constant or zero point is not zero
226
+ if zero_point is None:
227
+ return False
228
+
229
+ return numpy.all(zero_point == 0)
230
+
231
+ def check_node_input_value(self, node, input_index: int, expected_value):
232
+ """Verify that a node has expected input value
233
+
234
+ Args:
235
+ node (NodeProto): a node to check
236
+ input_index (int): index of its input to be verified
237
+ expected_value (Any): expected value of the input
238
+
239
+ Returns:
240
+ bool: whether the check is passed or not
241
+ """
242
+ assert len(node.input) > input_index
243
+
244
+ value = self.model.get_constant_value(node.input[input_index])
245
+
246
+ if isinstance(expected_value, list):
247
+ return (isinstance(value, (ndarray, list))) and array_equal(expected_value, value, equal_nan=False)
248
+ else:
249
+ return value == expected_value
250
+
251
+ def remove_identity_nodes(self):
252
+ """Remove Identity nodes, except those right before graph output."""
253
+ nodes_to_remove = []
254
+ graph_output_names = self.model.get_graphs_output_names()
255
+ for node in self.model.nodes():
256
+ if node.op_type == "Identity":
257
+ if node.output[0] not in graph_output_names:
258
+ self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
259
+ nodes_to_remove.append(node)
260
+
261
+ if nodes_to_remove:
262
+ self.model.remove_nodes(nodes_to_remove)
263
+ logger.info(f"Removed {len(nodes_to_remove)} Identity nodes")
264
+
265
+ def remove_cascaded_cast_nodes(self):
266
+ self.model.remove_cascaded_cast_nodes()
267
+
268
+ def remove_useless_cast_nodes(self):
269
+ self.model.remove_useless_cast_nodes()
270
+
271
+ def remove_useless_reshape_nodes(self):
272
+ """Remove reshape node that is not needed based on symbolic shape inference: input and output has same shape"""
273
+ shape_infer = self.model.infer_runtime_shape(update=True)
274
+ if shape_infer is None:
275
+ return
276
+
277
+ nodes_to_remove = []
278
+ for node in self.model.nodes():
279
+ if node.op_type == "Reshape":
280
+ input_shape = shape_infer.get_edge_shape(node.input[0])
281
+ output_shape = shape_infer.get_edge_shape(node.output[0])
282
+ if input_shape and output_shape and input_shape == output_shape:
283
+ logger.info(
284
+ f"Remove reshape node {node.name} since its input shape is same as output: {input_shape}"
285
+ )
286
+ nodes_to_remove.append(node)
287
+
288
+ if nodes_to_remove:
289
+ graph_input_names = set(self.model.get_graphs_input_names())
290
+ graph_output_names = set(self.model.get_graphs_output_names())
291
+ for node in nodes_to_remove:
292
+ if bool(set(node.output) & graph_output_names):
293
+ if (
294
+ not bool(set(node.input) & graph_input_names)
295
+ and len(self.model.input_name_to_nodes()[node.input[0]]) == 1 # parent has only one child
296
+ ):
297
+ self.model.replace_output_of_all_nodes(node.input[0], node.output[0])
298
+ else:
299
+ continue
300
+ else:
301
+ self.model.replace_input_of_all_nodes(node.output[0], node.input[0])
302
+ self.model.remove_node(node)
303
+
304
+
305
+ class NumpyHelper:
306
+ @staticmethod
307
+ def to_array(tensor: TensorProto, fill_zeros: bool = False) -> ndarray:
308
+ # When weights are in external data format but not presented, we can still test the optimizer with two changes:
309
+ # (1) set fill_zeros = True (2) change load_external_data=False in optimizer.py
310
+ if fill_zeros:
311
+ return ndarray(
312
+ shape=tensor.dims,
313
+ dtype=helper.tensor_dtype_to_np_dtype(tensor.data_type),
314
+ )
315
+
316
+ if tensor.data_type == TensorProto.BFLOAT16:
317
+ import onnx_ir as ir # noqa: PLC0415
318
+
319
+ # Use onnx_ir to correctly handle bfloat16 tensors
320
+ return ir.from_proto(tensor).numpy()
321
+ return numpy_helper.to_array(tensor)