onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,421 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Tuple
8
+
9
+ import numpy as np
10
+ from fusion_attention import AttentionMask
11
+ from fusion_base import Fusion
12
+ from fusion_utils import FusionUtils, NumpyHelper
13
+ from onnx import NodeProto, helper
14
+ from onnx_model import OnnxModel
15
+
16
+ logger = getLogger(__name__)
17
+
18
+
19
+ class FusionQOrderedAttention(Fusion):
20
+ def __init__(
21
+ self,
22
+ model: OnnxModel,
23
+ hidden_size: int,
24
+ num_heads: int,
25
+ attention_mask: AttentionMask,
26
+ ):
27
+ self.hidden_size = hidden_size
28
+ self.num_heads = num_heads
29
+ self.attention_mask = attention_mask
30
+
31
+ super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
32
+
33
+ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> Tuple[int, int]:
34
+ """Detect num_heads and hidden_size from a reshape node.
35
+ Args:
36
+ reshape_q (NodeProto): reshape node for Q
37
+ Returns:
38
+ Tuple[int, int]: num_heads and hidden_size
39
+ """
40
+
41
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
42
+ q_shape = self.model.get_initializer(reshape_q.input[1])
43
+ if q_shape is None:
44
+ logger.debug(f"{reshape_q.input[1]} is not initializer.")
45
+
46
+ # Check if the second input to Reshape flows through a Constant node
47
+ # TODO: Investigate why FusionAttention doesn't have such logic
48
+ constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
49
+
50
+ if constant_node is None:
51
+ return self.num_heads, self.hidden_size # Fall back to user specified value
52
+ else:
53
+ constant_node = constant_node[0]
54
+
55
+ if len(constant_node.attribute) != 1:
56
+ return self.num_heads, self.hidden_size # Fall back to user specified value
57
+
58
+ # This is assuming it is a Tensor attribute (this is a safe assumption)
59
+ q_shape = constant_node.attribute[0].t
60
+
61
+ q_shape_value = NumpyHelper.to_array(q_shape)
62
+ if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
63
+ logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
64
+ return self.num_heads, self.hidden_size # Fall back to user specified value
65
+
66
+ num_heads = q_shape_value[2]
67
+ head_size = q_shape_value[3]
68
+ hidden_size = num_heads * head_size
69
+
70
+ if self.num_heads > 0 and num_heads != self.num_heads:
71
+ if self.num_heads_warning:
72
+ logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
73
+ self.num_heads_warning = False # Do not show the warning more than once
74
+
75
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
76
+ if self.hidden_size_warning:
77
+ logger.warning(
78
+ f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
79
+ )
80
+ self.hidden_size_warning = False # Do not show the warning more than once
81
+
82
+ return num_heads, hidden_size
83
+
84
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
85
+ add_before_layernorm = self.model.match_parent_path(
86
+ normalize_node,
87
+ ["QuantizeLinear", "Add"],
88
+ [0, 0],
89
+ )
90
+
91
+ if add_before_layernorm is not None:
92
+ start_node = add_before_layernorm[-1]
93
+ else:
94
+ return
95
+
96
+ # Input QDQ nodes
97
+ dequantize_input = self.model.match_parent_path(
98
+ start_node,
99
+ ["DequantizeLinear"],
100
+ [None],
101
+ )
102
+
103
+ if dequantize_input is None:
104
+ logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
105
+ return
106
+
107
+ dequantize_input = dequantize_input[-1]
108
+
109
+ # QKV nodes
110
+ qkv_nodes = self.model.match_parent_path(
111
+ start_node,
112
+ ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
113
+ [None, None, 0, 0, 0, 0, 0],
114
+ )
115
+
116
+ if qkv_nodes is None:
117
+ logger.debug("fuse_qordered_attention: failed to match qkv path")
118
+ return
119
+
120
+ (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
121
+
122
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
123
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
124
+ return
125
+
126
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
127
+ return
128
+
129
+ # Identify the root input to the Attention node
130
+ other_inputs = []
131
+ for _i, input in enumerate(start_node.input):
132
+ if input not in output_name_to_node:
133
+ continue
134
+
135
+ if input == qkv_nodes[0].output[0]:
136
+ continue
137
+
138
+ other_inputs.append(input)
139
+
140
+ if len(other_inputs) != 1:
141
+ return
142
+
143
+ root_input = other_inputs[0]
144
+
145
+ # V nodes
146
+ v_nodes = self.model.match_parent_path(
147
+ matmul_qkv,
148
+ ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
149
+ [1, 0, 0, 0, 0, None],
150
+ )
151
+
152
+ if v_nodes is None:
153
+ logger.debug("fuse_qordered_attention: failed to match v path")
154
+ return
155
+
156
+ (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
157
+
158
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
159
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
160
+ return
161
+
162
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
163
+ return
164
+
165
+ # V MatMul weight
166
+ dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
167
+
168
+ if dequantize_v_matmul_weight is None:
169
+ logger.debug("fuse_qordered_attention: failed to match v path")
170
+ return
171
+
172
+ dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
173
+
174
+ if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
175
+ return
176
+
177
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
178
+ # Per-channel scales are supported for weights alone
179
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
180
+ return
181
+
182
+ # QK nodes
183
+ qk_nodes = self.model.match_parent_path(
184
+ matmul_qkv,
185
+ [
186
+ "DequantizeLinear",
187
+ "QuantizeLinear",
188
+ "Softmax",
189
+ "Add",
190
+ "Div",
191
+ "DequantizeLinear",
192
+ "QuantizeLinear",
193
+ "MatMul",
194
+ ],
195
+ [0, 0, 0, 0, None, 0, 0, 0],
196
+ )
197
+
198
+ if qk_nodes is None:
199
+ logger.debug("fuse_qordered_attention: failed to match qk path")
200
+ return
201
+
202
+ (
203
+ dequantize_qk_softmax,
204
+ quantize_qk_softmax,
205
+ softmax_qk,
206
+ add_qk,
207
+ div_qk,
208
+ dequantize_qk,
209
+ quantize_qk,
210
+ matmul_qk,
211
+ ) = qk_nodes
212
+
213
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
214
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
215
+ return
216
+
217
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
218
+ return
219
+
220
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
221
+ return
222
+
223
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
224
+ return
225
+
226
+ # Q nodes
227
+ q_nodes = self.model.match_parent_path(
228
+ matmul_qk,
229
+ ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
230
+ [0, 0, 0, 0, 0, None],
231
+ )
232
+
233
+ if q_nodes is None:
234
+ logger.debug("fuse_qordered_attention: failed to match q path")
235
+ return
236
+
237
+ (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
238
+
239
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
240
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
241
+ return
242
+
243
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
244
+ return
245
+
246
+ # Q MatMul weight
247
+ dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
248
+
249
+ if dequantize_q_matmul_weight is None:
250
+ logger.debug("fuse_qordered_attention: failed to match q path")
251
+ return
252
+
253
+ dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
254
+
255
+ if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
256
+ return
257
+
258
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
259
+ # Per-channel scales are supported for weights alone
260
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
261
+ return
262
+
263
+ # K nodes
264
+ k_nodes = self.model.match_parent_path(
265
+ matmul_qk,
266
+ ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
267
+ [1, 0, 0, 0, 0, None],
268
+ )
269
+
270
+ if k_nodes is None:
271
+ logger.debug("fuse_qordered_attention: failed to match k path")
272
+ return
273
+
274
+ (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
275
+
276
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
277
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
278
+ return
279
+
280
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
281
+ return
282
+
283
+ # K MatMul weight
284
+ dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
285
+
286
+ if dequantize_k_matmul_weight is None:
287
+ logger.debug("fuse_qordered_attention: failed to match k path")
288
+ return
289
+
290
+ dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
291
+
292
+ if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
293
+ return
294
+
295
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
296
+ # Per-channel scales are supported for weights alone
297
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
298
+ return
299
+
300
+ # Mask nodes
301
+ mask_nodes = self.model.match_parent_path(
302
+ add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
303
+ )
304
+
305
+ if mask_nodes is None:
306
+ logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
307
+ return
308
+
309
+ # Ascertain `qkv_hidden_sizes` attribute value
310
+ q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
311
+ k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
312
+ v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
313
+
314
+ qw = NumpyHelper.to_array(q_weight)
315
+ kw = NumpyHelper.to_array(k_weight)
316
+ vw = NumpyHelper.to_array(v_weight)
317
+
318
+ qw_out_size = np.prod(qw.shape[1:])
319
+ kw_out_size = np.prod(kw.shape[1:])
320
+ vw_out_size = np.prod(vw.shape[1:])
321
+
322
+ # Form QOrderedAttention node
323
+ if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
324
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
325
+
326
+ # Ascertain `num_heads` and `hidden_size`
327
+ num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
328
+
329
+ # Formulate the inputs
330
+ # Actual quantized input
331
+ attention_inputs = [dequantize_input.input[0]]
332
+ attention_inputs.append(dequantize_input.input[1])
333
+
334
+ attention_inputs.append(dequantize_q.input[1])
335
+ attention_inputs.append(dequantize_k.input[1])
336
+ attention_inputs.append(dequantize_v.input[1])
337
+
338
+ attention_inputs.append(dequantize_q_matmul_weight.input[0])
339
+ attention_inputs.append(dequantize_k_matmul_weight.input[0])
340
+ attention_inputs.append(dequantize_v_matmul_weight.input[0])
341
+
342
+ attention_inputs.append(dequantize_q_matmul_weight.input[1])
343
+ attention_inputs.append(dequantize_k_matmul_weight.input[1])
344
+ attention_inputs.append(dequantize_v_matmul_weight.input[1])
345
+
346
+ if self.model.get_initializer(add_q.input[0]):
347
+ attention_inputs.append(add_q.input[0])
348
+ else: # second input is the constant bias
349
+ attention_inputs.append(add_q.input[1])
350
+
351
+ if self.model.get_initializer(add_k.input[0]):
352
+ attention_inputs.append(add_k.input[0])
353
+ else: # second input is the constant bias
354
+ attention_inputs.append(add_k.input[1])
355
+
356
+ if self.model.get_initializer(add_v.input[0]):
357
+ attention_inputs.append(add_v.input[0])
358
+ else: # second input is the constant bias
359
+ attention_inputs.append(add_v.input[1])
360
+
361
+ attention_inputs.append(quantize_qk.input[1])
362
+ attention_inputs.append(quantize_qk_softmax.input[1])
363
+ attention_inputs.append(dequantize_qkv.input[1])
364
+
365
+ # Mask input
366
+ if mask_index is not None:
367
+ attention_inputs.append(mask_index)
368
+ else:
369
+ attention_inputs.append("")
370
+
371
+ # The MatMul weight 'B' and 'bias' need some post-processing
372
+ # Transpose weight 'B' from order ROW to order COL
373
+ # This offline transpose is needed only while using the CUDA EP
374
+ # TODO: Make this fusion logic EP-agnostic ?
375
+ q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
376
+ FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
377
+
378
+ k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
379
+ FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
380
+
381
+ v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
382
+ FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
383
+
384
+ # Name and create Attention node
385
+ attention_node_name = self.model.create_node_name("QOrderedAttention")
386
+
387
+ attention_node = helper.make_node(
388
+ "QOrderedAttention",
389
+ inputs=attention_inputs,
390
+ outputs=[reshape_qkv.output[0]],
391
+ name=attention_node_name,
392
+ )
393
+
394
+ self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
395
+ self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
396
+
397
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
398
+ attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
399
+ attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
400
+ attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
401
+ attention_node.attribute.extend(
402
+ [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
403
+ )
404
+
405
+ attention_node.domain = "com.microsoft"
406
+
407
+ self.nodes_to_add.append(attention_node)
408
+ self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
409
+
410
+ self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
411
+ self.nodes_to_remove.extend(qk_nodes)
412
+ self.nodes_to_remove.extend(q_nodes)
413
+ self.nodes_to_remove.extend(k_nodes)
414
+ self.nodes_to_remove.extend(v_nodes)
415
+ self.nodes_to_remove.extend(
416
+ [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
417
+ )
418
+
419
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
420
+ # self.nodes_to_remove.extend(mask_nodes)
421
+ self.prune_graph = True
@@ -0,0 +1,119 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+ from typing import Dict
8
+
9
+ from fusion_base import Fusion
10
+ from fusion_utils import FusionUtils
11
+ from onnx import helper
12
+ from onnx_model import OnnxModel
13
+
14
+ logger = getLogger(__name__)
15
+
16
+
17
+ class FusionQOrderedGelu(Fusion):
18
+ def __init__(self, model: OnnxModel):
19
+ super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"])
20
+
21
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
22
+ """
23
+ INPUT PATTERN
24
+ Fuse (quantized) Gelu subgraph into one node QOrderedGelu:
25
+ -> quantized input -> DQ -> Gelu -> Q ->
26
+
27
+ (or)
28
+
29
+ -> quantized input -> DQ -> FastGelu -> Q ->
30
+
31
+ OUTPUT PATTERN
32
+ -> QOrderedGelu ->
33
+ """
34
+ gelu_children = self.model.get_children(node, input_name_to_nodes)
35
+
36
+ # Should only have 1 child - QuantizeLinear (or)
37
+ # Should have 2 children - QuantizeLinear + Shape
38
+ if not (
39
+ (len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear")
40
+ or (
41
+ len(gelu_children) == 2
42
+ and gelu_children[0].op_type == "QuantizeLinear"
43
+ and gelu_children[1].op_type == "Shape"
44
+ )
45
+ ):
46
+ return
47
+
48
+ downstream_quantize_node = gelu_children[0]
49
+ downstream_shape_node = None
50
+
51
+ if len(gelu_children) == 2:
52
+ downstream_shape_node = gelu_children[1]
53
+
54
+ if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
55
+ return
56
+
57
+ # The first input to Gelu should flow through a DequantizeLinear node
58
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
59
+ node,
60
+ [(["DequantizeLinear"], [0])],
61
+ output_name_to_node,
62
+ )
63
+
64
+ if first_path_id < 0:
65
+ return
66
+
67
+ upstream_dequantize_node = first_input_parent_nodes[0]
68
+
69
+ if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
70
+ return
71
+
72
+ # Fusion logic
73
+ subgraph_nodes = [node] # Gelu/FastGelu
74
+ subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes
75
+
76
+ if not self.model.is_safe_to_fuse_nodes(
77
+ subgraph_nodes,
78
+ (
79
+ [node.output[0], downstream_quantize_node.output[0]]
80
+ if downstream_shape_node is not None
81
+ else downstream_quantize_node.output
82
+ ),
83
+ input_name_to_nodes,
84
+ output_name_to_node,
85
+ ):
86
+ logger.debug("It is not safe to fuse QOrderedGelu node. Skip")
87
+ return
88
+
89
+ self.nodes_to_remove.extend(subgraph_nodes)
90
+
91
+ ordered_gelu_node = helper.make_node(
92
+ "QOrderedGelu",
93
+ inputs=[
94
+ upstream_dequantize_node.input[0],
95
+ upstream_dequantize_node.input[1],
96
+ downstream_quantize_node.input[1],
97
+ ],
98
+ outputs=[downstream_quantize_node.output[0]],
99
+ name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"),
100
+ )
101
+
102
+ # Arrange the downstream Shape's input to be fed from the
103
+ # downstream QuantizeLinear node, so that fusion will
104
+ # be deemed safe
105
+ if downstream_shape_node is not None:
106
+ self.model.replace_node_input(
107
+ downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
108
+ )
109
+
110
+ # TODO: We only support CuBlasLt order ORDER_ROW for now.
111
+ # Once we start supporting other data ordering format(s), we
112
+ # will support user configuring the data ordering for the op.
113
+ ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)])
114
+ ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)])
115
+
116
+ ordered_gelu_node.domain = "com.microsoft"
117
+
118
+ self.nodes_to_add.append(ordered_gelu_node)
119
+ self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name
@@ -0,0 +1,123 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+ from typing import Dict
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionQOrderedLayerNormalization(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "QOrderedLayerNormalization", "LayerNormalization")
19
+
20
+ def fuse(self, node, input_name_to_nodes: Dict, output_name_to_node: Dict):
21
+ """
22
+ Fuse (quantized) Layer Normalization subgraph into one node QOrderedLayerNormalization:
23
+ quantized input -> DQ
24
+ |
25
+ |
26
+ (other inputs)-> LayerNormalization --> Q -->
27
+
28
+ should become
29
+
30
+ (quantized input + other inputs)-> QOrderedLayerNormalization --> Q -->
31
+ """
32
+
33
+ children = self.model.get_children(node, input_name_to_nodes)
34
+
35
+ # Should only have 1 child - QuantizeLinear (or)
36
+ # Should have 2 children - QuantizeLinear + Shape
37
+ if not (
38
+ (len(children) == 1 and children[0].op_type == "QuantizeLinear")
39
+ or (len(children) == 2 and children[0].op_type == "QuantizeLinear" and children[1].op_type == "Shape")
40
+ ):
41
+ return
42
+
43
+ downstream_quantize_node = children[0]
44
+ downstream_shape_node = None
45
+
46
+ if len(children) == 2:
47
+ downstream_shape_node = children[1]
48
+
49
+ if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
50
+ return
51
+
52
+ # The first input to LayerNormalization should flow through a DequantizeLinear node
53
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
54
+ node,
55
+ [(["DequantizeLinear"], [0])],
56
+ output_name_to_node,
57
+ )
58
+
59
+ if first_path_id < 0:
60
+ return
61
+
62
+ upstream_dequantize_node = first_input_parent_nodes[0]
63
+
64
+ if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
65
+ return
66
+
67
+ # Fusion logic
68
+ subgraph_nodes = [node] # LayerNormalization
69
+ subgraph_nodes.extend([downstream_quantize_node]) # Q node after LayerNormalization
70
+
71
+ upstream_dequantize_node_children = self.model.get_children(upstream_dequantize_node, input_name_to_nodes)
72
+
73
+ # In GPT2, the DQ node will be feeding a residual downstream Add and hence,
74
+ # we do not want to remove it
75
+ if len(upstream_dequantize_node_children) == 1:
76
+ subgraph_nodes.extend([upstream_dequantize_node]) # DQ node before LayerNormalization
77
+
78
+ if not self.model.is_safe_to_fuse_nodes(
79
+ subgraph_nodes,
80
+ (
81
+ [node.output[0], downstream_quantize_node.output[0]]
82
+ if downstream_shape_node is not None
83
+ else downstream_quantize_node.output
84
+ ),
85
+ input_name_to_nodes,
86
+ output_name_to_node,
87
+ ):
88
+ logger.debug("It is not safe to fuse QOrderedLayerNormalization node. Skip")
89
+ return
90
+
91
+ self.nodes_to_remove.extend(subgraph_nodes)
92
+
93
+ normalize_node = helper.make_node(
94
+ "QOrderedLayerNormalization",
95
+ inputs=[
96
+ upstream_dequantize_node.input[0],
97
+ upstream_dequantize_node.input[1],
98
+ node.input[1],
99
+ node.input[2],
100
+ downstream_quantize_node.input[1],
101
+ ],
102
+ outputs=[downstream_quantize_node.output[0]],
103
+ name=self.model.create_node_name("QOrderedLayerNormalization", name_prefix="QOrderedLayerNormalization"),
104
+ )
105
+
106
+ # Arrange the downstream Shape's input to be fed from the
107
+ # downstream QuantizeLinear node, so that fusion will
108
+ # be deemed safe
109
+ if downstream_shape_node is not None:
110
+ self.model.replace_node_input(
111
+ downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
112
+ )
113
+
114
+ # TODO: We only support CuBlasLt order ORDER_ROW for now.
115
+ # Once we start supporting other data ordering format(s), we
116
+ # will support user configuring the data ordering for the op.
117
+ normalize_node.attribute.extend([helper.make_attribute("order_X", 1)])
118
+ normalize_node.attribute.extend([helper.make_attribute("order_Y", 1)])
119
+
120
+ normalize_node.domain = "com.microsoft"
121
+
122
+ self.nodes_to_add.append(normalize_node)
123
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name