onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,791 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ import logging
6
+ from typing import Optional, Union
7
+
8
+ import numpy as np
9
+ from fusion_attention import AttentionMask, FusionAttention
10
+ from fusion_base import Fusion
11
+ from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
12
+ from fusion_utils import NumpyHelper
13
+ from onnx import NodeProto, TensorProto, helper
14
+ from onnx_model import OnnxModel
15
+ from onnx_model_bert import BertOnnxModel
16
+
17
+ logger = logging.getLogger(__name__)
18
+
19
+
20
+ class FusionT5Attention(FusionAttention):
21
+ """
22
+ Fuse T5 Attention subgraph into one Attention node.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ model: OnnxModel,
28
+ hidden_size: int,
29
+ num_heads: int,
30
+ attention_mask: AttentionMask,
31
+ ):
32
+ super().__init__(
33
+ model,
34
+ hidden_size,
35
+ num_heads,
36
+ attention_mask,
37
+ use_multi_head_attention=False,
38
+ search_op_types=["SkipSimplifiedLayerNormalization", "Add"],
39
+ )
40
+ self.static_kv = 1
41
+
42
+ def create_attention_node(
43
+ self,
44
+ mask_index: str,
45
+ q_matmul: NodeProto,
46
+ k_matmul: NodeProto,
47
+ v_matmul: NodeProto,
48
+ num_heads: int,
49
+ hidden_size: int,
50
+ input: str,
51
+ output: str,
52
+ add_qk_str: str,
53
+ scale: Optional[float] = None,
54
+ ) -> Union[NodeProto, None]:
55
+ """Create an Attention node.
56
+ Args:
57
+ mask_index (str): mask input
58
+ q_matmul (NodeProto): MatMul node in fully connection for Q
59
+ k_matmul (NodeProto): MatMul node in fully connection for K
60
+ v_matmul (NodeProto): MatMul node in fully connection for V
61
+ num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
62
+ hidden_size (int): hidden dimension. If a model is pruned, it is the hidden dimension after pruning.
63
+ input (str): input name
64
+ output (str): output name
65
+ Returns:
66
+ Union[NodeProto, None]: the node created or None if failed.
67
+ """
68
+ assert num_heads > 0
69
+
70
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
71
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
72
+ return None
73
+
74
+ q_weight = self.model.get_initializer(q_matmul.input[1])
75
+ k_weight = self.model.get_initializer(k_matmul.input[1])
76
+ v_weight = self.model.get_initializer(v_matmul.input[1])
77
+
78
+ if q_weight is None:
79
+ print(
80
+ f"{q_matmul.input[1]} is not an initializer. "
81
+ "Please set do_constant_folding=True in torch.onnx.export to unblock attention fusion"
82
+ )
83
+ return None
84
+
85
+ qw = NumpyHelper.to_array(q_weight)
86
+ kw = NumpyHelper.to_array(k_weight)
87
+ vw = NumpyHelper.to_array(v_weight)
88
+
89
+ # assert q and k have same shape as expected
90
+ assert qw.shape == kw.shape
91
+
92
+ qw_in_size = qw.shape[0]
93
+ kw_in_size = kw.shape[0]
94
+ vw_in_size = vw.shape[0]
95
+
96
+ assert qw_in_size == kw_in_size == vw_in_size
97
+
98
+ if hidden_size > 0 and hidden_size != qw_in_size:
99
+ logger.warning(
100
+ f"Input hidden size ({hidden_size}) is not same as weight matrix dimension of q,k,v ({qw_in_size}). "
101
+ "Please provide a correct input hidden size or pass in 0"
102
+ )
103
+
104
+ qw_out_size = np.prod(qw.shape[1:])
105
+ qkv_weight = np.stack((qw, kw, vw), axis=1)
106
+ qkv_weight_dim = 3 * qw_out_size
107
+
108
+ attention_node_name = self.model.create_node_name("Attention")
109
+
110
+ weight = helper.make_tensor(
111
+ name=attention_node_name + "_qkv_weight",
112
+ data_type=TensorProto.FLOAT,
113
+ dims=[qw_in_size, qkv_weight_dim],
114
+ vals=qkv_weight.tobytes(),
115
+ raw=True,
116
+ )
117
+
118
+ self.model.add_initializer(weight, self.this_graph_name)
119
+
120
+ attention_inputs = [
121
+ input,
122
+ attention_node_name + "_qkv_weight",
123
+ "",
124
+ ]
125
+ if mask_index is not None:
126
+ attention_inputs.append(mask_index)
127
+ else:
128
+ attention_inputs.append("")
129
+
130
+ if add_qk_str is not None:
131
+ attention_inputs.append("") # no past
132
+ attention_inputs.append(add_qk_str)
133
+
134
+ attention_node = helper.make_node(
135
+ "Attention",
136
+ inputs=attention_inputs,
137
+ outputs=[output],
138
+ name=attention_node_name,
139
+ )
140
+ attention_node.domain = "com.microsoft"
141
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
142
+
143
+ if scale is not None:
144
+ attention_node.attribute.extend([helper.make_attribute("scale", scale)])
145
+
146
+ if self.mask_filter_value is not None:
147
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
148
+
149
+ return attention_node
150
+
151
+ def create_mha_node(
152
+ self,
153
+ query: str,
154
+ key: str,
155
+ value: str,
156
+ mask_index: str,
157
+ res_pos_bias: str,
158
+ past_key: str,
159
+ past_value: str,
160
+ output: str,
161
+ present_key: str,
162
+ present_value: str,
163
+ num_heads: int,
164
+ hidden_size: int,
165
+ ) -> Union[NodeProto, None]:
166
+ assert num_heads > 0
167
+
168
+ if hidden_size > 0 and (hidden_size % num_heads) != 0:
169
+ logger.debug(f"input hidden size {hidden_size} is not a multiple of num of heads {num_heads}")
170
+ return None
171
+
172
+ attention_node_name = self.model.create_node_name("MultiHeadAttention")
173
+ attention_inputs = [
174
+ query,
175
+ "" if key is None else key, # key
176
+ "" if value is None else value, # value
177
+ "", # bias
178
+ ]
179
+ if mask_index is not None:
180
+ attention_inputs.append(mask_index)
181
+ else:
182
+ attention_inputs.append("")
183
+
184
+ if res_pos_bias is not None:
185
+ attention_inputs.append(res_pos_bias)
186
+ else:
187
+ attention_inputs.append("")
188
+
189
+ if past_key is not None:
190
+ assert past_value is not None
191
+ attention_inputs.append(past_key)
192
+ attention_inputs.append(past_value)
193
+
194
+ attention_outputs = [output]
195
+ if present_key is not None:
196
+ assert present_value is not None
197
+ attention_outputs.append(present_key)
198
+ attention_outputs.append(present_value)
199
+
200
+ attention_node = helper.make_node(
201
+ "MultiHeadAttention",
202
+ inputs=attention_inputs,
203
+ outputs=attention_outputs,
204
+ name=attention_node_name,
205
+ )
206
+
207
+ attention_node.domain = "com.microsoft"
208
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
209
+ attention_node.attribute.extend([helper.make_attribute("scale", 1.0)])
210
+ if self.mask_filter_value is not None:
211
+ attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
212
+
213
+ self.increase_counter("MultiHeadAttention")
214
+ return attention_node
215
+
216
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
217
+ self.fuse_t5_encoder(normalize_node, input_name_to_nodes, output_name_to_node)
218
+ self.fuse_t5_decoder(normalize_node, input_name_to_nodes, output_name_to_node)
219
+
220
+ def fuse_t5_encoder(self, normalize_node, input_name_to_nodes, output_name_to_node):
221
+ if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
222
+ return
223
+
224
+ qkv_nodes = self.model.match_parent_path(
225
+ normalize_node,
226
+ ["MatMul", "Reshape", "Transpose", "MatMul"],
227
+ [1, 0, 0, 0],
228
+ )
229
+ if qkv_nodes is None:
230
+ return
231
+
232
+ _, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes
233
+
234
+ qkv_shape_nodes = self.model.match_parent_path(
235
+ reshape_qkv,
236
+ ["Concat", "Unsqueeze", "Gather", "Shape"],
237
+ [1, 0, 0, 0],
238
+ )
239
+ if qkv_shape_nodes is None:
240
+ return
241
+ input_shape_node = qkv_shape_nodes[-1]
242
+
243
+ v_nodes = self.model.match_parent_path(
244
+ matmul_qkv,
245
+ ["Transpose", "Reshape", "MatMul"],
246
+ [1, 0, 0],
247
+ )
248
+ if v_nodes is None:
249
+ return
250
+ _, reshape_v, matmul_v = v_nodes
251
+ # todo: check reshape_v parent nodes
252
+
253
+ qk_nodes = self.model.match_parent_path(
254
+ matmul_qkv,
255
+ ["Softmax", "Add", "MatMul"],
256
+ [0, 0, 0],
257
+ )
258
+ if qk_nodes is None:
259
+ return
260
+ _, add_qk, matmul_qk = qk_nodes
261
+
262
+ mask_index = None
263
+ mask_nodes = self.model.match_parent_path(
264
+ add_qk,
265
+ ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
266
+ [1, 1, 0, 1, 0, 0],
267
+ )
268
+ if mask_nodes is None:
269
+ return
270
+ mul_node = mask_nodes[1]
271
+ if mask_nodes[1].op_type != "Mul":
272
+ return
273
+
274
+ _, mul_val = self.model.get_constant_input(mul_node)
275
+ if mul_val != -10000:
276
+ self.mask_filter_value = mul_val
277
+
278
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
279
+
280
+ res_pos_bias = None
281
+ rpb_nodes = self.model.match_parent_path(
282
+ add_qk,
283
+ ["Add", "RelativePositionBias"],
284
+ [1, 0],
285
+ )
286
+ if rpb_nodes is None:
287
+ return
288
+ rpb_add_node = rpb_nodes[0]
289
+ res_pos_bias = rpb_add_node.input[0]
290
+
291
+ k_nodes = self.model.match_parent_path(
292
+ matmul_qk,
293
+ ["Transpose", "Reshape", "MatMul"],
294
+ [1, 0, 0],
295
+ )
296
+ if k_nodes is None:
297
+ return
298
+ _, reshape_k, matmul_k = k_nodes
299
+ # todo: check reshape_k parent nodes
300
+
301
+ q_nodes = self.model.match_parent_path(
302
+ matmul_qk,
303
+ ["Transpose", "Reshape", "MatMul"],
304
+ [0, 0, 0],
305
+ )
306
+ if q_nodes is None:
307
+ return
308
+
309
+ transpose_q, reshape_q, matmul_q = q_nodes
310
+ # todo: check reshape_q parent nodes
311
+
312
+ if matmul_q.input[0] != input_shape_node.input[0]:
313
+ return
314
+
315
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
316
+
317
+ new_node = self.create_attention_node(
318
+ mask_index,
319
+ matmul_q,
320
+ matmul_k,
321
+ matmul_v,
322
+ q_num_heads,
323
+ q_hidden_size,
324
+ input_shape_node.input[0],
325
+ reshape_qkv.output[0],
326
+ res_pos_bias,
327
+ 1.0,
328
+ )
329
+ if new_node is None:
330
+ return
331
+
332
+ self.nodes_to_add.append(new_node)
333
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
334
+
335
+ self.nodes_to_remove.extend(qkv_nodes[1:])
336
+ self.nodes_to_remove.extend(qk_nodes)
337
+ self.nodes_to_remove.extend(k_nodes[:-1])
338
+ if v_nodes is not None:
339
+ self.nodes_to_remove.extend(v_nodes[:-1])
340
+ self.nodes_to_remove.extend(q_nodes[:-1])
341
+
342
+ self.prune_graph = True
343
+
344
+ def fuse_t5_decoder(self, normalize_node, input_name_to_nodes, output_name_to_node):
345
+ if normalize_node.op_type != "SkipSimplifiedLayerNormalization" and normalize_node.op_type != "Add":
346
+ return
347
+
348
+ qkv_nodes = self.model.match_parent_path(
349
+ normalize_node,
350
+ ["MatMul", "Reshape", "Transpose", "MatMul"],
351
+ [1, 0, 0, 0],
352
+ )
353
+ if qkv_nodes is None:
354
+ return
355
+
356
+ _, reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes
357
+
358
+ qkv_shape_nodes = self.model.match_parent_path(
359
+ reshape_qkv,
360
+ ["Concat", "Unsqueeze", "Gather", "Shape"],
361
+ [1, 0, 0, 0],
362
+ )
363
+ if qkv_shape_nodes is None:
364
+ return
365
+ input_shape_node = qkv_shape_nodes[-1]
366
+
367
+ value = None
368
+ past_value = None
369
+ present_value = None
370
+ v_nodes = self.model.match_parent_path(
371
+ matmul_qkv,
372
+ ["Concat", "Transpose", "Reshape", "MatMul"],
373
+ [1, 1, 0, 0],
374
+ )
375
+ if v_nodes is None:
376
+ v_nodes = self.model.match_parent_path(
377
+ matmul_qkv,
378
+ ["Transpose", "Reshape", "MatMul"],
379
+ [1, 0, 0],
380
+ )
381
+ if v_nodes is not None:
382
+ transpose_v, reshape_v, matmul_v = v_nodes
383
+ value = reshape_v.input[0]
384
+ present_value = transpose_v.output[0]
385
+ if "present_value" not in present_value:
386
+ return
387
+ if matmul_v.input[0] != input_shape_node.input[0]:
388
+ self.static_kv = 1
389
+ else:
390
+ self.static_kv = 0
391
+ else:
392
+ past_value = matmul_qkv.input[1]
393
+ if past_value in output_name_to_node:
394
+ return
395
+ if "past_value_cross" not in past_value:
396
+ return
397
+ self.static_kv = 1
398
+ else:
399
+ concat_v, _, reshape_v, _ = v_nodes
400
+ past_value = concat_v.input[0]
401
+ if past_value in output_name_to_node:
402
+ return
403
+ if "past_value_self" not in past_value:
404
+ return
405
+ present_value = concat_v.output[0]
406
+ if "present_value_self" not in present_value:
407
+ return
408
+ value = reshape_v.input[0]
409
+ self.static_kv = 0
410
+
411
+ qk_nodes = self.model.match_parent_path(
412
+ matmul_qkv,
413
+ ["Softmax", "Add", "MatMul"],
414
+ [0, 0, 0],
415
+ )
416
+ if qk_nodes is None:
417
+ return
418
+ _, add_qk, matmul_qk = qk_nodes
419
+
420
+ mask_index = None
421
+ res_pos_bias = None
422
+ if self.static_kv == 1:
423
+ mask_nodes = self.model.match_parent_path(
424
+ add_qk,
425
+ ["Add", "Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
426
+ [1, 1, 0, 1, 0, 0],
427
+ )
428
+ if mask_nodes is None:
429
+ return
430
+ mul_node = mask_nodes[1]
431
+ if mask_nodes[1].op_type != "Mul":
432
+ return
433
+
434
+ _, mul_val = self.model.get_constant_input(mul_node)
435
+ if mul_val != -10000:
436
+ self.mask_filter_value = mul_val
437
+
438
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
439
+ else:
440
+ rpb_nodes = self.model.match_parent_path(
441
+ add_qk,
442
+ ["Add", "Slice"],
443
+ [1, 0],
444
+ )
445
+ if rpb_nodes is not None:
446
+ res_pos_bias = add_qk.input[1]
447
+ else:
448
+ rpb_nodes = self.model.match_parent_path(
449
+ add_qk,
450
+ ["Add", "RelativePositionBias"],
451
+ [1, 0],
452
+ )
453
+ if rpb_nodes is None:
454
+ return
455
+ res_pos_bias = add_qk.input[1]
456
+
457
+ key = None
458
+ past_key = None
459
+ present_key = None
460
+ if self.static_kv == 1:
461
+ k_nodes = self.model.match_parent_path(
462
+ matmul_qk,
463
+ ["Transpose", "Reshape", "MatMul"],
464
+ [1, 0, 0],
465
+ )
466
+ if k_nodes is not None:
467
+ transpose_k, reshape_k, _ = k_nodes
468
+ key = reshape_k.input[0]
469
+ present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
470
+ for present_key_transpose_node in present_key_transpose_nodes:
471
+ present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
472
+ if present_key_candidate is not None:
473
+ present_key = present_key_candidate.name
474
+ break
475
+ if present_key is None:
476
+ return
477
+ if "present_key_cross" not in present_key:
478
+ return
479
+ else:
480
+ k_nodes = self.model.match_parent_path(
481
+ matmul_qk,
482
+ ["Transpose"],
483
+ [1],
484
+ )
485
+ if k_nodes is None:
486
+ return
487
+ transpose_k = k_nodes[0]
488
+
489
+ past_key = transpose_k.input[0]
490
+ if past_key in output_name_to_node:
491
+ return
492
+ if "past_key_cross" not in past_key:
493
+ return
494
+ else:
495
+ idx, k_nodes, _ = self.model.match_parent_paths(
496
+ matmul_qk,
497
+ [
498
+ (["Transpose", "Concat", "Reshape", "MatMul"], [1, 0, 1, 0]),
499
+ (["Transpose", "Concat", "Transpose", "Reshape", "MatMul"], [1, 0, 1, 0, 0]),
500
+ ],
501
+ output_name_to_node,
502
+ )
503
+ past_key_transpose_node = None
504
+ present_key_transpose_nodes = None
505
+ if k_nodes is not None:
506
+ concat_k, reshape_k = k_nodes[1], k_nodes[-2]
507
+ key = reshape_k.input[0]
508
+
509
+ if idx == 0:
510
+ past_key_transpose_node = output_name_to_node[concat_k.input[0]]
511
+ past_key = past_key_transpose_node.input[0]
512
+ else:
513
+ past_key = concat_k.input[0]
514
+ if past_key in output_name_to_node:
515
+ return
516
+ if "past_key_self" not in past_key:
517
+ return
518
+
519
+ if idx == 0:
520
+ present_key_transpose_nodes = input_name_to_nodes[concat_k.output[0]]
521
+ for present_key_transpose_node in present_key_transpose_nodes:
522
+ present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
523
+ if present_key_candidate is not None:
524
+ present_key = present_key_candidate.name
525
+ break
526
+ else:
527
+ present_key = concat_k.output[0]
528
+ if present_key is None:
529
+ return
530
+ if "present_key_self" not in present_key:
531
+ return
532
+ else:
533
+ k_nodes = self.model.match_parent_path(
534
+ matmul_qk,
535
+ ["Transpose", "Reshape", "MatMul"],
536
+ [1, 0, 0],
537
+ )
538
+ if k_nodes is None:
539
+ return
540
+ _, reshape_k, _ = k_nodes
541
+ key = reshape_k.input[0]
542
+ present_key_transpose_nodes = input_name_to_nodes[reshape_k.output[0]]
543
+ for present_key_transpose_node in present_key_transpose_nodes:
544
+ present_key_candidate = self.model.find_graph_output(present_key_transpose_node.output[0])
545
+ if present_key_candidate is not None:
546
+ present_key = present_key_candidate.name
547
+ break
548
+ if present_key is None:
549
+ return
550
+ if "present_key_self" not in present_key:
551
+ return
552
+
553
+ q_nodes = self.model.match_parent_path(
554
+ matmul_qk,
555
+ ["Transpose", "Reshape", "MatMul"],
556
+ [0, 0, 0],
557
+ )
558
+ if q_nodes is None:
559
+ return
560
+
561
+ transpose_q, reshape_q, matmul_q = q_nodes
562
+
563
+ if matmul_q.input[0] != input_shape_node.input[0]:
564
+ return
565
+
566
+ q_num_heads, q_hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
567
+
568
+ if self.static_kv == 1 and past_key is not None:
569
+ key = past_key
570
+ value = past_value
571
+ past_key = None
572
+ past_value = None
573
+
574
+ new_node = self.create_mha_node(
575
+ matmul_q.output[0],
576
+ key,
577
+ value,
578
+ mask_index,
579
+ res_pos_bias,
580
+ past_key,
581
+ past_value,
582
+ reshape_qkv.output[0],
583
+ present_key,
584
+ present_value,
585
+ q_num_heads,
586
+ q_hidden_size,
587
+ )
588
+ if new_node is None:
589
+ return
590
+
591
+ self.nodes_to_add.append(new_node)
592
+ self.node_name_to_graph_name[new_node.name] = self.this_graph_name
593
+
594
+ self.nodes_to_remove.extend(qkv_nodes[1:])
595
+ self.nodes_to_remove.extend(qk_nodes)
596
+ self.nodes_to_remove.extend(k_nodes[:-1])
597
+ if v_nodes is not None:
598
+ self.nodes_to_remove.extend(v_nodes[:-1])
599
+ self.nodes_to_remove.extend(q_nodes[:-1])
600
+
601
+ self.prune_graph = True
602
+
603
+
604
+ class FusionRelativePositionBiasBlock(Fusion):
605
+ def __init__(self, model: OnnxModel, max_distance: int):
606
+ super().__init__(model, "RelativePositionBias", ["Add", "Slice"])
607
+ self.max_distance = max_distance
608
+ # bidirectional=(not self.is_decoder)
609
+ self.is_bidirectional = False
610
+
611
+ def fuse(self, node, input_name_to_nodes, output_name_to_node):
612
+ # TODO: Optimization opportunity: only last dimension of relative_position_bias is used in decoder.
613
+ # Cuda kernel can be optimized to only compute last dimension.
614
+ if node.op_type != "Add" and node.op_type != "Slice":
615
+ return
616
+
617
+ compute_bias_nodes = self.model.match_parent_path(
618
+ node, ["Unsqueeze", "Transpose", "Gather", "Where"], [0, 0, 0, 1]
619
+ )
620
+ if compute_bias_nodes is None:
621
+ compute_bias_nodes = self.model.match_parent_path(
622
+ node, ["Unsqueeze", "Transpose", "Gather", "Add", "Where"], [0, 0, 0, 1, 1]
623
+ )
624
+ if compute_bias_nodes is None:
625
+ return
626
+
627
+ gather = compute_bias_nodes[2]
628
+ where = compute_bias_nodes[-1]
629
+ unsqueeze = compute_bias_nodes[0]
630
+
631
+ compute_buckets_nodes = self.model.match_parent_path(
632
+ where,
633
+ ["Min", "ConstantOfShape", "Shape", "Add", "Cast", "Mul", "Div", "Log", "Div"],
634
+ [2, 1, 0, 0, 0, 0, 0, 0, 0],
635
+ )
636
+ if compute_buckets_nodes is None:
637
+ return
638
+
639
+ div = compute_buckets_nodes[-1]
640
+
641
+ range_nodes = self.model.match_parent_path(
642
+ div,
643
+ ["Cast", "Neg", "Min", "ConstantOfShape", "Shape", "Sub", "Unsqueeze", "Range"],
644
+ [0, 0, 0, 1, 0, 0, 0, 0],
645
+ )
646
+ if range_nodes is None:
647
+ range_nodes = self.model.match_parent_path(
648
+ div, ["Cast", "Abs", "Sub", "Unsqueeze", "Range"], [0, 0, 0, 0, 0]
649
+ )
650
+ self.is_bidirectional = True
651
+ if range_nodes is None:
652
+ return
653
+
654
+ range_node = range_nodes[-1]
655
+
656
+ self.nodes_to_remove.extend(compute_bias_nodes)
657
+ self.nodes_to_remove.extend(compute_buckets_nodes)
658
+ self.nodes_to_remove.extend(range_nodes)
659
+
660
+ node_name_prefix = "encoder" if self.is_bidirectional else "decoder"
661
+
662
+ table_weight_i = self.model.get_initializer(gather.input[0])
663
+ table_weight = NumpyHelper.to_array(table_weight_i)
664
+ table_weight_t = np.transpose(table_weight)
665
+ bias_table = helper.make_tensor(
666
+ name=self.model.create_node_name("bias_table_weight", name_prefix=node_name_prefix),
667
+ data_type=TensorProto.FLOAT,
668
+ dims=[np.shape(table_weight)[0], np.shape(table_weight)[1]],
669
+ vals=table_weight_t.tobytes(),
670
+ raw=True,
671
+ )
672
+
673
+ self.model.add_initializer(bias_table, self.this_graph_name)
674
+ inputs = [bias_table.name, range_node.input[1], range_node.input[1]]
675
+ outputs = [unsqueeze.output[0]]
676
+ rpb_node = helper.make_node(
677
+ "RelativePositionBias",
678
+ inputs=inputs,
679
+ outputs=outputs,
680
+ name=self.model.create_node_name("RelativePositionBias", name_prefix=node_name_prefix),
681
+ )
682
+ rpb_node.domain = "com.microsoft"
683
+ rpb_node.attribute.extend([helper.make_attribute("max_distance", self.max_distance)])
684
+ rpb_node.attribute.extend([helper.make_attribute("is_bidirectional", self.is_bidirectional)])
685
+
686
+ self.nodes_to_add.append(rpb_node)
687
+ self.node_name_to_graph_name[rpb_node.name] = self.this_graph_name
688
+
689
+
690
+ class T5OnnxModel(BertOnnxModel):
691
+ def __init__(self, model, num_heads, hidden_size):
692
+ super().__init__(model, num_heads, hidden_size)
693
+ self.attention_mask = AttentionMask(self)
694
+ self.attention_fusion = FusionT5Attention(self, self.hidden_size, self.num_heads, self.attention_mask)
695
+ self.layer_norm_fusion = FusionSimplifiedLayerNormalization(self)
696
+ self.skip_layer_norm_fusion = FusionSkipSimplifiedLayerNormalization(self)
697
+ # TODO: consider retrieve max_distance from model.
698
+ # math.log(max_distance / (num_buckets // 2))
699
+ self.rpb_fusion = FusionRelativePositionBiasBlock(self, 128)
700
+
701
+ def fuse_attention(self):
702
+ self.attention_fusion.apply()
703
+
704
+ def fuse_layer_norm(self):
705
+ self.layer_norm_fusion.apply()
706
+
707
+ def fuse_skip_layer_norm(self):
708
+ self.skip_layer_norm_fusion.apply()
709
+
710
+ # Remove get_extended_attention_mask() since it generates all zeros.
711
+ def remove_extended_mask_decoder_init(self):
712
+ nodes_to_remove = []
713
+ for node in self.nodes():
714
+ if node.op_type == "Add":
715
+ extended_mask_nodes = self.match_parent_path(
716
+ node,
717
+ [
718
+ "Mul",
719
+ "Sub",
720
+ "Mul",
721
+ "Unsqueeze",
722
+ "Cast",
723
+ "LessOrEqual",
724
+ "Tile",
725
+ "Concat",
726
+ "Unsqueeze",
727
+ "Gather",
728
+ "Shape",
729
+ ],
730
+ [1, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0],
731
+ )
732
+ if extended_mask_nodes is None:
733
+ continue
734
+
735
+ rpb_nodes = self.match_parent_path(node, ["RelativePositionBias"], [0])
736
+ if rpb_nodes is None:
737
+ continue
738
+
739
+ rpb_node = rpb_nodes[0]
740
+ rpb_node.output[0] = node.output[0]
741
+
742
+ nodes_to_remove.extend(extended_mask_nodes)
743
+ nodes_to_remove.append(node)
744
+ self.remove_nodes(nodes_to_remove)
745
+
746
+ def remove_extended_mask_decoder(self):
747
+ nodes_to_remove = []
748
+ for node in self.nodes():
749
+ if node.op_type == "Add":
750
+ extended_mask_nodes = self.match_parent_path(
751
+ node,
752
+ [
753
+ "Mul",
754
+ "Sub",
755
+ "Mul",
756
+ "Unsqueeze",
757
+ "Concat",
758
+ "Cast",
759
+ "LessOrEqual",
760
+ "Tile",
761
+ "Concat",
762
+ "Unsqueeze",
763
+ "Gather",
764
+ "Shape",
765
+ ],
766
+ [1, 0, 1, 0, 0, 1, 0, 0, 1, 1, 0, 0],
767
+ )
768
+ if extended_mask_nodes is None:
769
+ continue
770
+
771
+ rpb_nodes = self.match_parent_path(node, ["Slice", "RelativePositionBias"], [0, 0])
772
+ if rpb_nodes is None:
773
+ continue
774
+
775
+ rpb_node = rpb_nodes[0]
776
+ rpb_node.output[0] = node.output[0]
777
+
778
+ nodes_to_remove.extend(extended_mask_nodes)
779
+ nodes_to_remove.append(node)
780
+ self.remove_nodes(nodes_to_remove)
781
+
782
+ def preprocess(self):
783
+ self.adjust_reshape_and_expand()
784
+ self.rpb_fusion.apply()
785
+
786
+ def postprocess(self):
787
+ # remove get_extended_attention_mask() since it generates all zeros.
788
+ self.remove_extended_mask_decoder_init()
789
+ self.remove_extended_mask_decoder()
790
+
791
+ self.prune_graph()