onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,141 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from fusion_attention import AttentionMask
8
+ from fusion_bart_attention import FusionBartAttention
9
+ from fusion_options import FusionOptions
10
+ from fusion_reshape import FusionReshape
11
+ from onnx import numpy_helper
12
+ from onnx_model import OnnxModel
13
+ from onnx_model_bert import BertOnnxModel
14
+
15
+ logger = logging.getLogger(__name__)
16
+
17
+
18
+ class FusionBartReshape(FusionReshape):
19
+ def __init__(self, model: OnnxModel):
20
+ super().__init__(model)
21
+
22
+ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
23
+ if reshape_node.input[1] not in output_name_to_node:
24
+ return
25
+
26
+ concat_node = output_name_to_node[reshape_node.input[1]]
27
+ if concat_node.op_type != "Concat" or len(concat_node.input) != 4:
28
+ return
29
+
30
+ path0 = self.model.match_parent_path(
31
+ concat_node,
32
+ ["Unsqueeze", "Gather", "Shape"],
33
+ [0, 0, 0],
34
+ output_name_to_node,
35
+ )
36
+ if path0 is None:
37
+ return
38
+
39
+ (_, gather_0, shape_0) = path0
40
+
41
+ shape = []
42
+ gather_value = self.model.get_constant_value(gather_0.input[1])
43
+ if gather_value == 0:
44
+ shape.append(0)
45
+
46
+ path1 = self.model.match_parent_path(
47
+ concat_node,
48
+ ["Unsqueeze", "Gather", "Shape"],
49
+ [1, 0, 0],
50
+ output_name_to_node,
51
+ )
52
+ if path1 is None:
53
+ input_1_proto = self.model.get_initializer(concat_node.input[1])
54
+ input_2_proto = self.model.get_initializer(concat_node.input[2])
55
+ input_3_proto = self.model.get_initializer(concat_node.input[3])
56
+ if input_1_proto is None or input_2_proto is None or input_3_proto is None:
57
+ return
58
+
59
+ input_1 = numpy_helper.to_array(input_1_proto)
60
+ input_2 = numpy_helper.to_array(input_2_proto)
61
+ input_3 = numpy_helper.to_array(input_3_proto)
62
+ if len(input_1) != 1 or len(input_2) != 1 or len(input_3) != 1:
63
+ return
64
+
65
+ if not (input_1[0] == -1 and input_2[0] > 0 and input_3[0] > 0):
66
+ return
67
+
68
+ shape.extend(input_1)
69
+ shape.extend(input_2)
70
+ shape.extend(input_3)
71
+ gemm_path_with_bias = self.model.match_parent_path(
72
+ reshape_node, ["Add", "MatMul"], [0, 1], output_name_to_node
73
+ )
74
+ gemm_path_no_bias = self.model.match_parent_path(reshape_node, ["MatMul"], [0], output_name_to_node)
75
+ if gemm_path_with_bias is not None:
76
+ gemm_path = gemm_path_with_bias
77
+ elif gemm_path_no_bias is not None:
78
+ gemm_path = gemm_path_no_bias
79
+ else:
80
+ return
81
+
82
+ top_matmul = gemm_path[-1]
83
+ root_input = top_matmul.input[0]
84
+
85
+ self.replace_reshape_node(shape, reshape_node, concat_node)
86
+ else:
87
+ (_, gather_1, shape_1) = path1
88
+
89
+ gather_value = self.model.get_constant_value(gather_1.input[1])
90
+ if gather_value == 1:
91
+ shape.append(0)
92
+
93
+ input_2_proto = self.model.get_initializer(concat_node.input[2])
94
+ input_3_proto = self.model.get_initializer(concat_node.input[3])
95
+ if input_2_proto is None or input_3_proto is None:
96
+ return
97
+
98
+ input_2 = numpy_helper.to_array(input_2_proto)
99
+ input_3 = numpy_helper.to_array(input_3_proto)
100
+ if len(input_2) != 1 or len(input_3) != 1:
101
+ return
102
+
103
+ if not (input_2[0] > 0 and input_3[0] > 0):
104
+ return
105
+
106
+ shape.extend(input_2)
107
+ shape.extend(input_3)
108
+ gemm_path = self.model.match_parent_path(
109
+ reshape_node, ["Mul", "Add", "MatMul"], [0, 0, 1], output_name_to_node
110
+ )
111
+ if gemm_path is None:
112
+ return
113
+
114
+ top_matmul = gemm_path[-1]
115
+ root_input = top_matmul.input[0]
116
+ if shape_0.input[0] != root_input or shape_1.input[0] != root_input:
117
+ return
118
+
119
+ self.replace_reshape_node(shape, reshape_node, concat_node)
120
+
121
+
122
+ class BartOnnxModel(BertOnnxModel):
123
+ def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
124
+ super().__init__(model, num_heads, hidden_size)
125
+ self.attention_mask = AttentionMask(self)
126
+ self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
127
+ self.bart_reshape_fusion_preprocess = FusionBartReshape(self)
128
+
129
+ def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
130
+ self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
131
+ self.attention_fusion.disable_multi_head_attention_bias = (
132
+ False if options is None else options.disable_multi_head_attention_bias
133
+ )
134
+ super().optimize(options, add_dynamic_axes)
135
+
136
+ def fuse_attention(self):
137
+ self.attention_fusion.apply()
138
+
139
+ def preprocess(self):
140
+ self.adjust_reshape_and_expand()
141
+ self.bart_reshape_fusion_preprocess.apply()
@@ -0,0 +1,488 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from convert_to_packing_mode import PackingMode
9
+ from fusion_attention import AttentionMask, FusionAttention
10
+ from fusion_bart_attention import FusionBartAttention
11
+ from fusion_biasgelu import FusionBiasGelu
12
+ from fusion_constant_fold import FusionConstantFold
13
+ from fusion_embedlayer import FusionEmbedLayerNormalization
14
+ from fusion_fastgelu import FusionFastGelu
15
+ from fusion_gelu import FusionGelu
16
+ from fusion_gelu_approximation import FusionGeluApproximation
17
+ from fusion_gemmfastgelu import FusionGemmFastGelu
18
+ from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
19
+ from fusion_options import AttentionMaskFormat, FusionOptions
20
+ from fusion_qordered_attention import FusionQOrderedAttention
21
+ from fusion_qordered_gelu import FusionQOrderedGelu
22
+ from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
23
+ from fusion_qordered_matmul import FusionQOrderedMatMul
24
+ from fusion_quickgelu import FusionQuickGelu
25
+ from fusion_reshape import FusionReshape
26
+ from fusion_rotary_attention import FusionRotaryEmbeddings
27
+ from fusion_shape import FusionShape
28
+ from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
29
+ from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
30
+ from fusion_utils import FusionUtils
31
+ from onnx import ModelProto, TensorProto, helper
32
+ from onnx_model import OnnxModel
33
+
34
+ logger = getLogger(__name__)
35
+
36
+
37
+ class BertOnnxModel(OnnxModel):
38
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
39
+ """Initialize BERT ONNX Model.
40
+
41
+ Args:
42
+ model (ModelProto): the ONNX model
43
+ num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
44
+ hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
45
+ """
46
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
47
+
48
+ super().__init__(model)
49
+ self.num_heads = num_heads
50
+ self.hidden_size = hidden_size
51
+
52
+ self.attention_mask = AttentionMask(self)
53
+ self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
54
+ self.qordered_attention_fusion = FusionQOrderedAttention(
55
+ self, self.hidden_size, self.num_heads, self.attention_mask
56
+ )
57
+ self.utils = FusionUtils(self)
58
+
59
+ def fuse_constant_fold(self):
60
+ fusion = FusionConstantFold(self)
61
+ fusion.apply()
62
+
63
+ def fuse_attention(self):
64
+ self.attention_fusion.apply()
65
+ # Only relevant in models with Q-DQ nodes
66
+ self.qordered_attention_fusion.apply()
67
+
68
+ def fuse_gelu(self):
69
+ fusion = FusionGelu(self)
70
+ fusion.apply()
71
+ fusion = FusionFastGelu(self)
72
+ fusion.apply()
73
+ fusion = FusionQuickGelu(self)
74
+ fusion.apply()
75
+ # Only relevant in models with Q-DQ nodes
76
+ fusion = FusionQOrderedGelu(self)
77
+ fusion.apply()
78
+
79
+ def fuse_bias_gelu(self, is_fastgelu):
80
+ fusion = FusionBiasGelu(self, is_fastgelu)
81
+ fusion.apply()
82
+
83
+ def gelu_approximation(self):
84
+ fusion = FusionGeluApproximation(self)
85
+ fusion.apply()
86
+
87
+ def fuse_gemm_fast_gelu(self):
88
+ fusion = FusionGemmFastGelu(self)
89
+ fusion.apply()
90
+
91
+ def fuse_add_bias_skip_layer_norm(self):
92
+ fusion = FusionBiasSkipLayerNormalization(self)
93
+ fusion.apply()
94
+
95
+ def fuse_reshape(self):
96
+ fusion = FusionReshape(self)
97
+ fusion.apply()
98
+
99
+ def fuse_shape(self):
100
+ fusion = FusionShape(self)
101
+ fusion.apply()
102
+
103
+ def fuse_embed_layer(self, use_mask_index):
104
+ fusion = FusionEmbedLayerNormalization(self, use_mask_index)
105
+ fusion.apply()
106
+
107
+ def fuse_layer_norm(self):
108
+ fusion = FusionLayerNormalization(self)
109
+ fusion.apply()
110
+
111
+ fusion = FusionLayerNormalizationTF(self)
112
+ fusion.apply()
113
+
114
+ # Only relevant in models with Q-DQ nodes
115
+ fusion = FusionQOrderedLayerNormalization(self)
116
+ fusion.apply()
117
+
118
+ def fuse_simplified_layer_norm(self):
119
+ fusion = FusionSimplifiedLayerNormalization(self)
120
+ fusion.apply()
121
+
122
+ def fuse_skip_layer_norm(self, shape_infer=True):
123
+ fusion = FusionSkipLayerNormalization(self, shape_infer=shape_infer)
124
+ fusion.apply()
125
+
126
+ def fuse_skip_simplified_layer_norm(self):
127
+ fusion = FusionSkipSimplifiedLayerNormalization(self)
128
+ fusion.apply()
129
+
130
+ def fuse_rotary_embeddings(self):
131
+ fusion = FusionRotaryEmbeddings(self)
132
+ fusion.apply()
133
+ # Remove non-MS domain functions
134
+ rot_emb_nodes = list(
135
+ filter(
136
+ lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft",
137
+ self.model.graph.node,
138
+ )
139
+ )
140
+ non_ms_domains_to_keep = {node.domain for node in rot_emb_nodes}
141
+ i = 0
142
+ while i < len(self.model.functions):
143
+ fn = self.model.functions[i]
144
+ if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep:
145
+ self.model.functions.remove(fn)
146
+ else:
147
+ i += 1
148
+
149
+ # Only relevant in models with Q-DQ nodes
150
+ def fuse_qordered_mamtul(self):
151
+ fusion = FusionQOrderedMatMul(self)
152
+ fusion.apply()
153
+
154
+ def get_graph_inputs_from_node_type(self, op_type: str, input_indices: list[int], casted: bool):
155
+ """
156
+ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
157
+ Returns a list of the graph input names based on the filter whether it is casted or not.
158
+ """
159
+ graph_inputs = []
160
+
161
+ output_name_to_node = self.output_name_to_node()
162
+ nodes = self.get_nodes_by_op_type(op_type)
163
+ for node in nodes:
164
+ bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
165
+ for bert_input in bert_inputs:
166
+ if self.find_graph_input(bert_input):
167
+ if not casted:
168
+ graph_inputs.append(bert_input)
169
+ elif bert_input in output_name_to_node:
170
+ parent = output_name_to_node[bert_input]
171
+ if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None:
172
+ if casted:
173
+ graph_inputs.append(parent.input[0])
174
+ return graph_inputs
175
+
176
+ def get_graph_inputs_from_fused_nodes(self, casted: bool):
177
+ inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted)
178
+ inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted)
179
+ return inputs
180
+
181
+ def change_graph_inputs_to_int32(self):
182
+ """Change data type of all graph inputs to int32 type, and add Cast node if needed."""
183
+ graph = self.graph()
184
+ add_cast_count = 0
185
+ remove_cast_count = 0
186
+ for graph_input in graph.input:
187
+ new_node, removed_nodes = self.change_graph_input_type(graph_input, TensorProto.INT32)
188
+ if new_node:
189
+ add_cast_count += 1
190
+ remove_cast_count += len(removed_nodes)
191
+ logger.info(
192
+ f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
193
+ )
194
+
195
+ def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"):
196
+ """
197
+ Update input and output shape to use dynamic axes.
198
+ """
199
+ bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
200
+ casted=True
201
+ ) + self.get_graph_inputs_from_fused_nodes(casted=False)
202
+
203
+ for input in self.model.graph.input:
204
+ if input.name in bert_graph_inputs:
205
+ dim_proto = input.type.tensor_type.shape.dim[0]
206
+ dim_proto.dim_param = dynamic_batch_dim
207
+ if dynamic_seq_len is not None:
208
+ dim_proto = input.type.tensor_type.shape.dim[1]
209
+ dim_proto.dim_param = dynamic_seq_len
210
+
211
+ for output in self.model.graph.output:
212
+ dim_proto = output.type.tensor_type.shape.dim[0]
213
+ dim_proto.dim_param = dynamic_batch_dim
214
+
215
+ def preprocess(self):
216
+ self.adjust_reshape_and_expand()
217
+ return
218
+
219
+ def adjust_reshape_and_expand(self):
220
+ nodes_to_remove = []
221
+ for node in self.nodes():
222
+ if node.op_type == "Reshape":
223
+ # Clean up unnecessary reshape nodes.
224
+ # Find reshape nodes with no actually data in "shape" attribute and remove.
225
+ reshape_shape = self.get_constant_value(node.input[1])
226
+ if reshape_shape is not None and reshape_shape.size == 0:
227
+ nodes_to_remove.extend([node])
228
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
229
+ continue
230
+
231
+ # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
232
+ # changing current reshape's input to output of slice.
233
+ reshape_path = self.match_parent_path(
234
+ node,
235
+ ["Expand", "Expand", "Reshape", "Slice"],
236
+ [0, 0, 0, 0],
237
+ self.output_name_to_node(),
238
+ )
239
+ if reshape_path is not None:
240
+ expand_node = reshape_path[-3]
241
+ expand_shape_value = self.get_constant_value(expand_node.input[1])
242
+
243
+ reshape_before_expand = reshape_path[-2]
244
+ shape_value = self.get_constant_value(reshape_before_expand.input[1])
245
+
246
+ slice_node = reshape_path[-1]
247
+ if (
248
+ expand_shape_value is not None
249
+ and shape_value is not None
250
+ and len(expand_shape_value) == 2
251
+ and len(shape_value) == 1
252
+ and expand_shape_value[1] == shape_value[0]
253
+ ):
254
+ node.input[0] = slice_node.output[0]
255
+
256
+ if nodes_to_remove:
257
+ self.remove_nodes(nodes_to_remove)
258
+ logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
259
+
260
+ def clean_graph(self):
261
+ output_name_to_node = self.output_name_to_node()
262
+ nodes_to_remove = []
263
+ for node in self.nodes():
264
+ # Before:
265
+ # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
266
+ # | |
267
+ # | v
268
+ # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
269
+ # After:
270
+ # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
271
+ # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
272
+ op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
273
+ if node.op_type in op_input_id:
274
+ i = op_input_id[node.op_type]
275
+ parent_nodes = self.match_parent_path(
276
+ node,
277
+ [
278
+ "Cast",
279
+ "ConstantOfShape",
280
+ "Concat",
281
+ "Unsqueeze",
282
+ "Gather",
283
+ "Shape",
284
+ ],
285
+ [i, 0, 0, 0, 0, 0],
286
+ output_name_to_node,
287
+ )
288
+ if parent_nodes is not None:
289
+ (
290
+ cast,
291
+ constantOfShape, # noqa: N806
292
+ concat,
293
+ unsqueeze,
294
+ gather,
295
+ shape,
296
+ ) = parent_nodes
297
+ if shape.input[0] == self.graph().input[0].name:
298
+ constantOfShape.input[0] = shape.output[0]
299
+ output_name_to_node = self.output_name_to_node()
300
+
301
+ if node.op_type == "Attention":
302
+ # Before:
303
+ # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
304
+ # After:
305
+ # remove this path, and remove the optional mask_index input of Attention node.
306
+ parent_nodes = self.match_parent_path(
307
+ node,
308
+ ["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
309
+ [3, 0, 0, 0],
310
+ output_name_to_node,
311
+ )
312
+ if parent_nodes is not None:
313
+ if parent_nodes[-1].input[0] == self.graph().input[0].name:
314
+ attention_node = helper.make_node(
315
+ "Attention",
316
+ inputs=node.input[0 : len(node.input) - 1],
317
+ outputs=node.output,
318
+ name=node.name + "_remove_mask",
319
+ )
320
+ attention_node.domain = "com.microsoft"
321
+ attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
322
+ self.add_node(attention_node, self.get_graph_by_node(attention_node).name)
323
+ nodes_to_remove.append(node)
324
+ self.remove_nodes(nodes_to_remove)
325
+
326
+ def postprocess(self):
327
+ self.clean_graph()
328
+ self.prune_graph()
329
+
330
+ def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
331
+ if (options is not None) and not options.enable_shape_inference:
332
+ self.disable_shape_inference()
333
+
334
+ self.utils.remove_identity_nodes()
335
+
336
+ # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
337
+ self.utils.remove_useless_cast_nodes()
338
+
339
+ # Apply any missed constant-folding model optimizations (e.g. for Dynamo-exported models)
340
+ self.fuse_constant_fold()
341
+
342
+ if (options is None) or options.enable_layer_norm:
343
+ self.fuse_layer_norm()
344
+ self.fuse_simplified_layer_norm()
345
+
346
+ if (options is None) or options.enable_gelu:
347
+ self.fuse_gelu()
348
+
349
+ self.preprocess()
350
+
351
+ self.fuse_reshape()
352
+
353
+ if (options is None) or options.enable_skip_layer_norm:
354
+ self.fuse_skip_layer_norm(options.enable_shape_inference)
355
+ self.fuse_skip_simplified_layer_norm()
356
+
357
+ if (options is None) or options.enable_rotary_embeddings:
358
+ self.fuse_rotary_embeddings()
359
+
360
+ if options is not None:
361
+ self.attention_mask.set_mask_format(options.attention_mask_format)
362
+ if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention):
363
+ self.attention_fusion = FusionAttention(
364
+ self,
365
+ self.hidden_size,
366
+ self.num_heads,
367
+ self.attention_mask,
368
+ options.use_multi_head_attention,
369
+ )
370
+
371
+ if (options is None) or options.enable_attention:
372
+ self.fuse_attention()
373
+
374
+ # Perform the MatMul fusion after the Attention fusion as we do not
375
+ # want to fuse the MatMuls inside the Attention subgraphs
376
+ if (options is None) or options.enable_qordered_matmul:
377
+ self.fuse_qordered_mamtul()
378
+
379
+ self.fuse_shape()
380
+
381
+ if (options is None) or options.enable_embed_layer_norm:
382
+ use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd
383
+ self.fuse_embed_layer(use_mask_index)
384
+
385
+ # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
386
+ self.utils.remove_useless_reshape_nodes()
387
+
388
+ self.postprocess()
389
+
390
+ # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
391
+ if (options is None) or options.enable_bias_gelu:
392
+ # Fuse Gelu and Add Bias before it.
393
+ self.fuse_bias_gelu(is_fastgelu=True)
394
+ self.fuse_bias_gelu(is_fastgelu=False)
395
+
396
+ if (options is None) or options.enable_bias_skip_layer_norm:
397
+ # Fuse SkipLayerNormalization and Add Bias before it.
398
+ self.fuse_add_bias_skip_layer_norm()
399
+
400
+ if options is not None and options.enable_gelu_approximation:
401
+ self.gelu_approximation()
402
+
403
+ if options is not None and options.enable_gemm_fast_gelu:
404
+ self.fuse_gemm_fast_gelu()
405
+
406
+ self.remove_unused_constant()
407
+
408
+ # Use symbolic batch dimension in input and output.
409
+ if add_dynamic_axes:
410
+ self.use_dynamic_axes()
411
+
412
+ logger.info(f"opset version: {self.get_opset_version()}")
413
+
414
+ def get_fused_operator_statistics(self):
415
+ """
416
+ Returns node count of fused operators.
417
+ """
418
+ op_count = {}
419
+ ops = [
420
+ "EmbedLayerNormalization",
421
+ "Attention",
422
+ "MultiHeadAttention",
423
+ "Gelu",
424
+ "FastGelu",
425
+ "BiasGelu",
426
+ "GemmFastGelu",
427
+ "LayerNormalization",
428
+ "SimplifiedLayerNormalization",
429
+ "SkipLayerNormalization",
430
+ "SkipSimplifiedLayerNormalization",
431
+ "RotaryEmbedding",
432
+ ]
433
+ q_ops = [
434
+ "QOrderedAttention",
435
+ "QOrderedGelu",
436
+ "QOrderedLayerNormalization",
437
+ "QOrderedMatMul",
438
+ ]
439
+ for op in ops + q_ops:
440
+ nodes = self.get_nodes_by_op_type(op)
441
+ op_count[op] = len(nodes)
442
+
443
+ logger.info(f"Optimized operators: {op_count}")
444
+ return op_count
445
+
446
+ def is_fully_optimized(self, fused_op_count=None):
447
+ """
448
+ Returns True when the model is fully optimized.
449
+ """
450
+ if fused_op_count is None:
451
+ fused_op_count = self.get_fused_operator_statistics()
452
+
453
+ def op_count(op_name: str):
454
+ return fused_op_count.get(op_name) or 0
455
+
456
+ embed = op_count("EmbedLayerNormalization")
457
+ attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention")
458
+ gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
459
+ layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
460
+ simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization")
461
+
462
+ is_perfect = (
463
+ (embed > 0)
464
+ and (attention > 0)
465
+ and (attention == gelu)
466
+ and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention))
467
+ )
468
+
469
+ if layer_norm == 0:
470
+ logger.debug("Layer Normalization not fused")
471
+
472
+ if simple_layer_norm == 0:
473
+ logger.debug("Simple Layer Normalization not fused")
474
+
475
+ if gelu == 0:
476
+ logger.debug("Gelu (or FastGelu) not fused")
477
+
478
+ if embed == 0:
479
+ logger.debug("EmbedLayerNormalization not fused")
480
+
481
+ if attention == 0:
482
+ logger.warning("Attention (or MultiHeadAttention) not fused")
483
+
484
+ return is_perfect
485
+
486
+ def convert_to_packing_mode(self, use_symbolic_shape_infer: bool = False):
487
+ packing_mode = PackingMode(self)
488
+ packing_mode.convert(use_symbolic_shape_infer)