onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,122 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
from typing import Dict, List, Union
|
|
8
|
+
|
|
9
|
+
from fusion_base import Fusion
|
|
10
|
+
from fusion_utils import NumpyHelper
|
|
11
|
+
from onnx import NodeProto, TensorProto, helper
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
|
|
14
|
+
logger = getLogger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class FusionGemmFastGelu(Fusion):
|
|
18
|
+
def __init__(self, model: OnnxModel):
|
|
19
|
+
super().__init__(model, "GemmFastGelu", "FastGelu", "GemmFastGelu")
|
|
20
|
+
self.shape_infer = None
|
|
21
|
+
self.shape_infer_done = False
|
|
22
|
+
|
|
23
|
+
def get_dimensions_from_tensor_proto(self, tensor_proto: TensorProto) -> Union[int, None]:
|
|
24
|
+
if tensor_proto.type.tensor_type.HasField("shape"):
|
|
25
|
+
return len(tensor_proto.type.tensor_type.shape.dim)
|
|
26
|
+
else:
|
|
27
|
+
return None
|
|
28
|
+
|
|
29
|
+
def get_dimensions(self, input_name: str) -> Union[int, None]:
|
|
30
|
+
graph_input = self.model.find_graph_input(input_name)
|
|
31
|
+
if graph_input:
|
|
32
|
+
return self.get_dimensions_from_tensor_proto(graph_input)
|
|
33
|
+
|
|
34
|
+
if not self.shape_infer_done:
|
|
35
|
+
self.shape_infer = self.model.infer_runtime_shape(update=True)
|
|
36
|
+
self.shape_infer_done = True
|
|
37
|
+
|
|
38
|
+
if self.shape_infer is not None:
|
|
39
|
+
return self.get_dimensions_from_tensor_proto(self.shape_infer.known_vi_[input_name])
|
|
40
|
+
|
|
41
|
+
return None
|
|
42
|
+
|
|
43
|
+
def fuse(
|
|
44
|
+
self,
|
|
45
|
+
node: NodeProto,
|
|
46
|
+
input_name_to_nodes: Dict[str, List[NodeProto]],
|
|
47
|
+
output_name_to_node: Dict[str, NodeProto],
|
|
48
|
+
):
|
|
49
|
+
"""
|
|
50
|
+
This pattern is from PyTorch bert model
|
|
51
|
+
Fuse MatMul with FastGelu into one node:
|
|
52
|
+
|
|
53
|
+
[root] --> MatMul --> FastGelu -->
|
|
54
|
+
|
|
55
|
+
"""
|
|
56
|
+
has_bias = False
|
|
57
|
+
if len(node.input) == 2:
|
|
58
|
+
has_bias = True
|
|
59
|
+
|
|
60
|
+
match_nodes = self.model.match_parent_path(node, ["MatMul"], [0])
|
|
61
|
+
if match_nodes is None:
|
|
62
|
+
return
|
|
63
|
+
matmul = match_nodes[0]
|
|
64
|
+
|
|
65
|
+
# matmul input X should >= two dimension, input weight should be two dimension
|
|
66
|
+
weight_index = -1
|
|
67
|
+
x_dims = 0
|
|
68
|
+
weight = None
|
|
69
|
+
|
|
70
|
+
for i, input in enumerate(matmul.input):
|
|
71
|
+
initializer = self.model.get_initializer(input)
|
|
72
|
+
if initializer is None:
|
|
73
|
+
x_dims = self.get_dimensions(matmul.input[i])
|
|
74
|
+
else:
|
|
75
|
+
weight_index = i
|
|
76
|
+
weight = NumpyHelper.to_array(initializer)
|
|
77
|
+
if weight is None:
|
|
78
|
+
return
|
|
79
|
+
if len(weight.shape) != 2:
|
|
80
|
+
return
|
|
81
|
+
if x_dims < len(weight.shape):
|
|
82
|
+
return
|
|
83
|
+
|
|
84
|
+
# bias weight should be one dimension
|
|
85
|
+
bias_index = -1
|
|
86
|
+
if has_bias:
|
|
87
|
+
bias_weight = None
|
|
88
|
+
for i, input in enumerate(node.input):
|
|
89
|
+
initializer = self.model.get_initializer(input)
|
|
90
|
+
if initializer is None:
|
|
91
|
+
continue
|
|
92
|
+
bias_index = i
|
|
93
|
+
bias_weight = NumpyHelper.to_array(initializer)
|
|
94
|
+
break
|
|
95
|
+
if bias_weight is None:
|
|
96
|
+
return
|
|
97
|
+
if len(bias_weight.shape) != 1:
|
|
98
|
+
return
|
|
99
|
+
|
|
100
|
+
subgraph_nodes = [node, matmul]
|
|
101
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
102
|
+
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
|
|
103
|
+
):
|
|
104
|
+
return
|
|
105
|
+
|
|
106
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
107
|
+
|
|
108
|
+
inputs = (
|
|
109
|
+
[matmul.input[1 - weight_index], matmul.input[weight_index], node.input[bias_index]]
|
|
110
|
+
if has_bias
|
|
111
|
+
else [matmul.input[1 - weight_index], matmul.input[weight_index]]
|
|
112
|
+
)
|
|
113
|
+
|
|
114
|
+
fused_node = helper.make_node(
|
|
115
|
+
"GemmFastGelu",
|
|
116
|
+
inputs=inputs,
|
|
117
|
+
outputs=node.output,
|
|
118
|
+
name=self.model.create_node_name("GemmFastGelu"),
|
|
119
|
+
)
|
|
120
|
+
fused_node.domain = "com.microsoft"
|
|
121
|
+
self.nodes_to_add.append(fused_node)
|
|
122
|
+
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|
|
@@ -0,0 +1,546 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import FusionUtils
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionGptAttentionPastBase(Fusion):
|
|
17
|
+
"""Base class for GPT Attention Fusion with past state"""
|
|
18
|
+
|
|
19
|
+
def __init__(self, model: OnnxModel, num_heads: int):
|
|
20
|
+
super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "with past")
|
|
21
|
+
self.num_heads = num_heads
|
|
22
|
+
self.utils = FusionUtils(model)
|
|
23
|
+
self.casted_attention_mask = {} # map from name of attention mask to the name that casted to int32
|
|
24
|
+
self.mask_filter_value = None
|
|
25
|
+
|
|
26
|
+
def match_past_pattern_1(self, concat_k, concat_v, output_name_to_node):
|
|
27
|
+
# Pattern 1:
|
|
28
|
+
# {past}
|
|
29
|
+
# / \
|
|
30
|
+
# / \
|
|
31
|
+
# Gather(axes=0, indices=0) Gather(indices=1)
|
|
32
|
+
# | |
|
|
33
|
+
# Transpose (perm=0,1,3,2) |
|
|
34
|
+
# | |
|
|
35
|
+
# Concat_k Concat_v
|
|
36
|
+
# | /
|
|
37
|
+
# Transpose (perm=0,1,3,2) /
|
|
38
|
+
# | /
|
|
39
|
+
# Unsqueeze Unsqueeze
|
|
40
|
+
# \ /
|
|
41
|
+
# \ /
|
|
42
|
+
# Concat
|
|
43
|
+
# |
|
|
44
|
+
# {present}
|
|
45
|
+
gather = self.model.get_parent(concat_v, 0, output_name_to_node)
|
|
46
|
+
if gather is None or gather.op_type != "Gather":
|
|
47
|
+
logger.debug("match_past_pattern_1: expect Gather for past")
|
|
48
|
+
return None
|
|
49
|
+
|
|
50
|
+
if self.model.find_constant_input(gather, 1) != 1:
|
|
51
|
+
logger.debug("match_past_pattern_1: expect indices=1 for Gather of past")
|
|
52
|
+
return None
|
|
53
|
+
past = gather.input[0]
|
|
54
|
+
|
|
55
|
+
parent = self.model.get_parent(concat_k, 0, output_name_to_node)
|
|
56
|
+
if parent and parent.op_type == "Gather":
|
|
57
|
+
gather_past_k = parent
|
|
58
|
+
else:
|
|
59
|
+
past_k_nodes = self.model.match_parent_path(concat_k, ["Transpose", "Gather"], [0, 0])
|
|
60
|
+
if past_k_nodes is None:
|
|
61
|
+
logger.debug("match_past_pattern_1: failed match Transpose and Gather")
|
|
62
|
+
return None
|
|
63
|
+
gather_past_k = past_k_nodes[-1]
|
|
64
|
+
|
|
65
|
+
if self.model.find_constant_input(gather_past_k, 0) != 1:
|
|
66
|
+
logger.debug("match_past_pattern_1: expect indices=0 for Gather k of past")
|
|
67
|
+
return None
|
|
68
|
+
past_k = gather_past_k.input[0]
|
|
69
|
+
if past != past_k:
|
|
70
|
+
logger.debug("match_past_pattern_1: expect past to be same")
|
|
71
|
+
return None
|
|
72
|
+
|
|
73
|
+
return past
|
|
74
|
+
|
|
75
|
+
def match_past_pattern_2(self, concat_k, concat_v, output_name_to_node):
|
|
76
|
+
# Pattern 2:
|
|
77
|
+
# Split (QKV)
|
|
78
|
+
# / | |
|
|
79
|
+
# / | +----------------------+
|
|
80
|
+
# | |
|
|
81
|
+
# | {past} |
|
|
82
|
+
# | | |
|
|
83
|
+
# Reshape Split Reshape
|
|
84
|
+
# | / \ |
|
|
85
|
+
# Transpose_k Squeeze Squeeze Transpose_v
|
|
86
|
+
# | | \ /
|
|
87
|
+
# +------|---+ \ /
|
|
88
|
+
# | | \ /
|
|
89
|
+
# Concat_k Concat_v
|
|
90
|
+
# | |
|
|
91
|
+
# Unsqueeze Unsqueeze
|
|
92
|
+
# \ /
|
|
93
|
+
# Concat
|
|
94
|
+
# |
|
|
95
|
+
# {present}
|
|
96
|
+
#
|
|
97
|
+
squeeze = self.model.get_parent(concat_v, 0, output_name_to_node)
|
|
98
|
+
if squeeze is None or squeeze.op_type != "Squeeze":
|
|
99
|
+
logger.debug("match_past_pattern_2: expect Squeeze as parent of concat_v")
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
split = self.model.get_parent(squeeze, 0, output_name_to_node)
|
|
103
|
+
if split is None or split.op_type != "Split":
|
|
104
|
+
logger.debug("match_past_pattern_2: expect Split for past path")
|
|
105
|
+
return None
|
|
106
|
+
|
|
107
|
+
opset_version = self.model.get_opset_version()
|
|
108
|
+
if opset_version < 13:
|
|
109
|
+
if not FusionUtils.check_node_attribute(squeeze, "axes", [0]):
|
|
110
|
+
logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
|
|
111
|
+
return None
|
|
112
|
+
|
|
113
|
+
if not FusionUtils.check_node_attribute(split, "split", [1, 1]):
|
|
114
|
+
logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
|
|
115
|
+
return None
|
|
116
|
+
else:
|
|
117
|
+
if not self.utils.check_node_input_value(squeeze, 1, [0]):
|
|
118
|
+
logger.debug("match_past_pattern_2: axes != [0] for Squeeze in past path")
|
|
119
|
+
return None
|
|
120
|
+
|
|
121
|
+
if not self.utils.check_node_input_value(split, 1, [1, 1]):
|
|
122
|
+
logger.debug("match_past_pattern_2: split != [1, 1] for Split in past path")
|
|
123
|
+
return None
|
|
124
|
+
|
|
125
|
+
if not FusionUtils.check_node_attribute(split, "axis", 0, default_value=0):
|
|
126
|
+
logger.debug("match_past_pattern_2: attribute axis of Split are not expected in past path")
|
|
127
|
+
return None
|
|
128
|
+
past = split.input[0]
|
|
129
|
+
|
|
130
|
+
past_k_nodes = self.model.match_parent_path(concat_k, ["Squeeze", "Split"], [0, 0])
|
|
131
|
+
if past_k_nodes is None:
|
|
132
|
+
logger.debug("match_past_pattern_2: failed to match past_k_nodes path")
|
|
133
|
+
return None
|
|
134
|
+
past_k = past_k_nodes[-1].input[0]
|
|
135
|
+
|
|
136
|
+
if past != past_k:
|
|
137
|
+
logger.info("match_past_pattern_2: expect past to be same")
|
|
138
|
+
return None
|
|
139
|
+
|
|
140
|
+
return past
|
|
141
|
+
|
|
142
|
+
def match_present(self, concat_v, input_name_to_nodes):
|
|
143
|
+
unsqueeze_present_v = self.model.find_first_child_by_type(
|
|
144
|
+
concat_v, "Unsqueeze", input_name_to_nodes, recursive=False
|
|
145
|
+
)
|
|
146
|
+
if not unsqueeze_present_v:
|
|
147
|
+
logger.info("expect unsqueeze for present")
|
|
148
|
+
return None
|
|
149
|
+
concat_present = self.model.find_first_child_by_type(
|
|
150
|
+
unsqueeze_present_v, "Concat", input_name_to_nodes, recursive=False
|
|
151
|
+
)
|
|
152
|
+
if not concat_present:
|
|
153
|
+
logger.info("expect concat for present")
|
|
154
|
+
return None
|
|
155
|
+
|
|
156
|
+
present = concat_present.output[0]
|
|
157
|
+
return present
|
|
158
|
+
|
|
159
|
+
def cast_attention_mask(self, input_name):
|
|
160
|
+
if input_name in self.casted_attention_mask:
|
|
161
|
+
attention_mask_input_name = self.casted_attention_mask[input_name]
|
|
162
|
+
elif self.model.find_graph_input(input_name):
|
|
163
|
+
casted, attention_mask_input_name = self.utils.cast_graph_input_to_int32(input_name)
|
|
164
|
+
self.casted_attention_mask[input_name] = attention_mask_input_name
|
|
165
|
+
else:
|
|
166
|
+
attention_mask_input_name, cast_node = self.utils.cast_input_to_int32(input_name)
|
|
167
|
+
self.casted_attention_mask[input_name] = attention_mask_input_name
|
|
168
|
+
return attention_mask_input_name
|
|
169
|
+
|
|
170
|
+
|
|
171
|
+
class FusionGptAttention(FusionGptAttentionPastBase):
|
|
172
|
+
"""
|
|
173
|
+
Fuse GPT-2 Attention with past state subgraph into one Attention node.
|
|
174
|
+
"""
|
|
175
|
+
|
|
176
|
+
def __init__(self, model: OnnxModel, num_heads: int):
|
|
177
|
+
super().__init__(model, num_heads)
|
|
178
|
+
|
|
179
|
+
def create_attention_node(
|
|
180
|
+
self,
|
|
181
|
+
fc_weight,
|
|
182
|
+
fc_bias,
|
|
183
|
+
gemm_qkv,
|
|
184
|
+
past,
|
|
185
|
+
present,
|
|
186
|
+
input,
|
|
187
|
+
output,
|
|
188
|
+
mask,
|
|
189
|
+
is_unidirectional,
|
|
190
|
+
):
|
|
191
|
+
attention_node_name = self.model.create_node_name("GptAttention")
|
|
192
|
+
attention_node = helper.make_node(
|
|
193
|
+
"Attention",
|
|
194
|
+
inputs=[input, fc_weight, fc_bias, mask, past],
|
|
195
|
+
outputs=[attention_node_name + "_output", present],
|
|
196
|
+
name=attention_node_name,
|
|
197
|
+
)
|
|
198
|
+
attention_node.domain = "com.microsoft"
|
|
199
|
+
attention_node.attribute.extend(
|
|
200
|
+
[
|
|
201
|
+
helper.make_attribute("num_heads", self.num_heads),
|
|
202
|
+
helper.make_attribute("unidirectional", 1 if is_unidirectional else 0),
|
|
203
|
+
]
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
if self.mask_filter_value is not None:
|
|
207
|
+
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
208
|
+
|
|
209
|
+
matmul_node = helper.make_node(
|
|
210
|
+
"MatMul",
|
|
211
|
+
inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
|
|
212
|
+
outputs=[attention_node_name + "_matmul_output"],
|
|
213
|
+
name=attention_node_name + "_matmul",
|
|
214
|
+
)
|
|
215
|
+
|
|
216
|
+
add_node = helper.make_node(
|
|
217
|
+
"Add",
|
|
218
|
+
inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
|
|
219
|
+
outputs=[output],
|
|
220
|
+
name=attention_node_name + "_add",
|
|
221
|
+
)
|
|
222
|
+
self.nodes_to_add.extend([attention_node, matmul_node, add_node])
|
|
223
|
+
self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
|
|
224
|
+
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
225
|
+
self.node_name_to_graph_name[add_node.name] = self.this_graph_name
|
|
226
|
+
|
|
227
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
228
|
+
past = None
|
|
229
|
+
present = None
|
|
230
|
+
return_indice = []
|
|
231
|
+
|
|
232
|
+
is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
|
|
233
|
+
qkv_nodes = None
|
|
234
|
+
|
|
235
|
+
if not is_normalize_node_skiplayernorm:
|
|
236
|
+
qkv_nodes = self.model.match_parent_path(
|
|
237
|
+
normalize_node,
|
|
238
|
+
["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
239
|
+
[0, None, 0, 0, 0, 0, 0],
|
|
240
|
+
output_name_to_node=output_name_to_node,
|
|
241
|
+
return_indice=return_indice,
|
|
242
|
+
)
|
|
243
|
+
else:
|
|
244
|
+
qkv_nodes = self.model.match_parent_path(
|
|
245
|
+
normalize_node,
|
|
246
|
+
["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
247
|
+
[None, 0, 0, 0, 0, 0],
|
|
248
|
+
output_name_to_node=output_name_to_node,
|
|
249
|
+
return_indice=return_indice,
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
if qkv_nodes is None:
|
|
253
|
+
return
|
|
254
|
+
|
|
255
|
+
another_input = None
|
|
256
|
+
if not is_normalize_node_skiplayernorm:
|
|
257
|
+
(
|
|
258
|
+
add_qkv,
|
|
259
|
+
reshape_qkv,
|
|
260
|
+
gemm_qkv,
|
|
261
|
+
reshape_1,
|
|
262
|
+
reshape_2,
|
|
263
|
+
transpose_qkv,
|
|
264
|
+
matmul_qkv,
|
|
265
|
+
) = qkv_nodes
|
|
266
|
+
|
|
267
|
+
another_input = add_qkv.input[1 - return_indice[0]]
|
|
268
|
+
else:
|
|
269
|
+
(
|
|
270
|
+
reshape_qkv,
|
|
271
|
+
gemm_qkv,
|
|
272
|
+
reshape_1,
|
|
273
|
+
reshape_2,
|
|
274
|
+
transpose_qkv,
|
|
275
|
+
matmul_qkv,
|
|
276
|
+
) = qkv_nodes
|
|
277
|
+
|
|
278
|
+
v_nodes = self.model.match_parent_path(matmul_qkv, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
|
|
279
|
+
if v_nodes is None:
|
|
280
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
281
|
+
return
|
|
282
|
+
(concat_v, transpose_v, reshape_v, split_fc) = v_nodes
|
|
283
|
+
|
|
284
|
+
# Try match pattern using Gemm + LayerNormalization
|
|
285
|
+
fc_nodes = self.model.match_parent_path(
|
|
286
|
+
split_fc,
|
|
287
|
+
["Reshape", "Gemm", "Reshape", "LayerNormalization"],
|
|
288
|
+
[0, 0, 0, 0],
|
|
289
|
+
output_name_to_node,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
# Try match pattern using Gemm + SkipLayerNormalization
|
|
293
|
+
if fc_nodes is None:
|
|
294
|
+
fc_nodes = self.model.match_parent_path(
|
|
295
|
+
split_fc,
|
|
296
|
+
["Reshape", "Gemm", "Reshape", "SkipLayerNormalization"],
|
|
297
|
+
[0, 0, 0, 0],
|
|
298
|
+
output_name_to_node,
|
|
299
|
+
)
|
|
300
|
+
|
|
301
|
+
# Try match pattern using MatMul
|
|
302
|
+
if fc_nodes is None:
|
|
303
|
+
# LayerNormalization
|
|
304
|
+
fc_nodes = self.model.match_parent_path(
|
|
305
|
+
split_fc,
|
|
306
|
+
["Add", "MatMul", "LayerNormalization"],
|
|
307
|
+
[0, None, 0],
|
|
308
|
+
output_name_to_node,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
# SkipLayerNormalization
|
|
312
|
+
if fc_nodes is None:
|
|
313
|
+
fc_nodes = self.model.match_parent_path(
|
|
314
|
+
split_fc,
|
|
315
|
+
["Add", "MatMul", "SkipLayerNormalization"],
|
|
316
|
+
[0, None, 0],
|
|
317
|
+
output_name_to_node,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
if fc_nodes is None:
|
|
321
|
+
logger.debug("fuse_attention: failed to match fc path")
|
|
322
|
+
return
|
|
323
|
+
|
|
324
|
+
fc_weight = fc_nodes[1].input[1]
|
|
325
|
+
i, _ = self.model.get_constant_input(fc_nodes[0])
|
|
326
|
+
fc_bias = fc_nodes[0].input[i]
|
|
327
|
+
else:
|
|
328
|
+
fc_weight = fc_nodes[1].input[1]
|
|
329
|
+
fc_bias = fc_nodes[1].input[2]
|
|
330
|
+
|
|
331
|
+
layernorm_before_attention = fc_nodes[-1]
|
|
332
|
+
|
|
333
|
+
# `another_input` will be non-None only if
|
|
334
|
+
# (1) SkipLayerNorm fusion wasn't turned ON
|
|
335
|
+
# (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
|
|
336
|
+
# fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
|
|
337
|
+
# So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
|
|
338
|
+
if another_input is not None and another_input not in layernorm_before_attention.input:
|
|
339
|
+
logger.debug("Upstream Add and (Skip)LayerNormalization shall have one same input")
|
|
340
|
+
return
|
|
341
|
+
|
|
342
|
+
is_unidirectional = True
|
|
343
|
+
slice_mask = None
|
|
344
|
+
input_mask_nodes = None
|
|
345
|
+
concat_k_to_match = None
|
|
346
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
|
|
347
|
+
if qk_nodes is not None:
|
|
348
|
+
(softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
|
|
349
|
+
mask_nodes = self.model.match_parent_path(
|
|
350
|
+
sub_qk,
|
|
351
|
+
[
|
|
352
|
+
"Mul",
|
|
353
|
+
"Sub",
|
|
354
|
+
"Slice",
|
|
355
|
+
"Slice",
|
|
356
|
+
"Unsqueeze",
|
|
357
|
+
"Sub",
|
|
358
|
+
"Squeeze",
|
|
359
|
+
"Slice",
|
|
360
|
+
"Shape",
|
|
361
|
+
"Div",
|
|
362
|
+
],
|
|
363
|
+
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
|
|
364
|
+
)
|
|
365
|
+
if mask_nodes is None:
|
|
366
|
+
logger.debug("fuse_attention: failed to match unidirectional mask path")
|
|
367
|
+
return
|
|
368
|
+
div_mask = mask_nodes[-1]
|
|
369
|
+
slice_mask = mask_nodes[3]
|
|
370
|
+
|
|
371
|
+
if div_qk != div_mask:
|
|
372
|
+
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
373
|
+
return
|
|
374
|
+
|
|
375
|
+
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
376
|
+
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
377
|
+
if mul_val != -10000:
|
|
378
|
+
self.mask_filter_value = -mul_val
|
|
379
|
+
|
|
380
|
+
else:
|
|
381
|
+
# New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
|
|
382
|
+
i, qk_nodes, _ = self.model.match_parent_paths(
|
|
383
|
+
matmul_qkv,
|
|
384
|
+
[
|
|
385
|
+
(["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0]),
|
|
386
|
+
(["Softmax", "Add", "Where", "Div", "MatMul"], [0, 0, None, 1, 0]),
|
|
387
|
+
],
|
|
388
|
+
output_name_to_node,
|
|
389
|
+
)
|
|
390
|
+
if qk_nodes is None:
|
|
391
|
+
logger.debug("fuse_attention: failed to match qk nodes")
|
|
392
|
+
return
|
|
393
|
+
|
|
394
|
+
where_qk = qk_nodes[-3]
|
|
395
|
+
div_qk = qk_nodes[-2]
|
|
396
|
+
matmul_qk = qk_nodes[-1]
|
|
397
|
+
|
|
398
|
+
if i == 1:
|
|
399
|
+
add_qk = qk_nodes[1]
|
|
400
|
+
_, input_mask_nodes, _ = self.model.match_parent_paths(
|
|
401
|
+
add_qk,
|
|
402
|
+
[
|
|
403
|
+
(
|
|
404
|
+
["Mul", "Sub", "Cast", "Unsqueeze", "Unsqueeze", "Reshape"],
|
|
405
|
+
[None, 0, 1, 0, 0, 0],
|
|
406
|
+
),
|
|
407
|
+
(
|
|
408
|
+
["Mul", "Sub", "Unsqueeze", "Unsqueeze", "Reshape"],
|
|
409
|
+
[None, 0, 1, 0, 0],
|
|
410
|
+
),
|
|
411
|
+
(
|
|
412
|
+
["Mul", "Sub", "Unsqueeze", "Unsqueeze"],
|
|
413
|
+
[None, 0, 1, 0],
|
|
414
|
+
), # useless cast and reshape are removed.
|
|
415
|
+
],
|
|
416
|
+
output_name_to_node,
|
|
417
|
+
)
|
|
418
|
+
if input_mask_nodes is None:
|
|
419
|
+
logger.debug("fuse_attention: failed to match input attention mask path")
|
|
420
|
+
return
|
|
421
|
+
if len(input_mask_nodes) > 1 and input_mask_nodes[0].op_type == "Mul":
|
|
422
|
+
_, mul_val = self.model.get_constant_input(input_mask_nodes[0])
|
|
423
|
+
if mul_val != -10000:
|
|
424
|
+
self.mask_filter_value = mul_val
|
|
425
|
+
|
|
426
|
+
i, mask_nodes, _ = self.model.match_parent_paths(
|
|
427
|
+
where_qk,
|
|
428
|
+
[
|
|
429
|
+
(
|
|
430
|
+
["Cast", "Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape"],
|
|
431
|
+
[0, 0, 0, 1, 0, 0, 0, 0],
|
|
432
|
+
),
|
|
433
|
+
# For Transformers >= 4.27, causal mask uses torch.bool instead of torch.uint8, so no Cast to bool.
|
|
434
|
+
(
|
|
435
|
+
["Slice", "Slice", "Unsqueeze", "Sub", "Squeeze", "Slice", "Shape"],
|
|
436
|
+
[0, 0, 1, 0, 0, 0, 0],
|
|
437
|
+
),
|
|
438
|
+
],
|
|
439
|
+
output_name_to_node,
|
|
440
|
+
)
|
|
441
|
+
if mask_nodes is None:
|
|
442
|
+
# TODO: match mask path for GPT2LMHeadModel_BeamSearchStep.
|
|
443
|
+
logger.debug("fuse_attention: failed to match mask path")
|
|
444
|
+
return
|
|
445
|
+
|
|
446
|
+
slice_mask = mask_nodes[2 if i == 0 else 1]
|
|
447
|
+
|
|
448
|
+
div_or_concat = self.model.get_parent(mask_nodes[-1], 0, output_name_to_node)
|
|
449
|
+
if div_or_concat.op_type == "Div":
|
|
450
|
+
div_mask = div_or_concat
|
|
451
|
+
if div_qk != div_mask:
|
|
452
|
+
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
453
|
+
return
|
|
454
|
+
elif div_or_concat.op_type == "Concat":
|
|
455
|
+
concat_k_to_match = div_or_concat
|
|
456
|
+
else:
|
|
457
|
+
logger.debug("fuse_attention: failed to match mask path")
|
|
458
|
+
|
|
459
|
+
# Validate that the mask data is either lower triangular (unidirectional) or all ones
|
|
460
|
+
mask_data = self.model.get_constant_value(slice_mask.input[0])
|
|
461
|
+
if not (
|
|
462
|
+
isinstance(mask_data, np.ndarray)
|
|
463
|
+
and len(mask_data.shape) == 4
|
|
464
|
+
and mask_data.shape[:2] == (1, 1)
|
|
465
|
+
and mask_data.shape[2] == mask_data.shape[3]
|
|
466
|
+
):
|
|
467
|
+
logger.debug("fuse_attention: skip since mask shape is not 1x1xWxW")
|
|
468
|
+
return
|
|
469
|
+
|
|
470
|
+
if np.allclose(mask_data, np.ones_like(mask_data)):
|
|
471
|
+
is_unidirectional = False
|
|
472
|
+
elif not np.allclose(mask_data, np.tril(np.ones_like(mask_data))):
|
|
473
|
+
logger.debug("fuse_attention: skip since mask is neither lower triangular nor ones")
|
|
474
|
+
return
|
|
475
|
+
|
|
476
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
|
|
477
|
+
if q_nodes is None:
|
|
478
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
479
|
+
return
|
|
480
|
+
(transpose_q, reshape_q, split_q) = q_nodes
|
|
481
|
+
if split_fc != split_q:
|
|
482
|
+
logger.debug("fuse_attention: skip since split_fc != split_q")
|
|
483
|
+
return
|
|
484
|
+
|
|
485
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Concat", "Transpose", "Reshape", "Split"], [1, 1, 0, 0])
|
|
486
|
+
if k_nodes is None:
|
|
487
|
+
# This pattern is from pytorch 1.7.1 and transformers 4.6.1
|
|
488
|
+
k_nodes = self.model.match_parent_path(
|
|
489
|
+
matmul_qk,
|
|
490
|
+
["Transpose", "Concat", "Transpose", "Reshape", "Split"],
|
|
491
|
+
[1, 0, 1, 0, 0],
|
|
492
|
+
)
|
|
493
|
+
if k_nodes is None:
|
|
494
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
495
|
+
return
|
|
496
|
+
else:
|
|
497
|
+
(_, concat_k, transpose_k, reshape_k, split_k) = k_nodes
|
|
498
|
+
else:
|
|
499
|
+
(concat_k, transpose_k, reshape_k, split_k) = k_nodes
|
|
500
|
+
if split_fc != split_k:
|
|
501
|
+
logger.debug("fuse_attention: skip since split_fc != split_k")
|
|
502
|
+
return
|
|
503
|
+
|
|
504
|
+
if concat_k_to_match and concat_k != concat_k_to_match:
|
|
505
|
+
logger.debug("fuse_attention: skip since concat_k != concat_k_to_match")
|
|
506
|
+
return
|
|
507
|
+
|
|
508
|
+
attention_mask_input_name = ""
|
|
509
|
+
if input_mask_nodes is not None:
|
|
510
|
+
input_name = input_mask_nodes[-1].input[0]
|
|
511
|
+
attention_mask_input_name = self.cast_attention_mask(input_name)
|
|
512
|
+
|
|
513
|
+
# Match past and present paths
|
|
514
|
+
past = self.match_past_pattern_1(concat_k, concat_v, output_name_to_node) or self.match_past_pattern_2(
|
|
515
|
+
concat_k, concat_v, output_name_to_node
|
|
516
|
+
)
|
|
517
|
+
if past is None:
|
|
518
|
+
logger.info("fuse_attention: failed to match past path")
|
|
519
|
+
return
|
|
520
|
+
if not self.model.find_graph_input(past):
|
|
521
|
+
logger.debug("past is not graph input.")
|
|
522
|
+
# For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
|
|
523
|
+
|
|
524
|
+
present = self.match_present(concat_v, input_name_to_nodes)
|
|
525
|
+
if present is None:
|
|
526
|
+
logger.info("fuse_attention: failed to match present path")
|
|
527
|
+
return
|
|
528
|
+
if not self.model.find_graph_output(present):
|
|
529
|
+
logger.info("expect present to be graph output")
|
|
530
|
+
return
|
|
531
|
+
|
|
532
|
+
self.create_attention_node(
|
|
533
|
+
fc_weight,
|
|
534
|
+
fc_bias,
|
|
535
|
+
gemm_qkv,
|
|
536
|
+
past,
|
|
537
|
+
present,
|
|
538
|
+
layernorm_before_attention.output[0],
|
|
539
|
+
reshape_qkv.output[0],
|
|
540
|
+
attention_mask_input_name,
|
|
541
|
+
is_unidirectional,
|
|
542
|
+
)
|
|
543
|
+
|
|
544
|
+
# we rely on prune_graph() to clean old subgraph nodes:
|
|
545
|
+
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
|
|
546
|
+
self.prune_graph = True
|