onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,435 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
9
|
+
from onnx import helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionBartAttention(FusionAttention):
|
|
16
|
+
"""
|
|
17
|
+
Fuse Bart Attention subgraph into one Attention node.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(
|
|
21
|
+
self,
|
|
22
|
+
model: OnnxModel,
|
|
23
|
+
hidden_size: int,
|
|
24
|
+
num_heads: int,
|
|
25
|
+
attention_mask: AttentionMask,
|
|
26
|
+
):
|
|
27
|
+
super().__init__(model, hidden_size, num_heads, attention_mask)
|
|
28
|
+
|
|
29
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
30
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
31
|
+
qkv_nodes = self.model.match_parent_path(
|
|
32
|
+
normalize_node,
|
|
33
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
34
|
+
[1, 1, 0, 0, 0],
|
|
35
|
+
)
|
|
36
|
+
if qkv_nodes is not None:
|
|
37
|
+
(
|
|
38
|
+
add_out,
|
|
39
|
+
matmul_out,
|
|
40
|
+
reshape_qkv,
|
|
41
|
+
transpose_qkv,
|
|
42
|
+
matmul_qkv,
|
|
43
|
+
) = qkv_nodes
|
|
44
|
+
else:
|
|
45
|
+
logger.debug("fuse_attention: failed to match qkv path")
|
|
46
|
+
return
|
|
47
|
+
|
|
48
|
+
other_inputs = []
|
|
49
|
+
for input_ in normalize_node.input:
|
|
50
|
+
if input_ not in output_name_to_node:
|
|
51
|
+
continue
|
|
52
|
+
if input_ == qkv_nodes[0].output[0]:
|
|
53
|
+
continue
|
|
54
|
+
other_inputs.append(input_)
|
|
55
|
+
if len(other_inputs) != 1:
|
|
56
|
+
return
|
|
57
|
+
root_input = other_inputs[0]
|
|
58
|
+
|
|
59
|
+
# Sometimes the input name to the attention MatMul nodes does not match the input name to the end
|
|
60
|
+
# SkipLayerNormalization node (name saved in root_input). We find the true input name to the MatMul
|
|
61
|
+
# nodes by getting the initial SkipLayerNormalization node and checking how many MatMul nodes are
|
|
62
|
+
# children nodes for each of its output names.
|
|
63
|
+
"""
|
|
64
|
+
root_input
|
|
65
|
+
+---------------------------------------------------+
|
|
66
|
+
| |
|
|
67
|
+
| |
|
|
68
|
+
SkipLayerNormalization --> Attention --> MatMul --> SkipLayerNormalization
|
|
69
|
+
"""
|
|
70
|
+
skip_layernorm = output_name_to_node[root_input]
|
|
71
|
+
# For some attention blocks, the end SkipLayerNormalization node may point to another node whose
|
|
72
|
+
# child is the LayerNormalization node.
|
|
73
|
+
if skip_layernorm.op_type in {"Add", "Clip"}:
|
|
74
|
+
skip_layernorm = self.model.get_children(skip_layernorm)[0]
|
|
75
|
+
for output in skip_layernorm.output:
|
|
76
|
+
if not output:
|
|
77
|
+
continue
|
|
78
|
+
children = input_name_to_nodes[output]
|
|
79
|
+
children_types = [child.op_type for child in children]
|
|
80
|
+
if children_types.count("MatMul") >= 1:
|
|
81
|
+
root_input = output
|
|
82
|
+
break
|
|
83
|
+
|
|
84
|
+
graph_input_names = {node.name for node in self.model.graph().input}
|
|
85
|
+
graph_output_names = {node.name for node in self.model.graph().output}
|
|
86
|
+
|
|
87
|
+
v_nodes_past_or_present = self.model.match_parent_path(
|
|
88
|
+
matmul_qkv,
|
|
89
|
+
["Transpose", "Reshape", "Add", "MatMul"],
|
|
90
|
+
[1, 0, 0, None],
|
|
91
|
+
)
|
|
92
|
+
v_nodes_with_past = self.model.match_parent_path(
|
|
93
|
+
matmul_qkv,
|
|
94
|
+
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
95
|
+
[1, 1, 0, 0, None],
|
|
96
|
+
)
|
|
97
|
+
v_nodes_past_only_oai = self.model.match_parent_path(
|
|
98
|
+
matmul_qkv,
|
|
99
|
+
["Transpose", "Reshape", "Reshape", "Transpose"],
|
|
100
|
+
[1, 0, 0, 0],
|
|
101
|
+
)
|
|
102
|
+
past_v, present_v = "", ""
|
|
103
|
+
v_nodes, add_v, matmul_v = [], None, None
|
|
104
|
+
if v_nodes_past_or_present is not None:
|
|
105
|
+
v_nodes = v_nodes_past_or_present
|
|
106
|
+
(transpose_v, reshape_v, add_v, matmul_v) = v_nodes
|
|
107
|
+
|
|
108
|
+
# Find past_v input name
|
|
109
|
+
start_child_nodes = input_name_to_nodes[add_v.output[0]]
|
|
110
|
+
for start_child_node in start_child_nodes:
|
|
111
|
+
if start_child_node.op_type == "Concat":
|
|
112
|
+
concat_v_nodes = self.model.match_parent_path(
|
|
113
|
+
start_child_node,
|
|
114
|
+
["Reshape", "Transpose"],
|
|
115
|
+
[0, 0],
|
|
116
|
+
)
|
|
117
|
+
if concat_v_nodes is not None:
|
|
118
|
+
past_v = concat_v_nodes[-1].input[0]
|
|
119
|
+
start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
|
|
120
|
+
break
|
|
121
|
+
|
|
122
|
+
# Find present_v output name
|
|
123
|
+
for start_child_node in start_child_nodes:
|
|
124
|
+
start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
|
|
125
|
+
for start_grandchild_node in start_grandchild_nodes:
|
|
126
|
+
if start_grandchild_node.output[0] in graph_output_names:
|
|
127
|
+
present_v = start_grandchild_node.output[0]
|
|
128
|
+
break
|
|
129
|
+
if present_v != "":
|
|
130
|
+
break
|
|
131
|
+
elif v_nodes_with_past is not None:
|
|
132
|
+
v_nodes = v_nodes_with_past
|
|
133
|
+
(concat_v, transpose_v, reshape_v, add_v, matmul_v) = v_nodes
|
|
134
|
+
past_v = concat_v.input[0]
|
|
135
|
+
present_v = concat_v.output[0]
|
|
136
|
+
elif matmul_qkv.input[1] in graph_input_names:
|
|
137
|
+
# Hugging Face's cross-attention where past_v is used directly as value
|
|
138
|
+
past_v = matmul_qkv.input[1]
|
|
139
|
+
elif v_nodes_past_only_oai is not None:
|
|
140
|
+
# OpenAI's cross-attention where past_v is used directly as value
|
|
141
|
+
v_nodes = v_nodes_past_only_oai
|
|
142
|
+
past_v = v_nodes[-1].input[0]
|
|
143
|
+
else:
|
|
144
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
145
|
+
return
|
|
146
|
+
past_v = past_v if past_v in graph_input_names else ""
|
|
147
|
+
present_v = present_v if present_v in graph_output_names else ""
|
|
148
|
+
|
|
149
|
+
qk_nodes_no_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "MatMul"], [0, 0])
|
|
150
|
+
qk_nodes_with_mask = self.model.match_parent_path(matmul_qkv, ["Softmax", "Add", "MatMul"], [0, 0, 0])
|
|
151
|
+
qk_nodes, add_qk = [], None
|
|
152
|
+
if qk_nodes_no_mask is not None:
|
|
153
|
+
_, matmul_qk = qk_nodes_no_mask
|
|
154
|
+
qk_nodes = qk_nodes_no_mask
|
|
155
|
+
elif qk_nodes_with_mask is not None:
|
|
156
|
+
_, add_qk, matmul_qk = qk_nodes_with_mask
|
|
157
|
+
qk_nodes = qk_nodes_with_mask
|
|
158
|
+
else:
|
|
159
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
q_nodes_hf = self.model.match_parent_path(
|
|
163
|
+
matmul_qk,
|
|
164
|
+
["Transpose", "Reshape", "Mul", "Add", "MatMul"],
|
|
165
|
+
[0, 0, 0, 0, 1],
|
|
166
|
+
)
|
|
167
|
+
q_nodes_oai = self.model.match_parent_path(
|
|
168
|
+
matmul_qk,
|
|
169
|
+
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
|
|
170
|
+
[0, 0, 0, 0, 1],
|
|
171
|
+
)
|
|
172
|
+
q_nodes = []
|
|
173
|
+
if q_nodes_hf is not None:
|
|
174
|
+
q_nodes = q_nodes_hf
|
|
175
|
+
(transpose_q, reshape_q, mul_q, add_q, matmul_q) = q_nodes
|
|
176
|
+
elif q_nodes_oai is not None:
|
|
177
|
+
q_nodes = q_nodes_oai
|
|
178
|
+
(mul_q, transpose_q, reshape_q, add_q, matmul_q) = q_nodes
|
|
179
|
+
else:
|
|
180
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
181
|
+
return
|
|
182
|
+
|
|
183
|
+
k_nodes_no_past_hf = self.model.match_parent_path(
|
|
184
|
+
matmul_qk,
|
|
185
|
+
["Transpose", "Reshape", "MatMul"],
|
|
186
|
+
[1, 0, 0],
|
|
187
|
+
)
|
|
188
|
+
k_nodes_with_past_hf = self.model.match_parent_path(
|
|
189
|
+
matmul_qk,
|
|
190
|
+
["Transpose", "Concat", "Transpose", "Reshape", "MatMul"],
|
|
191
|
+
[1, 0, 1, 0, 0],
|
|
192
|
+
)
|
|
193
|
+
k_nodes_past_or_present_oai = self.model.match_parent_path(
|
|
194
|
+
matmul_qk,
|
|
195
|
+
["Mul", "Transpose", "Reshape", "MatMul"],
|
|
196
|
+
[1, 0, 0, 0],
|
|
197
|
+
)
|
|
198
|
+
k_nodes_past_only_oai = self.model.match_parent_path(
|
|
199
|
+
matmul_qk,
|
|
200
|
+
["Mul", "Transpose", "Reshape", "Reshape", "Transpose"],
|
|
201
|
+
[1, 0, 0, 0, 0],
|
|
202
|
+
)
|
|
203
|
+
past_k, present_k = "", ""
|
|
204
|
+
k_nodes, add_k, matmul_k = [], None, None
|
|
205
|
+
if k_nodes_no_past_hf is not None:
|
|
206
|
+
k_nodes = k_nodes_no_past_hf
|
|
207
|
+
(transpose_k, reshape_k, matmul_k) = k_nodes
|
|
208
|
+
|
|
209
|
+
# Find present_k output name
|
|
210
|
+
transpose_k_nodes = input_name_to_nodes[reshape_k.output[0]]
|
|
211
|
+
for transpose_k_node in transpose_k_nodes:
|
|
212
|
+
if transpose_k_node.output[0] in graph_output_names:
|
|
213
|
+
present_k = transpose_k_node.output[0]
|
|
214
|
+
break
|
|
215
|
+
elif k_nodes_with_past_hf is not None:
|
|
216
|
+
k_nodes = k_nodes_with_past_hf
|
|
217
|
+
(_, concat_k, transpose_k, reshape_k, matmul_k) = k_nodes
|
|
218
|
+
past_k = concat_k.input[0]
|
|
219
|
+
present_k = concat_k.output[0]
|
|
220
|
+
elif output_name_to_node[matmul_qk.input[1]].input[0] in graph_input_names:
|
|
221
|
+
# Hugging Face's cross-attention where past_k is used directly as key
|
|
222
|
+
k_nodes = [output_name_to_node[matmul_qk.input[1]]]
|
|
223
|
+
past_k = k_nodes[0].input[0]
|
|
224
|
+
elif k_nodes_past_or_present_oai is not None:
|
|
225
|
+
k_nodes = k_nodes_past_or_present_oai
|
|
226
|
+
(_, transpose_k, reshape_k, matmul_k) = k_nodes
|
|
227
|
+
|
|
228
|
+
# Find past_k input name
|
|
229
|
+
start_child_nodes = input_name_to_nodes[matmul_k.output[0]]
|
|
230
|
+
for start_child_node in start_child_nodes:
|
|
231
|
+
if start_child_node.op_type == "Concat":
|
|
232
|
+
concat_k_nodes = self.model.match_parent_path(
|
|
233
|
+
start_child_node,
|
|
234
|
+
["Reshape", "Transpose"],
|
|
235
|
+
[0, 0],
|
|
236
|
+
)
|
|
237
|
+
if concat_k_nodes is not None:
|
|
238
|
+
past_k = concat_k_nodes[-1].input[0]
|
|
239
|
+
start_child_nodes = input_name_to_nodes[start_child_node.output[0]]
|
|
240
|
+
break
|
|
241
|
+
|
|
242
|
+
# Find present_k output name
|
|
243
|
+
for start_child_node in start_child_nodes:
|
|
244
|
+
start_grandchild_nodes = input_name_to_nodes[start_child_node.output[0]]
|
|
245
|
+
for start_grandchild_node in start_grandchild_nodes:
|
|
246
|
+
if start_grandchild_node.output[0] in graph_output_names:
|
|
247
|
+
present_k = start_grandchild_node.output[0]
|
|
248
|
+
break
|
|
249
|
+
if present_k != "":
|
|
250
|
+
break
|
|
251
|
+
elif k_nodes_past_only_oai is not None:
|
|
252
|
+
# OpenAI's cross-attention where past_k is used directly as key
|
|
253
|
+
k_nodes = k_nodes_past_only_oai
|
|
254
|
+
past_k = k_nodes[-1].input[0]
|
|
255
|
+
else:
|
|
256
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
257
|
+
return
|
|
258
|
+
past_k = past_k if past_k in graph_input_names else ""
|
|
259
|
+
present_k = present_k if present_k in graph_output_names else ""
|
|
260
|
+
|
|
261
|
+
if matmul_k is not None and add_k is None:
|
|
262
|
+
# Create empty Add node for attention graph
|
|
263
|
+
add_v_tensor = self.model.get_initializer(add_v.input[0])
|
|
264
|
+
bias_dim = add_v_tensor.dims[0]
|
|
265
|
+
dtype = add_v_tensor.data_type
|
|
266
|
+
empty_bias_name = "empty_bias"
|
|
267
|
+
empty_tensor = self.model.get_initializer(empty_bias_name)
|
|
268
|
+
if empty_tensor is None:
|
|
269
|
+
self.add_initializer(
|
|
270
|
+
empty_bias_name,
|
|
271
|
+
dtype,
|
|
272
|
+
dims=[bias_dim],
|
|
273
|
+
vals=np.array([0.0] * bias_dim, dtype=helper.tensor_dtype_to_np_dtype(dtype)),
|
|
274
|
+
)
|
|
275
|
+
|
|
276
|
+
add_name = self.model.create_node_name("Add")
|
|
277
|
+
add_k = helper.make_node("Add", [empty_bias_name, matmul_k.output[0]], [reshape_k.name], add_name)
|
|
278
|
+
|
|
279
|
+
three_root_inputs = bool(past_k) and bool(past_v) and matmul_k is None and matmul_v is None
|
|
280
|
+
one_root_input = (
|
|
281
|
+
not three_root_inputs
|
|
282
|
+
and matmul_q.input[0] == root_input
|
|
283
|
+
and matmul_k.input[0] == root_input
|
|
284
|
+
and matmul_v.input[0] == root_input
|
|
285
|
+
)
|
|
286
|
+
two_root_inputs = (
|
|
287
|
+
not three_root_inputs
|
|
288
|
+
and matmul_q.input[0] == root_input
|
|
289
|
+
and matmul_k.input[0] == matmul_v.input[0]
|
|
290
|
+
and matmul_k.input[0] != matmul_q.input[0]
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
# There are 5 types of attention:
|
|
294
|
+
# 1) Encoder attention with one_root_input=True and qk_nodes=qk_nodes_no_mask
|
|
295
|
+
# 2) Decoder self attention with one_root_input=True and qk_nodes=qk_nodes_with_mask
|
|
296
|
+
# 3) Decoder cross attention with two_root_inputs=True and qk_nodes=qk_nodes_no_mask
|
|
297
|
+
# 4) Decoder self attention with past with one_root_input=True and qk_nodes=qk_nodes_with_mask and past_k=past_decoder_key and past_v=past_decoder_value
|
|
298
|
+
# 5) Decoder cross attention with past with three_root_inputs=True and qk_nodes=qk_nodes_no_mask
|
|
299
|
+
encoder_attention = one_root_input and qk_nodes == qk_nodes_no_mask
|
|
300
|
+
decoder_self_attention = one_root_input and qk_nodes == qk_nodes_with_mask
|
|
301
|
+
decoder_cross_attention = two_root_inputs and qk_nodes == qk_nodes_no_mask
|
|
302
|
+
decoder_self_attention_with_past = decoder_self_attention and bool(past_k) and bool(past_v)
|
|
303
|
+
decoder_cross_attention_with_past = three_root_inputs and qk_nodes == qk_nodes_no_mask
|
|
304
|
+
|
|
305
|
+
# For decoder self-attentions, the attention mask needs to be included in the attention node
|
|
306
|
+
causal_mask = qk_nodes == qk_nodes_with_mask
|
|
307
|
+
mask_nodes = []
|
|
308
|
+
if causal_mask:
|
|
309
|
+
mask_nodes_bart = self.model.match_parent_path(
|
|
310
|
+
add_qk,
|
|
311
|
+
["Where"],
|
|
312
|
+
[1],
|
|
313
|
+
)
|
|
314
|
+
mask_nodes_whisper_hf = self.model.match_parent_path(
|
|
315
|
+
add_qk,
|
|
316
|
+
["Slice", "Expand", "Where"],
|
|
317
|
+
[1, 0, 1],
|
|
318
|
+
)
|
|
319
|
+
mask_nodes_whisper_oai = self.model.match_parent_path(
|
|
320
|
+
add_qk,
|
|
321
|
+
["Slice", "Unsqueeze", "Gather", "Shape", "Add"],
|
|
322
|
+
[1, 2, 0, 0, 0],
|
|
323
|
+
)
|
|
324
|
+
mask_nodes_whisper_oai_unit_test = self.model.match_parent_path(
|
|
325
|
+
add_qk,
|
|
326
|
+
["Slice", "Slice"],
|
|
327
|
+
[1, 0],
|
|
328
|
+
)
|
|
329
|
+
if mask_nodes_whisper_hf is not None:
|
|
330
|
+
mask_nodes = mask_nodes_whisper_hf
|
|
331
|
+
elif mask_nodes_whisper_oai is not None:
|
|
332
|
+
mask_nodes = mask_nodes_whisper_oai
|
|
333
|
+
elif mask_nodes_whisper_oai_unit_test is not None:
|
|
334
|
+
mask_nodes = mask_nodes_whisper_oai_unit_test
|
|
335
|
+
elif mask_nodes_bart is not None:
|
|
336
|
+
mask_nodes = mask_nodes_bart
|
|
337
|
+
else:
|
|
338
|
+
logger.debug("fuse_attention: failed to match mask nodes")
|
|
339
|
+
return
|
|
340
|
+
assert len(mask_nodes) > 0
|
|
341
|
+
|
|
342
|
+
if (
|
|
343
|
+
encoder_attention
|
|
344
|
+
or decoder_self_attention
|
|
345
|
+
or decoder_cross_attention
|
|
346
|
+
or decoder_self_attention_with_past
|
|
347
|
+
or decoder_cross_attention_with_past
|
|
348
|
+
):
|
|
349
|
+
attention_last_node = reshape_qkv
|
|
350
|
+
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
351
|
+
|
|
352
|
+
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
|
|
353
|
+
logger.debug("fuse_attention: failed to detect num_heads or hidden_size")
|
|
354
|
+
return
|
|
355
|
+
|
|
356
|
+
new_node = None
|
|
357
|
+
if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
|
|
358
|
+
# Note: Decoder attention with past key and past value is fused as multi-head attention
|
|
359
|
+
# rather than attention because multi-head attention supports separate past key and past
|
|
360
|
+
# value whereas attention supports concatenated past key and past value.
|
|
361
|
+
new_node = (
|
|
362
|
+
self.create_multihead_attention_node(
|
|
363
|
+
q_matmul=matmul_q,
|
|
364
|
+
k_matmul=matmul_k if decoder_cross_attention or decoder_self_attention_with_past else past_k,
|
|
365
|
+
v_matmul=matmul_v if decoder_cross_attention or decoder_self_attention_with_past else past_v,
|
|
366
|
+
q_add=add_q,
|
|
367
|
+
k_add=add_k if decoder_cross_attention or decoder_self_attention_with_past else None,
|
|
368
|
+
v_add=add_v if decoder_cross_attention or decoder_self_attention_with_past else None,
|
|
369
|
+
num_heads=num_heads,
|
|
370
|
+
hidden_size=hidden_size,
|
|
371
|
+
output=attention_last_node.output[0],
|
|
372
|
+
unidirectional=causal_mask,
|
|
373
|
+
past_k=past_k if decoder_self_attention_with_past else "",
|
|
374
|
+
past_v=past_v if decoder_self_attention_with_past else "",
|
|
375
|
+
present_k=present_k,
|
|
376
|
+
present_v=present_v,
|
|
377
|
+
)
|
|
378
|
+
if self.use_multi_head_attention
|
|
379
|
+
else None
|
|
380
|
+
)
|
|
381
|
+
else:
|
|
382
|
+
# Temporarily set multi-head attention flag to false
|
|
383
|
+
use_multi_head_attention_ground_truth = self.use_multi_head_attention
|
|
384
|
+
self.use_multi_head_attention = False
|
|
385
|
+
new_node = self.create_attention_node(
|
|
386
|
+
mask_index=None,
|
|
387
|
+
q_matmul=matmul_q,
|
|
388
|
+
k_matmul=matmul_k,
|
|
389
|
+
v_matmul=matmul_v,
|
|
390
|
+
q_add=add_q,
|
|
391
|
+
k_add=add_k,
|
|
392
|
+
v_add=add_v,
|
|
393
|
+
num_heads=num_heads,
|
|
394
|
+
hidden_size=hidden_size,
|
|
395
|
+
first_input=root_input,
|
|
396
|
+
output=attention_last_node.output[0],
|
|
397
|
+
causal=causal_mask,
|
|
398
|
+
past_k=past_k,
|
|
399
|
+
past_v=past_v,
|
|
400
|
+
present_k=present_k,
|
|
401
|
+
present_v=present_v,
|
|
402
|
+
)
|
|
403
|
+
self.use_multi_head_attention = use_multi_head_attention_ground_truth
|
|
404
|
+
if new_node is None:
|
|
405
|
+
logger.debug("fuse_attention: failed to create fused node")
|
|
406
|
+
return
|
|
407
|
+
|
|
408
|
+
self.nodes_to_add.append(new_node)
|
|
409
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
410
|
+
|
|
411
|
+
self.nodes_to_remove.extend([attention_last_node, transpose_qkv, matmul_qkv])
|
|
412
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
413
|
+
|
|
414
|
+
# When using multi-head attention, keep MatMul nodes in original graph
|
|
415
|
+
if decoder_self_attention_with_past or decoder_cross_attention or decoder_cross_attention_with_past:
|
|
416
|
+
if len(q_nodes) > 0 and q_nodes[-1].op_type == "MatMul":
|
|
417
|
+
q_nodes.pop()
|
|
418
|
+
if len(k_nodes) > 0 and k_nodes[-1].op_type == "MatMul":
|
|
419
|
+
k_nodes.pop()
|
|
420
|
+
if len(v_nodes) > 0 and v_nodes[-1].op_type == "MatMul":
|
|
421
|
+
v_nodes.pop()
|
|
422
|
+
if self.disable_multi_head_attention_bias:
|
|
423
|
+
if len(q_nodes) > 0 and q_nodes[-1].op_type == "Add":
|
|
424
|
+
q_nodes.pop()
|
|
425
|
+
if len(k_nodes) > 0 and k_nodes[-1].op_type == "Add":
|
|
426
|
+
k_nodes.pop()
|
|
427
|
+
if len(v_nodes) > 0 and v_nodes[-1].op_type == "Add":
|
|
428
|
+
v_nodes.pop()
|
|
429
|
+
|
|
430
|
+
self.nodes_to_remove.extend(q_nodes)
|
|
431
|
+
self.nodes_to_remove.extend(k_nodes)
|
|
432
|
+
self.nodes_to_remove.extend(v_nodes)
|
|
433
|
+
|
|
434
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
435
|
+
self.prune_graph = True
|
|
@@ -0,0 +1,141 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from collections import defaultdict
|
|
6
|
+
from collections.abc import Sequence
|
|
7
|
+
from logging import getLogger
|
|
8
|
+
from typing import Any
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
11
|
+
from onnx import NodeProto, TensorProto, helper
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class Fusion:
|
|
18
|
+
"""
|
|
19
|
+
Base class for Graph Fusion
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self,
|
|
24
|
+
model: OnnxModel,
|
|
25
|
+
fused_op_type: str,
|
|
26
|
+
search_op_types: str | list[str],
|
|
27
|
+
description: str = "",
|
|
28
|
+
):
|
|
29
|
+
self.search_op_types: list[str] = [search_op_types] if isinstance(search_op_types, str) else search_op_types
|
|
30
|
+
self.fused_op_type: str = fused_op_type
|
|
31
|
+
self.description: str = f"{fused_op_type}({description})" if description else fused_op_type
|
|
32
|
+
self.model: OnnxModel = model
|
|
33
|
+
self.nodes_to_remove: list = []
|
|
34
|
+
self.nodes_to_add: list = []
|
|
35
|
+
self.prune_graph: bool = False
|
|
36
|
+
self.node_name_to_graph_name: dict = {}
|
|
37
|
+
self.this_graph_name: str | None = None
|
|
38
|
+
# It is optional that subclass updates fused_count since we will also check nodes_to_add to get counter.
|
|
39
|
+
self.fused_count: defaultdict = defaultdict(int)
|
|
40
|
+
|
|
41
|
+
def increase_counter(self, fused_op_name: str):
|
|
42
|
+
"""
|
|
43
|
+
Increase counter of a fused operator.
|
|
44
|
+
"""
|
|
45
|
+
self.fused_count[fused_op_name] += 1
|
|
46
|
+
|
|
47
|
+
def fuse(
|
|
48
|
+
self,
|
|
49
|
+
node: NodeProto,
|
|
50
|
+
input_name_to_nodes: dict[str, list[NodeProto]],
|
|
51
|
+
output_name_to_node: dict[str, NodeProto],
|
|
52
|
+
):
|
|
53
|
+
"""Interface for fusion that starts from a node"""
|
|
54
|
+
raise NotImplementedError
|
|
55
|
+
|
|
56
|
+
def apply(self):
|
|
57
|
+
"""
|
|
58
|
+
Apply graph fusion on the whole model graph.
|
|
59
|
+
It searched nodes of given operators, and start fusion on each of those nodes.
|
|
60
|
+
"""
|
|
61
|
+
logger.debug(f"start {self.description} fusion...")
|
|
62
|
+
input_name_to_nodes = self.model.input_name_to_nodes()
|
|
63
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
64
|
+
|
|
65
|
+
# This assumes that two search ops will not be fused at same time!
|
|
66
|
+
for search_op_type in self.search_op_types:
|
|
67
|
+
for node in self.model.get_nodes_by_op_type(search_op_type):
|
|
68
|
+
graph = self.model.get_graph_by_node(node)
|
|
69
|
+
if graph is None:
|
|
70
|
+
raise Exception("Can not find node in any graph")
|
|
71
|
+
self.this_graph_name = graph.name
|
|
72
|
+
self.fuse(node, input_name_to_nodes, output_name_to_node)
|
|
73
|
+
|
|
74
|
+
op_list = [node.op_type for node in self.nodes_to_add]
|
|
75
|
+
if self.fused_count:
|
|
76
|
+
for key, value in self.fused_count.items():
|
|
77
|
+
if value:
|
|
78
|
+
logger.info(f"Fused {key}: {value}")
|
|
79
|
+
else:
|
|
80
|
+
count = op_list.count(self.fused_op_type)
|
|
81
|
+
if count > 0:
|
|
82
|
+
logger.info(f"Fused {self.description}: {count}")
|
|
83
|
+
|
|
84
|
+
self.model.remove_nodes(self.nodes_to_remove)
|
|
85
|
+
self.model.add_nodes(self.nodes_to_add, self.node_name_to_graph_name)
|
|
86
|
+
|
|
87
|
+
if self.prune_graph:
|
|
88
|
+
self.model.prune_graph()
|
|
89
|
+
elif self.nodes_to_remove or self.nodes_to_add:
|
|
90
|
+
self.model.update_graph()
|
|
91
|
+
|
|
92
|
+
def add_initializer(self, name: str, data_type: int, dims: Sequence[int], vals: Any, raw: bool = True):
|
|
93
|
+
if raw:
|
|
94
|
+
if not isinstance(vals, np.ndarray):
|
|
95
|
+
np_type = helper.tensor_dtype_to_np_dtype(data_type)
|
|
96
|
+
bytes = np.array(vals, dtype=np_type).tobytes()
|
|
97
|
+
else:
|
|
98
|
+
bytes = vals.tobytes()
|
|
99
|
+
tensor = helper.make_tensor(
|
|
100
|
+
name=name,
|
|
101
|
+
data_type=data_type,
|
|
102
|
+
dims=dims,
|
|
103
|
+
vals=bytes,
|
|
104
|
+
raw=True,
|
|
105
|
+
)
|
|
106
|
+
else:
|
|
107
|
+
tensor = helper.make_tensor(
|
|
108
|
+
name=name,
|
|
109
|
+
data_type=data_type,
|
|
110
|
+
dims=dims,
|
|
111
|
+
vals=vals,
|
|
112
|
+
raw=False,
|
|
113
|
+
)
|
|
114
|
+
|
|
115
|
+
self.model.add_initializer(tensor, self.this_graph_name)
|
|
116
|
+
return tensor
|
|
117
|
+
|
|
118
|
+
def remove_initializer(self, tensor: TensorProto):
|
|
119
|
+
self.model.remove_initializer(tensor)
|
|
120
|
+
|
|
121
|
+
def add_nodes_to_remove(self, nodes: list[NodeProto]):
|
|
122
|
+
# Some nodes are shared between paths (e.g. rotary embedding nodes in the Q and K paths).
|
|
123
|
+
# When path A is fused, its shared nodes are added to `self.nodes_to_remove`. But when path B
|
|
124
|
+
# is fused, its shared nodes are also added to `self.nodes_to_remove`. When the nodes are
|
|
125
|
+
# iteratively removed from `self.nodes_to_remove`, path A's shared nodes are removed first.
|
|
126
|
+
# Since path A's shared nodes are removed, path B's shared nodes are not removed because they
|
|
127
|
+
# were previously removed for path A. This causes an error to print in remove_node that a node
|
|
128
|
+
# has failed to be removed.
|
|
129
|
+
#
|
|
130
|
+
# To avoid this error, we pre-emptively check if the shared nodes are already in `self.nodes_to_remove`.
|
|
131
|
+
# We could alternatively convert `self.nodes_to_remove` to a set to avoid this issue, but there could
|
|
132
|
+
# be scenarios where the nodes need to be removed in a specific order and converting to a set would
|
|
133
|
+
# lose this order.
|
|
134
|
+
for node in nodes:
|
|
135
|
+
if node not in self.nodes_to_remove:
|
|
136
|
+
self.nodes_to_remove.append(node)
|
|
137
|
+
|
|
138
|
+
def add_nodes_to_remove_with_nodes_to_keep(self, nodes: list[NodeProto], nodes_to_keep: list[NodeProto]):
|
|
139
|
+
for node in nodes:
|
|
140
|
+
if node not in self.nodes_to_remove and node not in nodes_to_keep:
|
|
141
|
+
self.nodes_to_remove.append(node)
|
|
@@ -0,0 +1,57 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
from fusion_base import Fusion
|
|
8
|
+
from numpy import ndarray
|
|
9
|
+
from onnx import helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionBiasAdd(Fusion):
|
|
16
|
+
def __init__(self, model: OnnxModel):
|
|
17
|
+
super().__init__(model, "BiasAdd", "Add")
|
|
18
|
+
|
|
19
|
+
def fuse(self, add_node, input_name_to_nodes: dict, output_name_to_node: dict):
|
|
20
|
+
"""
|
|
21
|
+
Fuse Add bias and Add skip connection into BiasAdd
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
nodes = self.model.match_parent_path(
|
|
25
|
+
add_node,
|
|
26
|
+
["Add", "MatMul", "BiasSplitGelu", "MatMul", "SkipLayerNormalization"],
|
|
27
|
+
[0, None, 0, 0, 0],
|
|
28
|
+
output_name_to_node,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
if nodes is None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
bias_node = nodes[0]
|
|
35
|
+
skip_layer_norm = nodes[-1]
|
|
36
|
+
|
|
37
|
+
# Check skip connection is from SkipLayerNormalization output
|
|
38
|
+
if add_node.input[1] not in skip_layer_norm.output:
|
|
39
|
+
return
|
|
40
|
+
|
|
41
|
+
bias_index, bias_value = self.model.get_constant_input(bias_node)
|
|
42
|
+
if not (isinstance(bias_index, int) and (bias_value is not None) and isinstance(bias_value, ndarray)):
|
|
43
|
+
return
|
|
44
|
+
if bias_value.ndim != 1:
|
|
45
|
+
return
|
|
46
|
+
|
|
47
|
+
self.nodes_to_remove.extend([add_node, bias_node])
|
|
48
|
+
node_name = self.model.create_node_name("BiasAdd")
|
|
49
|
+
fused_node = helper.make_node(
|
|
50
|
+
"BiasAdd",
|
|
51
|
+
inputs=[bias_node.input[1 - bias_index], bias_node.input[bias_index], add_node.input[1]],
|
|
52
|
+
outputs=[add_node.output[0]],
|
|
53
|
+
name=node_name,
|
|
54
|
+
)
|
|
55
|
+
fused_node.domain = "com.microsoft"
|
|
56
|
+
self.nodes_to_add.append(fused_node)
|
|
57
|
+
self.node_name_to_graph_name[node_name] = self.this_graph_name
|