onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,420 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ import numpy as np
9
+ from fusion_attention import AttentionMask
10
+ from fusion_base import Fusion
11
+ from fusion_utils import FusionUtils, NumpyHelper
12
+ from onnx import NodeProto, helper
13
+ from onnx_model import OnnxModel
14
+
15
+ logger = getLogger(__name__)
16
+
17
+
18
+ class FusionQOrderedAttention(Fusion):
19
+ def __init__(
20
+ self,
21
+ model: OnnxModel,
22
+ hidden_size: int,
23
+ num_heads: int,
24
+ attention_mask: AttentionMask,
25
+ ):
26
+ self.hidden_size = hidden_size
27
+ self.num_heads = num_heads
28
+ self.attention_mask = attention_mask
29
+
30
+ super().__init__(model, "QOrderedAttention", "QOrderedLayerNormalization")
31
+
32
+ def get_num_heads_and_hidden_size(self, reshape_q: NodeProto) -> tuple[int, int]:
33
+ """Detect num_heads and hidden_size from a reshape node.
34
+ Args:
35
+ reshape_q (NodeProto): reshape node for Q
36
+ Returns:
37
+ Tuple[int, int]: num_heads and hidden_size
38
+ """
39
+
40
+ # we assume that reshape fusion has done, so the shape is a tensor like [0, 0, num_heads, head_size]
41
+ q_shape = self.model.get_initializer(reshape_q.input[1])
42
+ if q_shape is None:
43
+ logger.debug(f"{reshape_q.input[1]} is not initializer.")
44
+
45
+ # Check if the second input to Reshape flows through a Constant node
46
+ # TODO: Investigate why FusionAttention doesn't have such logic
47
+ constant_node = self.model.match_parent_path(reshape_q, ["Constant"], [1])
48
+
49
+ if constant_node is None:
50
+ return self.num_heads, self.hidden_size # Fall back to user specified value
51
+ else:
52
+ constant_node = constant_node[0]
53
+
54
+ if len(constant_node.attribute) != 1:
55
+ return self.num_heads, self.hidden_size # Fall back to user specified value
56
+
57
+ # This is assuming it is a Tensor attribute (this is a safe assumption)
58
+ q_shape = constant_node.attribute[0].t
59
+
60
+ q_shape_value = NumpyHelper.to_array(q_shape)
61
+ if len(q_shape_value) != 4 or (q_shape_value[2] <= 0 or q_shape_value[3] <= 0):
62
+ logger.debug(f"q_shape_value={q_shape_value}. Expected value are like [0, 0, num_heads, head_size].")
63
+ return self.num_heads, self.hidden_size # Fall back to user specified value
64
+
65
+ num_heads = q_shape_value[2]
66
+ head_size = q_shape_value[3]
67
+ hidden_size = num_heads * head_size
68
+
69
+ if self.num_heads > 0 and num_heads != self.num_heads:
70
+ if self.num_heads_warning:
71
+ logger.warning(f"--num_heads is {self.num_heads}. Detected value is {num_heads}. Using detected value.")
72
+ self.num_heads_warning = False # Do not show the warning more than once
73
+
74
+ if self.hidden_size > 0 and hidden_size != self.hidden_size:
75
+ if self.hidden_size_warning:
76
+ logger.warning(
77
+ f"--hidden_size is {self.hidden_size}. Detected value is {hidden_size}. Using detected value."
78
+ )
79
+ self.hidden_size_warning = False # Do not show the warning more than once
80
+
81
+ return num_heads, hidden_size
82
+
83
+ def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
84
+ add_before_layernorm = self.model.match_parent_path(
85
+ normalize_node,
86
+ ["QuantizeLinear", "Add"],
87
+ [0, 0],
88
+ )
89
+
90
+ if add_before_layernorm is not None:
91
+ start_node = add_before_layernorm[-1]
92
+ else:
93
+ return
94
+
95
+ # Input QDQ nodes
96
+ dequantize_input = self.model.match_parent_path(
97
+ start_node,
98
+ ["DequantizeLinear"],
99
+ [None],
100
+ )
101
+
102
+ if dequantize_input is None:
103
+ logger.debug("fuse_qordered_attention: failed to match input qdq nodes path")
104
+ return
105
+
106
+ dequantize_input = dequantize_input[-1]
107
+
108
+ # QKV nodes
109
+ qkv_nodes = self.model.match_parent_path(
110
+ start_node,
111
+ ["Add", "MatMul", "Reshape", "Transpose", "DequantizeLinear", "QuantizeLinear", "MatMul"],
112
+ [None, None, 0, 0, 0, 0, 0],
113
+ )
114
+
115
+ if qkv_nodes is None:
116
+ logger.debug("fuse_qordered_attention: failed to match qkv path")
117
+ return
118
+
119
+ (_, projection_matmul, reshape_qkv, transpose_qkv, dequantize_qkv, quantize_qkv, matmul_qkv) = qkv_nodes
120
+
121
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
122
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_qkv, self.model):
123
+ return
124
+
125
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_qkv, self.model):
126
+ return
127
+
128
+ # Identify the root input to the Attention node
129
+ other_inputs = []
130
+ for _i, input in enumerate(start_node.input):
131
+ if input not in output_name_to_node:
132
+ continue
133
+
134
+ if input == qkv_nodes[0].output[0]:
135
+ continue
136
+
137
+ other_inputs.append(input)
138
+
139
+ if len(other_inputs) != 1:
140
+ return
141
+
142
+ root_input = other_inputs[0]
143
+
144
+ # V nodes
145
+ v_nodes = self.model.match_parent_path(
146
+ matmul_qkv,
147
+ ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
148
+ [1, 0, 0, 0, 0, None],
149
+ )
150
+
151
+ if v_nodes is None:
152
+ logger.debug("fuse_qordered_attention: failed to match v path")
153
+ return
154
+
155
+ (_, _, dequantize_v, quantize_v, add_v, matmul_v) = v_nodes
156
+
157
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
158
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_v, self.model):
159
+ return
160
+
161
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_v, self.model):
162
+ return
163
+
164
+ # V MatMul weight
165
+ dequantize_v_matmul_weight = self.model.match_parent_path(matmul_v, ["DequantizeLinear"], [1])
166
+
167
+ if dequantize_v_matmul_weight is None:
168
+ logger.debug("fuse_qordered_attention: failed to match v path")
169
+ return
170
+
171
+ dequantize_v_matmul_weight = dequantize_v_matmul_weight[0]
172
+
173
+ if self.model.get_constant_value(dequantize_v_matmul_weight.input[0]) is None:
174
+ return
175
+
176
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
177
+ # Per-channel scales are supported for weights alone
178
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_v_matmul_weight, self.model, False):
179
+ return
180
+
181
+ # QK nodes
182
+ qk_nodes = self.model.match_parent_path(
183
+ matmul_qkv,
184
+ [
185
+ "DequantizeLinear",
186
+ "QuantizeLinear",
187
+ "Softmax",
188
+ "Add",
189
+ "Div",
190
+ "DequantizeLinear",
191
+ "QuantizeLinear",
192
+ "MatMul",
193
+ ],
194
+ [0, 0, 0, 0, None, 0, 0, 0],
195
+ )
196
+
197
+ if qk_nodes is None:
198
+ logger.debug("fuse_qordered_attention: failed to match qk path")
199
+ return
200
+
201
+ (
202
+ dequantize_qk_softmax,
203
+ quantize_qk_softmax,
204
+ softmax_qk,
205
+ add_qk,
206
+ div_qk,
207
+ dequantize_qk,
208
+ quantize_qk,
209
+ matmul_qk,
210
+ ) = qk_nodes
211
+
212
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
213
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_qk_softmax, self.model):
214
+ return
215
+
216
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk_softmax, self.model):
217
+ return
218
+
219
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_qk, self.model):
220
+ return
221
+
222
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_qk, self.model):
223
+ return
224
+
225
+ # Q nodes
226
+ q_nodes = self.model.match_parent_path(
227
+ matmul_qk,
228
+ ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
229
+ [0, 0, 0, 0, 0, None],
230
+ )
231
+
232
+ if q_nodes is None:
233
+ logger.debug("fuse_qordered_attention: failed to match q path")
234
+ return
235
+
236
+ (_, reshape_q, dequantize_q, quantize_q, add_q, matmul_q) = q_nodes
237
+
238
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
239
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_q, self.model):
240
+ return
241
+
242
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_q, self.model):
243
+ return
244
+
245
+ # Q MatMul weight
246
+ dequantize_q_matmul_weight = self.model.match_parent_path(matmul_q, ["DequantizeLinear"], [1])
247
+
248
+ if dequantize_q_matmul_weight is None:
249
+ logger.debug("fuse_qordered_attention: failed to match q path")
250
+ return
251
+
252
+ dequantize_q_matmul_weight = dequantize_q_matmul_weight[0]
253
+
254
+ if self.model.get_constant_value(dequantize_q_matmul_weight.input[0]) is None:
255
+ return
256
+
257
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
258
+ # Per-channel scales are supported for weights alone
259
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_q_matmul_weight, self.model, False):
260
+ return
261
+
262
+ # K nodes
263
+ k_nodes = self.model.match_parent_path(
264
+ matmul_qk,
265
+ ["Transpose", "Reshape", "DequantizeLinear", "QuantizeLinear", "Add", "MatMul"],
266
+ [1, 0, 0, 0, 0, None],
267
+ )
268
+
269
+ if k_nodes is None:
270
+ logger.debug("fuse_qordered_attention: failed to match k path")
271
+ return
272
+
273
+ (_, _, dequantize_k, quantize_k, add_k, matmul_k) = k_nodes
274
+
275
+ # Make sure the Q/DQ has the proper zero points and constant per-tensor scales
276
+ if not FusionUtils.check_qdq_node_for_fusion(quantize_k, self.model):
277
+ return
278
+
279
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_k, self.model):
280
+ return
281
+
282
+ # K MatMul weight
283
+ dequantize_k_matmul_weight = self.model.match_parent_path(matmul_k, ["DequantizeLinear"], [1])
284
+
285
+ if dequantize_k_matmul_weight is None:
286
+ logger.debug("fuse_qordered_attention: failed to match k path")
287
+ return
288
+
289
+ dequantize_k_matmul_weight = dequantize_k_matmul_weight[0]
290
+
291
+ if self.model.get_constant_value(dequantize_k_matmul_weight.input[0]) is None:
292
+ return
293
+
294
+ # Make sure the upstream DequantizeLinear-1 has the proper zero points and scales
295
+ # Per-channel scales are supported for weights alone
296
+ if not FusionUtils.check_qdq_node_for_fusion(dequantize_k_matmul_weight, self.model, False):
297
+ return
298
+
299
+ # Mask nodes
300
+ mask_nodes = self.model.match_parent_path(
301
+ add_qk, ["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"], [None, 0, 1, 0, 0]
302
+ )
303
+
304
+ if mask_nodes is None:
305
+ logger.debug("fuse_qordered_attention: failed to match mask_nodes path")
306
+ return
307
+
308
+ # Ascertain `qkv_hidden_sizes` attribute value
309
+ q_weight = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
310
+ k_weight = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
311
+ v_weight = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
312
+
313
+ qw = NumpyHelper.to_array(q_weight)
314
+ kw = NumpyHelper.to_array(k_weight)
315
+ vw = NumpyHelper.to_array(v_weight)
316
+
317
+ qw_out_size = np.prod(qw.shape[1:])
318
+ kw_out_size = np.prod(kw.shape[1:])
319
+ vw_out_size = np.prod(vw.shape[1:])
320
+
321
+ # Form QOrderedAttention node
322
+ if matmul_v.input[0] == root_input and matmul_q.input[0] == root_input and matmul_k.input[0] == root_input:
323
+ mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
324
+
325
+ # Ascertain `num_heads` and `hidden_size`
326
+ num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
327
+
328
+ # Formulate the inputs
329
+ # Actual quantized input
330
+ attention_inputs = [dequantize_input.input[0]]
331
+ attention_inputs.append(dequantize_input.input[1])
332
+
333
+ attention_inputs.append(dequantize_q.input[1])
334
+ attention_inputs.append(dequantize_k.input[1])
335
+ attention_inputs.append(dequantize_v.input[1])
336
+
337
+ attention_inputs.append(dequantize_q_matmul_weight.input[0])
338
+ attention_inputs.append(dequantize_k_matmul_weight.input[0])
339
+ attention_inputs.append(dequantize_v_matmul_weight.input[0])
340
+
341
+ attention_inputs.append(dequantize_q_matmul_weight.input[1])
342
+ attention_inputs.append(dequantize_k_matmul_weight.input[1])
343
+ attention_inputs.append(dequantize_v_matmul_weight.input[1])
344
+
345
+ if self.model.get_initializer(add_q.input[0]):
346
+ attention_inputs.append(add_q.input[0])
347
+ else: # second input is the constant bias
348
+ attention_inputs.append(add_q.input[1])
349
+
350
+ if self.model.get_initializer(add_k.input[0]):
351
+ attention_inputs.append(add_k.input[0])
352
+ else: # second input is the constant bias
353
+ attention_inputs.append(add_k.input[1])
354
+
355
+ if self.model.get_initializer(add_v.input[0]):
356
+ attention_inputs.append(add_v.input[0])
357
+ else: # second input is the constant bias
358
+ attention_inputs.append(add_v.input[1])
359
+
360
+ attention_inputs.append(quantize_qk.input[1])
361
+ attention_inputs.append(quantize_qk_softmax.input[1])
362
+ attention_inputs.append(dequantize_qkv.input[1])
363
+
364
+ # Mask input
365
+ if mask_index is not None:
366
+ attention_inputs.append(mask_index)
367
+ else:
368
+ attention_inputs.append("")
369
+
370
+ # The MatMul weight 'B' and 'bias' need some post-processing
371
+ # Transpose weight 'B' from order ROW to order COL
372
+ # This offline transpose is needed only while using the CUDA EP
373
+ # TODO: Make this fusion logic EP-agnostic ?
374
+ q_weight_tensor = self.model.get_initializer(dequantize_q_matmul_weight.input[0])
375
+ FusionUtils.transpose_2d_int8_tensor(q_weight_tensor)
376
+
377
+ k_weight_tensor = self.model.get_initializer(dequantize_k_matmul_weight.input[0])
378
+ FusionUtils.transpose_2d_int8_tensor(k_weight_tensor)
379
+
380
+ v_weight_tensor = self.model.get_initializer(dequantize_v_matmul_weight.input[0])
381
+ FusionUtils.transpose_2d_int8_tensor(v_weight_tensor)
382
+
383
+ # Name and create Attention node
384
+ attention_node_name = self.model.create_node_name("QOrderedAttention")
385
+
386
+ attention_node = helper.make_node(
387
+ "QOrderedAttention",
388
+ inputs=attention_inputs,
389
+ outputs=[reshape_qkv.output[0]],
390
+ name=attention_node_name,
391
+ )
392
+
393
+ self.model.replace_node_input(dequantize_qkv, dequantize_qkv.input[0], attention_node.output[0])
394
+ self.model.replace_node_input(projection_matmul, projection_matmul.input[0], dequantize_qkv.output[0])
395
+
396
+ attention_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
397
+ attention_node.attribute.extend([helper.make_attribute("order_input", 1)])
398
+ attention_node.attribute.extend([helper.make_attribute("order_weight", 0)])
399
+ attention_node.attribute.extend([helper.make_attribute("order_output", 1)])
400
+ attention_node.attribute.extend(
401
+ [helper.make_attribute("qkv_hidden_sizes", [qw_out_size, kw_out_size, vw_out_size])]
402
+ )
403
+
404
+ attention_node.domain = "com.microsoft"
405
+
406
+ self.nodes_to_add.append(attention_node)
407
+ self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
408
+
409
+ self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, quantize_qkv, matmul_qkv])
410
+ self.nodes_to_remove.extend(qk_nodes)
411
+ self.nodes_to_remove.extend(q_nodes)
412
+ self.nodes_to_remove.extend(k_nodes)
413
+ self.nodes_to_remove.extend(v_nodes)
414
+ self.nodes_to_remove.extend(
415
+ [dequantize_q_matmul_weight, dequantize_k_matmul_weight, dequantize_v_matmul_weight]
416
+ )
417
+
418
+ # Use prune graph to remove mask nodes since they are shared by all attention nodes.
419
+ # self.nodes_to_remove.extend(mask_nodes)
420
+ self.prune_graph = True
@@ -0,0 +1,118 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+
6
+ from logging import getLogger
7
+
8
+ from fusion_base import Fusion
9
+ from fusion_utils import FusionUtils
10
+ from onnx import helper
11
+ from onnx_model import OnnxModel
12
+
13
+ logger = getLogger(__name__)
14
+
15
+
16
+ class FusionQOrderedGelu(Fusion):
17
+ def __init__(self, model: OnnxModel):
18
+ super().__init__(model, "QOrderedGelu", ["Gelu", "FastGelu"])
19
+
20
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
21
+ """
22
+ INPUT PATTERN
23
+ Fuse (quantized) Gelu subgraph into one node QOrderedGelu:
24
+ -> quantized input -> DQ -> Gelu -> Q ->
25
+
26
+ (or)
27
+
28
+ -> quantized input -> DQ -> FastGelu -> Q ->
29
+
30
+ OUTPUT PATTERN
31
+ -> QOrderedGelu ->
32
+ """
33
+ gelu_children = self.model.get_children(node, input_name_to_nodes)
34
+
35
+ # Should only have 1 child - QuantizeLinear (or)
36
+ # Should have 2 children - QuantizeLinear + Shape
37
+ if not (
38
+ (len(gelu_children) == 1 and gelu_children[0].op_type == "QuantizeLinear")
39
+ or (
40
+ len(gelu_children) == 2
41
+ and gelu_children[0].op_type == "QuantizeLinear"
42
+ and gelu_children[1].op_type == "Shape"
43
+ )
44
+ ):
45
+ return
46
+
47
+ downstream_quantize_node = gelu_children[0]
48
+ downstream_shape_node = None
49
+
50
+ if len(gelu_children) == 2:
51
+ downstream_shape_node = gelu_children[1]
52
+
53
+ if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
54
+ return
55
+
56
+ # The first input to Gelu should flow through a DequantizeLinear node
57
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
58
+ node,
59
+ [(["DequantizeLinear"], [0])],
60
+ output_name_to_node,
61
+ )
62
+
63
+ if first_path_id < 0:
64
+ return
65
+
66
+ upstream_dequantize_node = first_input_parent_nodes[0]
67
+
68
+ if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
69
+ return
70
+
71
+ # Fusion logic
72
+ subgraph_nodes = [node] # Gelu/FastGelu
73
+ subgraph_nodes.extend([downstream_quantize_node, upstream_dequantize_node]) # Relevant Q, DQ nodes
74
+
75
+ if not self.model.is_safe_to_fuse_nodes(
76
+ subgraph_nodes,
77
+ (
78
+ [node.output[0], downstream_quantize_node.output[0]]
79
+ if downstream_shape_node is not None
80
+ else downstream_quantize_node.output
81
+ ),
82
+ input_name_to_nodes,
83
+ output_name_to_node,
84
+ ):
85
+ logger.debug("It is not safe to fuse QOrderedGelu node. Skip")
86
+ return
87
+
88
+ self.nodes_to_remove.extend(subgraph_nodes)
89
+
90
+ ordered_gelu_node = helper.make_node(
91
+ "QOrderedGelu",
92
+ inputs=[
93
+ upstream_dequantize_node.input[0],
94
+ upstream_dequantize_node.input[1],
95
+ downstream_quantize_node.input[1],
96
+ ],
97
+ outputs=[downstream_quantize_node.output[0]],
98
+ name=self.model.create_node_name("QOrderedGelu", name_prefix="QOrderedGelu"),
99
+ )
100
+
101
+ # Arrange the downstream Shape's input to be fed from the
102
+ # downstream QuantizeLinear node, so that fusion will
103
+ # be deemed safe
104
+ if downstream_shape_node is not None:
105
+ self.model.replace_node_input(
106
+ downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
107
+ )
108
+
109
+ # TODO: We only support CuBlasLt order ORDER_ROW for now.
110
+ # Once we start supporting other data ordering format(s), we
111
+ # will support user configuring the data ordering for the op.
112
+ ordered_gelu_node.attribute.extend([helper.make_attribute("order_X", 1)])
113
+ ordered_gelu_node.attribute.extend([helper.make_attribute("order_Y", 1)])
114
+
115
+ ordered_gelu_node.domain = "com.microsoft"
116
+
117
+ self.nodes_to_add.append(ordered_gelu_node)
118
+ self.node_name_to_graph_name[ordered_gelu_node.name] = self.this_graph_name
@@ -0,0 +1,122 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ from logging import getLogger
6
+
7
+ from fusion_base import Fusion
8
+ from fusion_utils import FusionUtils
9
+ from onnx import helper
10
+ from onnx_model import OnnxModel
11
+
12
+ logger = getLogger(__name__)
13
+
14
+
15
+ class FusionQOrderedLayerNormalization(Fusion):
16
+ def __init__(self, model: OnnxModel):
17
+ super().__init__(model, "QOrderedLayerNormalization", "LayerNormalization")
18
+
19
+ def fuse(self, node, input_name_to_nodes: dict, output_name_to_node: dict):
20
+ """
21
+ Fuse (quantized) Layer Normalization subgraph into one node QOrderedLayerNormalization:
22
+ quantized input -> DQ
23
+ |
24
+ |
25
+ (other inputs)-> LayerNormalization --> Q -->
26
+
27
+ should become
28
+
29
+ (quantized input + other inputs)-> QOrderedLayerNormalization --> Q -->
30
+ """
31
+
32
+ children = self.model.get_children(node, input_name_to_nodes)
33
+
34
+ # Should only have 1 child - QuantizeLinear (or)
35
+ # Should have 2 children - QuantizeLinear + Shape
36
+ if not (
37
+ (len(children) == 1 and children[0].op_type == "QuantizeLinear")
38
+ or (len(children) == 2 and children[0].op_type == "QuantizeLinear" and children[1].op_type == "Shape")
39
+ ):
40
+ return
41
+
42
+ downstream_quantize_node = children[0]
43
+ downstream_shape_node = None
44
+
45
+ if len(children) == 2:
46
+ downstream_shape_node = children[1]
47
+
48
+ if not FusionUtils.check_qdq_node_for_fusion(downstream_quantize_node, self.model):
49
+ return
50
+
51
+ # The first input to LayerNormalization should flow through a DequantizeLinear node
52
+ first_path_id, first_input_parent_nodes, _ = self.model.match_parent_paths(
53
+ node,
54
+ [(["DequantizeLinear"], [0])],
55
+ output_name_to_node,
56
+ )
57
+
58
+ if first_path_id < 0:
59
+ return
60
+
61
+ upstream_dequantize_node = first_input_parent_nodes[0]
62
+
63
+ if not FusionUtils.check_qdq_node_for_fusion(upstream_dequantize_node, self.model):
64
+ return
65
+
66
+ # Fusion logic
67
+ subgraph_nodes = [node] # LayerNormalization
68
+ subgraph_nodes.extend([downstream_quantize_node]) # Q node after LayerNormalization
69
+
70
+ upstream_dequantize_node_children = self.model.get_children(upstream_dequantize_node, input_name_to_nodes)
71
+
72
+ # In GPT2, the DQ node will be feeding a residual downstream Add and hence,
73
+ # we do not want to remove it
74
+ if len(upstream_dequantize_node_children) == 1:
75
+ subgraph_nodes.extend([upstream_dequantize_node]) # DQ node before LayerNormalization
76
+
77
+ if not self.model.is_safe_to_fuse_nodes(
78
+ subgraph_nodes,
79
+ (
80
+ [node.output[0], downstream_quantize_node.output[0]]
81
+ if downstream_shape_node is not None
82
+ else downstream_quantize_node.output
83
+ ),
84
+ input_name_to_nodes,
85
+ output_name_to_node,
86
+ ):
87
+ logger.debug("It is not safe to fuse QOrderedLayerNormalization node. Skip")
88
+ return
89
+
90
+ self.nodes_to_remove.extend(subgraph_nodes)
91
+
92
+ normalize_node = helper.make_node(
93
+ "QOrderedLayerNormalization",
94
+ inputs=[
95
+ upstream_dequantize_node.input[0],
96
+ upstream_dequantize_node.input[1],
97
+ node.input[1],
98
+ node.input[2],
99
+ downstream_quantize_node.input[1],
100
+ ],
101
+ outputs=[downstream_quantize_node.output[0]],
102
+ name=self.model.create_node_name("QOrderedLayerNormalization", name_prefix="QOrderedLayerNormalization"),
103
+ )
104
+
105
+ # Arrange the downstream Shape's input to be fed from the
106
+ # downstream QuantizeLinear node, so that fusion will
107
+ # be deemed safe
108
+ if downstream_shape_node is not None:
109
+ self.model.replace_node_input(
110
+ downstream_shape_node, downstream_shape_node.input[0], downstream_quantize_node.output[0]
111
+ )
112
+
113
+ # TODO: We only support CuBlasLt order ORDER_ROW for now.
114
+ # Once we start supporting other data ordering format(s), we
115
+ # will support user configuring the data ordering for the op.
116
+ normalize_node.attribute.extend([helper.make_attribute("order_X", 1)])
117
+ normalize_node.attribute.extend([helper.make_attribute("order_Y", 1)])
118
+
119
+ normalize_node.domain = "com.microsoft"
120
+
121
+ self.nodes_to_add.append(normalize_node)
122
+ self.node_name_to_graph_name[normalize_node.name] = self.this_graph_name