onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,66 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import NumpyHelper
|
|
10
|
+
from onnx import helper
|
|
11
|
+
from onnx_model import OnnxModel
|
|
12
|
+
|
|
13
|
+
logger = getLogger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class FusionBiasGelu(Fusion):
|
|
17
|
+
def __init__(self, model: OnnxModel, is_fastgelu):
|
|
18
|
+
if is_fastgelu:
|
|
19
|
+
super().__init__(model, "FastGelu", "FastGelu", "add bias")
|
|
20
|
+
else:
|
|
21
|
+
super().__init__(model, "BiasGelu", "Gelu")
|
|
22
|
+
|
|
23
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
24
|
+
gelu_op_type = node.op_type
|
|
25
|
+
fuse_op_type = "BiasGelu" if gelu_op_type == "Gelu" else "FastGelu"
|
|
26
|
+
|
|
27
|
+
if len(node.input) != 1:
|
|
28
|
+
return
|
|
29
|
+
|
|
30
|
+
nodes = self.model.match_parent_path(node, ["Add", "MatMul"], [0, None])
|
|
31
|
+
if nodes is None:
|
|
32
|
+
return
|
|
33
|
+
(add, matmul) = nodes
|
|
34
|
+
|
|
35
|
+
bias_weight = None
|
|
36
|
+
# bias should be one dimension
|
|
37
|
+
bias_index = -1
|
|
38
|
+
for i, input in enumerate(add.input):
|
|
39
|
+
initializer = self.model.get_initializer(input)
|
|
40
|
+
if initializer is None:
|
|
41
|
+
continue
|
|
42
|
+
bias_index = i
|
|
43
|
+
bias_weight = NumpyHelper.to_array(initializer)
|
|
44
|
+
break
|
|
45
|
+
if bias_weight is None:
|
|
46
|
+
return
|
|
47
|
+
if len(bias_weight.shape) != 1:
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
subgraph_nodes = [node, add]
|
|
51
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
52
|
+
subgraph_nodes, [node.output[0]], input_name_to_nodes, output_name_to_node
|
|
53
|
+
):
|
|
54
|
+
return
|
|
55
|
+
|
|
56
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
57
|
+
|
|
58
|
+
fused_node = helper.make_node(
|
|
59
|
+
fuse_op_type,
|
|
60
|
+
inputs=[matmul.output[0], add.input[bias_index]],
|
|
61
|
+
outputs=node.output,
|
|
62
|
+
name=self.model.create_node_name(fuse_op_type, gelu_op_type + "_AddBias_"),
|
|
63
|
+
)
|
|
64
|
+
fused_node.domain = "com.microsoft"
|
|
65
|
+
self.nodes_to_add.append(fused_node)
|
|
66
|
+
self.node_name_to_graph_name[fused_node.name] = self.this_graph_name
|
|
@@ -0,0 +1,110 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
from logging import getLogger
|
|
6
|
+
|
|
7
|
+
from fusion_base import Fusion
|
|
8
|
+
from onnx import helper
|
|
9
|
+
from onnx_model import OnnxModel
|
|
10
|
+
|
|
11
|
+
logger = getLogger(__name__)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class FusionBiasSplitGelu(Fusion):
|
|
15
|
+
def __init__(self, model: OnnxModel):
|
|
16
|
+
super().__init__(model, "BiasSplitGelu", "Gelu")
|
|
17
|
+
|
|
18
|
+
def fuse(self, gelu_node, input_name_to_nodes: dict, output_name_to_node: dict):
|
|
19
|
+
"""
|
|
20
|
+
[root] --->Add --------------------> Slice ---------------> Mul -->
|
|
21
|
+
| ^ ^
|
|
22
|
+
| | |
|
|
23
|
+
+----------------------------+---Slice --> Gelu---+
|
|
24
|
+
| | ^
|
|
25
|
+
| |-----|
|
|
26
|
+
| | |
|
|
27
|
+
| Mul Mul
|
|
28
|
+
| ^ ^
|
|
29
|
+
v | |
|
|
30
|
+
Shape ---> Gather --> Add --> Div --+
|
|
31
|
+
"""
|
|
32
|
+
if gelu_node.output[0] not in input_name_to_nodes:
|
|
33
|
+
return
|
|
34
|
+
children = input_name_to_nodes[gelu_node.output[0]]
|
|
35
|
+
if len(children) != 1 or children[0].op_type != "Mul":
|
|
36
|
+
return
|
|
37
|
+
mul_after_gelu = children[0]
|
|
38
|
+
|
|
39
|
+
slice_before_gelu = self.model.match_parent(gelu_node, "Slice", 0, output_name_to_node)
|
|
40
|
+
if slice_before_gelu is None:
|
|
41
|
+
return
|
|
42
|
+
|
|
43
|
+
if self.model.find_constant_input(slice_before_gelu, -1, delta=0.001) != 3:
|
|
44
|
+
return
|
|
45
|
+
|
|
46
|
+
add_output = slice_before_gelu.input[0]
|
|
47
|
+
|
|
48
|
+
start_index_nodes = self.model.match_parent_path(
|
|
49
|
+
slice_before_gelu,
|
|
50
|
+
["Div", "Add", "Gather", "Shape", "Add"],
|
|
51
|
+
[1, 0, 0, 0, 0],
|
|
52
|
+
output_name_to_node, # Mul(1) is optional
|
|
53
|
+
)
|
|
54
|
+
if start_index_nodes is None:
|
|
55
|
+
start_index_nodes = self.model.match_parent_path(
|
|
56
|
+
slice_before_gelu,
|
|
57
|
+
["Mul", "Div", "Add", "Gather", "Shape", "Add"],
|
|
58
|
+
[1, 0, 0, 0, 0, 0],
|
|
59
|
+
output_name_to_node,
|
|
60
|
+
)
|
|
61
|
+
|
|
62
|
+
if start_index_nodes is None or start_index_nodes[-2].input[0] != add_output:
|
|
63
|
+
return
|
|
64
|
+
|
|
65
|
+
end_index_nodes = self.model.match_parent_path(slice_before_gelu, ["Mul", "Div"], [2, 0], output_name_to_node)
|
|
66
|
+
|
|
67
|
+
if (
|
|
68
|
+
end_index_nodes is None or end_index_nodes[1] not in start_index_nodes
|
|
69
|
+
): # the Div is parent of both two Mul nodes
|
|
70
|
+
return
|
|
71
|
+
|
|
72
|
+
slice_before_mul = self.model.match_parent(mul_after_gelu, "Slice", 0, output_name_to_node)
|
|
73
|
+
if slice_before_mul is None:
|
|
74
|
+
return
|
|
75
|
+
|
|
76
|
+
if (
|
|
77
|
+
slice_before_mul.input[2] != slice_before_gelu.input[1]
|
|
78
|
+
): # end index of slice_before_mul is start index of slice_before_gelu
|
|
79
|
+
return
|
|
80
|
+
|
|
81
|
+
subgraph_nodes = [
|
|
82
|
+
*start_index_nodes,
|
|
83
|
+
end_index_nodes[0],
|
|
84
|
+
mul_after_gelu,
|
|
85
|
+
gelu_node,
|
|
86
|
+
slice_before_mul,
|
|
87
|
+
slice_before_gelu,
|
|
88
|
+
]
|
|
89
|
+
subgraph_output = mul_after_gelu.output[0]
|
|
90
|
+
if not self.model.is_safe_to_fuse_nodes(
|
|
91
|
+
subgraph_nodes, [subgraph_output], input_name_to_nodes, output_name_to_node
|
|
92
|
+
):
|
|
93
|
+
logger.info("Skip fuse BiasSplitGelu since it is not safe to fuse the subgraph.")
|
|
94
|
+
return
|
|
95
|
+
|
|
96
|
+
add_node = start_index_nodes[-1]
|
|
97
|
+
bias_index, _value = self.model.get_constant_input(add_node)
|
|
98
|
+
if not isinstance(bias_index, int):
|
|
99
|
+
return
|
|
100
|
+
self.nodes_to_remove.extend(subgraph_nodes)
|
|
101
|
+
node_name = self.model.create_node_name("BiasSplitGelu", name_prefix="BiasSplitGelu")
|
|
102
|
+
fused_node = helper.make_node(
|
|
103
|
+
"BiasSplitGelu",
|
|
104
|
+
inputs=[add_node.input[1 - bias_index], add_node.input[bias_index]],
|
|
105
|
+
outputs=[subgraph_output],
|
|
106
|
+
name=node_name,
|
|
107
|
+
)
|
|
108
|
+
fused_node.domain = "com.microsoft"
|
|
109
|
+
self.nodes_to_add.append(fused_node)
|
|
110
|
+
self.node_name_to_graph_name[node_name] = self.this_graph_name
|
|
@@ -0,0 +1,222 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from fusion_attention import AttentionMask, FusionAttention
|
|
8
|
+
from onnx_model import OnnxModel
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class FusionConformerAttention(FusionAttention):
|
|
14
|
+
"""
|
|
15
|
+
Fuse Conformer Attention subgraph into one MultiHeadAttention node.
|
|
16
|
+
"""
|
|
17
|
+
|
|
18
|
+
def __init__(
|
|
19
|
+
self,
|
|
20
|
+
model: OnnxModel,
|
|
21
|
+
hidden_size: int,
|
|
22
|
+
num_heads: int,
|
|
23
|
+
attention_mask: AttentionMask,
|
|
24
|
+
):
|
|
25
|
+
super().__init__(model, hidden_size, num_heads, attention_mask)
|
|
26
|
+
|
|
27
|
+
def fuse(self, normalize_node, input_name_to_nodes, output_name_to_node):
|
|
28
|
+
# SkipLayerNormalization has two inputs, and one of them is the root input for attention.
|
|
29
|
+
qkv_nodes = self.model.match_parent_path(
|
|
30
|
+
normalize_node,
|
|
31
|
+
["Add", "MatMul", "Reshape", "Transpose", "MatMul"],
|
|
32
|
+
[1, None, 0, 0, 0],
|
|
33
|
+
)
|
|
34
|
+
if qkv_nodes is None:
|
|
35
|
+
logger.debug("fuse_conformer_attention: failed to match qkv path")
|
|
36
|
+
return
|
|
37
|
+
|
|
38
|
+
reshape_qkv, transpose_qkv, matmul_qkv = qkv_nodes[-3], qkv_nodes[-2], qkv_nodes[-1]
|
|
39
|
+
|
|
40
|
+
past_v, present_v = "", ""
|
|
41
|
+
v_nodes = self.model.match_parent_path(
|
|
42
|
+
matmul_qkv,
|
|
43
|
+
["Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
44
|
+
[1, 1, 0, 0, 1],
|
|
45
|
+
)
|
|
46
|
+
if v_nodes is None:
|
|
47
|
+
v_nodes = self.model.match_parent_path(
|
|
48
|
+
matmul_qkv,
|
|
49
|
+
["Transpose", "Reshape", "Add", "MatMul"],
|
|
50
|
+
[1, 0, 0, 0],
|
|
51
|
+
)
|
|
52
|
+
if v_nodes is None:
|
|
53
|
+
logger.debug("fuse_conformer_attention: failed to match v path")
|
|
54
|
+
return
|
|
55
|
+
else:
|
|
56
|
+
concat_v = v_nodes[0]
|
|
57
|
+
concat_parent = self.model.get_parent(concat_v, 0, None)
|
|
58
|
+
present_v = concat_v.output[0]
|
|
59
|
+
past_v = concat_parent.output[0]
|
|
60
|
+
|
|
61
|
+
add_v, matmul_v = v_nodes[-2], v_nodes[-1]
|
|
62
|
+
|
|
63
|
+
attn_mask = ""
|
|
64
|
+
qk_nodes = self.model.match_parent_path(
|
|
65
|
+
matmul_qkv,
|
|
66
|
+
["Softmax", "Add", "MatMul"],
|
|
67
|
+
[0, 0, 0],
|
|
68
|
+
)
|
|
69
|
+
if qk_nodes is None:
|
|
70
|
+
qk_nodes = self.model.match_parent_path(
|
|
71
|
+
matmul_qkv,
|
|
72
|
+
["Where", "Softmax", "Where", "Add", "MatMul"],
|
|
73
|
+
[0, 2, 0, 2, 0],
|
|
74
|
+
)
|
|
75
|
+
if qk_nodes is None:
|
|
76
|
+
logger.debug("fuse_conformer_attention: failed to match qk path")
|
|
77
|
+
return
|
|
78
|
+
|
|
79
|
+
where_qk = qk_nodes[2]
|
|
80
|
+
mask_nodes = self.model.match_parent_path(
|
|
81
|
+
where_qk,
|
|
82
|
+
["Equal", "Unsqueeze", "Cast"],
|
|
83
|
+
[0, 0, 0],
|
|
84
|
+
)
|
|
85
|
+
if mask_nodes is not None:
|
|
86
|
+
attn_mask = mask_nodes[-1].output[0]
|
|
87
|
+
|
|
88
|
+
add_qk, matmul_qk = qk_nodes[-2], qk_nodes[-1]
|
|
89
|
+
|
|
90
|
+
q_nodes = self.model.match_parent_path(
|
|
91
|
+
matmul_qk,
|
|
92
|
+
["Div", "Transpose", "Reshape", "Add", "MatMul"],
|
|
93
|
+
[0, 0, 0, 0, 1],
|
|
94
|
+
)
|
|
95
|
+
if q_nodes is None:
|
|
96
|
+
q_nodes = self.model.match_parent_path(
|
|
97
|
+
matmul_qk,
|
|
98
|
+
["Mul", "Transpose", "Reshape", "Add", "MatMul"],
|
|
99
|
+
[0, 0, 0, 0, 0],
|
|
100
|
+
)
|
|
101
|
+
if q_nodes is None:
|
|
102
|
+
logger.debug("fuse_conformer_attention: failed to match q path")
|
|
103
|
+
return
|
|
104
|
+
|
|
105
|
+
reshape_q, add_q, matmul_q = q_nodes[-3], q_nodes[-2], q_nodes[-1]
|
|
106
|
+
|
|
107
|
+
extra_q_nodes = self.model.match_parent_path(
|
|
108
|
+
add_qk,
|
|
109
|
+
["Reshape", "Transpose", "MatMul", "Transpose", "Reshape", "Div"],
|
|
110
|
+
[1, 0, 0, 0, 0, 0],
|
|
111
|
+
)
|
|
112
|
+
if extra_q_nodes is not None and q_nodes[0] != extra_q_nodes[-1]:
|
|
113
|
+
logger.debug("fuse_conformer_attention: failed to match extra q path")
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
past_k, present_k = "", ""
|
|
117
|
+
k_nodes = self.model.match_parent_path(
|
|
118
|
+
matmul_qk,
|
|
119
|
+
["Transpose", "Concat", "Transpose", "Reshape", "Add", "MatMul"],
|
|
120
|
+
[1, 0, 1, 0, 0, 1],
|
|
121
|
+
)
|
|
122
|
+
if k_nodes is None:
|
|
123
|
+
k_nodes = self.model.match_parent_path(
|
|
124
|
+
matmul_qk,
|
|
125
|
+
["Transpose", "Transpose", "Reshape", "Add", "MatMul"],
|
|
126
|
+
[1, 0, 0, 0, 0],
|
|
127
|
+
)
|
|
128
|
+
if k_nodes is None:
|
|
129
|
+
k_nodes = self.model.match_parent_path(
|
|
130
|
+
matmul_qk,
|
|
131
|
+
["Transpose", "Reshape", "Add", "MatMul"],
|
|
132
|
+
[1, 0, 0, 0],
|
|
133
|
+
)
|
|
134
|
+
if k_nodes is None:
|
|
135
|
+
logger.debug("fuse_conformer_attention: failed to match k path")
|
|
136
|
+
return
|
|
137
|
+
else:
|
|
138
|
+
concat_k = k_nodes[1]
|
|
139
|
+
concat_parent = self.model.get_parent(concat_k, 0, None)
|
|
140
|
+
past_k = concat_parent.output[0]
|
|
141
|
+
present_k = concat_k.output[0]
|
|
142
|
+
|
|
143
|
+
add_k, matmul_k = k_nodes[-2], k_nodes[-1]
|
|
144
|
+
|
|
145
|
+
num_heads, hidden_size = self.get_num_heads_and_hidden_size(reshape_q)
|
|
146
|
+
if num_heads <= 0 or hidden_size <= 0 or (hidden_size % num_heads) != 0:
|
|
147
|
+
logger.debug("fuse_conformer_attention: failed to detect num_heads or hidden_size")
|
|
148
|
+
return
|
|
149
|
+
|
|
150
|
+
new_node = None
|
|
151
|
+
use_packed_attention_op = (
|
|
152
|
+
matmul_q.input[0] == matmul_k.input[0] and matmul_k.input[0] == matmul_v.input[0] and extra_q_nodes is None
|
|
153
|
+
)
|
|
154
|
+
if use_packed_attention_op:
|
|
155
|
+
# Self-attention, use Attention op
|
|
156
|
+
new_node = self.create_attention_node(
|
|
157
|
+
mask_index=attn_mask,
|
|
158
|
+
q_matmul=matmul_q,
|
|
159
|
+
k_matmul=matmul_k,
|
|
160
|
+
v_matmul=matmul_v,
|
|
161
|
+
q_add=add_q,
|
|
162
|
+
k_add=add_k,
|
|
163
|
+
v_add=add_v,
|
|
164
|
+
num_heads=num_heads,
|
|
165
|
+
hidden_size=hidden_size,
|
|
166
|
+
first_input=matmul_q.input[0],
|
|
167
|
+
output=reshape_qkv.output[0],
|
|
168
|
+
add_qk_str=add_qk.input[1],
|
|
169
|
+
past_k=past_k,
|
|
170
|
+
past_v=past_v,
|
|
171
|
+
present_k=present_k,
|
|
172
|
+
present_v=present_v,
|
|
173
|
+
)
|
|
174
|
+
else:
|
|
175
|
+
new_node = self.create_multihead_attention_node(
|
|
176
|
+
q_matmul=matmul_q,
|
|
177
|
+
k_matmul=matmul_k,
|
|
178
|
+
v_matmul=matmul_v,
|
|
179
|
+
q_add=add_q,
|
|
180
|
+
k_add=add_k,
|
|
181
|
+
v_add=add_v,
|
|
182
|
+
num_heads=num_heads,
|
|
183
|
+
hidden_size=hidden_size,
|
|
184
|
+
output=reshape_qkv.output[0],
|
|
185
|
+
key_padding_mask=attn_mask,
|
|
186
|
+
add_qk=add_qk.input[1],
|
|
187
|
+
past_k=past_k,
|
|
188
|
+
past_v=past_v,
|
|
189
|
+
present_k=present_k,
|
|
190
|
+
present_v=present_v,
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
if new_node is None:
|
|
194
|
+
logger.debug("fuse_conformer_attention: MultiHeadAttention node creation failed")
|
|
195
|
+
return
|
|
196
|
+
|
|
197
|
+
self.nodes_to_add.append(new_node)
|
|
198
|
+
self.node_name_to_graph_name[new_node.name] = self.this_graph_name
|
|
199
|
+
|
|
200
|
+
self.nodes_to_remove.extend([reshape_qkv, transpose_qkv, matmul_qkv])
|
|
201
|
+
self.nodes_to_remove.extend(qk_nodes)
|
|
202
|
+
|
|
203
|
+
# When using MultiHeadAttention, keep MatMul nodes unfused in original graph
|
|
204
|
+
if not use_packed_attention_op:
|
|
205
|
+
if q_nodes[-1].op_type == "MatMul":
|
|
206
|
+
q_nodes.pop()
|
|
207
|
+
if k_nodes[-1].op_type == "MatMul":
|
|
208
|
+
k_nodes.pop()
|
|
209
|
+
if v_nodes[-1].op_type == "MatMul":
|
|
210
|
+
v_nodes.pop()
|
|
211
|
+
|
|
212
|
+
if extra_q_nodes is None:
|
|
213
|
+
# Don't remove Q nodes for conformer-transducer (CT) model since it has
|
|
214
|
+
# an extra set of nodes attached to the output of the Q path that are not
|
|
215
|
+
# part of the attention computation
|
|
216
|
+
self.nodes_to_remove.extend(q_nodes)
|
|
217
|
+
|
|
218
|
+
self.nodes_to_remove.extend(k_nodes)
|
|
219
|
+
self.nodes_to_remove.extend(v_nodes)
|
|
220
|
+
|
|
221
|
+
# Use prune graph to remove mask nodes since they are shared by all attention nodes.
|
|
222
|
+
self.prune_graph = True
|
|
@@ -0,0 +1,144 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
from logging import getLogger
|
|
7
|
+
|
|
8
|
+
from fusion_base import Fusion
|
|
9
|
+
from fusion_utils import NumpyHelper
|
|
10
|
+
from onnx_model import OnnxModel
|
|
11
|
+
|
|
12
|
+
logger = getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class FusionConstantFold(Fusion):
|
|
16
|
+
def __init__(self, model: OnnxModel):
|
|
17
|
+
super().__init__(model, "", ["Transpose"])
|
|
18
|
+
self.count = 0
|
|
19
|
+
|
|
20
|
+
def apply(self):
|
|
21
|
+
super().apply()
|
|
22
|
+
if self.count > 0:
|
|
23
|
+
logger.info(f"Constant Folded: {self.count}")
|
|
24
|
+
|
|
25
|
+
def fuse(self, node, input_name_to_nodes, output_name_to_node):
|
|
26
|
+
"""
|
|
27
|
+
Apply multiple fusions on Transpose nodes that can be constant folded.
|
|
28
|
+
"""
|
|
29
|
+
self.fuse_1(node, input_name_to_nodes, output_name_to_node)
|
|
30
|
+
self.fuse_2(node, input_name_to_nodes, output_name_to_node)
|
|
31
|
+
|
|
32
|
+
def fuse_1(self, node, input_name_to_nodes, output_name_to_node):
|
|
33
|
+
"""
|
|
34
|
+
Constant fold any initializer data representing a MatMul's
|
|
35
|
+
weights that are stored in a Transpose op
|
|
36
|
+
|
|
37
|
+
Ex: Transpose --> Gemm or Transpose --> MatMul
|
|
38
|
+
"""
|
|
39
|
+
# Check if Transpose node only has one input and one output
|
|
40
|
+
if len(node.input) != 1 or len(node.output) != 1:
|
|
41
|
+
logger.debug("fuse_constant_fold: node has more than one input or output")
|
|
42
|
+
return
|
|
43
|
+
|
|
44
|
+
# Check if input is initializer data
|
|
45
|
+
proto = self.model.get_initializer(node.input[0])
|
|
46
|
+
if proto is None:
|
|
47
|
+
logger.debug("fuse_constant_fold: failed to identify initializer input")
|
|
48
|
+
return
|
|
49
|
+
|
|
50
|
+
# Check that all nodes using input are Transpose ops that also only use the initializer data as input
|
|
51
|
+
skip = False
|
|
52
|
+
for child_node in input_name_to_nodes[node.input[0]]:
|
|
53
|
+
if not (child_node.op_type == "Transpose" and len(node.input) == 1):
|
|
54
|
+
skip = True
|
|
55
|
+
break
|
|
56
|
+
if skip:
|
|
57
|
+
logger.debug("fuse_constant_fold: other non-Transpose nodes use the initializer")
|
|
58
|
+
return
|
|
59
|
+
|
|
60
|
+
# Check that all nodes using output are Gemm or MatMul ops
|
|
61
|
+
for child_node in input_name_to_nodes[node.output[0]]:
|
|
62
|
+
if not (child_node.op_type == "Gemm" or child_node.op_type == "MatMul"):
|
|
63
|
+
skip = True
|
|
64
|
+
break
|
|
65
|
+
if skip:
|
|
66
|
+
logger.debug("fuse_constant_fold: other non-Gemm and non-MatMul nodes use the transposed data")
|
|
67
|
+
return
|
|
68
|
+
|
|
69
|
+
# Check if initializer data is 2D
|
|
70
|
+
weight = NumpyHelper.to_array(proto)
|
|
71
|
+
if len(weight.shape) != 2:
|
|
72
|
+
logger.debug("fuse_constant_fold: shape of initializer data is not 2D")
|
|
73
|
+
return
|
|
74
|
+
|
|
75
|
+
# Remove old TensorProto and add new TensorProto while re-using same name
|
|
76
|
+
name = proto.name
|
|
77
|
+
dtype = proto.data_type
|
|
78
|
+
self.remove_initializer(proto)
|
|
79
|
+
self.add_initializer(
|
|
80
|
+
name=name,
|
|
81
|
+
data_type=dtype,
|
|
82
|
+
dims=[weight.shape[1], weight.shape[0]],
|
|
83
|
+
vals=weight.T,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Update weights input to be the initializer name and not
|
|
87
|
+
# the output of the Transpose op
|
|
88
|
+
for child_node in input_name_to_nodes[node.output[0]]:
|
|
89
|
+
for i in range(len(child_node.input)):
|
|
90
|
+
if child_node.input[i] == node.output[0]:
|
|
91
|
+
child_node.input[i] = node.input[0]
|
|
92
|
+
|
|
93
|
+
if child_node.op_type == "Gemm" and (i == 0 or i == 1):
|
|
94
|
+
# Ensure that transA/transB is set to 0 in Gemm
|
|
95
|
+
key = "transA" if i == 0 else "transB"
|
|
96
|
+
for j, attr_key in enumerate(child_node.attribute):
|
|
97
|
+
if attr_key.name == key:
|
|
98
|
+
child_node.attribute[j].i = 0
|
|
99
|
+
|
|
100
|
+
# Add node to list of nodes to remove
|
|
101
|
+
self.nodes_to_remove.append(node)
|
|
102
|
+
self.count += 1
|
|
103
|
+
|
|
104
|
+
def fuse_2(self, node, input_name_to_nodes, output_name_to_node):
|
|
105
|
+
"""
|
|
106
|
+
Constant fold any Transpose --> Transpose ops since the root input
|
|
107
|
+
is the final result
|
|
108
|
+
|
|
109
|
+
Ex: root_input --> Transpose --> Transpose --> next_node to root_input --> next_node
|
|
110
|
+
"""
|
|
111
|
+
# Check if Transpose node only has one input and one output
|
|
112
|
+
if len(node.input) != 1 or len(node.output) != 1:
|
|
113
|
+
logger.debug("fuse_constant_fold: node has more than one input or output")
|
|
114
|
+
return
|
|
115
|
+
|
|
116
|
+
# Check if parent node is Transpose node with only one input and one output
|
|
117
|
+
parent_node = self.model.match_parent(node, "Transpose", 0)
|
|
118
|
+
if parent_node is None:
|
|
119
|
+
logger.debug("fuse_constant_fold: failed to identify parent Transpose node")
|
|
120
|
+
return
|
|
121
|
+
if len(parent_node.input) != 1 or len(parent_node.output) != 1:
|
|
122
|
+
logger.debug("fuse_constant_fold: parent node has more than one input or output")
|
|
123
|
+
return
|
|
124
|
+
|
|
125
|
+
node_perm = node.attribute[0].ints
|
|
126
|
+
parent_node_perm = parent_node.attribute[0].ints
|
|
127
|
+
|
|
128
|
+
if node_perm != parent_node_perm:
|
|
129
|
+
logger.debug("fuse_constant_fold: Transpose node permutations aren't identical")
|
|
130
|
+
return
|
|
131
|
+
|
|
132
|
+
# For nodes that use output of child Transpose node as an input,
|
|
133
|
+
# replace that input with root_input
|
|
134
|
+
root_input = parent_node.input[0]
|
|
135
|
+
output_nodes = input_name_to_nodes[node.output[0]]
|
|
136
|
+
for output_node in output_nodes:
|
|
137
|
+
for i, input_ in enumerate(output_node.input):
|
|
138
|
+
if input_ == node.output[0]:
|
|
139
|
+
output_node.input[i] = root_input
|
|
140
|
+
|
|
141
|
+
# Add node to list of nodes to remove
|
|
142
|
+
self.nodes_to_remove.append(node)
|
|
143
|
+
self.nodes_to_remove.append(parent_node)
|
|
144
|
+
self.count += 1
|