onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,299 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
import tempfile
|
|
10
|
+
from pathlib import Path
|
|
11
|
+
from typing import List, Optional, Union
|
|
12
|
+
|
|
13
|
+
import numpy
|
|
14
|
+
import onnx
|
|
15
|
+
import torch
|
|
16
|
+
from onnx_model import OnnxModel
|
|
17
|
+
from past_helper import PastKeyValuesHelper
|
|
18
|
+
from t5_decoder import T5DecoderInit
|
|
19
|
+
from t5_encoder import T5Encoder, T5EncoderInputs
|
|
20
|
+
from torch_onnx_export_helper import torch_onnx_export
|
|
21
|
+
from transformers import MT5Config, T5Config
|
|
22
|
+
|
|
23
|
+
from onnxruntime import InferenceSession
|
|
24
|
+
|
|
25
|
+
logger = logging.getLogger(__name__)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class T5EncoderDecoderInit(torch.nn.Module):
|
|
29
|
+
"""A combination of T5Encoder and T5DecoderInit."""
|
|
30
|
+
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
encoder: torch.nn.Module,
|
|
34
|
+
decoder: torch.nn.Module,
|
|
35
|
+
lm_head: torch.nn.Module,
|
|
36
|
+
config: Union[T5Config, MT5Config],
|
|
37
|
+
decoder_start_token_id: Optional[int] = None,
|
|
38
|
+
):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.config = config
|
|
41
|
+
self.t5_encoder = T5Encoder(encoder, config)
|
|
42
|
+
self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
|
|
43
|
+
|
|
44
|
+
def forward(
|
|
45
|
+
self,
|
|
46
|
+
encoder_input_ids: torch.Tensor,
|
|
47
|
+
encoder_attention_mask: torch.Tensor,
|
|
48
|
+
decoder_input_ids: torch.Tensor = None,
|
|
49
|
+
):
|
|
50
|
+
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
|
|
51
|
+
lm_logits, past_self, past_cross = self.t5_decoder_init(
|
|
52
|
+
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
|
|
53
|
+
)
|
|
54
|
+
return lm_logits, encoder_hidden_states, past_self, past_cross
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
class T5EncoderDecoderInitInputs:
|
|
58
|
+
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
|
|
59
|
+
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
|
|
60
|
+
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
|
|
61
|
+
self.decoder_input_ids: torch.LongTensor = decoder_input_ids
|
|
62
|
+
|
|
63
|
+
@staticmethod
|
|
64
|
+
def create_dummy(
|
|
65
|
+
config: Union[T5Config, MT5Config],
|
|
66
|
+
batch_size: int,
|
|
67
|
+
encode_sequence_length: int,
|
|
68
|
+
use_decoder_input_ids: int,
|
|
69
|
+
device: torch.device,
|
|
70
|
+
use_int32_inputs: bool = False,
|
|
71
|
+
): # -> T5EncoderDecoderInitInputs:
|
|
72
|
+
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
|
|
73
|
+
batch_size,
|
|
74
|
+
encode_sequence_length,
|
|
75
|
+
config.vocab_size,
|
|
76
|
+
device,
|
|
77
|
+
use_int32_inputs=use_int32_inputs,
|
|
78
|
+
)
|
|
79
|
+
decoder_input_ids = None
|
|
80
|
+
if use_decoder_input_ids:
|
|
81
|
+
dtype = torch.int32 if use_int32_inputs else torch.int64
|
|
82
|
+
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
|
|
83
|
+
|
|
84
|
+
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
|
|
85
|
+
|
|
86
|
+
def to_list(self) -> List:
|
|
87
|
+
input_list = [self.encoder_input_ids, self.encoder_attention_mask]
|
|
88
|
+
if self.decoder_input_ids is not None:
|
|
89
|
+
input_list.append(self.decoder_input_ids)
|
|
90
|
+
return input_list
|
|
91
|
+
|
|
92
|
+
|
|
93
|
+
class T5EncoderDecoderInitHelper:
|
|
94
|
+
@staticmethod
|
|
95
|
+
def export_onnx(
|
|
96
|
+
model: T5EncoderDecoderInit,
|
|
97
|
+
device: torch.device,
|
|
98
|
+
onnx_model_path: str,
|
|
99
|
+
use_decoder_input_ids: bool = True,
|
|
100
|
+
verbose: bool = True,
|
|
101
|
+
use_external_data_format: bool = False,
|
|
102
|
+
use_int32_inputs: bool = False,
|
|
103
|
+
):
|
|
104
|
+
"""Export decoder to ONNX
|
|
105
|
+
|
|
106
|
+
Args:
|
|
107
|
+
model (T5EncoderDecoderInit): the model to export
|
|
108
|
+
device (torch.device): device of decoder object
|
|
109
|
+
onnx_model_path (str): onnx path
|
|
110
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
111
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
112
|
+
"""
|
|
113
|
+
assert isinstance(model, T5EncoderDecoderInit)
|
|
114
|
+
|
|
115
|
+
inputs = T5EncoderDecoderInitInputs.create_dummy(
|
|
116
|
+
model.config,
|
|
117
|
+
batch_size=2,
|
|
118
|
+
encode_sequence_length=3,
|
|
119
|
+
use_decoder_input_ids=use_decoder_input_ids,
|
|
120
|
+
device=device,
|
|
121
|
+
use_int32_inputs=use_int32_inputs,
|
|
122
|
+
)
|
|
123
|
+
input_list = inputs.to_list()
|
|
124
|
+
|
|
125
|
+
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
|
|
126
|
+
|
|
127
|
+
output_names = ["logits", "encoder_hidden_states", *present_names]
|
|
128
|
+
|
|
129
|
+
# Shape of input tensors (sequence_length==1):
|
|
130
|
+
# input_ids: (batch_size, sequence_length)
|
|
131
|
+
# encoder_attention_mask: (batch_size, encode_sequence_length)
|
|
132
|
+
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
|
|
133
|
+
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
|
134
|
+
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
135
|
+
|
|
136
|
+
# Shape of output tensors:
|
|
137
|
+
# logits: (batch_size, sequence_length, vocab_size)
|
|
138
|
+
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
|
|
139
|
+
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
140
|
+
|
|
141
|
+
input_names = ["encoder_input_ids", "encoder_attention_mask"]
|
|
142
|
+
|
|
143
|
+
# ONNX exporter might mark dimension like 'Transposepresent_value_self_1_dim_2' in shape inference.
|
|
144
|
+
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
|
|
145
|
+
sequence_length = "1"
|
|
146
|
+
num_heads = str(model.config.num_heads)
|
|
147
|
+
hidden_size = str(model.config.d_model)
|
|
148
|
+
head_size = str(model.config.d_kv)
|
|
149
|
+
|
|
150
|
+
dynamic_axes = {
|
|
151
|
+
"encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
|
|
152
|
+
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
|
|
153
|
+
"encoder_hidden_states": {
|
|
154
|
+
0: "batch_size",
|
|
155
|
+
1: "encode_sequence_length",
|
|
156
|
+
2: hidden_size,
|
|
157
|
+
},
|
|
158
|
+
"logits": {
|
|
159
|
+
0: "batch_size",
|
|
160
|
+
1: sequence_length,
|
|
161
|
+
},
|
|
162
|
+
}
|
|
163
|
+
|
|
164
|
+
if use_decoder_input_ids:
|
|
165
|
+
input_names.append("decoder_input_ids")
|
|
166
|
+
dynamic_axes["decoder_input_ids"] = {
|
|
167
|
+
0: "batch_size",
|
|
168
|
+
1: sequence_length,
|
|
169
|
+
}
|
|
170
|
+
|
|
171
|
+
for name in present_names:
|
|
172
|
+
if "cross" in name:
|
|
173
|
+
dynamic_axes[name] = {
|
|
174
|
+
0: "batch_size",
|
|
175
|
+
1: num_heads,
|
|
176
|
+
2: "encode_sequence_length",
|
|
177
|
+
3: head_size,
|
|
178
|
+
}
|
|
179
|
+
|
|
180
|
+
else: # self attention past state
|
|
181
|
+
dynamic_axes[name] = {
|
|
182
|
+
0: "batch_size",
|
|
183
|
+
1: num_heads,
|
|
184
|
+
2: sequence_length,
|
|
185
|
+
3: head_size,
|
|
186
|
+
}
|
|
187
|
+
|
|
188
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
189
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
|
|
190
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
191
|
+
torch_onnx_export(
|
|
192
|
+
model,
|
|
193
|
+
args=tuple(input_list),
|
|
194
|
+
f=temp_onnx_model_path,
|
|
195
|
+
export_params=True,
|
|
196
|
+
input_names=input_names,
|
|
197
|
+
output_names=output_names,
|
|
198
|
+
dynamic_axes=dynamic_axes,
|
|
199
|
+
opset_version=12,
|
|
200
|
+
do_constant_folding=True,
|
|
201
|
+
use_external_data_format=use_external_data_format,
|
|
202
|
+
verbose=verbose,
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
# Workaround as mentioned earlier: change numeric dim_param to dim_value
|
|
206
|
+
model = onnx.load(temp_onnx_model_path)
|
|
207
|
+
for tensor in model.graph.output:
|
|
208
|
+
for dim_proto in tensor.type.tensor_type.shape.dim:
|
|
209
|
+
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
|
|
210
|
+
sequence_length,
|
|
211
|
+
num_heads,
|
|
212
|
+
hidden_size,
|
|
213
|
+
head_size,
|
|
214
|
+
]:
|
|
215
|
+
dim_value = int(dim_proto.dim_param)
|
|
216
|
+
dim_proto.Clear()
|
|
217
|
+
dim_proto.dim_value = dim_value
|
|
218
|
+
|
|
219
|
+
OnnxModel.save(
|
|
220
|
+
model,
|
|
221
|
+
onnx_model_path,
|
|
222
|
+
save_as_external_data=use_external_data_format,
|
|
223
|
+
all_tensors_to_one_file=True,
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
@staticmethod
|
|
227
|
+
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
|
|
228
|
+
"""Run inference of ONNX model."""
|
|
229
|
+
logger.debug("start onnxruntime_inference")
|
|
230
|
+
|
|
231
|
+
ort_inputs = {
|
|
232
|
+
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
|
|
233
|
+
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
|
|
234
|
+
}
|
|
235
|
+
if inputs.decoder_input_ids is not None:
|
|
236
|
+
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
|
|
237
|
+
|
|
238
|
+
ort_outputs = ort_session.run(None, ort_inputs)
|
|
239
|
+
return ort_outputs
|
|
240
|
+
|
|
241
|
+
@staticmethod
|
|
242
|
+
def verify_onnx(
|
|
243
|
+
model: T5EncoderDecoderInit,
|
|
244
|
+
ort_session: InferenceSession,
|
|
245
|
+
device: torch.device,
|
|
246
|
+
use_int32_inputs: bool,
|
|
247
|
+
max_cases: int = 4,
|
|
248
|
+
):
|
|
249
|
+
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
250
|
+
ort_inputs = ort_session.get_inputs()
|
|
251
|
+
use_decoder_input_ids = len(ort_inputs) == 3
|
|
252
|
+
|
|
253
|
+
test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
|
|
254
|
+
test_cases_max_diff = []
|
|
255
|
+
for batch_size, encode_sequence_length in test_cases[:max_cases]:
|
|
256
|
+
inputs = T5EncoderDecoderInitInputs.create_dummy(
|
|
257
|
+
model.config,
|
|
258
|
+
batch_size,
|
|
259
|
+
encode_sequence_length,
|
|
260
|
+
use_decoder_input_ids=use_decoder_input_ids,
|
|
261
|
+
device=device,
|
|
262
|
+
use_int32_inputs=use_int32_inputs,
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
|
|
266
|
+
|
|
267
|
+
# Run inference of PyTorch model
|
|
268
|
+
input_list = inputs.to_list()
|
|
269
|
+
torch_outputs = model(*input_list)
|
|
270
|
+
|
|
271
|
+
num_decoder_layers = model.config.num_decoder_layers
|
|
272
|
+
|
|
273
|
+
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
|
|
274
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
|
|
275
|
+
logger.debug(f"logits max_diff={max_diff}")
|
|
276
|
+
max_diff_all = max_diff
|
|
277
|
+
|
|
278
|
+
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
|
|
279
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
|
|
280
|
+
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
|
|
281
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
282
|
+
|
|
283
|
+
for i in range(2 * num_decoder_layers):
|
|
284
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
|
|
285
|
+
logger.debug(f"self attention past state {i} max_diff={max_diff}")
|
|
286
|
+
|
|
287
|
+
for i in range(2 * num_decoder_layers):
|
|
288
|
+
max_diff = numpy.amax(
|
|
289
|
+
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
|
|
290
|
+
)
|
|
291
|
+
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
|
292
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
293
|
+
|
|
294
|
+
test_cases_max_diff.append(max_diff_all)
|
|
295
|
+
logger.info(
|
|
296
|
+
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
|
|
297
|
+
)
|
|
298
|
+
|
|
299
|
+
return max(test_cases_max_diff)
|
|
@@ -0,0 +1,272 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
from typing import Dict, List, Union
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from float16 import float_to_float16_max_diff
|
|
14
|
+
from onnx_model import OnnxModel
|
|
15
|
+
from optimizer import optimize_model
|
|
16
|
+
from t5_decoder import T5Decoder, T5DecoderHelper, T5DecoderInit
|
|
17
|
+
from t5_encoder import T5Encoder, T5EncoderHelper
|
|
18
|
+
from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
|
|
19
|
+
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
|
|
20
|
+
|
|
21
|
+
from onnxruntime import InferenceSession
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
|
|
26
|
+
PRETRAINED_MT5_MODELS = ["google/mt5-small", "google/mt5-base", "google/mt5-large", "google/mt5-xl", "google/mt5-xxl"]
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class T5Helper:
|
|
30
|
+
@staticmethod
|
|
31
|
+
def get_onnx_path(
|
|
32
|
+
output_dir: str,
|
|
33
|
+
model_name_or_path: str,
|
|
34
|
+
suffix: str = "",
|
|
35
|
+
new_folder: bool = False,
|
|
36
|
+
) -> str:
|
|
37
|
+
"""Build onnx path
|
|
38
|
+
|
|
39
|
+
Args:
|
|
40
|
+
output_dir (str): output directory
|
|
41
|
+
model_name_or_path (str): pretrained model name, or path to the model checkpoint
|
|
42
|
+
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
|
|
43
|
+
new_folder (bool, optional): create a new directory for the model. Defaults to False.
|
|
44
|
+
|
|
45
|
+
Returns:
|
|
46
|
+
str: path of onnx model
|
|
47
|
+
"""
|
|
48
|
+
model_name = model_name_or_path
|
|
49
|
+
if os.path.isdir(model_name_or_path):
|
|
50
|
+
model_name = Path(model_name_or_path).parts[-1]
|
|
51
|
+
else:
|
|
52
|
+
model_name.split("/")[-1]
|
|
53
|
+
|
|
54
|
+
model_name += suffix
|
|
55
|
+
|
|
56
|
+
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
|
|
57
|
+
return os.path.join(directory, model_name + ".onnx")
|
|
58
|
+
|
|
59
|
+
@staticmethod
|
|
60
|
+
def load_model(
|
|
61
|
+
model_name_or_path: str,
|
|
62
|
+
cache_dir: str,
|
|
63
|
+
device: torch.device,
|
|
64
|
+
merge_encoder_and_decoder_init: bool = True,
|
|
65
|
+
model_type: str = "t5",
|
|
66
|
+
state_dict_path: str = "",
|
|
67
|
+
) -> Dict[str, torch.nn.Module]:
|
|
68
|
+
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
|
69
|
+
|
|
70
|
+
Args:
|
|
71
|
+
model_name_or_path (str): pretrained model name or path
|
|
72
|
+
cache_dir (str): cache directory
|
|
73
|
+
device (torch.device): device to run the model
|
|
74
|
+
merge_encoder_and_decoder_init (bool, optional): Whether merge encoder and decoder initialization into one ONNX model. Defaults to True.
|
|
75
|
+
is_mt5 (bool, optional): whether the model is MT5 instead of T5
|
|
76
|
+
Returns:
|
|
77
|
+
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
|
78
|
+
"""
|
|
79
|
+
if model_type == "t5":
|
|
80
|
+
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
81
|
+
elif model_type == "mt5":
|
|
82
|
+
model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
83
|
+
else:
|
|
84
|
+
raise ValueError("only support mode_type=t5 or mt5")
|
|
85
|
+
|
|
86
|
+
if state_dict_path:
|
|
87
|
+
model.load_state_dict(torch.load(state_dict_path))
|
|
88
|
+
|
|
89
|
+
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
|
|
90
|
+
decoder.eval().to(device)
|
|
91
|
+
|
|
92
|
+
if merge_encoder_and_decoder_init:
|
|
93
|
+
encoder_decoder_init = T5EncoderDecoderInit(
|
|
94
|
+
model.encoder,
|
|
95
|
+
model.decoder,
|
|
96
|
+
model.lm_head,
|
|
97
|
+
model.config,
|
|
98
|
+
decoder_start_token_id=None,
|
|
99
|
+
)
|
|
100
|
+
return {"encoder_decoder_init": encoder_decoder_init, "decoder": decoder}
|
|
101
|
+
else:
|
|
102
|
+
encoder = T5Encoder(model.encoder, model.config)
|
|
103
|
+
encoder.eval().to(device)
|
|
104
|
+
decoder_init = T5DecoderInit(model.decoder, model.lm_head, model.config)
|
|
105
|
+
decoder_init.eval().to(device)
|
|
106
|
+
return {
|
|
107
|
+
"encoder": encoder,
|
|
108
|
+
"decoder": decoder,
|
|
109
|
+
"decoder_init": decoder_init,
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
@staticmethod
|
|
113
|
+
def export_onnx(
|
|
114
|
+
model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
|
|
115
|
+
device: torch.device,
|
|
116
|
+
onnx_model_path: str,
|
|
117
|
+
verbose: bool = True,
|
|
118
|
+
use_external_data_format: bool = False,
|
|
119
|
+
use_decoder_input_ids: bool = True,
|
|
120
|
+
use_int32_inputs: bool = False,
|
|
121
|
+
):
|
|
122
|
+
if isinstance(model, T5Encoder):
|
|
123
|
+
T5EncoderHelper.export_onnx(
|
|
124
|
+
model,
|
|
125
|
+
device,
|
|
126
|
+
onnx_model_path,
|
|
127
|
+
verbose,
|
|
128
|
+
use_external_data_format,
|
|
129
|
+
use_int32_inputs,
|
|
130
|
+
)
|
|
131
|
+
elif isinstance(model, T5EncoderDecoderInit):
|
|
132
|
+
T5EncoderDecoderInitHelper.export_onnx(
|
|
133
|
+
model,
|
|
134
|
+
device,
|
|
135
|
+
onnx_model_path,
|
|
136
|
+
use_decoder_input_ids,
|
|
137
|
+
verbose,
|
|
138
|
+
use_external_data_format,
|
|
139
|
+
use_int32_inputs,
|
|
140
|
+
)
|
|
141
|
+
else:
|
|
142
|
+
T5DecoderHelper.export_onnx(
|
|
143
|
+
model,
|
|
144
|
+
device,
|
|
145
|
+
onnx_model_path,
|
|
146
|
+
verbose,
|
|
147
|
+
use_external_data_format,
|
|
148
|
+
use_int32_inputs,
|
|
149
|
+
)
|
|
150
|
+
|
|
151
|
+
@staticmethod
|
|
152
|
+
def auto_mixed_precision(
|
|
153
|
+
onnx_model: OnnxModel,
|
|
154
|
+
op_block_list: List[str] = [ # noqa: B006
|
|
155
|
+
"SimplifiedLayerNormalization",
|
|
156
|
+
"SkipSimplifiedLayerNormalization",
|
|
157
|
+
"Relu",
|
|
158
|
+
"Add",
|
|
159
|
+
],
|
|
160
|
+
):
|
|
161
|
+
"""Convert model to mixed precision.
|
|
162
|
+
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
|
|
163
|
+
Args:
|
|
164
|
+
onnx_model (OnnxModel): optimized ONNX model
|
|
165
|
+
op_block_list (List[str], optional): . Defaults to ["SimplifiedLayerNormalization", "SkipSimplifiedLayerNormalization", "Relu", "Add"]
|
|
166
|
+
Returns:
|
|
167
|
+
parameters(dict): a dictionary of parameters used in float16 conversion
|
|
168
|
+
"""
|
|
169
|
+
op_full_set = {node.op_type for node in onnx_model.nodes()}
|
|
170
|
+
fp32_op_set = set(op_block_list)
|
|
171
|
+
fp16_op_set = op_full_set.difference(fp32_op_set)
|
|
172
|
+
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
|
|
173
|
+
|
|
174
|
+
# logits is the first output
|
|
175
|
+
logits_output_name = onnx_model.graph().output[0].name
|
|
176
|
+
|
|
177
|
+
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
|
|
178
|
+
is_weight_fp16_precision = False
|
|
179
|
+
output_name_to_node = onnx_model.output_name_to_node()
|
|
180
|
+
assert logits_output_name in output_name_to_node
|
|
181
|
+
node = output_name_to_node[logits_output_name]
|
|
182
|
+
last_matmul_node = None
|
|
183
|
+
if node.op_type == "MatMul":
|
|
184
|
+
last_matmul_node = node
|
|
185
|
+
logger.info(f"Found last MatMul node for logits: {node.name}")
|
|
186
|
+
initializer = None
|
|
187
|
+
for input in node.input:
|
|
188
|
+
initializer = onnx_model.get_initializer(input)
|
|
189
|
+
if initializer is not None:
|
|
190
|
+
break
|
|
191
|
+
|
|
192
|
+
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
|
|
193
|
+
# we can deduce that the weights are stored in float16 precision.
|
|
194
|
+
max_diff = float_to_float16_max_diff(initializer)
|
|
195
|
+
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
|
|
196
|
+
is_weight_fp16_precision = max_diff < 1e-6
|
|
197
|
+
else:
|
|
198
|
+
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
|
|
199
|
+
|
|
200
|
+
keep_io_types = []
|
|
201
|
+
node_block_list = []
|
|
202
|
+
if (not is_weight_fp16_precision) and (last_matmul_node is not None):
|
|
203
|
+
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
|
|
204
|
+
keep_io_types = [logits_output_name]
|
|
205
|
+
node_block_list = [last_matmul_node.name]
|
|
206
|
+
|
|
207
|
+
parameters = {
|
|
208
|
+
"keep_io_types": keep_io_types,
|
|
209
|
+
"op_block_list": op_block_list,
|
|
210
|
+
"node_block_list": node_block_list,
|
|
211
|
+
"force_fp16_initializers": is_weight_fp16_precision,
|
|
212
|
+
}
|
|
213
|
+
|
|
214
|
+
logger.info(f"auto_mixed_precision parameters: {parameters}")
|
|
215
|
+
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
|
|
216
|
+
|
|
217
|
+
return parameters
|
|
218
|
+
|
|
219
|
+
@staticmethod
|
|
220
|
+
def optimize_onnx(
|
|
221
|
+
onnx_model_path: str,
|
|
222
|
+
optimized_model_path: str,
|
|
223
|
+
is_float16: bool,
|
|
224
|
+
num_attention_heads: int,
|
|
225
|
+
hidden_size: int,
|
|
226
|
+
use_external_data_format: bool = False,
|
|
227
|
+
auto_mixed_precision: bool = True,
|
|
228
|
+
use_gpu: bool = False,
|
|
229
|
+
):
|
|
230
|
+
"""Optimize ONNX model with an option to convert it to use mixed precision."""
|
|
231
|
+
|
|
232
|
+
from fusion_options import FusionOptions
|
|
233
|
+
|
|
234
|
+
optimization_options = None
|
|
235
|
+
if is_float16:
|
|
236
|
+
optimization_options = FusionOptions("t5")
|
|
237
|
+
optimization_options.enable_skip_layer_norm = False
|
|
238
|
+
|
|
239
|
+
m = optimize_model(
|
|
240
|
+
onnx_model_path,
|
|
241
|
+
model_type="t5",
|
|
242
|
+
num_heads=num_attention_heads,
|
|
243
|
+
hidden_size=hidden_size,
|
|
244
|
+
opt_level=2 if not use_external_data_format else 0,
|
|
245
|
+
optimization_options=optimization_options,
|
|
246
|
+
use_gpu=False,
|
|
247
|
+
only_onnxruntime=not use_gpu,
|
|
248
|
+
)
|
|
249
|
+
|
|
250
|
+
if is_float16:
|
|
251
|
+
if auto_mixed_precision:
|
|
252
|
+
T5Helper.auto_mixed_precision(m)
|
|
253
|
+
else:
|
|
254
|
+
m.convert_model_float32_to_float16(cast_input_output=False)
|
|
255
|
+
|
|
256
|
+
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
|
|
257
|
+
|
|
258
|
+
@staticmethod
|
|
259
|
+
def verify_onnx(
|
|
260
|
+
model: Union[T5Encoder, T5Decoder, T5DecoderInit, T5EncoderDecoderInit],
|
|
261
|
+
ort_session: InferenceSession,
|
|
262
|
+
device: torch.device,
|
|
263
|
+
use_int32_inputs: bool,
|
|
264
|
+
):
|
|
265
|
+
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
266
|
+
if isinstance(model, T5Encoder):
|
|
267
|
+
return T5EncoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
|
268
|
+
|
|
269
|
+
if isinstance(model, T5EncoderDecoderInit):
|
|
270
|
+
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
|
271
|
+
|
|
272
|
+
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|