onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,180 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy as np
8
+ from fusion_base import Fusion
9
+ from onnx import TensorProto, helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionGroupNorm(Fusion):
16
+ def __init__(self, model: OnnxModel, channels_last=True):
17
+ super().__init__(model, "GroupNorm", "Add")
18
+ self.channels_last = channels_last
19
+
20
+ def fuse(self, add_node, input_name_to_nodes: dict, output_name_to_node: dict):
21
+ """
22
+ Fuse Group Normalization subgraph into one node GroupNorm.
23
+ The following is the pattern with swish activation:
24
+ +----------------Shape-------------------------------+
25
+ | |
26
+ | (0, 32, -1) v (512x1x1) (512x1x1) (optional)
27
+ [Root] --> Reshape -------> InstanceNormalization --> Reshape ---> Mul --> Add --> Mul--> [output]
28
+ Bx512xHxW (scale=ones(32), B=zeros(32)) | ^ Bx512xHxW
29
+ | |
30
+ +--->Sigmoid (optional)
31
+ The Mul and Sigmoid before output is for Swish activation. They are optional.
32
+ """
33
+ nodes = self.model.match_parent_path(
34
+ add_node, ["Mul", "Reshape", "InstanceNormalization", "Reshape"], [0, 0, 0, 0], output_name_to_node
35
+ )
36
+ if nodes is None:
37
+ return
38
+
39
+ weight_mul, reshape_4d, instance_norm, reshape_3d = nodes
40
+ root = reshape_3d.input[0]
41
+
42
+ parents = self.model.match_parent_path(reshape_4d, ["Shape"], [1], output_name_to_node)
43
+ if parents is None:
44
+ return
45
+ if parents[0].input[0] != root:
46
+ return
47
+ shape_node = parents[0]
48
+
49
+ # Check whether it has swish activation.
50
+ swish_mul = self.model.find_first_child_by_type(add_node, "Mul")
51
+ swish_sigmoid = None
52
+ if swish_mul is not None:
53
+ sigmoid_path = self.model.match_parent_path(swish_mul, ["Sigmoid"], [None], output_name_to_node)
54
+ if sigmoid_path is not None:
55
+ swish_sigmoid = sigmoid_path[0]
56
+
57
+ weight_input = weight_mul.input[1 - self.model.input_index(reshape_4d.output[0], weight_mul)]
58
+ if not self.model.is_constant_with_specified_dimension(weight_input, 3, "group norm weight"):
59
+ return
60
+
61
+ bias_input = add_node.input[1 - self.model.input_index(weight_mul.output[0], add_node)]
62
+ if not self.model.is_constant_with_specified_dimension(bias_input, 3, "layernorm bias"):
63
+ return
64
+
65
+ weight = self.model.get_constant_value(weight_input)
66
+ if weight is None:
67
+ return
68
+
69
+ if not (len(weight.shape) == 3 and weight.shape[1] == 1 and weight.shape[2] == 1):
70
+ return
71
+
72
+ bias = self.model.get_constant_value(bias_input)
73
+ if bias is None:
74
+ return
75
+ if not (len(bias.shape) == 3 and bias.shape[1] == 1 and bias.shape[2] == 1):
76
+ return
77
+
78
+ weight_elements = int(np.prod(weight.shape))
79
+ bias_elements = int(np.prod(bias.shape))
80
+ if weight_elements != bias_elements:
81
+ return
82
+
83
+ instance_norm_scale = self.model.get_constant_value(instance_norm.input[1])
84
+ if instance_norm_scale is None or len(instance_norm_scale.shape) != 1:
85
+ return
86
+ num_groups = int(instance_norm_scale.shape[0])
87
+
88
+ instance_norm_bias = self.model.get_constant_value(instance_norm.input[2])
89
+ if instance_norm_bias is None or instance_norm_scale.shape != instance_norm_scale.shape:
90
+ return
91
+
92
+ if not np.allclose(np.ones_like(instance_norm_scale), instance_norm_scale):
93
+ return
94
+ if not np.allclose(np.zeros_like(instance_norm_bias), instance_norm_bias):
95
+ return
96
+
97
+ group_norm_name = self.model.create_node_name("GroupNorm", name_prefix="GroupNorm")
98
+
99
+ self.add_initializer(
100
+ name=group_norm_name + "_gamma",
101
+ data_type=TensorProto.FLOAT,
102
+ dims=[weight_elements],
103
+ vals=weight,
104
+ )
105
+
106
+ self.add_initializer(
107
+ name=group_norm_name + "_beta",
108
+ data_type=TensorProto.FLOAT,
109
+ dims=[bias_elements],
110
+ vals=bias,
111
+ )
112
+
113
+ last_node = add_node
114
+ subgraph_nodes = [add_node, weight_mul, reshape_4d, instance_norm, reshape_3d, shape_node]
115
+ has_swish_activation = swish_mul and swish_sigmoid
116
+ if swish_mul and swish_sigmoid:
117
+ subgraph_nodes.extend([swish_mul, swish_sigmoid])
118
+ last_node = swish_mul
119
+
120
+ if not self.model.is_safe_to_fuse_nodes(
121
+ subgraph_nodes,
122
+ last_node.output,
123
+ input_name_to_nodes,
124
+ output_name_to_node,
125
+ ):
126
+ self.nodes_to_remove.extend([last_node])
127
+ else:
128
+ self.nodes_to_remove.extend(subgraph_nodes)
129
+
130
+ # instance_norm_scale might from Constant node. Use prune graph to clear it.
131
+ self.prune_graph = True
132
+
133
+ input_name = root
134
+ output_name = last_node.output[0]
135
+
136
+ group_norm_input_name = input_name + "_NHWC" if self.channels_last else input_name
137
+ group_norm_output_name = output_name + "_NHWC" if self.channels_last else output_name
138
+
139
+ # NCHW to NHWC
140
+ if self.channels_last:
141
+ transpose_input = helper.make_node(
142
+ "Transpose",
143
+ [input_name],
144
+ [group_norm_input_name],
145
+ name=self.model.create_node_name("Transpose", name_prefix="Transpose_NCHW_to_NHWC"),
146
+ perm=[0, 2, 3, 1],
147
+ )
148
+ self.nodes_to_add.append(transpose_input)
149
+ self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
150
+
151
+ new_node = helper.make_node(
152
+ "GroupNorm",
153
+ inputs=[group_norm_input_name, group_norm_name + "_gamma", group_norm_name + "_beta"],
154
+ outputs=[group_norm_output_name],
155
+ name=group_norm_name,
156
+ )
157
+
158
+ new_node.attribute.extend(instance_norm.attribute)
159
+
160
+ new_node.attribute.extend([helper.make_attribute("groups", num_groups)])
161
+ new_node.attribute.extend([helper.make_attribute("activation", 1 if has_swish_activation else 0)])
162
+
163
+ if not self.channels_last:
164
+ new_node.attribute.extend([helper.make_attribute("channels_last", 0)])
165
+
166
+ new_node.domain = "com.microsoft"
167
+ self.nodes_to_add.append(new_node)
168
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
169
+
170
+ # NHWC to NCHW
171
+ if self.channels_last:
172
+ transpose_output = helper.make_node(
173
+ "Transpose",
174
+ [group_norm_output_name],
175
+ [output_name],
176
+ name=self.model.create_node_name("Transpose", name_prefix="Transpose_NHWC_to_NCHW"),
177
+ perm=[0, 3, 1, 2],
178
+ )
179
+ self.nodes_to_add.append(transpose_output)
180
+ self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name
@@ -0,0 +1,489 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ from fusion_base import Fusion
8
+ from onnx import TensorProto, helper
9
+ from onnx_model import OnnxModel
10
+
11
+ logger = getLogger(__name__)
12
+
13
+
14
+ class FusionLayerNormalization(Fusion):
15
+ def __init__(self, model: OnnxModel, check_constant_and_dimension: bool = True, force: bool = False):
16
+ super().__init__(model, "LayerNormalization", "ReduceMean")
17
+ self.check_constant_and_dimension = check_constant_and_dimension
18
+ self.force = force
19
+
20
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
21
+ """
22
+ Fuse Layer Normalization subgraph into one node LayerNormalization:
23
+ +----------------------+
24
+ | |
25
+ | v
26
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
27
+ (axis=2 or -1) | (Y=2) (axis=2 or -1) (B=E-6 or E-12) ^
28
+ | |
29
+ +-------------------------------------------------+
30
+
31
+ It also handles cases of duplicated sub nodes exported from older version of PyTorch:
32
+ +----------------------+
33
+ | v
34
+ | +-------> Sub-----------------------------------------------+
35
+ | | |
36
+ | | v
37
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add
38
+ | ^
39
+ | |
40
+ +----------------------+
41
+ """
42
+ subgraph_nodes = []
43
+ children = self.model.get_children(node, input_name_to_nodes)
44
+ if len(children) == 0 or len(children) > 2:
45
+ return
46
+
47
+ root_input = node.input[0]
48
+
49
+ if children[0].op_type != "Sub" or children[0].input[0] != root_input:
50
+ return
51
+
52
+ if len(children) == 2:
53
+ if children[1].op_type != "Sub" or children[1].input[0] != root_input:
54
+ return
55
+
56
+ div_node = None
57
+ for child in children:
58
+ # Check if Sub --> Div exists
59
+ div_node_1 = self.model.find_first_child_by_type(child, "Div", input_name_to_nodes, recursive=False)
60
+ if div_node_1 is not None:
61
+ div_node = div_node_1
62
+ break
63
+ else:
64
+ # Check if Sub --> Cast --> Div
65
+ div_node_2 = self.model.match_child_path(child, ["Cast", "Div"])
66
+ if div_node_2 is not None:
67
+ div_node = div_node_2[-1]
68
+ break
69
+
70
+ if div_node is None:
71
+ return
72
+
73
+ _path_id, parent_nodes, _ = self.model.match_parent_paths(
74
+ div_node,
75
+ [
76
+ (["Sqrt", "Add", "ReduceMean", "Pow", "Sub"], [1, 0, 0, 0, 0]),
77
+ (["Sqrt", "Add", "ReduceMean", "Pow", "Cast", "Sub"], [1, 0, 0, 0, 0, 0]),
78
+ ],
79
+ output_name_to_node,
80
+ )
81
+ if parent_nodes is None:
82
+ return
83
+
84
+ sub_node = parent_nodes[-1]
85
+ if sub_node not in children:
86
+ return
87
+
88
+ add_eps_node = parent_nodes[1]
89
+ i, epsilon = self.model.get_constant_input(add_eps_node)
90
+ if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
91
+ logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}")
92
+ return
93
+
94
+ pow_node = parent_nodes[3]
95
+ if self.model.find_constant_input(pow_node, 2.0) != 1:
96
+ return
97
+
98
+ if div_node.output[0] not in input_name_to_nodes:
99
+ return
100
+
101
+ # In MMDit model, Div might have two Mul+Add children paths.
102
+ div_children = input_name_to_nodes[div_node.output[0]]
103
+ for temp_node in div_children:
104
+ if temp_node.op_type == "Cast":
105
+ # Div --> Cast --> Mul
106
+ subgraph_nodes.append(temp_node) # add Cast node to list of subgraph nodes
107
+ if temp_node.output[0] not in input_name_to_nodes:
108
+ continue
109
+ mul_node = input_name_to_nodes[temp_node.output[0]][0]
110
+ else:
111
+ # Div --> Mul
112
+ mul_node = temp_node
113
+ if mul_node.op_type != "Mul":
114
+ continue
115
+
116
+ if mul_node.output[0] not in input_name_to_nodes:
117
+ continue
118
+ last_add_node = input_name_to_nodes[mul_node.output[0]][0]
119
+ if last_add_node.op_type != "Add":
120
+ continue
121
+
122
+ subgraph_nodes.append(node)
123
+ subgraph_nodes.extend(children)
124
+ subgraph_nodes.extend(parent_nodes[:-1])
125
+
126
+ subgraph_nodes.extend([last_add_node, mul_node, div_node])
127
+
128
+ node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
129
+ weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
130
+ if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
131
+ weight_input, 1, "layernorm weight"
132
+ ):
133
+ continue
134
+
135
+ bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
136
+ if self.check_constant_and_dimension and not self.model.is_constant_with_specified_dimension(
137
+ bias_input, 1, "layernorm bias"
138
+ ):
139
+ continue
140
+
141
+ layer_norm_output = last_add_node.output[0]
142
+ if not self.model.is_safe_to_fuse_nodes(
143
+ subgraph_nodes,
144
+ last_add_node.output,
145
+ input_name_to_nodes,
146
+ output_name_to_node,
147
+ ):
148
+ # If it is not safe to fuse, somce computation may be duplicated if we force to fuse it.
149
+ # It it unknown that force fusion might bring performance gain/loss.
150
+ # User need test performance impact to see whether forcing fusion can help.
151
+ if self.force:
152
+ self.prune_graph = True
153
+ else:
154
+ logger.debug("It is not safe to fuse LayerNormalization node. Skip")
155
+ continue
156
+ else:
157
+ self.nodes_to_remove.extend(subgraph_nodes)
158
+
159
+ normalize_node = helper.make_node(
160
+ "LayerNormalization",
161
+ inputs=[node.input[0], weight_input, bias_input],
162
+ outputs=[layer_norm_output],
163
+ name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
164
+ )
165
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
166
+ self.nodes_to_add.append(normalize_node)
167
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
168
+
169
+
170
+ class FusionLayerNormalizationNCHW(Fusion):
171
+ def __init__(self, model: OnnxModel):
172
+ super().__init__(model, "LayerNormalization", "ReduceMean")
173
+
174
+ def get_weight_or_bias(self, output_name, description):
175
+ value = self.model.get_constant_value(output_name)
176
+ if value is None:
177
+ logger.debug(f"{description} {output_name} is not initializer.")
178
+ return None
179
+
180
+ if len(value.shape) != 3 or value.shape[1] != 1 or value.shape[2] != 1:
181
+ logger.debug(f"{description} {output_name} shall have 3 dimensions Cx1x1. Got shape {value.shape}")
182
+ return None
183
+
184
+ return value.reshape([value.shape[0]])
185
+
186
+ def create_transpose_node(self, input_name: str, perm: list[int], output_name=None):
187
+ """Append a Transpose node after an input"""
188
+ node_name = self.model.create_node_name("Transpose")
189
+
190
+ if output_name is None:
191
+ output_name = node_name + "_out" + "-" + input_name
192
+
193
+ transpose_node = helper.make_node("Transpose", inputs=[input_name], outputs=[output_name], name=node_name)
194
+ transpose_node.attribute.extend([helper.make_attribute("perm", perm)])
195
+
196
+ return transpose_node
197
+
198
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
199
+ """
200
+ Fuse Layer Normalization subgraph into one node LayerNormalization:
201
+ +----------------------+
202
+ | NxCxHxW |
203
+ | v (Cx1x1) (Cx1x1)
204
+ [Root] --> ReduceMean --> Sub --> Pow --> ReduceMean --> Add --> Sqrt --> Div --> Mul --> Add -->
205
+ (axes=1) | (Y=2) (axes=1) (E-6) ^
206
+ | |
207
+ +-----------------------------------------------+
208
+
209
+ Fused subgraph:
210
+ (0,2,3,1) (0,3,1,2)
211
+ [Root] --> Transpose --> LayerNormalization --> Transpose -->
212
+ """
213
+ axes = OnnxModel.get_node_attribute(node, "axes")
214
+ if (not isinstance(axes, list)) or axes != [1]:
215
+ return
216
+
217
+ subgraph_nodes = []
218
+ children = self.model.get_children(node, input_name_to_nodes)
219
+ if len(children) != 1:
220
+ return
221
+
222
+ root_input = node.input[0]
223
+
224
+ if children[0].op_type != "Sub" or children[0].input[0] != root_input:
225
+ return
226
+ sub = children[0]
227
+
228
+ div_node = self.model.find_first_child_by_type(sub, "Div", input_name_to_nodes, recursive=False)
229
+ if div_node is None:
230
+ return
231
+
232
+ parent_nodes = self.model.match_parent_path(
233
+ div_node,
234
+ ["Sqrt", "Add", "ReduceMean", "Pow", "Sub"],
235
+ [1, 0, 0, 0, 0],
236
+ output_name_to_node,
237
+ )
238
+ if parent_nodes is None:
239
+ return
240
+
241
+ _sqrt_node, second_add_node, reduce_mean_node, pow_node, sub_node = parent_nodes
242
+ if sub != sub_node:
243
+ return
244
+
245
+ i, epsilon = self.model.get_constant_input(second_add_node)
246
+ if epsilon is None or epsilon <= 0 or epsilon > 1.0e-4:
247
+ logger.debug(f"skip SkipLayerNormalization fusion since epsilon value is not expected: {epsilon}")
248
+ return
249
+
250
+ axes = OnnxModel.get_node_attribute(reduce_mean_node, "axes")
251
+ assert isinstance(axes, list)
252
+ if axes != [1]:
253
+ return
254
+
255
+ if self.model.find_constant_input(pow_node, 2.0) != 1:
256
+ return
257
+
258
+ temp_node = input_name_to_nodes[div_node.output[0]][0]
259
+ mul_node = temp_node
260
+ if mul_node.op_type != "Mul":
261
+ return
262
+
263
+ last_add_node = input_name_to_nodes[mul_node.output[0]][0]
264
+ if last_add_node.op_type != "Add":
265
+ return
266
+
267
+ subgraph_nodes.append(node)
268
+ subgraph_nodes.extend(parent_nodes)
269
+ subgraph_nodes.extend([last_add_node, mul_node, div_node])
270
+
271
+ if not self.model.is_safe_to_fuse_nodes(
272
+ subgraph_nodes,
273
+ last_add_node.output,
274
+ input_name_to_nodes,
275
+ output_name_to_node,
276
+ ):
277
+ logger.debug("It is not safe to fuse LayerNormalization node. Skip")
278
+ return
279
+
280
+ node_before_weight = div_node if temp_node.op_type != "Cast" else temp_node
281
+ weight_input = mul_node.input[1 - self.model.input_index(node_before_weight.output[0], mul_node)]
282
+ weight = self.get_weight_or_bias(weight_input, "layernorm weight")
283
+ if weight is None:
284
+ return
285
+
286
+ bias_input = last_add_node.input[1 - self.model.input_index(mul_node.output[0], last_add_node)]
287
+ bias = self.get_weight_or_bias(bias_input, "layernorm bias")
288
+ if bias is None:
289
+ return
290
+
291
+ weight_nhwc = helper.make_tensor(weight_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
292
+
293
+ bias_nhwc = helper.make_tensor(bias_input + "_NHWC", TensorProto.FLOAT, weight.shape, weight)
294
+ self.model.add_initializer(weight_nhwc, self.this_graph_name)
295
+ self.model.add_initializer(bias_nhwc, self.this_graph_name)
296
+
297
+ self.nodes_to_remove.extend(subgraph_nodes)
298
+
299
+ transpose_input = self.create_transpose_node(node.input[0], [0, 2, 3, 1])
300
+
301
+ layernorm_node_name = self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm")
302
+
303
+ transpose_output = self.create_transpose_node(
304
+ layernorm_node_name + "_out_nhwc", [0, 3, 1, 2], last_add_node.output[0]
305
+ )
306
+
307
+ normalize_node = helper.make_node(
308
+ "LayerNormalization",
309
+ inputs=[transpose_input.output[0], weight_input + "_NHWC", bias_input + "_NHWC"],
310
+ outputs=[layernorm_node_name + "_out_nhwc"],
311
+ name=layernorm_node_name,
312
+ )
313
+ normalize_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
314
+
315
+ self.nodes_to_add.append(transpose_input)
316
+ self.nodes_to_add.append(normalize_node)
317
+ self.nodes_to_add.append(transpose_output)
318
+ self.node_name_to_graph_name[transpose_input.name] = self.this_graph_name
319
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name
320
+ self.node_name_to_graph_name[transpose_output.name] = self.this_graph_name
321
+
322
+ counter_name = "LayerNormalization(NHWC)"
323
+ self.increase_counter(counter_name)
324
+
325
+
326
+ class FusionLayerNormalizationTF(Fusion):
327
+ def __init__(self, model: OnnxModel):
328
+ super().__init__(model, "LayerNormalization", "Add", "TF")
329
+
330
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
331
+ """
332
+ Layer Norm from Tensorflow model(using keras2onnx or tf2onnx):
333
+ +------------------------------------+
334
+ | |
335
+ | |
336
+ (Cast_1) |
337
+ | |
338
+ | v (B) (B) (A)
339
+ Add --> (Cast_1) --> ReduceMean --> Sub --> Mul --> ReduceMean --> (Cast_3) --> Add --> Sqrt --> Reciprocol --> Mul --> Mul --> Sub --> Add
340
+ | | | ^ ^
341
+ | | | | |
342
+ | +--------------------------------------------------(Cast_2)-------------------------------|-------+ |
343
+ | v |
344
+ +---------------------------------------------------------------------------------------------------------------> Mul--------------------+
345
+ """
346
+ return_indice = []
347
+ _, parent_nodes, return_indice = self.model.match_parent_paths(
348
+ node,
349
+ [
350
+ (
351
+ [
352
+ "Sub",
353
+ "Mul",
354
+ "Mul",
355
+ "Reciprocal",
356
+ "Sqrt",
357
+ "Add",
358
+ "ReduceMean",
359
+ "Mul",
360
+ "Sub",
361
+ "ReduceMean",
362
+ ],
363
+ [1, 1, None, 0, 0, 0, None, 0, 0, None],
364
+ ),
365
+ (
366
+ [
367
+ "Sub",
368
+ "Mul",
369
+ "Mul",
370
+ "Reciprocal",
371
+ "Sqrt",
372
+ "Add",
373
+ "Cast",
374
+ "ReduceMean",
375
+ "Mul",
376
+ "Sub",
377
+ "ReduceMean",
378
+ ],
379
+ [1, 1, None, 0, 0, 0, 0, None, 0, 0, None],
380
+ ),
381
+ ],
382
+ output_name_to_node,
383
+ )
384
+
385
+ if parent_nodes is None:
386
+ return
387
+
388
+ assert len(return_indice) == 3
389
+ if not (return_indice[0] in [0, 1] and return_indice[1] in [0, 1] and return_indice[2] in [0, 1]):
390
+ logger.debug("return indice is exepected in [0, 1], but got {return_indice}")
391
+ return
392
+
393
+ (
394
+ sub_node_0,
395
+ mul_node_0,
396
+ mul_node_1,
397
+ reciprocol_node,
398
+ sqrt_node,
399
+ add_node_0,
400
+ ) = parent_nodes[:6]
401
+ reduce_mean_node_0, mul_node_2, sub_node_1, reduce_mean_node_1 = parent_nodes[-4:]
402
+
403
+ cast_node_3 = None
404
+ if len(parent_nodes) == 11:
405
+ cast_node_3 = parent_nodes[6]
406
+ assert cast_node_3.op_type == "Cast"
407
+
408
+ mul_node_3 = self.model.match_parent(node, "Mul", 0, output_name_to_node)
409
+ if mul_node_3 is None:
410
+ logger.debug("mul_node_3 not found")
411
+ return
412
+
413
+ node_before_reduce = self.model.get_parent(reduce_mean_node_1, 0, output_name_to_node)
414
+ root_node = (
415
+ node_before_reduce
416
+ if cast_node_3 is None
417
+ else self.model.get_parent(node_before_reduce, 0, output_name_to_node)
418
+ )
419
+ if root_node is None:
420
+ logger.debug("root node is none")
421
+ return
422
+
423
+ i, epsilon = self.model.get_constant_input(add_node_0)
424
+ if epsilon is None or epsilon <= 0 or (epsilon > 1.0e-5 and cast_node_3 is None):
425
+ logger.debug("epsilon is not matched")
426
+ return
427
+
428
+ if cast_node_3 is None and (
429
+ reduce_mean_node_1.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
430
+ ):
431
+ logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
432
+ return
433
+
434
+ if cast_node_3 is not None and (
435
+ node_before_reduce.input[0] not in mul_node_3.input or reduce_mean_node_1.input[0] not in sub_node_1.input
436
+ ):
437
+ logger.debug("reduce_mean_node_1 and mul_node_3 shall link from root node")
438
+ return
439
+
440
+ if mul_node_2.input[0] != mul_node_2.input[1]:
441
+ logger.debug("mul_node_2 shall have two same inputs")
442
+ return
443
+
444
+ subgraph_nodes = [
445
+ node,
446
+ sub_node_0,
447
+ mul_node_0,
448
+ mul_node_1,
449
+ reciprocol_node,
450
+ sqrt_node,
451
+ add_node_0,
452
+ reduce_mean_node_0,
453
+ mul_node_2,
454
+ sub_node_1,
455
+ reduce_mean_node_1,
456
+ mul_node_3,
457
+ ]
458
+
459
+ if cast_node_3 is not None:
460
+ cast_node_2 = self.model.match_parent(mul_node_0, "Cast", 0, output_name_to_node)
461
+ if cast_node_2 is None:
462
+ logger.debug("cast_node_2 not found")
463
+ return
464
+ subgraph_nodes.extend([node_before_reduce, cast_node_2, cast_node_3])
465
+
466
+ if not self.model.is_safe_to_fuse_nodes(
467
+ subgraph_nodes,
468
+ node.output,
469
+ self.model.input_name_to_nodes(),
470
+ self.model.output_name_to_node(),
471
+ ):
472
+ logger.debug("not safe to fuse layer normalization")
473
+ return
474
+
475
+ self.nodes_to_remove.extend(subgraph_nodes)
476
+
477
+ weight_input = mul_node_1.input[1]
478
+ bias_input = sub_node_0.input[0]
479
+
480
+ # TODO: add epsilon attribute
481
+ fused_node = helper.make_node(
482
+ "LayerNormalization",
483
+ inputs=[mul_node_3.input[0], weight_input, bias_input],
484
+ outputs=[node.output[0]],
485
+ name=self.model.create_node_name("LayerNormalization", name_prefix="LayerNorm"),
486
+ )
487
+ fused_node.attribute.extend([helper.make_attribute("epsilon", float(epsilon))])
488
+ self.nodes_to_add.append(fused_node)
489
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name