onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,588 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
|
|
8
|
+
import numpy as np
|
|
9
|
+
import onnx
|
|
10
|
+
from onnx import TensorProto, helper, numpy_helper
|
|
11
|
+
from onnx_model_bert import BertOnnxModel
|
|
12
|
+
|
|
13
|
+
logger = logging.getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class BertOnnxModelTF(BertOnnxModel):
|
|
17
|
+
def __init__(self, model, num_heads, hidden_size):
|
|
18
|
+
super().__init__(model, num_heads, hidden_size)
|
|
19
|
+
|
|
20
|
+
def remove_identity(self):
|
|
21
|
+
nodes_to_remove = []
|
|
22
|
+
for node in self.nodes():
|
|
23
|
+
if node.op_type == "Identity":
|
|
24
|
+
if not self.find_graph_output(node.output[0]):
|
|
25
|
+
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
26
|
+
nodes_to_remove.append(node)
|
|
27
|
+
self.remove_nodes(nodes_to_remove)
|
|
28
|
+
logger.info(f"Removed Identity count: {len(nodes_to_remove)}")
|
|
29
|
+
|
|
30
|
+
def match_mask_path(self, add_or_sub_before_softmax):
|
|
31
|
+
mask_nodes = self.match_parent_path(
|
|
32
|
+
add_or_sub_before_softmax,
|
|
33
|
+
["Mul", "Sub", "Reshape", "Cast"],
|
|
34
|
+
[1, None, 1, 0],
|
|
35
|
+
)
|
|
36
|
+
if mask_nodes is not None:
|
|
37
|
+
return mask_nodes
|
|
38
|
+
|
|
39
|
+
mask_nodes = self.match_parent_path(
|
|
40
|
+
add_or_sub_before_softmax,
|
|
41
|
+
["Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
|
|
42
|
+
[1, 0, 1, 0, 0],
|
|
43
|
+
)
|
|
44
|
+
if mask_nodes is not None:
|
|
45
|
+
return mask_nodes
|
|
46
|
+
|
|
47
|
+
mask_nodes = self.match_parent_path(
|
|
48
|
+
add_or_sub_before_softmax,
|
|
49
|
+
["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze"],
|
|
50
|
+
[1, None, 1, 0, 0],
|
|
51
|
+
)
|
|
52
|
+
|
|
53
|
+
return mask_nodes
|
|
54
|
+
|
|
55
|
+
def get_2d_initializers_from_parent_subgraphs(self, current_node):
|
|
56
|
+
"""
|
|
57
|
+
Find initializers that is 2D. Returns a dictionary with name as key and shape as value.
|
|
58
|
+
"""
|
|
59
|
+
parent_nodes = self.get_parent_subgraph_nodes(current_node, [])
|
|
60
|
+
initializers = {}
|
|
61
|
+
for node in parent_nodes:
|
|
62
|
+
for input in node.input:
|
|
63
|
+
initializer = self.get_initializer(input)
|
|
64
|
+
if initializer:
|
|
65
|
+
temp = numpy_helper.to_array(initializer)
|
|
66
|
+
if len(temp.shape) == 2:
|
|
67
|
+
initializers[initializer.name] = temp.shape
|
|
68
|
+
|
|
69
|
+
return initializers
|
|
70
|
+
|
|
71
|
+
def find_segment_ids(self, segment_embedding, input_ids):
|
|
72
|
+
input_name_to_nodes = self.input_name_to_nodes()
|
|
73
|
+
if segment_embedding not in input_name_to_nodes:
|
|
74
|
+
return None
|
|
75
|
+
|
|
76
|
+
nodes = input_name_to_nodes[segment_embedding]
|
|
77
|
+
if len(nodes) != 1:
|
|
78
|
+
return None
|
|
79
|
+
|
|
80
|
+
graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
|
|
81
|
+
if len(graph_inputs) > 1:
|
|
82
|
+
print("Found multiple candidates of segment_ids", graph_inputs)
|
|
83
|
+
return None
|
|
84
|
+
# Find segment ids in graph inputs. The segment id input must not be the same as input_ids.
|
|
85
|
+
if len(graph_inputs) == 1 and graph_inputs[0] != input_ids:
|
|
86
|
+
return graph_inputs[0]
|
|
87
|
+
|
|
88
|
+
# If the segment id candidate is the same as the input_ids, try to assign alternative segment ids and simplify the graph if needed.
|
|
89
|
+
segment_ids = nodes[0].input[1]
|
|
90
|
+
_, segment_id_path, _ = self.match_parent_paths(
|
|
91
|
+
nodes[0],
|
|
92
|
+
[
|
|
93
|
+
(
|
|
94
|
+
["ConstantOfShape", "Cast", "Concat", "Slice", "Cast", "Shape"],
|
|
95
|
+
[1, 0, 0, 0, 0, 0],
|
|
96
|
+
),
|
|
97
|
+
(
|
|
98
|
+
[
|
|
99
|
+
"ConstantOfShape",
|
|
100
|
+
"Cast",
|
|
101
|
+
"Concat",
|
|
102
|
+
"Unsqueeze",
|
|
103
|
+
"Squeeze",
|
|
104
|
+
"Slice",
|
|
105
|
+
"Cast",
|
|
106
|
+
"Shape",
|
|
107
|
+
],
|
|
108
|
+
[1, 0, 0, 0, 0, 0, 0, 0],
|
|
109
|
+
),
|
|
110
|
+
],
|
|
111
|
+
None,
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
if segment_id_path and input_ids and input_ids == segment_id_path[-1].input[0]:
|
|
115
|
+
logger.debug("Simplify semgent id path...")
|
|
116
|
+
constantofshape_node = segment_id_path[0]
|
|
117
|
+
graph_name = self.get_graph_by_node(constantofshape_node).name
|
|
118
|
+
self.add_node(
|
|
119
|
+
helper.make_node("Shape", inputs=[input_ids], outputs=["input_shape"]),
|
|
120
|
+
graph_name,
|
|
121
|
+
)
|
|
122
|
+
constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
|
|
123
|
+
self.add_node(
|
|
124
|
+
helper.make_node(
|
|
125
|
+
"ConstantOfShape",
|
|
126
|
+
inputs=["input_shape"],
|
|
127
|
+
outputs=["zeros_for_input_shape"],
|
|
128
|
+
value=constantofshape_value,
|
|
129
|
+
),
|
|
130
|
+
graph_name,
|
|
131
|
+
)
|
|
132
|
+
segment_ids = "zeros_for_input_shape"
|
|
133
|
+
return segment_ids
|
|
134
|
+
|
|
135
|
+
def find_input_ids(self, word_embedding):
|
|
136
|
+
input_name_to_nodes = self.input_name_to_nodes()
|
|
137
|
+
if word_embedding not in input_name_to_nodes:
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
nodes = input_name_to_nodes[word_embedding]
|
|
141
|
+
if len(nodes) != 1:
|
|
142
|
+
return None
|
|
143
|
+
|
|
144
|
+
graph_inputs = self.get_graph_inputs(nodes[0], recursive=True)
|
|
145
|
+
if len(graph_inputs) == 1:
|
|
146
|
+
return graph_inputs[0]
|
|
147
|
+
|
|
148
|
+
print("Found multiple candidates of input_ids", graph_inputs)
|
|
149
|
+
return None
|
|
150
|
+
|
|
151
|
+
def find_mask_input(self, excluded_graph_inputs):
|
|
152
|
+
for node in self.nodes():
|
|
153
|
+
if node.op_type == "Softmax":
|
|
154
|
+
mask_path = self.match_parent_path(
|
|
155
|
+
node,
|
|
156
|
+
["Add", "Mul", "Sub", "Cast", "Slice", "Unsqueeze"],
|
|
157
|
+
[0, 1, None, 1, 0, 0],
|
|
158
|
+
)
|
|
159
|
+
if mask_path is None:
|
|
160
|
+
continue
|
|
161
|
+
(
|
|
162
|
+
add_node,
|
|
163
|
+
mul_node,
|
|
164
|
+
sub_node,
|
|
165
|
+
cast_node,
|
|
166
|
+
slice_node,
|
|
167
|
+
unsqueeze_node,
|
|
168
|
+
) = mask_path
|
|
169
|
+
if self.has_constant_input(mul_node, -10000) and self.has_constant_input(sub_node, 1):
|
|
170
|
+
graph_inputs = self.get_graph_inputs(sub_node, recursive=True)
|
|
171
|
+
inputs = [input for input in graph_inputs if input not in excluded_graph_inputs]
|
|
172
|
+
if len(inputs) > 1:
|
|
173
|
+
print("Found multiple candidates of mask input", inputs)
|
|
174
|
+
return None
|
|
175
|
+
if len(inputs) == 1:
|
|
176
|
+
return inputs[0]
|
|
177
|
+
# Duplicated input found. Try to simplify the graph.
|
|
178
|
+
path_to_be_simplified = self.match_parent_path(
|
|
179
|
+
mask_path[-1],
|
|
180
|
+
[
|
|
181
|
+
"ConstantOfShape",
|
|
182
|
+
"Cast",
|
|
183
|
+
"Concat",
|
|
184
|
+
"Unsqueeze",
|
|
185
|
+
"Squeeze",
|
|
186
|
+
"Slice",
|
|
187
|
+
"Cast",
|
|
188
|
+
"Shape",
|
|
189
|
+
],
|
|
190
|
+
[0, 0, 0, 0, 0, 0, 0, 0],
|
|
191
|
+
)
|
|
192
|
+
duplicated_inputs = [input for input in graph_inputs if input in excluded_graph_inputs]
|
|
193
|
+
# Simplify graph for dynamic axes.
|
|
194
|
+
if (
|
|
195
|
+
path_to_be_simplified
|
|
196
|
+
and duplicated_inputs
|
|
197
|
+
and len(duplicated_inputs) == 1
|
|
198
|
+
and duplicated_inputs[0] == path_to_be_simplified[-1].input[0]
|
|
199
|
+
):
|
|
200
|
+
logger.debug("Simplify semgent id path...")
|
|
201
|
+
constantofshape_node = path_to_be_simplified[0]
|
|
202
|
+
constantofshape_value = helper.get_attribute_value(constantofshape_node.attribute[0])
|
|
203
|
+
graph_name = self.get_graph_by_node(constantofshape_node).name
|
|
204
|
+
self.add_node(
|
|
205
|
+
helper.make_node(
|
|
206
|
+
"Shape",
|
|
207
|
+
inputs=[duplicated_inputs[0]],
|
|
208
|
+
outputs=["input_shape_for_mask"],
|
|
209
|
+
),
|
|
210
|
+
graph_name,
|
|
211
|
+
)
|
|
212
|
+
self.add_node(
|
|
213
|
+
helper.make_node(
|
|
214
|
+
"ConstantOfShape",
|
|
215
|
+
inputs=["input_shape_for_mask"],
|
|
216
|
+
outputs=[unsqueeze_node.input[0]],
|
|
217
|
+
value=constantofshape_value,
|
|
218
|
+
),
|
|
219
|
+
graph_name,
|
|
220
|
+
)
|
|
221
|
+
return unsqueeze_node.input[0]
|
|
222
|
+
return None
|
|
223
|
+
|
|
224
|
+
def create_embedding_subgraph(self, normalize_node, word_embedding, segment_embedding, position_embedding):
|
|
225
|
+
input_ids = self.find_input_ids(word_embedding)
|
|
226
|
+
if input_ids is None:
|
|
227
|
+
logger.info("Failed to find input_ids. Cannot fuse embedding layer.")
|
|
228
|
+
return False
|
|
229
|
+
|
|
230
|
+
segment_ids = self.find_segment_ids(segment_embedding, input_ids)
|
|
231
|
+
if segment_ids is None:
|
|
232
|
+
logger.info("Failed to find segment_ids. Cannot fuse embedding layer.")
|
|
233
|
+
return False
|
|
234
|
+
|
|
235
|
+
mask_input = self.find_mask_input([segment_ids, input_ids])
|
|
236
|
+
if mask_input is None:
|
|
237
|
+
logger.info("Failed to find input_mask. Cannot fuse embedding layer.")
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
self.bert_inputs = [input_ids, segment_ids, mask_input]
|
|
241
|
+
|
|
242
|
+
mask_index = self.create_node_name("mask_index")
|
|
243
|
+
self.attention_mask.set_mask_indice(mask_input, mask_index)
|
|
244
|
+
|
|
245
|
+
if self.find_graph_input(input_ids).type.tensor_type.elem_type != TensorProto.INT32:
|
|
246
|
+
casted, input_ids = self.utils.cast_graph_input_to_int32(input_ids)
|
|
247
|
+
|
|
248
|
+
if self.find_graph_input(segment_ids):
|
|
249
|
+
casted, segment_ids = self.utils.cast_graph_input_to_int32(segment_ids)
|
|
250
|
+
else:
|
|
251
|
+
segment_ids, segment_id_cast_node = self.utils.cast_input_to_int32(segment_ids)
|
|
252
|
+
|
|
253
|
+
if self.find_graph_input(mask_input):
|
|
254
|
+
casted, mask_input = self.utils.cast_graph_input_to_int32(mask_input)
|
|
255
|
+
else:
|
|
256
|
+
mask_input, mask_input_cast_node = self.utils.cast_input_to_int32(mask_input)
|
|
257
|
+
|
|
258
|
+
embed_output = self.create_node_name("embed_output")
|
|
259
|
+
embed_node = onnx.helper.make_node(
|
|
260
|
+
"EmbedLayerNormalization",
|
|
261
|
+
inputs=[
|
|
262
|
+
input_ids,
|
|
263
|
+
segment_ids,
|
|
264
|
+
word_embedding,
|
|
265
|
+
position_embedding,
|
|
266
|
+
segment_embedding,
|
|
267
|
+
normalize_node.input[1], # gamma
|
|
268
|
+
normalize_node.input[2], # beta
|
|
269
|
+
mask_input,
|
|
270
|
+
],
|
|
271
|
+
outputs=[embed_output, mask_index],
|
|
272
|
+
name="EmbedLayer",
|
|
273
|
+
)
|
|
274
|
+
embed_node.domain = "com.microsoft"
|
|
275
|
+
self.replace_input_of_all_nodes(normalize_node.output[0], embed_output)
|
|
276
|
+
self.add_node(embed_node, self.get_graph_by_node(normalize_node).name)
|
|
277
|
+
|
|
278
|
+
def process_embedding(self):
|
|
279
|
+
"""
|
|
280
|
+
Automatically detect word, segment and position embeddings.
|
|
281
|
+
"""
|
|
282
|
+
logger.info("start processing embedding layer...")
|
|
283
|
+
output_name_to_node = self.output_name_to_node()
|
|
284
|
+
|
|
285
|
+
layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
|
|
286
|
+
for layer_norm_node in layer_norm_nodes:
|
|
287
|
+
pos_embed_path = self.match_parent_path(
|
|
288
|
+
layer_norm_node,
|
|
289
|
+
["Add", "Reshape", "Slice"],
|
|
290
|
+
[0, 1, 0],
|
|
291
|
+
output_name_to_node,
|
|
292
|
+
)
|
|
293
|
+
if pos_embed_path is None:
|
|
294
|
+
continue
|
|
295
|
+
|
|
296
|
+
add_node, reshape_node, slice_node = pos_embed_path
|
|
297
|
+
initializer = self.get_initializer(slice_node.input[0])
|
|
298
|
+
if initializer is None:
|
|
299
|
+
continue
|
|
300
|
+
|
|
301
|
+
temp = numpy_helper.to_array(initializer)
|
|
302
|
+
if len(temp.shape) == 2:
|
|
303
|
+
logger.info(f"Found position embedding. name:{initializer.name}, shape:{temp.shape}")
|
|
304
|
+
position_embedding = initializer.name
|
|
305
|
+
else:
|
|
306
|
+
logger.info(f"Failed to find position embedding. name:{initializer.name}, shape:{temp.shape}")
|
|
307
|
+
return
|
|
308
|
+
|
|
309
|
+
first_parent = self.get_parent(add_node, 0, output_name_to_node)
|
|
310
|
+
if first_parent is not None and first_parent.op_type == "Add":
|
|
311
|
+
embeddings = self.get_2d_initializers_from_parent_subgraphs(first_parent)
|
|
312
|
+
if len(embeddings) != 2:
|
|
313
|
+
logger.warning(
|
|
314
|
+
f"Failed to find two embeddings (word and segment) from Add node. Found {embeddings}"
|
|
315
|
+
)
|
|
316
|
+
return
|
|
317
|
+
|
|
318
|
+
word_embedding = None
|
|
319
|
+
segment_embedding = None
|
|
320
|
+
for name, shape in embeddings.items():
|
|
321
|
+
if shape[0] == 2:
|
|
322
|
+
segment_embedding = name
|
|
323
|
+
logger.info(f"Found segment embedding. name:{name}, shape:{shape}")
|
|
324
|
+
else:
|
|
325
|
+
word_embedding = name
|
|
326
|
+
logger.info(f"Found words embedding. name:{name}, shape:{shape}")
|
|
327
|
+
|
|
328
|
+
if word_embedding is None or segment_embedding is None:
|
|
329
|
+
logger.info("Failed to find both word and segment embedding")
|
|
330
|
+
return
|
|
331
|
+
|
|
332
|
+
logger.info("Create Embedding node")
|
|
333
|
+
self.create_embedding_subgraph(
|
|
334
|
+
layer_norm_node,
|
|
335
|
+
word_embedding,
|
|
336
|
+
segment_embedding,
|
|
337
|
+
position_embedding,
|
|
338
|
+
)
|
|
339
|
+
# Prune graph to remove those original embedding nodes.
|
|
340
|
+
self.prune_graph()
|
|
341
|
+
break
|
|
342
|
+
|
|
343
|
+
def check_attention_input(self, matmul_q, matmul_k, matmul_v, parent, output_name_to_node):
|
|
344
|
+
for x in [matmul_q, matmul_k, matmul_v]:
|
|
345
|
+
root_input = x.input[0]
|
|
346
|
+
root_node = output_name_to_node[root_input]
|
|
347
|
+
if root_node == parent:
|
|
348
|
+
continue
|
|
349
|
+
logger.debug(f"Check attention input failed:{root_input}, {parent.output[0]}")
|
|
350
|
+
return False
|
|
351
|
+
|
|
352
|
+
return True
|
|
353
|
+
|
|
354
|
+
def fuse_attention(self):
|
|
355
|
+
output_name_to_node = self.output_name_to_node()
|
|
356
|
+
|
|
357
|
+
nodes_to_remove = []
|
|
358
|
+
attention_count = 0
|
|
359
|
+
|
|
360
|
+
start_nodes = []
|
|
361
|
+
skip_layer_norm_nodes = self.get_nodes_by_op_type("SkipLayerNormalization")
|
|
362
|
+
layer_norm_nodes = self.get_nodes_by_op_type("LayerNormalization")
|
|
363
|
+
# Sometimes we can not fuse skiplayernormalization since the add before layernorm has an output that used by nodes outside skiplayernorm
|
|
364
|
+
# Conceptually we treat add before layernorm as skiplayernorm node since they share the same pattern
|
|
365
|
+
start_nodes.extend(skip_layer_norm_nodes)
|
|
366
|
+
start_nodes.extend(layer_norm_nodes)
|
|
367
|
+
|
|
368
|
+
for normalize_node in start_nodes:
|
|
369
|
+
graph_name = self.get_graph_by_node(normalize_node).name
|
|
370
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
371
|
+
if normalize_node.op_type == "LayerNormalization":
|
|
372
|
+
add_before_layernorm = self.match_parent(normalize_node, "Add", 0)
|
|
373
|
+
if add_before_layernorm is not None:
|
|
374
|
+
normalize_node = add_before_layernorm # noqa: PLW2901
|
|
375
|
+
else:
|
|
376
|
+
continue
|
|
377
|
+
parent = self.get_parent(normalize_node, 1)
|
|
378
|
+
if parent is None or parent.op_type not in [
|
|
379
|
+
"SkipLayerNormalization",
|
|
380
|
+
"LayerNormalization",
|
|
381
|
+
"Reshape",
|
|
382
|
+
]:
|
|
383
|
+
parent = self.get_parent(normalize_node, 0)
|
|
384
|
+
if parent is None or parent.op_type not in [
|
|
385
|
+
"SkipLayerNormalization",
|
|
386
|
+
"LayerNormalization",
|
|
387
|
+
"Reshape",
|
|
388
|
+
]:
|
|
389
|
+
logger.debug("Failed to match parent of normalize_node")
|
|
390
|
+
continue
|
|
391
|
+
|
|
392
|
+
qkv_nodes = self.match_parent_path(
|
|
393
|
+
normalize_node,
|
|
394
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
395
|
+
[0, 0, 0, 0, 0],
|
|
396
|
+
)
|
|
397
|
+
if qkv_nodes is None:
|
|
398
|
+
qkv_nodes = self.match_parent_path(
|
|
399
|
+
normalize_node,
|
|
400
|
+
["MatMul", "Reshape", "Transpose", "MatMul"],
|
|
401
|
+
[1, 0, 0, 0],
|
|
402
|
+
)
|
|
403
|
+
if qkv_nodes is None:
|
|
404
|
+
qkv_nodes = self.match_parent_path(normalize_node, ["Add", "Einsum", "Einsum"], [0, 0, 0])
|
|
405
|
+
if qkv_nodes is None:
|
|
406
|
+
logger.debug("Failed to match qkv nodes")
|
|
407
|
+
continue
|
|
408
|
+
|
|
409
|
+
matmul_qkv = qkv_nodes[-1]
|
|
410
|
+
v_nodes = self.match_parent_path(matmul_qkv, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
|
|
411
|
+
if v_nodes is None:
|
|
412
|
+
v_nodes = self.match_parent_path(matmul_qkv, ["Add", "Einsum"], [1, 0])
|
|
413
|
+
if v_nodes is None:
|
|
414
|
+
logger.debug("Failed to match v path")
|
|
415
|
+
continue
|
|
416
|
+
|
|
417
|
+
add_v = v_nodes[-2]
|
|
418
|
+
matmul_v = v_nodes[-1]
|
|
419
|
+
qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Mul", "MatMul"], [0, 0, 0, 0])
|
|
420
|
+
if qk_nodes is None:
|
|
421
|
+
qk_nodes = self.match_parent_path(matmul_qkv, ["Softmax", "Add", "Einsum"], [0, 0, 0])
|
|
422
|
+
if qk_nodes is None:
|
|
423
|
+
logger.debug("Failed to match qk_paths")
|
|
424
|
+
continue
|
|
425
|
+
matmul_qk = qk_nodes[-1]
|
|
426
|
+
|
|
427
|
+
q_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [0, 0, 0, 0])
|
|
428
|
+
if q_nodes is None:
|
|
429
|
+
q_nodes = self.match_parent_path(matmul_qk, ["Add", "Einsum"], [0, 0])
|
|
430
|
+
if q_nodes is None:
|
|
431
|
+
logger.debug("Failed to match q path")
|
|
432
|
+
continue
|
|
433
|
+
|
|
434
|
+
add_q = q_nodes[-2]
|
|
435
|
+
matmul_q = q_nodes[-1]
|
|
436
|
+
|
|
437
|
+
k_nodes = self.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Add", "MatMul"], [1, 0, 0, 0])
|
|
438
|
+
if k_nodes is None:
|
|
439
|
+
k_nodes = self.match_parent_path(matmul_qk, ["Mul", "Add", "Einsum"], [1, 0, 0])
|
|
440
|
+
if k_nodes is None:
|
|
441
|
+
logger.debug("Failed to match k path")
|
|
442
|
+
continue
|
|
443
|
+
add_k = k_nodes[-2]
|
|
444
|
+
matmul_k = k_nodes[-1]
|
|
445
|
+
|
|
446
|
+
mask_nodes = self.match_mask_path(qk_nodes[1])
|
|
447
|
+
|
|
448
|
+
if mask_nodes is None:
|
|
449
|
+
logger.debug("Cannot find mask_nodes.")
|
|
450
|
+
continue
|
|
451
|
+
|
|
452
|
+
if not self.has_constant_input(mask_nodes[1], 1):
|
|
453
|
+
logger.debug("Sub node expected to have an input with constant value 1.0.")
|
|
454
|
+
continue
|
|
455
|
+
|
|
456
|
+
# add a squeeze node to convert a 3-d mask to 2-d
|
|
457
|
+
squeeze_node = self.match_parent_path(mask_nodes[-1], ["Squeeze"], [0]) or self.match_parent_path(
|
|
458
|
+
mask_nodes[-1], ["Expand"], [0]
|
|
459
|
+
)
|
|
460
|
+
squeeze_node_name = "Squeeze_3d_to_2d_mask"
|
|
461
|
+
squeeze_output_name = squeeze_node_name + "_output"
|
|
462
|
+
if squeeze_node is None and len(mask_nodes) == 5 and self.find_graph_input(mask_nodes[-1].input[0]) is None:
|
|
463
|
+
mask_input = mask_nodes[-1].input[1]
|
|
464
|
+
self.add_node(
|
|
465
|
+
helper.make_node(
|
|
466
|
+
"Squeeze",
|
|
467
|
+
[mask_input],
|
|
468
|
+
[squeeze_output_name],
|
|
469
|
+
squeeze_node_name,
|
|
470
|
+
axes=[1],
|
|
471
|
+
),
|
|
472
|
+
graph_name,
|
|
473
|
+
)
|
|
474
|
+
mask_nodes[-1].input[0] = squeeze_output_name
|
|
475
|
+
|
|
476
|
+
is_same_root = self.check_attention_input(matmul_q, matmul_k, matmul_v, parent, output_name_to_node)
|
|
477
|
+
if is_same_root:
|
|
478
|
+
mask_index = self.attention_mask.process_mask(mask_nodes[-1].input[0])
|
|
479
|
+
logger.debug("Create an Attention node.")
|
|
480
|
+
|
|
481
|
+
# For tf models, q and v are flipped.
|
|
482
|
+
attention_node = self.attention_fusion.create_attention_node(
|
|
483
|
+
mask_index=mask_index,
|
|
484
|
+
q_matmul=matmul_k,
|
|
485
|
+
k_matmul=matmul_q,
|
|
486
|
+
v_matmul=matmul_v,
|
|
487
|
+
q_add=add_k,
|
|
488
|
+
k_add=add_q,
|
|
489
|
+
v_add=add_v,
|
|
490
|
+
num_heads=self.num_heads,
|
|
491
|
+
hidden_size=self.hidden_size,
|
|
492
|
+
first_input=parent.output[0],
|
|
493
|
+
output=qkv_nodes[2].output[0],
|
|
494
|
+
)
|
|
495
|
+
if attention_node is None:
|
|
496
|
+
continue
|
|
497
|
+
|
|
498
|
+
if qkv_nodes[1].op_type == "Einsum":
|
|
499
|
+
# add reshape before einsum
|
|
500
|
+
tensor = helper.make_tensor(
|
|
501
|
+
name=qkv_nodes[1].name + "_newshape",
|
|
502
|
+
data_type=TensorProto.INT64,
|
|
503
|
+
dims=[4],
|
|
504
|
+
vals=np.int64(
|
|
505
|
+
[
|
|
506
|
+
[
|
|
507
|
+
0,
|
|
508
|
+
0,
|
|
509
|
+
self.num_heads,
|
|
510
|
+
int(self.hidden_size / self.num_heads),
|
|
511
|
+
]
|
|
512
|
+
]
|
|
513
|
+
).tobytes(),
|
|
514
|
+
raw=True,
|
|
515
|
+
)
|
|
516
|
+
self.add_initializer(tensor, graph_name)
|
|
517
|
+
reshape_ = helper.make_node(
|
|
518
|
+
"Reshape",
|
|
519
|
+
inputs=[
|
|
520
|
+
attention_node.output[0],
|
|
521
|
+
qkv_nodes[1].name + "_newshape",
|
|
522
|
+
],
|
|
523
|
+
outputs=[qkv_nodes[1].name + "_reshape_output"],
|
|
524
|
+
name=qkv_nodes[1].name + "_reshape",
|
|
525
|
+
)
|
|
526
|
+
qkv_nodes[1].input[0] = qkv_nodes[1].name + "_reshape_output"
|
|
527
|
+
self.add_node(reshape_, graph_name)
|
|
528
|
+
if parent.op_type == "Reshape":
|
|
529
|
+
# Temporary work around: we require the skiplayernorm and attention op be fed with 3-d input
|
|
530
|
+
hidden_size = numpy_helper.to_array(self.get_initializer(parent.input[1]))[1]
|
|
531
|
+
tensor = helper.make_tensor(
|
|
532
|
+
name=parent.name + "_modified",
|
|
533
|
+
data_type=TensorProto.INT64,
|
|
534
|
+
dims=[3],
|
|
535
|
+
vals=np.int64([[1, -1, hidden_size]]).tobytes(),
|
|
536
|
+
raw=True,
|
|
537
|
+
)
|
|
538
|
+
self.add_initializer(tensor, graph_name)
|
|
539
|
+
parent.input[1] = parent.name + "_modified"
|
|
540
|
+
|
|
541
|
+
self.add_node(attention_node, graph_name)
|
|
542
|
+
attention_count += 1
|
|
543
|
+
|
|
544
|
+
nodes_to_remove.extend(qkv_nodes[2:])
|
|
545
|
+
nodes_to_remove.extend(qk_nodes)
|
|
546
|
+
nodes_to_remove.extend(q_nodes)
|
|
547
|
+
nodes_to_remove.extend(k_nodes)
|
|
548
|
+
nodes_to_remove.extend(v_nodes)
|
|
549
|
+
nodes_to_remove.extend(mask_nodes)
|
|
550
|
+
else:
|
|
551
|
+
logger.debug("Root node not matched.")
|
|
552
|
+
continue
|
|
553
|
+
self.remove_nodes(nodes_to_remove)
|
|
554
|
+
self.update_graph()
|
|
555
|
+
logger.info(f"Fused Attention count:{attention_count}")
|
|
556
|
+
|
|
557
|
+
def preprocess(self):
|
|
558
|
+
self.remove_identity()
|
|
559
|
+
self.process_embedding()
|
|
560
|
+
self.skip_reshape()
|
|
561
|
+
|
|
562
|
+
def skip_reshape(self):
|
|
563
|
+
count = 0
|
|
564
|
+
reshape_nodes = self.get_nodes_by_op_type("Reshape")
|
|
565
|
+
for reshape_node in reshape_nodes:
|
|
566
|
+
parent = self.get_parent(reshape_node, 0)
|
|
567
|
+
if parent is not None and parent.op_type == "Reshape":
|
|
568
|
+
reshape_node.input[0] = parent.input[0]
|
|
569
|
+
count += 1
|
|
570
|
+
|
|
571
|
+
if count > 0:
|
|
572
|
+
logger.info(f"Skip consequent Reshape count: {count}")
|
|
573
|
+
|
|
574
|
+
def remove_reshape_before_first_attention(self):
|
|
575
|
+
attention_nodes = self.get_nodes_by_op_type("Attention")
|
|
576
|
+
for attention_node in attention_nodes:
|
|
577
|
+
path = self.match_parent_path(attention_node, ["Reshape", "EmbedLayerNormalization"], [0, 0])
|
|
578
|
+
if path is None:
|
|
579
|
+
continue
|
|
580
|
+
logger.info("Remove Reshape before first Attention node.")
|
|
581
|
+
reshape, _ = path
|
|
582
|
+
self.replace_input_of_all_nodes(reshape.output[0], reshape.input[0])
|
|
583
|
+
self.remove_node(reshape)
|
|
584
|
+
break
|
|
585
|
+
|
|
586
|
+
def postprocess(self):
|
|
587
|
+
self.remove_reshape_before_first_attention()
|
|
588
|
+
self.prune_graph()
|
|
@@ -0,0 +1,42 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_attention_clip import FusionAttentionClip
|
|
9
|
+
from onnx import ModelProto
|
|
10
|
+
from onnx_model_bert import BertOnnxModel
|
|
11
|
+
|
|
12
|
+
logger = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ClipOnnxModel(BertOnnxModel):
|
|
16
|
+
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
|
17
|
+
super().__init__(model, num_heads=num_heads, hidden_size=hidden_size)
|
|
18
|
+
self.clip_attention_fusion = FusionAttentionClip(self, self.hidden_size, self.num_heads)
|
|
19
|
+
|
|
20
|
+
def get_fused_operator_statistics(self):
|
|
21
|
+
"""
|
|
22
|
+
Returns node count of fused operators.
|
|
23
|
+
"""
|
|
24
|
+
op_count = {}
|
|
25
|
+
ops = [
|
|
26
|
+
"Attention",
|
|
27
|
+
"FastGelu",
|
|
28
|
+
"Gelu",
|
|
29
|
+
"LayerNormalization",
|
|
30
|
+
"QuickGelu",
|
|
31
|
+
"BiasGelu",
|
|
32
|
+
"SkipLayerNormalization",
|
|
33
|
+
]
|
|
34
|
+
for op in ops:
|
|
35
|
+
nodes = self.get_nodes_by_op_type(op)
|
|
36
|
+
op_count[op] = len(nodes)
|
|
37
|
+
|
|
38
|
+
logger.info(f"Optimized operators:{op_count}")
|
|
39
|
+
return op_count
|
|
40
|
+
|
|
41
|
+
def fuse_attention(self):
|
|
42
|
+
self.clip_attention_fusion.apply()
|
|
@@ -0,0 +1,32 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from fusion_attention import AttentionMask
|
|
8
|
+
from fusion_conformer_attention import FusionConformerAttention
|
|
9
|
+
from fusion_options import FusionOptions
|
|
10
|
+
from onnx_model_bert import BertOnnxModel
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class ConformerOnnxModel(BertOnnxModel):
|
|
16
|
+
def __init__(self, model, num_heads, hidden_size):
|
|
17
|
+
super().__init__(model, num_heads, hidden_size)
|
|
18
|
+
self.attention_mask = AttentionMask(self)
|
|
19
|
+
self.attention_fusion = FusionConformerAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
|
20
|
+
|
|
21
|
+
def optimize(self, options: FusionOptions | None = None, add_dynamic_axes: bool = False):
|
|
22
|
+
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
|
|
23
|
+
self.attention_fusion.disable_multi_head_attention_bias = (
|
|
24
|
+
False if options is None else options.disable_multi_head_attention_bias
|
|
25
|
+
)
|
|
26
|
+
super().optimize(options, add_dynamic_axes)
|
|
27
|
+
|
|
28
|
+
def fuse_attention(self):
|
|
29
|
+
self.attention_fusion.apply()
|
|
30
|
+
|
|
31
|
+
def preprocess(self):
|
|
32
|
+
self.adjust_reshape_and_expand()
|