onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,142 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ from typing import Optional
7
+
8
+ from fusion_attention import AttentionMask
9
+ from fusion_bart_attention import FusionBartAttention
10
+ from fusion_options import FusionOptions
11
+ from fusion_reshape import FusionReshape
12
+ from onnx import numpy_helper
13
+ from onnx_model import OnnxModel
14
+ from onnx_model_bert import BertOnnxModel
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ class FusionBartReshape(FusionReshape):
20
+ def __init__(self, model: OnnxModel):
21
+ super().__init__(model)
22
+
23
+ def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
24
+ if reshape_node.input[1] not in output_name_to_node:
25
+ return
26
+
27
+ concat_node = output_name_to_node[reshape_node.input[1]]
28
+ if concat_node.op_type != "Concat" or len(concat_node.input) != 4:
29
+ return
30
+
31
+ path0 = self.model.match_parent_path(
32
+ concat_node,
33
+ ["Unsqueeze", "Gather", "Shape"],
34
+ [0, 0, 0],
35
+ output_name_to_node,
36
+ )
37
+ if path0 is None:
38
+ return
39
+
40
+ (_, gather_0, shape_0) = path0
41
+
42
+ shape = []
43
+ gather_value = self.model.get_constant_value(gather_0.input[1])
44
+ if gather_value == 0:
45
+ shape.append(0)
46
+
47
+ path1 = self.model.match_parent_path(
48
+ concat_node,
49
+ ["Unsqueeze", "Gather", "Shape"],
50
+ [1, 0, 0],
51
+ output_name_to_node,
52
+ )
53
+ if path1 is None:
54
+ input_1_proto = self.model.get_initializer(concat_node.input[1])
55
+ input_2_proto = self.model.get_initializer(concat_node.input[2])
56
+ input_3_proto = self.model.get_initializer(concat_node.input[3])
57
+ if input_1_proto is None or input_2_proto is None or input_3_proto is None:
58
+ return
59
+
60
+ input_1 = numpy_helper.to_array(input_1_proto)
61
+ input_2 = numpy_helper.to_array(input_2_proto)
62
+ input_3 = numpy_helper.to_array(input_3_proto)
63
+ if len(input_1) != 1 or len(input_2) != 1 or len(input_3) != 1:
64
+ return
65
+
66
+ if not (input_1[0] == -1 and input_2[0] > 0 and input_3[0] > 0):
67
+ return
68
+
69
+ shape.extend(input_1)
70
+ shape.extend(input_2)
71
+ shape.extend(input_3)
72
+ gemm_path_with_bias = self.model.match_parent_path(
73
+ reshape_node, ["Add", "MatMul"], [0, 1], output_name_to_node
74
+ )
75
+ gemm_path_no_bias = self.model.match_parent_path(reshape_node, ["MatMul"], [0], output_name_to_node)
76
+ if gemm_path_with_bias is not None:
77
+ gemm_path = gemm_path_with_bias
78
+ elif gemm_path_no_bias is not None:
79
+ gemm_path = gemm_path_no_bias
80
+ else:
81
+ return
82
+
83
+ top_matmul = gemm_path[-1]
84
+ root_input = top_matmul.input[0]
85
+
86
+ self.replace_reshape_node(shape, reshape_node, concat_node)
87
+ else:
88
+ (_, gather_1, shape_1) = path1
89
+
90
+ gather_value = self.model.get_constant_value(gather_1.input[1])
91
+ if gather_value == 1:
92
+ shape.append(0)
93
+
94
+ input_2_proto = self.model.get_initializer(concat_node.input[2])
95
+ input_3_proto = self.model.get_initializer(concat_node.input[3])
96
+ if input_2_proto is None or input_3_proto is None:
97
+ return
98
+
99
+ input_2 = numpy_helper.to_array(input_2_proto)
100
+ input_3 = numpy_helper.to_array(input_3_proto)
101
+ if len(input_2) != 1 or len(input_3) != 1:
102
+ return
103
+
104
+ if not (input_2[0] > 0 and input_3[0] > 0):
105
+ return
106
+
107
+ shape.extend(input_2)
108
+ shape.extend(input_3)
109
+ gemm_path = self.model.match_parent_path(
110
+ reshape_node, ["Mul", "Add", "MatMul"], [0, 0, 1], output_name_to_node
111
+ )
112
+ if gemm_path is None:
113
+ return
114
+
115
+ top_matmul = gemm_path[-1]
116
+ root_input = top_matmul.input[0]
117
+ if shape_0.input[0] != root_input or shape_1.input[0] != root_input:
118
+ return
119
+
120
+ self.replace_reshape_node(shape, reshape_node, concat_node)
121
+
122
+
123
+ class BartOnnxModel(BertOnnxModel):
124
+ def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
125
+ super().__init__(model, num_heads, hidden_size)
126
+ self.attention_mask = AttentionMask(self)
127
+ self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
128
+ self.bart_reshape_fusion_preprocess = FusionBartReshape(self)
129
+
130
+ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
131
+ self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
132
+ self.attention_fusion.disable_multi_head_attention_bias = (
133
+ False if options is None else options.disable_multi_head_attention_bias
134
+ )
135
+ super().optimize(options, add_dynamic_axes)
136
+
137
+ def fuse_attention(self):
138
+ self.attention_fusion.apply()
139
+
140
+ def preprocess(self):
141
+ self.adjust_reshape_and_expand()
142
+ self.bart_reshape_fusion_preprocess.apply()
@@ -0,0 +1,481 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import List, Optional
8
+
9
+ from convert_to_packing_mode import PackingMode
10
+ from fusion_attention import AttentionMask, FusionAttention
11
+ from fusion_bart_attention import FusionBartAttention
12
+ from fusion_biasgelu import FusionBiasGelu
13
+ from fusion_embedlayer import FusionEmbedLayerNormalization
14
+ from fusion_fastgelu import FusionFastGelu
15
+ from fusion_gelu import FusionGelu
16
+ from fusion_gelu_approximation import FusionGeluApproximation
17
+ from fusion_gemmfastgelu import FusionGemmFastGelu
18
+ from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
19
+ from fusion_options import AttentionMaskFormat, FusionOptions
20
+ from fusion_qordered_attention import FusionQOrderedAttention
21
+ from fusion_qordered_gelu import FusionQOrderedGelu
22
+ from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
23
+ from fusion_qordered_matmul import FusionQOrderedMatMul
24
+ from fusion_quickgelu import FusionQuickGelu
25
+ from fusion_reshape import FusionReshape
26
+ from fusion_rotary_attention import FusionRotaryEmbeddings
27
+ from fusion_shape import FusionShape
28
+ from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
29
+ from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
30
+ from fusion_utils import FusionUtils
31
+ from onnx import ModelProto, TensorProto, helper
32
+ from onnx_model import OnnxModel
33
+
34
+ logger = getLogger(__name__)
35
+
36
+
37
+ class BertOnnxModel(OnnxModel):
38
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
39
+ """Initialize BERT ONNX Model.
40
+
41
+ Args:
42
+ model (ModelProto): the ONNX model
43
+ num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
44
+ hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
45
+ """
46
+ assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
47
+
48
+ super().__init__(model)
49
+ self.num_heads = num_heads
50
+ self.hidden_size = hidden_size
51
+
52
+ self.attention_mask = AttentionMask(self)
53
+ self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
54
+ self.qordered_attention_fusion = FusionQOrderedAttention(
55
+ self, self.hidden_size, self.num_heads, self.attention_mask
56
+ )
57
+ self.utils = FusionUtils(self)
58
+
59
+ def fuse_attention(self):
60
+ self.attention_fusion.apply()
61
+ # Only relevant in models with Q-DQ nodes
62
+ self.qordered_attention_fusion.apply()
63
+
64
+ def fuse_gelu(self):
65
+ fusion = FusionGelu(self)
66
+ fusion.apply()
67
+ fusion = FusionFastGelu(self)
68
+ fusion.apply()
69
+ fusion = FusionQuickGelu(self)
70
+ fusion.apply()
71
+ # Only relevant in models with Q-DQ nodes
72
+ fusion = FusionQOrderedGelu(self)
73
+ fusion.apply()
74
+
75
+ def fuse_bias_gelu(self, is_fastgelu):
76
+ fusion = FusionBiasGelu(self, is_fastgelu)
77
+ fusion.apply()
78
+
79
+ def gelu_approximation(self):
80
+ fusion = FusionGeluApproximation(self)
81
+ fusion.apply()
82
+
83
+ def fuse_gemm_fast_gelu(self):
84
+ fusion = FusionGemmFastGelu(self)
85
+ fusion.apply()
86
+
87
+ def fuse_add_bias_skip_layer_norm(self):
88
+ fusion = FusionBiasSkipLayerNormalization(self)
89
+ fusion.apply()
90
+
91
+ def fuse_reshape(self):
92
+ fusion = FusionReshape(self)
93
+ fusion.apply()
94
+
95
+ def fuse_shape(self):
96
+ fusion = FusionShape(self)
97
+ fusion.apply()
98
+
99
+ def fuse_embed_layer(self, use_mask_index):
100
+ fusion = FusionEmbedLayerNormalization(self, use_mask_index)
101
+ fusion.apply()
102
+
103
+ def fuse_layer_norm(self):
104
+ fusion = FusionLayerNormalization(self)
105
+ fusion.apply()
106
+
107
+ fusion = FusionLayerNormalizationTF(self)
108
+ fusion.apply()
109
+
110
+ # Only relevant in models with Q-DQ nodes
111
+ fusion = FusionQOrderedLayerNormalization(self)
112
+ fusion.apply()
113
+
114
+ def fuse_simplified_layer_norm(self):
115
+ fusion = FusionSimplifiedLayerNormalization(self)
116
+ fusion.apply()
117
+
118
+ def fuse_skip_layer_norm(self, shape_infer=True):
119
+ fusion = FusionSkipLayerNormalization(self, shape_infer=shape_infer)
120
+ fusion.apply()
121
+
122
+ def fuse_skip_simplified_layer_norm(self):
123
+ fusion = FusionSkipSimplifiedLayerNormalization(self)
124
+ fusion.apply()
125
+
126
+ def fuse_rotary_embeddings(self):
127
+ fusion = FusionRotaryEmbeddings(self)
128
+ fusion.apply()
129
+ # Remove non-MS domain functions
130
+ rot_emb_nodes = list(
131
+ filter(
132
+ lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft",
133
+ self.model.graph.node,
134
+ )
135
+ )
136
+ non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes))
137
+ i = 0
138
+ while i < len(self.model.functions):
139
+ fn = self.model.functions[i]
140
+ if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep:
141
+ self.model.functions.remove(fn)
142
+ else:
143
+ i += 1
144
+
145
+ # Only relevant in models with Q-DQ nodes
146
+ def fuse_qordered_mamtul(self):
147
+ fusion = FusionQOrderedMatMul(self)
148
+ fusion.apply()
149
+
150
+ def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool):
151
+ """
152
+ Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
153
+ Returns a list of the graph input names based on the filter whether it is casted or not.
154
+ """
155
+ graph_inputs = []
156
+
157
+ output_name_to_node = self.output_name_to_node()
158
+ nodes = self.get_nodes_by_op_type(op_type)
159
+ for node in nodes:
160
+ bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
161
+ for bert_input in bert_inputs:
162
+ if self.find_graph_input(bert_input):
163
+ if not casted:
164
+ graph_inputs.append(bert_input)
165
+ elif bert_input in output_name_to_node:
166
+ parent = output_name_to_node[bert_input]
167
+ if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None:
168
+ if casted:
169
+ graph_inputs.append(parent.input[0])
170
+ return graph_inputs
171
+
172
+ def get_graph_inputs_from_fused_nodes(self, casted: bool):
173
+ inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted)
174
+ inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted)
175
+ return inputs
176
+
177
+ def change_graph_inputs_to_int32(self):
178
+ """Change data type of all graph inputs to int32 type, and add Cast node if needed."""
179
+ graph = self.graph()
180
+ add_cast_count = 0
181
+ remove_cast_count = 0
182
+ for graph_input in graph.input:
183
+ new_node, removed_nodes = self.change_graph_input_type(graph_input, TensorProto.INT32)
184
+ if new_node:
185
+ add_cast_count += 1
186
+ remove_cast_count += len(removed_nodes)
187
+ logger.info(
188
+ f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
189
+ )
190
+
191
+ def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"):
192
+ """
193
+ Update input and output shape to use dynamic axes.
194
+ """
195
+ bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
196
+ casted=True
197
+ ) + self.get_graph_inputs_from_fused_nodes(casted=False)
198
+
199
+ for input in self.model.graph.input:
200
+ if input.name in bert_graph_inputs:
201
+ dim_proto = input.type.tensor_type.shape.dim[0]
202
+ dim_proto.dim_param = dynamic_batch_dim
203
+ if dynamic_seq_len is not None:
204
+ dim_proto = input.type.tensor_type.shape.dim[1]
205
+ dim_proto.dim_param = dynamic_seq_len
206
+
207
+ for output in self.model.graph.output:
208
+ dim_proto = output.type.tensor_type.shape.dim[0]
209
+ dim_proto.dim_param = dynamic_batch_dim
210
+
211
+ def preprocess(self):
212
+ self.adjust_reshape_and_expand()
213
+ return
214
+
215
+ def adjust_reshape_and_expand(self):
216
+ nodes_to_remove = []
217
+ for node in self.nodes():
218
+ if node.op_type == "Reshape":
219
+ # Clean up unnecessary reshape nodes.
220
+ # Find reshape nodes with no actually data in "shape" attribute and remove.
221
+ reshape_shape = self.get_constant_value(node.input[1])
222
+ if reshape_shape is not None and reshape_shape.size == 0:
223
+ nodes_to_remove.extend([node])
224
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
225
+ continue
226
+
227
+ # Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
228
+ # changing current reshape's input to output of slice.
229
+ reshape_path = self.match_parent_path(
230
+ node,
231
+ ["Expand", "Expand", "Reshape", "Slice"],
232
+ [0, 0, 0, 0],
233
+ self.output_name_to_node(),
234
+ )
235
+ if reshape_path is not None:
236
+ expand_node = reshape_path[-3]
237
+ expand_shape_value = self.get_constant_value(expand_node.input[1])
238
+
239
+ reshape_before_expand = reshape_path[-2]
240
+ shape_value = self.get_constant_value(reshape_before_expand.input[1])
241
+
242
+ slice_node = reshape_path[-1]
243
+ if (
244
+ expand_shape_value is not None
245
+ and shape_value is not None
246
+ and len(expand_shape_value) == 2
247
+ and len(shape_value) == 1
248
+ and expand_shape_value[1] == shape_value[0]
249
+ ):
250
+ node.input[0] = slice_node.output[0]
251
+
252
+ if nodes_to_remove:
253
+ self.remove_nodes(nodes_to_remove)
254
+ logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
255
+
256
+ def clean_graph(self):
257
+ output_name_to_node = self.output_name_to_node()
258
+ nodes_to_remove = []
259
+ for node in self.nodes():
260
+ # Before:
261
+ # input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
262
+ # | |
263
+ # | v
264
+ # +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
265
+ # After:
266
+ # input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
267
+ # TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
268
+ op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
269
+ if node.op_type in op_input_id:
270
+ i = op_input_id[node.op_type]
271
+ parent_nodes = self.match_parent_path(
272
+ node,
273
+ [
274
+ "Cast",
275
+ "ConstantOfShape",
276
+ "Concat",
277
+ "Unsqueeze",
278
+ "Gather",
279
+ "Shape",
280
+ ],
281
+ [i, 0, 0, 0, 0, 0],
282
+ output_name_to_node,
283
+ )
284
+ if parent_nodes is not None:
285
+ (
286
+ cast,
287
+ constantOfShape, # noqa: N806
288
+ concat,
289
+ unsqueeze,
290
+ gather,
291
+ shape,
292
+ ) = parent_nodes
293
+ if shape.input[0] == self.graph().input[0].name:
294
+ constantOfShape.input[0] = shape.output[0]
295
+ output_name_to_node = self.output_name_to_node()
296
+
297
+ if node.op_type == "Attention":
298
+ # Before:
299
+ # input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
300
+ # After:
301
+ # remove this path, and remove the optional mask_index input of Attention node.
302
+ parent_nodes = self.match_parent_path(
303
+ node,
304
+ ["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
305
+ [3, 0, 0, 0],
306
+ output_name_to_node,
307
+ )
308
+ if parent_nodes is not None:
309
+ if parent_nodes[-1].input[0] == self.graph().input[0].name:
310
+ attention_node = helper.make_node(
311
+ "Attention",
312
+ inputs=node.input[0 : len(node.input) - 1],
313
+ outputs=node.output,
314
+ name=node.name + "_remove_mask",
315
+ )
316
+ attention_node.domain = "com.microsoft"
317
+ attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
318
+ self.add_node(attention_node, self.get_graph_by_node(attention_node).name)
319
+ nodes_to_remove.append(node)
320
+ self.remove_nodes(nodes_to_remove)
321
+
322
+ def postprocess(self):
323
+ self.clean_graph()
324
+ self.prune_graph()
325
+
326
+ def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
327
+ if (options is not None) and not options.enable_shape_inference:
328
+ self.disable_shape_inference()
329
+
330
+ self.utils.remove_identity_nodes()
331
+
332
+ # Remove cast nodes that having same data type of input and output based on symbolic shape inference.
333
+ self.utils.remove_useless_cast_nodes()
334
+
335
+ if (options is None) or options.enable_layer_norm:
336
+ self.fuse_layer_norm()
337
+ self.fuse_simplified_layer_norm()
338
+
339
+ if (options is None) or options.enable_gelu:
340
+ self.fuse_gelu()
341
+
342
+ self.preprocess()
343
+
344
+ self.fuse_reshape()
345
+
346
+ if (options is None) or options.enable_skip_layer_norm:
347
+ self.fuse_skip_layer_norm(options.enable_shape_inference)
348
+ self.fuse_skip_simplified_layer_norm()
349
+
350
+ if (options is None) or options.enable_rotary_embeddings:
351
+ self.fuse_rotary_embeddings()
352
+
353
+ if options is not None:
354
+ self.attention_mask.set_mask_format(options.attention_mask_format)
355
+ if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention):
356
+ self.attention_fusion = FusionAttention(
357
+ self,
358
+ self.hidden_size,
359
+ self.num_heads,
360
+ self.attention_mask,
361
+ options.use_multi_head_attention,
362
+ )
363
+
364
+ if (options is None) or options.enable_attention:
365
+ self.fuse_attention()
366
+
367
+ # Perform the MatMul fusion after the Attention fusion as we do not
368
+ # want to fuse the MatMuls inside the Attention subgraphs
369
+ if (options is None) or options.enable_qordered_matmul:
370
+ self.fuse_qordered_mamtul()
371
+
372
+ self.fuse_shape()
373
+
374
+ if (options is None) or options.enable_embed_layer_norm:
375
+ use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd
376
+ self.fuse_embed_layer(use_mask_index)
377
+
378
+ # Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
379
+ self.utils.remove_useless_reshape_nodes()
380
+
381
+ self.postprocess()
382
+
383
+ # Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
384
+ if (options is None) or options.enable_bias_gelu:
385
+ # Fuse Gelu and Add Bias before it.
386
+ self.fuse_bias_gelu(is_fastgelu=True)
387
+ self.fuse_bias_gelu(is_fastgelu=False)
388
+
389
+ if (options is None) or options.enable_bias_skip_layer_norm:
390
+ # Fuse SkipLayerNormalization and Add Bias before it.
391
+ self.fuse_add_bias_skip_layer_norm()
392
+
393
+ if options is not None and options.enable_gelu_approximation:
394
+ self.gelu_approximation()
395
+
396
+ if options is not None and options.enable_gemm_fast_gelu:
397
+ self.fuse_gemm_fast_gelu()
398
+
399
+ self.remove_unused_constant()
400
+
401
+ # Use symbolic batch dimension in input and output.
402
+ if add_dynamic_axes:
403
+ self.use_dynamic_axes()
404
+
405
+ logger.info(f"opset version: {self.get_opset_version()}")
406
+
407
+ def get_fused_operator_statistics(self):
408
+ """
409
+ Returns node count of fused operators.
410
+ """
411
+ op_count = {}
412
+ ops = [
413
+ "EmbedLayerNormalization",
414
+ "Attention",
415
+ "MultiHeadAttention",
416
+ "Gelu",
417
+ "FastGelu",
418
+ "BiasGelu",
419
+ "GemmFastGelu",
420
+ "LayerNormalization",
421
+ "SimplifiedLayerNormalization",
422
+ "SkipLayerNormalization",
423
+ "SkipSimplifiedLayerNormalization",
424
+ "RotaryEmbedding",
425
+ ]
426
+ q_ops = [
427
+ "QOrderedAttention",
428
+ "QOrderedGelu",
429
+ "QOrderedLayerNormalization",
430
+ "QOrderedMatMul",
431
+ ]
432
+ for op in ops + q_ops:
433
+ nodes = self.get_nodes_by_op_type(op)
434
+ op_count[op] = len(nodes)
435
+
436
+ logger.info(f"Optimized operators: {op_count}")
437
+ return op_count
438
+
439
+ def is_fully_optimized(self, fused_op_count=None):
440
+ """
441
+ Returns True when the model is fully optimized.
442
+ """
443
+ if fused_op_count is None:
444
+ fused_op_count = self.get_fused_operator_statistics()
445
+
446
+ def op_count(op_name: str):
447
+ return fused_op_count.get(op_name) or 0
448
+
449
+ embed = op_count("EmbedLayerNormalization")
450
+ attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention")
451
+ gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
452
+ layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
453
+ simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization")
454
+
455
+ is_perfect = (
456
+ (embed > 0)
457
+ and (attention > 0)
458
+ and (attention == gelu)
459
+ and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention))
460
+ )
461
+
462
+ if layer_norm == 0:
463
+ logger.debug("Layer Normalization not fused")
464
+
465
+ if simple_layer_norm == 0:
466
+ logger.debug("Simple Layer Normalization not fused")
467
+
468
+ if gelu == 0:
469
+ logger.debug("Gelu (or FastGelu) not fused")
470
+
471
+ if embed == 0:
472
+ logger.debug("EmbedLayerNormalization not fused")
473
+
474
+ if attention == 0:
475
+ logger.warning("Attention (or MultiHeadAttention) not fused")
476
+
477
+ return is_perfect
478
+
479
+ def convert_to_packing_mode(self, use_symbolic_shape_infer: bool = False):
480
+ packing_mode = PackingMode(self)
481
+ packing_mode.convert(use_symbolic_shape_infer)