onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,810 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import NodeProto, TensorProto, helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionEmbedLayerNoMask(Fusion):
17
+ """
18
+ Fuse embedding layer into one node (EmbedLayerNormalization).
19
+ It supports the following model types: BERT, DistilBert, ALBert.
20
+ """
21
+
22
+ def __init__(self, model: OnnxModel, description: str = "no mask"):
23
+ super().__init__(
24
+ model,
25
+ "EmbedLayerNormalization",
26
+ ["LayerNormalization", "SkipLayerNormalization"],
27
+ description,
28
+ )
29
+ self.utils = FusionUtils(model)
30
+ self.shape_infer = None
31
+ self.shape_infer_done = False
32
+
33
+ # The following will be reset in each fuse call of FusionEmbedLayerNormalization
34
+ self.attention = None
35
+ self.embed_node = None
36
+
37
+ def match_two_gather(self, add: NodeProto) -> None | tuple[NodeProto, NodeProto]:
38
+ gather_0_path = self.model.match_parent_path(add, ["Gather"], [0])
39
+ if gather_0_path is None:
40
+ return None
41
+
42
+ gather_1_path = self.model.match_parent_path(add, ["Gather"], [1])
43
+ if gather_1_path is None:
44
+ return None
45
+
46
+ return gather_0_path[0], gather_1_path[0]
47
+
48
+ def check_attention_subgraph(
49
+ self,
50
+ layernorm: NodeProto,
51
+ input_name_to_nodes: dict[str, list[NodeProto]],
52
+ is_distil_bert: bool,
53
+ ) -> bool:
54
+ """Check that LayerNormalization has a child of Attention node or subgraph like Attention.
55
+
56
+ Args:
57
+ layernorm (NodeProto): LayerNormalization node
58
+ input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
59
+ is_distil_bert (bool): whether it is DistilBert or not
60
+
61
+ Returns:
62
+ bool: whether there is Attention node or subgraph like Attention
63
+ """
64
+ self.attention = self.model.find_first_child_by_type(
65
+ layernorm, "Attention", input_name_to_nodes, recursive=False
66
+ )
67
+
68
+ if self.attention is not None:
69
+ return True
70
+
71
+ if layernorm.output[0] not in input_name_to_nodes:
72
+ return False
73
+ children = input_name_to_nodes[layernorm.output[0]]
74
+ children_types = sorted([child.op_type for child in children])
75
+
76
+ # Try find MultiHeadAttention
77
+ if children_types == ["MatMul", "MatMul", "MatMul", "SkipLayerNormalization"]:
78
+ for node in children:
79
+ if node.op_type == "SkipLayerNormalization":
80
+ path1 = self.model.match_parent_path(
81
+ node,
82
+ ["Add", "MatMul", "MultiHeadAttention", "MatMul"],
83
+ [None, None, 0, 0],
84
+ )
85
+ if path1 is not None and path1[-1].input[0] == layernorm.output[0]:
86
+ self.cross_attention = path1[2]
87
+ return True
88
+
89
+ # In case user disables attention fusion, check whether subgraph looks like Attention.
90
+ # For Albert, there is MatMul+Add after embedding layer before attention.
91
+ if len(children) == 1 and children[0].op_type == "MatMul" and children[0].output[0] in input_name_to_nodes:
92
+ grandchildren = input_name_to_nodes[children[0].output[0]]
93
+ if (
94
+ len(grandchildren) == 1
95
+ and grandchildren[0].op_type == "Add"
96
+ and grandchildren[0].output[0] in input_name_to_nodes
97
+ ):
98
+ nodes = input_name_to_nodes[grandchildren[0].output[0]]
99
+ for node in nodes:
100
+ if node.op_type == "Attention":
101
+ self.attention = node
102
+ return True
103
+ children_types = sorted([child.op_type for child in nodes])
104
+
105
+ # Two Shape nodes might be merged by ORT
106
+ if is_distil_bert:
107
+ # SkipLayerNormailization might exist when model has been optimized by ORT first.
108
+ if (
109
+ children_types != ["MatMul", "MatMul", "MatMul", "Shape", "SkipLayerNormalization"]
110
+ and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape", "Shape"]
111
+ and children_types != ["Add", "MatMul", "MatMul", "MatMul", "Shape"]
112
+ ):
113
+ logger.debug("No Attention like subgraph in children of LayerNormalization")
114
+ return False
115
+ else:
116
+ if children_types != [
117
+ "Add",
118
+ "MatMul",
119
+ "MatMul",
120
+ "MatMul",
121
+ ] and children_types != [
122
+ "MatMul",
123
+ "MatMul",
124
+ "MatMul",
125
+ "SkipLayerNormalization",
126
+ ]:
127
+ logger.debug("No Attention like subgraph in children of LayerNormalization")
128
+ return False
129
+
130
+ return True
131
+
132
+ def match_position_embedding_distilbert(self, position_embedding_gather, input_ids, output_name_to_node):
133
+ """ Match position embedding path from input_ids to Gather for DistilBert.
134
+
135
+ Pattern is like the following:
136
+ (input_ids)
137
+ |
138
+ Shape
139
+ | \
140
+ | Gather (indices=1)
141
+ | |
142
+ | Cast (optional)
143
+ | |
144
+ | Range (start=0, end=*, delta=1)
145
+ | |
146
+ | Unsqueeze
147
+ | /
148
+ Expand
149
+ |
150
+ Gather
151
+ """
152
+ # remove after tests pass
153
+ path1 = self.model.match_parent_path(position_embedding_gather, ["Expand", "Shape"], [1, 1])
154
+ if path1 is None:
155
+ path1 = self.model.match_parent_path(
156
+ position_embedding_gather,
157
+ ["Expand", "Where", "Reshape", "Shape"],
158
+ [1, 1, 2, 0],
159
+ )
160
+ if path1 is None:
161
+ return False
162
+
163
+ expand, shape = path1[0], path1[-1]
164
+ if shape.input[0] != input_ids:
165
+ return False
166
+
167
+ _, path2, _ = self.model.match_parent_paths(
168
+ expand,
169
+ [
170
+ (["Unsqueeze", "Range", "Cast", "Gather", "Shape"], [0, 0, 1, 0, 0]),
171
+ (["Unsqueeze", "Range", "Gather", "Shape"], [0, 0, 1, 0]),
172
+ ],
173
+ output_name_to_node,
174
+ )
175
+ if path2 is None:
176
+ return False
177
+
178
+ range_node = path2[1]
179
+ if not (
180
+ self.utils.check_node_input_value(range_node, 0, 0) and self.utils.check_node_input_value(range_node, 2, 1)
181
+ ):
182
+ return False
183
+
184
+ gather_node = path2[-2]
185
+ if not (self.utils.check_node_input_value(gather_node, 1, 1)):
186
+ return False
187
+
188
+ shape_node = path2[-1]
189
+ if shape_node.input[0] != input_ids:
190
+ return False
191
+
192
+ return True
193
+
194
+ def match_position_embedding_roberta(self, position_embedding_gather, input_ids, output_name_to_node):
195
+ """Match position embedding path from input_ids to Gather for Roberta.
196
+
197
+ Roberta Embedding Layer Pattern (* is optional since it might be removed by ORT, ? is the padding word id):
198
+ (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Mul -- Cast(to=7) -- Add(B=1) -- Cast(to=7)* --> Gather
199
+ | ^
200
+ V |
201
+ +------------------------------+
202
+
203
+ Roberta new pattern from transformers v4.9:
204
+ (input_ids) --> Equal(B=?) -- Not -- Cast(to=6) -- CumSum(axis=1) -- Add(B=0) -- Mul -- Cast(to=7) -- Add(B=1) --> Gather
205
+ | ^
206
+ V |
207
+ +-------------------------------------------+
208
+
209
+ start_node = position_embedding_gather
210
+ start_index = 1
211
+
212
+ # match optional Cast node.
213
+ parent = self.model.get_parent(start_node, start_index, output_name_to_node)
214
+ if parent is None:
215
+ return
216
+ if parent.op_type == "Cast":
217
+ if OnnxModel.get_node_attribute(parent, "to") != 7:
218
+ return
219
+ start_node = parent
220
+ start_index = 0
221
+
222
+ i, path, return_indices = self.model.match_parent_paths(
223
+ start_node,
224
+ [ (['Add', 'Cast', 'Mul', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0]),
225
+ (['Add', 'Cast', 'Mul', 'Add', 'CumSum', 'Cast', 'Not', 'Equal'], [start_index, 0, 0, 0, 0, 0, 0, 0])],
226
+ output_name_to_node)
227
+
228
+ if path is not None:
229
+ # constant input of Add shall be 1.
230
+ i, value = self.model.get_constant_input(path[0])
231
+ if value != 1:
232
+ return False
233
+
234
+ _, self.padding_word_id = self.model.get_constant_input(path[-1])
235
+
236
+ return input_ids == path[-1].input[0]
237
+ """
238
+
239
+ return False
240
+
241
+ def match_position_embedding_bert(self, position_embedding_gather, input_ids, output_name_to_node):
242
+ """ Match position embedding path from input_ids to Gather for BERT.
243
+
244
+ BERT Embedding Layer Pattern:
245
+ (input_ids)
246
+ / \
247
+ / Shape
248
+ / |
249
+ / Gather (indices=1)
250
+ / |
251
+ / Add (optional, B=0)
252
+ / |
253
+ Gather (segment_ids) Unsqueeze (axes=0)
254
+ \\ | |
255
+ \\ Gather Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
256
+ \\ / |
257
+ Add Gather
258
+ \\ /
259
+ Add
260
+ |
261
+ LayerNormalization
262
+ """
263
+ path = self.model.match_parent_path(
264
+ position_embedding_gather,
265
+ ["Slice", "Unsqueeze"],
266
+ [1, 2],
267
+ output_name_to_node,
268
+ )
269
+ if path is None:
270
+ return False
271
+
272
+ slice, unsqueeze = path
273
+ slice_weight = self.model.get_constant_value(slice.input[0])
274
+ if not (
275
+ slice_weight is not None
276
+ and len(slice_weight.shape) == 2
277
+ and slice_weight.shape[0] == 1
278
+ and self.utils.check_node_input_value(slice, 1, [0])
279
+ and self.utils.check_node_input_value(slice, 3, [1])
280
+ and (len(slice.input) == 4 or self.utils.check_node_input_value(slice, 4, [1]))
281
+ ):
282
+ return False
283
+
284
+ opset_version = self.model.get_opset_version()
285
+ if opset_version < 13:
286
+ if not FusionUtils.check_node_attribute(unsqueeze, "axes", [0]):
287
+ return False
288
+ else:
289
+ if not self.utils.check_node_input_value(unsqueeze, 1, [0]):
290
+ return False
291
+
292
+ node = self.model.get_parent(unsqueeze, 0, output_name_to_node)
293
+ if node is None:
294
+ return False
295
+ if node.op_type == "Add":
296
+ if not self.utils.check_node_input_value(node, 1, 0):
297
+ return False
298
+ gather = self.model.get_parent(node, 0, output_name_to_node)
299
+ else:
300
+ gather = node
301
+
302
+ if gather is None or gather.op_type != "Gather":
303
+ return False
304
+ if not (self.utils.check_node_input_value(gather, 1, 1)):
305
+ return False
306
+
307
+ shape = self.model.get_parent(gather, 0, output_name_to_node)
308
+ if shape is None or shape.op_type != "Shape":
309
+ return False
310
+
311
+ return input_ids == shape.input[0]
312
+
313
+ def match_position_embedding(self, position_embedding_gather, input_ids, output_name_to_node):
314
+ if self.match_position_embedding_bert(position_embedding_gather, input_ids, output_name_to_node):
315
+ return True
316
+
317
+ # TODO: Support roberta (position starts from 2 instead of 0) in EmbedLayerNormalization kernel
318
+ # related: https://github.com/huggingface/transformers/issues/10736
319
+ # if self.match_position_embedding_roberta(position_embedding_gather, input_ids, output_name_to_node):
320
+ # return True
321
+
322
+ if self.match_position_embedding_distilbert(position_embedding_gather, input_ids, output_name_to_node):
323
+ return True
324
+
325
+ return False
326
+
327
+ def check_embedding(self, word_embedding_gather, segment_embedding_gather, position_embedding_gather):
328
+ """Sanity check of embedding weights, and match hidden_size of weights and shape of inputs."""
329
+ input_ids = word_embedding_gather.input[1]
330
+ segment_ids = segment_embedding_gather.input[1] if segment_embedding_gather else None
331
+ position_ids = position_embedding_gather.input[1]
332
+
333
+ if not self.shape_infer_done:
334
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
335
+ self.shape_infer_done = True
336
+
337
+ if self.shape_infer is not None:
338
+ input_ids_shape = self.shape_infer.get_edge_shape(input_ids)
339
+ position_ids_shape = self.shape_infer.get_edge_shape(position_ids)
340
+ assert input_ids_shape and position_ids_shape
341
+ if not (
342
+ len(input_ids_shape) == 2
343
+ and len(position_ids_shape) == 2
344
+ and input_ids_shape[1] == position_ids_shape[1]
345
+ ):
346
+ logger.info(
347
+ f"Cannot fuse EmbedLayerNormalization: input_ids and position_ids not matched in 2nd dimension: {input_ids_shape} vs {position_ids_shape}"
348
+ )
349
+ return False
350
+
351
+ if segment_ids and not self.shape_infer.compare_shape(input_ids, segment_ids):
352
+ logger.info(
353
+ f"Cannot fuse EmbedLayerNormalization: input_ids and segment_ids does not have same shape: {input_ids_shape} != {self.shape_infer.get_edge_shape(segment_ids)}"
354
+ )
355
+ return False
356
+
357
+ word_embedding_table = self.model.get_constant_value(word_embedding_gather.input[0])
358
+ if word_embedding_table is None or len(word_embedding_table.shape) != 2:
359
+ logger.info("Cannot fuse EmbedLayerNormalization: word embedding table is not expected")
360
+ return False
361
+
362
+ position_embedding_table = self.model.get_constant_value(position_embedding_gather.input[0])
363
+ if (
364
+ position_embedding_table is None
365
+ or len(position_embedding_table.shape) != 2
366
+ or (word_embedding_table.shape[1] != position_embedding_table.shape[1])
367
+ ):
368
+ logger.info("Cannot fuse EmbedLayerNormalization: position embedding table is not expected")
369
+ return False
370
+
371
+ if segment_ids:
372
+ segment_embedding_table = self.model.get_constant_value(segment_embedding_gather.input[0])
373
+ if (
374
+ segment_embedding_table is None
375
+ or len(segment_embedding_table.shape) != 2
376
+ or (word_embedding_table.shape[1] != segment_embedding_table.shape[1])
377
+ ):
378
+ logger.info("Cannot fuse EmbedLayerNormalization: segment embedding table is not expected")
379
+ return False
380
+
381
+ # In normal case, word embedding table is the largest, and segment embedding table is the smallest, while position embedding table is in between.
382
+ # TODO: use other information (like initializer names) to identify different embedding weights automatically.
383
+ if word_embedding_table.shape[0] <= position_embedding_table.shape[0]:
384
+ logger.warning(
385
+ f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]}"
386
+ )
387
+
388
+ if segment_ids:
389
+ if word_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
390
+ logger.warning(
391
+ f"word_embedding_table ({word_embedding_gather.input[0]}) size {word_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
392
+ )
393
+
394
+ if position_embedding_table.shape[0] <= segment_embedding_table.shape[0]:
395
+ logger.warning(
396
+ f"position_embedding_table ({position_embedding_gather.input[0]}) size {position_embedding_table.shape[0]} <= segment_embedding_table ({segment_embedding_gather.input[0]}) size {segment_embedding_table.shape[0]}"
397
+ )
398
+
399
+ return True
400
+
401
+ def cast_to_int32(self, input_name: str) -> tuple[str, None | NodeProto]:
402
+ """Cast a graph input or node input to int32.
403
+
404
+ Args:
405
+ input_name (str): name of graph input or node input
406
+
407
+ Returns:
408
+ A tuple of casted input name and the cast node.
409
+ int32_output (str): If input is int32, it is the input name, Otherwise it is output name of Cast node.
410
+ input_cast_node (Union[None, NodeProto]): Cast node. It could be None if input is int32.
411
+ """
412
+ input_cast_node = None
413
+ graph_input = self.model.find_graph_input(input_name)
414
+ if graph_input is not None:
415
+ if graph_input.type.tensor_type.elem_type != TensorProto.INT32:
416
+ int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
417
+ else:
418
+ int32_output = input_name
419
+ else:
420
+ int32_output, input_cast_node = self.utils.cast_input_to_int32(input_name)
421
+
422
+ return int32_output, input_cast_node
423
+
424
+ def create_fused_node(
425
+ self,
426
+ input_ids: str,
427
+ layernorm: NodeProto,
428
+ word_embedding_gather: NodeProto,
429
+ position_embedding_gather: NodeProto,
430
+ segment_embedding_gather: None | NodeProto,
431
+ position_ids: str | None = None,
432
+ embedding_sum_output=False,
433
+ embedding_sum_name=None,
434
+ ):
435
+ """Create an EmbedLayerNormalization node. Note that segment embedding is optional.
436
+
437
+ Args:
438
+ input_ids (str): input_ids for word embeddings
439
+ layernorm (NodeProto): LayerNormalization or SkipLayerNormalization node.
440
+ word_embedding_gather (NodeProto): the Gather node for word embedding
441
+ position_embedding_gather (NodeProto): the Gather node for position embedding
442
+ segment_embedding_gather (Union[None, NodeProto]): the Gather node for segment embedding, or None.
443
+
444
+ Returns:
445
+ NodeProto: the EmbedLayerNormalization node created.
446
+ """
447
+ nodes_to_add = []
448
+ input_ids, _ = self.cast_to_int32(input_ids)
449
+
450
+ node_name = self.model.create_node_name("EmbedLayerNormalization")
451
+
452
+ if layernorm.op_type == "LayerNormalization":
453
+ gamma = layernorm.input[1]
454
+ beta = layernorm.input[2]
455
+ else: # SkipLayerNormalization
456
+ gamma = layernorm.input[2]
457
+ beta = layernorm.input[3]
458
+
459
+ embed_node_inputs = None
460
+ if segment_embedding_gather is not None:
461
+ segment_ids, _ = self.cast_to_int32(segment_embedding_gather.input[1])
462
+
463
+ embed_node_inputs = [
464
+ input_ids,
465
+ segment_ids,
466
+ word_embedding_gather.input[0],
467
+ position_embedding_gather.input[0],
468
+ segment_embedding_gather.input[0],
469
+ gamma,
470
+ beta,
471
+ ]
472
+ else: # no segment embedding
473
+ embed_node_inputs = [
474
+ input_ids,
475
+ "",
476
+ word_embedding_gather.input[0],
477
+ position_embedding_gather.input[0],
478
+ "",
479
+ gamma,
480
+ beta,
481
+ ]
482
+
483
+ if position_ids is not None:
484
+ # Adding an empty input for mask before position_ids
485
+ embed_node_inputs.append("")
486
+ position_ids, _ = self.cast_to_int32(position_ids)
487
+ embed_node_inputs.append(position_ids)
488
+
489
+ embed_node_outputs = [node_name + "_output", node_name + "_dummy_mask_index"]
490
+ if embedding_sum_output:
491
+ name = embedding_sum_name if embedding_sum_name is not None else node_name + "_embedding_sum"
492
+ embed_node_outputs.append(name)
493
+
494
+ embed_node = helper.make_node(
495
+ "EmbedLayerNormalization",
496
+ embed_node_inputs,
497
+ outputs=embed_node_outputs,
498
+ name=node_name,
499
+ )
500
+
501
+ embed_node.domain = "com.microsoft"
502
+
503
+ # Pass attribute "epsilon" from normalize node to EmbedLayerNormalization.
504
+ for att in layernorm.attribute:
505
+ if att.name == "epsilon":
506
+ embed_node.attribute.extend([att])
507
+
508
+ # Set default value to 1e-12 if no attribute is found.
509
+ # OnnxRuntime 1.2.0 or older has no epsilon attribute. The optimized model can only work for 1.3.0 or later.
510
+ if len(embed_node.attribute) == 0:
511
+ embed_node.attribute.extend([helper.make_attribute("epsilon", 1.0e-12)])
512
+
513
+ # Make sure new EmbedLayerNormalization node is the last one in self.nodes_to_add.
514
+ nodes_to_add.append(embed_node)
515
+ for node in nodes_to_add:
516
+ self.node_name_to_graph_name[node.name] = self.this_graph_name
517
+ self.nodes_to_add.extend(nodes_to_add)
518
+
519
+ self.embed_node = embed_node
520
+ return embed_node
521
+
522
+ def finish_fusion(self, layernorm, embed_node):
523
+ self.model.replace_input_of_all_nodes(layernorm.output[0], embed_node.output[0])
524
+ # use prune graph to remove nodes that is not needed
525
+ self.prune_graph = True
526
+
527
+ def is_skip_layer_norm_with_sum_output(self, node):
528
+ return (node.op_type == "SkipLayerNormalization") and len(node.output) > 3 and len(node.output[3]) > 0
529
+
530
+ def fuse_gpt2(
531
+ self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather=None
532
+ ):
533
+ # graph checks
534
+ # gpt2 has optional segment embedding, subgraph pattern is like
535
+ # input_ids position_ids
536
+ # | |
537
+ # token_ids Gather Gather
538
+ # | \ /
539
+ # Gather (optional) Add _ _ _ _ _
540
+ # \ | |
541
+ # LayerNormalization |
542
+ # | |
543
+ # Attention |
544
+ # | |
545
+ # Matmul |
546
+ # | /
547
+ # Add /
548
+ # \ /
549
+ # Add
550
+ two_gather = self.match_two_gather(add_before_layernorm)
551
+ if two_gather is None:
552
+ return False
553
+
554
+ word_embedding_gather, position_embedding_gather = two_gather
555
+ input_ids = word_embedding_gather.input[1]
556
+ position_ids = position_embedding_gather.input[1]
557
+
558
+ if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
559
+ return False
560
+
561
+ if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
562
+ return False
563
+
564
+ # If layernorm node is SkipLayerNormalization, we need look at its optional fourth output.
565
+ # If the add_before_layernorm node is an Add node, then the add_output output is the first output of this node.
566
+ # If the add_before_layernorm node is a SkipLayerNormalization node, then the add_output output
567
+ # is the (optional) fourth index output of this node.
568
+ # When add_before_layernorm is SkipLayerNormalization, add_before_layernorm and layernorm are same node.
569
+ if layernorm.op_type == "SkipLayerNormalization":
570
+ need_embedding_sum_output = self.is_skip_layer_norm_with_sum_output(layernorm)
571
+ sum_output_index = 3
572
+ node_with_sum_output = layernorm
573
+ sum_output = layernorm.output[3] if need_embedding_sum_output else None
574
+ is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
575
+ else: # layernorm.op_type == "LayerNormalization"
576
+ node_with_sum_output = add_before_layernorm
577
+ sum_output_index = 0 if add_before_layernorm.op_type == "Add" else 3
578
+ sum_output = (
579
+ add_before_layernorm.output[sum_output_index]
580
+ if len(add_before_layernorm.output) > sum_output_index
581
+ else None
582
+ )
583
+ is_sum_graph_output = (sum_output is not None) and (self.model.find_graph_output(sum_output) is not None)
584
+ is_sum_used_by_multiple_nodes = (
585
+ sum_output and (sum_output in input_name_to_nodes) and len(input_name_to_nodes[sum_output]) > 1
586
+ )
587
+ need_embedding_sum_output = (sum_output is not None) and (
588
+ add_before_layernorm.op_type != "Add" or is_sum_graph_output or is_sum_used_by_multiple_nodes
589
+ )
590
+
591
+ # make the fused node
592
+ embed_node = self.create_fused_node(
593
+ input_ids,
594
+ layernorm,
595
+ word_embedding_gather,
596
+ position_embedding_gather,
597
+ optional_segment_gather,
598
+ position_ids,
599
+ embedding_sum_output=need_embedding_sum_output,
600
+ embedding_sum_name=sum_output if is_sum_graph_output else None,
601
+ )
602
+
603
+ if need_embedding_sum_output:
604
+ node_with_sum_output.output[sum_output_index] = "_no_use__to_be_removed_"
605
+ if not is_sum_graph_output:
606
+ self.model.replace_input_of_all_nodes(sum_output, embed_node.output[2])
607
+
608
+ self.finish_fusion(layernorm, embed_node)
609
+ return True
610
+
611
+ def fuse_distilbert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
612
+ """Fuse embedding layer for DistilBert
613
+ Args:
614
+ layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
615
+ add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
616
+ input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
617
+ output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
618
+ """
619
+
620
+ # DistilBert has no segment embedding, subgraph pattern is like
621
+ # input_ids
622
+ # | \
623
+ # | (position_embedding_subgraph)
624
+ # | |
625
+ # Gather Gather
626
+ # \ /
627
+ # Add
628
+ # |
629
+ # LayerNormalization
630
+ two_gather = self.match_two_gather(add_before_layernorm)
631
+ if two_gather is None:
632
+ return False
633
+
634
+ word_embedding_gather, position_embedding_gather = two_gather
635
+ input_ids = word_embedding_gather.input[1]
636
+
637
+ if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=True):
638
+ return False
639
+
640
+ if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
641
+ return False
642
+
643
+ if not self.check_embedding(word_embedding_gather, None, position_embedding_gather):
644
+ return False
645
+
646
+ embed_node = self.create_fused_node(
647
+ input_ids, layernorm, word_embedding_gather, position_embedding_gather, None
648
+ )
649
+ self.finish_fusion(layernorm, embed_node)
650
+ return True
651
+
652
+ def fuse_bert(self, layernorm, add_before_layernorm, input_name_to_nodes, output_name_to_node):
653
+ """Fuse embedding layer for Bert
654
+ Args:
655
+ layernorm (NodeProto): node of LayerNormalization or SkipLayerNormalization
656
+ add_before_layernorm (NodeProto): the Add node before LayerNormalization, or the SkipLayerNormalization itself
657
+ input_name_to_nodes (Dict[str, List[NodeProto]]): map from input name to nodes
658
+ output_name_to_node (Dict[str, List[NodeProto]]): map from output name to nodes
659
+ """
660
+
661
+ add_2_gather = self.model.match_parent_path(add_before_layernorm, ["Add"], [0])
662
+ if add_2_gather is None:
663
+ return False
664
+
665
+ two_gather = self.match_two_gather(add_2_gather[0])
666
+ if two_gather is None:
667
+ return False
668
+
669
+ word_embedding_gather, segment_embedding_gather = two_gather
670
+
671
+ input_ids = word_embedding_gather.input[1]
672
+
673
+ if not self.check_attention_subgraph(layernorm, input_name_to_nodes, is_distil_bert=False):
674
+ return False
675
+
676
+ position_embedding_path = self.model.match_parent_path(add_before_layernorm, ["Gather"], [1])
677
+ if position_embedding_path is None:
678
+ return False
679
+
680
+ position_embedding_gather = position_embedding_path[0]
681
+ if not self.match_position_embedding(position_embedding_gather, input_ids, output_name_to_node):
682
+ if not self.match_position_embedding(segment_embedding_gather, input_ids, output_name_to_node):
683
+ return False
684
+ # position and segment are switched
685
+ temp = segment_embedding_gather
686
+ segment_embedding_gather = position_embedding_gather
687
+ position_embedding_gather = temp
688
+
689
+ if not self.check_embedding(word_embedding_gather, segment_embedding_gather, position_embedding_gather):
690
+ return False
691
+
692
+ embed_node = self.create_fused_node(
693
+ input_ids,
694
+ layernorm,
695
+ word_embedding_gather,
696
+ position_embedding_gather,
697
+ segment_embedding_gather,
698
+ )
699
+ self.finish_fusion(layernorm, embed_node)
700
+ return True
701
+
702
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
703
+ first_add_path = self.model.match_parent_path(node, ["Add"], [0])
704
+ if node.op_type == "LayerNormalization":
705
+ if first_add_path is None:
706
+ return
707
+ add_before_layernorm = first_add_path[0]
708
+ optional_segment_gather = None
709
+ else: # SkipLayerNormalization
710
+ gather_0_path = self.model.match_parent_path(node, ["Gather"], [0])
711
+ gather_1_path = self.model.match_parent_path(node, ["Gather"], [1])
712
+ if gather_0_path is None and gather_1_path is not None:
713
+ if first_add_path is None:
714
+ return
715
+ add_before_layernorm = first_add_path[0]
716
+ optional_segment_gather = gather_1_path[0]
717
+ elif gather_0_path is not None and gather_1_path is None:
718
+ first_add_path = self.model.match_parent_path(node, ["Add"], [1])
719
+ if first_add_path is None:
720
+ return
721
+ add_before_layernorm = first_add_path[0]
722
+ optional_segment_gather = gather_0_path[0]
723
+ else:
724
+ add_before_layernorm = node # Add is fused into SkipLayerNormalization
725
+ optional_segment_gather = None
726
+
727
+ if self.fuse_gpt2(
728
+ node, add_before_layernorm, input_name_to_nodes, output_name_to_node, optional_segment_gather
729
+ ):
730
+ return
731
+
732
+ if self.fuse_distilbert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
733
+ return
734
+
735
+ if self.fuse_bert(node, add_before_layernorm, input_name_to_nodes, output_name_to_node):
736
+ return
737
+
738
+
739
+ class FusionEmbedLayerNormalization(FusionEmbedLayerNoMask):
740
+ def __init__(self, model: OnnxModel, use_mask_index=False):
741
+ super().__init__(model, "with mask")
742
+ self.use_mask_index = use_mask_index
743
+
744
+ def replace_mask(self, mask_int32, attention_nodes):
745
+ # Inputs of EmbedLayerNorm: input_ids, segment_ids (optional), word_embedding, position_embedding,
746
+ # segment_embedding (optional), gamma, beta, mask (optional), position_ids (optional)
747
+ embed_node = self.embed_node
748
+ if len(embed_node.input) == 7:
749
+ embed_node.input.append(mask_int32)
750
+ logger.debug("append mask to %s", embed_node.name)
751
+ elif len(embed_node.input) > 7 and not embed_node.input[7]:
752
+ embed_node.input[7] = mask_int32
753
+ logger.debug("replace mask in %s", embed_node.name)
754
+ else:
755
+ logger.debug("skip mask in %s", embed_node.name)
756
+ return
757
+
758
+ for attention_node in attention_nodes:
759
+ logger.debug("update mask_index in %s", attention_node.name)
760
+ if attention_node.op_type == "Attention":
761
+ attention_node.input[3] = embed_node.output[1]
762
+ elif attention_node.op_type == "MultiHeadAttention":
763
+ attention_node.input[4] = embed_node.output[1]
764
+
765
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
766
+ # Reset attention and embed_node so that we know fusion is successful when they are not None.
767
+ self.attention = None
768
+ self.cross_attention = None
769
+ self.embed_node = None
770
+ super().fuse(node, input_name_to_nodes, output_name_to_node)
771
+
772
+ if self.embed_node is None:
773
+ return
774
+
775
+ if not self.use_mask_index:
776
+ logger.debug("--use_mask_index is not set: EmbedLayerNormalization will not have mask")
777
+ self.increase_counter("EmbedLayerNormalization(no mask)")
778
+ return
779
+
780
+ if self.attention is None and self.cross_attention is None:
781
+ logger.debug("EmbedLayerNormalization will not have mask since attention node is not found")
782
+ self.increase_counter("EmbedLayerNormalization(no mask)")
783
+ return
784
+
785
+ if self.attention:
786
+ mask_int32 = self.attention.input[3]
787
+ else:
788
+ mask_int32 = self.cross_attention.input[4]
789
+
790
+ children_nodes = input_name_to_nodes[mask_int32]
791
+ if self.model.find_graph_input(mask_int32):
792
+ attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
793
+ self.replace_mask(mask_int32, attention_nodes)
794
+ self.increase_counter("EmbedLayerNormalization(with mask)")
795
+ return
796
+
797
+ if mask_int32 not in output_name_to_node:
798
+ logger.debug("EmbedLayerNormalization will not have mask since %s is not a node output", mask_int32)
799
+ self.increase_counter("EmbedLayerNormalization(no mask)")
800
+ return
801
+
802
+ node = output_name_to_node[mask_int32]
803
+ if node.op_type in ["ReduceSum", "Cast"]:
804
+ attention_nodes = [node for node in children_nodes if node.op_type in ["Attention", "MultiHeadAttention"]]
805
+ if node.op_type == "ReduceSum":
806
+ mask_int32 = node.input[0]
807
+ if len(children_nodes) == len(attention_nodes):
808
+ self.nodes_to_remove.append(node)
809
+ self.replace_mask(mask_int32, attention_nodes)
810
+ self.increase_counter("EmbedLayerNormalization(with mask)")