onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,719 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import numpy
|
|
12
|
+
import torch
|
|
13
|
+
from affinity_helper import AffinitySetting
|
|
14
|
+
from benchmark_helper import OptimizerInfo, Precision, create_onnxruntime_session
|
|
15
|
+
from huggingface_models import MODEL_CLASSES
|
|
16
|
+
from quantize_helper import QuantizeHelper
|
|
17
|
+
from torch_onnx_export_helper import torch_onnx_export
|
|
18
|
+
from transformers import AutoConfig, AutoFeatureExtractor, AutoTokenizer, LxmertConfig, TransfoXLConfig
|
|
19
|
+
|
|
20
|
+
from onnxruntime.transformers.models.gpt2.gpt2_helper import (
|
|
21
|
+
PRETRAINED_GPT2_MODELS,
|
|
22
|
+
GPT2ModelNoPastState,
|
|
23
|
+
TFGPT2ModelNoPastState,
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "2"
|
|
27
|
+
|
|
28
|
+
logger = logging.getLogger(__name__)
|
|
29
|
+
|
|
30
|
+
# Workaround by replacing torch.triu using self-defined op
|
|
31
|
+
# Since torch.triu cannot be exported to ONNX. See https://github.com/pytorch/pytorch/issues/32968
|
|
32
|
+
torch_func = {"triu": torch.triu}
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def triu_onnx(x, diagonal=0, out=None):
|
|
36
|
+
assert out is None
|
|
37
|
+
assert len(x.shape) == 2 and x.size(0) == x.size(1)
|
|
38
|
+
|
|
39
|
+
torch_triu = torch_func["triu"]
|
|
40
|
+
template = torch_triu(torch.ones((1024, 1024), dtype=torch.uint8), diagonal)
|
|
41
|
+
mask = template[: x.size(0), : x.size(1)]
|
|
42
|
+
return torch.where(mask.bool(), x, torch.zeros_like(x))
|
|
43
|
+
|
|
44
|
+
|
|
45
|
+
def replace_torch_functions():
|
|
46
|
+
torch.triu = triu_onnx
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def restore_torch_functions():
|
|
50
|
+
torch.triu = torch_func["triu"]
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
def create_onnxruntime_input(vocab_size, batch_size, sequence_length, input_names, config, data_type=numpy.int64):
|
|
54
|
+
if config.model_type in ["vit", "swin"]:
|
|
55
|
+
input_ids = numpy.random.rand(batch_size, 3, config.image_size, config.image_size).astype(numpy.float32)
|
|
56
|
+
inputs = {"pixel_values": input_ids}
|
|
57
|
+
return inputs
|
|
58
|
+
|
|
59
|
+
input_ids = numpy.random.randint(low=0, high=vocab_size - 1, size=(batch_size, sequence_length), dtype=data_type)
|
|
60
|
+
inputs = {"input_ids": input_ids}
|
|
61
|
+
|
|
62
|
+
if "attention_mask" in input_names:
|
|
63
|
+
attention_mask = numpy.ones([batch_size, sequence_length], dtype=data_type)
|
|
64
|
+
inputs["attention_mask"] = attention_mask
|
|
65
|
+
|
|
66
|
+
if "token_type_ids" in input_names:
|
|
67
|
+
segment_ids = numpy.zeros([batch_size, sequence_length], dtype=data_type)
|
|
68
|
+
inputs["token_type_ids"] = segment_ids
|
|
69
|
+
|
|
70
|
+
if config.is_encoder_decoder:
|
|
71
|
+
inputs["decoder_input_ids"] = input_ids
|
|
72
|
+
|
|
73
|
+
if isinstance(config, LxmertConfig):
|
|
74
|
+
inputs["visual_feats"] = numpy.random.randn(1, 1, config.visual_feat_dim).astype(numpy.float32)
|
|
75
|
+
inputs["visual_pos"] = numpy.random.randn(1, 1, config.visual_pos_dim).astype(numpy.float32)
|
|
76
|
+
if isinstance(config, TransfoXLConfig):
|
|
77
|
+
inputs["tf_transfo_xl_model/transformer/pos_emb/einsum/Einsum/inputs_1:0"] = numpy.zeros(
|
|
78
|
+
[config.hidden_size], dtype=numpy.float32
|
|
79
|
+
)
|
|
80
|
+
return inputs
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
def filter_inputs(inputs, input_names):
|
|
84
|
+
remaining_model_inputs = {}
|
|
85
|
+
for input_name in input_names:
|
|
86
|
+
if input_name in inputs:
|
|
87
|
+
remaining_model_inputs[input_name] = inputs[input_name]
|
|
88
|
+
return remaining_model_inputs
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def flatten(inputs):
|
|
92
|
+
return [[flatten(i) for i in inputs] if isinstance(inputs, (list, tuple)) else inputs]
|
|
93
|
+
|
|
94
|
+
|
|
95
|
+
def update_flatten_list(inputs, res_list):
|
|
96
|
+
for i in inputs:
|
|
97
|
+
res_list.append(i) if not isinstance(i, (list, tuple)) else update_flatten_list(i, res_list)
|
|
98
|
+
return res_list
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def build_dynamic_axes(example_inputs, outputs_flatten):
|
|
102
|
+
sequence_length = example_inputs["input_ids"].shape[-1]
|
|
103
|
+
|
|
104
|
+
dynamic_axes = {key: {0: "batch_size", 1: "seq_len"} for key in example_inputs}
|
|
105
|
+
|
|
106
|
+
output_names = ["output_" + str(i + 1) for i in range(len(outputs_flatten))]
|
|
107
|
+
for i, output_name in enumerate(output_names):
|
|
108
|
+
dynamic_axes[output_name] = {0: "batch_size"}
|
|
109
|
+
dims = outputs_flatten[i].shape
|
|
110
|
+
for j, dim in enumerate(dims):
|
|
111
|
+
if dim == sequence_length:
|
|
112
|
+
dynamic_axes[output_name].update({j: "seq_len"})
|
|
113
|
+
return dynamic_axes, output_names
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def validate_onnx_model(
|
|
117
|
+
onnx_model_path,
|
|
118
|
+
example_inputs,
|
|
119
|
+
example_outputs_flatten,
|
|
120
|
+
use_gpu,
|
|
121
|
+
fp16,
|
|
122
|
+
output_names=None,
|
|
123
|
+
):
|
|
124
|
+
test_session = create_onnxruntime_session(onnx_model_path, use_gpu, enable_all_optimization=False)
|
|
125
|
+
if test_session is None:
|
|
126
|
+
logger.error(f"{onnx_model_path} is an invalid ONNX model")
|
|
127
|
+
return False
|
|
128
|
+
|
|
129
|
+
logger.info(f"{onnx_model_path} is a valid ONNX model")
|
|
130
|
+
|
|
131
|
+
# Compare the inference result with PyTorch or Tensorflow
|
|
132
|
+
example_ort_inputs = {k: t.numpy() for k, t in example_inputs.items()}
|
|
133
|
+
example_ort_outputs = test_session.run(output_names, example_ort_inputs)
|
|
134
|
+
if len(example_outputs_flatten) != len(example_ort_outputs):
|
|
135
|
+
logger.error(
|
|
136
|
+
f"Number of output tensors expected {len(example_outputs_flatten)}, got {len(example_ort_outputs)}"
|
|
137
|
+
)
|
|
138
|
+
return False
|
|
139
|
+
|
|
140
|
+
for i in range(len(example_outputs_flatten)):
|
|
141
|
+
abs_diff = numpy.amax(numpy.abs(example_ort_outputs[i] - example_outputs_flatten[i].cpu().numpy()))
|
|
142
|
+
if abs_diff > 1e-4:
|
|
143
|
+
logger.info(f"Max absolute diff={abs_diff} for output tensor {i}")
|
|
144
|
+
|
|
145
|
+
rtol = 5e-02 if fp16 else 1e-4
|
|
146
|
+
atol = 1e-01 if fp16 else 1e-4
|
|
147
|
+
if not numpy.allclose(
|
|
148
|
+
example_ort_outputs[i],
|
|
149
|
+
example_outputs_flatten[i].cpu().numpy(),
|
|
150
|
+
rtol=rtol,
|
|
151
|
+
atol=atol,
|
|
152
|
+
):
|
|
153
|
+
logger.error(f"Output tensor {i} is not close: rtol={rtol}, atol={atol}")
|
|
154
|
+
return False
|
|
155
|
+
|
|
156
|
+
logger.info(f"inference result of onnxruntime is validated on {onnx_model_path}")
|
|
157
|
+
return True
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
def get_onnx_file_path(
|
|
161
|
+
onnx_dir: str,
|
|
162
|
+
model_name: str,
|
|
163
|
+
input_count: int,
|
|
164
|
+
optimized_by_script: bool,
|
|
165
|
+
use_gpu: bool,
|
|
166
|
+
precision: Precision,
|
|
167
|
+
optimized_by_onnxruntime: bool,
|
|
168
|
+
use_external_data: bool,
|
|
169
|
+
):
|
|
170
|
+
from re import sub # noqa: PLC0415
|
|
171
|
+
|
|
172
|
+
normalized_model_name = sub(r"[^a-zA-Z0-9_]", "_", model_name)
|
|
173
|
+
|
|
174
|
+
if not optimized_by_script:
|
|
175
|
+
filename = f"{normalized_model_name}_{input_count}"
|
|
176
|
+
else:
|
|
177
|
+
device = "gpu" if use_gpu else "cpu"
|
|
178
|
+
filename = f"{normalized_model_name}_{input_count}_{precision}_{device}"
|
|
179
|
+
|
|
180
|
+
if optimized_by_onnxruntime:
|
|
181
|
+
filename += "_ort"
|
|
182
|
+
|
|
183
|
+
directory = onnx_dir
|
|
184
|
+
# ONNXRuntime will not write external data so the raw and optimized models shall be in same directory.
|
|
185
|
+
if use_external_data and not optimized_by_onnxruntime:
|
|
186
|
+
directory = os.path.join(onnx_dir, filename)
|
|
187
|
+
if not os.path.exists(directory):
|
|
188
|
+
os.makedirs(directory)
|
|
189
|
+
|
|
190
|
+
return os.path.join(directory, f"{filename}.onnx")
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
def add_filename_suffix(file_path: str, suffix: str) -> str:
|
|
194
|
+
"""
|
|
195
|
+
Append a suffix at the filename (before the extension).
|
|
196
|
+
Args:
|
|
197
|
+
path: pathlib.Path The actual path object we would like to add a suffix
|
|
198
|
+
suffix: The suffix to add
|
|
199
|
+
Returns: path with suffix appended at the end of the filename and before extension
|
|
200
|
+
"""
|
|
201
|
+
path = Path(file_path)
|
|
202
|
+
return str(path.parent.joinpath(path.stem + suffix).with_suffix(path.suffix))
|
|
203
|
+
|
|
204
|
+
|
|
205
|
+
def optimize_onnx_model_by_ort(onnx_model_path, ort_model_path, use_gpu, overwrite, model_fusion_statistics):
|
|
206
|
+
if overwrite or not os.path.exists(ort_model_path):
|
|
207
|
+
Path(ort_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
208
|
+
from optimizer import get_fusion_statistics, optimize_by_onnxruntime # noqa: PLC0415
|
|
209
|
+
|
|
210
|
+
# Use onnxruntime to optimize model, which will be saved to *_ort.onnx
|
|
211
|
+
_ = optimize_by_onnxruntime(
|
|
212
|
+
onnx_model_path,
|
|
213
|
+
use_gpu=use_gpu,
|
|
214
|
+
optimized_model_path=ort_model_path,
|
|
215
|
+
opt_level=99,
|
|
216
|
+
)
|
|
217
|
+
model_fusion_statistics[ort_model_path] = get_fusion_statistics(ort_model_path)
|
|
218
|
+
else:
|
|
219
|
+
logger.info(f"Skip optimization since model existed: {ort_model_path}")
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def optimize_onnx_model(
|
|
223
|
+
onnx_model_path,
|
|
224
|
+
optimized_model_path,
|
|
225
|
+
model_type,
|
|
226
|
+
num_attention_heads,
|
|
227
|
+
hidden_size,
|
|
228
|
+
use_gpu,
|
|
229
|
+
precision,
|
|
230
|
+
use_raw_attention_mask,
|
|
231
|
+
overwrite,
|
|
232
|
+
model_fusion_statistics,
|
|
233
|
+
use_external_data_format,
|
|
234
|
+
optimization_options=None,
|
|
235
|
+
):
|
|
236
|
+
if overwrite or not os.path.exists(optimized_model_path):
|
|
237
|
+
Path(optimized_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
238
|
+
|
|
239
|
+
from fusion_options import FusionOptions # noqa: PLC0415
|
|
240
|
+
from optimizer import optimize_model # noqa: PLC0415
|
|
241
|
+
|
|
242
|
+
if optimization_options is None:
|
|
243
|
+
optimization_options = FusionOptions(model_type)
|
|
244
|
+
optimization_options.use_raw_attention_mask(use_raw_attention_mask)
|
|
245
|
+
if precision == Precision.FLOAT16:
|
|
246
|
+
optimization_options.enable_gelu_approximation = True
|
|
247
|
+
if precision == Precision.INT8:
|
|
248
|
+
optimization_options.enable_embed_layer_norm = False
|
|
249
|
+
|
|
250
|
+
# For swin models, the num_attention_heads is a list, which isn't supported yet, so set to 0 for now
|
|
251
|
+
if model_type == "swin":
|
|
252
|
+
num_attention_heads = 0
|
|
253
|
+
hidden_size = 0
|
|
254
|
+
|
|
255
|
+
# Use script to optimize model.
|
|
256
|
+
# Use opt_level <= 1 for models to be converted to fp16, because some fused op (like FusedGemm) has only fp32 and no fp16.
|
|
257
|
+
# It is better to be conservative so we use opt_level=0 here, in case MemcpyFromHost is added to the graph by OnnxRuntime.
|
|
258
|
+
opt_model = optimize_model(
|
|
259
|
+
onnx_model_path,
|
|
260
|
+
model_type,
|
|
261
|
+
num_heads=num_attention_heads,
|
|
262
|
+
hidden_size=hidden_size,
|
|
263
|
+
opt_level=0,
|
|
264
|
+
optimization_options=optimization_options,
|
|
265
|
+
use_gpu=use_gpu,
|
|
266
|
+
only_onnxruntime=False,
|
|
267
|
+
)
|
|
268
|
+
if model_type == "bert_keras" or model_type == "bert_tf":
|
|
269
|
+
opt_model.use_dynamic_axes()
|
|
270
|
+
|
|
271
|
+
model_fusion_statistics[optimized_model_path] = opt_model.get_fused_operator_statistics()
|
|
272
|
+
|
|
273
|
+
if precision == Precision.FLOAT16:
|
|
274
|
+
opt_model.convert_float_to_float16(keep_io_types=True)
|
|
275
|
+
|
|
276
|
+
opt_model.save_model_to_file(optimized_model_path, use_external_data_format)
|
|
277
|
+
else:
|
|
278
|
+
logger.info(f"Skip optimization since model existed: {optimized_model_path}")
|
|
279
|
+
|
|
280
|
+
|
|
281
|
+
def modelclass_dispatcher(model_name, custom_model_class):
|
|
282
|
+
if custom_model_class is not None:
|
|
283
|
+
if custom_model_class in MODEL_CLASSES:
|
|
284
|
+
return custom_model_class
|
|
285
|
+
else:
|
|
286
|
+
raise Exception("Valid model class: " + " ".join(MODEL_CLASSES))
|
|
287
|
+
|
|
288
|
+
if model_name in PRETRAINED_GPT2_MODELS:
|
|
289
|
+
return "GPT2ModelNoPastState"
|
|
290
|
+
|
|
291
|
+
import re # noqa: PLC0415
|
|
292
|
+
|
|
293
|
+
if re.search("-squad$", model_name) is not None:
|
|
294
|
+
return "AutoModelForQuestionAnswering"
|
|
295
|
+
elif re.search("-mprc$", model_name) is not None:
|
|
296
|
+
return "AutoModelForSequenceClassification"
|
|
297
|
+
elif re.search("gpt2", model_name) is not None:
|
|
298
|
+
return "AutoModelWithLMHead"
|
|
299
|
+
|
|
300
|
+
return "AutoModel"
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
def load_pretrained_model(model_name, config, cache_dir, custom_model_class, is_tf_model=False):
|
|
304
|
+
model_class_name = modelclass_dispatcher(model_name, custom_model_class)
|
|
305
|
+
|
|
306
|
+
if model_class_name == "GPT2ModelNoPastState":
|
|
307
|
+
if is_tf_model:
|
|
308
|
+
return TFGPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
|
309
|
+
else:
|
|
310
|
+
return GPT2ModelNoPastState.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
|
311
|
+
|
|
312
|
+
if is_tf_model:
|
|
313
|
+
model_class_name = "TF" + model_class_name
|
|
314
|
+
|
|
315
|
+
transformers_module = __import__("transformers", fromlist=[model_class_name])
|
|
316
|
+
logger.info(f"Model class name: {model_class_name}")
|
|
317
|
+
model_class = getattr(transformers_module, model_class_name)
|
|
318
|
+
|
|
319
|
+
return model_class.from_pretrained(model_name, config=config, cache_dir=cache_dir)
|
|
320
|
+
|
|
321
|
+
|
|
322
|
+
def load_pt_model(model_name, model_class, cache_dir, config_modifier):
|
|
323
|
+
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
|
|
324
|
+
if hasattr(config, "return_dict"):
|
|
325
|
+
config.return_dict = False
|
|
326
|
+
|
|
327
|
+
config_modifier.modify(config)
|
|
328
|
+
|
|
329
|
+
model = load_pretrained_model(model_name, config=config, cache_dir=cache_dir, custom_model_class=model_class)
|
|
330
|
+
|
|
331
|
+
return config, model
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def load_tf_model(model_name, model_class, cache_dir, config_modifier):
|
|
335
|
+
config = AutoConfig.from_pretrained(model_name, cache_dir=cache_dir)
|
|
336
|
+
|
|
337
|
+
config_modifier.modify(config)
|
|
338
|
+
# Loading tf model from transformers limits the cpu affinity to {0} when KMP_AFFINITY is set
|
|
339
|
+
# Restore the affinity after model loading for expected ORT performance
|
|
340
|
+
affinity_setting = AffinitySetting()
|
|
341
|
+
affinity_setting.get_affinity()
|
|
342
|
+
model = load_pretrained_model(
|
|
343
|
+
model_name,
|
|
344
|
+
config=config,
|
|
345
|
+
cache_dir=cache_dir,
|
|
346
|
+
custom_model_class=model_class,
|
|
347
|
+
is_tf_model=True,
|
|
348
|
+
)
|
|
349
|
+
affinity_setting.set_affinity()
|
|
350
|
+
|
|
351
|
+
return config, model
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
# For test only
|
|
355
|
+
def load_pt_model_from_tf(model_name):
|
|
356
|
+
# Note that we could get pt model from tf, but model source and its structure in this case is different from directly using
|
|
357
|
+
# load_pt_model() and load_tf_model() even with the same name. Therefore it should not be used for comparing with them
|
|
358
|
+
from convert_tf_models_to_pytorch import tf2pt_pipeline # noqa: PLC0415
|
|
359
|
+
|
|
360
|
+
config, model = tf2pt_pipeline(model_name)
|
|
361
|
+
|
|
362
|
+
return config, model
|
|
363
|
+
|
|
364
|
+
|
|
365
|
+
def validate_and_optimize_onnx(
|
|
366
|
+
model_name,
|
|
367
|
+
use_external_data_format,
|
|
368
|
+
model_type,
|
|
369
|
+
onnx_dir,
|
|
370
|
+
input_names,
|
|
371
|
+
use_gpu,
|
|
372
|
+
precision,
|
|
373
|
+
optimize_info,
|
|
374
|
+
validate_onnx,
|
|
375
|
+
use_raw_attention_mask,
|
|
376
|
+
overwrite,
|
|
377
|
+
config,
|
|
378
|
+
model_fusion_statistics,
|
|
379
|
+
onnx_model_path,
|
|
380
|
+
example_inputs,
|
|
381
|
+
example_outputs_flatten,
|
|
382
|
+
output_names,
|
|
383
|
+
fusion_options,
|
|
384
|
+
):
|
|
385
|
+
is_valid_onnx_model = True
|
|
386
|
+
if validate_onnx:
|
|
387
|
+
is_valid_onnx_model = validate_onnx_model(
|
|
388
|
+
onnx_model_path,
|
|
389
|
+
example_inputs,
|
|
390
|
+
example_outputs_flatten,
|
|
391
|
+
use_gpu,
|
|
392
|
+
False,
|
|
393
|
+
output_names,
|
|
394
|
+
)
|
|
395
|
+
if optimize_info.name == OptimizerInfo.NOOPT.name:
|
|
396
|
+
return onnx_model_path, is_valid_onnx_model, config.vocab_size
|
|
397
|
+
|
|
398
|
+
if (
|
|
399
|
+
optimize_info.name == OptimizerInfo.BYSCRIPT.name
|
|
400
|
+
or precision == Precision.FLOAT16
|
|
401
|
+
or precision == Precision.INT8
|
|
402
|
+
): # Use script (optimizer.py) to optimize
|
|
403
|
+
optimized_model_path = get_onnx_file_path(
|
|
404
|
+
onnx_dir,
|
|
405
|
+
model_name,
|
|
406
|
+
len(input_names),
|
|
407
|
+
True,
|
|
408
|
+
use_gpu,
|
|
409
|
+
precision,
|
|
410
|
+
False,
|
|
411
|
+
use_external_data_format,
|
|
412
|
+
)
|
|
413
|
+
optimize_onnx_model(
|
|
414
|
+
onnx_model_path,
|
|
415
|
+
optimized_model_path,
|
|
416
|
+
model_type,
|
|
417
|
+
config.num_attention_heads,
|
|
418
|
+
config.hidden_size,
|
|
419
|
+
use_gpu,
|
|
420
|
+
precision,
|
|
421
|
+
use_raw_attention_mask,
|
|
422
|
+
overwrite,
|
|
423
|
+
model_fusion_statistics,
|
|
424
|
+
use_external_data_format,
|
|
425
|
+
fusion_options,
|
|
426
|
+
)
|
|
427
|
+
|
|
428
|
+
onnx_model_path = optimized_model_path
|
|
429
|
+
if validate_onnx:
|
|
430
|
+
is_valid_onnx_model = validate_onnx_model(
|
|
431
|
+
onnx_model_path,
|
|
432
|
+
example_inputs,
|
|
433
|
+
example_outputs_flatten,
|
|
434
|
+
use_gpu,
|
|
435
|
+
precision == Precision.FLOAT16,
|
|
436
|
+
output_names,
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
if precision == Precision.INT8:
|
|
440
|
+
logger.info(f"Quantizing model: {onnx_model_path}")
|
|
441
|
+
QuantizeHelper.quantize_onnx_model(onnx_model_path, onnx_model_path, use_external_data_format)
|
|
442
|
+
logger.info(f"Finished quantizing model: {onnx_model_path}")
|
|
443
|
+
|
|
444
|
+
if optimize_info.name == OptimizerInfo.BYORT.name: # Use OnnxRuntime to optimize
|
|
445
|
+
if is_valid_onnx_model:
|
|
446
|
+
ort_model_path = add_filename_suffix(onnx_model_path, "_ort")
|
|
447
|
+
optimize_onnx_model_by_ort(
|
|
448
|
+
onnx_model_path,
|
|
449
|
+
ort_model_path,
|
|
450
|
+
use_gpu,
|
|
451
|
+
overwrite,
|
|
452
|
+
model_fusion_statistics,
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
return (
|
|
456
|
+
onnx_model_path,
|
|
457
|
+
is_valid_onnx_model,
|
|
458
|
+
config.num_labels if model_type in ["vit", "swin"] else config.vocab_size,
|
|
459
|
+
)
|
|
460
|
+
|
|
461
|
+
|
|
462
|
+
def export_onnx_model_from_pt(
|
|
463
|
+
model_name,
|
|
464
|
+
opset_version,
|
|
465
|
+
use_external_data_format,
|
|
466
|
+
model_type,
|
|
467
|
+
model_class,
|
|
468
|
+
config_modifier,
|
|
469
|
+
cache_dir,
|
|
470
|
+
onnx_dir,
|
|
471
|
+
input_names,
|
|
472
|
+
use_gpu,
|
|
473
|
+
precision,
|
|
474
|
+
optimizer_info,
|
|
475
|
+
validate_onnx,
|
|
476
|
+
use_raw_attention_mask,
|
|
477
|
+
overwrite,
|
|
478
|
+
model_fusion_statistics,
|
|
479
|
+
fusion_options,
|
|
480
|
+
):
|
|
481
|
+
config, model = load_pt_model(model_name, model_class, cache_dir, config_modifier)
|
|
482
|
+
# config, model = load_pt_model_from_tf(model_name)
|
|
483
|
+
model.cpu()
|
|
484
|
+
|
|
485
|
+
example_inputs = None
|
|
486
|
+
max_input_size = None
|
|
487
|
+
|
|
488
|
+
if model_type in ["vit", "swin"]:
|
|
489
|
+
image_processor = AutoFeatureExtractor.from_pretrained(model_name, cache_dir=cache_dir)
|
|
490
|
+
data = numpy.random.randint(
|
|
491
|
+
low=0, high=256, size=config.image_size * config.image_size * 3, dtype=numpy.uint8
|
|
492
|
+
).reshape(config.image_size, config.image_size, 3)
|
|
493
|
+
|
|
494
|
+
example_inputs = image_processor(data, return_tensors="pt")
|
|
495
|
+
else:
|
|
496
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
|
497
|
+
max_input_size = tokenizer.model_max_length
|
|
498
|
+
example_inputs = tokenizer.encode_plus("This is a sample input", return_tensors="pt")
|
|
499
|
+
|
|
500
|
+
example_inputs = filter_inputs(example_inputs, input_names)
|
|
501
|
+
|
|
502
|
+
example_outputs = model(**example_inputs)
|
|
503
|
+
|
|
504
|
+
assert isinstance(example_outputs, (list, tuple)), f"type of output is not list or tuple: {type(example_outputs)}"
|
|
505
|
+
|
|
506
|
+
# Flatten is needed for gpt2 and distilgpt2.
|
|
507
|
+
example_outputs_flatten = flatten(example_outputs)
|
|
508
|
+
example_outputs_flatten = update_flatten_list(example_outputs_flatten, [])
|
|
509
|
+
|
|
510
|
+
onnx_model_path = get_onnx_file_path(
|
|
511
|
+
onnx_dir,
|
|
512
|
+
model_name,
|
|
513
|
+
len(input_names),
|
|
514
|
+
False,
|
|
515
|
+
use_gpu,
|
|
516
|
+
precision,
|
|
517
|
+
False,
|
|
518
|
+
use_external_data_format,
|
|
519
|
+
)
|
|
520
|
+
|
|
521
|
+
if overwrite or not os.path.exists(onnx_model_path):
|
|
522
|
+
logger.info(f"Exporting ONNX model to {onnx_model_path}")
|
|
523
|
+
Path(onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
524
|
+
|
|
525
|
+
dynamic_axes = None
|
|
526
|
+
output_names = None
|
|
527
|
+
|
|
528
|
+
if model_type in ["vit", "swin"]:
|
|
529
|
+
dynamic_axes, output_names = {key: {0: "pixel_values"} for key in example_inputs}, ["logits"]
|
|
530
|
+
else:
|
|
531
|
+
dynamic_axes, output_names = build_dynamic_axes(example_inputs, example_outputs_flatten)
|
|
532
|
+
|
|
533
|
+
replace_torch_functions()
|
|
534
|
+
torch_onnx_export(
|
|
535
|
+
model=model,
|
|
536
|
+
args=tuple(example_inputs.values()),
|
|
537
|
+
f=onnx_model_path,
|
|
538
|
+
input_names=list(example_inputs.keys()),
|
|
539
|
+
output_names=output_names,
|
|
540
|
+
dynamic_axes=dynamic_axes,
|
|
541
|
+
do_constant_folding=True,
|
|
542
|
+
opset_version=opset_version,
|
|
543
|
+
use_external_data_format=use_external_data_format,
|
|
544
|
+
)
|
|
545
|
+
restore_torch_functions()
|
|
546
|
+
else:
|
|
547
|
+
logger.info(f"Skip export since model existed: {onnx_model_path}")
|
|
548
|
+
|
|
549
|
+
onnx_model_file, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
|
|
550
|
+
model_name,
|
|
551
|
+
use_external_data_format,
|
|
552
|
+
model_type,
|
|
553
|
+
onnx_dir,
|
|
554
|
+
input_names,
|
|
555
|
+
use_gpu,
|
|
556
|
+
precision,
|
|
557
|
+
optimizer_info,
|
|
558
|
+
validate_onnx,
|
|
559
|
+
use_raw_attention_mask,
|
|
560
|
+
overwrite,
|
|
561
|
+
config,
|
|
562
|
+
model_fusion_statistics,
|
|
563
|
+
onnx_model_path,
|
|
564
|
+
example_inputs,
|
|
565
|
+
example_outputs_flatten,
|
|
566
|
+
None,
|
|
567
|
+
fusion_options,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
return onnx_model_file, is_valid_onnx_model, vocab_size, max_input_size
|
|
571
|
+
|
|
572
|
+
|
|
573
|
+
def export_onnx_model_from_tf(
|
|
574
|
+
model_name,
|
|
575
|
+
opset_version,
|
|
576
|
+
use_external_data_format,
|
|
577
|
+
model_type,
|
|
578
|
+
model_class,
|
|
579
|
+
config_modifier,
|
|
580
|
+
cache_dir,
|
|
581
|
+
onnx_dir,
|
|
582
|
+
input_names,
|
|
583
|
+
use_gpu,
|
|
584
|
+
precision,
|
|
585
|
+
optimizer_info,
|
|
586
|
+
validate_onnx,
|
|
587
|
+
use_raw_attention_mask,
|
|
588
|
+
overwrite,
|
|
589
|
+
model_fusion_statistics,
|
|
590
|
+
fusion_options,
|
|
591
|
+
):
|
|
592
|
+
# Use CPU to export
|
|
593
|
+
import tensorflow as tf # noqa: PLC0415
|
|
594
|
+
|
|
595
|
+
tf.config.set_visible_devices([], "GPU")
|
|
596
|
+
|
|
597
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name, cache_dir=cache_dir)
|
|
598
|
+
# Fix "Using pad_token, but it is not set yet" error.
|
|
599
|
+
if tokenizer.pad_token is None:
|
|
600
|
+
tokenizer.add_special_tokens({"pad_token": "[PAD]"})
|
|
601
|
+
max_input_size = tokenizer.model_max_length
|
|
602
|
+
|
|
603
|
+
config, model = load_tf_model(model_name, model_class, cache_dir, config_modifier)
|
|
604
|
+
model.resize_token_embeddings(len(tokenizer))
|
|
605
|
+
|
|
606
|
+
example_inputs = tokenizer.encode_plus(
|
|
607
|
+
"This is a sample input",
|
|
608
|
+
return_tensors="tf",
|
|
609
|
+
max_length=max_input_size,
|
|
610
|
+
padding="max_length",
|
|
611
|
+
truncation=True,
|
|
612
|
+
)
|
|
613
|
+
example_inputs = filter_inputs(example_inputs, input_names)
|
|
614
|
+
|
|
615
|
+
if config.is_encoder_decoder:
|
|
616
|
+
example_inputs["decoder_input_ids"] = tokenizer.encode_plus(
|
|
617
|
+
"This is a sample input",
|
|
618
|
+
return_tensors="tf",
|
|
619
|
+
max_length=max_input_size,
|
|
620
|
+
padding="max_length",
|
|
621
|
+
truncation=True,
|
|
622
|
+
).input_ids
|
|
623
|
+
if model_name == "unc-nlp/lxmert-base-uncased":
|
|
624
|
+
example_inputs["visual_feats"] = tf.random.normal([1, 1, config.visual_feat_dim])
|
|
625
|
+
example_inputs["visual_pos"] = tf.random.normal([1, 1, config.visual_pos_dim])
|
|
626
|
+
|
|
627
|
+
try:
|
|
628
|
+
# Use no past state for these models
|
|
629
|
+
if config.use_cache:
|
|
630
|
+
config.use_cache = False
|
|
631
|
+
except Exception:
|
|
632
|
+
pass
|
|
633
|
+
|
|
634
|
+
example_outputs = model(example_inputs, training=False)
|
|
635
|
+
output_names = None
|
|
636
|
+
|
|
637
|
+
# For xlnet models, only compare the last_hidden_state output.
|
|
638
|
+
if model_name == "xlnet-base-cased" or model_name == "xlnet-large-cased":
|
|
639
|
+
output_names = ["last_hidden_state"]
|
|
640
|
+
example_outputs = example_outputs["last_hidden_state"]
|
|
641
|
+
|
|
642
|
+
# Flatten is needed for gpt2 and distilgpt2. Output name sorting is needed for tf2onnx outputs to match onnx outputs.
|
|
643
|
+
from tensorflow.python.util import nest # noqa: PLC0415
|
|
644
|
+
|
|
645
|
+
example_outputs_flatten = nest.flatten(example_outputs)
|
|
646
|
+
|
|
647
|
+
onnx_model_path = get_onnx_file_path(
|
|
648
|
+
onnx_dir,
|
|
649
|
+
model_name,
|
|
650
|
+
len(input_names),
|
|
651
|
+
False,
|
|
652
|
+
use_gpu,
|
|
653
|
+
precision,
|
|
654
|
+
False,
|
|
655
|
+
use_external_data_format,
|
|
656
|
+
)
|
|
657
|
+
tf_internal_model_path = onnx_model_path[:-5] if use_external_data_format else onnx_model_path
|
|
658
|
+
|
|
659
|
+
if overwrite or not os.path.exists(tf_internal_model_path):
|
|
660
|
+
logger.info(f"Exporting ONNX model to {onnx_model_path}")
|
|
661
|
+
if not use_external_data_format:
|
|
662
|
+
Path(tf_internal_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
663
|
+
|
|
664
|
+
import zipfile # noqa: PLC0415
|
|
665
|
+
|
|
666
|
+
import tf2onnx # noqa: PLC0415
|
|
667
|
+
|
|
668
|
+
tf2onnx.logging.set_level(tf2onnx.logging.ERROR)
|
|
669
|
+
specs = []
|
|
670
|
+
for name, value in example_inputs.items():
|
|
671
|
+
dims = [None] * len(value.shape)
|
|
672
|
+
specs.append(tf.TensorSpec(tuple(dims), value.dtype, name=name))
|
|
673
|
+
_, _ = tf2onnx.convert.from_keras(
|
|
674
|
+
model,
|
|
675
|
+
input_signature=tuple(specs),
|
|
676
|
+
opset=opset_version,
|
|
677
|
+
large_model=use_external_data_format,
|
|
678
|
+
output_path=tf_internal_model_path,
|
|
679
|
+
)
|
|
680
|
+
if use_external_data_format:
|
|
681
|
+
# need to unpack the zip for run_onnxruntime()
|
|
682
|
+
with zipfile.ZipFile(tf_internal_model_path, "r") as z:
|
|
683
|
+
z.extractall(os.path.dirname(tf_internal_model_path))
|
|
684
|
+
tf_internal_model_path = os.path.join(os.path.dirname(tf_internal_model_path), "__MODEL_PROTO.onnx")
|
|
685
|
+
if os.path.exists(onnx_model_path):
|
|
686
|
+
os.remove(onnx_model_path)
|
|
687
|
+
os.rename(tf_internal_model_path, onnx_model_path)
|
|
688
|
+
|
|
689
|
+
else:
|
|
690
|
+
logger.info(f"Skip export since model existed: {onnx_model_path}")
|
|
691
|
+
|
|
692
|
+
model_type = model_type + "_tf"
|
|
693
|
+
optimized_onnx_path, is_valid_onnx_model, vocab_size = validate_and_optimize_onnx(
|
|
694
|
+
model_name,
|
|
695
|
+
use_external_data_format,
|
|
696
|
+
model_type,
|
|
697
|
+
onnx_dir,
|
|
698
|
+
input_names,
|
|
699
|
+
use_gpu,
|
|
700
|
+
precision,
|
|
701
|
+
optimizer_info,
|
|
702
|
+
validate_onnx,
|
|
703
|
+
use_raw_attention_mask,
|
|
704
|
+
overwrite,
|
|
705
|
+
config,
|
|
706
|
+
model_fusion_statistics,
|
|
707
|
+
onnx_model_path,
|
|
708
|
+
example_inputs,
|
|
709
|
+
example_outputs_flatten,
|
|
710
|
+
output_names,
|
|
711
|
+
fusion_options,
|
|
712
|
+
)
|
|
713
|
+
|
|
714
|
+
return (
|
|
715
|
+
optimized_onnx_path,
|
|
716
|
+
is_valid_onnx_model,
|
|
717
|
+
vocab_size,
|
|
718
|
+
max_input_size,
|
|
719
|
+
)
|