onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,667 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import FusionUtils
|
|
10
|
+
from onnx import NodeProto, TensorProto, helper, numpy_helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionMultiHeadAttentionMMDit(Fusion):
|
|
17
|
+
"""
|
|
18
|
+
Fuse MultiHeadAttention for Multimodal Diffusion Transformer (MMDiT).
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
def __init__(self, model: OnnxModel):
|
|
22
|
+
super().__init__(model, fused_op_type="MultiHeadAttention", search_op_types=["Softmax"])
|
|
23
|
+
self.unsqueeze_update_map = {}
|
|
24
|
+
|
|
25
|
+
def get_num_heads(self, start_node: NodeProto, output_name_to_node, input_index=0) -> int:
|
|
26
|
+
"""
|
|
27
|
+
Detect num_heads from Reshape & Transpose of q/k/v for both Stable Diffusion 3.x and Flux 1.x:
|
|
28
|
+
|
|
29
|
+
MatMul .. [-1] [24] ..
|
|
30
|
+
| | | / /
|
|
31
|
+
Add Concat(axis=0)
|
|
32
|
+
| /
|
|
33
|
+
Reshape
|
|
34
|
+
|
|
|
35
|
+
Transpose(perm=0,1,3,2)
|
|
36
|
+
|
|
|
37
|
+
(start_node)
|
|
38
|
+
"""
|
|
39
|
+
nodes = self.model.match_parent_path(
|
|
40
|
+
start_node, ["Transpose", "Reshape", "Concat"], [input_index, 0, 1], output_name_to_node=output_name_to_node
|
|
41
|
+
)
|
|
42
|
+
if nodes is None:
|
|
43
|
+
return 0
|
|
44
|
+
|
|
45
|
+
concat_shape = nodes[-1]
|
|
46
|
+
if len(concat_shape.input) != 4:
|
|
47
|
+
return 0
|
|
48
|
+
|
|
49
|
+
value = self.model.get_constant_value(concat_shape.input[2])
|
|
50
|
+
if value is None:
|
|
51
|
+
return 0
|
|
52
|
+
|
|
53
|
+
if len(value.shape) != 1:
|
|
54
|
+
return 0
|
|
55
|
+
|
|
56
|
+
return int(value[0])
|
|
57
|
+
|
|
58
|
+
def get_num_heads_from_k(self, transpose_k: NodeProto, output_name_to_node, concat_before_transpose: bool) -> int:
|
|
59
|
+
"""
|
|
60
|
+
Detect num_heads from subgraph like the following (num_heads=24 in this example):
|
|
61
|
+
MatMu .. [-1] [24] ..
|
|
62
|
+
| | | / /
|
|
63
|
+
Add Concat
|
|
64
|
+
| /
|
|
65
|
+
Reshape
|
|
66
|
+
|
|
|
67
|
+
Transpose(perm=0,2,1,3)
|
|
68
|
+
|
|
|
69
|
+
SimplifiedLayerNormalization
|
|
70
|
+
|
|
|
71
|
+
Transpose(perm=0,1,3,2)
|
|
72
|
+
|
|
73
|
+
Another variant is to an extra Concat node to join two symmetrical subgraphs:
|
|
74
|
+
|
|
75
|
+
| |
|
|
76
|
+
MatMul MatMul .. [-1] [24] ..
|
|
77
|
+
| | | | / /
|
|
78
|
+
Add Concat Add Concat
|
|
79
|
+
| / | /
|
|
80
|
+
Reshape Reshape
|
|
81
|
+
| |
|
|
82
|
+
Transpose Transpose(perm=0,2,1,3)
|
|
83
|
+
| |
|
|
84
|
+
SimplifiedLayerNormalization SimplifiedLayerNormalization
|
|
85
|
+
| /
|
|
86
|
+
Concat
|
|
87
|
+
|
|
|
88
|
+
Transpose(perm=0,1,3,2)
|
|
89
|
+
|
|
90
|
+
Both patterns are used in stable diffusion 3.5 model.
|
|
91
|
+
"""
|
|
92
|
+
if concat_before_transpose:
|
|
93
|
+
nodes = self.model.match_parent_path(
|
|
94
|
+
transpose_k, ["Concat", "SimplifiedLayerNormalization"], [0, 1], output_name_to_node=output_name_to_node
|
|
95
|
+
)
|
|
96
|
+
if nodes:
|
|
97
|
+
return self.get_num_heads(nodes[1], output_name_to_node)
|
|
98
|
+
else:
|
|
99
|
+
nodes = self.model.match_parent_path(
|
|
100
|
+
transpose_k, ["SimplifiedLayerNormalization"], [0], output_name_to_node=output_name_to_node
|
|
101
|
+
)
|
|
102
|
+
if nodes:
|
|
103
|
+
return self.get_num_heads(nodes[0], output_name_to_node)
|
|
104
|
+
|
|
105
|
+
return 0
|
|
106
|
+
|
|
107
|
+
def reshape_to_3d(self, input_name: str, output_name: str) -> str:
|
|
108
|
+
"""Add a Reshape node to convert 4D BxSxNxH to 3D BxSxD.
|
|
109
|
+
|
|
110
|
+
Args:
|
|
111
|
+
input_name (str): input name for the 4D tensor of shape BxSxNxH.
|
|
112
|
+
output_name (str): output name for the 3D tensor of shape BxSxD, where D = N * H.
|
|
113
|
+
|
|
114
|
+
Returns:
|
|
115
|
+
str: the output name
|
|
116
|
+
"""
|
|
117
|
+
|
|
118
|
+
new_dims_name = "bsnh_to_bsd_reshape_dims"
|
|
119
|
+
new_dims = self.model.get_initializer(new_dims_name)
|
|
120
|
+
if new_dims is None:
|
|
121
|
+
new_dims = numpy_helper.from_array(np.array([0, 0, -1], dtype="int64"), name=new_dims_name)
|
|
122
|
+
self.model.add_initializer(new_dims, self.this_graph_name)
|
|
123
|
+
reshape_q = helper.make_node(
|
|
124
|
+
"Reshape",
|
|
125
|
+
inputs=[input_name, new_dims_name],
|
|
126
|
+
outputs=[output_name],
|
|
127
|
+
name=self.model.create_node_name("Reshape"),
|
|
128
|
+
)
|
|
129
|
+
self.nodes_to_add.append(reshape_q)
|
|
130
|
+
self.node_name_to_graph_name[reshape_q.name] = self.this_graph_name
|
|
131
|
+
return reshape_q.output[0]
|
|
132
|
+
|
|
133
|
+
def adjust_query_from_bnsh_to_bsd_no_concat(self, mul_q: NodeProto, output_name_to_node) -> str | None:
|
|
134
|
+
"""
|
|
135
|
+
MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format.
|
|
136
|
+
|
|
137
|
+
Before:
|
|
138
|
+
MatMul
|
|
139
|
+
|
|
|
140
|
+
Add Concat
|
|
141
|
+
| /
|
|
142
|
+
Reshape
|
|
143
|
+
|
|
|
144
|
+
Transpose(perm=0,2,1,3)
|
|
145
|
+
|
|
|
146
|
+
SimplifiedLayerNorm
|
|
147
|
+
|
|
|
148
|
+
Mul
|
|
149
|
+
|
|
150
|
+
After:
|
|
151
|
+
MatMul
|
|
152
|
+
|
|
|
153
|
+
Add Concat
|
|
154
|
+
| /
|
|
155
|
+
Reshape
|
|
156
|
+
|
|
|
157
|
+
SimplifiedLayerNorm
|
|
158
|
+
|
|
|
159
|
+
Reshape (shape=[0, 0, -1])
|
|
160
|
+
"""
|
|
161
|
+
|
|
162
|
+
path = self.model.match_parent_path(
|
|
163
|
+
mul_q,
|
|
164
|
+
["SimplifiedLayerNormalization", "Transpose"],
|
|
165
|
+
[0, 0],
|
|
166
|
+
)
|
|
167
|
+
if path is None:
|
|
168
|
+
return None
|
|
169
|
+
sln_a, transpose_a = path
|
|
170
|
+
|
|
171
|
+
if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
|
|
172
|
+
return None
|
|
173
|
+
|
|
174
|
+
# Update the graph
|
|
175
|
+
sln_a.input[0] = transpose_a.input[0]
|
|
176
|
+
sln_output = sln_a.output[0]
|
|
177
|
+
sln_a.output[0] = sln_output + "_BSNH"
|
|
178
|
+
|
|
179
|
+
return self.reshape_to_3d(sln_a.output[0], sln_output + "_BSD")
|
|
180
|
+
|
|
181
|
+
def adjust_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
|
|
182
|
+
"""
|
|
183
|
+
MultiHeadAttenion requires query in BSD format. This function adjusts query from BNSH to BSD format.
|
|
184
|
+
|
|
185
|
+
Before:
|
|
186
|
+
MatMul MatMul
|
|
187
|
+
| |
|
|
188
|
+
Add Concat Add Concat
|
|
189
|
+
| / | /
|
|
190
|
+
Reshape Reshape
|
|
191
|
+
| |
|
|
192
|
+
Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
|
|
193
|
+
| |
|
|
194
|
+
SimplifiedLayerNorm SimplifiedLayerNorm
|
|
195
|
+
| /
|
|
196
|
+
Concat(axis=2)
|
|
197
|
+
|
|
|
198
|
+
Mul
|
|
199
|
+
|
|
200
|
+
After:
|
|
201
|
+
MatMul MatMul
|
|
202
|
+
| |
|
|
203
|
+
Add Concat Add Concat
|
|
204
|
+
| / | /
|
|
205
|
+
Reshape Reshape
|
|
206
|
+
| |
|
|
207
|
+
SimplifiedLayerNorm SimplifiedLayerNorm
|
|
208
|
+
| /
|
|
209
|
+
Concat(axis=1)
|
|
210
|
+
|
|
|
211
|
+
Reshape (shape=[0, 0, -1])
|
|
212
|
+
"""
|
|
213
|
+
|
|
214
|
+
path = self.model.match_parent_path(
|
|
215
|
+
mul_q,
|
|
216
|
+
["Concat", "SimplifiedLayerNormalization", "Transpose"],
|
|
217
|
+
[0, 0, 0],
|
|
218
|
+
)
|
|
219
|
+
if path is None:
|
|
220
|
+
return None
|
|
221
|
+
concat, sln_a, transpose_a = path
|
|
222
|
+
|
|
223
|
+
if len(concat.input) != 2:
|
|
224
|
+
return None
|
|
225
|
+
|
|
226
|
+
path = self.model.match_parent_path(
|
|
227
|
+
concat,
|
|
228
|
+
["SimplifiedLayerNormalization", "Transpose"],
|
|
229
|
+
[1, 0],
|
|
230
|
+
)
|
|
231
|
+
if path is None:
|
|
232
|
+
return None
|
|
233
|
+
sln_b, transpose_b = path
|
|
234
|
+
|
|
235
|
+
if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
|
|
236
|
+
return None
|
|
237
|
+
|
|
238
|
+
if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]):
|
|
239
|
+
return None
|
|
240
|
+
|
|
241
|
+
if not FusionUtils.check_node_attribute(concat, "axis", 2):
|
|
242
|
+
return None
|
|
243
|
+
|
|
244
|
+
# Update the graph
|
|
245
|
+
sln_a.input[0] = transpose_a.input[0]
|
|
246
|
+
sln_b.input[0] = transpose_b.input[0]
|
|
247
|
+
|
|
248
|
+
new_concat_node = helper.make_node(
|
|
249
|
+
"Concat",
|
|
250
|
+
inputs=[sln_a.output[0], sln_b.output[0]],
|
|
251
|
+
outputs=[concat.output[0] + "_BSNH"],
|
|
252
|
+
name=self.model.create_node_name("Concat"),
|
|
253
|
+
axis=1,
|
|
254
|
+
)
|
|
255
|
+
self.nodes_to_add.append(new_concat_node)
|
|
256
|
+
self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name
|
|
257
|
+
|
|
258
|
+
return self.reshape_to_3d(new_concat_node.output[0], concat.output[0] + "_BSD")
|
|
259
|
+
|
|
260
|
+
def update_unsqueeze_axes_1_to_2(self, unsqueeze: NodeProto) -> str:
|
|
261
|
+
updated_unsqueeze_output = self.unsqueeze_update_map.get(unsqueeze.name)
|
|
262
|
+
if updated_unsqueeze_output is None:
|
|
263
|
+
if len(unsqueeze.input) == 1:
|
|
264
|
+
new_node = helper.make_node(
|
|
265
|
+
"Unsqueeze",
|
|
266
|
+
inputs=unsqueeze.input,
|
|
267
|
+
outputs=[unsqueeze.output[0] + "_BSNH"],
|
|
268
|
+
name=self.model.create_node_name("Unsqueeze"),
|
|
269
|
+
axes=[2],
|
|
270
|
+
)
|
|
271
|
+
else:
|
|
272
|
+
initializer_name = "unsqueeze_axes_2"
|
|
273
|
+
if self.model.get_initializer(initializer_name) is None:
|
|
274
|
+
unsqueeze_axes_2 = helper.make_tensor(
|
|
275
|
+
name=initializer_name,
|
|
276
|
+
data_type=TensorProto.INT64,
|
|
277
|
+
dims=[1], # Shape of the tensor
|
|
278
|
+
vals=[2], # Tensor values
|
|
279
|
+
)
|
|
280
|
+
self.model.add_initializer(unsqueeze_axes_2, self.this_graph_name)
|
|
281
|
+
|
|
282
|
+
new_node = helper.make_node(
|
|
283
|
+
"Unsqueeze",
|
|
284
|
+
inputs=[unsqueeze.input[0], initializer_name],
|
|
285
|
+
outputs=[unsqueeze.output[0] + "_BSNH"],
|
|
286
|
+
name=self.model.create_node_name("Unsqueeze"),
|
|
287
|
+
)
|
|
288
|
+
|
|
289
|
+
self.nodes_to_add.append(new_node)
|
|
290
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
291
|
+
updated_unsqueeze_output = new_node.output[0]
|
|
292
|
+
self.unsqueeze_update_map[unsqueeze.name] = updated_unsqueeze_output
|
|
293
|
+
|
|
294
|
+
return updated_unsqueeze_output
|
|
295
|
+
|
|
296
|
+
def update_unsqueeze_axes(self, add: NodeProto, output_name_to_node: dict[str, NodeProto]) -> bool:
|
|
297
|
+
"""
|
|
298
|
+
Update axes of Unsqueeze from [1] to [2] in the following pattern:
|
|
299
|
+
Unsqueeze Unsqueeze
|
|
300
|
+
(axes=[0]) (axes=[0])
|
|
301
|
+
| |
|
|
302
|
+
Unsqueeze Unsqueeze
|
|
303
|
+
... (axes=[1]) ... (axes=[1])
|
|
304
|
+
| / | /
|
|
305
|
+
Mul Mul
|
|
306
|
+
| /
|
|
307
|
+
Add
|
|
308
|
+
Args:
|
|
309
|
+
add (NodeProto): the Add node
|
|
310
|
+
output_name_to_node (Dict[str, NodeProto]): mapping from output name to node
|
|
311
|
+
|
|
312
|
+
Returns:
|
|
313
|
+
bool: True if the pattern is matched and updated successfully, False otherwise.
|
|
314
|
+
"""
|
|
315
|
+
if len(add.input) != 2:
|
|
316
|
+
return False
|
|
317
|
+
|
|
318
|
+
# Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively.
|
|
319
|
+
nodes_b = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [1, 1, 0], output_name_to_node)
|
|
320
|
+
if nodes_b is None:
|
|
321
|
+
return False
|
|
322
|
+
|
|
323
|
+
fusion_utils = FusionUtils(self.model)
|
|
324
|
+
axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[1])
|
|
325
|
+
if axes_1 is None or axes_1 != [1]:
|
|
326
|
+
return False
|
|
327
|
+
|
|
328
|
+
axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_b[2])
|
|
329
|
+
if axes_0 is None or axes_0 != [0]:
|
|
330
|
+
return False
|
|
331
|
+
|
|
332
|
+
# Check axes of Unsqueeze nodes are [0] and [1], and change to [0] and [2] respectively.
|
|
333
|
+
nodes_a = self.model.match_parent_path(add, ["Mul", "Unsqueeze", "Unsqueeze"], [0, 1, 0], output_name_to_node)
|
|
334
|
+
if nodes_a is None:
|
|
335
|
+
return False
|
|
336
|
+
|
|
337
|
+
axes_1 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[1])
|
|
338
|
+
if axes_1 is None or axes_1 != [1]:
|
|
339
|
+
return False
|
|
340
|
+
|
|
341
|
+
axes_0 = fusion_utils.get_squeeze_or_unsqueeze_axes(nodes_a[2])
|
|
342
|
+
if axes_0 is None or axes_0 != [0]:
|
|
343
|
+
return False
|
|
344
|
+
|
|
345
|
+
nodes_a[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_a[1])
|
|
346
|
+
nodes_b[0].input[1] = self.update_unsqueeze_axes_1_to_2(nodes_b[1])
|
|
347
|
+
return True
|
|
348
|
+
|
|
349
|
+
def adjust_flux_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
|
|
350
|
+
"""
|
|
351
|
+
Adjust graph to change query format from BNSH to BSD for Flux model.
|
|
352
|
+
Note that the graph pattern is complex, and we only do a shallow match here.
|
|
353
|
+
|
|
354
|
+
Before:
|
|
355
|
+
| |
|
|
356
|
+
Transpose(perm=0,2,1,3) Transpose(perm=0,2,1,3)
|
|
357
|
+
| |
|
|
358
|
+
SimplifiedLayerNorm SimplifiedLayerNorm
|
|
359
|
+
| /
|
|
360
|
+
Concat(axis=2)
|
|
361
|
+
|
|
|
362
|
+
Mul Mul
|
|
363
|
+
| /
|
|
364
|
+
Add
|
|
365
|
+
|
|
|
366
|
+
Mul
|
|
367
|
+
|
|
368
|
+
After (Transpose nods are removed, and a Reshape is added):
|
|
369
|
+
|
|
370
|
+
| |
|
|
371
|
+
SimplifiedLayerNorm SimplifiedLayerNorm
|
|
372
|
+
| /
|
|
373
|
+
Concat(axis=1)
|
|
374
|
+
|
|
|
375
|
+
Mul Mul
|
|
376
|
+
| /
|
|
377
|
+
Add
|
|
378
|
+
|
|
|
379
|
+
Reshape (shape=[0, 0, -1])
|
|
380
|
+
"""
|
|
381
|
+
|
|
382
|
+
path = self.model.match_parent_path(
|
|
383
|
+
mul_q,
|
|
384
|
+
["Add", "Mul", "Concat", "SimplifiedLayerNormalization", "Transpose"],
|
|
385
|
+
[0, 0, 0, 0, 0],
|
|
386
|
+
)
|
|
387
|
+
if path is None:
|
|
388
|
+
return None
|
|
389
|
+
add, _mul_a, concat, sln_a, transpose_a = path
|
|
390
|
+
|
|
391
|
+
if len(concat.input) != 2:
|
|
392
|
+
return None
|
|
393
|
+
|
|
394
|
+
path = self.model.match_parent_path(
|
|
395
|
+
concat,
|
|
396
|
+
["SimplifiedLayerNormalization", "Transpose"],
|
|
397
|
+
[1, 0],
|
|
398
|
+
)
|
|
399
|
+
if path is None:
|
|
400
|
+
return None
|
|
401
|
+
sln_b, transpose_b = path
|
|
402
|
+
|
|
403
|
+
if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
|
|
404
|
+
return None
|
|
405
|
+
|
|
406
|
+
if not FusionUtils.check_node_attribute(transpose_b, "perm", [0, 2, 1, 3]):
|
|
407
|
+
return None
|
|
408
|
+
|
|
409
|
+
if not FusionUtils.check_node_attribute(concat, "axis", 2):
|
|
410
|
+
return None
|
|
411
|
+
|
|
412
|
+
# Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH.
|
|
413
|
+
if not self.update_unsqueeze_axes(add, output_name_to_node):
|
|
414
|
+
return None
|
|
415
|
+
|
|
416
|
+
# Update the graph
|
|
417
|
+
sln_a.input[0] = transpose_a.input[0]
|
|
418
|
+
sln_b.input[0] = transpose_b.input[0]
|
|
419
|
+
|
|
420
|
+
new_concat_node = helper.make_node(
|
|
421
|
+
"Concat",
|
|
422
|
+
inputs=[sln_a.output[0], sln_b.output[0]],
|
|
423
|
+
outputs=[concat.output[0] + "_BSNH"],
|
|
424
|
+
name=self.model.create_node_name("Concat"),
|
|
425
|
+
axis=1,
|
|
426
|
+
)
|
|
427
|
+
self.nodes_to_add.append(new_concat_node)
|
|
428
|
+
self.node_name_to_graph_name[new_concat_node.name] = self.this_graph_name
|
|
429
|
+
self.model.replace_input_of_all_nodes(concat.output[0], new_concat_node.output[0])
|
|
430
|
+
|
|
431
|
+
return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD")
|
|
432
|
+
|
|
433
|
+
def adjust_flux_single_query_from_bnsh_to_bsd(self, mul_q: NodeProto, output_name_to_node) -> str | None:
|
|
434
|
+
"""
|
|
435
|
+
Adjust graph to change query format from BNSH to BSD for Flux model.
|
|
436
|
+
Note that the graph pattern is complex, and we only do a shallow match here.
|
|
437
|
+
|
|
438
|
+
Before:
|
|
439
|
+
|
|
|
440
|
+
Transpose(perm=0,2,1,3)
|
|
441
|
+
|
|
|
442
|
+
SimplifiedLayerNorm
|
|
443
|
+
|
|
|
444
|
+
Mul Mul
|
|
445
|
+
| /
|
|
446
|
+
Add
|
|
447
|
+
|
|
|
448
|
+
Mul
|
|
449
|
+
|
|
450
|
+
After (Transpose is removed, and a Reshape is added):
|
|
451
|
+
|
|
452
|
+
|
|
|
453
|
+
SimplifiedLayerNorm
|
|
454
|
+
|
|
|
455
|
+
Mul Mul
|
|
456
|
+
| /
|
|
457
|
+
Add
|
|
458
|
+
|
|
|
459
|
+
Reshape (shape=[0, 0, -1])
|
|
460
|
+
"""
|
|
461
|
+
|
|
462
|
+
path = self.model.match_parent_path(
|
|
463
|
+
mul_q,
|
|
464
|
+
["Add", "Mul", "SimplifiedLayerNormalization", "Transpose"],
|
|
465
|
+
[0, 0, 0, 0],
|
|
466
|
+
)
|
|
467
|
+
if path is None:
|
|
468
|
+
return None
|
|
469
|
+
add, _mul_a, sln_a, transpose_a = path
|
|
470
|
+
|
|
471
|
+
if not FusionUtils.check_node_attribute(transpose_a, "perm", [0, 2, 1, 3]):
|
|
472
|
+
return None
|
|
473
|
+
|
|
474
|
+
# Need adjust axes of Unsqueeze nodes from [1] to [2] so that the tensors to Mul nodes are BSNH instead of BNSH.
|
|
475
|
+
if not self.update_unsqueeze_axes(add, output_name_to_node):
|
|
476
|
+
return None
|
|
477
|
+
|
|
478
|
+
# Update the graph
|
|
479
|
+
sln_a.input[0] = transpose_a.input[0]
|
|
480
|
+
add.output[0] = add.output[0] + "_BSNH"
|
|
481
|
+
|
|
482
|
+
return self.reshape_to_3d(add.output[0], add.output[0] + "_BSD")
|
|
483
|
+
|
|
484
|
+
def transpose_reshape_bnsh_to_bsd(self, q: str, output_name_to_node) -> str | None:
|
|
485
|
+
transpose_q = helper.make_node(
|
|
486
|
+
"Transpose",
|
|
487
|
+
[q],
|
|
488
|
+
[q + "_BSNH"],
|
|
489
|
+
name=self.model.create_node_name("Transpose", name_prefix="Transpose_BNSH_to_BSNH"),
|
|
490
|
+
perm=[0, 2, 1, 3],
|
|
491
|
+
)
|
|
492
|
+
self.nodes_to_add.append(transpose_q)
|
|
493
|
+
self.node_name_to_graph_name[transpose_q.name] = self.this_graph_name
|
|
494
|
+
|
|
495
|
+
return self.reshape_to_3d(q + "_BSNH", q + "_BSD")
|
|
496
|
+
|
|
497
|
+
def create_multihead_attention_node(
|
|
498
|
+
self,
|
|
499
|
+
q: str,
|
|
500
|
+
k: str,
|
|
501
|
+
v: str,
|
|
502
|
+
output: str,
|
|
503
|
+
num_heads: int,
|
|
504
|
+
) -> NodeProto:
|
|
505
|
+
"""
|
|
506
|
+
Create a MultiHeadAttention node.
|
|
507
|
+
|
|
508
|
+
Args:
|
|
509
|
+
q (str): name of q
|
|
510
|
+
k (str): name of k
|
|
511
|
+
v (str): name of v
|
|
512
|
+
output (str): output name of MHA
|
|
513
|
+
num_heads (int): number of attention heads. If a model is pruned, it is the number of heads after pruning.
|
|
514
|
+
|
|
515
|
+
Returns:
|
|
516
|
+
NodeProto: the node created.
|
|
517
|
+
"""
|
|
518
|
+
|
|
519
|
+
assert num_heads > 0
|
|
520
|
+
|
|
521
|
+
# Add inputs for MHA: Query, Key, Value (Proj_Bias, Mask, Attention_Bias, Past_K, Past_V are optional)
|
|
522
|
+
mha_inputs = [q, k, v]
|
|
523
|
+
|
|
524
|
+
# Add outputs for MHA (Present_K, Present_V are optional)
|
|
525
|
+
mha_outputs = [output]
|
|
526
|
+
|
|
527
|
+
mha_node = helper.make_node(
|
|
528
|
+
"MultiHeadAttention",
|
|
529
|
+
inputs=mha_inputs,
|
|
530
|
+
outputs=mha_outputs,
|
|
531
|
+
name=self.model.create_node_name("MultiHeadAttention"),
|
|
532
|
+
)
|
|
533
|
+
|
|
534
|
+
mha_node.domain = "com.microsoft"
|
|
535
|
+
mha_node.attribute.extend([helper.make_attribute("num_heads", num_heads)])
|
|
536
|
+
|
|
537
|
+
# No mask is used in MMDit model, so we need not set the optional mask_filter_value attribute.
|
|
538
|
+
return mha_node
|
|
539
|
+
|
|
540
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
541
|
+
assert node.op_type == "Softmax"
|
|
542
|
+
softmax = node
|
|
543
|
+
|
|
544
|
+
# Softmax output shall not be graph output.
|
|
545
|
+
if self.model.find_graph_output(softmax.output[0]):
|
|
546
|
+
return
|
|
547
|
+
|
|
548
|
+
nodes = self.model.match_child_path(
|
|
549
|
+
softmax, ["MatMul", "Transpose", "Reshape"], [(0, 0), (0, 0), (0, 0)], input_name_to_nodes
|
|
550
|
+
)
|
|
551
|
+
if nodes is None:
|
|
552
|
+
return
|
|
553
|
+
|
|
554
|
+
matmul_s_v, transpose_out, reshape_out = nodes
|
|
555
|
+
if not FusionUtils.check_node_attribute(transpose_out, "perm", [0, 2, 1, 3]):
|
|
556
|
+
return
|
|
557
|
+
|
|
558
|
+
q_nodes = self.model.match_parent_path(
|
|
559
|
+
softmax,
|
|
560
|
+
["MatMul", "Mul", "Sqrt", "Div", "Sqrt", "Cast", "Slice", "Shape"],
|
|
561
|
+
[0, 0, 1, 0, 1, 0, 0, 0],
|
|
562
|
+
)
|
|
563
|
+
|
|
564
|
+
if q_nodes is None:
|
|
565
|
+
return
|
|
566
|
+
|
|
567
|
+
matmul_qk, mul_q, sqrt_q_2, div_q, sqrt_q, _, _, shape_q = q_nodes
|
|
568
|
+
|
|
569
|
+
q_bnsh = mul_q.input[0]
|
|
570
|
+
if q_bnsh != shape_q.input[0]:
|
|
571
|
+
return
|
|
572
|
+
|
|
573
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Mul", "Transpose"], [1, 0])
|
|
574
|
+
if k_nodes is None:
|
|
575
|
+
return
|
|
576
|
+
|
|
577
|
+
mul_k, transpose_k = k_nodes
|
|
578
|
+
k = transpose_k.input[0]
|
|
579
|
+
if not FusionUtils.check_node_attribute(transpose_k, "perm", [0, 1, 3, 2]):
|
|
580
|
+
return
|
|
581
|
+
|
|
582
|
+
k_scale_nodes = self.model.match_parent_path(mul_k, ["Sqrt", "Div"], [1, 0])
|
|
583
|
+
if k_scale_nodes is None:
|
|
584
|
+
return
|
|
585
|
+
if k_scale_nodes[0].input[0] != sqrt_q_2.input[0]:
|
|
586
|
+
return
|
|
587
|
+
|
|
588
|
+
v = matmul_s_v.input[1]
|
|
589
|
+
|
|
590
|
+
# Here we sanity check the v path to make sure it is in the expected BNSH format.
|
|
591
|
+
concat_v = self.model.match_parent(matmul_s_v, "Concat", input_index=1, output_name_to_node=output_name_to_node)
|
|
592
|
+
if concat_v is not None:
|
|
593
|
+
# Match v path like:
|
|
594
|
+
# -- Transpose (perm=[0,2,1,3]) ----+
|
|
595
|
+
# |
|
|
596
|
+
# v
|
|
597
|
+
# -- Transpose (perm=[0,2,1,3]) -> Concat -> (v)
|
|
598
|
+
transpose_1 = self.model.match_parent(
|
|
599
|
+
concat_v, "Transpose", input_index=0, output_name_to_node=output_name_to_node
|
|
600
|
+
)
|
|
601
|
+
if transpose_1 is None:
|
|
602
|
+
return
|
|
603
|
+
if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
|
|
604
|
+
return
|
|
605
|
+
|
|
606
|
+
transpose_2 = self.model.match_parent(
|
|
607
|
+
concat_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
|
|
608
|
+
)
|
|
609
|
+
if transpose_2 is None:
|
|
610
|
+
return
|
|
611
|
+
if not FusionUtils.check_node_attribute(transpose_2, "perm", [0, 2, 1, 3]):
|
|
612
|
+
return
|
|
613
|
+
else:
|
|
614
|
+
# Match v path like:
|
|
615
|
+
# -- Transpose (perm=[0,2,1,3]) -> (v)
|
|
616
|
+
transpose_1 = self.model.match_parent(
|
|
617
|
+
matmul_s_v, "Transpose", input_index=1, output_name_to_node=output_name_to_node
|
|
618
|
+
)
|
|
619
|
+
if transpose_1 is None:
|
|
620
|
+
return
|
|
621
|
+
if not FusionUtils.check_node_attribute(transpose_1, "perm", [0, 2, 1, 3]):
|
|
622
|
+
return
|
|
623
|
+
|
|
624
|
+
# Match patterns for Flux.
|
|
625
|
+
num_heads = (
|
|
626
|
+
self.get_num_heads(concat_v, output_name_to_node)
|
|
627
|
+
if concat_v
|
|
628
|
+
else self.get_num_heads(matmul_s_v, output_name_to_node, input_index=1)
|
|
629
|
+
)
|
|
630
|
+
|
|
631
|
+
if num_heads == 0:
|
|
632
|
+
# Match patterns for Stable Diffusion 3.5.
|
|
633
|
+
num_heads = self.get_num_heads_from_k(transpose_k, output_name_to_node, concat_v is not None)
|
|
634
|
+
if num_heads <= 0:
|
|
635
|
+
return
|
|
636
|
+
|
|
637
|
+
# Q is in BNSH format, we need to adjust it to BSD format due to limitation of MHA op.
|
|
638
|
+
# TODO: MHA op support BNSH format to reduce the effort in fusion.
|
|
639
|
+
if concat_v is not None:
|
|
640
|
+
query = self.adjust_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
|
|
641
|
+
else:
|
|
642
|
+
query = self.adjust_query_from_bnsh_to_bsd_no_concat(mul_q, output_name_to_node)
|
|
643
|
+
|
|
644
|
+
if query is None:
|
|
645
|
+
query = self.adjust_flux_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
|
|
646
|
+
if query is None:
|
|
647
|
+
query = self.adjust_flux_single_query_from_bnsh_to_bsd(mul_q, output_name_to_node)
|
|
648
|
+
if query is None:
|
|
649
|
+
# fallback to use Transpose and Add to adjust query from BNSH to BSD
|
|
650
|
+
# This is more general approach.
|
|
651
|
+
# However, it might be slower if the extra Transpose node cannot be removed by ORT optimizer.
|
|
652
|
+
query = self.transpose_reshape_bnsh_to_bsd(q_bnsh, output_name_to_node)
|
|
653
|
+
|
|
654
|
+
new_node = self.create_multihead_attention_node(
|
|
655
|
+
q=query,
|
|
656
|
+
k=k,
|
|
657
|
+
v=v,
|
|
658
|
+
output=reshape_out.output[0],
|
|
659
|
+
num_heads=num_heads,
|
|
660
|
+
)
|
|
661
|
+
self.nodes_to_add.append(new_node)
|
|
662
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
663
|
+
|
|
664
|
+
self.nodes_to_remove.extend([matmul_s_v, transpose_out, reshape_out])
|
|
665
|
+
|
|
666
|
+
# Use prune graph to remove nodes
|
|
667
|
+
self.prune_graph = True
|