onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,131 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import inspect
|
|
5
|
+
from collections import abc
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def _parse_inputs_for_onnx_export(all_input_parameters, inputs, kwargs):
|
|
11
|
+
# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L433
|
|
12
|
+
|
|
13
|
+
def _add_input(name, input):
|
|
14
|
+
"""Returns number of expanded inputs that _add_input processed"""
|
|
15
|
+
|
|
16
|
+
if input is None:
|
|
17
|
+
# Drop all None inputs and return 0.
|
|
18
|
+
return 0
|
|
19
|
+
|
|
20
|
+
num_expanded_non_none_inputs = 0
|
|
21
|
+
if isinstance(input, abc.Sequence):
|
|
22
|
+
# If the input is a sequence (like a list), expand the list so that
|
|
23
|
+
# each element of the list is an input by itself.
|
|
24
|
+
for i, val in enumerate(input):
|
|
25
|
+
# Name each input with the index appended to the original name of the
|
|
26
|
+
# argument.
|
|
27
|
+
num_expanded_non_none_inputs += _add_input(f"{name}_{i}", val)
|
|
28
|
+
|
|
29
|
+
# Return here since the list by itself is not a valid input.
|
|
30
|
+
# All the elements of the list have already been added as inputs individually.
|
|
31
|
+
return num_expanded_non_none_inputs
|
|
32
|
+
elif isinstance(input, abc.Mapping):
|
|
33
|
+
# If the input is a mapping (like a dict), expand the dict so that
|
|
34
|
+
# each element of the dict is an input by itself.
|
|
35
|
+
for key, val in input.items():
|
|
36
|
+
num_expanded_non_none_inputs += _add_input(f"{name}_{key}", val)
|
|
37
|
+
|
|
38
|
+
# Return here since the dict by itself is not a valid input.
|
|
39
|
+
# All the elements of the dict have already been added as inputs individually.
|
|
40
|
+
return num_expanded_non_none_inputs
|
|
41
|
+
|
|
42
|
+
# InputInfo should contain all the names irrespective of whether they are
|
|
43
|
+
# a part of the onnx graph or not.
|
|
44
|
+
input_names.append(name)
|
|
45
|
+
|
|
46
|
+
# A single input non none input was processed, return 1
|
|
47
|
+
return 1
|
|
48
|
+
|
|
49
|
+
input_names = []
|
|
50
|
+
var_positional_idx = 0
|
|
51
|
+
num_expanded_non_none_positional_inputs = 0
|
|
52
|
+
|
|
53
|
+
for input_idx, input_parameter in enumerate(all_input_parameters):
|
|
54
|
+
if input_parameter.kind == inspect.Parameter.VAR_POSITIONAL:
|
|
55
|
+
# VAR_POSITIONAL parameter carries all *args parameters from original forward method
|
|
56
|
+
for args_i in range(input_idx, len(inputs)):
|
|
57
|
+
name = f"{input_parameter.name}_{var_positional_idx}"
|
|
58
|
+
var_positional_idx += 1
|
|
59
|
+
inp = inputs[args_i]
|
|
60
|
+
num_expanded_non_none_positional_inputs += _add_input(name, inp)
|
|
61
|
+
elif (
|
|
62
|
+
input_parameter.kind == inspect.Parameter.POSITIONAL_ONLY
|
|
63
|
+
or input_parameter.kind == inspect.Parameter.POSITIONAL_OR_KEYWORD
|
|
64
|
+
or input_parameter.kind == inspect.Parameter.KEYWORD_ONLY
|
|
65
|
+
):
|
|
66
|
+
# All positional non-*args and non-**kwargs are processed here
|
|
67
|
+
name = input_parameter.name
|
|
68
|
+
inp = None
|
|
69
|
+
input_idx += var_positional_idx # noqa: PLW2901
|
|
70
|
+
is_positional = True
|
|
71
|
+
if input_idx < len(inputs) and inputs[input_idx] is not None:
|
|
72
|
+
inp = inputs[input_idx]
|
|
73
|
+
elif name in kwargs and kwargs[name] is not None:
|
|
74
|
+
inp = kwargs[name]
|
|
75
|
+
is_positional = False
|
|
76
|
+
num_expanded_non_none_inputs_local = _add_input(name, inp)
|
|
77
|
+
if is_positional:
|
|
78
|
+
num_expanded_non_none_positional_inputs += num_expanded_non_none_inputs_local
|
|
79
|
+
elif input_parameter.kind == inspect.Parameter.VAR_KEYWORD:
|
|
80
|
+
# **kwargs is always the last argument of forward()
|
|
81
|
+
for name, inp in kwargs.items():
|
|
82
|
+
if name not in input_names:
|
|
83
|
+
_add_input(name, inp)
|
|
84
|
+
|
|
85
|
+
return input_names
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
def _flatten_module_input(names, args, kwargs):
|
|
89
|
+
"""Flatten args and kwargs in a single tuple of tensors."""
|
|
90
|
+
# extracted from https://github.com/microsoft/onnxruntime/blob/239c6ad3f021ff7cc2e6247eb074bd4208dc11e2/orttraining/orttraining/python/training/ortmodule/_io.py#L110
|
|
91
|
+
|
|
92
|
+
def is_primitive_type(value):
|
|
93
|
+
return type(value) in {int, bool, float}
|
|
94
|
+
|
|
95
|
+
def to_tensor(value):
|
|
96
|
+
return torch.tensor(value)
|
|
97
|
+
|
|
98
|
+
ret = [to_tensor(arg) if is_primitive_type(arg) else arg for arg in args]
|
|
99
|
+
ret += [
|
|
100
|
+
to_tensor(kwargs[name]) if is_primitive_type(kwargs[name]) else kwargs[name] for name in names if name in kwargs
|
|
101
|
+
]
|
|
102
|
+
|
|
103
|
+
# if kwargs is empty, append an empty dictionary at the end of the sample inputs to make exporter
|
|
104
|
+
# happy. This is because the exporter is confused with kwargs and dictionary inputs otherwise.
|
|
105
|
+
if not kwargs:
|
|
106
|
+
ret.append({})
|
|
107
|
+
|
|
108
|
+
return tuple(ret)
|
|
109
|
+
|
|
110
|
+
|
|
111
|
+
def infer_input_info(module: torch.nn.Module, *inputs, **kwargs):
|
|
112
|
+
"""
|
|
113
|
+
Infer the input names and order from the arguments used to execute a PyTorch module for usage exporting
|
|
114
|
+
the model via torch.onnx.export.
|
|
115
|
+
Assumes model is on CPU. Use `module.to(torch.device('cpu'))` if it isn't.
|
|
116
|
+
|
|
117
|
+
Example usage:
|
|
118
|
+
input_names, inputs_as_tuple = infer_input_info(module, ...)
|
|
119
|
+
torch.onnx.export(module, inputs_as_type, 'model.onnx', input_names=input_names, output_names=[...], ...)
|
|
120
|
+
|
|
121
|
+
:param module: Module
|
|
122
|
+
:param inputs: Positional inputs
|
|
123
|
+
:param kwargs: Keyword argument inputs
|
|
124
|
+
:return: Tuple of ordered input names and input values. These can be used directly with torch.onnx.export as the
|
|
125
|
+
`input_names` and `inputs` arguments.
|
|
126
|
+
"""
|
|
127
|
+
module_parameters = inspect.signature(module.forward).parameters.values()
|
|
128
|
+
input_names = _parse_inputs_for_onnx_export(module_parameters, inputs, kwargs)
|
|
129
|
+
inputs_as_tuple = _flatten_module_input(input_names, inputs, kwargs)
|
|
130
|
+
|
|
131
|
+
return input_names, inputs_as_tuple
|
|
File without changes
|
|
@@ -0,0 +1,37 @@
|
|
|
1
|
+
#!/usr/bin/env python3
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
|
|
5
|
+
import argparse
|
|
6
|
+
import os
|
|
7
|
+
import pathlib
|
|
8
|
+
|
|
9
|
+
import onnx
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def optimize_qdq_model():
|
|
13
|
+
parser = argparse.ArgumentParser(
|
|
14
|
+
os.path.basename(__file__),
|
|
15
|
+
description="Update a QDQ format ONNX model to ensure optimal performance when executed using ONNX Runtime.",
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
parser.add_argument("input_model", type=pathlib.Path, help="Provide path to ONNX model to update.")
|
|
19
|
+
parser.add_argument("output_model", type=pathlib.Path, help="Provide path to write updated ONNX model to.")
|
|
20
|
+
|
|
21
|
+
args = parser.parse_args()
|
|
22
|
+
|
|
23
|
+
model = onnx.load(str(args.input_model.resolve(strict=True)))
|
|
24
|
+
|
|
25
|
+
# run QDQ model optimizations here
|
|
26
|
+
|
|
27
|
+
# Originally, the fixing up of DQ nodes with multiple consumers was implemented as one such optimization.
|
|
28
|
+
# That was moved to an ORT graph transformer.
|
|
29
|
+
print("As of ORT 1.15, the fixing up of DQ nodes with multiple consumers is done by an ORT graph transformer.")
|
|
30
|
+
|
|
31
|
+
# There are no optimizations being run currently but we expect that there may be in the future.
|
|
32
|
+
|
|
33
|
+
onnx.save(model, str(args.output_model.resolve()))
|
|
34
|
+
|
|
35
|
+
|
|
36
|
+
if __name__ == "__main__":
|
|
37
|
+
optimize_qdq_model()
|
|
@@ -0,0 +1,292 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import json
|
|
7
|
+
from argparse import ArgumentParser
|
|
8
|
+
|
|
9
|
+
import onnx
|
|
10
|
+
from onnx import TensorProto, helper
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
def graph_topological_sort(graph):
|
|
14
|
+
deps_count = [0] * len(graph.node) # dependency count of each node
|
|
15
|
+
deps_to_nodes = {} # input to node indice
|
|
16
|
+
sorted_nodes = [] # initialize sorted_nodes
|
|
17
|
+
for node_idx, node in enumerate(graph.node):
|
|
18
|
+
# CANNOT use len(node.input) directly because input can be optional
|
|
19
|
+
deps_count[node_idx] = sum(1 for _ in node.input if _)
|
|
20
|
+
if deps_count[node_idx] == 0: # Constant doesn't depend on any inputs
|
|
21
|
+
sorted_nodes.append(graph.node[node_idx])
|
|
22
|
+
continue
|
|
23
|
+
|
|
24
|
+
for input_name in node.input:
|
|
25
|
+
if input_name not in deps_to_nodes:
|
|
26
|
+
deps_to_nodes[input_name] = [node_idx]
|
|
27
|
+
else:
|
|
28
|
+
deps_to_nodes[input_name].append(node_idx)
|
|
29
|
+
|
|
30
|
+
# Note: this logic only applies to top level graph since a sub graph could use intializer from parent graph
|
|
31
|
+
initializer_names = [init.name for init in graph.initializer]
|
|
32
|
+
graph_input_names = [input.name for input in graph.input]
|
|
33
|
+
input_names = initializer_names + graph_input_names
|
|
34
|
+
input_names.sort()
|
|
35
|
+
prev_input_name = None
|
|
36
|
+
for input_name in input_names:
|
|
37
|
+
if prev_input_name == input_name:
|
|
38
|
+
continue
|
|
39
|
+
|
|
40
|
+
prev_input_name = input_name
|
|
41
|
+
if input_name in deps_to_nodes:
|
|
42
|
+
for node_idx in deps_to_nodes[input_name]:
|
|
43
|
+
deps_count[node_idx] = deps_count[node_idx] - 1
|
|
44
|
+
if deps_count[node_idx] == 0:
|
|
45
|
+
sorted_nodes.append(graph.node[node_idx])
|
|
46
|
+
|
|
47
|
+
start = 0
|
|
48
|
+
end = len(sorted_nodes)
|
|
49
|
+
|
|
50
|
+
while start < end:
|
|
51
|
+
for output in sorted_nodes[start].output:
|
|
52
|
+
if output in deps_to_nodes:
|
|
53
|
+
for node_idx in deps_to_nodes[output]:
|
|
54
|
+
deps_count[node_idx] = deps_count[node_idx] - 1
|
|
55
|
+
if deps_count[node_idx] == 0:
|
|
56
|
+
sorted_nodes.append(graph.node[node_idx])
|
|
57
|
+
end = end + 1
|
|
58
|
+
start = start + 1
|
|
59
|
+
|
|
60
|
+
assert end == len(graph.node), "Graph is not a DAG"
|
|
61
|
+
graph.ClearField("node")
|
|
62
|
+
graph.node.extend(sorted_nodes)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class QnnTensorStruct:
|
|
66
|
+
def __init__(self):
|
|
67
|
+
self.name = ""
|
|
68
|
+
self.onnx_data_type = TensorProto.FLOAT
|
|
69
|
+
self.dim = []
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
def qnn_data_type_to_onnx_data_type(qnn_data_type):
|
|
73
|
+
# QNN_DATATYPE_UFIXED_POINT_8 QNN_DATATYPE_UINT_8
|
|
74
|
+
if qnn_data_type == 0x0408 or qnn_data_type == 0x0108:
|
|
75
|
+
return TensorProto.UINT8
|
|
76
|
+
# QNN_DATATYPE_UFIXED_POINT_16 QNN_DATATYPE_UINT_16
|
|
77
|
+
elif qnn_data_type == 0x0416 or qnn_data_type == 0x0116:
|
|
78
|
+
return TensorProto.UINT16
|
|
79
|
+
# QNN_DATATYPE_UFIXED_POINT_32 QNN_DATATYPE_UINT_32
|
|
80
|
+
elif qnn_data_type == 0x0432 or qnn_data_type == 0x0132:
|
|
81
|
+
return TensorProto.UINT32
|
|
82
|
+
# QNN_DATATYPE_UINT_64
|
|
83
|
+
elif qnn_data_type == 0x0164:
|
|
84
|
+
return TensorProto.UINT64
|
|
85
|
+
# QNN_DATATYPE_FIXED_POINT_8 QNN_DATATYPE_INT_8
|
|
86
|
+
elif qnn_data_type == 0x0308 or qnn_data_type == 0x0008:
|
|
87
|
+
return TensorProto.INT8
|
|
88
|
+
# QNN_DATATYPE_FIXED_POINT_16 QNN_DATATYPE_INT_16
|
|
89
|
+
elif qnn_data_type == 0x0316 or qnn_data_type == 0x0016:
|
|
90
|
+
return TensorProto.INT16
|
|
91
|
+
# QNN_DATATYPE_FIXED_POINT_32 QNN_DATATYPE_INT_32
|
|
92
|
+
elif qnn_data_type == 0x0332 or qnn_data_type == 0x0032:
|
|
93
|
+
return TensorProto.INT32
|
|
94
|
+
# QNN_DATATYPE_INT_64
|
|
95
|
+
elif qnn_data_type == 0x0064:
|
|
96
|
+
return TensorProto.INT64
|
|
97
|
+
# QNN_DATATYPE_FLOAT_16
|
|
98
|
+
elif qnn_data_type == 0x0216:
|
|
99
|
+
return TensorProto.FLOAT16
|
|
100
|
+
# QNN_DATATYPE_FLOAT_32
|
|
101
|
+
elif qnn_data_type == 0x0232:
|
|
102
|
+
return TensorProto.FLOAT
|
|
103
|
+
# QNN_DATATYPE_BOOL_8
|
|
104
|
+
elif qnn_data_type == 0x0508:
|
|
105
|
+
return TensorProto.BOOL
|
|
106
|
+
else:
|
|
107
|
+
return TensorProto.UNDEFINED
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
def parse_qnn_json_file(qnn_json_file_path, qnn_input_output_tensor_dic):
|
|
111
|
+
with open(qnn_json_file_path) as qnn_json_file:
|
|
112
|
+
qnn_json = json.load(qnn_json_file)
|
|
113
|
+
assert "graph" in qnn_json, "QNN converted json file not valid. Can't find graph."
|
|
114
|
+
assert "tensors" in qnn_json["graph"], "QNN converted json file not valid. Can't find tensors."
|
|
115
|
+
for qnn_tensor_name, qnn_tensor_attribute in qnn_json["graph"]["tensors"].items():
|
|
116
|
+
# type:0 - QNN input tensor, type:1 - QNN output tensor
|
|
117
|
+
assert (
|
|
118
|
+
"type" in qnn_tensor_attribute
|
|
119
|
+
and "data_type" in qnn_tensor_attribute
|
|
120
|
+
and "dims" in qnn_tensor_attribute
|
|
121
|
+
), "QNN converted json file not valid. Can't find some keys from tensors"
|
|
122
|
+
if qnn_tensor_attribute["type"] == 0 or qnn_tensor_attribute["type"] == 1:
|
|
123
|
+
qnn_tensor = QnnTensorStruct()
|
|
124
|
+
qnn_tensor.name = qnn_tensor_name
|
|
125
|
+
qnn_tensor.onnx_data_type = qnn_data_type_to_onnx_data_type(qnn_tensor_attribute["data_type"])
|
|
126
|
+
qnn_tensor.dim = qnn_tensor_attribute["dims"]
|
|
127
|
+
qnn_input_output_tensor_dic[qnn_tensor_name] = qnn_tensor
|
|
128
|
+
|
|
129
|
+
assert len(qnn_input_output_tensor_dic) > 1, (
|
|
130
|
+
"Converted QNN model not valid. It should have at least 1 input & 1 output."
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
|
|
134
|
+
def compare_onnx_shape_with_qnn_shape(onnx_dims, qnn_dims):
|
|
135
|
+
assert len(onnx_dims) == len(qnn_dims), "Onnx shape and Qnn shape has different rank."
|
|
136
|
+
return all(onnx_dims[i].dim_value == qnn_dims[i] for i in range(len(onnx_dims)))
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def gen_to_channel_first_perm(rank):
|
|
140
|
+
assert rank > 2, "Shape rank should >2 for the Transpose node."
|
|
141
|
+
perm = []
|
|
142
|
+
perm.append(0)
|
|
143
|
+
perm.append(rank - 1)
|
|
144
|
+
for i in range(1, rank - 1):
|
|
145
|
+
perm.append(i) # noqa: PERF402
|
|
146
|
+
|
|
147
|
+
return perm
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
def gen_to_channel_last_perm(rank):
|
|
151
|
+
assert rank > 2, "Shape rank should >2 for the Transpose node."
|
|
152
|
+
perm = []
|
|
153
|
+
perm.append(0)
|
|
154
|
+
for i in range(2, rank):
|
|
155
|
+
perm.append(i) # noqa: PERF402
|
|
156
|
+
perm.append(1)
|
|
157
|
+
|
|
158
|
+
return perm
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
# Onnxruntime QNN EP can support context binary file generated by QNN tool chain. However QNN generated context binary file
|
|
162
|
+
# uses channel last data layout and 8 bits or 16 bits for input and output.
|
|
163
|
+
# This script gets the QNN model input & output information from QNN converted model_net.json file, compare them with Onnx model
|
|
164
|
+
# and inserts Cast, Transpose nodes to Onnx model if required
|
|
165
|
+
def main():
|
|
166
|
+
parser = ArgumentParser(
|
|
167
|
+
"Insert Cast, Transpose nodes into Onnx model to make it aligned with QNN generated context binary."
|
|
168
|
+
)
|
|
169
|
+
parser.add_argument("-m", "--onnx_model", help="Required. Path to Onnx model file.", required=True, type=str)
|
|
170
|
+
parser.add_argument(
|
|
171
|
+
"-q", "--qnn_json", help="Required. Path to Qnn converted model_net.json file.", required=True, type=str
|
|
172
|
+
)
|
|
173
|
+
args = parser.parse_args()
|
|
174
|
+
|
|
175
|
+
# Parse Qnn model_net.json file to get the graph input output information
|
|
176
|
+
qnn_input_output_tensor_dic = {}
|
|
177
|
+
parse_qnn_json_file(args.qnn_json, qnn_input_output_tensor_dic)
|
|
178
|
+
|
|
179
|
+
model = onnx.load(args.onnx_model)
|
|
180
|
+
|
|
181
|
+
nodes_to_add = []
|
|
182
|
+
# Tranch the tensor name change to update the consumer nodes
|
|
183
|
+
graph_input_output_name_dic = {}
|
|
184
|
+
for graph_input in model.graph.input:
|
|
185
|
+
if graph_input.name in qnn_input_output_tensor_dic:
|
|
186
|
+
input_name_fater_node_insert = graph_input.name
|
|
187
|
+
qnn_input_tensor = qnn_input_output_tensor_dic[graph_input.name]
|
|
188
|
+
# Insert Cast node if Onnx input and Qnn input has different data type
|
|
189
|
+
if graph_input.type.tensor_type.elem_type != qnn_input_tensor.onnx_data_type:
|
|
190
|
+
# Insert Cast node
|
|
191
|
+
cast_input_name = input_name_fater_node_insert
|
|
192
|
+
cast_output_name = cast_input_name + "_qnn_cast"
|
|
193
|
+
input_cast_node = helper.make_node(
|
|
194
|
+
"Cast",
|
|
195
|
+
name=cast_output_name,
|
|
196
|
+
inputs=[cast_input_name],
|
|
197
|
+
outputs=[cast_output_name],
|
|
198
|
+
to=graph_input.type.tensor_type.elem_type,
|
|
199
|
+
)
|
|
200
|
+
# Change input data type to Qnn input data type
|
|
201
|
+
graph_input.type.tensor_type.elem_type = qnn_input_tensor.onnx_data_type
|
|
202
|
+
nodes_to_add.extend([input_cast_node])
|
|
203
|
+
input_name_fater_node_insert = cast_output_name
|
|
204
|
+
graph_input_output_name_dic[graph_input.name] = cast_output_name
|
|
205
|
+
|
|
206
|
+
if not compare_onnx_shape_with_qnn_shape(graph_input.type.tensor_type.shape.dim, qnn_input_tensor.dim):
|
|
207
|
+
# Add Transpose node (channel last to channel first)
|
|
208
|
+
transpose_perm = gen_to_channel_first_perm(len(graph_input.type.tensor_type.shape.dim))
|
|
209
|
+
transpose_input_name = input_name_fater_node_insert
|
|
210
|
+
transpose_output_name = transpose_input_name + "_qnn_trans"
|
|
211
|
+
input_transpose_node = helper.make_node(
|
|
212
|
+
"Transpose",
|
|
213
|
+
name=transpose_output_name,
|
|
214
|
+
inputs=[transpose_input_name],
|
|
215
|
+
outputs=[transpose_output_name],
|
|
216
|
+
perm=transpose_perm,
|
|
217
|
+
)
|
|
218
|
+
nodes_to_add.extend([input_transpose_node])
|
|
219
|
+
graph_input_output_name_dic[graph_input.name] = transpose_output_name
|
|
220
|
+
|
|
221
|
+
# Change input shape to Qnn input shape
|
|
222
|
+
for i in range(len(graph_input.type.tensor_type.shape.dim)):
|
|
223
|
+
graph_input.type.tensor_type.shape.dim[i].dim_value = qnn_input_tensor.dim[i]
|
|
224
|
+
else:
|
|
225
|
+
raise AssertionError("Error: Onnx model input: " + graph_input.name + " not exist from QNN model input.")
|
|
226
|
+
|
|
227
|
+
for graph_output in model.graph.output:
|
|
228
|
+
if graph_output.name in qnn_input_output_tensor_dic:
|
|
229
|
+
output_name_after_node_insert = graph_output.name
|
|
230
|
+
# Insert Cast node if Onnx input and Qnn input has idfferent data type
|
|
231
|
+
qnn_output_tensor = qnn_input_output_tensor_dic[graph_output.name]
|
|
232
|
+
if graph_output.type.tensor_type.elem_type != qnn_output_tensor.onnx_data_type:
|
|
233
|
+
# Insert Cast node
|
|
234
|
+
cast_output_name = output_name_after_node_insert
|
|
235
|
+
cast_input_name = cast_output_name + "_qnn_cast"
|
|
236
|
+
output_cast_node = helper.make_node(
|
|
237
|
+
"Cast",
|
|
238
|
+
name=cast_input_name,
|
|
239
|
+
inputs=[cast_input_name],
|
|
240
|
+
outputs=[cast_output_name],
|
|
241
|
+
to=qnn_output_tensor.onnx_data_type,
|
|
242
|
+
)
|
|
243
|
+
# Change output data type to Onn output data type
|
|
244
|
+
graph_output.type.tensor_type.elem_type = qnn_output_tensor.onnx_data_type
|
|
245
|
+
nodes_to_add.extend([output_cast_node])
|
|
246
|
+
output_name_after_node_insert = cast_input_name
|
|
247
|
+
graph_input_output_name_dic[graph_output.name] = cast_input_name
|
|
248
|
+
|
|
249
|
+
if not compare_onnx_shape_with_qnn_shape(graph_output.type.tensor_type.shape.dim, qnn_output_tensor.dim):
|
|
250
|
+
# Add Transpose node (channel first to channel last)
|
|
251
|
+
transpose_perm = gen_to_channel_last_perm(len(graph_output.type.tensor_type.shape.dim))
|
|
252
|
+
transpose_output_name = output_name_after_node_insert
|
|
253
|
+
transpose_input_name = transpose_output_name + "_qnn_trans"
|
|
254
|
+
output_transpose_node = helper.make_node(
|
|
255
|
+
"Transpose",
|
|
256
|
+
name=transpose_input_name,
|
|
257
|
+
inputs=[transpose_input_name],
|
|
258
|
+
outputs=[transpose_output_name],
|
|
259
|
+
perm=transpose_perm,
|
|
260
|
+
)
|
|
261
|
+
nodes_to_add.extend([output_transpose_node])
|
|
262
|
+
graph_input_output_name_dic[graph_output.name] = transpose_input_name
|
|
263
|
+
|
|
264
|
+
# Change output shape to Qnn output shape
|
|
265
|
+
for i in range(len(graph_output.type.tensor_type.shape.dim)):
|
|
266
|
+
graph_output.type.tensor_type.shape.dim[i].dim_value = qnn_input_output_tensor_dic[
|
|
267
|
+
graph_output.name
|
|
268
|
+
].dim[i]
|
|
269
|
+
else:
|
|
270
|
+
raise AssertionError("Error: Onnx model output: " + graph_output.name + " not exist from QNN model output.")
|
|
271
|
+
|
|
272
|
+
for node in model.graph.node:
|
|
273
|
+
for node_input_index, node_input in enumerate(node.input):
|
|
274
|
+
# update consumer node for graph inputs to connect to inserted node
|
|
275
|
+
if node_input in graph_input_output_name_dic:
|
|
276
|
+
node.input[node_input_index] = graph_input_output_name_dic[node_input]
|
|
277
|
+
|
|
278
|
+
for node_output_index, node_output in enumerate(node.output):
|
|
279
|
+
# update producer node for graph outputs to connect to inserted node
|
|
280
|
+
if node_output in graph_input_output_name_dic:
|
|
281
|
+
node.output[node_output_index] = graph_input_output_name_dic[node_output]
|
|
282
|
+
|
|
283
|
+
model.graph.node.extend(nodes_to_add)
|
|
284
|
+
graph_topological_sort(model.graph)
|
|
285
|
+
|
|
286
|
+
# Add extra parameter all_tensors_to_one_file=False, size_threshold=5000 if the model exceeds protobuf 2GB limit e.g below
|
|
287
|
+
# onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"), all_tensors_to_one_file=False, size_threshold=5000)
|
|
288
|
+
onnx.save(model, args.onnx_model.replace(".onnx", "_add_trans.onnx"))
|
|
289
|
+
|
|
290
|
+
|
|
291
|
+
if __name__ == "__main__":
|
|
292
|
+
main()
|