onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,588 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ import logging
7
+
8
+ import numpy as np
9
+ import onnx
10
+ from onnx import TensorProto, helper, numpy_helper
11
+ from onnx_model_bert import BertOnnxModel
12
+
13
+ logger = logging.getLogger(__name__)
14
+
15
+
16
+ class BertOnnxModelTF(BertOnnxModel):
17
+ def __init__(self, model, num_heads, hidden_size):
18
+ super().__init__(model, num_heads, hidden_size)
19
+
20
+ def remove_identity(self):
21
+ nodes_to_remove = []
22
+ for node in self.nodes():
23
+ if node.op_type == "Identity":
24
+ if not self.find_graph_output(node.output[0]):
25
+ self.replace_input_of_all_nodes(node.output[0], node.input[0])
26
+ nodes_to_remove.append(node)
27
+ self.remove_nodes(nodes_to_remove)
28
+ logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
29
+
30
+ def match_mask_path(self, add_or_sub_before_softmax):
31
+ mask_nodes = self.match_parent_path(
32
+ add_or_sub_before_softmax,
33
+ ["Mul", "Sub", "Reshape", "Cast"],
34
+ [1, None, 1, 0],
35
+ )
36
+ if mask_nodes is not None:
37
+ return mask_nodes
38
+
39
+ mask_nodes = self.match_parent_path(
40
+ add_or_sub_before_softmax,
41
+ ["Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
42
+ [1, 0, 1, 0, 0],
43
+ )
44
+ if mask_nodes is not None:
45
+ return mask_nodes
46
+
47
+ mask_nodes = self.match_parent_path(
48
+ add_or_sub_before_softmax,
49
+ ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
50
+ [1, None, 1, 0, 0],
51
+ )
52
+
53
+ return mask_nodes
54
+
55
+ def get_2d_initializers_from_parent_subgraphs(self, current_node):
56
+ """
57
+ Find initializers that is 2D. Returns a dictionary with name as key and shape as value.
58
+ """
59
+ parent_nodes = self.get_parent_subgraph_nodes(current_node, [])
60
+ initializers = {}
61
+ for node in parent_nodes:
62
+ for input in node.input:
63
+ initializer = self.get_initializer(input)
64
+ if initializer:
65
+ temp = numpy_helper.to_array(initializer)
66
+ if len(temp.shape) == 2:
67
+ initializers[initializer.name] = temp.shape
68
+
69
+ return initializers
70
+
71
+ def find_segment_ids(self, segment_embedding, input_ids):
72
+ input_name_to_nodes = self.input_name_to_nodes()
73
+ if segment_embedding not in input_name_to_nodes:
74
+ return None
75
+
76
+ nodes = input_name_to_nodes[segment_embedding]
77
+ if len(nodes) != 1:
78
+ return None
79
+
80
+ graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
81
+ if len(graph_inputs) > 1:
82
+ print("Found multiple candidates of segment_ids", graph_inputs)
83
+ return None
84
+ # Find segment ids in graph inputs. The segment id input must not be the same as input_ids.
85
+ if len(graph_inputs) == 1 and graph_inputs[0] != input_ids:
86
+ return graph_inputs[0]
87
+
88
+ # If the segment id candidate is the same as the input_ids, try to assign alternative segment ids and simplify the graph if needed.
89
+ segment_ids = nodes[0].input[1]
90
+ _, segment_id_path, _ = self.match_parent_paths(
91
+ nodes[0],
92
+ [
93
+ (
94
+ ["ConstantOfShape", "Cast", "Concat", "Slice", "Cast", "Shape"],
95
+ [1, 0, 0, 0, 0, 0],
96
+ ),
97
+ (
98
+ [
99
+ "ConstantOfShape",
100
+ "Cast",
101
+ "Concat",
102
+ "Unsqueeze",
103
+ "Squeeze",
104
+ "Slice",
105
+ "Cast",
106
+ "Shape",
107
+ ],
108
+ [1, 0, 0, 0, 0, 0, 0, 0],
109
+ ),
110
+ ],
111
+ None,
112
+ )
113
+
114
+ if segment_id_path and input_ids and input_ids == segment_id_path[-1].input[0]:
115
+ logger.debug("Simplify semgent id path...")
116
+ constantofshape_node = segment_id_path[0]
117
+ graph_name = self.get_graph_by_node(constantofshape_node).name
118
+ self.add_node(
119
+ helper.make_node("Shape", inputs=[input_ids], outputs=["input_shape"]),
120
+ graph_name,
121
+ )
122
+ constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
123
+ self.add_node(
124
+ helper.make_node(
125
+ "ConstantOfShape",
126
+ inputs=["input_shape"],
127
+ outputs=["zeros_for_input_shape"],
128
+ value=constantofshape_value,
129
+ ),
130
+ graph_name,
131
+ )
132
+ segment_ids = "zeros_for_input_shape"
133
+ return segment_ids
134
+
135
+ def find_input_ids(self, word_embedding):
136
+ input_name_to_nodes = self.input_name_to_nodes()
137
+ if word_embedding not in input_name_to_nodes:
138
+ return None
139
+
140
+ nodes = input_name_to_nodes[word_embedding]
141
+ if len(nodes) != 1:
142
+ return None
143
+
144
+ graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
145
+ if len(graph_inputs) == 1:
146
+ return graph_inputs[0]
147
+
148
+ print("Found multiple candidates of input_ids", graph_inputs)
149
+ return None
150
+
151
+ def find_mask_input(self, excluded_graph_inputs):
152
+ for node in self.nodes():
153
+ if node.op_type == "Softmax":
154
+ mask_path = self.match_parent_path(
155
+ node,
156
+ ["Add", "Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
157
+ [0, 1, None, 1, 0, 0],
158
+ )
159
+ if mask_path is None:
160
+ continue
161
+ (
162
+ add_node,
163
+ mul_node,
164
+ sub_node,
165
+ cast_node,
166
+ slice_node,
167
+ unsqueeze_node,
168
+ ) = mask_path
169
+ if self.has_constant_input(mul_node, -10000) and self.has_constant_input(sub_node, 1):
170
+ graph_inputs = self.get_graph_inputs(sub_node, recursive=True)
171
+ inputs = [input for input in graph_inputs if input not in excluded_graph_inputs]
172
+ if len(inputs) > 1:
173
+ print("Found multiple candidates of mask input", inputs)
174
+ return None
175
+ if len(inputs) == 1:
176
+ return inputs[0]
177
+ # Duplicated input found. Try to simplify the graph.
178
+ path_to_be_simplified = self.match_parent_path(
179
+ mask_path[-1],
180
+ [
181
+ "ConstantOfShape",
182
+ "Cast",
183
+ "Concat",
184
+ "Unsqueeze",
185
+ "Squeeze",
186
+ "Slice",
187
+ "Cast",
188
+ "Shape",
189
+ ],
190
+ [0, 0, 0, 0, 0, 0, 0, 0],
191
+ )
192
+ duplicated_inputs = [input for input in graph_inputs if input in excluded_graph_inputs]
193
+ # Simplify graph for dynamic axes.
194
+ if (
195
+ path_to_be_simplified
196
+ and duplicated_inputs
197
+ and len(duplicated_inputs) == 1
198
+ and duplicated_inputs[0] == path_to_be_simplified[-1].input[0]
199
+ ):
200
+ logger.debug("Simplify semgent id path...")
201
+ constantofshape_node = path_to_be_simplified[0]
202
+ constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
203
+ graph_name = self.get_graph_by_node(constantofshape_node).name
204
+ self.add_node(
205
+ helper.make_node(
206
+ "Shape",
207
+ inputs=[duplicated_inputs[0]],
208
+ outputs=["input_shape_for_mask"],
209
+ ),
210
+ graph_name,
211
+ )
212
+ self.add_node(
213
+ helper.make_node(
214
+ "ConstantOfShape",
215
+ inputs=["input_shape_for_mask"],
216
+ outputs=[unsqueeze_node.input[0]],
217
+ value=constantofshape_value,
218
+ ),
219
+ graph_name,
220
+ )
221
+ return unsqueeze_node.input[0]
222
+ return None
223
+
224
+ def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embedding, position_embedding):
225
+ input_ids = self.find_input_ids(word_embedding)
226
+ if input_ids is None:
227
+ logger.info("Failed to find input_ids. Cannot fuse embedding layer.")
228
+ return False
229
+
230
+ segment_ids = self.find_segment_ids(segment_embedding, input_ids)
231
+ if segment_ids is None:
232
+ logger.info("Failed to find segment_ids. Cannot fuse embedding layer.")
233
+ return False
234
+
235
+ mask_input = self.find_mask_input([segment_ids, input_ids])
236
+ if mask_input is None:
237
+ logger.info("Failed to find input_mask. Cannot fuse embedding layer.")
238
+ return False
239
+
240
+ self.bert_inputs = [input_ids, segment_ids, mask_input]
241
+
242
+ mask_index = self.create_node_name("mask_index")
243
+ self.attention_mask.set_mask_indice(mask_input, mask_index)
244
+
245
+ if self.find_graph_input(input_ids).type.tensor_type.elem_type != TensorProto.INT32:
246
+ casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
247
+
248
+ if self.find_graph_input(segment_ids):
249
+ casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
250
+ else:
251
+ segment_ids, segment_id_cast_node = self.utils.cast_input_to_int32(segment_ids)
252
+
253
+ if self.find_graph_input(mask_input):
254
+ casted, mask_input = self.utils.cast_graph_input_to_int32(mask_input)
255
+ else:
256
+ mask_input, mask_input_cast_node = self.utils.cast_input_to_int32(mask_input)
257
+
258
+ embed_output = self.create_node_name("embed_output")
259
+ embed_node = onnx.helper.make_node(
260
+ "EmbedLayerNormalization",
261
+ inputs=[
262
+ input_ids,
263
+ segment_ids,
264
+ word_embedding,
265
+ position_embedding,
266
+ segment_embedding,
267
+ normalize_node.input[1], # gamma
268
+ normalize_node.input[2], # beta
269
+ mask_input,
270
+ ],
271
+ outputs=[embed_output, mask_index],
272
+ name="EmbedLayer",
273
+ )
274
+ embed_node.domain = "com.microsoft"
275
+ self.replace_input_of_all_nodes(normalize_node.output[0], embed_output)
276
+ self.add_node(embed_node, self.get_graph_by_node(normalize_node).name)
277
+
278
+ def process_embedding(self):
279
+ """
280
+ Automatically detect word, segment and position embeddings.
281
+ """
282
+ logger.info("start processing embedding layer...")
283
+ output_name_to_node = self.output_name_to_node()
284
+
285
+ layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
286
+ for layer_norm_node in layer_norm_nodes:
287
+ pos_embed_path = self.match_parent_path(
288
+ layer_norm_node,
289
+ ["Add", "Reshape", "Slice"],
290
+ [0, 1, 0],
291
+ output_name_to_node,
292
+ )
293
+ if pos_embed_path is None:
294
+ continue
295
+
296
+ add_node, reshape_node, slice_node = pos_embed_path
297
+ initializer = self.get_initializer(slice_node.input[0])
298
+ if initializer is None:
299
+ continue
300
+
301
+ temp = numpy_helper.to_array(initializer)
302
+ if len(temp.shape) == 2:
303
+ logger.info(f"Found position embedding. name:{initializer.name}, shape:{temp.shape}")
304
+ position_embedding = initializer.name
305
+ else:
306
+ logger.info(f"Failed to find position embedding. name:{initializer.name}, shape:{temp.shape}")
307
+ return
308
+
309
+ first_parent = self.get_parent(add_node, 0, output_name_to_node)
310
+ if first_parent is not None and first_parent.op_type == "Add":
311
+ embeddings = self.get_2d_initializers_from_parent_subgraphs(first_parent)
312
+ if len(embeddings) != 2:
313
+ logger.warning(
314
+ f"Failed to find two embeddings (word and segment) from Add node. Found {embeddings}"
315
+ )
316
+ return
317
+
318
+ word_embedding = None
319
+ segment_embedding = None
320
+ for name, shape in embeddings.items():
321
+ if shape[0] == 2:
322
+ segment_embedding = name
323
+ logger.info(f"Found segment embedding. name:{name}, shape:{shape}")
324
+ else:
325
+ word_embedding = name
326
+ logger.info(f"Found words embedding. name:{name}, shape:{shape}")
327
+
328
+ if word_embedding is None or segment_embedding is None:
329
+ logger.info("Failed to find both word and segment embedding")
330
+ return
331
+
332
+ logger.info("Create Embedding node")
333
+ self.create_embedding_subgraph(
334
+ layer_norm_node,
335
+ word_embedding,
336
+ segment_embedding,
337
+ position_embedding,
338
+ )
339
+ # Prune graph to remove those original embedding nodes.
340
+ self.prune_graph()
341
+ break
342
+
343
+ def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
344
+ for x in [matmul_q, matmul_k, matmul_v]:
345
+ root_input = x.input[0]
346
+ root_node = output_name_to_node[root_input]
347
+ if root_node == parent:
348
+ continue
349
+ logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
350
+ return False
351
+
352
+ return True
353
+
354
+ def fuse_attention(self):
355
+ output_name_to_node = self.output_name_to_node()
356
+
357
+ nodes_to_remove = []
358
+ attention_count = 0
359
+
360
+ start_nodes = []
361
+ skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
362
+ layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
363
+ # Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
364
+ # Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
365
+ start_nodes.extend(skip_layer_norm_nodes)
366
+ start_nodes.extend(layer_norm_nodes)
367
+
368
+ for normalize_node in start_nodes:
369
+ graph_name = self.get_graph_by_node(normalize_node).name
370
+ # SkipLayerNormalization has two inputs, and one of them is the root input for attention.
371
+ if normalize_node.op_type == "LayerNormalization":
372
+ add_before_layernorm = self.match_parent(normalize_node, "Add", 0)
373
+ if add_before_layernorm is not None:
374
+ normalize_node = add_before_layernorm # noqa: PLW2901
375
+ else:
376
+ continue
377
+ parent = self.get_parent(normalize_node, 1)
378
+ if parent is None or parent.op_type not in [
379
+ "SkipLayerNormalization",
380
+ "LayerNormalization",
381
+ "Reshape",
382
+ ]:
383
+ parent = self.get_parent(normalize_node, 0)
384
+ if parent is None or parent.op_type not in [
385
+ "SkipLayerNormalization",
386
+ "LayerNormalization",
387
+ "Reshape",
388
+ ]:
389
+ logger.debug("Failed to match parent of normalize_node")
390
+ continue
391
+
392
+ qkv_nodes = self.match_parent_path(
393
+ normalize_node,
394
+ ["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
395
+ [0, 0, 0, 0, 0],
396
+ )
397
+ if qkv_nodes is None:
398
+ qkv_nodes = self.match_parent_path(
399
+ normalize_node,
400
+ ["MatMul", "Reshape", "Transpose", "MatMul"],
401
+ [1, 0, 0, 0],
402
+ )
403
+ if qkv_nodes is None:
404
+ qkv_nodes = self.match_parent_path(normalize_node, ["Add", "Einsum", "Einsum"], [0, 0, 0])
405
+ if qkv_nodes is None:
406
+ logger.debug("Failed to match qkv nodes")
407
+ continue
408
+
409
+ matmul_qkv = qkv_nodes[-1]
410
+ v_nodes = self.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
411
+ if v_nodes is None:
412
+ v_nodes = self.match_parent_path(matmul_qkv, ["Add", "Einsum"], [1, 0])
413
+ if v_nodes is None:
414
+ logger.debug("Failed to match v path")
415
+ continue
416
+
417
+ add_v = v_nodes[-2]
418
+ matmul_v = v_nodes[-1]
419
+ qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
420
+ if qk_nodes is None:
421
+ qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Einsum"], [0, 0, 0])
422
+ if qk_nodes is None:
423
+ logger.debug("Failed to match qk_paths")
424
+ continue
425
+ matmul_qk = qk_nodes[-1]
426
+
427
+ q_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0])
428
+ if q_nodes is None:
429
+ q_nodes = self.match_parent_path(matmul_qk, ["Add", "Einsum"], [0, 0])
430
+ if q_nodes is None:
431
+ logger.debug("Failed to match q path")
432
+ continue
433
+
434
+ add_q = q_nodes[-2]
435
+ matmul_q = q_nodes[-1]
436
+
437
+ k_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
438
+ if k_nodes is None:
439
+ k_nodes = self.match_parent_path(matmul_qk, ["Mul", "Add", "Einsum"], [1, 0, 0])
440
+ if k_nodes is None:
441
+ logger.debug("Failed to match k path")
442
+ continue
443
+ add_k = k_nodes[-2]
444
+ matmul_k = k_nodes[-1]
445
+
446
+ mask_nodes = self.match_mask_path(qk_nodes[1])
447
+
448
+ if mask_nodes is None:
449
+ logger.debug("Cannot find mask_nodes.")
450
+ continue
451
+
452
+ if not self.has_constant_input(mask_nodes[1], 1):
453
+ logger.debug("Sub node expected to have an input with constant value 1.0.")
454
+ continue
455
+
456
+ # add a squeeze node to convert a 3-d mask to 2-d
457
+ squeeze_node = self.match_parent_path(mask_nodes[-1], ["Squeeze"], [0]) or self.match_parent_path(
458
+ mask_nodes[-1], ["Expand"], [0]
459
+ )
460
+ squeeze_node_name = "Squeeze_3d_to_2d_mask"
461
+ squeeze_output_name = squeeze_node_name + "_output"
462
+ if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None:
463
+ mask_input = mask_nodes[-1].input[1]
464
+ self.add_node(
465
+ helper.make_node(
466
+ "Squeeze",
467
+ [mask_input],
468
+ [squeeze_output_name],
469
+ squeeze_node_name,
470
+ axes=[1],
471
+ ),
472
+ graph_name,
473
+ )
474
+ mask_nodes[-1].input[0] = squeeze_output_name
475
+
476
+ is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
477
+ if is_same_root:
478
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
479
+ logger.debug("Create an Attention node.")
480
+
481
+ # For tf models, q and v are flipped.
482
+ attention_node = self.attention_fusion.create_attention_node(
483
+ mask_index=mask_index,
484
+ q_matmul=matmul_k,
485
+ k_matmul=matmul_q,
486
+ v_matmul=matmul_v,
487
+ q_add=add_k,
488
+ k_add=add_q,
489
+ v_add=add_v,
490
+ num_heads=self.num_heads,
491
+ hidden_size=self.hidden_size,
492
+ first_input=parent.output[0],
493
+ output=qkv_nodes[2].output[0],
494
+ )
495
+ if attention_node is None:
496
+ continue
497
+
498
+ if qkv_nodes[1].op_type == "Einsum":
499
+ # add reshape before einsum
500
+ tensor = helper.make_tensor(
501
+ name=qkv_nodes[1].name + "_newshape",
502
+ data_type=TensorProto.INT64,
503
+ dims=[4],
504
+ vals=np.int64(
505
+ [
506
+ [
507
+ 0,
508
+ 0,
509
+ self.num_heads,
510
+ int(self.hidden_size / self.num_heads),
511
+ ]
512
+ ]
513
+ ).tobytes(),
514
+ raw=True,
515
+ )
516
+ self.add_initializer(tensor, graph_name)
517
+ reshape_ = helper.make_node(
518
+ "Reshape",
519
+ inputs=[
520
+ attention_node.output[0],
521
+ qkv_nodes[1].name + "_newshape",
522
+ ],
523
+ outputs=[qkv_nodes[1].name + "_reshape_output"],
524
+ name=qkv_nodes[1].name + "_reshape",
525
+ )
526
+ qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output"
527
+ self.add_node(reshape_, graph_name)
528
+ if parent.op_type == "Reshape":
529
+ # Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
530
+ hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
531
+ tensor = helper.make_tensor(
532
+ name=parent.name + "_modified",
533
+ data_type=TensorProto.INT64,
534
+ dims=[3],
535
+ vals=np.int64([[1, -1, hidden_size]]).tobytes(),
536
+ raw=True,
537
+ )
538
+ self.add_initializer(tensor, graph_name)
539
+ parent.input[1] = parent.name + "_modified"
540
+
541
+ self.add_node(attention_node, graph_name)
542
+ attention_count += 1
543
+
544
+ nodes_to_remove.extend(qkv_nodes[2:])
545
+ nodes_to_remove.extend(qk_nodes)
546
+ nodes_to_remove.extend(q_nodes)
547
+ nodes_to_remove.extend(k_nodes)
548
+ nodes_to_remove.extend(v_nodes)
549
+ nodes_to_remove.extend(mask_nodes)
550
+ else:
551
+ logger.debug("Root node not matched.")
552
+ continue
553
+ self.remove_nodes(nodes_to_remove)
554
+ self.update_graph()
555
+ logger.info(f"Fused Attention count:{attention_count}")
556
+
557
+ def preprocess(self):
558
+ self.remove_identity()
559
+ self.process_embedding()
560
+ self.skip_reshape()
561
+
562
+ def skip_reshape(self):
563
+ count = 0
564
+ reshape_nodes = self.get_nodes_by_op_type("Reshape")
565
+ for reshape_node in reshape_nodes:
566
+ parent = self.get_parent(reshape_node, 0)
567
+ if parent is not None and parent.op_type == "Reshape":
568
+ reshape_node.input[0] = parent.input[0]
569
+ count += 1
570
+
571
+ if count > 0:
572
+ logger.info(f"Skip consequent Reshape count: {count}")
573
+
574
+ def remove_reshape_before_first_attention(self):
575
+ attention_nodes = self.get_nodes_by_op_type("Attention")
576
+ for attention_node in attention_nodes:
577
+ path = self.match_parent_path(attention_node, ["Reshape", "EmbedLayerNormalization"], [0, 0])
578
+ if path is None:
579
+ continue
580
+ logger.info("Remove Reshape before first Attention node.")
581
+ reshape, _ = path
582
+ self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0])
583
+ self.remove_node(reshape)
584
+ break
585
+
586
+ def postprocess(self):
587
+ self.remove_reshape_before_first_attention()
588
+ self.prune_graph()
@@ -0,0 +1,42 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_attention_clip import FusionAttentionClip
9
+ from onnx import ModelProto
10
+ from onnx_model_bert import BertOnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class ClipOnnxModel(BertOnnxModel):
16
+ def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
17
+ super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
18
+ self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads)
19
+
20
+ def get_fused_operator_statistics(self):
21
+ """
22
+ Returns node count of fused operators.
23
+ """
24
+ op_count = {}
25
+ ops = [
26
+ "Attention",
27
+ "FastGelu",
28
+ "Gelu",
29
+ "LayerNormalization",
30
+ "QuickGelu",
31
+ "BiasGelu",
32
+ "SkipLayerNormalization",
33
+ ]
34
+ for op in ops:
35
+ nodes = self.get_nodes_by_op_type(op)
36
+ op_count[op] = len(nodes)
37
+
38
+ logger.info(f"Optimized operators:{op_count}")
39
+ return op_count
40
+
41
+ def fuse_attention(self):
42
+ self.clip_attention_fusion.apply()
@@ -0,0 +1,32 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+
7
+ from fusion_attention import AttentionMask
8
+ from fusion_conformer_attention import FusionConformerAttention
9
+ from fusion_options import FusionOptions
10
+ from onnx_model_bert import BertOnnxModel
11
+
12
+ logger = logging.getLogger(__name__)
13
+
14
+
15
+ class ConformerOnnxModel(BertOnnxModel):
16
+ def __init__(self, model, num_heads, hidden_size):
17
+ super().__init__(model, num_heads, hidden_size)
18
+ self.attention_mask = AttentionMask(self)
19
+ self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
20
+
21
+ def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
22
+ self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
23
+ self.attention_fusion.disable_multi_head_attention_bias = (
24
+ False if options is None else options.disable_multi_head_attention_bias
25
+ )
26
+ super().optimize(options, add_dynamic_axes)
27
+
28
+ def fuse_attention(self):
29
+ self.attention_fusion.apply()
30
+
31
+ def preprocess(self):
32
+ self.adjust_reshape_and_expand()