onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,122 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Dict, List, Union
8
+
9
+ from fusion_base import Fusion
10
+ from fusion_utils import NumpyHelper
11
+ from onnx import NodeProto, TensorProto, helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionGemmFastGelu(Fusion):
18
+ def __init__(self, model: OnnxModel):
19
+ super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
20
+ self.shape_infer = None
21
+ self.shape_infer_done = False
22
+
23
+ def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
24
+ if tensor_proto.type.tensor_type.HasField("shape"):
25
+ return len(tensor_proto.type.tensor_type.shape.dim)
26
+ else:
27
+ return None
28
+
29
+ def get_dimensions(self, input_name: str) -> Union[int, None]:
30
+ graph_input = self.model.find_graph_input(input_name)
31
+ if graph_input:
32
+ return self.get_dimensions_from_tensor_proto(graph_input)
33
+
34
+ if not self.shape_infer_done:
35
+ self.shape_infer = self.model.infer_runtime_shape(update=True)
36
+ self.shape_infer_done = True
37
+
38
+ if self.shape_infer is not None:
39
+ return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
40
+
41
+ return None
42
+
43
+ def fuse(
44
+ self,
45
+ node: NodeProto,
46
+ input_name_to_nodes: Dict[str, List[NodeProto]],
47
+ output_name_to_node: Dict[str, NodeProto],
48
+ ):
49
+ """
50
+ This pattern is from PyTorch bert model
51
+ Fuse MatMul with FastGelu into one node:
52
+
53
+ [root] --> MatMul --> FastGelu -->
54
+
55
+ """
56
+ has_bias = False
57
+ if len(node.input) == 2:
58
+ has_bias = True
59
+
60
+ match_nodes = self.model.match_parent_path(node, ["MatMul"], [0])
61
+ if match_nodes is None:
62
+ return
63
+ matmul = match_nodes[0]
64
+
65
+ # matmul input X should >= two dimension, input weight should be two dimension
66
+ weight_index = -1
67
+ x_dims = 0
68
+ weight = None
69
+
70
+ for i, input in enumerate(matmul.input):
71
+ initializer = self.model.get_initializer(input)
72
+ if initializer is None:
73
+ x_dims = self.get_dimensions(matmul.input[i])
74
+ else:
75
+ weight_index = i
76
+ weight = NumpyHelper.to_array(initializer)
77
+ if weight is None:
78
+ return
79
+ if len(weight.shape) != 2:
80
+ return
81
+ if x_dims < len(weight.shape):
82
+ return
83
+
84
+ # bias weight should be one dimension
85
+ bias_index = -1
86
+ if has_bias:
87
+ bias_weight = None
88
+ for i, input in enumerate(node.input):
89
+ initializer = self.model.get_initializer(input)
90
+ if initializer is None:
91
+ continue
92
+ bias_index = i
93
+ bias_weight = NumpyHelper.to_array(initializer)
94
+ break
95
+ if bias_weight is None:
96
+ return
97
+ if len(bias_weight.shape) != 1:
98
+ return
99
+
100
+ subgraph_nodes = [node, matmul]
101
+ if not self.model.is_safe_to_fuse_nodes(
102
+ subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
103
+ ):
104
+ return
105
+
106
+ self.nodes_to_remove.extend(subgraph_nodes)
107
+
108
+ inputs = (
109
+ [matmul.input[1 - weight_index], matmul.input[weight_index], node.input[bias_index]]
110
+ if has_bias
111
+ else [matmul.input[1 - weight_index], matmul.input[weight_index]]
112
+ )
113
+
114
+ fused_node = helper.make_node(
115
+ "GemmFastGelu",
116
+ inputs=inputs,
117
+ outputs=node.output,
118
+ name=self.model.create_node_name("GemmFastGelu"),
119
+ )
120
+ fused_node.domain = "com.microsoft"
121
+ self.nodes_to_add.append(fused_node)
122
+ self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
@@ -0,0 +1,546 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy as np
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionGptAttentionPastBase(Fusion):
17
+ """Base class for GPT Attention Fusion with past state"""
18
+
19
+ def __init__(self, model: OnnxModel, num_heads: int):
20
+ super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "with past")
21
+ self.num_heads = num_heads
22
+ self.utils = FusionUtils(model)
23
+ self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
24
+ self.mask_filter_value = None
25
+
26
+ def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
27
+ # Pattern 1:
28
+ # {past}
29
+ # / \
30
+ # / \
31
+ # Gather(axes=0, indices=0) Gather(indices=1)
32
+ # | |
33
+ # Transpose (perm=0,1,3,2) |
34
+ # | |
35
+ # Concat_k Concat_v
36
+ # | /
37
+ # Transpose (perm=0,1,3,2) /
38
+ # | /
39
+ # Unsqueeze Unsqueeze
40
+ # \ /
41
+ # \ /
42
+ # Concat
43
+ # |
44
+ # {present}
45
+ gather = self.model.get_parent(concat_v, 0, output_name_to_node)
46
+ if gather is None or gather.op_type != "Gather":
47
+ logger.debug("match_past_pattern_1: expect Gather for past")
48
+ return None
49
+
50
+ if self.model.find_constant_input(gather, 1) != 1:
51
+ logger.debug("match_past_pattern_1: expect indices=1 for Gather of past")
52
+ return None
53
+ past = gather.input[0]
54
+
55
+ parent = self.model.get_parent(concat_k, 0, output_name_to_node)
56
+ if parent and parent.op_type == "Gather":
57
+ gather_past_k = parent
58
+ else:
59
+ past_k_nodes = self.model.match_parent_path(concat_k, ["Transpose", "Gather"], [0, 0])
60
+ if past_k_nodes is None:
61
+ logger.debug("match_past_pattern_1: failed match Transpose and Gather")
62
+ return None
63
+ gather_past_k = past_k_nodes[-1]
64
+
65
+ if self.model.find_constant_input(gather_past_k, 0) != 1:
66
+ logger.debug("match_past_pattern_1: expect indices=0 for Gather k of past")
67
+ return None
68
+ past_k = gather_past_k.input[0]
69
+ if past != past_k:
70
+ logger.debug("match_past_pattern_1: expect past to be same")
71
+ return None
72
+
73
+ return past
74
+
75
+ def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
76
+ # Pattern 2:
77
+ # Split (QKV)
78
+ # / | |
79
+ # / | +----------------------+
80
+ # | |
81
+ # | {past} |
82
+ # | | |
83
+ # Reshape Split Reshape
84
+ # | / \ |
85
+ # Transpose_k Squeeze Squeeze Transpose_v
86
+ # | | \ /
87
+ # +------|---+ \ /
88
+ # | | \ /
89
+ # Concat_k Concat_v
90
+ # | |
91
+ # Unsqueeze Unsqueeze
92
+ # \ /
93
+ # Concat
94
+ # |
95
+ # {present}
96
+ #
97
+ squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
98
+ if squeeze is None or squeeze.op_type != "Squeeze":
99
+ logger.debug("match_past_pattern_2: expect Squeeze as parent of concat_v")
100
+ return None
101
+
102
+ split = self.model.get_parent(squeeze, 0, output_name_to_node)
103
+ if split is None or split.op_type != "Split":
104
+ logger.debug("match_past_pattern_2: expect Split for past path")
105
+ return None
106
+
107
+ opset_version = self.model.get_opset_version()
108
+ if opset_version < 13:
109
+ if not FusionUtils.check_node_attribute(squeeze, "axes", [0]):
110
+ logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
111
+ return None
112
+
113
+ if not FusionUtils.check_node_attribute(split, "split", [1, 1]):
114
+ logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
115
+ return None
116
+ else:
117
+ if not self.utils.check_node_input_value(squeeze, 1, [0]):
118
+ logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
119
+ return None
120
+
121
+ if not self.utils.check_node_input_value(split, 1, [1, 1]):
122
+ logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
123
+ return None
124
+
125
+ if not FusionUtils.check_node_attribute(split, "axis", 0, default_value=0):
126
+ logger.debug("match_past_pattern_2: attribute axis of Split are not expected in past path")
127
+ return None
128
+ past = split.input[0]
129
+
130
+ past_k_nodes = self.model.match_parent_path(concat_k, ["Squeeze", "Split"], [0, 0])
131
+ if past_k_nodes is None:
132
+ logger.debug("match_past_pattern_2: failed to match past_k_nodes path")
133
+ return None
134
+ past_k = past_k_nodes[-1].input[0]
135
+
136
+ if past != past_k:
137
+ logger.info("match_past_pattern_2: expect past to be same")
138
+ return None
139
+
140
+ return past
141
+
142
+ def match_present(self, concat_v, input_name_to_nodes):
143
+ unsqueeze_present_v = self.model.find_first_child_by_type(
144
+ concat_v, "Unsqueeze", input_name_to_nodes, recursive=False
145
+ )
146
+ if not unsqueeze_present_v:
147
+ logger.info("expect unsqueeze for present")
148
+ return None
149
+ concat_present = self.model.find_first_child_by_type(
150
+ unsqueeze_present_v, "Concat", input_name_to_nodes, recursive=False
151
+ )
152
+ if not concat_present:
153
+ logger.info("expect concat for present")
154
+ return None
155
+
156
+ present = concat_present.output[0]
157
+ return present
158
+
159
+ def cast_attention_mask(self, input_name):
160
+ if input_name in self.casted_attention_mask:
161
+ attention_mask_input_name = self.casted_attention_mask[input_name]
162
+ elif self.model.find_graph_input(input_name):
163
+ casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(input_name)
164
+ self.casted_attention_mask[input_name] = attention_mask_input_name
165
+ else:
166
+ attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(input_name)
167
+ self.casted_attention_mask[input_name] = attention_mask_input_name
168
+ return attention_mask_input_name
169
+
170
+
171
+ class FusionGptAttention(FusionGptAttentionPastBase):
172
+ """
173
+ Fuse GPT-2 Attention with past state subgraph into one Attention node.
174
+ """
175
+
176
+ def __init__(self, model: OnnxModel, num_heads: int):
177
+ super().__init__(model, num_heads)
178
+
179
+ def create_attention_node(
180
+ self,
181
+ fc_weight,
182
+ fc_bias,
183
+ gemm_qkv,
184
+ past,
185
+ present,
186
+ input,
187
+ output,
188
+ mask,
189
+ is_unidirectional,
190
+ ):
191
+ attention_node_name = self.model.create_node_name("GptAttention")
192
+ attention_node = helper.make_node(
193
+ "Attention",
194
+ inputs=[input, fc_weight, fc_bias, mask, past],
195
+ outputs=[attention_node_name + "_output", present],
196
+ name=attention_node_name,
197
+ )
198
+ attention_node.domain = "com.microsoft"
199
+ attention_node.attribute.extend(
200
+ [
201
+ helper.make_attribute("num_heads", self.num_heads),
202
+ helper.make_attribute("unidirectional", 1 if is_unidirectional else 0),
203
+ ]
204
+ )
205
+
206
+ if self.mask_filter_value is not None:
207
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
208
+
209
+ matmul_node = helper.make_node(
210
+ "MatMul",
211
+ inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
212
+ outputs=[attention_node_name + "_matmul_output"],
213
+ name=attention_node_name + "_matmul",
214
+ )
215
+
216
+ add_node = helper.make_node(
217
+ "Add",
218
+ inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
219
+ outputs=[output],
220
+ name=attention_node_name + "_add",
221
+ )
222
+ self.nodes_to_add.extend([attention_node, matmul_node, add_node])
223
+ self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
224
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
225
+ self.node_name_to_graph_name[add_node.name] = self.this_graph_name
226
+
227
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
228
+ past = None
229
+ present = None
230
+ return_indice = []
231
+
232
+ is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
233
+ qkv_nodes = None
234
+
235
+ if not is_normalize_node_skiplayernorm:
236
+ qkv_nodes = self.model.match_parent_path(
237
+ normalize_node,
238
+ ["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
239
+ [0, None, 0, 0, 0, 0, 0],
240
+ output_name_to_node=output_name_to_node,
241
+ return_indice=return_indice,
242
+ )
243
+ else:
244
+ qkv_nodes = self.model.match_parent_path(
245
+ normalize_node,
246
+ ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
247
+ [None, 0, 0, 0, 0, 0],
248
+ output_name_to_node=output_name_to_node,
249
+ return_indice=return_indice,
250
+ )
251
+
252
+ if qkv_nodes is None:
253
+ return
254
+
255
+ another_input = None
256
+ if not is_normalize_node_skiplayernorm:
257
+ (
258
+ add_qkv,
259
+ reshape_qkv,
260
+ gemm_qkv,
261
+ reshape_1,
262
+ reshape_2,
263
+ transpose_qkv,
264
+ matmul_qkv,
265
+ ) = qkv_nodes
266
+
267
+ another_input = add_qkv.input[1 - return_indice[0]]
268
+ else:
269
+ (
270
+ reshape_qkv,
271
+ gemm_qkv,
272
+ reshape_1,
273
+ reshape_2,
274
+ transpose_qkv,
275
+ matmul_qkv,
276
+ ) = qkv_nodes
277
+
278
+ v_nodes = self.model.match_parent_path(matmul_qkv, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
279
+ if v_nodes is None:
280
+ logger.debug("fuse_attention: failed to match v path")
281
+ return
282
+ (concat_v, transpose_v, reshape_v, split_fc) = v_nodes
283
+
284
+ # Try match pattern using Gemm + LayerNormalization
285
+ fc_nodes = self.model.match_parent_path(
286
+ split_fc,
287
+ ["Reshape", "Gemm", "Reshape", "LayerNormalization"],
288
+ [0, 0, 0, 0],
289
+ output_name_to_node,
290
+ )
291
+
292
+ # Try match pattern using Gemm + SkipLayerNormalization
293
+ if fc_nodes is None:
294
+ fc_nodes = self.model.match_parent_path(
295
+ split_fc,
296
+ ["Reshape", "Gemm", "Reshape", "SkipLayerNormalization"],
297
+ [0, 0, 0, 0],
298
+ output_name_to_node,
299
+ )
300
+
301
+ # Try match pattern using MatMul
302
+ if fc_nodes is None:
303
+ # LayerNormalization
304
+ fc_nodes = self.model.match_parent_path(
305
+ split_fc,
306
+ ["Add", "MatMul", "LayerNormalization"],
307
+ [0, None, 0],
308
+ output_name_to_node,
309
+ )
310
+
311
+ # SkipLayerNormalization
312
+ if fc_nodes is None:
313
+ fc_nodes = self.model.match_parent_path(
314
+ split_fc,
315
+ ["Add", "MatMul", "SkipLayerNormalization"],
316
+ [0, None, 0],
317
+ output_name_to_node,
318
+ )
319
+
320
+ if fc_nodes is None:
321
+ logger.debug("fuse_attention: failed to match fc path")
322
+ return
323
+
324
+ fc_weight = fc_nodes[1].input[1]
325
+ i, _ = self.model.get_constant_input(fc_nodes[0])
326
+ fc_bias = fc_nodes[0].input[i]
327
+ else:
328
+ fc_weight = fc_nodes[1].input[1]
329
+ fc_bias = fc_nodes[1].input[2]
330
+
331
+ layernorm_before_attention = fc_nodes[-1]
332
+
333
+ # `another_input` will be non-None only if
334
+ # (1) SkipLayerNorm fusion wasn't turned ON
335
+ # (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
336
+ # fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
337
+ # So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
338
+ if another_input is not None and another_input not in layernorm_before_attention.input:
339
+ logger.debug("Upstream Add and (Skip)LayerNormalization shall have one same input")
340
+ return
341
+
342
+ is_unidirectional = True
343
+ slice_mask = None
344
+ input_mask_nodes = None
345
+ concat_k_to_match = None
346
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
347
+ if qk_nodes is not None:
348
+ (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
349
+ mask_nodes = self.model.match_parent_path(
350
+ sub_qk,
351
+ [
352
+ "Mul",
353
+ "Sub",
354
+ "Slice",
355
+ "Slice",
356
+ "Unsqueeze",
357
+ "Sub",
358
+ "Squeeze",
359
+ "Slice",
360
+ "Shape",
361
+ "Div",
362
+ ],
363
+ [1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
364
+ )
365
+ if mask_nodes is None:
366
+ logger.debug("fuse_attention: failed to match unidirectional mask path")
367
+ return
368
+ div_mask = mask_nodes[-1]
369
+ slice_mask = mask_nodes[3]
370
+
371
+ if div_qk != div_mask:
372
+ logger.debug("fuse_attention: skip since div_qk != div_mask")
373
+ return
374
+
375
+ if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
376
+ _, mul_val = self.model.get_constant_input(mask_nodes[0])
377
+ if mul_val != -10000:
378
+ self.mask_filter_value = -mul_val
379
+
380
+ else:
381
+ # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
382
+ i, qk_nodes, _ = self.model.match_parent_paths(
383
+ matmul_qkv,
384
+ [
385
+ (["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]),
386
+ (["Softmax", "Add", "Where", "Div", "MatMul"], [0, 0, None, 1, 0]),
387
+ ],
388
+ output_name_to_node,
389
+ )
390
+ if qk_nodes is None:
391
+ logger.debug("fuse_attention: failed to match qk nodes")
392
+ return
393
+
394
+ where_qk = qk_nodes[-3]
395
+ div_qk = qk_nodes[-2]
396
+ matmul_qk = qk_nodes[-1]
397
+
398
+ if i == 1:
399
+ add_qk = qk_nodes[1]
400
+ _, input_mask_nodes, _ = self.model.match_parent_paths(
401
+ add_qk,
402
+ [
403
+ (
404
+ ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze", "Reshape"],
405
+ [None, 0, 1, 0, 0, 0],
406
+ ),
407
+ (
408
+ ["Mul", "Sub", "Unsqueeze", "Unsqueeze", "Reshape"],
409
+ [None, 0, 1, 0, 0],
410
+ ),
411
+ (
412
+ ["Mul", "Sub", "Unsqueeze", "Unsqueeze"],
413
+ [None, 0, 1, 0],
414
+ ), # useless cast and reshape are removed.
415
+ ],
416
+ output_name_to_node,
417
+ )
418
+ if input_mask_nodes is None:
419
+ logger.debug("fuse_attention: failed to match input attention mask path")
420
+ return
421
+ if len(input_mask_nodes) > 1 and input_mask_nodes[0].op_type == "Mul":
422
+ _, mul_val = self.model.get_constant_input(input_mask_nodes[0])
423
+ if mul_val != -10000:
424
+ self.mask_filter_value = mul_val
425
+
426
+ i, mask_nodes, _ = self.model.match_parent_paths(
427
+ where_qk,
428
+ [
429
+ (
430
+ ["Cast", "Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape"],
431
+ [0, 0, 0, 1, 0, 0, 0, 0],
432
+ ),
433
+ # For Transformers >= 4.27, causal mask uses torch.bool instead of torch.uint8, so no Cast to bool.
434
+ (
435
+ ["Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape"],
436
+ [0, 0, 1, 0, 0, 0, 0],
437
+ ),
438
+ ],
439
+ output_name_to_node,
440
+ )
441
+ if mask_nodes is None:
442
+ # TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
443
+ logger.debug("fuse_attention: failed to match mask path")
444
+ return
445
+
446
+ slice_mask = mask_nodes[2 if i == 0 else 1]
447
+
448
+ div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node)
449
+ if div_or_concat.op_type == "Div":
450
+ div_mask = div_or_concat
451
+ if div_qk != div_mask:
452
+ logger.debug("fuse_attention: skip since div_qk != div_mask")
453
+ return
454
+ elif div_or_concat.op_type == "Concat":
455
+ concat_k_to_match = div_or_concat
456
+ else:
457
+ logger.debug("fuse_attention: failed to match mask path")
458
+
459
+ # Validate that the mask data is either lower triangular (unidirectional) or all ones
460
+ mask_data = self.model.get_constant_value(slice_mask.input[0])
461
+ if not (
462
+ isinstance(mask_data, np.ndarray)
463
+ and len(mask_data.shape) == 4
464
+ and mask_data.shape[:2] == (1, 1)
465
+ and mask_data.shape[2] == mask_data.shape[3]
466
+ ):
467
+ logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
468
+ return
469
+
470
+ if np.allclose(mask_data, np.ones_like(mask_data)):
471
+ is_unidirectional = False
472
+ elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
473
+ logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
474
+ return
475
+
476
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
477
+ if q_nodes is None:
478
+ logger.debug("fuse_attention: failed to match q path")
479
+ return
480
+ (transpose_q, reshape_q, split_q) = q_nodes
481
+ if split_fc != split_q:
482
+ logger.debug("fuse_attention: skip since split_fc != split_q")
483
+ return
484
+
485
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
486
+ if k_nodes is None:
487
+ # This pattern is from pytorch 1.7.1 and transformers 4.6.1
488
+ k_nodes = self.model.match_parent_path(
489
+ matmul_qk,
490
+ ["Transpose", "Concat", "Transpose", "Reshape", "Split"],
491
+ [1, 0, 1, 0, 0],
492
+ )
493
+ if k_nodes is None:
494
+ logger.debug("fuse_attention: failed to match k path")
495
+ return
496
+ else:
497
+ (_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
498
+ else:
499
+ (concat_k, transpose_k, reshape_k, split_k) = k_nodes
500
+ if split_fc != split_k:
501
+ logger.debug("fuse_attention: skip since split_fc != split_k")
502
+ return
503
+
504
+ if concat_k_to_match and concat_k != concat_k_to_match:
505
+ logger.debug("fuse_attention: skip since concat_k != concat_k_to_match")
506
+ return
507
+
508
+ attention_mask_input_name = ""
509
+ if input_mask_nodes is not None:
510
+ input_name = input_mask_nodes[-1].input[0]
511
+ attention_mask_input_name = self.cast_attention_mask(input_name)
512
+
513
+ # Match past and present paths
514
+ past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or self.match_past_pattern_2(
515
+ concat_k, concat_v, output_name_to_node
516
+ )
517
+ if past is None:
518
+ logger.info("fuse_attention: failed to match past path")
519
+ return
520
+ if not self.model.find_graph_input(past):
521
+ logger.debug("past is not graph input.")
522
+ # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
523
+
524
+ present = self.match_present(concat_v, input_name_to_nodes)
525
+ if present is None:
526
+ logger.info("fuse_attention: failed to match present path")
527
+ return
528
+ if not self.model.find_graph_output(present):
529
+ logger.info("expect present to be graph output")
530
+ return
531
+
532
+ self.create_attention_node(
533
+ fc_weight,
534
+ fc_bias,
535
+ gemm_qkv,
536
+ past,
537
+ present,
538
+ layernorm_before_attention.output[0],
539
+ reshape_qkv.output[0],
540
+ attention_mask_input_name,
541
+ is_unidirectional,
542
+ )
543
+
544
+ # we rely on prune_graph() to clean old subgraph nodes:
545
+ # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
546
+ self.prune_graph = True