onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,361 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# -------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
import tempfile
|
|
9
|
+
from pathlib import Path
|
|
10
|
+
|
|
11
|
+
import numpy
|
|
12
|
+
import onnx
|
|
13
|
+
import torch
|
|
14
|
+
from onnx_model import OnnxModel
|
|
15
|
+
from past_helper import PastKeyValuesHelper
|
|
16
|
+
from t5_decoder import T5DecoderInit
|
|
17
|
+
from t5_encoder import T5Encoder, T5EncoderInputs
|
|
18
|
+
from torch_onnx_export_helper import torch_onnx_export
|
|
19
|
+
from transformers import MT5Config, T5Config
|
|
20
|
+
|
|
21
|
+
from onnxruntime import InferenceSession
|
|
22
|
+
|
|
23
|
+
logger = logging.getLogger(__name__)
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
class T5EncoderDecoderInit(torch.nn.Module):
|
|
27
|
+
"""A combination of T5Encoder and T5DecoderInit."""
|
|
28
|
+
|
|
29
|
+
def __init__(
|
|
30
|
+
self,
|
|
31
|
+
encoder: torch.nn.Module,
|
|
32
|
+
decoder: torch.nn.Module,
|
|
33
|
+
lm_head: torch.nn.Linear,
|
|
34
|
+
config: T5Config | MT5Config,
|
|
35
|
+
decoder_start_token_id: int | None = None,
|
|
36
|
+
output_cross_only: bool = False,
|
|
37
|
+
):
|
|
38
|
+
super().__init__()
|
|
39
|
+
self.config: T5Config | MT5Config = config
|
|
40
|
+
self.t5_encoder = T5Encoder(encoder, config)
|
|
41
|
+
self.t5_decoder_init = T5DecoderInit(decoder, lm_head, config, decoder_start_token_id)
|
|
42
|
+
self.output_cross_only = output_cross_only
|
|
43
|
+
|
|
44
|
+
def forward(
|
|
45
|
+
self,
|
|
46
|
+
encoder_input_ids: torch.Tensor,
|
|
47
|
+
encoder_attention_mask: torch.Tensor,
|
|
48
|
+
decoder_input_ids: torch.Tensor | None = None,
|
|
49
|
+
):
|
|
50
|
+
encoder_hidden_states: torch.FloatTensor = self.t5_encoder(encoder_input_ids, encoder_attention_mask)
|
|
51
|
+
|
|
52
|
+
lm_logits, past_self, past_cross = self.t5_decoder_init(
|
|
53
|
+
decoder_input_ids, encoder_attention_mask, encoder_hidden_states
|
|
54
|
+
)
|
|
55
|
+
|
|
56
|
+
if self.output_cross_only:
|
|
57
|
+
return past_cross
|
|
58
|
+
else:
|
|
59
|
+
return lm_logits, encoder_hidden_states, past_self, past_cross
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class T5EncoderDecoderInitInputs:
|
|
63
|
+
def __init__(self, encoder_input_ids, encoder_attention_mask, decoder_input_ids=None):
|
|
64
|
+
self.encoder_input_ids: torch.LongTensor = encoder_input_ids
|
|
65
|
+
self.encoder_attention_mask: torch.LongTensor = encoder_attention_mask
|
|
66
|
+
self.decoder_input_ids: torch.LongTensor | None = decoder_input_ids
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def create_dummy(
|
|
70
|
+
config: T5Config | MT5Config,
|
|
71
|
+
batch_size: int,
|
|
72
|
+
encode_sequence_length: int,
|
|
73
|
+
use_decoder_input_ids: int,
|
|
74
|
+
device: torch.device,
|
|
75
|
+
use_int32_inputs: bool = False,
|
|
76
|
+
): # -> T5EncoderDecoderInitInputs:
|
|
77
|
+
encoder_inputs: T5EncoderInputs = T5EncoderInputs.create_dummy(
|
|
78
|
+
batch_size,
|
|
79
|
+
encode_sequence_length,
|
|
80
|
+
config.vocab_size,
|
|
81
|
+
device,
|
|
82
|
+
use_int32_inputs=use_int32_inputs,
|
|
83
|
+
)
|
|
84
|
+
decoder_input_ids = None
|
|
85
|
+
if use_decoder_input_ids:
|
|
86
|
+
dtype = torch.int32 if use_int32_inputs else torch.int64
|
|
87
|
+
decoder_input_ids = torch.ones((batch_size, 1), dtype=dtype, device=device) * config.decoder_start_token_id
|
|
88
|
+
|
|
89
|
+
return T5EncoderDecoderInitInputs(encoder_inputs.input_ids, encoder_inputs.attention_mask, decoder_input_ids)
|
|
90
|
+
|
|
91
|
+
def to_list(self) -> list:
|
|
92
|
+
input_list = [self.encoder_input_ids, self.encoder_attention_mask]
|
|
93
|
+
if self.decoder_input_ids is not None:
|
|
94
|
+
input_list.append(self.decoder_input_ids)
|
|
95
|
+
return input_list
|
|
96
|
+
|
|
97
|
+
|
|
98
|
+
class T5EncoderDecoderInitHelper:
|
|
99
|
+
@staticmethod
|
|
100
|
+
def export_onnx(
|
|
101
|
+
model: T5EncoderDecoderInit,
|
|
102
|
+
device: torch.device,
|
|
103
|
+
onnx_model_path: str,
|
|
104
|
+
use_decoder_input_ids: bool = True,
|
|
105
|
+
verbose: bool = True,
|
|
106
|
+
use_external_data_format: bool = False,
|
|
107
|
+
use_int32_inputs: bool = False,
|
|
108
|
+
):
|
|
109
|
+
"""Export decoder to ONNX
|
|
110
|
+
|
|
111
|
+
Args:
|
|
112
|
+
model (T5EncoderDecoderInit): the model to export
|
|
113
|
+
device (torch.device): device of decoder object
|
|
114
|
+
onnx_model_path (str): onnx path
|
|
115
|
+
verbose (bool, optional): print verbose information. Defaults to True.
|
|
116
|
+
use_external_data_format (bool, optional): use external data format or not. Defaults to False.
|
|
117
|
+
use_int32_inputs (bool, optional): use int32 instead of int64 for integer inputs. Defaults to False.
|
|
118
|
+
"""
|
|
119
|
+
assert isinstance(model, T5EncoderDecoderInit)
|
|
120
|
+
|
|
121
|
+
# Do not exclude decoder in torch onnx export so that cross can show up.
|
|
122
|
+
output_cross_only = model.output_cross_only
|
|
123
|
+
model.output_cross_only = False
|
|
124
|
+
|
|
125
|
+
inputs = T5EncoderDecoderInitInputs.create_dummy(
|
|
126
|
+
model.config,
|
|
127
|
+
batch_size=2,
|
|
128
|
+
encode_sequence_length=3,
|
|
129
|
+
use_decoder_input_ids=use_decoder_input_ids,
|
|
130
|
+
device=device,
|
|
131
|
+
use_int32_inputs=use_int32_inputs,
|
|
132
|
+
)
|
|
133
|
+
input_list = inputs.to_list()
|
|
134
|
+
|
|
135
|
+
present_names = PastKeyValuesHelper.get_past_names(model.config.num_decoder_layers, present=True)
|
|
136
|
+
|
|
137
|
+
output_names = ["logits", "encoder_hidden_states", *present_names]
|
|
138
|
+
|
|
139
|
+
# Shape of input tensors (sequence_length==1):
|
|
140
|
+
# input_ids: (batch_size, sequence_length)
|
|
141
|
+
# encoder_attention_mask: (batch_size, encode_sequence_length)
|
|
142
|
+
# encoder_hidden_states: (batch_size, encode_sequence_length, hidden_size)
|
|
143
|
+
# past_self_*: (batch_size, num_heads, past_decode_sequence_length, head_size)
|
|
144
|
+
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
145
|
+
|
|
146
|
+
# Shape of output tensors:
|
|
147
|
+
# logits: (batch_size, sequence_length, vocab_size)
|
|
148
|
+
# past_self_*: (batch_size, num_heads, past_decode_sequence_length + sequence_length, head_size)
|
|
149
|
+
# past_cross_*: (batch_size, num_heads, encode_sequence_length, head_size)
|
|
150
|
+
|
|
151
|
+
input_names = ["encoder_input_ids", "encoder_attention_mask"]
|
|
152
|
+
|
|
153
|
+
# ONNX exporter might mark dimension like 'present_value_self_1_dim_2' in shape inference.
|
|
154
|
+
# We use a workaround here: first use dim_param "1" for sequence_length, and later change to dim_value.
|
|
155
|
+
sequence_length = "1"
|
|
156
|
+
num_heads = str(model.config.num_heads)
|
|
157
|
+
hidden_size = str(model.config.d_model)
|
|
158
|
+
head_size = str(model.config.d_kv)
|
|
159
|
+
|
|
160
|
+
dynamic_axes = {
|
|
161
|
+
"encoder_input_ids": {0: "batch_size", 1: "encode_sequence_length"},
|
|
162
|
+
"encoder_attention_mask": {0: "batch_size", 1: "encode_sequence_length"},
|
|
163
|
+
"encoder_hidden_states": {
|
|
164
|
+
0: "batch_size",
|
|
165
|
+
1: "encode_sequence_length",
|
|
166
|
+
2: hidden_size,
|
|
167
|
+
},
|
|
168
|
+
"logits": {
|
|
169
|
+
0: "batch_size",
|
|
170
|
+
1: sequence_length,
|
|
171
|
+
},
|
|
172
|
+
}
|
|
173
|
+
|
|
174
|
+
if use_decoder_input_ids:
|
|
175
|
+
input_names.append("decoder_input_ids")
|
|
176
|
+
dynamic_axes["decoder_input_ids"] = {
|
|
177
|
+
0: "batch_size",
|
|
178
|
+
1: sequence_length,
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
for name in present_names:
|
|
182
|
+
if "cross" in name:
|
|
183
|
+
dynamic_axes[name] = {
|
|
184
|
+
0: "batch_size",
|
|
185
|
+
1: num_heads,
|
|
186
|
+
2: "encode_sequence_length",
|
|
187
|
+
3: head_size,
|
|
188
|
+
}
|
|
189
|
+
|
|
190
|
+
else: # self attention past state
|
|
191
|
+
dynamic_axes[name] = {
|
|
192
|
+
0: "batch_size",
|
|
193
|
+
1: num_heads,
|
|
194
|
+
2: sequence_length,
|
|
195
|
+
3: head_size,
|
|
196
|
+
}
|
|
197
|
+
|
|
198
|
+
with tempfile.TemporaryDirectory() as tmp_dir_name:
|
|
199
|
+
temp_onnx_model_path = os.path.join(tmp_dir_name, "encoder_decoder_init.onnx")
|
|
200
|
+
Path(temp_onnx_model_path).parent.mkdir(parents=True, exist_ok=True)
|
|
201
|
+
torch_onnx_export(
|
|
202
|
+
model,
|
|
203
|
+
args=tuple(input_list),
|
|
204
|
+
f=temp_onnx_model_path,
|
|
205
|
+
export_params=True,
|
|
206
|
+
input_names=input_names,
|
|
207
|
+
output_names=output_names,
|
|
208
|
+
dynamic_axes=dynamic_axes,
|
|
209
|
+
opset_version=12,
|
|
210
|
+
do_constant_folding=True,
|
|
211
|
+
use_external_data_format=use_external_data_format,
|
|
212
|
+
verbose=verbose,
|
|
213
|
+
)
|
|
214
|
+
|
|
215
|
+
# Restore output_cross_only setting.
|
|
216
|
+
model.output_cross_only = output_cross_only
|
|
217
|
+
|
|
218
|
+
# Workaround as mentioned earlier: change numeric dim_param to dim_value
|
|
219
|
+
exported_model: onnx.ModelProto = onnx.load(temp_onnx_model_path)
|
|
220
|
+
for tensor in exported_model.graph.output:
|
|
221
|
+
for dim_proto in tensor.type.tensor_type.shape.dim:
|
|
222
|
+
if dim_proto.HasField("dim_param") and dim_proto.dim_param in [
|
|
223
|
+
sequence_length,
|
|
224
|
+
num_heads,
|
|
225
|
+
hidden_size,
|
|
226
|
+
head_size,
|
|
227
|
+
]:
|
|
228
|
+
dim_value = int(dim_proto.dim_param)
|
|
229
|
+
dim_proto.Clear()
|
|
230
|
+
dim_proto.dim_value = dim_value
|
|
231
|
+
|
|
232
|
+
if output_cross_only:
|
|
233
|
+
# Rewrite onnx graph to only keep present_[key|value]_cross_* outputs.
|
|
234
|
+
onnx_model = OnnxModel(exported_model)
|
|
235
|
+
output_name_to_node = onnx_model.output_name_to_node()
|
|
236
|
+
|
|
237
|
+
for output in exported_model.graph.output:
|
|
238
|
+
if "cross" in output.name:
|
|
239
|
+
assert output.name in output_name_to_node
|
|
240
|
+
|
|
241
|
+
transpose_node = output_name_to_node[output.name]
|
|
242
|
+
assert transpose_node and transpose_node.op_type == "Transpose"
|
|
243
|
+
|
|
244
|
+
permutation = OnnxModel.get_node_attribute(transpose_node, "perm")
|
|
245
|
+
assert isinstance(permutation, list)
|
|
246
|
+
assert permutation == [0, 2, 1, 3]
|
|
247
|
+
|
|
248
|
+
matched_nodes = onnx_model.match_parent_path(
|
|
249
|
+
transpose_node,
|
|
250
|
+
["Reshape", "MatMul"],
|
|
251
|
+
[0, 0],
|
|
252
|
+
output_name_to_node,
|
|
253
|
+
)
|
|
254
|
+
assert matched_nodes is not None
|
|
255
|
+
|
|
256
|
+
reshape_node, matmul_node = matched_nodes
|
|
257
|
+
assert "encoder_hidden_states" in matmul_node.input
|
|
258
|
+
|
|
259
|
+
if not onnx_model.get_initializer("cross_reshape_shape"):
|
|
260
|
+
shape_tensor = onnx.helper.make_tensor(
|
|
261
|
+
name="cross_reshape_shape",
|
|
262
|
+
data_type=onnx.TensorProto.INT64,
|
|
263
|
+
dims=[4],
|
|
264
|
+
vals=[0, 0, int(num_heads), int(head_size)],
|
|
265
|
+
raw=False,
|
|
266
|
+
)
|
|
267
|
+
onnx_model.add_initializer(shape_tensor)
|
|
268
|
+
|
|
269
|
+
reshape_node.input[1] = "cross_reshape_shape"
|
|
270
|
+
|
|
271
|
+
cross_outputs = [output.name for output in exported_model.graph.output if "cross" in output.name]
|
|
272
|
+
onnx_model.prune_graph(cross_outputs, allow_remove_graph_inputs=True)
|
|
273
|
+
|
|
274
|
+
OnnxModel.save(
|
|
275
|
+
exported_model,
|
|
276
|
+
onnx_model_path,
|
|
277
|
+
save_as_external_data=use_external_data_format,
|
|
278
|
+
all_tensors_to_one_file=True,
|
|
279
|
+
)
|
|
280
|
+
|
|
281
|
+
@staticmethod
|
|
282
|
+
def onnxruntime_inference(ort_session, inputs: T5EncoderDecoderInitInputs):
|
|
283
|
+
"""Run inference of ONNX model."""
|
|
284
|
+
logger.debug("start onnxruntime_inference")
|
|
285
|
+
|
|
286
|
+
ort_inputs = {
|
|
287
|
+
"encoder_input_ids": numpy.ascontiguousarray(inputs.encoder_input_ids.cpu().numpy()),
|
|
288
|
+
"encoder_attention_mask": numpy.ascontiguousarray(inputs.encoder_attention_mask.cpu().numpy()),
|
|
289
|
+
}
|
|
290
|
+
if inputs.decoder_input_ids is not None:
|
|
291
|
+
ort_inputs["decoder_input_ids"] = numpy.ascontiguousarray(inputs.decoder_input_ids.cpu().numpy())
|
|
292
|
+
|
|
293
|
+
ort_outputs = ort_session.run(None, ort_inputs)
|
|
294
|
+
return ort_outputs
|
|
295
|
+
|
|
296
|
+
@staticmethod
|
|
297
|
+
def verify_onnx(
|
|
298
|
+
model: T5EncoderDecoderInit,
|
|
299
|
+
ort_session: InferenceSession,
|
|
300
|
+
device: torch.device,
|
|
301
|
+
use_int32_inputs: bool,
|
|
302
|
+
max_cases: int = 4,
|
|
303
|
+
):
|
|
304
|
+
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
305
|
+
ort_inputs = ort_session.get_inputs()
|
|
306
|
+
use_decoder_input_ids = len(ort_inputs) == 3
|
|
307
|
+
|
|
308
|
+
test_cases = [(4, 11), (1, 2), (3, 1), (8, 5)]
|
|
309
|
+
test_cases_max_diff = []
|
|
310
|
+
for batch_size, encode_sequence_length in test_cases[:max_cases]:
|
|
311
|
+
inputs = T5EncoderDecoderInitInputs.create_dummy(
|
|
312
|
+
model.config,
|
|
313
|
+
batch_size,
|
|
314
|
+
encode_sequence_length,
|
|
315
|
+
use_decoder_input_ids=use_decoder_input_ids,
|
|
316
|
+
device=device,
|
|
317
|
+
use_int32_inputs=use_int32_inputs,
|
|
318
|
+
)
|
|
319
|
+
|
|
320
|
+
ort_outputs = T5EncoderDecoderInitHelper.onnxruntime_inference(ort_session, inputs)
|
|
321
|
+
|
|
322
|
+
# Run inference of PyTorch model
|
|
323
|
+
input_list = inputs.to_list()
|
|
324
|
+
torch_outputs = model(*input_list)
|
|
325
|
+
|
|
326
|
+
num_decoder_layers = model.config.num_decoder_layers
|
|
327
|
+
|
|
328
|
+
if not model.output_cross_only:
|
|
329
|
+
assert torch_outputs[0].cpu().numpy().shape == ort_outputs[0].shape
|
|
330
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[0].cpu().numpy() - ort_outputs[0]))
|
|
331
|
+
logger.debug(f"logits max_diff={max_diff}")
|
|
332
|
+
max_diff_all = max_diff
|
|
333
|
+
|
|
334
|
+
assert torch_outputs[1].cpu().numpy().shape == ort_outputs[1].shape
|
|
335
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[1].cpu().numpy() - ort_outputs[1]))
|
|
336
|
+
logger.debug(f"encoder_hidden_states max_diff={max_diff}")
|
|
337
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
338
|
+
|
|
339
|
+
for i in range(2 * num_decoder_layers):
|
|
340
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[2][i].cpu().numpy() - ort_outputs[2 + i]))
|
|
341
|
+
logger.debug(f"self attention past state {i} max_diff={max_diff}")
|
|
342
|
+
|
|
343
|
+
for i in range(2 * num_decoder_layers):
|
|
344
|
+
max_diff = numpy.amax(
|
|
345
|
+
numpy.abs(torch_outputs[3][i].cpu().numpy() - ort_outputs[2 + 2 * num_decoder_layers + i])
|
|
346
|
+
)
|
|
347
|
+
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
|
348
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
349
|
+
else:
|
|
350
|
+
max_diff_all = -float("inf")
|
|
351
|
+
for i in range(2 * num_decoder_layers):
|
|
352
|
+
max_diff = numpy.amax(numpy.abs(torch_outputs[i].cpu().numpy() - ort_outputs[i]))
|
|
353
|
+
logger.debug(f"cross attention past state {i} max_diff={max_diff}")
|
|
354
|
+
max_diff_all = max(max_diff_all, max_diff)
|
|
355
|
+
|
|
356
|
+
test_cases_max_diff.append(max_diff_all)
|
|
357
|
+
logger.info(
|
|
358
|
+
f"batch_size={batch_size} encode_sequence_length={encode_sequence_length}, max_diff={max_diff_all}"
|
|
359
|
+
)
|
|
360
|
+
|
|
361
|
+
return max(test_cases_max_diff)
|
|
@@ -0,0 +1,302 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# -------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import logging
|
|
7
|
+
import os
|
|
8
|
+
from pathlib import Path
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from float16 import float_to_float16_max_diff
|
|
12
|
+
from onnx_model import OnnxModel
|
|
13
|
+
from optimizer import optimize_model
|
|
14
|
+
from t5_decoder import T5Decoder, T5DecoderHelper
|
|
15
|
+
from t5_encoder_decoder_init import T5EncoderDecoderInit, T5EncoderDecoderInitHelper
|
|
16
|
+
from transformers import MT5ForConditionalGeneration, T5ForConditionalGeneration
|
|
17
|
+
|
|
18
|
+
from onnxruntime import InferenceSession
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
PRETRAINED_T5_MODELS = ["t5-small", "t5-base", "t5-large", "t5-3b", "t5-11b"]
|
|
23
|
+
PRETRAINED_MT5_MODELS = [
|
|
24
|
+
"google/mt5-small",
|
|
25
|
+
"google/mt5-base",
|
|
26
|
+
"google/mt5-large",
|
|
27
|
+
"google/mt5-xl",
|
|
28
|
+
"google/mt5-xxl",
|
|
29
|
+
]
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class T5Helper:
|
|
33
|
+
@staticmethod
|
|
34
|
+
def get_onnx_path(
|
|
35
|
+
output_dir: str,
|
|
36
|
+
model_name_or_path: str,
|
|
37
|
+
suffix: str = "",
|
|
38
|
+
new_folder: bool = False,
|
|
39
|
+
) -> str:
|
|
40
|
+
"""Build onnx path
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
output_dir (str): output directory
|
|
44
|
+
model_name_or_path (str): pretrained model name, or path to the model checkpoint
|
|
45
|
+
suffix (str, optional): suffix like "_encoder" or "_decoder_fp16" will be appended to file name. Defaults to None.
|
|
46
|
+
new_folder (bool, optional): create a new directory for the model. Defaults to False.
|
|
47
|
+
|
|
48
|
+
Returns:
|
|
49
|
+
str: path of onnx model
|
|
50
|
+
"""
|
|
51
|
+
model_name = model_name_or_path
|
|
52
|
+
if os.path.isdir(model_name_or_path):
|
|
53
|
+
model_name = Path(model_name_or_path).parts[-1]
|
|
54
|
+
else:
|
|
55
|
+
model_name.split("/")[-1]
|
|
56
|
+
|
|
57
|
+
model_name += suffix
|
|
58
|
+
|
|
59
|
+
directory = os.path.join(output_dir, model_name) if new_folder else output_dir
|
|
60
|
+
return os.path.join(directory, model_name + ".onnx")
|
|
61
|
+
|
|
62
|
+
@staticmethod
|
|
63
|
+
def load_model(
|
|
64
|
+
model_name_or_path: str,
|
|
65
|
+
cache_dir: str,
|
|
66
|
+
device: torch.device,
|
|
67
|
+
model_type: str = "t5",
|
|
68
|
+
state_dict_path: str = "",
|
|
69
|
+
encoder_decoder_init: bool = False,
|
|
70
|
+
) -> dict[str, T5EncoderDecoderInit | T5Decoder]:
|
|
71
|
+
"""Load model given a pretrained name or path, then build models for ONNX conversion.
|
|
72
|
+
|
|
73
|
+
Args:
|
|
74
|
+
model_name_or_path (str): pretrained model name or path
|
|
75
|
+
cache_dir (str): cache directory
|
|
76
|
+
device (torch.device): device to run the model
|
|
77
|
+
model_type (str, optional): model type "t5" or "mt5"
|
|
78
|
+
state_dict_path(str, optional): state dictionary path
|
|
79
|
+
encoder_decoder_init (bool, optional): combine encoder and decoder kv cache initialization into one model.
|
|
80
|
+
Returns:
|
|
81
|
+
Dict[str, torch.nn.Module]: mapping from name to modules for ONNX conversion.
|
|
82
|
+
"""
|
|
83
|
+
if model_type == "t5":
|
|
84
|
+
model = T5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
85
|
+
elif model_type == "mt5":
|
|
86
|
+
model = MT5ForConditionalGeneration.from_pretrained(model_name_or_path, cache_dir=cache_dir)
|
|
87
|
+
else:
|
|
88
|
+
raise ValueError("only support mode_type=t5 or mt5")
|
|
89
|
+
|
|
90
|
+
if state_dict_path:
|
|
91
|
+
model.load_state_dict(torch.load(state_dict_path))
|
|
92
|
+
|
|
93
|
+
decoder = T5Decoder(model.decoder, model.lm_head, model.config)
|
|
94
|
+
decoder.eval().to(device)
|
|
95
|
+
|
|
96
|
+
encoder = T5EncoderDecoderInit(
|
|
97
|
+
model.encoder,
|
|
98
|
+
model.decoder,
|
|
99
|
+
model.lm_head,
|
|
100
|
+
model.config,
|
|
101
|
+
decoder_start_token_id=None,
|
|
102
|
+
output_cross_only=not encoder_decoder_init,
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
encoder_name = "encoder_decoder_init" if encoder_decoder_init else "encoder"
|
|
106
|
+
return {encoder_name: encoder, "decoder": decoder}
|
|
107
|
+
|
|
108
|
+
@staticmethod
|
|
109
|
+
def export_onnx(
|
|
110
|
+
model: T5Decoder | T5EncoderDecoderInit,
|
|
111
|
+
device: torch.device,
|
|
112
|
+
onnx_model_path: str,
|
|
113
|
+
verbose: bool = True,
|
|
114
|
+
use_external_data_format: bool = False,
|
|
115
|
+
use_decoder_input_ids: bool = True,
|
|
116
|
+
use_int32_inputs: bool = False,
|
|
117
|
+
):
|
|
118
|
+
if isinstance(model, T5EncoderDecoderInit):
|
|
119
|
+
T5EncoderDecoderInitHelper.export_onnx(
|
|
120
|
+
model,
|
|
121
|
+
device,
|
|
122
|
+
onnx_model_path,
|
|
123
|
+
use_decoder_input_ids,
|
|
124
|
+
verbose,
|
|
125
|
+
use_external_data_format,
|
|
126
|
+
use_int32_inputs,
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
T5DecoderHelper.export_onnx(
|
|
130
|
+
model,
|
|
131
|
+
device,
|
|
132
|
+
onnx_model_path,
|
|
133
|
+
verbose,
|
|
134
|
+
use_external_data_format,
|
|
135
|
+
use_int32_inputs,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
@staticmethod
|
|
139
|
+
def auto_mixed_precision(
|
|
140
|
+
onnx_model: OnnxModel,
|
|
141
|
+
op_block_list: list[str] | None = None,
|
|
142
|
+
force_fp16_logits: bool = False,
|
|
143
|
+
use_symbolic_shape_infer: bool = True,
|
|
144
|
+
):
|
|
145
|
+
"""Convert model to mixed precision.
|
|
146
|
+
It detects whether original model has fp16 precision weights, and set parameters for float16 conversion automatically.
|
|
147
|
+
Args:
|
|
148
|
+
onnx_model (OnnxModel): optimized ONNX model
|
|
149
|
+
op_block_list (List[str], optional): operators need to run in fp32.
|
|
150
|
+
force_fp16_logits (bool, optional): force logits and last MatMul node to be in float16. Defaults to False.
|
|
151
|
+
use_symbolic_shape_infer (bool, optional): use symbolic shape inference to convert float to float16. Defaults to True.
|
|
152
|
+
Returns:
|
|
153
|
+
parameters(dict): a dictionary of parameters used in float16 conversion
|
|
154
|
+
"""
|
|
155
|
+
if op_block_list is None:
|
|
156
|
+
op_block_list = [
|
|
157
|
+
"SimplifiedLayerNormalization",
|
|
158
|
+
"SkipSimplifiedLayerNormalization",
|
|
159
|
+
"Relu",
|
|
160
|
+
"Add",
|
|
161
|
+
]
|
|
162
|
+
|
|
163
|
+
op_full_set = {node.op_type for node in onnx_model.nodes()}
|
|
164
|
+
fp32_op_set = set(op_block_list)
|
|
165
|
+
fp16_op_set = op_full_set.difference(fp32_op_set)
|
|
166
|
+
logger.info(f"fp32 op: {fp32_op_set} fp16 op: {fp16_op_set}")
|
|
167
|
+
|
|
168
|
+
# logits is the first output
|
|
169
|
+
logits_output_name = onnx_model.graph().output[0].name
|
|
170
|
+
|
|
171
|
+
# We use the weight in last MatMul node to detect whether the model is stored with float16 weights from training.
|
|
172
|
+
is_weight_fp16_precision = False
|
|
173
|
+
output_name_to_node = onnx_model.output_name_to_node()
|
|
174
|
+
assert logits_output_name in output_name_to_node
|
|
175
|
+
node = output_name_to_node[logits_output_name]
|
|
176
|
+
last_matmul_node = None
|
|
177
|
+
if node.op_type == "MatMul":
|
|
178
|
+
last_matmul_node = node
|
|
179
|
+
logger.info(f"Found last MatMul node for logits: {node.name}")
|
|
180
|
+
initializer = None
|
|
181
|
+
for input in node.input:
|
|
182
|
+
initializer = onnx_model.get_initializer(input)
|
|
183
|
+
if initializer is not None:
|
|
184
|
+
break
|
|
185
|
+
|
|
186
|
+
# when the max difference of value after converting float to float16 is lower than a threshold (1e-6),
|
|
187
|
+
# we can deduce that the weights are stored in float16 precision.
|
|
188
|
+
max_diff = float_to_float16_max_diff(initializer)
|
|
189
|
+
logger.debug(f"max diff of converting weights in last MatMul node {node.name}: {max_diff}")
|
|
190
|
+
is_weight_fp16_precision = max_diff < 1e-6
|
|
191
|
+
else:
|
|
192
|
+
logger.warning(f"Failed to find MatMul node for logits. Found {node.op_type} of node {node.name}")
|
|
193
|
+
|
|
194
|
+
keep_io_types = []
|
|
195
|
+
node_block_list = []
|
|
196
|
+
if (not is_weight_fp16_precision) and (last_matmul_node is not None) and not force_fp16_logits:
|
|
197
|
+
# When original weight is float32 precision, keep logits and last MatMul in float32 could get better precision.
|
|
198
|
+
keep_io_types = [logits_output_name]
|
|
199
|
+
node_block_list = [last_matmul_node.name]
|
|
200
|
+
|
|
201
|
+
if "Add" not in op_block_list:
|
|
202
|
+
input_name_to_nodes = onnx_model.input_name_to_nodes()
|
|
203
|
+
fp32_add = 0
|
|
204
|
+
changed = True
|
|
205
|
+
add_nodes = onnx_model.get_nodes_by_op_type("Add")
|
|
206
|
+
while changed:
|
|
207
|
+
changed = False
|
|
208
|
+
for node in add_nodes:
|
|
209
|
+
if node.name not in node_block_list:
|
|
210
|
+
parents = onnx_model.get_parents(node, output_name_to_node)
|
|
211
|
+
children = onnx_model.get_children(node, input_name_to_nodes)
|
|
212
|
+
blocked_children = [
|
|
213
|
+
child for child in children if child.op_type in op_block_list or child in node_block_list
|
|
214
|
+
]
|
|
215
|
+
blocked_parents = [
|
|
216
|
+
parent for parent in parents if parent.op_type in op_block_list or parent in node_block_list
|
|
217
|
+
]
|
|
218
|
+
# If any child or parent is in fp32, we place the Add node to fp32.
|
|
219
|
+
if (len(blocked_children) + len(blocked_parents)) > 0:
|
|
220
|
+
node_block_list.append(node.name)
|
|
221
|
+
fp32_add += 1
|
|
222
|
+
changed = True
|
|
223
|
+
fp16_add = len(add_nodes) - fp32_add
|
|
224
|
+
logger.info(f"node counter of Add operator: fp32={fp32_add} fp16={fp16_add}")
|
|
225
|
+
|
|
226
|
+
logger.info(f"node_block_list: {node_block_list}")
|
|
227
|
+
|
|
228
|
+
parameters = {
|
|
229
|
+
"keep_io_types": keep_io_types,
|
|
230
|
+
"op_block_list": op_block_list,
|
|
231
|
+
"node_block_list": node_block_list,
|
|
232
|
+
"force_fp16_initializers": is_weight_fp16_precision,
|
|
233
|
+
}
|
|
234
|
+
|
|
235
|
+
logger.info(f"auto_mixed_precision parameters: {parameters}")
|
|
236
|
+
if use_symbolic_shape_infer:
|
|
237
|
+
onnx_model.convert_float_to_float16(use_symbolic_shape_infer=True, **parameters)
|
|
238
|
+
else:
|
|
239
|
+
# Workaround when symbolic shape inference fails.
|
|
240
|
+
# Need enable shape_infer_before_optimization in convert_to_onnx.py as well.
|
|
241
|
+
from float16 import convert_float_to_float16 # noqa: PLC0415
|
|
242
|
+
|
|
243
|
+
convert_float_to_float16(
|
|
244
|
+
onnx_model.model,
|
|
245
|
+
disable_shape_infer=True,
|
|
246
|
+
**parameters,
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return parameters
|
|
250
|
+
|
|
251
|
+
@staticmethod
|
|
252
|
+
def optimize_onnx(
|
|
253
|
+
onnx_model_path: str,
|
|
254
|
+
optimized_model_path: str,
|
|
255
|
+
is_float16: bool,
|
|
256
|
+
num_attention_heads: int,
|
|
257
|
+
hidden_size: int,
|
|
258
|
+
use_external_data_format: bool = False,
|
|
259
|
+
auto_mixed_precision: bool = True,
|
|
260
|
+
use_gpu: bool = False,
|
|
261
|
+
force_fp16_io: bool = False,
|
|
262
|
+
):
|
|
263
|
+
"""Optimize ONNX model with an option to convert it to use mixed precision."""
|
|
264
|
+
|
|
265
|
+
from fusion_options import FusionOptions # noqa: PLC0415
|
|
266
|
+
|
|
267
|
+
optimization_options = None
|
|
268
|
+
if is_float16:
|
|
269
|
+
optimization_options = FusionOptions("t5")
|
|
270
|
+
# SkipLayerNormalization is faster but might bring accuracy drop since it uses fp16 accumulation.
|
|
271
|
+
optimization_options.enable_skip_layer_norm = not auto_mixed_precision
|
|
272
|
+
|
|
273
|
+
m = optimize_model(
|
|
274
|
+
onnx_model_path,
|
|
275
|
+
model_type="t5",
|
|
276
|
+
num_heads=num_attention_heads,
|
|
277
|
+
hidden_size=hidden_size,
|
|
278
|
+
opt_level=0,
|
|
279
|
+
optimization_options=optimization_options,
|
|
280
|
+
use_gpu=use_gpu,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
if is_float16:
|
|
284
|
+
if auto_mixed_precision:
|
|
285
|
+
T5Helper.auto_mixed_precision(m, force_fp16_logits=force_fp16_io)
|
|
286
|
+
else:
|
|
287
|
+
m.convert_model_float32_to_float16(cast_input_output=force_fp16_io)
|
|
288
|
+
|
|
289
|
+
m.save_model_to_file(optimized_model_path, use_external_data_format, all_tensors_to_one_file=True)
|
|
290
|
+
|
|
291
|
+
@staticmethod
|
|
292
|
+
def verify_onnx(
|
|
293
|
+
model: T5Decoder | T5EncoderDecoderInit,
|
|
294
|
+
ort_session: InferenceSession,
|
|
295
|
+
device: torch.device,
|
|
296
|
+
use_int32_inputs: bool,
|
|
297
|
+
):
|
|
298
|
+
"""Compare the result from PyTorch and OnnxRuntime to verify the ONNX model is good."""
|
|
299
|
+
if isinstance(model, T5EncoderDecoderInit):
|
|
300
|
+
return T5EncoderDecoderInitHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
|
301
|
+
|
|
302
|
+
return T5DecoderHelper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|