onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,811 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Dict, List, Optional, Tuple, Union
8
+
9
+ from fusion_base import Fusion
10
+ from fusion_utils import FusionUtils
11
+ from onnx import NodeProto, TensorProto, helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionEmbedLayerNoMask(Fusion):
18
+ """
19
+ Fuse embedding layer into one node (EmbedLayerNormalization).
20
+ It supports the following model types: BERT, DistilBert, ALBert.
21
+ """
22
+
23
+ def __init__(self, model: OnnxModel, description: str = "no mask"):
24
+ super().__init__(
25
+ model,
26
+ "EmbedLayerNormalization",
27
+ ["LayerNormalization", "SkipLayerNormalization"],
28
+ description,
29
+ )
30
+ self.utils = FusionUtils(model)
31
+ self.shape_infer = None
32
+ self.shape_infer_done = False
33
+
34
+ # The following will be reset in each fuse call of FusionEmbedLayerNormalization
35
+ self.attention = None
36
+ self.embed_node = None
37
+
38
+ def match_two_gather(self, add: NodeProto) -> Union[None, Tuple[NodeProto, NodeProto]]:
39
+ gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
40
+ if gather_0_path is None:
41
+ return None
42
+
43
+ gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
44
+ if gather_1_path is None:
45
+ return None
46
+
47
+ return gather_0_path[0], gather_1_path[0]
48
+
49
+ def check_attention_subgraph(
50
+ self,
51
+ layernorm: NodeProto,
52
+ input_name_to_nodes: Dict[str, List[NodeProto]],
53
+ is_distil_bert: bool,
54
+ ) -> bool:
55
+ """Check that LayerNormalization has a child of Attention node or subgraph like Attention.
56
+
57
+ Args:
58
+ layernorm (NodeProto): LayerNormalization node
59
+ input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
60
+ is_distil_bert (bool): whether it is DistilBert or not
61
+
62
+ Returns:
63
+ bool: whether there is Attention node or subgraph like Attention
64
+ """
65
+ self.attention = self.model.find_first_child_by_type(
66
+ layernorm, "Attention", input_name_to_nodes, recursive=False
67
+ )
68
+
69
+ if self.attention is not None:
70
+ return True
71
+
72
+ if layernorm.output[0] not in input_name_to_nodes:
73
+ return False
74
+ children = input_name_to_nodes[layernorm.output[0]]
75
+ children_types = sorted([child.op_type for child in children])
76
+
77
+ # Try find MultiHeadAttention
78
+ if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
79
+ for node in children:
80
+ if node.op_type == "SkipLayerNormalization":
81
+ path1 = self.model.match_parent_path(
82
+ node,
83
+ ["Add", "MatMul", "MultiHeadAttention", "MatMul"],
84
+ [None, None, 0, 0],
85
+ )
86
+ if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
87
+ self.cross_attention = path1[2]
88
+ return True
89
+
90
+ # In case user disables attention fusion, check whether subgraph looks like Attention.
91
+ # For Albert, there is MatMul+Add after embedding layer before attention.
92
+ if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
93
+ grandchildren = input_name_to_nodes[children[0].output[0]]
94
+ if (
95
+ len(grandchildren) == 1
96
+ and grandchildren[0].op_type == "Add"
97
+ and grandchildren[0].output[0] in input_name_to_nodes
98
+ ):
99
+ nodes = input_name_to_nodes[grandchildren[0].output[0]]
100
+ for node in nodes:
101
+ if node.op_type == "Attention":
102
+ self.attention = node
103
+ return True
104
+ children_types = sorted([child.op_type for child in nodes])
105
+
106
+ # Two Shape nodes might be merged by ORT
107
+ if is_distil_bert:
108
+ # SkipLayerNormailization might exist when model has been optimized by ORT first.
109
+ if (
110
+ children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
111
+ and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
112
+ and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
113
+ ):
114
+ logger.debug("No Attention like subgraph in children of LayerNormalization")
115
+ return False
116
+ else:
117
+ if children_types != [
118
+ "Add",
119
+ "MatMul",
120
+ "MatMul",
121
+ "MatMul",
122
+ ] and children_types != [
123
+ "MatMul",
124
+ "MatMul",
125
+ "MatMul",
126
+ "SkipLayerNormalization",
127
+ ]:
128
+ logger.debug("No Attention like subgraph in children of LayerNormalization")
129
+ return False
130
+
131
+ return True
132
+
133
+ def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
134
+ """ Match position embedding path from input_ids to Gather for DistilBert.
135
+
136
+ Pattern is like the following:
137
+ (input_ids)
138
+ |
139
+ Shape
140
+ | \
141
+ | Gather (indices=1)
142
+ | |
143
+ | Cast (optional)
144
+ | |
145
+ | Range (start=0, end=*, delta=1)
146
+ | |
147
+ | Unsqueeze
148
+ | /
149
+ Expand
150
+ |
151
+ Gather
152
+ """
153
+ # remove after tests pass
154
+ path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
155
+ if path1 is None:
156
+ path1 = self.model.match_parent_path(
157
+ position_embedding_gather,
158
+ ["Expand", "Where", "Reshape", "Shape"],
159
+ [1, 1, 2, 0],
160
+ )
161
+ if path1 is None:
162
+ return False
163
+
164
+ expand, shape = path1[0], path1[-1]
165
+ if shape.input[0] != input_ids:
166
+ return False
167
+
168
+ _, path2, _ = self.model.match_parent_paths(
169
+ expand,
170
+ [
171
+ (["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
172
+ (["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
173
+ ],
174
+ output_name_to_node,
175
+ )
176
+ if path2 is None:
177
+ return False
178
+
179
+ range_node = path2[1]
180
+ if not (
181
+ self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
182
+ ):
183
+ return False
184
+
185
+ gather_node = path2[-2]
186
+ if not (self.utils.check_node_input_value(gather_node, 1, 1)):
187
+ return False
188
+
189
+ shape_node = path2[-1]
190
+ if shape_node.input[0] != input_ids:
191
+ return False
192
+
193
+ return True
194
+
195
+ def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
196
+ """Match position embedding path from input_ids to Gather for Roberta.
197
+
198
+ Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
199
+ (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
200
+ | ^
201
+ V |
202
+ +------------------------------+
203
+
204
+ Roberta new pattern from transformers v4.9:
205
+ (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
206
+ | ^
207
+ V |
208
+ +-------------------------------------------+
209
+
210
+ start_node = position_embedding_gather
211
+ start_index = 1
212
+
213
+ # match optional Cast node.
214
+ parent = self.model.get_parent(start_node, start_index, output_name_to_node)
215
+ if parent is None:
216
+ return
217
+ if parent.op_type == "Cast":
218
+ if OnnxModel.get_node_attribute(parent, "to") != 7:
219
+ return
220
+ start_node = parent
221
+ start_index = 0
222
+
223
+ i, path, return_indices = self.model.match_parent_paths(
224
+ start_node,
225
+ [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
226
+ (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
227
+ output_name_to_node)
228
+
229
+ if path is not None:
230
+ # constant input of Add shall be 1.
231
+ i, value = self.model.get_constant_input(path[0])
232
+ if value != 1:
233
+ return False
234
+
235
+ _, self.padding_word_id = self.model.get_constant_input(path[-1])
236
+
237
+ return input_ids == path[-1].input[0]
238
+ """
239
+
240
+ return False
241
+
242
+ def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
243
+ """ Match position embedding path from input_ids to Gather for BERT.
244
+
245
+ BERT Embedding Layer Pattern:
246
+ (input_ids)
247
+ / \
248
+ / Shape
249
+ / |
250
+ / Gather (indices=1)
251
+ / |
252
+ / Add (optional, B=0)
253
+ / |
254
+ Gather (segment_ids) Unsqueeze (axes=0)
255
+ \\ | |
256
+ \\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
257
+ \\ / |
258
+ Add Gather
259
+ \\ /
260
+ Add
261
+ |
262
+ LayerNormalization
263
+ """
264
+ path = self.model.match_parent_path(
265
+ position_embedding_gather,
266
+ ["Slice", "Unsqueeze"],
267
+ [1, 2],
268
+ output_name_to_node,
269
+ )
270
+ if path is None:
271
+ return False
272
+
273
+ slice, unsqueeze = path
274
+ slice_weight = self.model.get_constant_value(slice.input[0])
275
+ if not (
276
+ slice_weight is not None
277
+ and len(slice_weight.shape) == 2
278
+ and slice_weight.shape[0] == 1
279
+ and self.utils.check_node_input_value(slice, 1, [0])
280
+ and self.utils.check_node_input_value(slice, 3, [1])
281
+ and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
282
+ ):
283
+ return False
284
+
285
+ opset_version = self.model.get_opset_version()
286
+ if opset_version < 13:
287
+ if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
288
+ return False
289
+ else:
290
+ if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
291
+ return False
292
+
293
+ node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
294
+ if node is None:
295
+ return False
296
+ if node.op_type == "Add":
297
+ if not self.utils.check_node_input_value(node, 1, 0):
298
+ return False
299
+ gather = self.model.get_parent(node, 0, output_name_to_node)
300
+ else:
301
+ gather = node
302
+
303
+ if gather is None or gather.op_type != "Gather":
304
+ return False
305
+ if not (self.utils.check_node_input_value(gather, 1, 1)):
306
+ return False
307
+
308
+ shape = self.model.get_parent(gather, 0, output_name_to_node)
309
+ if shape is None or shape.op_type != "Shape":
310
+ return False
311
+
312
+ return input_ids == shape.input[0]
313
+
314
+ def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
315
+ if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
316
+ return True
317
+
318
+ # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
319
+ # related: https://github.com/huggingface/transformers/issues/10736
320
+ # if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
321
+ # return True
322
+
323
+ if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
324
+ return True
325
+
326
+ return False
327
+
328
+ def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
329
+ """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
330
+ input_ids = word_embedding_gather.input[1]
331
+ segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
332
+ position_ids = position_embedding_gather.input[1]
333
+
334
+ if not self.shape_infer_done:
335
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
336
+ self.shape_infer_done = True
337
+
338
+ if self.shape_infer is not None:
339
+ input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
340
+ position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
341
+ assert input_ids_shape and position_ids_shape
342
+ if not (
343
+ len(input_ids_shape) == 2
344
+ and len(position_ids_shape) == 2
345
+ and input_ids_shape[1] == position_ids_shape[1]
346
+ ):
347
+ logger.info(
348
+ f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
349
+ )
350
+ return False
351
+
352
+ if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
353
+ logger.info(
354
+ f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
355
+ )
356
+ return False
357
+
358
+ word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
359
+ if word_embedding_table is None or len(word_embedding_table.shape) != 2:
360
+ logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
361
+ return False
362
+
363
+ position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
364
+ if (
365
+ position_embedding_table is None
366
+ or len(position_embedding_table.shape) != 2
367
+ or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
368
+ ):
369
+ logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
370
+ return False
371
+
372
+ if segment_ids:
373
+ segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
374
+ if (
375
+ segment_embedding_table is None
376
+ or len(segment_embedding_table.shape) != 2
377
+ or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
378
+ ):
379
+ logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
380
+ return False
381
+
382
+ # In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
383
+ # TODO: use other information (like initializer names) to identify different embedding weights automatically.
384
+ if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
385
+ logger.warning(
386
+ f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
387
+ )
388
+
389
+ if segment_ids:
390
+ if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
391
+ logger.warning(
392
+ f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
393
+ )
394
+
395
+ if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
396
+ logger.warning(
397
+ f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
398
+ )
399
+
400
+ return True
401
+
402
+ def cast_to_int32(self, input_name: str) -> Tuple[str, Union[None, NodeProto]]:
403
+ """Cast a graph input or node input to int32.
404
+
405
+ Args:
406
+ input_name (str): name of graph input or node input
407
+
408
+ Returns:
409
+ A tuple of casted input name and the cast node.
410
+ int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
411
+ input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
412
+ """
413
+ input_cast_node = None
414
+ graph_input = self.model.find_graph_input(input_name)
415
+ if graph_input is not None:
416
+ if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
417
+ int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
418
+ else:
419
+ int32_output = input_name
420
+ else:
421
+ int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
422
+
423
+ return int32_output, input_cast_node
424
+
425
+ def create_fused_node(
426
+ self,
427
+ input_ids: str,
428
+ layernorm: NodeProto,
429
+ word_embedding_gather: NodeProto,
430
+ position_embedding_gather: NodeProto,
431
+ segment_embedding_gather: Union[None, NodeProto],
432
+ position_ids: Optional[str] = None,
433
+ embedding_sum_output=False,
434
+ embedding_sum_name=None,
435
+ ):
436
+ """Create an EmbedLayerNormalization node. Note that segment embedding is optional.
437
+
438
+ Args:
439
+ input_ids (str): input_ids for word embeddings
440
+ layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
441
+ word_embedding_gather (NodeProto): the Gather node for word embedding
442
+ position_embedding_gather (NodeProto): the Gather node for position embedding
443
+ segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
444
+
445
+ Returns:
446
+ NodeProto: the EmbedLayerNormalization node created.
447
+ """
448
+ nodes_to_add = []
449
+ input_ids, _ = self.cast_to_int32(input_ids)
450
+
451
+ node_name = self.model.create_node_name("EmbedLayerNormalization")
452
+
453
+ if layernorm.op_type == "LayerNormalization":
454
+ gamma = layernorm.input[1]
455
+ beta = layernorm.input[2]
456
+ else: # SkipLayerNormalization
457
+ gamma = layernorm.input[2]
458
+ beta = layernorm.input[3]
459
+
460
+ embed_node_inputs = None
461
+ if segment_embedding_gather is not None:
462
+ segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
463
+
464
+ embed_node_inputs = [
465
+ input_ids,
466
+ segment_ids,
467
+ word_embedding_gather.input[0],
468
+ position_embedding_gather.input[0],
469
+ segment_embedding_gather.input[0],
470
+ gamma,
471
+ beta,
472
+ ]
473
+ else: # no segment embedding
474
+ embed_node_inputs = [
475
+ input_ids,
476
+ "",
477
+ word_embedding_gather.input[0],
478
+ position_embedding_gather.input[0],
479
+ "",
480
+ gamma,
481
+ beta,
482
+ ]
483
+
484
+ if position_ids is not None:
485
+ # Adding an empty input for mask before position_ids
486
+ embed_node_inputs.append("")
487
+ position_ids, _ = self.cast_to_int32(position_ids)
488
+ embed_node_inputs.append(position_ids)
489
+
490
+ embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
491
+ if embedding_sum_output:
492
+ name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
493
+ embed_node_outputs.append(name)
494
+
495
+ embed_node = helper.make_node(
496
+ "EmbedLayerNormalization",
497
+ embed_node_inputs,
498
+ outputs=embed_node_outputs,
499
+ name=node_name,
500
+ )
501
+
502
+ embed_node.domain = "com.microsoft"
503
+
504
+ # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
505
+ for att in layernorm.attribute:
506
+ if att.name == "epsilon":
507
+ embed_node.attribute.extend([att])
508
+
509
+ # Set default value to 1e-12 if no attribute is found.
510
+ # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
511
+ if len(embed_node.attribute) == 0:
512
+ embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
513
+
514
+ # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
515
+ nodes_to_add.append(embed_node)
516
+ for node in nodes_to_add:
517
+ self.node_name_to_graph_name[node.name] = self.this_graph_name
518
+ self.nodes_to_add.extend(nodes_to_add)
519
+
520
+ self.embed_node = embed_node
521
+ return embed_node
522
+
523
+ def finish_fusion(self, layernorm, embed_node):
524
+ self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
525
+ # use prune graph to remove nodes that is not needed
526
+ self.prune_graph = True
527
+
528
+ def is_skip_layer_norm_with_sum_output(self, node):
529
+ return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
530
+
531
+ def fuse_gpt2(
532
+ self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
533
+ ):
534
+ # graph checks
535
+ # gpt2 has optional segment embedding, subgraph pattern is like
536
+ # input_ids position_ids
537
+ # | |
538
+ # token_ids Gather Gather
539
+ # | \ /
540
+ # Gather (optional) Add _ _ _ _ _
541
+ # \ | |
542
+ # LayerNormalization |
543
+ # | |
544
+ # Attention |
545
+ # | |
546
+ # Matmul |
547
+ # | /
548
+ # Add /
549
+ # \ /
550
+ # Add
551
+ two_gather = self.match_two_gather(add_before_layernorm)
552
+ if two_gather is None:
553
+ return False
554
+
555
+ word_embedding_gather, position_embedding_gather = two_gather
556
+ input_ids = word_embedding_gather.input[1]
557
+ position_ids = position_embedding_gather.input[1]
558
+
559
+ if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
560
+ return False
561
+
562
+ if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
563
+ return False
564
+
565
+ # If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
566
+ # If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
567
+ # If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
568
+ # is the (optional) fourth index output of this node.
569
+ # When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
570
+ if layernorm.op_type == "SkipLayerNormalization":
571
+ need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
572
+ sum_output_index = 3
573
+ node_with_sum_output = layernorm
574
+ sum_output = layernorm.output[3] if need_embedding_sum_output else None
575
+ is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
576
+ else: # layernorm.op_type == "LayerNormalization"
577
+ node_with_sum_output = add_before_layernorm
578
+ sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
579
+ sum_output = (
580
+ add_before_layernorm.output[sum_output_index]
581
+ if len(add_before_layernorm.output) > sum_output_index
582
+ else None
583
+ )
584
+ is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
585
+ is_sum_used_by_multiple_nodes = (
586
+ sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
587
+ )
588
+ need_embedding_sum_output = (sum_output is not None) and (
589
+ add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
590
+ )
591
+
592
+ # make the fused node
593
+ embed_node = self.create_fused_node(
594
+ input_ids,
595
+ layernorm,
596
+ word_embedding_gather,
597
+ position_embedding_gather,
598
+ optional_segment_gather,
599
+ position_ids,
600
+ embedding_sum_output=need_embedding_sum_output,
601
+ embedding_sum_name=sum_output if is_sum_graph_output else None,
602
+ )
603
+
604
+ if need_embedding_sum_output:
605
+ node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
606
+ if not is_sum_graph_output:
607
+ self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
608
+
609
+ self.finish_fusion(layernorm, embed_node)
610
+ return True
611
+
612
+ def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
613
+ """Fuse embedding layer for DistilBert
614
+ Args:
615
+ layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
616
+ add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
617
+ input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
618
+ output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
619
+ """
620
+
621
+ # DistilBert has no segment embedding, subgraph pattern is like
622
+ # input_ids
623
+ # | \
624
+ # | (position_embedding_subgraph)
625
+ # | |
626
+ # Gather Gather
627
+ # \ /
628
+ # Add
629
+ # |
630
+ # LayerNormalization
631
+ two_gather = self.match_two_gather(add_before_layernorm)
632
+ if two_gather is None:
633
+ return False
634
+
635
+ word_embedding_gather, position_embedding_gather = two_gather
636
+ input_ids = word_embedding_gather.input[1]
637
+
638
+ if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
639
+ return False
640
+
641
+ if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
642
+ return False
643
+
644
+ if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
645
+ return False
646
+
647
+ embed_node = self.create_fused_node(
648
+ input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
649
+ )
650
+ self.finish_fusion(layernorm, embed_node)
651
+ return True
652
+
653
+ def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
654
+ """Fuse embedding layer for Bert
655
+ Args:
656
+ layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
657
+ add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
658
+ input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
659
+ output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
660
+ """
661
+
662
+ add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
663
+ if add_2_gather is None:
664
+ return False
665
+
666
+ two_gather = self.match_two_gather(add_2_gather[0])
667
+ if two_gather is None:
668
+ return False
669
+
670
+ word_embedding_gather, segment_embedding_gather = two_gather
671
+
672
+ input_ids = word_embedding_gather.input[1]
673
+
674
+ if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
675
+ return False
676
+
677
+ position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
678
+ if position_embedding_path is None:
679
+ return False
680
+
681
+ position_embedding_gather = position_embedding_path[0]
682
+ if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
683
+ if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
684
+ return False
685
+ # position and segment are switched
686
+ temp = segment_embedding_gather
687
+ segment_embedding_gather = position_embedding_gather
688
+ position_embedding_gather = temp
689
+
690
+ if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
691
+ return False
692
+
693
+ embed_node = self.create_fused_node(
694
+ input_ids,
695
+ layernorm,
696
+ word_embedding_gather,
697
+ position_embedding_gather,
698
+ segment_embedding_gather,
699
+ )
700
+ self.finish_fusion(layernorm, embed_node)
701
+ return True
702
+
703
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
704
+ first_add_path = self.model.match_parent_path(node, ["Add"], [0])
705
+ if node.op_type == "LayerNormalization":
706
+ if first_add_path is None:
707
+ return
708
+ add_before_layernorm = first_add_path[0]
709
+ optional_segment_gather = None
710
+ else: # SkipLayerNormalization
711
+ gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
712
+ gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
713
+ if gather_0_path is None and gather_1_path is not None:
714
+ if first_add_path is None:
715
+ return
716
+ add_before_layernorm = first_add_path[0]
717
+ optional_segment_gather = gather_1_path[0]
718
+ elif gather_0_path is not None and gather_1_path is None:
719
+ first_add_path = self.model.match_parent_path(node, ["Add"], [1])
720
+ if first_add_path is None:
721
+ return
722
+ add_before_layernorm = first_add_path[0]
723
+ optional_segment_gather = gather_0_path[0]
724
+ else:
725
+ add_before_layernorm = node # Add is fused into SkipLayerNormalization
726
+ optional_segment_gather = None
727
+
728
+ if self.fuse_gpt2(
729
+ node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
730
+ ):
731
+ return
732
+
733
+ if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
734
+ return
735
+
736
+ if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
737
+ return
738
+
739
+
740
+ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
741
+ def __init__(self, model: OnnxModel, use_mask_index=False):
742
+ super().__init__(model, "with mask")
743
+ self.use_mask_index = use_mask_index
744
+
745
+ def replace_mask(self, mask_int32, attention_nodes):
746
+ # Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
747
+ # segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
748
+ embed_node = self.embed_node
749
+ if len(embed_node.input) == 7:
750
+ embed_node.input.append(mask_int32)
751
+ logger.debug("append mask to %s", embed_node.name)
752
+ elif len(embed_node.input) > 7 and not embed_node.input[7]:
753
+ embed_node.input[7] = mask_int32
754
+ logger.debug("replace mask in %s", embed_node.name)
755
+ else:
756
+ logger.debug("skip mask in %s", embed_node.name)
757
+ return
758
+
759
+ for attention_node in attention_nodes:
760
+ logger.debug("update mask_index in %s", attention_node.name)
761
+ if attention_node.op_type == "Attention":
762
+ attention_node.input[3] = embed_node.output[1]
763
+ elif attention_node.op_type == "MultiHeadAttention":
764
+ attention_node.input[4] = embed_node.output[1]
765
+
766
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
767
+ # Reset attention and embed_node so that we know fusion is successful when they are not None.
768
+ self.attention = None
769
+ self.cross_attention = None
770
+ self.embed_node = None
771
+ super().fuse(node, input_name_to_nodes, output_name_to_node)
772
+
773
+ if self.embed_node is None:
774
+ return
775
+
776
+ if not self.use_mask_index:
777
+ logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
778
+ self.increase_counter("EmbedLayerNormalization(no mask)")
779
+ return
780
+
781
+ if self.attention is None and self.cross_attention is None:
782
+ logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
783
+ self.increase_counter("EmbedLayerNormalization(no mask)")
784
+ return
785
+
786
+ if self.attention:
787
+ mask_int32 = self.attention.input[3]
788
+ else:
789
+ mask_int32 = self.cross_attention.input[4]
790
+
791
+ children_nodes = input_name_to_nodes[mask_int32]
792
+ if self.model.find_graph_input(mask_int32):
793
+ attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
794
+ self.replace_mask(mask_int32, attention_nodes)
795
+ self.increase_counter("EmbedLayerNormalization(with mask)")
796
+ return
797
+
798
+ if mask_int32 not in output_name_to_node:
799
+ logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
800
+ self.increase_counter("EmbedLayerNormalization(no mask)")
801
+ return
802
+
803
+ node = output_name_to_node[mask_int32]
804
+ if node.op_type in ["ReduceSum", "Cast"]:
805
+ attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
806
+ if node.op_type == "ReduceSum":
807
+ mask_int32 = node.input[0]
808
+ if len(children_nodes) == len(attention_nodes):
809
+ self.nodes_to_remove.append(node)
810
+ self.replace_mask(mask_int32, attention_nodes)
811
+ self.increase_counter("EmbedLayerNormalization(with mask)")