onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,311 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ from collections import deque
9
+
10
+ import onnx
11
+
12
+ from ..onnx_model import ONNXModel
13
+
14
+
15
+ class Fusion:
16
+ """
17
+ Base class for fusions.
18
+ """
19
+
20
+ def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
21
+ self.search_op_type: str = search_op_type
22
+ self.fused_op_type: str = fused_op_type
23
+ self.model: ONNXModel = model
24
+ self.nodes_to_remove: list = []
25
+ self.nodes_to_add: list = []
26
+
27
+ self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
28
+ self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
29
+
30
+ def fuse(
31
+ self,
32
+ node: onnx.NodeProto,
33
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
34
+ output_name_to_node: dict[str, onnx.NodeProto],
35
+ ):
36
+ """
37
+ Interface function for derived fusion classes. Tries to fuse a node sequence containing
38
+ the specified node.
39
+ """
40
+ raise NotImplementedError
41
+
42
+ def apply(self) -> bool:
43
+ """
44
+ Apply graph fusion on the entire model graph.
45
+ """
46
+ input_name_to_nodes = self.model.input_name_to_nodes()
47
+ output_name_to_node = self.model.output_name_to_node()
48
+
49
+ for node in self.model.nodes():
50
+ if node.op_type == self.search_op_type:
51
+ self.fuse(node, input_name_to_nodes, output_name_to_node)
52
+
53
+ self.model.remove_nodes(self.nodes_to_remove)
54
+ self.model.add_nodes(self.nodes_to_add)
55
+
56
+ graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
57
+
58
+ if graph_updated:
59
+ self.model.remove_unused_constant()
60
+
61
+ return graph_updated
62
+
63
+ def create_unique_node_name(self):
64
+ prefix = self._new_node_name_prefix
65
+
66
+ if self._new_node_name_suffix is None:
67
+ largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
68
+ self._new_node_name_suffix = largest_suffix + 1
69
+
70
+ new_name = f"{prefix}{self._new_node_name_suffix!s}"
71
+ self._new_node_name_suffix += 1
72
+
73
+ return new_name
74
+
75
+ @staticmethod
76
+ def is_safe_to_fuse_nodes(
77
+ nodes_to_remove: list[onnx.NodeProto],
78
+ keep_outputs: list[str],
79
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
80
+ output_name_to_node: dict[str, onnx.NodeProto],
81
+ ) -> bool:
82
+ for node_to_remove in nodes_to_remove:
83
+ for output_to_remove in node_to_remove.output:
84
+ if output_to_remove in keep_outputs:
85
+ continue
86
+
87
+ if output_to_remove in input_name_to_nodes:
88
+ for impacted_node in input_name_to_nodes[output_to_remove]:
89
+ if impacted_node not in nodes_to_remove:
90
+ # Not safe to remove nodes since output is used by impacted_node
91
+ return False
92
+ return True
93
+
94
+ @staticmethod
95
+ def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
96
+ for attr in node.attribute:
97
+ if attr.name == attribute_name:
98
+ value = onnx.helper.get_attribute_value(attr)
99
+ return value
100
+ return None
101
+
102
+ @staticmethod
103
+ def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
104
+ for index, input_name in enumerate(child_node.input):
105
+ if input_name == node_output:
106
+ return index
107
+ return -1
108
+
109
+ @staticmethod
110
+ def tensor_shape_to_list(tensor_type) -> list[int]:
111
+ shape_list = []
112
+ for d in tensor_type.shape.dim:
113
+ if d.HasField("dim_value"):
114
+ shape_list.append(d.dim_value) # known dimension
115
+ elif d.HasField("dim_param"):
116
+ shape_list.append(d.dim_param) # unknown dimension with symbolic name
117
+ else:
118
+ shape_list.append("?") # shall not happen
119
+ return shape_list
120
+
121
+ def get_constant_input(self, node: onnx.NodeProto):
122
+ for i, inp in enumerate(node.input):
123
+ value = self.model.get_constant_value(inp)
124
+ if value is not None:
125
+ return i, value
126
+
127
+ return None, None
128
+
129
+ def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
130
+ i, value = self.get_constant_input(node)
131
+ if value is not None and value.size == 1 and abs(value - expected_value) < delta:
132
+ return i
133
+
134
+ return -1
135
+
136
+ def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
137
+ return self.find_constant_input(node, expected_value, delta) >= 0
138
+
139
+ def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
140
+ value = self.model.get_constant_value(output_name)
141
+ if value is None:
142
+ return False # Not an initializer
143
+
144
+ if len(value.shape) != rank:
145
+ return False # Wrong dimensions
146
+
147
+ return True
148
+
149
+ def match_first_parent(
150
+ self,
151
+ node: onnx.NodeProto,
152
+ parent_op_type: str,
153
+ output_name_to_node: dict[str, onnx.NodeProto] | None = None,
154
+ exclude: list[onnx.NodeProto] = [], # noqa: B006
155
+ ) -> tuple[onnx.NodeProto | None, int | None]:
156
+ """
157
+ Find parent node based on constraints on op_type.
158
+
159
+ Args:
160
+ node: current node.
161
+ parent_op_type (str): constraint of parent node op_type.
162
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
163
+ exclude (list): list of nodes that are excluded (not allowed to match as parent).
164
+
165
+ Returns:
166
+ parent: The matched parent node. None if not found.
167
+ index: The input index of matched parent node. None if not found.
168
+ """
169
+ if output_name_to_node is None:
170
+ output_name_to_node = self.model.output_name_to_node()
171
+
172
+ for i, inp in enumerate(node.input):
173
+ if inp in output_name_to_node:
174
+ parent = output_name_to_node[inp]
175
+ if parent.op_type == parent_op_type and parent not in exclude:
176
+ return parent, i
177
+
178
+ return None, None
179
+
180
+ def match_parent(
181
+ self,
182
+ node: onnx.NodeProto,
183
+ parent_op_type: str,
184
+ input_index: int | None = None,
185
+ output_name_to_node: dict[str, onnx.NodeProto] | None = None,
186
+ exclude: list[onnx.NodeProto] = [], # noqa: B006
187
+ return_indice: list[int] | None = None,
188
+ ) -> onnx.NodeProto | None:
189
+ """
190
+ Find parent node based on constraints on op_type and index.
191
+ When input_index is None, we will find the first parent node based on constraints,
192
+ and return_indice will be appended the corresponding input index.
193
+
194
+ Args:
195
+ node (str): current node name.
196
+ parent_op_type (str): constraint of parent node op_type.
197
+ input_index (int or None): only check the parent given input index of current node.
198
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
199
+ exclude (list): list of nodes that are excluded (not allowed to match as parent).
200
+ return_indice (list): a list to append the input index when input_index is None.
201
+
202
+ Returns:
203
+ parent: The matched parent node.
204
+ """
205
+ assert node is not None
206
+ assert input_index is None or input_index >= 0
207
+
208
+ if output_name_to_node is None:
209
+ output_name_to_node = self.model.output_name_to_node()
210
+
211
+ if input_index is None:
212
+ parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
213
+ if return_indice is not None:
214
+ return_indice.append(index)
215
+ return parent
216
+
217
+ if input_index >= len(node.input):
218
+ # Input index out of bounds.
219
+ return None
220
+
221
+ parent = self.model.get_parent(node, input_index, output_name_to_node)
222
+ if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
223
+ return parent
224
+
225
+ return None
226
+
227
+ def match_parent_path(
228
+ self,
229
+ node: onnx.NodeProto,
230
+ parent_op_types: list[str],
231
+ parent_input_index: list[int] | None = None,
232
+ output_name_to_node: dict[str, onnx.NodeProto] | None = None,
233
+ return_indice: list[int] | None = None,
234
+ ) -> list[onnx.NodeProto] | None:
235
+ """
236
+ Find a sequence of input edges based on constraints on parent op_type and index.
237
+ When input_index is None, we will find the first parent node based on constraints,
238
+ and return_indice will be appended the corresponding input index.
239
+
240
+ Args:
241
+ node (str): current node name.
242
+ parent_op_types (str): constraint of parent node op_type of each input edge.
243
+ parent_input_index (list): constraint of input index of each input edge. None means no constraint.
244
+ output_name_to_node (dict): dictionary with output name as key, and node as value.
245
+ return_indice (list): a list to append the input index
246
+ When there is no constraint on input index of an edge.
247
+
248
+ Returns:
249
+ parents: a list of matched parent node.
250
+ """
251
+ if parent_input_index is not None:
252
+ assert len(parent_input_index) == len(parent_op_types)
253
+
254
+ if output_name_to_node is None:
255
+ output_name_to_node = self.model.output_name_to_node()
256
+
257
+ current_node = node
258
+ matched_parents = []
259
+ for i, op_type in enumerate(parent_op_types):
260
+ matched_parent = self.match_parent(
261
+ current_node,
262
+ op_type,
263
+ parent_input_index[i] if parent_input_index is not None else None,
264
+ output_name_to_node,
265
+ exclude=[],
266
+ return_indice=return_indice,
267
+ )
268
+ if matched_parent is None:
269
+ return None
270
+
271
+ matched_parents.append(matched_parent)
272
+ current_node = matched_parent
273
+
274
+ return matched_parents
275
+
276
+ def match_parent_paths(
277
+ self,
278
+ node: onnx.NodeProto,
279
+ paths: list[tuple[list[str], list[int]]],
280
+ output_name_to_node: dict[str, onnx.NodeProto],
281
+ ) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
282
+ """
283
+ Find a matching parent path to the given node.
284
+ """
285
+ for i, path in enumerate(paths):
286
+ return_indice = []
287
+ matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
288
+ if matched:
289
+ return i, matched, return_indice
290
+ return -1, None, None
291
+
292
+ def find_first_child_by_type(
293
+ self,
294
+ node: onnx.NodeProto,
295
+ child_type: str,
296
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
297
+ recursive: bool = True,
298
+ ) -> onnx.NodeProto | None:
299
+ children = self.model.get_children(node, input_name_to_nodes)
300
+ dq = deque(children)
301
+ while len(dq) > 0:
302
+ current_node = dq.pop()
303
+ if current_node.op_type == child_type:
304
+ return current_node
305
+
306
+ if recursive:
307
+ children = self.model.get_children(current_node, input_name_to_nodes)
308
+ for child in children:
309
+ dq.appendleft(child)
310
+
311
+ return None
@@ -0,0 +1,272 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License. See License.txt in the project root for
4
+ # license information.
5
+ # --------------------------------------------------------------------------
6
+ from __future__ import annotations
7
+
8
+ import onnx
9
+
10
+ from ..onnx_model import ONNXModel
11
+ from .fusion import Fusion
12
+
13
+
14
+ class FusionGelu(Fusion):
15
+ def __init__(self, model: ONNXModel):
16
+ super().__init__(model, "Gelu", "Erf")
17
+
18
+ def fuse(
19
+ self,
20
+ erf_node: onnx.NodeProto,
21
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
22
+ output_name_to_node: dict[str, onnx.NodeProto],
23
+ ):
24
+ """
25
+ Interface function that tries to fuse a node sequence containing an Erf node into a single
26
+ Gelu node.
27
+ """
28
+ if (
29
+ self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
30
+ or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
31
+ or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
32
+ ):
33
+ self.model.set_opset_import("com.microsoft", 1)
34
+
35
+ def fuse_1(
36
+ self,
37
+ erf_node: onnx.NodeProto,
38
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
39
+ output_name_to_node: dict[str, onnx.NodeProto],
40
+ ) -> bool:
41
+ """
42
+ This pattern is from PyTorch model
43
+ Fuse Gelu with Erf into one node:
44
+ Pattern 1:
45
+ +-------Mul(0.5)---------------------+
46
+ | |
47
+ | v
48
+ [root] --> Div -----> Erf --> Add --> Mul -->
49
+ (B=1.4142...) (1)
50
+
51
+ Pattern 2:
52
+ +------------------------------------+
53
+ | |
54
+ | v
55
+ [root] --> Div -----> Erf --> Add --> Mul -->Mul -->
56
+ (B=1.4142...) (1) (0.5)
57
+
58
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
59
+ """
60
+ if erf_node.output[0] not in input_name_to_nodes:
61
+ return False
62
+ children = input_name_to_nodes[erf_node.output[0]]
63
+ if len(children) != 1 or children[0].op_type != "Add":
64
+ return False
65
+ add_after_erf = children[0]
66
+
67
+ if not self.has_constant_input(add_after_erf, 1):
68
+ return False
69
+
70
+ if add_after_erf.output[0] not in input_name_to_nodes:
71
+ return False
72
+
73
+ children = input_name_to_nodes[add_after_erf.output[0]]
74
+ if len(children) != 1 or children[0].op_type != "Mul":
75
+ return False
76
+
77
+ mul_after_erf = children[0]
78
+
79
+ div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
80
+ if div is None:
81
+ return False
82
+
83
+ if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
84
+ return False
85
+
86
+ subgraph_input = div.input[0]
87
+
88
+ another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
89
+ if subgraph_input == mul_after_erf.input[another]: # pattern 2
90
+ children = input_name_to_nodes[mul_after_erf.output[0]]
91
+ if len(children) != 1 or children[0].op_type != "Mul":
92
+ return False
93
+ mul_half = children[0]
94
+ if not self.has_constant_input(mul_half, 0.5):
95
+ return False
96
+ subgraph_output = mul_half.output[0]
97
+ else: # pattern 1
98
+ mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
99
+ if mul_half is None:
100
+ return False
101
+
102
+ if not self.has_constant_input(mul_half, 0.5):
103
+ return False
104
+
105
+ if subgraph_input not in mul_half.input:
106
+ return False
107
+
108
+ subgraph_output = mul_after_erf.output[0]
109
+
110
+ subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
111
+ if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
112
+ return False
113
+
114
+ self.nodes_to_remove.extend(subgraph_nodes)
115
+ fused_node = onnx.helper.make_node(
116
+ "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
117
+ )
118
+ fused_node.domain = "com.microsoft"
119
+ self.nodes_to_add.append(fused_node)
120
+ return True
121
+
122
+ def fuse_2(
123
+ self,
124
+ erf_node: onnx.NodeProto,
125
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
126
+ output_name_to_node: dict[str, onnx.NodeProto],
127
+ ) -> bool:
128
+ """
129
+ This pattern is from Keras model
130
+ Fuse Gelu with Erf into one node:
131
+ +------------------------------------------+
132
+ | |
133
+ | v
134
+ [root] --> Div -----> Erf --> Add --> Mul -->Mul
135
+ (B=1.4142...) (A=1) (A=0.5)
136
+
137
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
138
+ """
139
+ if erf_node.output[0] not in input_name_to_nodes:
140
+ return False
141
+ children = input_name_to_nodes[erf_node.output[0]]
142
+ if len(children) != 1 or children[0].op_type != "Add":
143
+ return False
144
+ add_after_erf = children[0]
145
+
146
+ if not self.has_constant_input(add_after_erf, 1):
147
+ return False
148
+
149
+ if add_after_erf.output[0] not in input_name_to_nodes:
150
+ return False
151
+ children = input_name_to_nodes[add_after_erf.output[0]]
152
+ if len(children) != 1 or children[0].op_type != "Mul":
153
+ return False
154
+ mul_after_erf = children[0]
155
+
156
+ if not self.has_constant_input(mul_after_erf, 0.5):
157
+ return False
158
+
159
+ if mul_after_erf.output[0] not in input_name_to_nodes:
160
+ return False
161
+ children = input_name_to_nodes[mul_after_erf.output[0]]
162
+ if len(children) != 1 or children[0].op_type != "Mul":
163
+ return False
164
+ mul = children[0]
165
+
166
+ div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
167
+ if div is None:
168
+ return False
169
+
170
+ sqrt_node = None
171
+ if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
172
+ sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
173
+ if sqrt_node is None:
174
+ return False
175
+ if not self.has_constant_input(sqrt_node, 2.0):
176
+ return False
177
+
178
+ subgraph_input = div.input[0]
179
+
180
+ if subgraph_input not in mul.input:
181
+ return False
182
+
183
+ subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
184
+ if sqrt_node:
185
+ subgraph_nodes.append(sqrt_node)
186
+
187
+ if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
188
+ return False
189
+
190
+ self.nodes_to_remove.extend(subgraph_nodes)
191
+ fused_node = onnx.helper.make_node(
192
+ "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
193
+ )
194
+ fused_node.domain = "com.microsoft"
195
+ self.nodes_to_add.append(fused_node)
196
+ return True
197
+
198
+ def fuse_3(
199
+ self,
200
+ erf_node: onnx.NodeProto,
201
+ input_name_to_nodes: dict[str, list[onnx.NodeProto]],
202
+ output_name_to_node: dict[str, onnx.NodeProto],
203
+ ) -> bool:
204
+ """
205
+ This pattern is from TensorFlow model
206
+ Fuse Gelu with Erf into one node:
207
+ +----------------------------------------------+
208
+ | |
209
+ | v
210
+ [root] --> Mul -----> Erf --> Add --> Mul -->Mul
211
+ (A=0.7071067690849304) (B=1) (B=0.5)
212
+
213
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
214
+ """
215
+
216
+ if erf_node.output[0] not in input_name_to_nodes:
217
+ return False
218
+ children = input_name_to_nodes[erf_node.output[0]]
219
+ if len(children) != 1 or children[0].op_type != "Add":
220
+ return False
221
+ add_after_erf = children[0]
222
+
223
+ if not self.has_constant_input(add_after_erf, 1):
224
+ return False
225
+
226
+ if add_after_erf.output[0] not in input_name_to_nodes:
227
+ return False
228
+ children = input_name_to_nodes[add_after_erf.output[0]]
229
+ if len(children) != 1 or children[0].op_type != "Mul":
230
+ return False
231
+ mul_half = children[0]
232
+
233
+ if not self.has_constant_input(mul_half, 0.5):
234
+ return False
235
+
236
+ first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
237
+ if first_mul is None:
238
+ return False
239
+
240
+ i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
241
+ if i < 0:
242
+ return False
243
+
244
+ root_input_index = 1 - i
245
+ subgraph_input = first_mul.input[root_input_index]
246
+
247
+ if mul_half.output[0] not in input_name_to_nodes:
248
+ return False
249
+ children = input_name_to_nodes[mul_half.output[0]]
250
+ if len(children) != 1 or children[0].op_type != "Mul":
251
+ return False
252
+ last_mul = children[0]
253
+
254
+ if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
255
+ return False
256
+
257
+ subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
258
+ if not self.is_safe_to_fuse_nodes(
259
+ subgraph_nodes,
260
+ [last_mul.output[0]],
261
+ input_name_to_nodes,
262
+ output_name_to_node,
263
+ ):
264
+ return False
265
+
266
+ self.nodes_to_remove.extend(subgraph_nodes)
267
+ fused_node = onnx.helper.make_node(
268
+ "Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
269
+ )
270
+ fused_node.domain = "com.microsoft"
271
+ self.nodes_to_add.append(fused_node)
272
+ return True