onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,355 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ import numpy as np
8
+ from fusion_gpt_attention import FusionGptAttentionPastBase
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ def is_close(value, expected_value):
16
+ return abs(value - expected_value) <= 1e-6
17
+
18
+
19
+ class FusionGptAttentionMegatron(FusionGptAttentionPastBase):
20
+ """
21
+ Fuse GPT-2 Attention with past state subgraph from Megatron into one Attention node.
22
+ """
23
+
24
+ def __init__(self, model: OnnxModel, num_heads: int):
25
+ super().__init__(model, num_heads)
26
+
27
+ def fuse_attention_node(
28
+ self,
29
+ matmul_before_split,
30
+ add_before_split,
31
+ past,
32
+ present,
33
+ input,
34
+ reshape_qkv,
35
+ mask,
36
+ ):
37
+ attention_node_name = self.model.create_node_name("GptAttention")
38
+ int32_mask = self.cast_attention_mask(mask)
39
+ output = reshape_qkv.output[0]
40
+ i = 1 if (add_before_split.input[0] == matmul_before_split.output[0]) else 0
41
+ attention_node = helper.make_node(
42
+ "Attention",
43
+ inputs=[
44
+ input,
45
+ matmul_before_split.input[1],
46
+ add_before_split.input[i],
47
+ int32_mask,
48
+ past,
49
+ ],
50
+ outputs=[output, present],
51
+ name=attention_node_name,
52
+ )
53
+ attention_node.domain = "com.microsoft"
54
+ attention_node.attribute.extend(
55
+ [
56
+ helper.make_attribute("num_heads", self.num_heads),
57
+ helper.make_attribute("unidirectional", 0), # unidirectional shall not be ON for 4D attention mask
58
+ ]
59
+ )
60
+ if self.mask_filter_value is not None:
61
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
62
+
63
+ nodes_to_add = [attention_node]
64
+ self.nodes_to_add.extend(nodes_to_add)
65
+
66
+ for node in nodes_to_add:
67
+ self.node_name_to_graph_name[node.name] = self.this_graph_name
68
+
69
+ self.nodes_to_remove.append(reshape_qkv)
70
+
71
+ # we rely on prune_graph() to clean old subgraph nodes
72
+ self.prune_graph = True
73
+
74
+ def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention):
75
+ mask_nodes = self.model.match_parent_path(sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0])
76
+ if mask_nodes is None:
77
+ logger.debug("fuse_attention: failed to match unidirectional mask path")
78
+ return None
79
+ (mul_mask, sub_mask, last_slice_mask, slice_mask) = mask_nodes
80
+
81
+ if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
82
+ _, mul_val = self.model.get_constant_input(mask_nodes[0])
83
+ if mul_val != 10000:
84
+ self.mask_filter_value = -mul_val
85
+
86
+ if mul_qk.input[1] != last_slice_mask.output[0]:
87
+ logger.debug("fuse_attention failed: mul_qk.input[1] != last_slice_mask.output[0]")
88
+ return None
89
+
90
+ if not self.utils.check_node_input_value(mul_mask, 1, 10000.0):
91
+ logger.debug("fuse_attention failed: mul_mask input 1 is not constant 10000.0")
92
+ return None
93
+
94
+ if not self.utils.check_node_input_value(sub_mask, 0, 1.0):
95
+ logger.debug("fuse_attention failed: sub_mask input 0 is not constant 1.0")
96
+ return None
97
+
98
+ if not self.model.find_graph_input(slice_mask.input[0]):
99
+ logger.info("expect slick_mask input 0 to be graph input")
100
+ return None
101
+
102
+ if not self.utils.check_node_input_value(last_slice_mask, 1, [0]):
103
+ logger.debug("fuse_attention failed: last_slice_mask input 1 (starts) is not constant [0]")
104
+ return None
105
+
106
+ if not self.utils.check_node_input_value(last_slice_mask, 3, [3]):
107
+ logger.debug("fuse_attention failed: last_slice_mask input 3 (axes) is not constant [3]")
108
+ return False
109
+
110
+ if not self.utils.check_node_input_value(last_slice_mask, 4, [1]):
111
+ logger.debug("fuse_attention failed: last_slice_mask input 4 (steps) is not constant [1]")
112
+ return False
113
+
114
+ if not self.utils.check_node_input_value(slice_mask, 3, [2]):
115
+ logger.debug("fuse_attention failed: slice_mask input 3 (axes) is not constant [2]")
116
+ return None
117
+
118
+ if not self.utils.check_node_input_value(slice_mask, 4, [1]):
119
+ logger.debug("fuse_attention failed: slice_mask input 4 (steps) is not constant [1]")
120
+ return None
121
+
122
+ last_slice_path = self.model.match_parent_path(
123
+ last_slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
124
+ )
125
+ if last_slice_path is None or last_slice_path[-1] != matmul_qk:
126
+ logger.debug("fuse_attention: failed to match last slice path")
127
+ return None
128
+
129
+ first_slice_path = self.model.match_parent_path(
130
+ slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
131
+ )
132
+ if first_slice_path is None or first_slice_path[-1] != matmul_qk:
133
+ logger.debug("fuse_attention: failed to match first slice path")
134
+ return None
135
+
136
+ first_slice_sub = self.model.match_parent_path(
137
+ slice_mask,
138
+ ["Unsqueeze", "Sub", "Gather", "Shape", "MatMul"],
139
+ [1, 0, 0, 0, 0],
140
+ )
141
+ if first_slice_sub is None or first_slice_sub[-1] != matmul_qk:
142
+ logger.debug("fuse_attention: failed to match last slice sub path")
143
+ return None
144
+
145
+ first_slice_sub_1 = self.model.match_parent_path(
146
+ slice_mask,
147
+ ["Unsqueeze", "Sub", "Gather", "Shape", "LayerNormalization"],
148
+ [1, 0, 1, 0, 0],
149
+ )
150
+
151
+ if first_slice_sub_1 is None:
152
+ first_slice_sub_1 = self.model.match_parent_path(
153
+ slice_mask,
154
+ ["Unsqueeze", "Sub", "Gather", "Shape", "SkipLayerNormalization"],
155
+ [1, 0, 1, 0, 0],
156
+ )
157
+
158
+ if first_slice_sub_1 is None or first_slice_sub_1[-1] != layernorm_before_attention:
159
+ logger.debug("fuse_attention: failed to match last slice sub path 1")
160
+ return None
161
+
162
+ return slice_mask.input[0]
163
+
164
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
165
+ past = None
166
+ present = None
167
+
168
+ is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
169
+ qkv_nodes = None
170
+
171
+ if not is_normalize_node_skiplayernorm:
172
+ qkv_nodes = self.model.match_parent_path(
173
+ normalize_node,
174
+ ["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
175
+ [0, 1, None, 0, 0, 0],
176
+ output_name_to_node=output_name_to_node,
177
+ )
178
+ else:
179
+ qkv_nodes = self.model.match_parent_path(
180
+ normalize_node,
181
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
182
+ [1, None, 0, 0, 0],
183
+ output_name_to_node=output_name_to_node,
184
+ )
185
+
186
+ if qkv_nodes is None:
187
+ return
188
+
189
+ skip_input = None
190
+ if not is_normalize_node_skiplayernorm:
191
+ (
192
+ add_skip,
193
+ add_after_attention,
194
+ matmul_after_attention,
195
+ reshape_qkv,
196
+ transpose_qkv,
197
+ matmul_qkv,
198
+ ) = qkv_nodes
199
+
200
+ skip_input = add_skip.input[0]
201
+ else:
202
+ (
203
+ add_after_attention,
204
+ matmul_after_attention,
205
+ reshape_qkv,
206
+ transpose_qkv,
207
+ matmul_qkv,
208
+ ) = qkv_nodes
209
+
210
+ skip_input = normalize_node.input[0]
211
+
212
+ v_nodes = self.model.match_parent_path(
213
+ matmul_qkv,
214
+ [
215
+ "Concat",
216
+ "Transpose",
217
+ "Reshape",
218
+ "Split",
219
+ "Add",
220
+ "MatMul",
221
+ "LayerNormalization",
222
+ ],
223
+ [1, 1, 0, 0, 0, None, 0],
224
+ )
225
+
226
+ if v_nodes is None:
227
+ v_nodes = self.model.match_parent_path(
228
+ matmul_qkv,
229
+ [
230
+ "Concat",
231
+ "Transpose",
232
+ "Reshape",
233
+ "Split",
234
+ "Add",
235
+ "MatMul",
236
+ "SkipLayerNormalization",
237
+ ],
238
+ [1, 1, 0, 0, 0, None, 0],
239
+ )
240
+
241
+ if v_nodes is None:
242
+ logger.debug("fuse_attention: failed to match v path")
243
+ return
244
+ (
245
+ concat_v,
246
+ transpose_v,
247
+ reshape_v,
248
+ split_v,
249
+ add_before_split,
250
+ matmul_before_split,
251
+ layernorm_before_attention,
252
+ ) = v_nodes
253
+
254
+ if (
255
+ layernorm_before_attention.op_type == "LayerNormalization"
256
+ and skip_input != layernorm_before_attention.input[0]
257
+ ):
258
+ logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
259
+ return
260
+
261
+ if (
262
+ layernorm_before_attention.op_type == "SkipLayerNormalization"
263
+ and skip_input != layernorm_before_attention.output[3]
264
+ ):
265
+ logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
266
+ return
267
+
268
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "MatMul"], [0, 0, 0, 0])
269
+ if qk_nodes is None:
270
+ logger.debug("fuse_attention: failed to match qk path")
271
+ return None
272
+ (softmax_qk, sub_qk, mul_qk, matmul_qk) = qk_nodes
273
+ if self.model.get_node_attribute(softmax_qk, "axis") != 3:
274
+ logger.debug("fuse_attention failed: softmax_qk axis != 3")
275
+ return None
276
+
277
+ attention_mask = self.match_mask(sub_qk, mul_qk, matmul_qk, layernorm_before_attention)
278
+
279
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Div", "Transpose", "Reshape", "Split"], [0, 0, 0, 0])
280
+ if q_nodes is None:
281
+ logger.debug("fuse_attention: failed to match q path")
282
+ return
283
+ (div_q, transpose_q, reshape_q, split_q) = q_nodes
284
+ if split_v != split_q:
285
+ logger.debug("fuse_attention: skip since split_v != split_q")
286
+ return
287
+
288
+ k_nodes = self.model.match_parent_path(
289
+ matmul_qk,
290
+ ["Div", "Transpose", "Concat", "Transpose", "Reshape", "Split"],
291
+ [1, 0, 0, 1, 0, 0],
292
+ )
293
+ if k_nodes is None:
294
+ logger.debug("fuse_attention: failed to match k path")
295
+ return
296
+ (div_k, _, concat_k, transpose_k, reshape_k, split_k) = k_nodes
297
+ if split_v != split_k:
298
+ logger.debug("fuse_attention: skip since split_v != split_k")
299
+ return
300
+
301
+ i, value = self.model.get_constant_input(reshape_k)
302
+ if not (
303
+ isinstance(value, np.ndarray)
304
+ and list(value.shape) == [4]
305
+ and value[0] == 0
306
+ and value[1] == 0
307
+ and value[2] > 0
308
+ and value[3] > 0
309
+ ):
310
+ logger.debug("fuse_attention: reshape constant input is not [0, 0, N, H]")
311
+ return
312
+
313
+ num_heads = value[2]
314
+ if num_heads != self.num_heads:
315
+ logger.info(f"Detected num_heads={num_heads}. Ignore user specified value {self.num_heads}")
316
+ self.num_heads = num_heads
317
+
318
+ hidden_size_per_head = value[3]
319
+ i, value = self.model.get_constant_input(div_k)
320
+ expected_value = float(np.sqrt(np.sqrt(hidden_size_per_head)))
321
+ if not is_close(value, expected_value):
322
+ logger.debug(f"fuse_attention: div_k value={value} expected={expected_value}")
323
+ return
324
+
325
+ i, value = self.model.get_constant_input(div_q)
326
+ if not is_close(value, expected_value):
327
+ logger.debug(f"fuse_attention: div_q value={value} expected={expected_value}")
328
+ return
329
+
330
+ # Match past and present paths
331
+ past = self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
332
+ if past is None:
333
+ logger.debug("fuse_attention: match past failed")
334
+ return
335
+ if not self.model.find_graph_input(past):
336
+ logger.debug("fuse_attention: past is not graph input.")
337
+ # For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
338
+
339
+ present = self.match_present(concat_v, input_name_to_nodes)
340
+ if present is None:
341
+ logger.debug("fuse_attention: match present failed")
342
+ return
343
+ if not self.model.find_graph_output(present):
344
+ logger.info("fuse_attention: expect present to be graph output")
345
+ return
346
+
347
+ self.fuse_attention_node(
348
+ matmul_before_split,
349
+ add_before_split,
350
+ past,
351
+ present,
352
+ layernorm_before_attention.output[0],
353
+ reshape_qkv,
354
+ attention_mask,
355
+ )
@@ -0,0 +1,260 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ from fusion_base import Fusion
8
+ from onnx import helper
9
+ from onnx_model import OnnxModel
10
+
11
+ logger = getLogger(__name__)
12
+
13
+
14
+ class FusionGptAttentionNoPast(Fusion):
15
+ """
16
+ Fuse GPT-2 Attention without past state into one Attention node.
17
+ This does not support attention_mask graph input right now.
18
+ """
19
+
20
+ def __init__(self, model: OnnxModel, num_heads: int):
21
+ super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "without past")
22
+ # TODO: detect num_heads from graph like FusionAttention
23
+ self.num_heads = num_heads
24
+ self.mask_filter_value = None
25
+
26
+ def create_attention_node(self, gemm, gemm_qkv, input, output):
27
+ attention_node_name = self.model.create_node_name("Attention")
28
+ attention_node = helper.make_node(
29
+ "Attention",
30
+ inputs=[input, gemm.input[1], gemm.input[2]],
31
+ outputs=[attention_node_name + "_output"],
32
+ name=attention_node_name,
33
+ )
34
+ attention_node.domain = "com.microsoft"
35
+ attention_node.attribute.extend(
36
+ [
37
+ helper.make_attribute("num_heads", self.num_heads),
38
+ helper.make_attribute("unidirectional", 1),
39
+ ]
40
+ )
41
+ if self.mask_filter_value is not None:
42
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
43
+
44
+ matmul_node = helper.make_node(
45
+ "MatMul",
46
+ inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
47
+ outputs=[attention_node_name + "_matmul_output"],
48
+ name=attention_node_name + "_matmul",
49
+ )
50
+
51
+ add_node = helper.make_node(
52
+ "Add",
53
+ inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
54
+ outputs=[output],
55
+ name=attention_node_name + "_add",
56
+ )
57
+
58
+ self.nodes_to_add.extend([attention_node, matmul_node, add_node])
59
+ self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
60
+ self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
61
+ self.node_name_to_graph_name[add_node.name] = self.this_graph_name
62
+
63
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
64
+ # (TODO) hasesh/tlwu: Investigate what fixes the following logic needs in order
65
+ # to fuse the Attention sub-graph. With some changes to other fusions, this stopped
66
+ # working.
67
+ return_indice = []
68
+
69
+ is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
70
+ qkv_nodes = None
71
+
72
+ if not is_normalize_node_skiplayernorm:
73
+ qkv_nodes = self.model.match_parent_path(
74
+ normalize_node,
75
+ ["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
76
+ [0, None, 0, 0, 0, 0, 0],
77
+ output_name_to_node=output_name_to_node,
78
+ return_indice=return_indice,
79
+ )
80
+ else:
81
+ qkv_nodes = self.model.match_parent_path(
82
+ normalize_node,
83
+ ["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
84
+ [None, 0, 0, 0, 0, 0],
85
+ output_name_to_node=output_name_to_node,
86
+ return_indice=return_indice,
87
+ )
88
+
89
+ if qkv_nodes is None:
90
+ return
91
+
92
+ another_input = None
93
+ if not is_normalize_node_skiplayernorm:
94
+ (
95
+ add_qkv,
96
+ reshape_qkv,
97
+ gemm_qkv,
98
+ reshape_1,
99
+ reshape_2,
100
+ transpose_qkv,
101
+ matmul_qkv,
102
+ ) = qkv_nodes
103
+
104
+ another_input = add_qkv.input[1 - return_indice[0]]
105
+ else:
106
+ (
107
+ reshape_qkv,
108
+ gemm_qkv,
109
+ reshape_1,
110
+ reshape_2,
111
+ transpose_qkv,
112
+ matmul_qkv,
113
+ ) = qkv_nodes
114
+
115
+ v_nodes = self.model.match_parent_path(
116
+ matmul_qkv,
117
+ ["Transpose", "Reshape", "Split", "Reshape", "Gemm", "Reshape"],
118
+ [1, 0, 0, 0, 0, 0],
119
+ )
120
+ if v_nodes is None:
121
+ logger.debug("fuse_attention: failed to match v path")
122
+ return
123
+ (
124
+ transpose_v,
125
+ reshape_v,
126
+ split_v,
127
+ reshape_after_gemm,
128
+ gemm,
129
+ reshape_before_gemm,
130
+ ) = v_nodes
131
+
132
+ layernorm_before_attention = self.model.get_parent(reshape_before_gemm, 0, output_name_to_node)
133
+ if layernorm_before_attention is None or (
134
+ layernorm_before_attention.op_type != "LayerNormalization"
135
+ and layernorm_before_attention.op_type != "SkipLayerNormalization"
136
+ ):
137
+ if layernorm_before_attention.op_type != "Add":
138
+ logger.debug(f"failed to get (skip)layernorm before gemm. Got {layernorm_before_attention.op_type}")
139
+ return
140
+
141
+ # `another_input` will be non-None only if
142
+ # (1) SkipLayerNorm fusion wasn't turned ON
143
+ # (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
144
+ # fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
145
+ # So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
146
+ if another_input is not None:
147
+ if another_input not in layernorm_before_attention.input:
148
+ # match openai-gpt
149
+ if another_input not in layernorm_before_attention.output:
150
+ logger.debug("Add and (Skip)LayerNormalization shall have one same input")
151
+ return
152
+
153
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
154
+ if qk_nodes is not None:
155
+ (softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
156
+ mask_nodes = self.model.match_parent_path(
157
+ sub_qk,
158
+ [
159
+ "Mul",
160
+ "Sub",
161
+ "Slice",
162
+ "Slice",
163
+ "Unsqueeze",
164
+ "Sub",
165
+ "Squeeze",
166
+ "Slice",
167
+ "Shape",
168
+ "Div",
169
+ ],
170
+ [1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
171
+ )
172
+ if mask_nodes is None:
173
+ logger.debug("fuse_attention: failed to match mask path")
174
+ return
175
+ div_mask = mask_nodes[-1]
176
+
177
+ if div_qk != div_mask:
178
+ logger.debug("fuse_attention: skip since div_qk != div_mask")
179
+ return
180
+ if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
181
+ _, mul_val = self.model.get_constant_input(mask_nodes[0])
182
+ if mul_val != -10000:
183
+ self.mask_filter_value = mul_val
184
+
185
+ else:
186
+ # New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
187
+ qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0])
188
+ if qk_nodes is not None:
189
+ (softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes
190
+ mask_nodes = self.model.match_parent_path(
191
+ where_qk,
192
+ [
193
+ "Cast",
194
+ "Slice",
195
+ "Slice",
196
+ "Unsqueeze",
197
+ "Sub",
198
+ "Squeeze",
199
+ "Slice",
200
+ "Shape",
201
+ "Div",
202
+ ],
203
+ [0, 0, 0, 1, 0, 0, 0, 0, 0],
204
+ )
205
+ if mask_nodes is None:
206
+ logger.debug("fuse_attention: failed to match mask path")
207
+ return
208
+ div_mask = mask_nodes[-1]
209
+
210
+ if div_qk != div_mask:
211
+ logger.debug("fuse_attention: skip since div_qk != div_mask")
212
+ return
213
+ else:
214
+ # match openai-gpt
215
+ qk_nodes = self.model.match_parent_path(
216
+ matmul_qkv,
217
+ ["Softmax", "Add", "Mul", "Div", "MatMul"],
218
+ [0, 0, 0, 0, 0],
219
+ )
220
+ if qk_nodes is None:
221
+ logger.debug("fuse_attention: failed to match qk path")
222
+ return
223
+ (softmax_qk, add_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
224
+ mask_nodes = self.model.match_parent_path(
225
+ mul_qk,
226
+ ["Slice", "Slice", "Unsqueeze", "Squeeze", "Slice", "Shape", "Div"],
227
+ [1, 0, 2, 0, 0, 0, 0],
228
+ )
229
+ if mask_nodes is None:
230
+ logger.debug("fuse_attention: failed to match mask path")
231
+ return
232
+ div_mask = mask_nodes[-1]
233
+
234
+ if div_qk != div_mask:
235
+ logger.debug("fuse_attention: skip since div_qk != div_mask")
236
+ return
237
+
238
+ q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
239
+ if q_nodes is None:
240
+ logger.debug("fuse_attention: failed to match q path")
241
+ return
242
+ (transpose_q, reshape_q, split_q) = q_nodes
243
+ if split_v != split_q:
244
+ logger.debug("fuse_attention: skip since split_v != split_q")
245
+ return
246
+
247
+ k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [1, 0, 0])
248
+ if k_nodes is None:
249
+ logger.debug("fuse_attention: failed to match k path")
250
+ return
251
+ (transpose_k, reshape_k, split_k) = k_nodes
252
+ if split_v != split_k:
253
+ logger.debug("fuse_attention: skip since split_v != split_k")
254
+ return
255
+
256
+ self.create_attention_node(gemm, gemm_qkv, layernorm_before_attention.output[0], reshape_qkv.output[0])
257
+
258
+ # we rely on prune_graph() to clean old subgraph nodes:
259
+ # qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
260
+ self.prune_graph = True