onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,435 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ import numpy as np
8
+ from fusion_attention import AttentionMask, FusionAttention
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class FusionBartAttention(FusionAttention):
16
+ """
17
+ Fuse Bart Attention subgraph into one Attention node.
18
+ """
19
+
20
+ def __init__(
21
+ self,
22
+ model: OnnxModel,
23
+ hidden_size: int,
24
+ num_heads: int,
25
+ attention_mask: AttentionMask,
26
+ ):
27
+ super().__init__(model, hidden_size, num_heads, attention_mask)
28
+
29
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
30
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
31
+ qkv_nodes = self.model.match_parent_path(
32
+ normalize_node,
33
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
34
+ [1, 1, 0, 0, 0],
35
+ )
36
+ if qkv_nodes is not None:
37
+ (
38
+ add_out,
39
+ matmul_out,
40
+ reshape_qkv,
41
+ transpose_qkv,
42
+ matmul_qkv,
43
+ ) = qkv_nodes
44
+ else:
45
+ logger.debug("fuse_attention: failed to match qkv path")
46
+ return
47
+
48
+ other_inputs = []
49
+ for input_ in normalize_node.input:
50
+ if input_ not in output_name_to_node:
51
+ continue
52
+ if input_ == qkv_nodes[0].output[0]:
53
+ continue
54
+ other_inputs.append(input_)
55
+ if len(other_inputs) != 1:
56
+ return
57
+ root_input = other_inputs[0]
58
+
59
+ # Sometimes the input name to the attention MatMul nodes does not match the input name to the end
60
+ # SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
61
+ # nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
62
+ # children nodes for each of its output names.
63
+ """
64
+ root_input
65
+ +---------------------------------------------------+
66
+ | |
67
+ | |
68
+ SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
69
+ """
70
+ skip_layernorm = output_name_to_node[root_input]
71
+ # For some attention blocks, the end SkipLayerNormalization node may point to another node whose
72
+ # child is the LayerNormalization node.
73
+ if skip_layernorm.op_type in {"Add", "Clip"}:
74
+ skip_layernorm = self.model.get_children(skip_layernorm)[0]
75
+ for output in skip_layernorm.output:
76
+ if not output:
77
+ continue
78
+ children = input_name_to_nodes[output]
79
+ children_types = [child.op_type for child in children]
80
+ if children_types.count("MatMul") >= 1:
81
+ root_input = output
82
+ break
83
+
84
+ graph_input_names = {node.name for node in self.model.graph().input}
85
+ graph_output_names = {node.name for node in self.model.graph().output}
86
+
87
+ v_nodes_past_or_present = self.model.match_parent_path(
88
+ matmul_qkv,
89
+ ["Transpose", "Reshape", "Add", "MatMul"],
90
+ [1, 0, 0, None],
91
+ )
92
+ v_nodes_with_past = self.model.match_parent_path(
93
+ matmul_qkv,
94
+ ["Concat", "Transpose", "Reshape", "Add", "MatMul"],
95
+ [1, 1, 0, 0, None],
96
+ )
97
+ v_nodes_past_only_oai = self.model.match_parent_path(
98
+ matmul_qkv,
99
+ ["Transpose", "Reshape", "Reshape", "Transpose"],
100
+ [1, 0, 0, 0],
101
+ )
102
+ past_v, present_v = "", ""
103
+ v_nodes, add_v, matmul_v = [], None, None
104
+ if v_nodes_past_or_present is not None:
105
+ v_nodes = v_nodes_past_or_present
106
+ (transpose_v, reshape_v, add_v, matmul_v) = v_nodes
107
+
108
+ # Find past_v input name
109
+ start_child_nodes = input_name_to_nodes[add_v.output[0]]
110
+ for start_child_node in start_child_nodes:
111
+ if start_child_node.op_type == "Concat":
112
+ concat_v_nodes = self.model.match_parent_path(
113
+ start_child_node,
114
+ ["Reshape", "Transpose"],
115
+ [0, 0],
116
+ )
117
+ if concat_v_nodes is not None:
118
+ past_v = concat_v_nodes[-1].input[0]
119
+ start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
120
+ break
121
+
122
+ # Find present_v output name
123
+ for start_child_node in start_child_nodes:
124
+ start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
125
+ for start_grandchild_node in start_grandchild_nodes:
126
+ if start_grandchild_node.output[0] in graph_output_names:
127
+ present_v = start_grandchild_node.output[0]
128
+ break
129
+ if present_v != "":
130
+ break
131
+ elif v_nodes_with_past is not None:
132
+ v_nodes = v_nodes_with_past
133
+ (concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes
134
+ past_v = concat_v.input[0]
135
+ present_v = concat_v.output[0]
136
+ elif matmul_qkv.input[1] in graph_input_names:
137
+ # Hugging Face's cross-attention where past_v is used directly as value
138
+ past_v = matmul_qkv.input[1]
139
+ elif v_nodes_past_only_oai is not None:
140
+ # OpenAI's cross-attention where past_v is used directly as value
141
+ v_nodes = v_nodes_past_only_oai
142
+ past_v = v_nodes[-1].input[0]
143
+ else:
144
+ logger.debug("fuse_attention: failed to match v path")
145
+ return
146
+ past_v = past_v if past_v in graph_input_names else ""
147
+ present_v = present_v if present_v in graph_output_names else ""
148
+
149
+ qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
150
+ qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
151
+ qk_nodes, add_qk = [], None
152
+ if qk_nodes_no_mask is not None:
153
+ _, matmul_qk = qk_nodes_no_mask
154
+ qk_nodes = qk_nodes_no_mask
155
+ elif qk_nodes_with_mask is not None:
156
+ _, add_qk, matmul_qk = qk_nodes_with_mask
157
+ qk_nodes = qk_nodes_with_mask
158
+ else:
159
+ logger.debug("fuse_attention: failed to match qk path")
160
+ return
161
+
162
+ q_nodes_hf = self.model.match_parent_path(
163
+ matmul_qk,
164
+ ["Transpose", "Reshape", "Mul", "Add", "MatMul"],
165
+ [0, 0, 0, 0, 1],
166
+ )
167
+ q_nodes_oai = self.model.match_parent_path(
168
+ matmul_qk,
169
+ ["Mul", "Transpose", "Reshape", "Add", "MatMul"],
170
+ [0, 0, 0, 0, 1],
171
+ )
172
+ q_nodes = []
173
+ if q_nodes_hf is not None:
174
+ q_nodes = q_nodes_hf
175
+ (transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
176
+ elif q_nodes_oai is not None:
177
+ q_nodes = q_nodes_oai
178
+ (mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
179
+ else:
180
+ logger.debug("fuse_attention: failed to match q path")
181
+ return
182
+
183
+ k_nodes_no_past_hf = self.model.match_parent_path(
184
+ matmul_qk,
185
+ ["Transpose", "Reshape", "MatMul"],
186
+ [1, 0, 0],
187
+ )
188
+ k_nodes_with_past_hf = self.model.match_parent_path(
189
+ matmul_qk,
190
+ ["Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
191
+ [1, 0, 1, 0, 0],
192
+ )
193
+ k_nodes_past_or_present_oai = self.model.match_parent_path(
194
+ matmul_qk,
195
+ ["Mul", "Transpose", "Reshape", "MatMul"],
196
+ [1, 0, 0, 0],
197
+ )
198
+ k_nodes_past_only_oai = self.model.match_parent_path(
199
+ matmul_qk,
200
+ ["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
201
+ [1, 0, 0, 0, 0],
202
+ )
203
+ past_k, present_k = "", ""
204
+ k_nodes, add_k, matmul_k = [], None, None
205
+ if k_nodes_no_past_hf is not None:
206
+ k_nodes = k_nodes_no_past_hf
207
+ (transpose_k, reshape_k, matmul_k) = k_nodes
208
+
209
+ # Find present_k output name
210
+ transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]]
211
+ for transpose_k_node in transpose_k_nodes:
212
+ if transpose_k_node.output[0] in graph_output_names:
213
+ present_k = transpose_k_node.output[0]
214
+ break
215
+ elif k_nodes_with_past_hf is not None:
216
+ k_nodes = k_nodes_with_past_hf
217
+ (_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes
218
+ past_k = concat_k.input[0]
219
+ present_k = concat_k.output[0]
220
+ elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names:
221
+ # Hugging Face's cross-attention where past_k is used directly as key
222
+ k_nodes = [output_name_to_node[matmul_qk.input[1]]]
223
+ past_k = k_nodes[0].input[0]
224
+ elif k_nodes_past_or_present_oai is not None:
225
+ k_nodes = k_nodes_past_or_present_oai
226
+ (_, transpose_k, reshape_k, matmul_k) = k_nodes
227
+
228
+ # Find past_k input name
229
+ start_child_nodes = input_name_to_nodes[matmul_k.output[0]]
230
+ for start_child_node in start_child_nodes:
231
+ if start_child_node.op_type == "Concat":
232
+ concat_k_nodes = self.model.match_parent_path(
233
+ start_child_node,
234
+ ["Reshape", "Transpose"],
235
+ [0, 0],
236
+ )
237
+ if concat_k_nodes is not None:
238
+ past_k = concat_k_nodes[-1].input[0]
239
+ start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
240
+ break
241
+
242
+ # Find present_k output name
243
+ for start_child_node in start_child_nodes:
244
+ start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
245
+ for start_grandchild_node in start_grandchild_nodes:
246
+ if start_grandchild_node.output[0] in graph_output_names:
247
+ present_k = start_grandchild_node.output[0]
248
+ break
249
+ if present_k != "":
250
+ break
251
+ elif k_nodes_past_only_oai is not None:
252
+ # OpenAI's cross-attention where past_k is used directly as key
253
+ k_nodes = k_nodes_past_only_oai
254
+ past_k = k_nodes[-1].input[0]
255
+ else:
256
+ logger.debug("fuse_attention: failed to match k path")
257
+ return
258
+ past_k = past_k if past_k in graph_input_names else ""
259
+ present_k = present_k if present_k in graph_output_names else ""
260
+
261
+ if matmul_k is not None and add_k is None:
262
+ # Create empty Add node for attention graph
263
+ add_v_tensor = self.model.get_initializer(add_v.input[0])
264
+ bias_dim = add_v_tensor.dims[0]
265
+ dtype = add_v_tensor.data_type
266
+ empty_bias_name = "empty_bias"
267
+ empty_tensor = self.model.get_initializer(empty_bias_name)
268
+ if empty_tensor is None:
269
+ self.add_initializer(
270
+ empty_bias_name,
271
+ dtype,
272
+ dims=[bias_dim],
273
+ vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)),
274
+ )
275
+
276
+ add_name = self.model.create_node_name("Add")
277
+ add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name)
278
+
279
+ three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None
280
+ one_root_input = (
281
+ not three_root_inputs
282
+ and matmul_q.input[0] == root_input
283
+ and matmul_k.input[0] == root_input
284
+ and matmul_v.input[0] == root_input
285
+ )
286
+ two_root_inputs = (
287
+ not three_root_inputs
288
+ and matmul_q.input[0] == root_input
289
+ and matmul_k.input[0] == matmul_v.input[0]
290
+ and matmul_k.input[0] != matmul_q.input[0]
291
+ )
292
+
293
+ # There are 5 types of attention:
294
+ # 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask
295
+ # 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask
296
+ # 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask
297
+ # 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value
298
+ # 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask
299
+ encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask
300
+ decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask
301
+ decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask
302
+ decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v)
303
+ decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask
304
+
305
+ # For decoder self-attentions, the attention mask needs to be included in the attention node
306
+ causal_mask = qk_nodes == qk_nodes_with_mask
307
+ mask_nodes = []
308
+ if causal_mask:
309
+ mask_nodes_bart = self.model.match_parent_path(
310
+ add_qk,
311
+ ["Where"],
312
+ [1],
313
+ )
314
+ mask_nodes_whisper_hf = self.model.match_parent_path(
315
+ add_qk,
316
+ ["Slice", "Expand", "Where"],
317
+ [1, 0, 1],
318
+ )
319
+ mask_nodes_whisper_oai = self.model.match_parent_path(
320
+ add_qk,
321
+ ["Slice", "Unsqueeze", "Gather", "Shape", "Add"],
322
+ [1, 2, 0, 0, 0],
323
+ )
324
+ mask_nodes_whisper_oai_unit_test = self.model.match_parent_path(
325
+ add_qk,
326
+ ["Slice", "Slice"],
327
+ [1, 0],
328
+ )
329
+ if mask_nodes_whisper_hf is not None:
330
+ mask_nodes = mask_nodes_whisper_hf
331
+ elif mask_nodes_whisper_oai is not None:
332
+ mask_nodes = mask_nodes_whisper_oai
333
+ elif mask_nodes_whisper_oai_unit_test is not None:
334
+ mask_nodes = mask_nodes_whisper_oai_unit_test
335
+ elif mask_nodes_bart is not None:
336
+ mask_nodes = mask_nodes_bart
337
+ else:
338
+ logger.debug("fuse_attention: failed to match mask nodes")
339
+ return
340
+ assert len(mask_nodes) > 0
341
+
342
+ if (
343
+ encoder_attention
344
+ or decoder_self_attention
345
+ or decoder_cross_attention
346
+ or decoder_self_attention_with_past
347
+ or decoder_cross_attention_with_past
348
+ ):
349
+ attention_last_node = reshape_qkv
350
+ num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
351
+
352
+ if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
353
+ logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
354
+ return
355
+
356
+ new_node = None
357
+ if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
358
+ # Note: Decoder attention with past key and past value is fused as multi-head attention
359
+ # rather than attention because multi-head attention supports separate past key and past
360
+ # value whereas attention supports concatenated past key and past value.
361
+ new_node = (
362
+ self.create_multihead_attention_node(
363
+ q_matmul=matmul_q,
364
+ k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k,
365
+ v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v,
366
+ q_add=add_q,
367
+ k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None,
368
+ v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None,
369
+ num_heads=num_heads,
370
+ hidden_size=hidden_size,
371
+ output=attention_last_node.output[0],
372
+ unidirectional=causal_mask,
373
+ past_k=past_k if decoder_self_attention_with_past else "",
374
+ past_v=past_v if decoder_self_attention_with_past else "",
375
+ present_k=present_k,
376
+ present_v=present_v,
377
+ )
378
+ if self.use_multi_head_attention
379
+ else None
380
+ )
381
+ else:
382
+ # Temporarily set multi-head attention flag to false
383
+ use_multi_head_attention_ground_truth = self.use_multi_head_attention
384
+ self.use_multi_head_attention = False
385
+ new_node = self.create_attention_node(
386
+ mask_index=None,
387
+ q_matmul=matmul_q,
388
+ k_matmul=matmul_k,
389
+ v_matmul=matmul_v,
390
+ q_add=add_q,
391
+ k_add=add_k,
392
+ v_add=add_v,
393
+ num_heads=num_heads,
394
+ hidden_size=hidden_size,
395
+ first_input=root_input,
396
+ output=attention_last_node.output[0],
397
+ causal=causal_mask,
398
+ past_k=past_k,
399
+ past_v=past_v,
400
+ present_k=present_k,
401
+ present_v=present_v,
402
+ )
403
+ self.use_multi_head_attention = use_multi_head_attention_ground_truth
404
+ if new_node is None:
405
+ logger.debug("fuse_attention: failed to create fused node")
406
+ return
407
+
408
+ self.nodes_to_add.append(new_node)
409
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
410
+
411
+ self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
412
+ self.nodes_to_remove.extend(qk_nodes)
413
+
414
+ # When using multi-head attention, keep MatMul nodes in original graph
415
+ if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
416
+ if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul":
417
+ q_nodes.pop()
418
+ if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul":
419
+ k_nodes.pop()
420
+ if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul":
421
+ v_nodes.pop()
422
+ if self.disable_multi_head_attention_bias:
423
+ if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add":
424
+ q_nodes.pop()
425
+ if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add":
426
+ k_nodes.pop()
427
+ if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add":
428
+ v_nodes.pop()
429
+
430
+ self.nodes_to_remove.extend(q_nodes)
431
+ self.nodes_to_remove.extend(k_nodes)
432
+ self.nodes_to_remove.extend(v_nodes)
433
+
434
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
435
+ self.prune_graph = True
@@ -0,0 +1,141 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from collections import defaultdict
6
+ from collections.abc import Sequence
7
+ from logging import getLogger
8
+ from typing import Any
9
+
10
+ import numpy as np
11
+ from onnx import NodeProto, TensorProto, helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class Fusion:
18
+ """
19
+ Base class for Graph Fusion
20
+ """
21
+
22
+ def __init__(
23
+ self,
24
+ model: OnnxModel,
25
+ fused_op_type: str,
26
+ search_op_types: str | list[str],
27
+ description: str = "",
28
+ ):
29
+ self.search_op_types: list[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
30
+ self.fused_op_type: str = fused_op_type
31
+ self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
32
+ self.model: OnnxModel = model
33
+ self.nodes_to_remove: list = []
34
+ self.nodes_to_add: list = []
35
+ self.prune_graph: bool = False
36
+ self.node_name_to_graph_name: dict = {}
37
+ self.this_graph_name: str | None = None
38
+ # It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
39
+ self.fused_count: defaultdict = defaultdict(int)
40
+
41
+ def increase_counter(self, fused_op_name: str):
42
+ """
43
+ Increase counter of a fused operator.
44
+ """
45
+ self.fused_count[fused_op_name] += 1
46
+
47
+ def fuse(
48
+ self,
49
+ node: NodeProto,
50
+ input_name_to_nodes: dict[str, list[NodeProto]],
51
+ output_name_to_node: dict[str, NodeProto],
52
+ ):
53
+ """Interface for fusion that starts from a node"""
54
+ raise NotImplementedError
55
+
56
+ def apply(self):
57
+ """
58
+ Apply graph fusion on the whole model graph.
59
+ It searched nodes of given operators, and start fusion on each of those nodes.
60
+ """
61
+ logger.debug(f"start {self.description} fusion...")
62
+ input_name_to_nodes = self.model.input_name_to_nodes()
63
+ output_name_to_node = self.model.output_name_to_node()
64
+
65
+ # This assumes that two search ops will not be fused at same time!
66
+ for search_op_type in self.search_op_types:
67
+ for node in self.model.get_nodes_by_op_type(search_op_type):
68
+ graph = self.model.get_graph_by_node(node)
69
+ if graph is None:
70
+ raise Exception("Can not find node in any graph")
71
+ self.this_graph_name = graph.name
72
+ self.fuse(node, input_name_to_nodes, output_name_to_node)
73
+
74
+ op_list = [node.op_type for node in self.nodes_to_add]
75
+ if self.fused_count:
76
+ for key, value in self.fused_count.items():
77
+ if value:
78
+ logger.info(f"Fused {key}: {value}")
79
+ else:
80
+ count = op_list.count(self.fused_op_type)
81
+ if count > 0:
82
+ logger.info(f"Fused {self.description}: {count}")
83
+
84
+ self.model.remove_nodes(self.nodes_to_remove)
85
+ self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
86
+
87
+ if self.prune_graph:
88
+ self.model.prune_graph()
89
+ elif self.nodes_to_remove or self.nodes_to_add:
90
+ self.model.update_graph()
91
+
92
+ def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
93
+ if raw:
94
+ if not isinstance(vals, np.ndarray):
95
+ np_type = helper.tensor_dtype_to_np_dtype(data_type)
96
+ bytes = np.array(vals, dtype=np_type).tobytes()
97
+ else:
98
+ bytes = vals.tobytes()
99
+ tensor = helper.make_tensor(
100
+ name=name,
101
+ data_type=data_type,
102
+ dims=dims,
103
+ vals=bytes,
104
+ raw=True,
105
+ )
106
+ else:
107
+ tensor = helper.make_tensor(
108
+ name=name,
109
+ data_type=data_type,
110
+ dims=dims,
111
+ vals=vals,
112
+ raw=False,
113
+ )
114
+
115
+ self.model.add_initializer(tensor, self.this_graph_name)
116
+ return tensor
117
+
118
+ def remove_initializer(self, tensor: TensorProto):
119
+ self.model.remove_initializer(tensor)
120
+
121
+ def add_nodes_to_remove(self, nodes: list[NodeProto]):
122
+ # Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
123
+ # When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
124
+ # is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
125
+ # iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
126
+ # Since path A's shared nodes are removed, path B's shared nodes are not removed because they
127
+ # were previously removed for path A. This causes an error to print in remove_node that a node
128
+ # has failed to be removed.
129
+ #
130
+ # To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
131
+ # We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
132
+ # be scenarios where the nodes need to be removed in a specific order and converting to a set would
133
+ # lose this order.
134
+ for node in nodes:
135
+ if node not in self.nodes_to_remove:
136
+ self.nodes_to_remove.append(node)
137
+
138
+ def add_nodes_to_remove_with_nodes_to_keep(self, nodes: list[NodeProto], nodes_to_keep: list[NodeProto]):
139
+ for node in nodes:
140
+ if node not in self.nodes_to_remove and node not in nodes_to_keep:
141
+ self.nodes_to_remove.append(node)
@@ -0,0 +1,57 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ from fusion_base import Fusion
8
+ from numpy import ndarray
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionBiasAdd(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "BiasAdd", "Add")
18
+
19
+ def fuse(self, add_node, input_name_to_nodes: dict, output_name_to_node: dict):
20
+ """
21
+ Fuse Add bias and Add skip connection into BiasAdd
22
+ """
23
+
24
+ nodes = self.model.match_parent_path(
25
+ add_node,
26
+ ["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
27
+ [0, None, 0, 0, 0],
28
+ output_name_to_node,
29
+ )
30
+
31
+ if nodes is None:
32
+ return
33
+
34
+ bias_node = nodes[0]
35
+ skip_layer_norm = nodes[-1]
36
+
37
+ # Check skip connection is from SkipLayerNormalization output
38
+ if add_node.input[1] not in skip_layer_norm.output:
39
+ return
40
+
41
+ bias_index, bias_value = self.model.get_constant_input(bias_node)
42
+ if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
43
+ return
44
+ if bias_value.ndim != 1:
45
+ return
46
+
47
+ self.nodes_to_remove.extend([add_node, bias_node])
48
+ node_name = self.model.create_node_name("BiasAdd")
49
+ fused_node = helper.make_node(
50
+ "BiasAdd",
51
+ inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
52
+ outputs=[add_node.output[0]],
53
+ name=node_name,
54
+ )
55
+ fused_node.domain = "com.microsoft"
56
+ self.nodes_to_add.append(fused_node)
57
+ self.node_name_to_graph_name[node_name] = self.this_graph_name