onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,142 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
from fusion_attention import AttentionMask
|
|
9
|
+
from fusion_bart_attention import FusionBartAttention
|
|
10
|
+
from fusion_options import FusionOptions
|
|
11
|
+
from fusion_reshape import FusionReshape
|
|
12
|
+
from onnx import numpy_helper
|
|
13
|
+
from onnx_model import OnnxModel
|
|
14
|
+
from onnx_model_bert import BertOnnxModel
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FusionBartReshape(FusionReshape):
|
|
20
|
+
def __init__(self, model: OnnxModel):
|
|
21
|
+
super().__init__(model)
|
|
22
|
+
|
|
23
|
+
def fuse(self, reshape_node, input_name_to_nodes, output_name_to_node):
|
|
24
|
+
if reshape_node.input[1] not in output_name_to_node:
|
|
25
|
+
return
|
|
26
|
+
|
|
27
|
+
concat_node = output_name_to_node[reshape_node.input[1]]
|
|
28
|
+
if concat_node.op_type != "Concat" or len(concat_node.input) != 4:
|
|
29
|
+
return
|
|
30
|
+
|
|
31
|
+
path0 = self.model.match_parent_path(
|
|
32
|
+
concat_node,
|
|
33
|
+
["Unsqueeze", "Gather", "Shape"],
|
|
34
|
+
[0, 0, 0],
|
|
35
|
+
output_name_to_node,
|
|
36
|
+
)
|
|
37
|
+
if path0 is None:
|
|
38
|
+
return
|
|
39
|
+
|
|
40
|
+
(_, gather_0, shape_0) = path0
|
|
41
|
+
|
|
42
|
+
shape = []
|
|
43
|
+
gather_value = self.model.get_constant_value(gather_0.input[1])
|
|
44
|
+
if gather_value == 0:
|
|
45
|
+
shape.append(0)
|
|
46
|
+
|
|
47
|
+
path1 = self.model.match_parent_path(
|
|
48
|
+
concat_node,
|
|
49
|
+
["Unsqueeze", "Gather", "Shape"],
|
|
50
|
+
[1, 0, 0],
|
|
51
|
+
output_name_to_node,
|
|
52
|
+
)
|
|
53
|
+
if path1 is None:
|
|
54
|
+
input_1_proto = self.model.get_initializer(concat_node.input[1])
|
|
55
|
+
input_2_proto = self.model.get_initializer(concat_node.input[2])
|
|
56
|
+
input_3_proto = self.model.get_initializer(concat_node.input[3])
|
|
57
|
+
if input_1_proto is None or input_2_proto is None or input_3_proto is None:
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
input_1 = numpy_helper.to_array(input_1_proto)
|
|
61
|
+
input_2 = numpy_helper.to_array(input_2_proto)
|
|
62
|
+
input_3 = numpy_helper.to_array(input_3_proto)
|
|
63
|
+
if len(input_1) != 1 or len(input_2) != 1 or len(input_3) != 1:
|
|
64
|
+
return
|
|
65
|
+
|
|
66
|
+
if not (input_1[0] == -1 and input_2[0] > 0 and input_3[0] > 0):
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
shape.extend(input_1)
|
|
70
|
+
shape.extend(input_2)
|
|
71
|
+
shape.extend(input_3)
|
|
72
|
+
gemm_path_with_bias = self.model.match_parent_path(
|
|
73
|
+
reshape_node, ["Add", "MatMul"], [0, 1], output_name_to_node
|
|
74
|
+
)
|
|
75
|
+
gemm_path_no_bias = self.model.match_parent_path(reshape_node, ["MatMul"], [0], output_name_to_node)
|
|
76
|
+
if gemm_path_with_bias is not None:
|
|
77
|
+
gemm_path = gemm_path_with_bias
|
|
78
|
+
elif gemm_path_no_bias is not None:
|
|
79
|
+
gemm_path = gemm_path_no_bias
|
|
80
|
+
else:
|
|
81
|
+
return
|
|
82
|
+
|
|
83
|
+
top_matmul = gemm_path[-1]
|
|
84
|
+
root_input = top_matmul.input[0]
|
|
85
|
+
|
|
86
|
+
self.replace_reshape_node(shape, reshape_node, concat_node)
|
|
87
|
+
else:
|
|
88
|
+
(_, gather_1, shape_1) = path1
|
|
89
|
+
|
|
90
|
+
gather_value = self.model.get_constant_value(gather_1.input[1])
|
|
91
|
+
if gather_value == 1:
|
|
92
|
+
shape.append(0)
|
|
93
|
+
|
|
94
|
+
input_2_proto = self.model.get_initializer(concat_node.input[2])
|
|
95
|
+
input_3_proto = self.model.get_initializer(concat_node.input[3])
|
|
96
|
+
if input_2_proto is None or input_3_proto is None:
|
|
97
|
+
return
|
|
98
|
+
|
|
99
|
+
input_2 = numpy_helper.to_array(input_2_proto)
|
|
100
|
+
input_3 = numpy_helper.to_array(input_3_proto)
|
|
101
|
+
if len(input_2) != 1 or len(input_3) != 1:
|
|
102
|
+
return
|
|
103
|
+
|
|
104
|
+
if not (input_2[0] > 0 and input_3[0] > 0):
|
|
105
|
+
return
|
|
106
|
+
|
|
107
|
+
shape.extend(input_2)
|
|
108
|
+
shape.extend(input_3)
|
|
109
|
+
gemm_path = self.model.match_parent_path(
|
|
110
|
+
reshape_node, ["Mul", "Add", "MatMul"], [0, 0, 1], output_name_to_node
|
|
111
|
+
)
|
|
112
|
+
if gemm_path is None:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
top_matmul = gemm_path[-1]
|
|
116
|
+
root_input = top_matmul.input[0]
|
|
117
|
+
if shape_0.input[0] != root_input or shape_1.input[0] != root_input:
|
|
118
|
+
return
|
|
119
|
+
|
|
120
|
+
self.replace_reshape_node(shape, reshape_node, concat_node)
|
|
121
|
+
|
|
122
|
+
|
|
123
|
+
class BartOnnxModel(BertOnnxModel):
|
|
124
|
+
def __init__(self, model, num_heads, hidden_size, model_impl="hf"):
|
|
125
|
+
super().__init__(model, num_heads, hidden_size)
|
|
126
|
+
self.attention_mask = AttentionMask(self)
|
|
127
|
+
self.attention_fusion = FusionBartAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
|
128
|
+
self.bart_reshape_fusion_preprocess = FusionBartReshape(self)
|
|
129
|
+
|
|
130
|
+
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
|
|
131
|
+
self.attention_fusion.use_multi_head_attention = False if options is None else options.use_multi_head_attention
|
|
132
|
+
self.attention_fusion.disable_multi_head_attention_bias = (
|
|
133
|
+
False if options is None else options.disable_multi_head_attention_bias
|
|
134
|
+
)
|
|
135
|
+
super().optimize(options, add_dynamic_axes)
|
|
136
|
+
|
|
137
|
+
def fuse_attention(self):
|
|
138
|
+
self.attention_fusion.apply()
|
|
139
|
+
|
|
140
|
+
def preprocess(self):
|
|
141
|
+
self.adjust_reshape_and_expand()
|
|
142
|
+
self.bart_reshape_fusion_preprocess.apply()
|
|
@@ -0,0 +1,481 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import List, Optional
|
|
8
|
+
|
|
9
|
+
from convert_to_packing_mode import PackingMode
|
|
10
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
11
|
+
from fusion_bart_attention import FusionBartAttention
|
|
12
|
+
from fusion_biasgelu import FusionBiasGelu
|
|
13
|
+
from fusion_embedlayer import FusionEmbedLayerNormalization
|
|
14
|
+
from fusion_fastgelu import FusionFastGelu
|
|
15
|
+
from fusion_gelu import FusionGelu
|
|
16
|
+
from fusion_gelu_approximation import FusionGeluApproximation
|
|
17
|
+
from fusion_gemmfastgelu import FusionGemmFastGelu
|
|
18
|
+
from fusion_layernorm import FusionLayerNormalization, FusionLayerNormalizationTF
|
|
19
|
+
from fusion_options import AttentionMaskFormat, FusionOptions
|
|
20
|
+
from fusion_qordered_attention import FusionQOrderedAttention
|
|
21
|
+
from fusion_qordered_gelu import FusionQOrderedGelu
|
|
22
|
+
from fusion_qordered_layernorm import FusionQOrderedLayerNormalization
|
|
23
|
+
from fusion_qordered_matmul import FusionQOrderedMatMul
|
|
24
|
+
from fusion_quickgelu import FusionQuickGelu
|
|
25
|
+
from fusion_reshape import FusionReshape
|
|
26
|
+
from fusion_rotary_attention import FusionRotaryEmbeddings
|
|
27
|
+
from fusion_shape import FusionShape
|
|
28
|
+
from fusion_simplified_layernorm import FusionSimplifiedLayerNormalization, FusionSkipSimplifiedLayerNormalization
|
|
29
|
+
from fusion_skiplayernorm import FusionBiasSkipLayerNormalization, FusionSkipLayerNormalization
|
|
30
|
+
from fusion_utils import FusionUtils
|
|
31
|
+
from onnx import ModelProto, TensorProto, helper
|
|
32
|
+
from onnx_model import OnnxModel
|
|
33
|
+
|
|
34
|
+
logger = getLogger(__name__)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
class BertOnnxModel(OnnxModel):
|
|
38
|
+
def __init__(self, model: ModelProto, num_heads: int = 0, hidden_size: int = 0):
|
|
39
|
+
"""Initialize BERT ONNX Model.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
model (ModelProto): the ONNX model
|
|
43
|
+
num_heads (int, optional): number of attention heads. Defaults to 0 (detect the parameter automatically).
|
|
44
|
+
hidden_size (int, optional): hidden dimension. Defaults to 0 (detect the parameter automatically).
|
|
45
|
+
"""
|
|
46
|
+
assert (num_heads == 0 and hidden_size == 0) or (num_heads > 0 and hidden_size % num_heads == 0)
|
|
47
|
+
|
|
48
|
+
super().__init__(model)
|
|
49
|
+
self.num_heads = num_heads
|
|
50
|
+
self.hidden_size = hidden_size
|
|
51
|
+
|
|
52
|
+
self.attention_mask = AttentionMask(self)
|
|
53
|
+
self.attention_fusion = FusionAttention(self, self.hidden_size, self.num_heads, self.attention_mask)
|
|
54
|
+
self.qordered_attention_fusion = FusionQOrderedAttention(
|
|
55
|
+
self, self.hidden_size, self.num_heads, self.attention_mask
|
|
56
|
+
)
|
|
57
|
+
self.utils = FusionUtils(self)
|
|
58
|
+
|
|
59
|
+
def fuse_attention(self):
|
|
60
|
+
self.attention_fusion.apply()
|
|
61
|
+
# Only relevant in models with Q-DQ nodes
|
|
62
|
+
self.qordered_attention_fusion.apply()
|
|
63
|
+
|
|
64
|
+
def fuse_gelu(self):
|
|
65
|
+
fusion = FusionGelu(self)
|
|
66
|
+
fusion.apply()
|
|
67
|
+
fusion = FusionFastGelu(self)
|
|
68
|
+
fusion.apply()
|
|
69
|
+
fusion = FusionQuickGelu(self)
|
|
70
|
+
fusion.apply()
|
|
71
|
+
# Only relevant in models with Q-DQ nodes
|
|
72
|
+
fusion = FusionQOrderedGelu(self)
|
|
73
|
+
fusion.apply()
|
|
74
|
+
|
|
75
|
+
def fuse_bias_gelu(self, is_fastgelu):
|
|
76
|
+
fusion = FusionBiasGelu(self, is_fastgelu)
|
|
77
|
+
fusion.apply()
|
|
78
|
+
|
|
79
|
+
def gelu_approximation(self):
|
|
80
|
+
fusion = FusionGeluApproximation(self)
|
|
81
|
+
fusion.apply()
|
|
82
|
+
|
|
83
|
+
def fuse_gemm_fast_gelu(self):
|
|
84
|
+
fusion = FusionGemmFastGelu(self)
|
|
85
|
+
fusion.apply()
|
|
86
|
+
|
|
87
|
+
def fuse_add_bias_skip_layer_norm(self):
|
|
88
|
+
fusion = FusionBiasSkipLayerNormalization(self)
|
|
89
|
+
fusion.apply()
|
|
90
|
+
|
|
91
|
+
def fuse_reshape(self):
|
|
92
|
+
fusion = FusionReshape(self)
|
|
93
|
+
fusion.apply()
|
|
94
|
+
|
|
95
|
+
def fuse_shape(self):
|
|
96
|
+
fusion = FusionShape(self)
|
|
97
|
+
fusion.apply()
|
|
98
|
+
|
|
99
|
+
def fuse_embed_layer(self, use_mask_index):
|
|
100
|
+
fusion = FusionEmbedLayerNormalization(self, use_mask_index)
|
|
101
|
+
fusion.apply()
|
|
102
|
+
|
|
103
|
+
def fuse_layer_norm(self):
|
|
104
|
+
fusion = FusionLayerNormalization(self)
|
|
105
|
+
fusion.apply()
|
|
106
|
+
|
|
107
|
+
fusion = FusionLayerNormalizationTF(self)
|
|
108
|
+
fusion.apply()
|
|
109
|
+
|
|
110
|
+
# Only relevant in models with Q-DQ nodes
|
|
111
|
+
fusion = FusionQOrderedLayerNormalization(self)
|
|
112
|
+
fusion.apply()
|
|
113
|
+
|
|
114
|
+
def fuse_simplified_layer_norm(self):
|
|
115
|
+
fusion = FusionSimplifiedLayerNormalization(self)
|
|
116
|
+
fusion.apply()
|
|
117
|
+
|
|
118
|
+
def fuse_skip_layer_norm(self, shape_infer=True):
|
|
119
|
+
fusion = FusionSkipLayerNormalization(self, shape_infer=shape_infer)
|
|
120
|
+
fusion.apply()
|
|
121
|
+
|
|
122
|
+
def fuse_skip_simplified_layer_norm(self):
|
|
123
|
+
fusion = FusionSkipSimplifiedLayerNormalization(self)
|
|
124
|
+
fusion.apply()
|
|
125
|
+
|
|
126
|
+
def fuse_rotary_embeddings(self):
|
|
127
|
+
fusion = FusionRotaryEmbeddings(self)
|
|
128
|
+
fusion.apply()
|
|
129
|
+
# Remove non-MS domain functions
|
|
130
|
+
rot_emb_nodes = list(
|
|
131
|
+
filter(
|
|
132
|
+
lambda node: node.op_type == "RotaryEmbedding" and node.domain != "com.microsoft",
|
|
133
|
+
self.model.graph.node,
|
|
134
|
+
)
|
|
135
|
+
)
|
|
136
|
+
non_ms_domains_to_keep = set(map(lambda node: node.domain, rot_emb_nodes))
|
|
137
|
+
i = 0
|
|
138
|
+
while i < len(self.model.functions):
|
|
139
|
+
fn = self.model.functions[i]
|
|
140
|
+
if "RotaryEmbedding" in fn.name and fn.domain not in non_ms_domains_to_keep:
|
|
141
|
+
self.model.functions.remove(fn)
|
|
142
|
+
else:
|
|
143
|
+
i += 1
|
|
144
|
+
|
|
145
|
+
# Only relevant in models with Q-DQ nodes
|
|
146
|
+
def fuse_qordered_mamtul(self):
|
|
147
|
+
fusion = FusionQOrderedMatMul(self)
|
|
148
|
+
fusion.apply()
|
|
149
|
+
|
|
150
|
+
def get_graph_inputs_from_node_type(self, op_type: str, input_indices: List[int], casted: bool):
|
|
151
|
+
"""
|
|
152
|
+
Get graph inputs that feed into node type (like EmbedLayerNormalization or Attention).
|
|
153
|
+
Returns a list of the graph input names based on the filter whether it is casted or not.
|
|
154
|
+
"""
|
|
155
|
+
graph_inputs = []
|
|
156
|
+
|
|
157
|
+
output_name_to_node = self.output_name_to_node()
|
|
158
|
+
nodes = self.get_nodes_by_op_type(op_type)
|
|
159
|
+
for node in nodes:
|
|
160
|
+
bert_inputs = [node.input[i] for i in input_indices if i < len(node.input)]
|
|
161
|
+
for bert_input in bert_inputs:
|
|
162
|
+
if self.find_graph_input(bert_input):
|
|
163
|
+
if not casted:
|
|
164
|
+
graph_inputs.append(bert_input)
|
|
165
|
+
elif bert_input in output_name_to_node:
|
|
166
|
+
parent = output_name_to_node[bert_input]
|
|
167
|
+
if parent.op_type == "Cast" and self.find_graph_input(parent.input[0]) is not None:
|
|
168
|
+
if casted:
|
|
169
|
+
graph_inputs.append(parent.input[0])
|
|
170
|
+
return graph_inputs
|
|
171
|
+
|
|
172
|
+
def get_graph_inputs_from_fused_nodes(self, casted: bool):
|
|
173
|
+
inputs = self.get_graph_inputs_from_node_type("EmbedLayerNormalization", [0, 1, 7], casted)
|
|
174
|
+
inputs += self.get_graph_inputs_from_node_type("Attention", [3], casted)
|
|
175
|
+
return inputs
|
|
176
|
+
|
|
177
|
+
def change_graph_inputs_to_int32(self):
|
|
178
|
+
"""Change data type of all graph inputs to int32 type, and add Cast node if needed."""
|
|
179
|
+
graph = self.graph()
|
|
180
|
+
add_cast_count = 0
|
|
181
|
+
remove_cast_count = 0
|
|
182
|
+
for graph_input in graph.input:
|
|
183
|
+
new_node, removed_nodes = self.change_graph_input_type(graph_input, TensorProto.INT32)
|
|
184
|
+
if new_node:
|
|
185
|
+
add_cast_count += 1
|
|
186
|
+
remove_cast_count += len(removed_nodes)
|
|
187
|
+
logger.info(
|
|
188
|
+
f"Graph inputs are changed to int32. Added {add_cast_count} Cast nodes, and removed {remove_cast_count} Cast nodes."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
def use_dynamic_axes(self, dynamic_batch_dim="batch_size", dynamic_seq_len="max_seq_len"):
|
|
192
|
+
"""
|
|
193
|
+
Update input and output shape to use dynamic axes.
|
|
194
|
+
"""
|
|
195
|
+
bert_graph_inputs = self.get_graph_inputs_from_fused_nodes(
|
|
196
|
+
casted=True
|
|
197
|
+
) + self.get_graph_inputs_from_fused_nodes(casted=False)
|
|
198
|
+
|
|
199
|
+
for input in self.model.graph.input:
|
|
200
|
+
if input.name in bert_graph_inputs:
|
|
201
|
+
dim_proto = input.type.tensor_type.shape.dim[0]
|
|
202
|
+
dim_proto.dim_param = dynamic_batch_dim
|
|
203
|
+
if dynamic_seq_len is not None:
|
|
204
|
+
dim_proto = input.type.tensor_type.shape.dim[1]
|
|
205
|
+
dim_proto.dim_param = dynamic_seq_len
|
|
206
|
+
|
|
207
|
+
for output in self.model.graph.output:
|
|
208
|
+
dim_proto = output.type.tensor_type.shape.dim[0]
|
|
209
|
+
dim_proto.dim_param = dynamic_batch_dim
|
|
210
|
+
|
|
211
|
+
def preprocess(self):
|
|
212
|
+
self.adjust_reshape_and_expand()
|
|
213
|
+
return
|
|
214
|
+
|
|
215
|
+
def adjust_reshape_and_expand(self):
|
|
216
|
+
nodes_to_remove = []
|
|
217
|
+
for node in self.nodes():
|
|
218
|
+
if node.op_type == "Reshape":
|
|
219
|
+
# Clean up unnecessary reshape nodes.
|
|
220
|
+
# Find reshape nodes with no actually data in "shape" attribute and remove.
|
|
221
|
+
reshape_shape = self.get_constant_value(node.input[1])
|
|
222
|
+
if reshape_shape is not None and reshape_shape.size == 0:
|
|
223
|
+
nodes_to_remove.extend([node])
|
|
224
|
+
self.replace_input_of_all_nodes(node.output[0], node.input[0])
|
|
225
|
+
continue
|
|
226
|
+
|
|
227
|
+
# Find path "Slice" -> "Reshape" -> "Expand" -> "Expand" -> current "Reshape", simplify the graph by
|
|
228
|
+
# changing current reshape's input to output of slice.
|
|
229
|
+
reshape_path = self.match_parent_path(
|
|
230
|
+
node,
|
|
231
|
+
["Expand", "Expand", "Reshape", "Slice"],
|
|
232
|
+
[0, 0, 0, 0],
|
|
233
|
+
self.output_name_to_node(),
|
|
234
|
+
)
|
|
235
|
+
if reshape_path is not None:
|
|
236
|
+
expand_node = reshape_path[-3]
|
|
237
|
+
expand_shape_value = self.get_constant_value(expand_node.input[1])
|
|
238
|
+
|
|
239
|
+
reshape_before_expand = reshape_path[-2]
|
|
240
|
+
shape_value = self.get_constant_value(reshape_before_expand.input[1])
|
|
241
|
+
|
|
242
|
+
slice_node = reshape_path[-1]
|
|
243
|
+
if (
|
|
244
|
+
expand_shape_value is not None
|
|
245
|
+
and shape_value is not None
|
|
246
|
+
and len(expand_shape_value) == 2
|
|
247
|
+
and len(shape_value) == 1
|
|
248
|
+
and expand_shape_value[1] == shape_value[0]
|
|
249
|
+
):
|
|
250
|
+
node.input[0] = slice_node.output[0]
|
|
251
|
+
|
|
252
|
+
if nodes_to_remove:
|
|
253
|
+
self.remove_nodes(nodes_to_remove)
|
|
254
|
+
logger.info(f"Removed Reshape and Expand count: {len(nodes_to_remove)}")
|
|
255
|
+
|
|
256
|
+
def clean_graph(self):
|
|
257
|
+
output_name_to_node = self.output_name_to_node()
|
|
258
|
+
nodes_to_remove = []
|
|
259
|
+
for node in self.nodes():
|
|
260
|
+
# Before:
|
|
261
|
+
# input_ids --> Shape --> Gather(indices=0) --> Unsqueeze ------+
|
|
262
|
+
# | |
|
|
263
|
+
# | v
|
|
264
|
+
# +----> Shape --> Gather(indices=1) --> Unsqueeze---> Concat --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
|
|
265
|
+
# After:
|
|
266
|
+
# input_ids --> Shape --> ConstantOfShape -->Cast --> EmbedLayerNormaliation/ReduceSum
|
|
267
|
+
# TODO: merge ConstantOfShape -->Cast to ConstantOfShape (need update the data type of value)
|
|
268
|
+
op_input_id = {"EmbedLayerNormalization": 1, "ReduceSum": 0, "Attention": 3}
|
|
269
|
+
if node.op_type in op_input_id:
|
|
270
|
+
i = op_input_id[node.op_type]
|
|
271
|
+
parent_nodes = self.match_parent_path(
|
|
272
|
+
node,
|
|
273
|
+
[
|
|
274
|
+
"Cast",
|
|
275
|
+
"ConstantOfShape",
|
|
276
|
+
"Concat",
|
|
277
|
+
"Unsqueeze",
|
|
278
|
+
"Gather",
|
|
279
|
+
"Shape",
|
|
280
|
+
],
|
|
281
|
+
[i, 0, 0, 0, 0, 0],
|
|
282
|
+
output_name_to_node,
|
|
283
|
+
)
|
|
284
|
+
if parent_nodes is not None:
|
|
285
|
+
(
|
|
286
|
+
cast,
|
|
287
|
+
constantOfShape, # noqa: N806
|
|
288
|
+
concat,
|
|
289
|
+
unsqueeze,
|
|
290
|
+
gather,
|
|
291
|
+
shape,
|
|
292
|
+
) = parent_nodes
|
|
293
|
+
if shape.input[0] == self.graph().input[0].name:
|
|
294
|
+
constantOfShape.input[0] = shape.output[0]
|
|
295
|
+
output_name_to_node = self.output_name_to_node()
|
|
296
|
+
|
|
297
|
+
if node.op_type == "Attention":
|
|
298
|
+
# Before:
|
|
299
|
+
# input_ids --> Shape -->ConstantOfShape -->Cast --> ReduceSum --> Attention
|
|
300
|
+
# After:
|
|
301
|
+
# remove this path, and remove the optional mask_index input of Attention node.
|
|
302
|
+
parent_nodes = self.match_parent_path(
|
|
303
|
+
node,
|
|
304
|
+
["ReduceSum", "Cast", "ConstantOfShape", "Shape"],
|
|
305
|
+
[3, 0, 0, 0],
|
|
306
|
+
output_name_to_node,
|
|
307
|
+
)
|
|
308
|
+
if parent_nodes is not None:
|
|
309
|
+
if parent_nodes[-1].input[0] == self.graph().input[0].name:
|
|
310
|
+
attention_node = helper.make_node(
|
|
311
|
+
"Attention",
|
|
312
|
+
inputs=node.input[0 : len(node.input) - 1],
|
|
313
|
+
outputs=node.output,
|
|
314
|
+
name=node.name + "_remove_mask",
|
|
315
|
+
)
|
|
316
|
+
attention_node.domain = "com.microsoft"
|
|
317
|
+
attention_node.attribute.extend([helper.make_attribute("num_heads", self.num_heads)])
|
|
318
|
+
self.add_node(attention_node, self.get_graph_by_node(attention_node).name)
|
|
319
|
+
nodes_to_remove.append(node)
|
|
320
|
+
self.remove_nodes(nodes_to_remove)
|
|
321
|
+
|
|
322
|
+
def postprocess(self):
|
|
323
|
+
self.clean_graph()
|
|
324
|
+
self.prune_graph()
|
|
325
|
+
|
|
326
|
+
def optimize(self, options: Optional[FusionOptions] = None, add_dynamic_axes: bool = False):
|
|
327
|
+
if (options is not None) and not options.enable_shape_inference:
|
|
328
|
+
self.disable_shape_inference()
|
|
329
|
+
|
|
330
|
+
self.utils.remove_identity_nodes()
|
|
331
|
+
|
|
332
|
+
# Remove cast nodes that having same data type of input and output based on symbolic shape inference.
|
|
333
|
+
self.utils.remove_useless_cast_nodes()
|
|
334
|
+
|
|
335
|
+
if (options is None) or options.enable_layer_norm:
|
|
336
|
+
self.fuse_layer_norm()
|
|
337
|
+
self.fuse_simplified_layer_norm()
|
|
338
|
+
|
|
339
|
+
if (options is None) or options.enable_gelu:
|
|
340
|
+
self.fuse_gelu()
|
|
341
|
+
|
|
342
|
+
self.preprocess()
|
|
343
|
+
|
|
344
|
+
self.fuse_reshape()
|
|
345
|
+
|
|
346
|
+
if (options is None) or options.enable_skip_layer_norm:
|
|
347
|
+
self.fuse_skip_layer_norm(options.enable_shape_inference)
|
|
348
|
+
self.fuse_skip_simplified_layer_norm()
|
|
349
|
+
|
|
350
|
+
if (options is None) or options.enable_rotary_embeddings:
|
|
351
|
+
self.fuse_rotary_embeddings()
|
|
352
|
+
|
|
353
|
+
if options is not None:
|
|
354
|
+
self.attention_mask.set_mask_format(options.attention_mask_format)
|
|
355
|
+
if options.use_multi_head_attention and not isinstance(self.attention_fusion, FusionBartAttention):
|
|
356
|
+
self.attention_fusion = FusionAttention(
|
|
357
|
+
self,
|
|
358
|
+
self.hidden_size,
|
|
359
|
+
self.num_heads,
|
|
360
|
+
self.attention_mask,
|
|
361
|
+
options.use_multi_head_attention,
|
|
362
|
+
)
|
|
363
|
+
|
|
364
|
+
if (options is None) or options.enable_attention:
|
|
365
|
+
self.fuse_attention()
|
|
366
|
+
|
|
367
|
+
# Perform the MatMul fusion after the Attention fusion as we do not
|
|
368
|
+
# want to fuse the MatMuls inside the Attention subgraphs
|
|
369
|
+
if (options is None) or options.enable_qordered_matmul:
|
|
370
|
+
self.fuse_qordered_mamtul()
|
|
371
|
+
|
|
372
|
+
self.fuse_shape()
|
|
373
|
+
|
|
374
|
+
if (options is None) or options.enable_embed_layer_norm:
|
|
375
|
+
use_mask_index = options.attention_mask_format == AttentionMaskFormat.MaskIndexEnd
|
|
376
|
+
self.fuse_embed_layer(use_mask_index)
|
|
377
|
+
|
|
378
|
+
# Remove reshape nodes that having same shape of input and output based on symbolic shape inference.
|
|
379
|
+
self.utils.remove_useless_reshape_nodes()
|
|
380
|
+
|
|
381
|
+
self.postprocess()
|
|
382
|
+
|
|
383
|
+
# Bias fusion is done after postprocess to avoid extra Reshape between bias and Gelu/FastGelu/SkipLayerNormalization
|
|
384
|
+
if (options is None) or options.enable_bias_gelu:
|
|
385
|
+
# Fuse Gelu and Add Bias before it.
|
|
386
|
+
self.fuse_bias_gelu(is_fastgelu=True)
|
|
387
|
+
self.fuse_bias_gelu(is_fastgelu=False)
|
|
388
|
+
|
|
389
|
+
if (options is None) or options.enable_bias_skip_layer_norm:
|
|
390
|
+
# Fuse SkipLayerNormalization and Add Bias before it.
|
|
391
|
+
self.fuse_add_bias_skip_layer_norm()
|
|
392
|
+
|
|
393
|
+
if options is not None and options.enable_gelu_approximation:
|
|
394
|
+
self.gelu_approximation()
|
|
395
|
+
|
|
396
|
+
if options is not None and options.enable_gemm_fast_gelu:
|
|
397
|
+
self.fuse_gemm_fast_gelu()
|
|
398
|
+
|
|
399
|
+
self.remove_unused_constant()
|
|
400
|
+
|
|
401
|
+
# Use symbolic batch dimension in input and output.
|
|
402
|
+
if add_dynamic_axes:
|
|
403
|
+
self.use_dynamic_axes()
|
|
404
|
+
|
|
405
|
+
logger.info(f"opset version: {self.get_opset_version()}")
|
|
406
|
+
|
|
407
|
+
def get_fused_operator_statistics(self):
|
|
408
|
+
"""
|
|
409
|
+
Returns node count of fused operators.
|
|
410
|
+
"""
|
|
411
|
+
op_count = {}
|
|
412
|
+
ops = [
|
|
413
|
+
"EmbedLayerNormalization",
|
|
414
|
+
"Attention",
|
|
415
|
+
"MultiHeadAttention",
|
|
416
|
+
"Gelu",
|
|
417
|
+
"FastGelu",
|
|
418
|
+
"BiasGelu",
|
|
419
|
+
"GemmFastGelu",
|
|
420
|
+
"LayerNormalization",
|
|
421
|
+
"SimplifiedLayerNormalization",
|
|
422
|
+
"SkipLayerNormalization",
|
|
423
|
+
"SkipSimplifiedLayerNormalization",
|
|
424
|
+
"RotaryEmbedding",
|
|
425
|
+
]
|
|
426
|
+
q_ops = [
|
|
427
|
+
"QOrderedAttention",
|
|
428
|
+
"QOrderedGelu",
|
|
429
|
+
"QOrderedLayerNormalization",
|
|
430
|
+
"QOrderedMatMul",
|
|
431
|
+
]
|
|
432
|
+
for op in ops + q_ops:
|
|
433
|
+
nodes = self.get_nodes_by_op_type(op)
|
|
434
|
+
op_count[op] = len(nodes)
|
|
435
|
+
|
|
436
|
+
logger.info(f"Optimized operators: {op_count}")
|
|
437
|
+
return op_count
|
|
438
|
+
|
|
439
|
+
def is_fully_optimized(self, fused_op_count=None):
|
|
440
|
+
"""
|
|
441
|
+
Returns True when the model is fully optimized.
|
|
442
|
+
"""
|
|
443
|
+
if fused_op_count is None:
|
|
444
|
+
fused_op_count = self.get_fused_operator_statistics()
|
|
445
|
+
|
|
446
|
+
def op_count(op_name: str):
|
|
447
|
+
return fused_op_count.get(op_name) or 0
|
|
448
|
+
|
|
449
|
+
embed = op_count("EmbedLayerNormalization")
|
|
450
|
+
attention = op_count("Attention") + op_count("MultiHeadAttention") + op_count("QOrderedAttention")
|
|
451
|
+
gelu = op_count("Gelu") + op_count("BiasGelu") + op_count("FastGelu")
|
|
452
|
+
layer_norm = op_count("LayerNormalization") + op_count("SkipLayerNormalization")
|
|
453
|
+
simple_layer_norm = op_count("SimplifiedLayerNormalization") + op_count("SkipSimplifiedLayerNormalization")
|
|
454
|
+
|
|
455
|
+
is_perfect = (
|
|
456
|
+
(embed > 0)
|
|
457
|
+
and (attention > 0)
|
|
458
|
+
and (attention == gelu)
|
|
459
|
+
and ((layer_norm >= 2 * attention) or (simple_layer_norm >= 2 * attention))
|
|
460
|
+
)
|
|
461
|
+
|
|
462
|
+
if layer_norm == 0:
|
|
463
|
+
logger.debug("Layer Normalization not fused")
|
|
464
|
+
|
|
465
|
+
if simple_layer_norm == 0:
|
|
466
|
+
logger.debug("Simple Layer Normalization not fused")
|
|
467
|
+
|
|
468
|
+
if gelu == 0:
|
|
469
|
+
logger.debug("Gelu (or FastGelu) not fused")
|
|
470
|
+
|
|
471
|
+
if embed == 0:
|
|
472
|
+
logger.debug("EmbedLayerNormalization not fused")
|
|
473
|
+
|
|
474
|
+
if attention == 0:
|
|
475
|
+
logger.warning("Attention (or MultiHeadAttention) not fused")
|
|
476
|
+
|
|
477
|
+
return is_perfect
|
|
478
|
+
|
|
479
|
+
def convert_to_packing_mode(self, use_symbolic_shape_infer: bool = False):
|
|
480
|
+
packing_mode = PackingMode(self)
|
|
481
|
+
packing_mode.convert(use_symbolic_shape_infer)
|