onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,179 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict
7
+
8
+ import numpy as np
9
+ from fusion_base import Fusion
10
+ from onnx import TensorProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionGroupNorm(Fusion):
17
+ def __init__(self, model: OnnxModel, channels_last=True):
18
+ super().__init__(model, "GroupNorm", "Add")
19
+ self.channels_last = channels_last
20
+
21
+ def fuse(self, add_node, input_name_to_nodes: Dict, output_name_to_node: Dict):
22
+ """
23
+ Fuse Group Normalization subgraph into one node GroupNorm.
24
+ The following is the pattern with swish activation:
25
+ +----------------Shape-------------------------------+
26
+ | |
27
+ | (0, 32, -1) v (512x1x1) (512x1x1) (optional)
28
+ [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add --> Mul--> [output]
29
+ Bx512xHxW (scale=ones(32), B=zeros(32)) | ^ Bx512xHxW
30
+ | |
31
+ +--->Sigmoid (optional)
32
+ The Mul and Sigmoid before output is for Swish activation. They are optional.
33
+ """
34
+ nodes = self.model.match_parent_path(
35
+ add_node, ["Mul", "Reshape", "InstanceNormalization", "Reshape"], [0, 0, 0, 0], output_name_to_node
36
+ )
37
+ if nodes is None:
38
+ return
39
+
40
+ weight_mul, reshape_4d, instance_norm, reshape_3d = nodes
41
+ root = reshape_3d.input[0]
42
+
43
+ parents = self.model.match_parent_path(reshape_4d, ["Shape"], [1], output_name_to_node)
44
+ if parents is None:
45
+ return
46
+ if parents[0].input[0] != root:
47
+ return
48
+ shape_node = parents[0]
49
+
50
+ # Check whether it has swish activation.
51
+ swish_mul = self.model.find_first_child_by_type(add_node, "Mul")
52
+ swish_sigmoid = None
53
+ if swish_mul is not None:
54
+ sigmoid_path = self.model.match_parent_path(swish_mul, ["Sigmoid"], [None], output_name_to_node)
55
+ if sigmoid_path is not None:
56
+ swish_sigmoid = sigmoid_path[0]
57
+
58
+ weight_input = weight_mul.input[1 - self.model.input_index(reshape_4d.output[0], weight_mul)]
59
+ if not self.model.is_constant_with_specified_dimension(weight_input, 3, "group norm weight"):
60
+ return
61
+
62
+ bias_input = add_node.input[1 - self.model.input_index(weight_mul.output[0], add_node)]
63
+ if not self.model.is_constant_with_specified_dimension(bias_input, 3, "layernorm bias"):
64
+ return
65
+
66
+ weight = self.model.get_constant_value(weight_input)
67
+ if weight is None:
68
+ return
69
+
70
+ if not (len(weight.shape) == 3 and weight.shape[1] == 1 and weight.shape[2] == 1):
71
+ return
72
+
73
+ bias = self.model.get_constant_value(bias_input)
74
+ if bias is None:
75
+ return
76
+ if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1):
77
+ return
78
+
79
+ weight_elements = int(np.prod(weight.shape))
80
+ bias_elements = int(np.prod(bias.shape))
81
+ if weight_elements != bias_elements:
82
+ return
83
+
84
+ instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
85
+ if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
86
+ return
87
+
88
+ instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
89
+ if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
90
+ return
91
+
92
+ if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
93
+ return
94
+ if not np.allclose(np.zeros_like(instance_norm_bias), instance_norm_bias):
95
+ return
96
+
97
+ group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")
98
+
99
+ self.add_initializer(
100
+ name=group_norm_name + "_gamma",
101
+ data_type=TensorProto.FLOAT,
102
+ dims=[weight_elements],
103
+ vals=weight,
104
+ )
105
+
106
+ self.add_initializer(
107
+ name=group_norm_name + "_beta",
108
+ data_type=TensorProto.FLOAT,
109
+ dims=[bias_elements],
110
+ vals=bias,
111
+ )
112
+
113
+ last_node = add_node
114
+ subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node]
115
+ has_swish_activation = swish_mul and swish_sigmoid
116
+ if swish_mul and swish_sigmoid:
117
+ subgraph_nodes.extend([swish_mul, swish_sigmoid])
118
+ last_node = swish_mul
119
+
120
+ if not self.model.is_safe_to_fuse_nodes(
121
+ subgraph_nodes,
122
+ last_node.output,
123
+ input_name_to_nodes,
124
+ output_name_to_node,
125
+ ):
126
+ self.nodes_to_remove.extend([last_node])
127
+ else:
128
+ self.nodes_to_remove.extend(subgraph_nodes)
129
+
130
+ # instance_norm_scale might from Constant node. Use prune graph to clear it.
131
+ self.prune_graph = True
132
+
133
+ input_name = root
134
+ output_name = last_node.output[0]
135
+
136
+ group_norm_input_name = input_name + "_NHWC" if self.channels_last else input_name
137
+ group_norm_output_name = output_name + "_NHWC" if self.channels_last else output_name
138
+
139
+ # NCHW to NHWC
140
+ if self.channels_last:
141
+ transpose_input = helper.make_node(
142
+ "Transpose",
143
+ [input_name],
144
+ [group_norm_input_name],
145
+ name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"),
146
+ perm=[0, 2, 3, 1],
147
+ )
148
+ self.nodes_to_add.append(transpose_input)
149
+ self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
150
+
151
+ new_node = helper.make_node(
152
+ "GroupNorm",
153
+ inputs=[group_norm_input_name, group_norm_name + "_gamma", group_norm_name + "_beta"],
154
+ outputs=[group_norm_output_name],
155
+ name=group_norm_name,
156
+ )
157
+
158
+ new_node.attribute.extend(instance_norm.attribute)
159
+ new_node.attribute.extend([helper.make_attribute("groups", 32)])
160
+ new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)])
161
+
162
+ if not self.channels_last:
163
+ new_node.attribute.extend([helper.make_attribute("channels_last", 0)])
164
+
165
+ new_node.domain = "com.microsoft"
166
+ self.nodes_to_add.append(new_node)
167
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
168
+
169
+ # NHWC to NCHW
170
+ if self.channels_last:
171
+ transpose_output = helper.make_node(
172
+ "Transpose",
173
+ [group_norm_output_name],
174
+ [output_name],
175
+ name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"),
176
+ perm=[0, 3, 1, 2],
177
+ )
178
+ self.nodes_to_add.append(transpose_output)
179
+ self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name
@@ -0,0 +1,465 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict, List
7
+
8
+ from fusion_base import Fusion
9
+ from onnx import TensorProto, helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionLayerNormalization(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "LayerNormalization", "ReduceMean")
18
+
19
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
20
+ """
21
+ Fuse Layer Normalization subgraph into one node LayerNormalization:
22
+ +----------------------+
23
+ | |
24
+ | v
25
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
26
+ (axis=2 or -1) | (Y=2) (axis=2 or -1) (E-6 or E-12 or 0) ^
27
+ | |
28
+ +-----------------------------------------------+
29
+
30
+ It also handles cases of duplicated sub nodes exported from older version of PyTorch:
31
+ +----------------------+
32
+ | v
33
+ | +-------> Sub-----------------------------------------------+
34
+ | | |
35
+ | | v
36
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
37
+ | ^
38
+ | |
39
+ +----------------------+
40
+ """
41
+ subgraph_nodes = []
42
+ children = self.model.get_children(node, input_name_to_nodes)
43
+ if len(children) == 0 or len(children) > 2:
44
+ return
45
+
46
+ root_input = node.input[0]
47
+
48
+ if children[0].op_type != "Sub" or children[0].input[0] != root_input:
49
+ return
50
+
51
+ if len(children) == 2:
52
+ if children[1].op_type != "Sub" or children[1].input[0] != root_input:
53
+ return
54
+
55
+ div_node = None
56
+ for child in children:
57
+ # Check if Sub --> Div exists
58
+ div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
59
+
60
+ # Check if Sub --> Cast --> Div
61
+ div_node_2 = self.model.match_child_path(child, ["Cast", "Div"], exclude=[])
62
+
63
+ if div_node_1 is not None:
64
+ div_node = div_node_1
65
+ elif div_node_2 is not None:
66
+ div_node = div_node_2[-1]
67
+ if div_node is None:
68
+ return
69
+
70
+ path_id, parent_nodes, _ = self.model.match_parent_paths(
71
+ div_node,
72
+ [
73
+ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
74
+ (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
75
+ ],
76
+ output_name_to_node,
77
+ )
78
+ if path_id < 0:
79
+ return
80
+
81
+ sub_node = parent_nodes[-1]
82
+ if sub_node not in children:
83
+ return
84
+
85
+ second_add_node = parent_nodes[1]
86
+ i, add_weight = self.model.get_constant_input(second_add_node)
87
+ if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
88
+ logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}")
89
+ return
90
+
91
+ pow_node = parent_nodes[3]
92
+ if self.model.find_constant_input(pow_node, 2.0) != 1:
93
+ return
94
+
95
+ temp_node = input_name_to_nodes[div_node.output[0]][0]
96
+ if temp_node.op_type == "Cast":
97
+ # Div --> Cast --> Mul
98
+ subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
99
+ mul_node = input_name_to_nodes[temp_node.output[0]][0]
100
+ else:
101
+ # Div --> Mul
102
+ mul_node = temp_node
103
+ if mul_node.op_type != "Mul":
104
+ return
105
+
106
+ last_add_node = input_name_to_nodes[mul_node.output[0]][0]
107
+ if last_add_node.op_type != "Add":
108
+ return
109
+
110
+ subgraph_nodes.append(node)
111
+ subgraph_nodes.extend(children)
112
+ subgraph_nodes.extend(parent_nodes[:-1])
113
+
114
+ subgraph_nodes.extend([last_add_node, mul_node, div_node])
115
+ if not self.model.is_safe_to_fuse_nodes(
116
+ subgraph_nodes,
117
+ last_add_node.output,
118
+ input_name_to_nodes,
119
+ output_name_to_node,
120
+ ):
121
+ logger.debug("It is not safe to fuse LayerNormalization node. Skip")
122
+ return
123
+
124
+ node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
125
+ weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
126
+ if not self.model.is_constant_with_specified_dimension(weight_input, 1, "layernorm weight"):
127
+ return
128
+
129
+ bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
130
+ if not self.model.is_constant_with_specified_dimension(bias_input, 1, "layernorm bias"):
131
+ return
132
+
133
+ self.nodes_to_remove.extend(subgraph_nodes)
134
+
135
+ normalize_node = helper.make_node(
136
+ "LayerNormalization",
137
+ inputs=[node.input[0], weight_input, bias_input],
138
+ outputs=[last_add_node.output[0]],
139
+ name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
140
+ )
141
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
142
+ self.nodes_to_add.append(normalize_node)
143
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
144
+
145
+
146
+ class FusionLayerNormalizationNCHW(Fusion):
147
+ def __init__(self, model: OnnxModel):
148
+ super().__init__(model, "LayerNormalization", "ReduceMean")
149
+
150
+ def get_weight_or_bias(self, output_name, description):
151
+ value = self.model.get_constant_value(output_name)
152
+ if value is None:
153
+ logger.debug(f"{description} {output_name} is not initializer.")
154
+ return None
155
+
156
+ if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1:
157
+ logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}")
158
+ return None
159
+
160
+ return value.reshape([value.shape[0]])
161
+
162
+ def create_transpose_node(self, input_name: str, perm: List[int], output_name=None):
163
+ """Append a Transpose node after an input"""
164
+ node_name = self.model.create_node_name("Transpose")
165
+
166
+ if output_name is None:
167
+ output_name = node_name + "_out" + "-" + input_name
168
+
169
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
170
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
171
+
172
+ return transpose_node
173
+
174
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
175
+ """
176
+ Fuse Layer Normalization subgraph into one node LayerNormalization:
177
+ +----------------------+
178
+ | NxCxHxW |
179
+ | v (Cx1x1) (Cx1x1)
180
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add -->
181
+ (axes=1) | (Y=2) (axes=1) (E-6) ^
182
+ | |
183
+ +-----------------------------------------------+
184
+
185
+ Fused subgraph:
186
+ (0,2,3,1) (0,3,1,2)
187
+ [Root] --> Transpose --> LayerNormalization --> Transpose -->
188
+ """
189
+ axes = OnnxModel.get_node_attribute(node, "axes")
190
+ if (not isinstance(axes, list)) or axes != [1]:
191
+ return
192
+
193
+ subgraph_nodes = []
194
+ children = self.model.get_children(node, input_name_to_nodes)
195
+ if len(children) != 1:
196
+ return
197
+
198
+ root_input = node.input[0]
199
+
200
+ if children[0].op_type != "Sub" or children[0].input[0] != root_input:
201
+ return
202
+ sub = children[0]
203
+
204
+ div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False)
205
+ if div_node is None:
206
+ return
207
+
208
+ parent_nodes = self.model.match_parent_path(
209
+ div_node,
210
+ ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"],
211
+ [1, 0, 0, 0, 0],
212
+ output_name_to_node,
213
+ )
214
+ if parent_nodes is None:
215
+ return
216
+
217
+ _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes
218
+ if sub != sub_node:
219
+ return
220
+
221
+ i, add_weight = self.model.get_constant_input(second_add_node)
222
+ if add_weight is None or add_weight <= 0 or add_weight > 1.0e-4:
223
+ logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {add_weight}")
224
+ return
225
+
226
+ axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes")
227
+ assert isinstance(axes, list)
228
+ if axes != [1]:
229
+ return
230
+
231
+ if self.model.find_constant_input(pow_node, 2.0) != 1:
232
+ return
233
+
234
+ temp_node = input_name_to_nodes[div_node.output[0]][0]
235
+ mul_node = temp_node
236
+ if mul_node.op_type != "Mul":
237
+ return
238
+
239
+ last_add_node = input_name_to_nodes[mul_node.output[0]][0]
240
+ if last_add_node.op_type != "Add":
241
+ return
242
+
243
+ subgraph_nodes.append(node)
244
+ subgraph_nodes.extend(parent_nodes)
245
+ subgraph_nodes.extend([last_add_node, mul_node, div_node])
246
+
247
+ if not self.model.is_safe_to_fuse_nodes(
248
+ subgraph_nodes,
249
+ last_add_node.output,
250
+ input_name_to_nodes,
251
+ output_name_to_node,
252
+ ):
253
+ logger.debug("It is not safe to fuse LayerNormalization node. Skip")
254
+ return
255
+
256
+ node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
257
+ weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
258
+ weight = self.get_weight_or_bias(weight_input, "layernorm weight")
259
+ if weight is None:
260
+ return
261
+
262
+ bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
263
+ bias = self.get_weight_or_bias(bias_input, "layernorm bias")
264
+ if bias is None:
265
+ return
266
+
267
+ weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
268
+
269
+ bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
270
+ self.model.add_initializer(weight_nhwc, self.this_graph_name)
271
+ self.model.add_initializer(bias_nhwc, self.this_graph_name)
272
+
273
+ self.nodes_to_remove.extend(subgraph_nodes)
274
+
275
+ transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1])
276
+
277
+ layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm")
278
+
279
+ transpose_output = self.create_transpose_node(
280
+ layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0]
281
+ )
282
+
283
+ normalize_node = helper.make_node(
284
+ "LayerNormalization",
285
+ inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"],
286
+ outputs=[layernorm_node_name + "_out_nhwc"],
287
+ name=layernorm_node_name,
288
+ )
289
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(add_weight))])
290
+
291
+ self.nodes_to_add.append(transpose_input)
292
+ self.nodes_to_add.append(normalize_node)
293
+ self.nodes_to_add.append(transpose_output)
294
+ self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
295
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
296
+ self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name
297
+
298
+ counter_name = "LayerNormalization(NHWC)"
299
+ self.increase_counter(counter_name)
300
+
301
+
302
+ class FusionLayerNormalizationTF(Fusion):
303
+ def __init__(self, model: OnnxModel):
304
+ super().__init__(model, "LayerNormalization", "Add", "TF")
305
+
306
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
307
+ """
308
+ Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
309
+ +------------------------------------+
310
+ | |
311
+ | |
312
+ (Cast_1) |
313
+ | |
314
+ | v (B) (B) (A)
315
+ Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
316
+ | | | ^ ^
317
+ | | | | |
318
+ | +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
319
+ | v |
320
+ +---------------------------------------------------------------------------------------------------------------> Mul--------------------+
321
+ """
322
+ return_indice = []
323
+ _, parent_nodes, return_indice = self.model.match_parent_paths(
324
+ node,
325
+ [
326
+ (
327
+ [
328
+ "Sub",
329
+ "Mul",
330
+ "Mul",
331
+ "Reciprocal",
332
+ "Sqrt",
333
+ "Add",
334
+ "ReduceMean",
335
+ "Mul",
336
+ "Sub",
337
+ "ReduceMean",
338
+ ],
339
+ [1, 1, None, 0, 0, 0, None, 0, 0, None],
340
+ ),
341
+ (
342
+ [
343
+ "Sub",
344
+ "Mul",
345
+ "Mul",
346
+ "Reciprocal",
347
+ "Sqrt",
348
+ "Add",
349
+ "Cast",
350
+ "ReduceMean",
351
+ "Mul",
352
+ "Sub",
353
+ "ReduceMean",
354
+ ],
355
+ [1, 1, None, 0, 0, 0, 0, None, 0, 0, None],
356
+ ),
357
+ ],
358
+ output_name_to_node,
359
+ )
360
+
361
+ if parent_nodes is None:
362
+ return
363
+
364
+ assert len(return_indice) == 3
365
+ if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
366
+ logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
367
+ return
368
+
369
+ (
370
+ sub_node_0,
371
+ mul_node_0,
372
+ mul_node_1,
373
+ reciprocol_node,
374
+ sqrt_node,
375
+ add_node_0,
376
+ ) = parent_nodes[:6]
377
+ reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]
378
+
379
+ cast_node_3 = None
380
+ if len(parent_nodes) == 11:
381
+ cast_node_3 = parent_nodes[6]
382
+ assert cast_node_3.op_type == "Cast"
383
+
384
+ mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node)
385
+ if mul_node_3 is None:
386
+ logger.debug("mul_node_3 not found")
387
+ return
388
+
389
+ node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
390
+ root_node = (
391
+ node_before_reduce
392
+ if cast_node_3 is None
393
+ else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
394
+ )
395
+ if root_node is None:
396
+ logger.debug("root node is none")
397
+ return
398
+
399
+ i, epsilon = self.model.get_constant_input(add_node_0)
400
+ if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None):
401
+ logger.debug("epsilon is not matched")
402
+ return
403
+
404
+ if cast_node_3 is None and (
405
+ reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
406
+ ):
407
+ logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
408
+ return
409
+
410
+ if cast_node_3 is not None and (
411
+ node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
412
+ ):
413
+ logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
414
+ return
415
+
416
+ if mul_node_2.input[0] != mul_node_2.input[1]:
417
+ logger.debug("mul_node_2 shall have two same inputs")
418
+ return
419
+
420
+ subgraph_nodes = [
421
+ node,
422
+ sub_node_0,
423
+ mul_node_0,
424
+ mul_node_1,
425
+ reciprocol_node,
426
+ sqrt_node,
427
+ add_node_0,
428
+ reduce_mean_node_0,
429
+ mul_node_2,
430
+ sub_node_1,
431
+ reduce_mean_node_1,
432
+ mul_node_3,
433
+ ]
434
+
435
+ if cast_node_3 is not None:
436
+ cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node)
437
+ if cast_node_2 is None:
438
+ logger.debug("cast_node_2 not found")
439
+ return
440
+ subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])
441
+
442
+ if not self.model.is_safe_to_fuse_nodes(
443
+ subgraph_nodes,
444
+ node.output,
445
+ self.model.input_name_to_nodes(),
446
+ self.model.output_name_to_node(),
447
+ ):
448
+ logger.debug("not safe to fuse layer normalization")
449
+ return
450
+
451
+ self.nodes_to_remove.extend(subgraph_nodes)
452
+
453
+ weight_input = mul_node_1.input[1]
454
+ bias_input = sub_node_0.input[0]
455
+
456
+ # TODO: add epsilon attribute
457
+ fused_node = helper.make_node(
458
+ "LayerNormalization",
459
+ inputs=[mul_node_3.input[0], weight_input, bias_input],
460
+ outputs=[node.output[0]],
461
+ name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
462
+ )
463
+ fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
464
+ self.nodes_to_add.append(fused_node)
465
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name