onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,311 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
from collections import deque
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
|
|
12
|
+
from ..onnx_model import ONNXModel
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class Fusion:
|
|
16
|
+
"""
|
|
17
|
+
Base class for fusions.
|
|
18
|
+
"""
|
|
19
|
+
|
|
20
|
+
def __init__(self, model: ONNXModel, fused_op_type: str, search_op_type: str):
|
|
21
|
+
self.search_op_type: str = search_op_type
|
|
22
|
+
self.fused_op_type: str = fused_op_type
|
|
23
|
+
self.model: ONNXModel = model
|
|
24
|
+
self.nodes_to_remove: list = []
|
|
25
|
+
self.nodes_to_add: list = []
|
|
26
|
+
|
|
27
|
+
self._new_node_name_prefix = self.fused_op_type + "_fused_" + self.search_op_type + "_"
|
|
28
|
+
self._new_node_name_suffix = None # int|None used to create unique node names for the fused ops.
|
|
29
|
+
|
|
30
|
+
def fuse(
|
|
31
|
+
self,
|
|
32
|
+
node: onnx.NodeProto,
|
|
33
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
34
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
35
|
+
):
|
|
36
|
+
"""
|
|
37
|
+
Interface function for derived fusion classes. Tries to fuse a node sequence containing
|
|
38
|
+
the specified node.
|
|
39
|
+
"""
|
|
40
|
+
raise NotImplementedError
|
|
41
|
+
|
|
42
|
+
def apply(self) -> bool:
|
|
43
|
+
"""
|
|
44
|
+
Apply graph fusion on the entire model graph.
|
|
45
|
+
"""
|
|
46
|
+
input_name_to_nodes = self.model.input_name_to_nodes()
|
|
47
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
48
|
+
|
|
49
|
+
for node in self.model.nodes():
|
|
50
|
+
if node.op_type == self.search_op_type:
|
|
51
|
+
self.fuse(node, input_name_to_nodes, output_name_to_node)
|
|
52
|
+
|
|
53
|
+
self.model.remove_nodes(self.nodes_to_remove)
|
|
54
|
+
self.model.add_nodes(self.nodes_to_add)
|
|
55
|
+
|
|
56
|
+
graph_updated = bool(self.nodes_to_remove or self.nodes_to_add)
|
|
57
|
+
|
|
58
|
+
if graph_updated:
|
|
59
|
+
self.model.remove_unused_constant()
|
|
60
|
+
|
|
61
|
+
return graph_updated
|
|
62
|
+
|
|
63
|
+
def create_unique_node_name(self):
|
|
64
|
+
prefix = self._new_node_name_prefix
|
|
65
|
+
|
|
66
|
+
if self._new_node_name_suffix is None:
|
|
67
|
+
largest_suffix: int = self.model.get_largest_node_name_suffix(prefix)
|
|
68
|
+
self._new_node_name_suffix = largest_suffix + 1
|
|
69
|
+
|
|
70
|
+
new_name = f"{prefix}{self._new_node_name_suffix!s}"
|
|
71
|
+
self._new_node_name_suffix += 1
|
|
72
|
+
|
|
73
|
+
return new_name
|
|
74
|
+
|
|
75
|
+
@staticmethod
|
|
76
|
+
def is_safe_to_fuse_nodes(
|
|
77
|
+
nodes_to_remove: list[onnx.NodeProto],
|
|
78
|
+
keep_outputs: list[str],
|
|
79
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
80
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
81
|
+
) -> bool:
|
|
82
|
+
for node_to_remove in nodes_to_remove:
|
|
83
|
+
for output_to_remove in node_to_remove.output:
|
|
84
|
+
if output_to_remove in keep_outputs:
|
|
85
|
+
continue
|
|
86
|
+
|
|
87
|
+
if output_to_remove in input_name_to_nodes:
|
|
88
|
+
for impacted_node in input_name_to_nodes[output_to_remove]:
|
|
89
|
+
if impacted_node not in nodes_to_remove:
|
|
90
|
+
# Not safe to remove nodes since output is used by impacted_node
|
|
91
|
+
return False
|
|
92
|
+
return True
|
|
93
|
+
|
|
94
|
+
@staticmethod
|
|
95
|
+
def get_node_attribute(node: onnx.NodeProto, attribute_name: str):
|
|
96
|
+
for attr in node.attribute:
|
|
97
|
+
if attr.name == attribute_name:
|
|
98
|
+
value = onnx.helper.get_attribute_value(attr)
|
|
99
|
+
return value
|
|
100
|
+
return None
|
|
101
|
+
|
|
102
|
+
@staticmethod
|
|
103
|
+
def input_index(node_output: str, child_node: onnx.NodeProto) -> int:
|
|
104
|
+
for index, input_name in enumerate(child_node.input):
|
|
105
|
+
if input_name == node_output:
|
|
106
|
+
return index
|
|
107
|
+
return -1
|
|
108
|
+
|
|
109
|
+
@staticmethod
|
|
110
|
+
def tensor_shape_to_list(tensor_type) -> list[int]:
|
|
111
|
+
shape_list = []
|
|
112
|
+
for d in tensor_type.shape.dim:
|
|
113
|
+
if d.HasField("dim_value"):
|
|
114
|
+
shape_list.append(d.dim_value) # known dimension
|
|
115
|
+
elif d.HasField("dim_param"):
|
|
116
|
+
shape_list.append(d.dim_param) # unknown dimension with symbolic name
|
|
117
|
+
else:
|
|
118
|
+
shape_list.append("?") # shall not happen
|
|
119
|
+
return shape_list
|
|
120
|
+
|
|
121
|
+
def get_constant_input(self, node: onnx.NodeProto):
|
|
122
|
+
for i, inp in enumerate(node.input):
|
|
123
|
+
value = self.model.get_constant_value(inp)
|
|
124
|
+
if value is not None:
|
|
125
|
+
return i, value
|
|
126
|
+
|
|
127
|
+
return None, None
|
|
128
|
+
|
|
129
|
+
def find_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> int:
|
|
130
|
+
i, value = self.get_constant_input(node)
|
|
131
|
+
if value is not None and value.size == 1 and abs(value - expected_value) < delta:
|
|
132
|
+
return i
|
|
133
|
+
|
|
134
|
+
return -1
|
|
135
|
+
|
|
136
|
+
def has_constant_input(self, node: onnx.NodeProto, expected_value: float, delta: float = 0.000001) -> bool:
|
|
137
|
+
return self.find_constant_input(node, expected_value, delta) >= 0
|
|
138
|
+
|
|
139
|
+
def is_constant_with_specified_rank(self, output_name: str, rank: int) -> bool:
|
|
140
|
+
value = self.model.get_constant_value(output_name)
|
|
141
|
+
if value is None:
|
|
142
|
+
return False # Not an initializer
|
|
143
|
+
|
|
144
|
+
if len(value.shape) != rank:
|
|
145
|
+
return False # Wrong dimensions
|
|
146
|
+
|
|
147
|
+
return True
|
|
148
|
+
|
|
149
|
+
def match_first_parent(
|
|
150
|
+
self,
|
|
151
|
+
node: onnx.NodeProto,
|
|
152
|
+
parent_op_type: str,
|
|
153
|
+
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
|
154
|
+
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
|
155
|
+
) -> tuple[onnx.NodeProto | None, int | None]:
|
|
156
|
+
"""
|
|
157
|
+
Find parent node based on constraints on op_type.
|
|
158
|
+
|
|
159
|
+
Args:
|
|
160
|
+
node: current node.
|
|
161
|
+
parent_op_type (str): constraint of parent node op_type.
|
|
162
|
+
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
163
|
+
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
|
164
|
+
|
|
165
|
+
Returns:
|
|
166
|
+
parent: The matched parent node. None if not found.
|
|
167
|
+
index: The input index of matched parent node. None if not found.
|
|
168
|
+
"""
|
|
169
|
+
if output_name_to_node is None:
|
|
170
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
171
|
+
|
|
172
|
+
for i, inp in enumerate(node.input):
|
|
173
|
+
if inp in output_name_to_node:
|
|
174
|
+
parent = output_name_to_node[inp]
|
|
175
|
+
if parent.op_type == parent_op_type and parent not in exclude:
|
|
176
|
+
return parent, i
|
|
177
|
+
|
|
178
|
+
return None, None
|
|
179
|
+
|
|
180
|
+
def match_parent(
|
|
181
|
+
self,
|
|
182
|
+
node: onnx.NodeProto,
|
|
183
|
+
parent_op_type: str,
|
|
184
|
+
input_index: int | None = None,
|
|
185
|
+
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
|
186
|
+
exclude: list[onnx.NodeProto] = [], # noqa: B006
|
|
187
|
+
return_indice: list[int] | None = None,
|
|
188
|
+
) -> onnx.NodeProto | None:
|
|
189
|
+
"""
|
|
190
|
+
Find parent node based on constraints on op_type and index.
|
|
191
|
+
When input_index is None, we will find the first parent node based on constraints,
|
|
192
|
+
and return_indice will be appended the corresponding input index.
|
|
193
|
+
|
|
194
|
+
Args:
|
|
195
|
+
node (str): current node name.
|
|
196
|
+
parent_op_type (str): constraint of parent node op_type.
|
|
197
|
+
input_index (int or None): only check the parent given input index of current node.
|
|
198
|
+
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
199
|
+
exclude (list): list of nodes that are excluded (not allowed to match as parent).
|
|
200
|
+
return_indice (list): a list to append the input index when input_index is None.
|
|
201
|
+
|
|
202
|
+
Returns:
|
|
203
|
+
parent: The matched parent node.
|
|
204
|
+
"""
|
|
205
|
+
assert node is not None
|
|
206
|
+
assert input_index is None or input_index >= 0
|
|
207
|
+
|
|
208
|
+
if output_name_to_node is None:
|
|
209
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
210
|
+
|
|
211
|
+
if input_index is None:
|
|
212
|
+
parent, index = self.match_first_parent(node, parent_op_type, output_name_to_node, exclude)
|
|
213
|
+
if return_indice is not None:
|
|
214
|
+
return_indice.append(index)
|
|
215
|
+
return parent
|
|
216
|
+
|
|
217
|
+
if input_index >= len(node.input):
|
|
218
|
+
# Input index out of bounds.
|
|
219
|
+
return None
|
|
220
|
+
|
|
221
|
+
parent = self.model.get_parent(node, input_index, output_name_to_node)
|
|
222
|
+
if parent is not None and parent.op_type == parent_op_type and parent not in exclude:
|
|
223
|
+
return parent
|
|
224
|
+
|
|
225
|
+
return None
|
|
226
|
+
|
|
227
|
+
def match_parent_path(
|
|
228
|
+
self,
|
|
229
|
+
node: onnx.NodeProto,
|
|
230
|
+
parent_op_types: list[str],
|
|
231
|
+
parent_input_index: list[int] | None = None,
|
|
232
|
+
output_name_to_node: dict[str, onnx.NodeProto] | None = None,
|
|
233
|
+
return_indice: list[int] | None = None,
|
|
234
|
+
) -> list[onnx.NodeProto] | None:
|
|
235
|
+
"""
|
|
236
|
+
Find a sequence of input edges based on constraints on parent op_type and index.
|
|
237
|
+
When input_index is None, we will find the first parent node based on constraints,
|
|
238
|
+
and return_indice will be appended the corresponding input index.
|
|
239
|
+
|
|
240
|
+
Args:
|
|
241
|
+
node (str): current node name.
|
|
242
|
+
parent_op_types (str): constraint of parent node op_type of each input edge.
|
|
243
|
+
parent_input_index (list): constraint of input index of each input edge. None means no constraint.
|
|
244
|
+
output_name_to_node (dict): dictionary with output name as key, and node as value.
|
|
245
|
+
return_indice (list): a list to append the input index
|
|
246
|
+
When there is no constraint on input index of an edge.
|
|
247
|
+
|
|
248
|
+
Returns:
|
|
249
|
+
parents: a list of matched parent node.
|
|
250
|
+
"""
|
|
251
|
+
if parent_input_index is not None:
|
|
252
|
+
assert len(parent_input_index) == len(parent_op_types)
|
|
253
|
+
|
|
254
|
+
if output_name_to_node is None:
|
|
255
|
+
output_name_to_node = self.model.output_name_to_node()
|
|
256
|
+
|
|
257
|
+
current_node = node
|
|
258
|
+
matched_parents = []
|
|
259
|
+
for i, op_type in enumerate(parent_op_types):
|
|
260
|
+
matched_parent = self.match_parent(
|
|
261
|
+
current_node,
|
|
262
|
+
op_type,
|
|
263
|
+
parent_input_index[i] if parent_input_index is not None else None,
|
|
264
|
+
output_name_to_node,
|
|
265
|
+
exclude=[],
|
|
266
|
+
return_indice=return_indice,
|
|
267
|
+
)
|
|
268
|
+
if matched_parent is None:
|
|
269
|
+
return None
|
|
270
|
+
|
|
271
|
+
matched_parents.append(matched_parent)
|
|
272
|
+
current_node = matched_parent
|
|
273
|
+
|
|
274
|
+
return matched_parents
|
|
275
|
+
|
|
276
|
+
def match_parent_paths(
|
|
277
|
+
self,
|
|
278
|
+
node: onnx.NodeProto,
|
|
279
|
+
paths: list[tuple[list[str], list[int]]],
|
|
280
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
281
|
+
) -> tuple[int, list[onnx.NodeProto] | None, list[int] | None]:
|
|
282
|
+
"""
|
|
283
|
+
Find a matching parent path to the given node.
|
|
284
|
+
"""
|
|
285
|
+
for i, path in enumerate(paths):
|
|
286
|
+
return_indice = []
|
|
287
|
+
matched = self.match_parent_path(node, path[0], path[1], output_name_to_node, return_indice)
|
|
288
|
+
if matched:
|
|
289
|
+
return i, matched, return_indice
|
|
290
|
+
return -1, None, None
|
|
291
|
+
|
|
292
|
+
def find_first_child_by_type(
|
|
293
|
+
self,
|
|
294
|
+
node: onnx.NodeProto,
|
|
295
|
+
child_type: str,
|
|
296
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]] | None = None,
|
|
297
|
+
recursive: bool = True,
|
|
298
|
+
) -> onnx.NodeProto | None:
|
|
299
|
+
children = self.model.get_children(node, input_name_to_nodes)
|
|
300
|
+
dq = deque(children)
|
|
301
|
+
while len(dq) > 0:
|
|
302
|
+
current_node = dq.pop()
|
|
303
|
+
if current_node.op_type == child_type:
|
|
304
|
+
return current_node
|
|
305
|
+
|
|
306
|
+
if recursive:
|
|
307
|
+
children = self.model.get_children(current_node, input_name_to_nodes)
|
|
308
|
+
for child in children:
|
|
309
|
+
dq.appendleft(child)
|
|
310
|
+
|
|
311
|
+
return None
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
from __future__ import annotations
|
|
7
|
+
|
|
8
|
+
import onnx
|
|
9
|
+
|
|
10
|
+
from ..onnx_model import ONNXModel
|
|
11
|
+
from .fusion import Fusion
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FusionGelu(Fusion):
|
|
15
|
+
def __init__(self, model: ONNXModel):
|
|
16
|
+
super().__init__(model, "Gelu", "Erf")
|
|
17
|
+
|
|
18
|
+
def fuse(
|
|
19
|
+
self,
|
|
20
|
+
erf_node: onnx.NodeProto,
|
|
21
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
22
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
23
|
+
):
|
|
24
|
+
"""
|
|
25
|
+
Interface function that tries to fuse a node sequence containing an Erf node into a single
|
|
26
|
+
Gelu node.
|
|
27
|
+
"""
|
|
28
|
+
if (
|
|
29
|
+
self.fuse_1(erf_node, input_name_to_nodes, output_name_to_node)
|
|
30
|
+
or self.fuse_2(erf_node, input_name_to_nodes, output_name_to_node)
|
|
31
|
+
or self.fuse_3(erf_node, input_name_to_nodes, output_name_to_node)
|
|
32
|
+
):
|
|
33
|
+
self.model.set_opset_import("com.microsoft", 1)
|
|
34
|
+
|
|
35
|
+
def fuse_1(
|
|
36
|
+
self,
|
|
37
|
+
erf_node: onnx.NodeProto,
|
|
38
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
39
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
40
|
+
) -> bool:
|
|
41
|
+
"""
|
|
42
|
+
This pattern is from PyTorch model
|
|
43
|
+
Fuse Gelu with Erf into one node:
|
|
44
|
+
Pattern 1:
|
|
45
|
+
+-------Mul(0.5)---------------------+
|
|
46
|
+
| |
|
|
47
|
+
| v
|
|
48
|
+
[root] --> Div -----> Erf --> Add --> Mul -->
|
|
49
|
+
(B=1.4142...) (1)
|
|
50
|
+
|
|
51
|
+
Pattern 2:
|
|
52
|
+
+------------------------------------+
|
|
53
|
+
| |
|
|
54
|
+
| v
|
|
55
|
+
[root] --> Div -----> Erf --> Add --> Mul -->Mul -->
|
|
56
|
+
(B=1.4142...) (1) (0.5)
|
|
57
|
+
|
|
58
|
+
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
|
59
|
+
"""
|
|
60
|
+
if erf_node.output[0] not in input_name_to_nodes:
|
|
61
|
+
return False
|
|
62
|
+
children = input_name_to_nodes[erf_node.output[0]]
|
|
63
|
+
if len(children) != 1 or children[0].op_type != "Add":
|
|
64
|
+
return False
|
|
65
|
+
add_after_erf = children[0]
|
|
66
|
+
|
|
67
|
+
if not self.has_constant_input(add_after_erf, 1):
|
|
68
|
+
return False
|
|
69
|
+
|
|
70
|
+
if add_after_erf.output[0] not in input_name_to_nodes:
|
|
71
|
+
return False
|
|
72
|
+
|
|
73
|
+
children = input_name_to_nodes[add_after_erf.output[0]]
|
|
74
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
mul_after_erf = children[0]
|
|
78
|
+
|
|
79
|
+
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
|
|
80
|
+
if div is None:
|
|
81
|
+
return False
|
|
82
|
+
|
|
83
|
+
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
|
84
|
+
return False
|
|
85
|
+
|
|
86
|
+
subgraph_input = div.input[0]
|
|
87
|
+
|
|
88
|
+
another = 1 if mul_after_erf.input[0] == add_after_erf.output[0] else 0
|
|
89
|
+
if subgraph_input == mul_after_erf.input[another]: # pattern 2
|
|
90
|
+
children = input_name_to_nodes[mul_after_erf.output[0]]
|
|
91
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
92
|
+
return False
|
|
93
|
+
mul_half = children[0]
|
|
94
|
+
if not self.has_constant_input(mul_half, 0.5):
|
|
95
|
+
return False
|
|
96
|
+
subgraph_output = mul_half.output[0]
|
|
97
|
+
else: # pattern 1
|
|
98
|
+
mul_half = self.match_parent(mul_after_erf, "Mul", another, output_name_to_node)
|
|
99
|
+
if mul_half is None:
|
|
100
|
+
return False
|
|
101
|
+
|
|
102
|
+
if not self.has_constant_input(mul_half, 0.5):
|
|
103
|
+
return False
|
|
104
|
+
|
|
105
|
+
if subgraph_input not in mul_half.input:
|
|
106
|
+
return False
|
|
107
|
+
|
|
108
|
+
subgraph_output = mul_after_erf.output[0]
|
|
109
|
+
|
|
110
|
+
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul_half]
|
|
111
|
+
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node):
|
|
112
|
+
return False
|
|
113
|
+
|
|
114
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
115
|
+
fused_node = onnx.helper.make_node(
|
|
116
|
+
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[subgraph_output]
|
|
117
|
+
)
|
|
118
|
+
fused_node.domain = "com.microsoft"
|
|
119
|
+
self.nodes_to_add.append(fused_node)
|
|
120
|
+
return True
|
|
121
|
+
|
|
122
|
+
def fuse_2(
|
|
123
|
+
self,
|
|
124
|
+
erf_node: onnx.NodeProto,
|
|
125
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
126
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
127
|
+
) -> bool:
|
|
128
|
+
"""
|
|
129
|
+
This pattern is from Keras model
|
|
130
|
+
Fuse Gelu with Erf into one node:
|
|
131
|
+
+------------------------------------------+
|
|
132
|
+
| |
|
|
133
|
+
| v
|
|
134
|
+
[root] --> Div -----> Erf --> Add --> Mul -->Mul
|
|
135
|
+
(B=1.4142...) (A=1) (A=0.5)
|
|
136
|
+
|
|
137
|
+
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
|
138
|
+
"""
|
|
139
|
+
if erf_node.output[0] not in input_name_to_nodes:
|
|
140
|
+
return False
|
|
141
|
+
children = input_name_to_nodes[erf_node.output[0]]
|
|
142
|
+
if len(children) != 1 or children[0].op_type != "Add":
|
|
143
|
+
return False
|
|
144
|
+
add_after_erf = children[0]
|
|
145
|
+
|
|
146
|
+
if not self.has_constant_input(add_after_erf, 1):
|
|
147
|
+
return False
|
|
148
|
+
|
|
149
|
+
if add_after_erf.output[0] not in input_name_to_nodes:
|
|
150
|
+
return False
|
|
151
|
+
children = input_name_to_nodes[add_after_erf.output[0]]
|
|
152
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
153
|
+
return False
|
|
154
|
+
mul_after_erf = children[0]
|
|
155
|
+
|
|
156
|
+
if not self.has_constant_input(mul_after_erf, 0.5):
|
|
157
|
+
return False
|
|
158
|
+
|
|
159
|
+
if mul_after_erf.output[0] not in input_name_to_nodes:
|
|
160
|
+
return False
|
|
161
|
+
children = input_name_to_nodes[mul_after_erf.output[0]]
|
|
162
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
163
|
+
return False
|
|
164
|
+
mul = children[0]
|
|
165
|
+
|
|
166
|
+
div = self.match_parent(erf_node, "Div", 0, output_name_to_node)
|
|
167
|
+
if div is None:
|
|
168
|
+
return False
|
|
169
|
+
|
|
170
|
+
sqrt_node = None
|
|
171
|
+
if self.find_constant_input(div, 1.4142, delta=0.001) != 1:
|
|
172
|
+
sqrt_node = self.match_parent(div, "Sqrt", 1, output_name_to_node)
|
|
173
|
+
if sqrt_node is None:
|
|
174
|
+
return False
|
|
175
|
+
if not self.has_constant_input(sqrt_node, 2.0):
|
|
176
|
+
return False
|
|
177
|
+
|
|
178
|
+
subgraph_input = div.input[0]
|
|
179
|
+
|
|
180
|
+
if subgraph_input not in mul.input:
|
|
181
|
+
return False
|
|
182
|
+
|
|
183
|
+
subgraph_nodes = [div, erf_node, add_after_erf, mul_after_erf, mul]
|
|
184
|
+
if sqrt_node:
|
|
185
|
+
subgraph_nodes.append(sqrt_node)
|
|
186
|
+
|
|
187
|
+
if not self.is_safe_to_fuse_nodes(subgraph_nodes, [mul.output[0]], input_name_to_nodes, output_name_to_node):
|
|
188
|
+
return False
|
|
189
|
+
|
|
190
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
191
|
+
fused_node = onnx.helper.make_node(
|
|
192
|
+
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[mul.output[0]]
|
|
193
|
+
)
|
|
194
|
+
fused_node.domain = "com.microsoft"
|
|
195
|
+
self.nodes_to_add.append(fused_node)
|
|
196
|
+
return True
|
|
197
|
+
|
|
198
|
+
def fuse_3(
|
|
199
|
+
self,
|
|
200
|
+
erf_node: onnx.NodeProto,
|
|
201
|
+
input_name_to_nodes: dict[str, list[onnx.NodeProto]],
|
|
202
|
+
output_name_to_node: dict[str, onnx.NodeProto],
|
|
203
|
+
) -> bool:
|
|
204
|
+
"""
|
|
205
|
+
This pattern is from TensorFlow model
|
|
206
|
+
Fuse Gelu with Erf into one node:
|
|
207
|
+
+----------------------------------------------+
|
|
208
|
+
| |
|
|
209
|
+
| v
|
|
210
|
+
[root] --> Mul -----> Erf --> Add --> Mul -->Mul
|
|
211
|
+
(A=0.7071067690849304) (B=1) (B=0.5)
|
|
212
|
+
|
|
213
|
+
Note that constant input for Add and Mul could be first or second input: like either A=0.5 or B=0.5 is fine.
|
|
214
|
+
"""
|
|
215
|
+
|
|
216
|
+
if erf_node.output[0] not in input_name_to_nodes:
|
|
217
|
+
return False
|
|
218
|
+
children = input_name_to_nodes[erf_node.output[0]]
|
|
219
|
+
if len(children) != 1 or children[0].op_type != "Add":
|
|
220
|
+
return False
|
|
221
|
+
add_after_erf = children[0]
|
|
222
|
+
|
|
223
|
+
if not self.has_constant_input(add_after_erf, 1):
|
|
224
|
+
return False
|
|
225
|
+
|
|
226
|
+
if add_after_erf.output[0] not in input_name_to_nodes:
|
|
227
|
+
return False
|
|
228
|
+
children = input_name_to_nodes[add_after_erf.output[0]]
|
|
229
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
230
|
+
return False
|
|
231
|
+
mul_half = children[0]
|
|
232
|
+
|
|
233
|
+
if not self.has_constant_input(mul_half, 0.5):
|
|
234
|
+
return False
|
|
235
|
+
|
|
236
|
+
first_mul = self.match_parent(erf_node, "Mul", 0, output_name_to_node)
|
|
237
|
+
if first_mul is None:
|
|
238
|
+
return False
|
|
239
|
+
|
|
240
|
+
i = self.find_constant_input(first_mul, 0.7071067690849304, delta=0.001)
|
|
241
|
+
if i < 0:
|
|
242
|
+
return False
|
|
243
|
+
|
|
244
|
+
root_input_index = 1 - i
|
|
245
|
+
subgraph_input = first_mul.input[root_input_index]
|
|
246
|
+
|
|
247
|
+
if mul_half.output[0] not in input_name_to_nodes:
|
|
248
|
+
return False
|
|
249
|
+
children = input_name_to_nodes[mul_half.output[0]]
|
|
250
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
251
|
+
return False
|
|
252
|
+
last_mul = children[0]
|
|
253
|
+
|
|
254
|
+
if not (last_mul.input[0] == subgraph_input or last_mul.input[1] == subgraph_input):
|
|
255
|
+
return False
|
|
256
|
+
|
|
257
|
+
subgraph_nodes = [first_mul, erf_node, add_after_erf, mul_half, last_mul]
|
|
258
|
+
if not self.is_safe_to_fuse_nodes(
|
|
259
|
+
subgraph_nodes,
|
|
260
|
+
[last_mul.output[0]],
|
|
261
|
+
input_name_to_nodes,
|
|
262
|
+
output_name_to_node,
|
|
263
|
+
):
|
|
264
|
+
return False
|
|
265
|
+
|
|
266
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
267
|
+
fused_node = onnx.helper.make_node(
|
|
268
|
+
"Gelu", name=self.create_unique_node_name(), inputs=[subgraph_input], outputs=[last_mul.output[0]]
|
|
269
|
+
)
|
|
270
|
+
fused_node.domain = "com.microsoft"
|
|
271
|
+
self.nodes_to_add.append(fused_node)
|
|
272
|
+
return True
|