onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,360 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict, Optional
7
+
8
+ from fusion_base import Fusion
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionFastGelu(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "FastGelu", "Tanh")
18
+
19
+ def fuse(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
20
+ if self.fuse_1(tanh_node, input_name_to_nodes, output_name_to_node):
21
+ return
22
+
23
+ if self.fuse_2(tanh_node, input_name_to_nodes, output_name_to_node):
24
+ return
25
+
26
+ if self.fuse_3(tanh_node, input_name_to_nodes, output_name_to_node):
27
+ return
28
+
29
+ def fuse_1(self, tanh_node, input_name_to_nodes, output_name_to_node) -> Optional[bool]:
30
+ """
31
+ Fuse Gelu with tanh into one node:
32
+ +---------------------------+
33
+ | |
34
+ | v
35
+ [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul
36
+ | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
37
+ | |
38
+ +------> Mul(B=0.5)--------------------------------------------+
39
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
40
+ """
41
+ if tanh_node.output[0] not in input_name_to_nodes:
42
+ return
43
+ children = input_name_to_nodes[tanh_node.output[0]]
44
+ if len(children) != 1 or children[0].op_type != "Add":
45
+ return
46
+ add_after_tanh = children[0]
47
+
48
+ if not self.model.has_constant_input(add_after_tanh, 1.0):
49
+ return
50
+
51
+ if add_after_tanh.output[0] not in input_name_to_nodes:
52
+ return
53
+ children = input_name_to_nodes[add_after_tanh.output[0]]
54
+ if len(children) != 1 or children[0].op_type != "Mul":
55
+ return
56
+ mul_after_tanh = children[0]
57
+
58
+ mul_half = self.model.match_parent(mul_after_tanh, "Mul", None, output_name_to_node)
59
+ if mul_half is None:
60
+ return
61
+
62
+ i = self.model.find_constant_input(mul_half, 0.5)
63
+ if i < 0:
64
+ return
65
+
66
+ root_input = mul_half.input[0 if i == 1 else 1]
67
+
68
+ # root_node could be None when root_input is graph input
69
+ root_node = self.model.get_parent(mul_half, 0 if i == 1 else 1, output_name_to_node)
70
+
71
+ mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
72
+ if mul_before_tanh is None:
73
+ return
74
+
75
+ i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
76
+ if i < 0:
77
+ return
78
+
79
+ add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
80
+ if add_before_tanh is None:
81
+ return
82
+
83
+ mul_after_pow = self.model.match_parent(
84
+ add_before_tanh,
85
+ "Mul",
86
+ None,
87
+ output_name_to_node,
88
+ exclude=[root_node] if root_node else [],
89
+ )
90
+ if mul_after_pow is None:
91
+ return
92
+
93
+ i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
94
+ if i < 0:
95
+ return
96
+
97
+ pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
98
+ if pow is None:
99
+ return
100
+
101
+ if not self.model.has_constant_input(pow, 3.0):
102
+ return
103
+
104
+ if pow.input[0] != root_input:
105
+ return
106
+
107
+ subgraph_nodes = [
108
+ mul_after_tanh,
109
+ mul_half,
110
+ add_after_tanh,
111
+ tanh_node,
112
+ mul_before_tanh,
113
+ add_before_tanh,
114
+ mul_after_pow,
115
+ pow,
116
+ ]
117
+ if not self.model.is_safe_to_fuse_nodes(
118
+ subgraph_nodes,
119
+ [mul_after_tanh.output[0]],
120
+ input_name_to_nodes,
121
+ output_name_to_node,
122
+ ):
123
+ return
124
+
125
+ self.nodes_to_remove.extend(subgraph_nodes)
126
+ fused_node = helper.make_node(
127
+ "FastGelu",
128
+ inputs=[root_input],
129
+ outputs=mul_after_tanh.output,
130
+ name=self.model.create_node_name("FastGelu"),
131
+ )
132
+ fused_node.domain = "com.microsoft"
133
+ self.nodes_to_add.append(fused_node)
134
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
135
+ return True
136
+
137
+ def fuse_2(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
138
+ """
139
+ This pattern is from Tensorflow model.
140
+ Fuse Gelu with tanh into one node:
141
+ +---------------------------+
142
+ | |
143
+ | v
144
+ [root] --> Pow --> Mul -----> Add --> Mul --> Tanh --> Add --> Mul(B=0.5)-->Mul-->
145
+ | (Y=3) (B=0.0447...) (B=0.7978...) (B=1) ^
146
+ | |
147
+ +---------------------------------------------------------------------------+
148
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
149
+ """
150
+ if tanh_node.output[0] not in input_name_to_nodes:
151
+ return
152
+ children = input_name_to_nodes[tanh_node.output[0]]
153
+ if len(children) != 1 or children[0].op_type != "Add":
154
+ return
155
+ add_after_tanh = children[0]
156
+
157
+ if not self.model.has_constant_input(add_after_tanh, 1.0):
158
+ return
159
+
160
+ if add_after_tanh.output[0] not in input_name_to_nodes:
161
+ return
162
+ children = input_name_to_nodes[add_after_tanh.output[0]]
163
+ if len(children) != 1 or children[0].op_type != "Mul":
164
+ return
165
+ mul_half = children[0]
166
+
167
+ i = self.model.find_constant_input(mul_half, 0.5)
168
+ if i < 0:
169
+ return
170
+
171
+ if mul_half.output[0] not in input_name_to_nodes:
172
+ return
173
+ children = input_name_to_nodes[mul_half.output[0]]
174
+ if len(children) != 1 or children[0].op_type != "Mul":
175
+ return
176
+ mul_after_mul_half = children[0]
177
+
178
+ root_node = self.model.get_parent(
179
+ mul_after_mul_half,
180
+ 0 if mul_after_mul_half.input[1] == mul_half.output[0] else 1,
181
+ output_name_to_node,
182
+ )
183
+ if root_node is None:
184
+ return
185
+
186
+ mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
187
+ if mul_before_tanh is None:
188
+ return
189
+
190
+ i = self.model.find_constant_input(mul_before_tanh, 0.7978, delta=0.0001)
191
+ if i < 0:
192
+ return
193
+
194
+ add_before_tanh = self.model.match_parent(mul_before_tanh, "Add", 0 if i == 1 else 1, output_name_to_node)
195
+ if add_before_tanh is None:
196
+ return
197
+
198
+ mul_after_pow = self.model.match_parent(add_before_tanh, "Mul", None, output_name_to_node, exclude=[root_node])
199
+ if mul_after_pow is None:
200
+ return
201
+
202
+ i = self.model.find_constant_input(mul_after_pow, 0.0447, delta=0.0001)
203
+ if i < 0:
204
+ return
205
+
206
+ pow = self.model.match_parent(mul_after_pow, "Pow", 0 if i == 1 else 1, output_name_to_node)
207
+ if pow is None:
208
+ return
209
+
210
+ if not self.model.has_constant_input(pow, 3.0):
211
+ return
212
+
213
+ if pow.input[0] != root_node.output[0]:
214
+ return
215
+
216
+ subgraph_nodes = [
217
+ mul_after_mul_half,
218
+ mul_half,
219
+ add_after_tanh,
220
+ tanh_node,
221
+ mul_before_tanh,
222
+ add_before_tanh,
223
+ mul_after_pow,
224
+ pow,
225
+ ]
226
+ if not self.model.is_safe_to_fuse_nodes(
227
+ subgraph_nodes,
228
+ [mul_after_mul_half.output[0]],
229
+ input_name_to_nodes,
230
+ output_name_to_node,
231
+ ):
232
+ return
233
+
234
+ self.nodes_to_remove.extend(subgraph_nodes)
235
+ fused_node = helper.make_node(
236
+ "FastGelu",
237
+ inputs=[root_node.output[0]],
238
+ outputs=mul_after_mul_half.output,
239
+ name=self.model.create_node_name("FastGelu"),
240
+ )
241
+ fused_node.domain = "com.microsoft"
242
+ self.nodes_to_add.append(fused_node)
243
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
244
+ return True
245
+
246
+ def fuse_3(self, tanh_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
247
+ """
248
+ OpenAI's gelu implementation, also used in Megatron:
249
+ Gelu(x) = x * 0.5 * (1.0 + torch.tanh(0.79788456 * x * (1.0 + 0.044715 * x * x)))
250
+
251
+ Fuse subgraph into a FastGelu node:
252
+ +------------ Mul (B=0.79788456) -------------------+
253
+ | |
254
+ +-------------------------------+ |
255
+ | | |
256
+ | v v
257
+ [root] --> Mul (B=0.044715) --> Mul --> Add(B=1) --> Mul --> Tanh --> Add(B=1) --> Mul-->
258
+ | ^
259
+ | |
260
+ +-----------> Mul (B=0.5) --------------------------------------------------------+
261
+ """
262
+ if tanh_node.output[0] not in input_name_to_nodes:
263
+ return
264
+
265
+ children = input_name_to_nodes[tanh_node.output[0]]
266
+ if len(children) != 1 or children[0].op_type != "Add":
267
+ return
268
+ add_after_tanh = children[0]
269
+
270
+ if not self.model.has_constant_input(add_after_tanh, 1.0):
271
+ return
272
+
273
+ if add_after_tanh.output[0] not in input_name_to_nodes:
274
+ return
275
+ children = input_name_to_nodes[add_after_tanh.output[0]]
276
+ if len(children) != 1 or children[0].op_type != "Mul":
277
+ return
278
+ mul_last = children[0]
279
+
280
+ mul_half = self.model.match_parent(mul_last, "Mul", None, output_name_to_node)
281
+ if mul_half is None:
282
+ return
283
+
284
+ i = self.model.find_constant_input(mul_half, 0.5)
285
+ if i < 0:
286
+ return
287
+
288
+ root_input = mul_half.input[0 if i == 1 else 1]
289
+
290
+ mul_before_tanh = self.model.match_parent(tanh_node, "Mul", 0, output_name_to_node)
291
+ if mul_before_tanh is None:
292
+ return
293
+
294
+ add_1 = self.model.match_parent(mul_before_tanh, "Add", None, output_name_to_node)
295
+ if add_1 is None:
296
+ return
297
+ j = self.model.find_constant_input(add_1, 1.0)
298
+ if j < 0:
299
+ return
300
+
301
+ mul_7978 = self.model.match_parent(mul_before_tanh, "Mul", None, output_name_to_node)
302
+ if mul_7978 is None:
303
+ return
304
+ k = self.model.find_constant_input(mul_7978, 0.7978, delta=0.0001)
305
+ if k < 0:
306
+ return
307
+ if mul_7978.input[0 if k == 1 else 1] != root_input:
308
+ return
309
+
310
+ mul_before_add_1 = self.model.match_parent(add_1, "Mul", 0 if j == 1 else 1, output_name_to_node)
311
+ if mul_before_add_1 is None:
312
+ return
313
+
314
+ if mul_before_add_1.input[0] == root_input:
315
+ another = 1
316
+ elif mul_before_add_1.input[1] == root_input:
317
+ another = 0
318
+ else:
319
+ return
320
+
321
+ mul_0447 = self.model.match_parent(mul_before_add_1, "Mul", another, output_name_to_node)
322
+ if mul_0447 is None:
323
+ return
324
+ m = self.model.find_constant_input(mul_0447, 0.0447, delta=0.0001)
325
+ if m < 0:
326
+ return
327
+
328
+ if mul_0447.input[0 if m == 1 else 1] != root_input:
329
+ return
330
+
331
+ subgraph_nodes = [
332
+ mul_0447,
333
+ mul_before_add_1,
334
+ add_1,
335
+ mul_before_tanh,
336
+ tanh_node,
337
+ add_after_tanh,
338
+ mul_7978,
339
+ mul_half,
340
+ mul_last,
341
+ ]
342
+ if not self.model.is_safe_to_fuse_nodes(
343
+ subgraph_nodes,
344
+ [mul_last.output[0]],
345
+ input_name_to_nodes,
346
+ output_name_to_node,
347
+ ):
348
+ return
349
+
350
+ self.nodes_to_remove.extend(subgraph_nodes)
351
+ fused_node = helper.make_node(
352
+ "FastGelu",
353
+ inputs=[root_input],
354
+ outputs=mul_last.output,
355
+ name=self.model.create_node_name("FastGelu"),
356
+ )
357
+ fused_node.domain = "com.microsoft"
358
+ self.nodes_to_add.append(fused_node)
359
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
360
+ return True
@@ -0,0 +1,259 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict, Optional
7
+
8
+ from fusion_base import Fusion
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionGelu(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "Gelu", "Erf")
18
+
19
+ def fuse(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
20
+ if self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node):
21
+ return
22
+ if self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node):
23
+ return
24
+ self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
25
+
26
+ def fuse_1(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
27
+ """
28
+ This pattern is from PyTorch model
29
+ Fuse Gelu with Erf into one node:
30
+ Pattern 1:
31
+ +-------Mul(0.5)---------------------+
32
+ | |
33
+ | v
34
+ [root] --> Div -----> Erf --> Add --> Mul -->
35
+ (B=1.4142...) (1)
36
+
37
+ Pattern 2:
38
+ +------------------------------------+
39
+ | |
40
+ | v
41
+ [root] --> Div -----> Erf --> Add --> Mul -->Mul -->
42
+ (B=1.4142...) (1) (0.5)
43
+
44
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
45
+ """
46
+ if erf_node.output[0] not in input_name_to_nodes:
47
+ return
48
+ children = input_name_to_nodes[erf_node.output[0]]
49
+ if len(children) != 1 or children[0].op_type != "Add":
50
+ return
51
+ add_after_erf = children[0]
52
+
53
+ if not self.model.has_constant_input(add_after_erf, 1):
54
+ return
55
+
56
+ if add_after_erf.output[0] not in input_name_to_nodes:
57
+ return
58
+ children = input_name_to_nodes[add_after_erf.output[0]]
59
+ if len(children) != 1 or children[0].op_type != "Mul":
60
+ return
61
+ mul_after_erf = children[0]
62
+
63
+ div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
64
+ if div is None:
65
+ return
66
+
67
+ if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
68
+ return
69
+
70
+ subgraph_input = div.input[0]
71
+
72
+ another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
73
+ if subgraph_input == mul_after_erf.input[another]: # pattern 2
74
+ children = input_name_to_nodes[mul_after_erf.output[0]]
75
+ if len(children) != 1 or children[0].op_type != "Mul":
76
+ return
77
+ mul_half = children[0]
78
+ if not self.model.has_constant_input(mul_half, 0.5):
79
+ return
80
+ subgraph_output = mul_half.output[0]
81
+ else: # pattern 1
82
+ mul_half = self.model.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
83
+ if mul_half is None:
84
+ return
85
+
86
+ if not self.model.has_constant_input(mul_half, 0.5):
87
+ return
88
+
89
+ if subgraph_input not in mul_half.input:
90
+ return
91
+
92
+ subgraph_output = mul_after_erf.output[0]
93
+
94
+ subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
95
+ if not self.model.is_safe_to_fuse_nodes(
96
+ subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
97
+ ):
98
+ return
99
+
100
+ self.nodes_to_remove.extend(subgraph_nodes)
101
+ fused_node = helper.make_node(
102
+ "Gelu", inputs=[subgraph_input], outputs=[subgraph_output], name=self.model.create_node_name("Gelu")
103
+ )
104
+ fused_node.domain = "com.microsoft"
105
+ self.nodes_to_add.append(fused_node)
106
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
107
+ self.increase_counter("Gelu")
108
+ return True
109
+
110
+ def fuse_2(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
111
+ """
112
+ This pattern is from Keras model
113
+ Fuse Gelu with Erf into one node:
114
+ +------------------------------------------+
115
+ | |
116
+ | v
117
+ [root] --> Div -----> Erf --> Add --> Mul -->Mul
118
+ (B=1.4142...) (A=1) (A=0.5)
119
+
120
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
121
+ """
122
+ if erf_node.output[0] not in input_name_to_nodes:
123
+ return
124
+ children = input_name_to_nodes[erf_node.output[0]]
125
+ if len(children) != 1 or children[0].op_type != "Add":
126
+ return
127
+ add_after_erf = children[0]
128
+
129
+ if not self.model.has_constant_input(add_after_erf, 1):
130
+ return
131
+
132
+ if add_after_erf.output[0] not in input_name_to_nodes:
133
+ return
134
+ children = input_name_to_nodes[add_after_erf.output[0]]
135
+ if len(children) != 1 or children[0].op_type != "Mul":
136
+ return
137
+ mul_after_erf = children[0]
138
+
139
+ if not self.model.has_constant_input(mul_after_erf, 0.5):
140
+ return
141
+
142
+ if mul_after_erf.output[0] not in input_name_to_nodes:
143
+ return
144
+ children = input_name_to_nodes[mul_after_erf.output[0]]
145
+ if len(children) != 1 or children[0].op_type != "Mul":
146
+ return
147
+ mul = children[0]
148
+
149
+ div = self.model.match_parent(erf_node, "Div", 0, output_name_to_node)
150
+ if div is None:
151
+ return
152
+
153
+ sqrt_node = None
154
+ if self.model.find_constant_input(div, 1.4142, delta=0.001) != 1:
155
+ sqrt_node = self.model.match_parent(div, "Sqrt", 1, output_name_to_node)
156
+ if sqrt_node is None:
157
+ return
158
+ if not self.model.has_constant_input(sqrt_node, 2.0):
159
+ return
160
+
161
+ root_node = self.model.get_parent(div, 0, output_name_to_node)
162
+ if root_node is None:
163
+ return
164
+
165
+ if root_node.output[0] not in mul.input:
166
+ return
167
+
168
+ subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
169
+ if sqrt_node:
170
+ subgraph_nodes.append(sqrt_node)
171
+
172
+ if not self.model.is_safe_to_fuse_nodes(
173
+ subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node
174
+ ):
175
+ return
176
+
177
+ self.nodes_to_remove.extend(subgraph_nodes)
178
+ fused_node = helper.make_node(
179
+ "Gelu", inputs=[root_node.output[0]], outputs=[mul.output[0]], name=self.model.create_node_name("Gelu")
180
+ )
181
+ fused_node.domain = "com.microsoft"
182
+ self.nodes_to_add.append(fused_node)
183
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
184
+ self.increase_counter("Gelu")
185
+ return True
186
+
187
+ def fuse_3(self, erf_node, input_name_to_nodes: Dict, output_name_to_node: Dict) -> Optional[bool]:
188
+ """
189
+ This pattern is from TensorFlow model
190
+ Fuse Gelu with Erf into one node:
191
+ +----------------------------------------------+
192
+ | |
193
+ | v
194
+ [root] --> Mul -----> Erf --> Add --> Mul -->Mul
195
+ (A=0.7071067690849304) (B=1) (B=0.5)
196
+
197
+ Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
198
+ """
199
+
200
+ if erf_node.output[0] not in input_name_to_nodes:
201
+ return
202
+ children = input_name_to_nodes[erf_node.output[0]]
203
+ if len(children) != 1 or children[0].op_type != "Add":
204
+ return
205
+ add_after_erf = children[0]
206
+
207
+ if not self.model.has_constant_input(add_after_erf, 1):
208
+ return
209
+
210
+ if add_after_erf.output[0] not in input_name_to_nodes:
211
+ return
212
+ children = input_name_to_nodes[add_after_erf.output[0]]
213
+ if len(children) != 1 or children[0].op_type != "Mul":
214
+ return
215
+ mul_half = children[0]
216
+
217
+ if not self.model.has_constant_input(mul_half, 0.5):
218
+ return
219
+
220
+ first_mul = self.model.match_parent(erf_node, "Mul", 0, output_name_to_node)
221
+ if first_mul is None:
222
+ return
223
+
224
+ i = self.model.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
225
+ if i < 0:
226
+ return
227
+
228
+ root_node = self.model.get_parent(first_mul, 0 if i == 1 else 1, output_name_to_node)
229
+ if root_node is None:
230
+ return
231
+
232
+ if mul_half.output[0] not in input_name_to_nodes:
233
+ return
234
+ children = input_name_to_nodes[mul_half.output[0]]
235
+ if len(children) != 1 or children[0].op_type != "Mul":
236
+ return
237
+ last_mul = children[0]
238
+
239
+ if not (last_mul.input[0] == root_node.output[0] or last_mul.input[1] == root_node.output[0]):
240
+ return
241
+
242
+ subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
243
+ if not self.model.is_safe_to_fuse_nodes(
244
+ subgraph_nodes,
245
+ [last_mul.output[0]],
246
+ input_name_to_nodes,
247
+ output_name_to_node,
248
+ ):
249
+ return
250
+
251
+ self.nodes_to_remove.extend(subgraph_nodes)
252
+ fused_node = helper.make_node(
253
+ "Gelu", inputs=[root_node.output[0]], outputs=[last_mul.output[0]], name=self.model.create_node_name("Gelu")
254
+ )
255
+ fused_node.domain = "com.microsoft"
256
+ self.nodes_to_add.append(fused_node)
257
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
258
+ self.increase_counter("Gelu")
259
+ return True
@@ -0,0 +1,25 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from fusion_base import Fusion
7
+ from onnx import helper
8
+ from onnx_model import OnnxModel
9
+
10
+
11
+ class FusionGeluApproximation(Fusion):
12
+ def __init__(self, model: OnnxModel):
13
+ super().__init__(model, "FastGelu", ["Gelu", "BiasGelu"], "GeluApproximation")
14
+
15
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
16
+ new_node = helper.make_node(
17
+ "FastGelu",
18
+ inputs=node.input,
19
+ outputs=node.output,
20
+ name=self.model.create_node_name("FastGelu", node.op_type + "_Approximation"),
21
+ )
22
+ new_node.domain = "com.microsoft"
23
+ self.nodes_to_remove.append(node)
24
+ self.nodes_to_add.append(new_node)
25
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name