onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
import numpy as np
|
|
8
|
+
from fusion_gpt_attention import FusionGptAttentionPastBase
|
|
9
|
+
from onnx import helper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
def is_close(value, expected_value):
|
|
16
|
+
return abs(value - expected_value) <= 1e-6
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class FusionGptAttentionMegatron(FusionGptAttentionPastBase):
|
|
20
|
+
"""
|
|
21
|
+
Fuse GPT-2 Attention with past state subgraph from Megatron into one Attention node.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
def __init__(self, model: OnnxModel, num_heads: int):
|
|
25
|
+
super().__init__(model, num_heads)
|
|
26
|
+
|
|
27
|
+
def fuse_attention_node(
|
|
28
|
+
self,
|
|
29
|
+
matmul_before_split,
|
|
30
|
+
add_before_split,
|
|
31
|
+
past,
|
|
32
|
+
present,
|
|
33
|
+
input,
|
|
34
|
+
reshape_qkv,
|
|
35
|
+
mask,
|
|
36
|
+
):
|
|
37
|
+
attention_node_name = self.model.create_node_name("GptAttention")
|
|
38
|
+
int32_mask = self.cast_attention_mask(mask)
|
|
39
|
+
output = reshape_qkv.output[0]
|
|
40
|
+
i = 1 if (add_before_split.input[0] == matmul_before_split.output[0]) else 0
|
|
41
|
+
attention_node = helper.make_node(
|
|
42
|
+
"Attention",
|
|
43
|
+
inputs=[
|
|
44
|
+
input,
|
|
45
|
+
matmul_before_split.input[1],
|
|
46
|
+
add_before_split.input[i],
|
|
47
|
+
int32_mask,
|
|
48
|
+
past,
|
|
49
|
+
],
|
|
50
|
+
outputs=[output, present],
|
|
51
|
+
name=attention_node_name,
|
|
52
|
+
)
|
|
53
|
+
attention_node.domain = "com.microsoft"
|
|
54
|
+
attention_node.attribute.extend(
|
|
55
|
+
[
|
|
56
|
+
helper.make_attribute("num_heads", self.num_heads),
|
|
57
|
+
helper.make_attribute("unidirectional", 0), # unidirectional shall not be ON for 4D attention mask
|
|
58
|
+
]
|
|
59
|
+
)
|
|
60
|
+
if self.mask_filter_value is not None:
|
|
61
|
+
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
62
|
+
|
|
63
|
+
nodes_to_add = [attention_node]
|
|
64
|
+
self.nodes_to_add.extend(nodes_to_add)
|
|
65
|
+
|
|
66
|
+
for node in nodes_to_add:
|
|
67
|
+
self.node_name_to_graph_name[node.name] = self.this_graph_name
|
|
68
|
+
|
|
69
|
+
self.nodes_to_remove.append(reshape_qkv)
|
|
70
|
+
|
|
71
|
+
# we rely on prune_graph() to clean old subgraph nodes
|
|
72
|
+
self.prune_graph = True
|
|
73
|
+
|
|
74
|
+
def match_mask(self, sub_qk, mul_qk, matmul_qk, layernorm_before_attention):
|
|
75
|
+
mask_nodes = self.model.match_parent_path(sub_qk, ["Mul", "Sub", "Slice", "Slice"], [1, 0, 1, 0])
|
|
76
|
+
if mask_nodes is None:
|
|
77
|
+
logger.debug("fuse_attention: failed to match unidirectional mask path")
|
|
78
|
+
return None
|
|
79
|
+
(mul_mask, sub_mask, last_slice_mask, slice_mask) = mask_nodes
|
|
80
|
+
|
|
81
|
+
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
82
|
+
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
83
|
+
if mul_val != 10000:
|
|
84
|
+
self.mask_filter_value = -mul_val
|
|
85
|
+
|
|
86
|
+
if mul_qk.input[1] != last_slice_mask.output[0]:
|
|
87
|
+
logger.debug("fuse_attention failed: mul_qk.input[1] != last_slice_mask.output[0]")
|
|
88
|
+
return None
|
|
89
|
+
|
|
90
|
+
if not self.utils.check_node_input_value(mul_mask, 1, 10000.0):
|
|
91
|
+
logger.debug("fuse_attention failed: mul_mask input 1 is not constant 10000.0")
|
|
92
|
+
return None
|
|
93
|
+
|
|
94
|
+
if not self.utils.check_node_input_value(sub_mask, 0, 1.0):
|
|
95
|
+
logger.debug("fuse_attention failed: sub_mask input 0 is not constant 1.0")
|
|
96
|
+
return None
|
|
97
|
+
|
|
98
|
+
if not self.model.find_graph_input(slice_mask.input[0]):
|
|
99
|
+
logger.info("expect slick_mask input 0 to be graph input")
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
if not self.utils.check_node_input_value(last_slice_mask, 1, [0]):
|
|
103
|
+
logger.debug("fuse_attention failed: last_slice_mask input 1 (starts) is not constant [0]")
|
|
104
|
+
return None
|
|
105
|
+
|
|
106
|
+
if not self.utils.check_node_input_value(last_slice_mask, 3, [3]):
|
|
107
|
+
logger.debug("fuse_attention failed: last_slice_mask input 3 (axes) is not constant [3]")
|
|
108
|
+
return False
|
|
109
|
+
|
|
110
|
+
if not self.utils.check_node_input_value(last_slice_mask, 4, [1]):
|
|
111
|
+
logger.debug("fuse_attention failed: last_slice_mask input 4 (steps) is not constant [1]")
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
if not self.utils.check_node_input_value(slice_mask, 3, [2]):
|
|
115
|
+
logger.debug("fuse_attention failed: slice_mask input 3 (axes) is not constant [2]")
|
|
116
|
+
return None
|
|
117
|
+
|
|
118
|
+
if not self.utils.check_node_input_value(slice_mask, 4, [1]):
|
|
119
|
+
logger.debug("fuse_attention failed: slice_mask input 4 (steps) is not constant [1]")
|
|
120
|
+
return None
|
|
121
|
+
|
|
122
|
+
last_slice_path = self.model.match_parent_path(
|
|
123
|
+
last_slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
|
|
124
|
+
)
|
|
125
|
+
if last_slice_path is None or last_slice_path[-1] != matmul_qk:
|
|
126
|
+
logger.debug("fuse_attention: failed to match last slice path")
|
|
127
|
+
return None
|
|
128
|
+
|
|
129
|
+
first_slice_path = self.model.match_parent_path(
|
|
130
|
+
slice_mask, ["Unsqueeze", "Gather", "Shape", "MatMul"], [2, 0, 0, 0]
|
|
131
|
+
)
|
|
132
|
+
if first_slice_path is None or first_slice_path[-1] != matmul_qk:
|
|
133
|
+
logger.debug("fuse_attention: failed to match first slice path")
|
|
134
|
+
return None
|
|
135
|
+
|
|
136
|
+
first_slice_sub = self.model.match_parent_path(
|
|
137
|
+
slice_mask,
|
|
138
|
+
["Unsqueeze", "Sub", "Gather", "Shape", "MatMul"],
|
|
139
|
+
[1, 0, 0, 0, 0],
|
|
140
|
+
)
|
|
141
|
+
if first_slice_sub is None or first_slice_sub[-1] != matmul_qk:
|
|
142
|
+
logger.debug("fuse_attention: failed to match last slice sub path")
|
|
143
|
+
return None
|
|
144
|
+
|
|
145
|
+
first_slice_sub_1 = self.model.match_parent_path(
|
|
146
|
+
slice_mask,
|
|
147
|
+
["Unsqueeze", "Sub", "Gather", "Shape", "LayerNormalization"],
|
|
148
|
+
[1, 0, 1, 0, 0],
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
if first_slice_sub_1 is None:
|
|
152
|
+
first_slice_sub_1 = self.model.match_parent_path(
|
|
153
|
+
slice_mask,
|
|
154
|
+
["Unsqueeze", "Sub", "Gather", "Shape", "SkipLayerNormalization"],
|
|
155
|
+
[1, 0, 1, 0, 0],
|
|
156
|
+
)
|
|
157
|
+
|
|
158
|
+
if first_slice_sub_1 is None or first_slice_sub_1[-1] != layernorm_before_attention:
|
|
159
|
+
logger.debug("fuse_attention: failed to match last slice sub path 1")
|
|
160
|
+
return None
|
|
161
|
+
|
|
162
|
+
return slice_mask.input[0]
|
|
163
|
+
|
|
164
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
165
|
+
past = None
|
|
166
|
+
present = None
|
|
167
|
+
|
|
168
|
+
is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
|
|
169
|
+
qkv_nodes = None
|
|
170
|
+
|
|
171
|
+
if not is_normalize_node_skiplayernorm:
|
|
172
|
+
qkv_nodes = self.model.match_parent_path(
|
|
173
|
+
normalize_node,
|
|
174
|
+
["Add", "Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
175
|
+
[0, 1, None, 0, 0, 0],
|
|
176
|
+
output_name_to_node=output_name_to_node,
|
|
177
|
+
)
|
|
178
|
+
else:
|
|
179
|
+
qkv_nodes = self.model.match_parent_path(
|
|
180
|
+
normalize_node,
|
|
181
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
182
|
+
[1, None, 0, 0, 0],
|
|
183
|
+
output_name_to_node=output_name_to_node,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
if qkv_nodes is None:
|
|
187
|
+
return
|
|
188
|
+
|
|
189
|
+
skip_input = None
|
|
190
|
+
if not is_normalize_node_skiplayernorm:
|
|
191
|
+
(
|
|
192
|
+
add_skip,
|
|
193
|
+
add_after_attention,
|
|
194
|
+
matmul_after_attention,
|
|
195
|
+
reshape_qkv,
|
|
196
|
+
transpose_qkv,
|
|
197
|
+
matmul_qkv,
|
|
198
|
+
) = qkv_nodes
|
|
199
|
+
|
|
200
|
+
skip_input = add_skip.input[0]
|
|
201
|
+
else:
|
|
202
|
+
(
|
|
203
|
+
add_after_attention,
|
|
204
|
+
matmul_after_attention,
|
|
205
|
+
reshape_qkv,
|
|
206
|
+
transpose_qkv,
|
|
207
|
+
matmul_qkv,
|
|
208
|
+
) = qkv_nodes
|
|
209
|
+
|
|
210
|
+
skip_input = normalize_node.input[0]
|
|
211
|
+
|
|
212
|
+
v_nodes = self.model.match_parent_path(
|
|
213
|
+
matmul_qkv,
|
|
214
|
+
[
|
|
215
|
+
"Concat",
|
|
216
|
+
"Transpose",
|
|
217
|
+
"Reshape",
|
|
218
|
+
"Split",
|
|
219
|
+
"Add",
|
|
220
|
+
"MatMul",
|
|
221
|
+
"LayerNormalization",
|
|
222
|
+
],
|
|
223
|
+
[1, 1, 0, 0, 0, None, 0],
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
if v_nodes is None:
|
|
227
|
+
v_nodes = self.model.match_parent_path(
|
|
228
|
+
matmul_qkv,
|
|
229
|
+
[
|
|
230
|
+
"Concat",
|
|
231
|
+
"Transpose",
|
|
232
|
+
"Reshape",
|
|
233
|
+
"Split",
|
|
234
|
+
"Add",
|
|
235
|
+
"MatMul",
|
|
236
|
+
"SkipLayerNormalization",
|
|
237
|
+
],
|
|
238
|
+
[1, 1, 0, 0, 0, None, 0],
|
|
239
|
+
)
|
|
240
|
+
|
|
241
|
+
if v_nodes is None:
|
|
242
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
243
|
+
return
|
|
244
|
+
(
|
|
245
|
+
concat_v,
|
|
246
|
+
transpose_v,
|
|
247
|
+
reshape_v,
|
|
248
|
+
split_v,
|
|
249
|
+
add_before_split,
|
|
250
|
+
matmul_before_split,
|
|
251
|
+
layernorm_before_attention,
|
|
252
|
+
) = v_nodes
|
|
253
|
+
|
|
254
|
+
if (
|
|
255
|
+
layernorm_before_attention.op_type == "LayerNormalization"
|
|
256
|
+
and skip_input != layernorm_before_attention.input[0]
|
|
257
|
+
):
|
|
258
|
+
logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
|
|
259
|
+
return
|
|
260
|
+
|
|
261
|
+
if (
|
|
262
|
+
layernorm_before_attention.op_type == "SkipLayerNormalization"
|
|
263
|
+
and skip_input != layernorm_before_attention.output[3]
|
|
264
|
+
):
|
|
265
|
+
logger.debug("fuse_attention: skip_input != layernorm_before_attention.input[0]")
|
|
266
|
+
return
|
|
267
|
+
|
|
268
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "MatMul"], [0, 0, 0, 0])
|
|
269
|
+
if qk_nodes is None:
|
|
270
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
271
|
+
return None
|
|
272
|
+
(softmax_qk, sub_qk, mul_qk, matmul_qk) = qk_nodes
|
|
273
|
+
if self.model.get_node_attribute(softmax_qk, "axis") != 3:
|
|
274
|
+
logger.debug("fuse_attention failed: softmax_qk axis != 3")
|
|
275
|
+
return None
|
|
276
|
+
|
|
277
|
+
attention_mask = self.match_mask(sub_qk, mul_qk, matmul_qk, layernorm_before_attention)
|
|
278
|
+
|
|
279
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Div", "Transpose", "Reshape", "Split"], [0, 0, 0, 0])
|
|
280
|
+
if q_nodes is None:
|
|
281
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
282
|
+
return
|
|
283
|
+
(div_q, transpose_q, reshape_q, split_q) = q_nodes
|
|
284
|
+
if split_v != split_q:
|
|
285
|
+
logger.debug("fuse_attention: skip since split_v != split_q")
|
|
286
|
+
return
|
|
287
|
+
|
|
288
|
+
k_nodes = self.model.match_parent_path(
|
|
289
|
+
matmul_qk,
|
|
290
|
+
["Div", "Transpose", "Concat", "Transpose", "Reshape", "Split"],
|
|
291
|
+
[1, 0, 0, 1, 0, 0],
|
|
292
|
+
)
|
|
293
|
+
if k_nodes is None:
|
|
294
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
295
|
+
return
|
|
296
|
+
(div_k, _, concat_k, transpose_k, reshape_k, split_k) = k_nodes
|
|
297
|
+
if split_v != split_k:
|
|
298
|
+
logger.debug("fuse_attention: skip since split_v != split_k")
|
|
299
|
+
return
|
|
300
|
+
|
|
301
|
+
i, value = self.model.get_constant_input(reshape_k)
|
|
302
|
+
if not (
|
|
303
|
+
isinstance(value, np.ndarray)
|
|
304
|
+
and list(value.shape) == [4]
|
|
305
|
+
and value[0] == 0
|
|
306
|
+
and value[1] == 0
|
|
307
|
+
and value[2] > 0
|
|
308
|
+
and value[3] > 0
|
|
309
|
+
):
|
|
310
|
+
logger.debug("fuse_attention: reshape constant input is not [0, 0, N, H]")
|
|
311
|
+
return
|
|
312
|
+
|
|
313
|
+
num_heads = value[2]
|
|
314
|
+
if num_heads != self.num_heads:
|
|
315
|
+
logger.info(f"Detected num_heads={num_heads}. Ignore user specified value {self.num_heads}")
|
|
316
|
+
self.num_heads = num_heads
|
|
317
|
+
|
|
318
|
+
hidden_size_per_head = value[3]
|
|
319
|
+
i, value = self.model.get_constant_input(div_k)
|
|
320
|
+
expected_value = float(np.sqrt(np.sqrt(hidden_size_per_head)))
|
|
321
|
+
if not is_close(value, expected_value):
|
|
322
|
+
logger.debug(f"fuse_attention: div_k value={value} expected={expected_value}")
|
|
323
|
+
return
|
|
324
|
+
|
|
325
|
+
i, value = self.model.get_constant_input(div_q)
|
|
326
|
+
if not is_close(value, expected_value):
|
|
327
|
+
logger.debug(f"fuse_attention: div_q value={value} expected={expected_value}")
|
|
328
|
+
return
|
|
329
|
+
|
|
330
|
+
# Match past and present paths
|
|
331
|
+
past = self.match_past_pattern_2(concat_k, concat_v, output_name_to_node)
|
|
332
|
+
if past is None:
|
|
333
|
+
logger.debug("fuse_attention: match past failed")
|
|
334
|
+
return
|
|
335
|
+
if not self.model.find_graph_input(past):
|
|
336
|
+
logger.debug("fuse_attention: past is not graph input.")
|
|
337
|
+
# For GPT2LMHeadModel_BeamSearchStep, there is an extra Gather node to select beam index so it is not graph input.
|
|
338
|
+
|
|
339
|
+
present = self.match_present(concat_v, input_name_to_nodes)
|
|
340
|
+
if present is None:
|
|
341
|
+
logger.debug("fuse_attention: match present failed")
|
|
342
|
+
return
|
|
343
|
+
if not self.model.find_graph_output(present):
|
|
344
|
+
logger.info("fuse_attention: expect present to be graph output")
|
|
345
|
+
return
|
|
346
|
+
|
|
347
|
+
self.fuse_attention_node(
|
|
348
|
+
matmul_before_split,
|
|
349
|
+
add_before_split,
|
|
350
|
+
past,
|
|
351
|
+
present,
|
|
352
|
+
layernorm_before_attention.output[0],
|
|
353
|
+
reshape_qkv,
|
|
354
|
+
attention_mask,
|
|
355
|
+
)
|
|
@@ -0,0 +1,260 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
from fusion_base import Fusion
|
|
8
|
+
from onnx import helper
|
|
9
|
+
from onnx_model import OnnxModel
|
|
10
|
+
|
|
11
|
+
logger = getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FusionGptAttentionNoPast(Fusion):
|
|
15
|
+
"""
|
|
16
|
+
Fuse GPT-2 Attention without past state into one Attention node.
|
|
17
|
+
This does not support attention_mask graph input right now.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model: OnnxModel, num_heads: int):
|
|
21
|
+
super().__init__(model, "Attention", ["LayerNormalization", "SkipLayerNormalization"], "without past")
|
|
22
|
+
# TODO: detect num_heads from graph like FusionAttention
|
|
23
|
+
self.num_heads = num_heads
|
|
24
|
+
self.mask_filter_value = None
|
|
25
|
+
|
|
26
|
+
def create_attention_node(self, gemm, gemm_qkv, input, output):
|
|
27
|
+
attention_node_name = self.model.create_node_name("Attention")
|
|
28
|
+
attention_node = helper.make_node(
|
|
29
|
+
"Attention",
|
|
30
|
+
inputs=[input, gemm.input[1], gemm.input[2]],
|
|
31
|
+
outputs=[attention_node_name + "_output"],
|
|
32
|
+
name=attention_node_name,
|
|
33
|
+
)
|
|
34
|
+
attention_node.domain = "com.microsoft"
|
|
35
|
+
attention_node.attribute.extend(
|
|
36
|
+
[
|
|
37
|
+
helper.make_attribute("num_heads", self.num_heads),
|
|
38
|
+
helper.make_attribute("unidirectional", 1),
|
|
39
|
+
]
|
|
40
|
+
)
|
|
41
|
+
if self.mask_filter_value is not None:
|
|
42
|
+
attention_node.attribute.extend([helper.make_attribute("mask_filter_value", float(self.mask_filter_value))])
|
|
43
|
+
|
|
44
|
+
matmul_node = helper.make_node(
|
|
45
|
+
"MatMul",
|
|
46
|
+
inputs=[attention_node_name + "_output", gemm_qkv.input[1]],
|
|
47
|
+
outputs=[attention_node_name + "_matmul_output"],
|
|
48
|
+
name=attention_node_name + "_matmul",
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
add_node = helper.make_node(
|
|
52
|
+
"Add",
|
|
53
|
+
inputs=[attention_node_name + "_matmul_output", gemm_qkv.input[2]],
|
|
54
|
+
outputs=[output],
|
|
55
|
+
name=attention_node_name + "_add",
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
self.nodes_to_add.extend([attention_node, matmul_node, add_node])
|
|
59
|
+
self.node_name_to_graph_name[attention_node.name] = self.this_graph_name
|
|
60
|
+
self.node_name_to_graph_name[matmul_node.name] = self.this_graph_name
|
|
61
|
+
self.node_name_to_graph_name[add_node.name] = self.this_graph_name
|
|
62
|
+
|
|
63
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
64
|
+
# (TODO) hasesh/tlwu: Investigate what fixes the following logic needs in order
|
|
65
|
+
# to fuse the Attention sub-graph. With some changes to other fusions, this stopped
|
|
66
|
+
# working.
|
|
67
|
+
return_indice = []
|
|
68
|
+
|
|
69
|
+
is_normalize_node_skiplayernorm = normalize_node.op_type == "SkipLayerNormalization"
|
|
70
|
+
qkv_nodes = None
|
|
71
|
+
|
|
72
|
+
if not is_normalize_node_skiplayernorm:
|
|
73
|
+
qkv_nodes = self.model.match_parent_path(
|
|
74
|
+
normalize_node,
|
|
75
|
+
["Add", "Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
76
|
+
[0, None, 0, 0, 0, 0, 0],
|
|
77
|
+
output_name_to_node=output_name_to_node,
|
|
78
|
+
return_indice=return_indice,
|
|
79
|
+
)
|
|
80
|
+
else:
|
|
81
|
+
qkv_nodes = self.model.match_parent_path(
|
|
82
|
+
normalize_node,
|
|
83
|
+
["Reshape", "Gemm", "Reshape", "Reshape", "Transpose", "MatMul"],
|
|
84
|
+
[None, 0, 0, 0, 0, 0],
|
|
85
|
+
output_name_to_node=output_name_to_node,
|
|
86
|
+
return_indice=return_indice,
|
|
87
|
+
)
|
|
88
|
+
|
|
89
|
+
if qkv_nodes is None:
|
|
90
|
+
return
|
|
91
|
+
|
|
92
|
+
another_input = None
|
|
93
|
+
if not is_normalize_node_skiplayernorm:
|
|
94
|
+
(
|
|
95
|
+
add_qkv,
|
|
96
|
+
reshape_qkv,
|
|
97
|
+
gemm_qkv,
|
|
98
|
+
reshape_1,
|
|
99
|
+
reshape_2,
|
|
100
|
+
transpose_qkv,
|
|
101
|
+
matmul_qkv,
|
|
102
|
+
) = qkv_nodes
|
|
103
|
+
|
|
104
|
+
another_input = add_qkv.input[1 - return_indice[0]]
|
|
105
|
+
else:
|
|
106
|
+
(
|
|
107
|
+
reshape_qkv,
|
|
108
|
+
gemm_qkv,
|
|
109
|
+
reshape_1,
|
|
110
|
+
reshape_2,
|
|
111
|
+
transpose_qkv,
|
|
112
|
+
matmul_qkv,
|
|
113
|
+
) = qkv_nodes
|
|
114
|
+
|
|
115
|
+
v_nodes = self.model.match_parent_path(
|
|
116
|
+
matmul_qkv,
|
|
117
|
+
["Transpose", "Reshape", "Split", "Reshape", "Gemm", "Reshape"],
|
|
118
|
+
[1, 0, 0, 0, 0, 0],
|
|
119
|
+
)
|
|
120
|
+
if v_nodes is None:
|
|
121
|
+
logger.debug("fuse_attention: failed to match v path")
|
|
122
|
+
return
|
|
123
|
+
(
|
|
124
|
+
transpose_v,
|
|
125
|
+
reshape_v,
|
|
126
|
+
split_v,
|
|
127
|
+
reshape_after_gemm,
|
|
128
|
+
gemm,
|
|
129
|
+
reshape_before_gemm,
|
|
130
|
+
) = v_nodes
|
|
131
|
+
|
|
132
|
+
layernorm_before_attention = self.model.get_parent(reshape_before_gemm, 0, output_name_to_node)
|
|
133
|
+
if layernorm_before_attention is None or (
|
|
134
|
+
layernorm_before_attention.op_type != "LayerNormalization"
|
|
135
|
+
and layernorm_before_attention.op_type != "SkipLayerNormalization"
|
|
136
|
+
):
|
|
137
|
+
if layernorm_before_attention.op_type != "Add":
|
|
138
|
+
logger.debug(f"failed to get (skip)layernorm before gemm. Got {layernorm_before_attention.op_type}")
|
|
139
|
+
return
|
|
140
|
+
|
|
141
|
+
# `another_input` will be non-None only if
|
|
142
|
+
# (1) SkipLayerNorm fusion wasn't turned ON
|
|
143
|
+
# (2) SkipLayerNorm fusion was turned ON but upstream layer's LayerNorm + Add was not
|
|
144
|
+
# fused into a SkipLayerNorm. This can happen if the shapes to the Add node are different.
|
|
145
|
+
# So, keep the following check if SkipLayerNorm fusion is turned ON or OFF.
|
|
146
|
+
if another_input is not None:
|
|
147
|
+
if another_input not in layernorm_before_attention.input:
|
|
148
|
+
# match openai-gpt
|
|
149
|
+
if another_input not in layernorm_before_attention.output:
|
|
150
|
+
logger.debug("Add and (Skip)LayerNormalization shall have one same input")
|
|
151
|
+
return
|
|
152
|
+
|
|
153
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Sub", "Mul", "Div", "MatMul"], [0, 0, 0, 0, 0])
|
|
154
|
+
if qk_nodes is not None:
|
|
155
|
+
(softmax_qk, sub_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
|
|
156
|
+
mask_nodes = self.model.match_parent_path(
|
|
157
|
+
sub_qk,
|
|
158
|
+
[
|
|
159
|
+
"Mul",
|
|
160
|
+
"Sub",
|
|
161
|
+
"Slice",
|
|
162
|
+
"Slice",
|
|
163
|
+
"Unsqueeze",
|
|
164
|
+
"Sub",
|
|
165
|
+
"Squeeze",
|
|
166
|
+
"Slice",
|
|
167
|
+
"Shape",
|
|
168
|
+
"Div",
|
|
169
|
+
],
|
|
170
|
+
[1, 0, 1, 0, 1, 0, 0, 0, 0, 0],
|
|
171
|
+
)
|
|
172
|
+
if mask_nodes is None:
|
|
173
|
+
logger.debug("fuse_attention: failed to match mask path")
|
|
174
|
+
return
|
|
175
|
+
div_mask = mask_nodes[-1]
|
|
176
|
+
|
|
177
|
+
if div_qk != div_mask:
|
|
178
|
+
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
179
|
+
return
|
|
180
|
+
if len(mask_nodes) > 1 and mask_nodes[0].op_type == "Mul":
|
|
181
|
+
_, mul_val = self.model.get_constant_input(mask_nodes[0])
|
|
182
|
+
if mul_val != -10000:
|
|
183
|
+
self.mask_filter_value = mul_val
|
|
184
|
+
|
|
185
|
+
else:
|
|
186
|
+
# New pattern for gpt2 from PyTorch 1.5.0 and Transformers 2.9.0.
|
|
187
|
+
qk_nodes = self.model.match_parent_path(matmul_qkv, ["Softmax", "Where", "Div", "MatMul"], [0, 0, 1, 0])
|
|
188
|
+
if qk_nodes is not None:
|
|
189
|
+
(softmax_qk, where_qk, div_qk, matmul_qk) = qk_nodes
|
|
190
|
+
mask_nodes = self.model.match_parent_path(
|
|
191
|
+
where_qk,
|
|
192
|
+
[
|
|
193
|
+
"Cast",
|
|
194
|
+
"Slice",
|
|
195
|
+
"Slice",
|
|
196
|
+
"Unsqueeze",
|
|
197
|
+
"Sub",
|
|
198
|
+
"Squeeze",
|
|
199
|
+
"Slice",
|
|
200
|
+
"Shape",
|
|
201
|
+
"Div",
|
|
202
|
+
],
|
|
203
|
+
[0, 0, 0, 1, 0, 0, 0, 0, 0],
|
|
204
|
+
)
|
|
205
|
+
if mask_nodes is None:
|
|
206
|
+
logger.debug("fuse_attention: failed to match mask path")
|
|
207
|
+
return
|
|
208
|
+
div_mask = mask_nodes[-1]
|
|
209
|
+
|
|
210
|
+
if div_qk != div_mask:
|
|
211
|
+
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
212
|
+
return
|
|
213
|
+
else:
|
|
214
|
+
# match openai-gpt
|
|
215
|
+
qk_nodes = self.model.match_parent_path(
|
|
216
|
+
matmul_qkv,
|
|
217
|
+
["Softmax", "Add", "Mul", "Div", "MatMul"],
|
|
218
|
+
[0, 0, 0, 0, 0],
|
|
219
|
+
)
|
|
220
|
+
if qk_nodes is None:
|
|
221
|
+
logger.debug("fuse_attention: failed to match qk path")
|
|
222
|
+
return
|
|
223
|
+
(softmax_qk, add_qk, mul_qk, div_qk, matmul_qk) = qk_nodes
|
|
224
|
+
mask_nodes = self.model.match_parent_path(
|
|
225
|
+
mul_qk,
|
|
226
|
+
["Slice", "Slice", "Unsqueeze", "Squeeze", "Slice", "Shape", "Div"],
|
|
227
|
+
[1, 0, 2, 0, 0, 0, 0],
|
|
228
|
+
)
|
|
229
|
+
if mask_nodes is None:
|
|
230
|
+
logger.debug("fuse_attention: failed to match mask path")
|
|
231
|
+
return
|
|
232
|
+
div_mask = mask_nodes[-1]
|
|
233
|
+
|
|
234
|
+
if div_qk != div_mask:
|
|
235
|
+
logger.debug("fuse_attention: skip since div_qk != div_mask")
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
q_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [0, 0, 0])
|
|
239
|
+
if q_nodes is None:
|
|
240
|
+
logger.debug("fuse_attention: failed to match q path")
|
|
241
|
+
return
|
|
242
|
+
(transpose_q, reshape_q, split_q) = q_nodes
|
|
243
|
+
if split_v != split_q:
|
|
244
|
+
logger.debug("fuse_attention: skip since split_v != split_q")
|
|
245
|
+
return
|
|
246
|
+
|
|
247
|
+
k_nodes = self.model.match_parent_path(matmul_qk, ["Transpose", "Reshape", "Split"], [1, 0, 0])
|
|
248
|
+
if k_nodes is None:
|
|
249
|
+
logger.debug("fuse_attention: failed to match k path")
|
|
250
|
+
return
|
|
251
|
+
(transpose_k, reshape_k, split_k) = k_nodes
|
|
252
|
+
if split_v != split_k:
|
|
253
|
+
logger.debug("fuse_attention: skip since split_v != split_k")
|
|
254
|
+
return
|
|
255
|
+
|
|
256
|
+
self.create_attention_node(gemm, gemm_qkv, layernorm_before_attention.output[0], reshape_qkv.output[0])
|
|
257
|
+
|
|
258
|
+
# we rely on prune_graph() to clean old subgraph nodes:
|
|
259
|
+
# qk_nodes + q_nodes + k_nodes + v_nodes + mask_nodes + [reshape_qkv, transpose_qkv, matmul_qkv]
|
|
260
|
+
self.prune_graph = True
|