onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
|
|
4
|
+
import argparse
|
|
5
|
+
import copy
|
|
6
|
+
import json
|
|
7
|
+
import sys
|
|
8
|
+
from collections import OrderedDict
|
|
9
|
+
from pprint import pprint
|
|
10
|
+
from typing import Any
|
|
11
|
+
|
|
12
|
+
import onnx
|
|
13
|
+
|
|
14
|
+
TuningResults = dict[str, Any]
|
|
15
|
+
|
|
16
|
+
_TUNING_RESULTS_KEY = "tuning_results"
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def _find_tuning_results_in_props(metadata_props):
|
|
20
|
+
for idx, prop in enumerate(metadata_props):
|
|
21
|
+
if prop.key == _TUNING_RESULTS_KEY:
|
|
22
|
+
return idx
|
|
23
|
+
return -1
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def extract(model: onnx.ModelProto):
|
|
27
|
+
idx = _find_tuning_results_in_props(model.metadata_props)
|
|
28
|
+
if idx < 0:
|
|
29
|
+
return None
|
|
30
|
+
|
|
31
|
+
tuning_results_prop = model.metadata_props[idx]
|
|
32
|
+
return json.loads(tuning_results_prop.value)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def embed(model: onnx.ModelProto, tuning_results: list[TuningResults], overwrite=False):
|
|
36
|
+
idx = _find_tuning_results_in_props(model.metadata_props)
|
|
37
|
+
assert overwrite or idx <= 0, "the supplied onnx file already have tuning results embedded!"
|
|
38
|
+
|
|
39
|
+
if idx >= 0:
|
|
40
|
+
model.metadata_props.pop(idx)
|
|
41
|
+
|
|
42
|
+
entry = model.metadata_props.add()
|
|
43
|
+
entry.key = _TUNING_RESULTS_KEY
|
|
44
|
+
entry.value = json.dumps(tuning_results)
|
|
45
|
+
return model
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class Merger:
|
|
49
|
+
class EpAndValidators:
|
|
50
|
+
def __init__(self, ep: str, validators: dict[str, str]):
|
|
51
|
+
self.ep = ep
|
|
52
|
+
self.validators = copy.deepcopy(validators)
|
|
53
|
+
self.key = (ep, tuple(sorted(validators.items())))
|
|
54
|
+
|
|
55
|
+
def __hash__(self):
|
|
56
|
+
return hash(self.key)
|
|
57
|
+
|
|
58
|
+
def __eq__(self, other):
|
|
59
|
+
return self.ep == other.ep and self.key == other.key
|
|
60
|
+
|
|
61
|
+
def __init__(self):
|
|
62
|
+
self.ev_to_results = OrderedDict()
|
|
63
|
+
|
|
64
|
+
def merge(self, tuning_results: list[TuningResults]):
|
|
65
|
+
for trs in tuning_results:
|
|
66
|
+
self._merge_one(trs)
|
|
67
|
+
|
|
68
|
+
def get_merged(self):
|
|
69
|
+
tuning_results = []
|
|
70
|
+
for ev, flat_results in self.ev_to_results.items():
|
|
71
|
+
results = {}
|
|
72
|
+
trs = {
|
|
73
|
+
"ep": ev.ep,
|
|
74
|
+
"validators": ev.validators,
|
|
75
|
+
"results": results,
|
|
76
|
+
}
|
|
77
|
+
for (op_sig, params_sig), kernel_id in flat_results.items():
|
|
78
|
+
kernel_map = results.setdefault(op_sig, {})
|
|
79
|
+
kernel_map[params_sig] = kernel_id
|
|
80
|
+
tuning_results.append(trs)
|
|
81
|
+
return tuning_results
|
|
82
|
+
|
|
83
|
+
def _merge_one(self, trs: TuningResults):
|
|
84
|
+
ev = Merger.EpAndValidators(trs["ep"], trs["validators"])
|
|
85
|
+
flat_results = self.ev_to_results.setdefault(ev, {})
|
|
86
|
+
for op_sig, kernel_map in trs["results"].items():
|
|
87
|
+
for params_sig, kernel_id in kernel_map.items():
|
|
88
|
+
if (op_sig, params_sig) not in flat_results:
|
|
89
|
+
flat_results[(op_sig, params_sig)] = kernel_id
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def parse_args():
|
|
93
|
+
parser = argparse.ArgumentParser()
|
|
94
|
+
sub_parsers = parser.add_subparsers(help="Command to execute", dest="cmd")
|
|
95
|
+
|
|
96
|
+
extract_parser = sub_parsers.add_parser("extract", help="Extract embedded tuning results from an onnx file.")
|
|
97
|
+
extract_parser.add_argument("input_onnx")
|
|
98
|
+
extract_parser.add_argument("output_json")
|
|
99
|
+
|
|
100
|
+
embed_parser = sub_parsers.add_parser("embed", help="Embed the tuning results into an onnx file.")
|
|
101
|
+
embed_parser.add_argument("--force", "-f", action="store_true", help="Overwrite the tuning results if it existed.")
|
|
102
|
+
embed_parser.add_argument("output_onnx", help="Path of the output onnx file.")
|
|
103
|
+
embed_parser.add_argument("input_onnx", help="Path of the input onnx file.")
|
|
104
|
+
embed_parser.add_argument("input_json", nargs="+", help="Path(s) of the tuning results file(s) to be embedded.")
|
|
105
|
+
|
|
106
|
+
merge_parser = sub_parsers.add_parser("merge", help="Merge multiple tuning results files as a single one.")
|
|
107
|
+
merge_parser.add_argument("output_json", help="Path of the output tuning results file.")
|
|
108
|
+
merge_parser.add_argument("input_json", nargs="+", help="Paths of the tuning results files to be merged.")
|
|
109
|
+
|
|
110
|
+
pprint_parser = sub_parsers.add_parser("pprint", help="Pretty print the tuning results.")
|
|
111
|
+
pprint_parser.add_argument("json_or_onnx", help="A tuning results json file or an onnx file.")
|
|
112
|
+
|
|
113
|
+
args = parser.parse_args()
|
|
114
|
+
if len(vars(args)) == 0:
|
|
115
|
+
parser.print_help()
|
|
116
|
+
exit(-1)
|
|
117
|
+
return args
|
|
118
|
+
|
|
119
|
+
|
|
120
|
+
def main():
|
|
121
|
+
args = parse_args()
|
|
122
|
+
if args.cmd == "extract":
|
|
123
|
+
tuning_results = extract(onnx.load_model(args.input_onnx))
|
|
124
|
+
if tuning_results is None:
|
|
125
|
+
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
|
|
126
|
+
sys.exit(-1)
|
|
127
|
+
json.dump(tuning_results, open(args.output_json, "w")) # noqa: SIM115
|
|
128
|
+
elif args.cmd == "embed":
|
|
129
|
+
model = onnx.load_model(args.input_onnx)
|
|
130
|
+
merger = Merger()
|
|
131
|
+
for tuning_results in [json.load(open(f)) for f in args.input_json]: # noqa: SIM115
|
|
132
|
+
merger.merge(tuning_results)
|
|
133
|
+
model = embed(model, merger.get_merged(), args.force)
|
|
134
|
+
onnx.save_model(model, args.output_onnx)
|
|
135
|
+
elif args.cmd == "merge":
|
|
136
|
+
merger = Merger()
|
|
137
|
+
for tuning_results in [json.load(open(f)) for f in args.input_json]: # noqa: SIM115
|
|
138
|
+
merger.merge(tuning_results)
|
|
139
|
+
json.dump(merger.get_merged(), open(args.output_json, "w")) # noqa: SIM115
|
|
140
|
+
elif args.cmd == "pprint":
|
|
141
|
+
tuning_results = None
|
|
142
|
+
try: # noqa: SIM105
|
|
143
|
+
tuning_results = json.load(open(args.json_or_onnx)) # noqa: SIM115
|
|
144
|
+
except Exception:
|
|
145
|
+
# it might be an onnx file otherwise, try it latter
|
|
146
|
+
pass
|
|
147
|
+
|
|
148
|
+
if tuning_results is None:
|
|
149
|
+
try:
|
|
150
|
+
model = onnx.load_model(args.json_or_onnx)
|
|
151
|
+
tuning_results = extract(model)
|
|
152
|
+
if tuning_results is None:
|
|
153
|
+
sys.stderr.write(f"{args.input_onnx} does not have tuning results embedded!\n")
|
|
154
|
+
sys.exit(-1)
|
|
155
|
+
except Exception:
|
|
156
|
+
pass
|
|
157
|
+
|
|
158
|
+
if tuning_results is None:
|
|
159
|
+
sys.stderr.write(f"{args.json_or_onnx} is not a valid tuning results file or onnx file!")
|
|
160
|
+
sys.exit(-1)
|
|
161
|
+
|
|
162
|
+
pprint(tuning_results)
|
|
163
|
+
else:
|
|
164
|
+
# invalid choice will be handled by the parser
|
|
165
|
+
pass
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
if __name__ == "__main__":
|
|
169
|
+
main()
|
|
@@ -0,0 +1,416 @@
|
|
|
1
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
2
|
+
# Licensed under the MIT License.
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
import logging
|
|
6
|
+
import pathlib
|
|
7
|
+
|
|
8
|
+
import onnx
|
|
9
|
+
from onnx import version_converter
|
|
10
|
+
|
|
11
|
+
import onnxruntime as ort
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def iterate_graph_per_node_func(graph, per_node_func, **func_args):
|
|
15
|
+
"""
|
|
16
|
+
Iterate the graph including subgraphs calling the per_node_func for each node.
|
|
17
|
+
:param graph: Graph to iterate
|
|
18
|
+
:param per_node_func: Function to call for each node. Signature is fn(node: onnx:NodeProto, **kwargs)
|
|
19
|
+
:param func_args: The keyword args to pass through.
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
for node in graph.node:
|
|
23
|
+
per_node_func(node, **func_args)
|
|
24
|
+
# recurse into subgraph for control flow nodes (Scan/Loop/If)
|
|
25
|
+
for attr in node.attribute:
|
|
26
|
+
if attr.HasField("g"):
|
|
27
|
+
iterate_graph_per_node_func(attr.g, per_node_func, **func_args)
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def iterate_graph_per_graph_func(graph, per_graph_func, **func_args):
|
|
31
|
+
"""
|
|
32
|
+
Iterate the graph including subgraphs calling the per_graph_func for each Graph.
|
|
33
|
+
:param graph: Graph to iterate
|
|
34
|
+
:param per_graph_func: Function to call for each graph. Signature is fn(graph: onnx:GraphProto, **kwargs)
|
|
35
|
+
:param func_args: The keyword args to pass through.
|
|
36
|
+
"""
|
|
37
|
+
|
|
38
|
+
per_graph_func(graph, **func_args)
|
|
39
|
+
|
|
40
|
+
for node in graph.node:
|
|
41
|
+
# recurse into subgraph for control flow nodes (Scan/Loop/If)
|
|
42
|
+
for attr in node.attribute:
|
|
43
|
+
if attr.HasField("g"):
|
|
44
|
+
iterate_graph_per_graph_func(attr.g, per_graph_func, **func_args)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def get_opsets_imported(model: onnx.ModelProto):
|
|
48
|
+
"""
|
|
49
|
+
Get the opsets imported by the model
|
|
50
|
+
:param model: Model to check.
|
|
51
|
+
:return: Map of domain to opset.
|
|
52
|
+
"""
|
|
53
|
+
opsets = {}
|
|
54
|
+
for entry in model.opset_import:
|
|
55
|
+
# if empty it's ai.onnx
|
|
56
|
+
domain = entry.domain or "ai.onnx"
|
|
57
|
+
opsets[domain] = entry.version
|
|
58
|
+
|
|
59
|
+
return opsets
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def update_onnx_opset(
|
|
63
|
+
model_path: pathlib.Path,
|
|
64
|
+
opset: int,
|
|
65
|
+
out_path: pathlib.Path | None = None,
|
|
66
|
+
logger: logging.Logger | None = None,
|
|
67
|
+
):
|
|
68
|
+
"""
|
|
69
|
+
Helper to update the opset of a model using onnx version_converter. Target opset must be greater than current opset.
|
|
70
|
+
:param model_path: Path to model to update
|
|
71
|
+
:param opset: Opset to update model to
|
|
72
|
+
:param out_path: Optional output path for updated model to be saved to.
|
|
73
|
+
:param logger: Optional logger for diagnostic output
|
|
74
|
+
:returns: Updated onnx.ModelProto
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
model_path_str = str(model_path.resolve(strict=True))
|
|
78
|
+
if logger:
|
|
79
|
+
logger.info("Updating %s to opset %d", model_path_str, opset)
|
|
80
|
+
|
|
81
|
+
model = onnx.load(model_path_str)
|
|
82
|
+
|
|
83
|
+
new_model = version_converter.convert_version(model, opset)
|
|
84
|
+
|
|
85
|
+
if out_path:
|
|
86
|
+
onnx.save(new_model, str(out_path))
|
|
87
|
+
if logger:
|
|
88
|
+
logger.info("Saved updated model to %s", out_path)
|
|
89
|
+
|
|
90
|
+
return new_model
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
def optimize_model(
|
|
94
|
+
model_path: pathlib.Path,
|
|
95
|
+
output_path: pathlib.Path,
|
|
96
|
+
level: ort.GraphOptimizationLevel = ort.GraphOptimizationLevel.ORT_ENABLE_BASIC,
|
|
97
|
+
log_level: int = 3,
|
|
98
|
+
use_external_initializers: bool = False,
|
|
99
|
+
):
|
|
100
|
+
"""
|
|
101
|
+
Optimize an ONNX model using ONNX Runtime to the specified level
|
|
102
|
+
:param model_path: Path to ONNX model
|
|
103
|
+
:param output_path: Path to save optimized model to.
|
|
104
|
+
:param level: onnxruntime.GraphOptimizationLevel to use. Default is ORT_ENABLE_BASIC.
|
|
105
|
+
:param log_level: Log level. Defaults to Error (3) so we don't get output about unused initializers being removed.
|
|
106
|
+
Warning (2) or Info (1) may be desirable in some scenarios.
|
|
107
|
+
:param use_external_initializers: Set flag to write initializers to an external file. Required if model > 2GB.
|
|
108
|
+
Requires onnxruntime 1.17+
|
|
109
|
+
"""
|
|
110
|
+
so = ort.SessionOptions()
|
|
111
|
+
so.optimized_model_filepath = str(output_path.resolve())
|
|
112
|
+
so.graph_optimization_level = level
|
|
113
|
+
so.log_severity_level = log_level
|
|
114
|
+
|
|
115
|
+
# save using external initializers so models > 2 GB are handled
|
|
116
|
+
if use_external_initializers:
|
|
117
|
+
major, minor, rest = ort.__version__.split(".", 3)
|
|
118
|
+
if (int(major), int(minor)) >= (1, 17):
|
|
119
|
+
so.add_session_config_entry("session.optimized_model_external_initializers_file_name", "external_data.pb")
|
|
120
|
+
else:
|
|
121
|
+
raise ValueError(
|
|
122
|
+
"ONNX Runtime 1.17 or higher required to save initializers as external data when optimizing model. "
|
|
123
|
+
f"Current ONNX Runtime version is {ort.__version__}"
|
|
124
|
+
)
|
|
125
|
+
|
|
126
|
+
# create session to optimize. this will write the updated model to output_path
|
|
127
|
+
_ = ort.InferenceSession(str(model_path.resolve(strict=True)), so, providers=["CPUExecutionProvider"])
|
|
128
|
+
|
|
129
|
+
|
|
130
|
+
def _replace_symbolic_dim_value(graph: onnx.GraphProto, **kwargs):
|
|
131
|
+
param_to_replace = kwargs["dim_param"]
|
|
132
|
+
value = kwargs["value"]
|
|
133
|
+
|
|
134
|
+
def update_dim_values(value_infos):
|
|
135
|
+
for vi in value_infos:
|
|
136
|
+
if vi.type.HasField("tensor_type"):
|
|
137
|
+
shape = vi.type.tensor_type.shape
|
|
138
|
+
if shape:
|
|
139
|
+
for dim in shape.dim:
|
|
140
|
+
if dim.HasField("dim_param") and dim.dim_param == param_to_replace:
|
|
141
|
+
dim.Clear()
|
|
142
|
+
dim.dim_value = value
|
|
143
|
+
|
|
144
|
+
update_dim_values(graph.input)
|
|
145
|
+
update_dim_values(graph.output)
|
|
146
|
+
update_dim_values(graph.value_info)
|
|
147
|
+
|
|
148
|
+
|
|
149
|
+
def _remove_invalid_dim_values_impl(graph: onnx.GraphProto):
|
|
150
|
+
def clear_invalid_values(value):
|
|
151
|
+
if value.type.HasField("tensor_type"):
|
|
152
|
+
shape = value.type.tensor_type.shape
|
|
153
|
+
if shape:
|
|
154
|
+
for dim in shape.dim:
|
|
155
|
+
if dim.HasField("dim_value") and dim.dim_value < 1:
|
|
156
|
+
dim.Clear()
|
|
157
|
+
|
|
158
|
+
for i in graph.input:
|
|
159
|
+
clear_invalid_values(i)
|
|
160
|
+
|
|
161
|
+
for o in graph.output:
|
|
162
|
+
clear_invalid_values(o)
|
|
163
|
+
|
|
164
|
+
for vi in graph.value_info:
|
|
165
|
+
clear_invalid_values(vi)
|
|
166
|
+
|
|
167
|
+
|
|
168
|
+
def remove_invalid_dim_values(graph: onnx.GraphProto):
|
|
169
|
+
"""
|
|
170
|
+
Iterate the graph and subgraphs, unsetting any dim_value entries that have a value of less than 1.
|
|
171
|
+
These are typically erroneously inserted by a converter to represent a dynamic dimension.
|
|
172
|
+
:param graph: GraphProto to update
|
|
173
|
+
"""
|
|
174
|
+
iterate_graph_per_graph_func(graph, _remove_invalid_dim_values_impl)
|
|
175
|
+
|
|
176
|
+
|
|
177
|
+
def make_dim_param_fixed(graph: onnx.GraphProto, param_name: str, value: int):
|
|
178
|
+
"""
|
|
179
|
+
Iterate all values in the graph, replacing dim_param in a tensor shape with the provided value.
|
|
180
|
+
:param graph: GraphProto to update
|
|
181
|
+
:param param_name: dim_param to set
|
|
182
|
+
:param value: value to use
|
|
183
|
+
"""
|
|
184
|
+
iterate_graph_per_graph_func(graph, _replace_symbolic_dim_value, dim_param=param_name, value=value)
|
|
185
|
+
|
|
186
|
+
|
|
187
|
+
def make_input_shape_fixed(graph: onnx.GraphProto, input_name: str, fixed_shape: [int]):
|
|
188
|
+
"""
|
|
189
|
+
Update the named graph input to set shape to the provided value. This can be used to set unknown dims as well
|
|
190
|
+
as to replace dim values.
|
|
191
|
+
If setting the input shape replaces a dim_param, update any other values in the graph that use the dim_param.
|
|
192
|
+
:param graph: Graph to update
|
|
193
|
+
:param input_name: Name of graph input to update.
|
|
194
|
+
:param fixed_shape: Shape to use.
|
|
195
|
+
"""
|
|
196
|
+
|
|
197
|
+
# remove any invalid dim values first. typically this is a dim_value of -1.
|
|
198
|
+
remove_invalid_dim_values(graph)
|
|
199
|
+
|
|
200
|
+
for i in graph.input:
|
|
201
|
+
if i.name == input_name:
|
|
202
|
+
if not i.type.HasField("tensor_type"):
|
|
203
|
+
raise ValueError(f"Input {input_name} is not a tensor")
|
|
204
|
+
|
|
205
|
+
# graph inputs are required to have a shape to provide the rank
|
|
206
|
+
shape = i.type.tensor_type.shape
|
|
207
|
+
if len(shape.dim) != len(fixed_shape):
|
|
208
|
+
raise ValueError(f"Rank mismatch. Existing:{len(shape.dim)} Replacement:{len(fixed_shape)}")
|
|
209
|
+
|
|
210
|
+
for idx, dim in enumerate(shape.dim):
|
|
211
|
+
# check any existing fixed dims match
|
|
212
|
+
if dim.HasField("dim_value"):
|
|
213
|
+
if dim.dim_value != fixed_shape[idx]:
|
|
214
|
+
raise ValueError(
|
|
215
|
+
f"Can't replace existing fixed size of {dim.dim_value} with {fixed_shape[idx]} "
|
|
216
|
+
f"for dimension {idx + 1}"
|
|
217
|
+
)
|
|
218
|
+
elif dim.HasField("dim_param"):
|
|
219
|
+
# replacing a dim_param so have to do that through the entire graph
|
|
220
|
+
make_dim_param_fixed(graph, dim.dim_param, fixed_shape[idx])
|
|
221
|
+
else:
|
|
222
|
+
# replacing an unknown dim
|
|
223
|
+
dim.Clear()
|
|
224
|
+
dim.dim_value = fixed_shape[idx]
|
|
225
|
+
|
|
226
|
+
return
|
|
227
|
+
|
|
228
|
+
raise ValueError(
|
|
229
|
+
f"Input {input_name} was not found in graph inputs. "
|
|
230
|
+
f"Valid input names are: {','.join([i.name for i in graph.input])}"
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
|
|
234
|
+
def fix_output_shapes(model: onnx.ModelProto):
|
|
235
|
+
"""
|
|
236
|
+
Update the output shapesof a model where the input shape/s were made fixed, if possible.
|
|
237
|
+
This is mainly to make the model usage clearer if the output shapes can be inferred from the new input shapes.
|
|
238
|
+
:param model: Model that had input shapes fixed.
|
|
239
|
+
"""
|
|
240
|
+
|
|
241
|
+
# get a version of the model with shape inferencing info in it. this will provide fixed output shapes if possible.
|
|
242
|
+
m2 = onnx.shape_inference.infer_shapes(model)
|
|
243
|
+
onnx.checker.check_model(m2)
|
|
244
|
+
|
|
245
|
+
for idx, o in enumerate(model.graph.output):
|
|
246
|
+
if not is_fixed_size_tensor(o):
|
|
247
|
+
new_o = m2.graph.output[idx]
|
|
248
|
+
if is_fixed_size_tensor(new_o):
|
|
249
|
+
o.type.tensor_type.shape.CopyFrom(new_o.type.tensor_type.shape)
|
|
250
|
+
|
|
251
|
+
|
|
252
|
+
def _create_producer_consumer_link(
|
|
253
|
+
node_to_producers: dict, node_to_consumers: dict, producer: onnx.NodeProto, consumer: onnx.NodeProto
|
|
254
|
+
):
|
|
255
|
+
"""
|
|
256
|
+
Create links between two nodes for a value produced by one and consumed by the other.
|
|
257
|
+
:param node_to_producers: Map of NodeProto to set of nodes that produce values the node consumes as inputs.
|
|
258
|
+
:param node_to_consumers: Map of NodeProto to set of nodes that consume values the node produces as outputs.
|
|
259
|
+
:param producer: Producer node
|
|
260
|
+
:param consumer: Consumer node
|
|
261
|
+
"""
|
|
262
|
+
|
|
263
|
+
if consumer not in node_to_producers:
|
|
264
|
+
node_to_producers[consumer] = set()
|
|
265
|
+
|
|
266
|
+
if producer not in node_to_consumers:
|
|
267
|
+
node_to_consumers[producer] = set()
|
|
268
|
+
|
|
269
|
+
# add entry mapping this node to the producer of this input
|
|
270
|
+
node_to_producers[consumer].add(producer)
|
|
271
|
+
node_to_consumers[producer].add(consumer)
|
|
272
|
+
|
|
273
|
+
|
|
274
|
+
def _map_node_dependencies(graph: onnx.GraphProto, node_to_producers: dict, node_to_consumers: dict):
|
|
275
|
+
graph_inputs = {i.name for i in graph.input}
|
|
276
|
+
initializers = {i.name for i in graph.initializer}
|
|
277
|
+
|
|
278
|
+
# map of value name to node that creates it. copy parent values but override if values get shadowed
|
|
279
|
+
producers = {}
|
|
280
|
+
|
|
281
|
+
implicit_inputs = set()
|
|
282
|
+
|
|
283
|
+
def is_local_value(value):
|
|
284
|
+
return value in producers or value in initializers or value in graph_inputs
|
|
285
|
+
|
|
286
|
+
for node in graph.node:
|
|
287
|
+
inputs = list(node.input)
|
|
288
|
+
|
|
289
|
+
for attr in node.attribute:
|
|
290
|
+
if attr.HasField("g"):
|
|
291
|
+
subgraph_implicit_inputs = _map_node_dependencies(attr.g, node_to_producers, node_to_consumers)
|
|
292
|
+
inputs += subgraph_implicit_inputs
|
|
293
|
+
|
|
294
|
+
for i in inputs:
|
|
295
|
+
if not i:
|
|
296
|
+
# missing optional input
|
|
297
|
+
continue
|
|
298
|
+
|
|
299
|
+
if is_local_value(i):
|
|
300
|
+
if i in producers:
|
|
301
|
+
producer = producers[i]
|
|
302
|
+
_create_producer_consumer_link(node_to_producers, node_to_consumers, producer, node)
|
|
303
|
+
else:
|
|
304
|
+
implicit_inputs.add(i)
|
|
305
|
+
|
|
306
|
+
for o in node.output:
|
|
307
|
+
producers[o] = node
|
|
308
|
+
|
|
309
|
+
return implicit_inputs
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
def get_producer_consumer_maps(graph: onnx.GraphProto):
|
|
313
|
+
"""
|
|
314
|
+
Get maps for connections between the node that produces each value and the nodes that consume the value.
|
|
315
|
+
Processing includes subgraphs. As the map key is a Node instance from the Graph there should be no ambiguity.
|
|
316
|
+
:param graph: Graph to process.
|
|
317
|
+
:return: Tuple with two maps.
|
|
318
|
+
First is node_to_producers map of a node to set of all nodes producing input it consumes.
|
|
319
|
+
Second is node_to_consumers map of a node to set of all nodes consuming output it creates.
|
|
320
|
+
e.g. NodeA and NodeB provide inputs to NodeC. NodeC provides input to NodeD
|
|
321
|
+
node_to_consumers[NodeA] = set([NodeC])
|
|
322
|
+
node_to_consumers[NodeB] = set([NodeC])
|
|
323
|
+
node_to_producers[NodeC] = set([NodeA, NodeB])
|
|
324
|
+
node_to_consumers[NodeC] = set([NodeD])
|
|
325
|
+
node_to_producers[NodeD] = set([NodeC])
|
|
326
|
+
"""
|
|
327
|
+
|
|
328
|
+
# use a hash of the object id for NodeProto.
|
|
329
|
+
# we need this for the partitioning checker where we keep maps with nodes as the key.
|
|
330
|
+
onnx.NodeProto.__hash__ = lambda self: id(self)
|
|
331
|
+
|
|
332
|
+
node_to_producers = {} # map of node instance to nodes producing input values it consumes
|
|
333
|
+
node_to_consumers = {} # map of node instance to nodes consuming output values it produces
|
|
334
|
+
|
|
335
|
+
implicit_inputs = _map_node_dependencies(graph, node_to_producers, node_to_consumers)
|
|
336
|
+
|
|
337
|
+
# top level graph should have no implicit inputs
|
|
338
|
+
if implicit_inputs:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"This appears to be an invalid model with missing inputs of {','.join(sorted(implicit_inputs))}"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
return node_to_producers, node_to_consumers
|
|
344
|
+
|
|
345
|
+
|
|
346
|
+
def is_fixed_size_tensor(value: onnx.ValueInfoProto):
|
|
347
|
+
"""
|
|
348
|
+
Check if value is a tensor with a fixed shape.
|
|
349
|
+
:param value: onnx.ValueInfoProto to check
|
|
350
|
+
:return: True if value is a tensor, with a shape, where all dimensions have fixed values.
|
|
351
|
+
"""
|
|
352
|
+
|
|
353
|
+
is_fixed = False
|
|
354
|
+
if value.type.HasField("tensor_type"):
|
|
355
|
+
shape = value.type.tensor_type.shape
|
|
356
|
+
if shape:
|
|
357
|
+
is_fixed = True # scalar has no dims so set to True and unset if we hit a dim without a valid value
|
|
358
|
+
for dim in shape.dim:
|
|
359
|
+
if dim.HasField("dim_value") and dim.dim_value > 0:
|
|
360
|
+
continue
|
|
361
|
+
|
|
362
|
+
# anything else means it's a dynamic value
|
|
363
|
+
is_fixed = False
|
|
364
|
+
break
|
|
365
|
+
|
|
366
|
+
return is_fixed
|
|
367
|
+
|
|
368
|
+
|
|
369
|
+
def get_optimization_level(level):
|
|
370
|
+
"""Convert string to GraphOptimizationLevel."""
|
|
371
|
+
if level == "disable":
|
|
372
|
+
return ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
373
|
+
if level == "basic":
|
|
374
|
+
# Constant folding and other optimizations that only use ONNX operators
|
|
375
|
+
return ort.GraphOptimizationLevel.ORT_ENABLE_BASIC
|
|
376
|
+
if level == "extended":
|
|
377
|
+
# Optimizations using custom operators, excluding NCHWc and NHWC layout optimizers
|
|
378
|
+
return ort.GraphOptimizationLevel.ORT_ENABLE_EXTENDED
|
|
379
|
+
if level == "layout":
|
|
380
|
+
# NCHWc and NHWC layout optimizers
|
|
381
|
+
return ort.GraphOptimizationLevel.ORT_ENABLE_LAYOUT
|
|
382
|
+
if level == "all":
|
|
383
|
+
return ort.GraphOptimizationLevel.ORT_ENABLE_ALL
|
|
384
|
+
|
|
385
|
+
raise ValueError("Invalid optimization level of " + level)
|
|
386
|
+
|
|
387
|
+
|
|
388
|
+
class ModelProtoWithShapeInfo:
|
|
389
|
+
"""
|
|
390
|
+
Class to load an ONNX model and run shape inferencing on it to populate the ValueInfo.
|
|
391
|
+
The model_with_shape_info property will contain the updated model.
|
|
392
|
+
If the model is > 2GB and uses external data a temporary file is required to run shape inferencing successfully.
|
|
393
|
+
This helper class handles automatic removal of the temporary file.
|
|
394
|
+
"""
|
|
395
|
+
|
|
396
|
+
def __init__(self, model_path: pathlib.Path):
|
|
397
|
+
"""
|
|
398
|
+
:param model_path: Path to ONNX model to load and run shape inferencing on.
|
|
399
|
+
"""
|
|
400
|
+
|
|
401
|
+
self.model_path = model_path
|
|
402
|
+
|
|
403
|
+
model = onnx.load(str(model_path))
|
|
404
|
+
self.model_with_shape_info = onnx.shape_inference.infer_shapes(model, strict_mode=True)
|
|
405
|
+
|
|
406
|
+
# ONNX has a silent failure from the call to infer_shapes when the model is > 2GB.
|
|
407
|
+
# We detect that by checking the nodes in the returned model.
|
|
408
|
+
self._tmp_model_path = None
|
|
409
|
+
if len(model.graph.node) > 0 and len(self.model_with_shape_info.graph.node) == 0:
|
|
410
|
+
self._tmp_model_path = pathlib.Path(model_path).with_suffix(".temp_with_shapeinf.onnx")
|
|
411
|
+
onnx.shape_inference.infer_shapes_path(str(model_path), str(self._tmp_model_path), strict_mode=True)
|
|
412
|
+
self.model_with_shape_info = onnx.load(str(self._tmp_model_path))
|
|
413
|
+
|
|
414
|
+
def __del__(self):
|
|
415
|
+
if self._tmp_model_path:
|
|
416
|
+
self._tmp_model_path.unlink(missing_ok=True)
|
|
@@ -0,0 +1,85 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
# An offline standalone script to declassify an ONNX model by randomizing the tensor data in initializers.
|
|
7
|
+
# The ORT Performance may change especially on generative models.
|
|
8
|
+
|
|
9
|
+
import argparse
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
|
|
12
|
+
import numpy as np
|
|
13
|
+
from onnx import load_model, numpy_helper, onnx_pb, save_model
|
|
14
|
+
|
|
15
|
+
# An experimental small value for differentiating shape data and weights.
|
|
16
|
+
# The tensor data with larger size can't be shape data.
|
|
17
|
+
# User may adjust this value as needed.
|
|
18
|
+
SIZE_THRESHOLD = 10
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def graph_iterator(model, func):
|
|
22
|
+
graph_queue = [model.graph]
|
|
23
|
+
while graph_queue:
|
|
24
|
+
graph = graph_queue.pop(0)
|
|
25
|
+
func(graph)
|
|
26
|
+
for node in graph.node:
|
|
27
|
+
for attr in node.attribute:
|
|
28
|
+
if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPH:
|
|
29
|
+
assert isinstance(attr.g, onnx_pb.GraphProto)
|
|
30
|
+
graph_queue.append(attr.g)
|
|
31
|
+
if attr.type == onnx_pb.AttributeProto.AttributeType.GRAPHS:
|
|
32
|
+
for g in attr.graphs:
|
|
33
|
+
assert isinstance(g, onnx_pb.GraphProto)
|
|
34
|
+
graph_queue.append(g)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def randomize_graph_initializer(graph):
|
|
38
|
+
for i_tensor in graph.initializer:
|
|
39
|
+
array = numpy_helper.to_array(i_tensor)
|
|
40
|
+
# TODO: need to find a better way to differentiate shape data and weights.
|
|
41
|
+
if array.size > SIZE_THRESHOLD:
|
|
42
|
+
random_array = np.random.uniform(array.min(), array.max(), size=array.shape).astype(array.dtype)
|
|
43
|
+
o_tensor = numpy_helper.from_array(random_array, i_tensor.name)
|
|
44
|
+
i_tensor.CopyFrom(o_tensor)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def main():
|
|
48
|
+
parser = argparse.ArgumentParser(description="Randomize the weights of an ONNX model")
|
|
49
|
+
parser.add_argument("-m", type=str, required=True, help="input onnx model path")
|
|
50
|
+
parser.add_argument("-o", type=str, required=True, help="output onnx model path")
|
|
51
|
+
parser.add_argument(
|
|
52
|
+
"--use_external_data_format",
|
|
53
|
+
required=False,
|
|
54
|
+
action="store_true",
|
|
55
|
+
help="Store or Save in external data format",
|
|
56
|
+
)
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"--all_tensors_to_one_file",
|
|
59
|
+
required=False,
|
|
60
|
+
action="store_true",
|
|
61
|
+
help="Save all tensors to one file",
|
|
62
|
+
)
|
|
63
|
+
args = parser.parse_args()
|
|
64
|
+
|
|
65
|
+
data_path = None
|
|
66
|
+
if args.use_external_data_format:
|
|
67
|
+
if Path(args.m).parent == Path(args.o).parent:
|
|
68
|
+
raise RuntimeError("Please specify output directory with different parent path to input directory.")
|
|
69
|
+
if args.all_tensors_to_one_file:
|
|
70
|
+
data_path = Path(args.o).name + ".data"
|
|
71
|
+
|
|
72
|
+
Path(args.o).parent.mkdir(parents=True, exist_ok=True)
|
|
73
|
+
onnx_model = load_model(args.m, load_external_data=args.use_external_data_format)
|
|
74
|
+
graph_iterator(onnx_model, randomize_graph_initializer)
|
|
75
|
+
save_model(
|
|
76
|
+
onnx_model,
|
|
77
|
+
args.o,
|
|
78
|
+
save_as_external_data=args.use_external_data_format,
|
|
79
|
+
all_tensors_to_one_file=args.all_tensors_to_one_file,
|
|
80
|
+
location=data_path,
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
if __name__ == "__main__":
|
|
85
|
+
main()
|