onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import tensorrt as trt
|
|
6
|
+
|
|
7
|
+
TRT_LOGGER = trt.Logger(trt.Logger.ERROR)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def init_trt_plugins():
|
|
11
|
+
# Register TensorRT plugins
|
|
12
|
+
trt.init_libnvinfer_plugins(TRT_LOGGER, "")
|
|
@@ -0,0 +1,12 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import os.path
|
|
6
|
+
import sys
|
|
7
|
+
|
|
8
|
+
sys.path.append(os.path.dirname(__file__))
|
|
9
|
+
|
|
10
|
+
transformers_dir = os.path.normpath(os.path.join(os.path.dirname(__file__), "..", ".."))
|
|
11
|
+
if transformers_dir not in sys.path:
|
|
12
|
+
sys.path.append(transformers_dir)
|
|
@@ -0,0 +1,278 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import copy
|
|
9
|
+
import logging
|
|
10
|
+
import os
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from benchmark_helper import Precision, create_onnxruntime_session, prepare_environment, setup_logger
|
|
14
|
+
from t5_helper import PRETRAINED_MT5_MODELS, PRETRAINED_T5_MODELS, T5Helper
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger("")
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
def parse_arguments():
|
|
20
|
+
parser = argparse.ArgumentParser()
|
|
21
|
+
|
|
22
|
+
pretrained_models = PRETRAINED_T5_MODELS + PRETRAINED_MT5_MODELS
|
|
23
|
+
parser.add_argument(
|
|
24
|
+
"-m",
|
|
25
|
+
"--model_name_or_path",
|
|
26
|
+
required=False,
|
|
27
|
+
default=PRETRAINED_T5_MODELS[0],
|
|
28
|
+
type=str,
|
|
29
|
+
help="Model path, or pretrained model name in the list: " + ", ".join(pretrained_models),
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
parser.add_argument(
|
|
33
|
+
"--model_type",
|
|
34
|
+
required=False,
|
|
35
|
+
type=str,
|
|
36
|
+
default="t5",
|
|
37
|
+
choices=["t5", "mt5"],
|
|
38
|
+
help="Model type: either t5 (default) or mt5",
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
parser.add_argument(
|
|
42
|
+
"--cache_dir",
|
|
43
|
+
required=False,
|
|
44
|
+
type=str,
|
|
45
|
+
default=os.path.join(".", "cache_models"),
|
|
46
|
+
help="Directory to cache pre-trained models",
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
parser.add_argument(
|
|
50
|
+
"--output",
|
|
51
|
+
required=False,
|
|
52
|
+
type=str,
|
|
53
|
+
default=os.path.join(".", "onnx_models"),
|
|
54
|
+
help="Output directory",
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
parser.add_argument(
|
|
58
|
+
"-o",
|
|
59
|
+
"--optimize_onnx",
|
|
60
|
+
required=False,
|
|
61
|
+
action="store_true",
|
|
62
|
+
help="Use optimizer.py to optimize onnx model",
|
|
63
|
+
)
|
|
64
|
+
parser.set_defaults(optimize_onnx=False)
|
|
65
|
+
|
|
66
|
+
parser.add_argument("--use_gpu", required=False, action="store_true", help="use GPU for inference")
|
|
67
|
+
parser.set_defaults(use_gpu=False)
|
|
68
|
+
|
|
69
|
+
parser.add_argument(
|
|
70
|
+
"-p",
|
|
71
|
+
"--precision",
|
|
72
|
+
required=False,
|
|
73
|
+
type=Precision,
|
|
74
|
+
default=Precision.FLOAT32,
|
|
75
|
+
choices=[Precision.FLOAT32, Precision.FLOAT16],
|
|
76
|
+
help="Precision of model to run. fp32 for full precision, fp16 for half precision",
|
|
77
|
+
)
|
|
78
|
+
|
|
79
|
+
parser.add_argument("--verbose", required=False, action="store_true")
|
|
80
|
+
parser.set_defaults(verbose=False)
|
|
81
|
+
|
|
82
|
+
parser.add_argument("-e", "--use_external_data_format", required=False, action="store_true")
|
|
83
|
+
parser.set_defaults(use_external_data_format=False)
|
|
84
|
+
|
|
85
|
+
parser.add_argument(
|
|
86
|
+
"-s",
|
|
87
|
+
"--use_decoder_start_token",
|
|
88
|
+
required=False,
|
|
89
|
+
action="store_true",
|
|
90
|
+
help="Use config.decoder_start_token_id. Otherwise, add an extra graph input for decoder_input_ids.",
|
|
91
|
+
)
|
|
92
|
+
parser.set_defaults(use_decoder_start_token=False)
|
|
93
|
+
|
|
94
|
+
parser.add_argument(
|
|
95
|
+
"-w",
|
|
96
|
+
"--overwrite",
|
|
97
|
+
required=False,
|
|
98
|
+
action="store_true",
|
|
99
|
+
help="overwrite existing ONNX model",
|
|
100
|
+
)
|
|
101
|
+
parser.set_defaults(overwrite=False)
|
|
102
|
+
|
|
103
|
+
parser.add_argument(
|
|
104
|
+
"--disable_auto_mixed_precision",
|
|
105
|
+
required=False,
|
|
106
|
+
action="store_true",
|
|
107
|
+
help="use pure fp16 instead of mixed precision",
|
|
108
|
+
)
|
|
109
|
+
parser.set_defaults(disable_auto_mixed_precision=False)
|
|
110
|
+
|
|
111
|
+
parser.add_argument(
|
|
112
|
+
"--separate_encoder_and_decoder_init",
|
|
113
|
+
required=False,
|
|
114
|
+
action="store_true",
|
|
115
|
+
help="Do not merge encode and decoder init. Output 3 instead of 2 onnx models.",
|
|
116
|
+
)
|
|
117
|
+
parser.set_defaults(separate_encoder_and_decoder_init=False)
|
|
118
|
+
|
|
119
|
+
parser.add_argument(
|
|
120
|
+
"--use_int64_inputs",
|
|
121
|
+
required=False,
|
|
122
|
+
action="store_true",
|
|
123
|
+
help="Use int64 instead of int32 for input_ids, position_ids and attention_mask.",
|
|
124
|
+
)
|
|
125
|
+
parser.set_defaults(use_int64_inputs=False)
|
|
126
|
+
|
|
127
|
+
parser.add_argument(
|
|
128
|
+
"--state_dict_path",
|
|
129
|
+
type=str,
|
|
130
|
+
default="",
|
|
131
|
+
help="filepath to load pre-trained model with custom state dictionary (e.g. pytorch_model.bin)",
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
args = parser.parse_args()
|
|
135
|
+
|
|
136
|
+
return args
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
def export_onnx_models(
|
|
140
|
+
model_name_or_path,
|
|
141
|
+
cache_dir,
|
|
142
|
+
output_dir,
|
|
143
|
+
use_gpu,
|
|
144
|
+
use_external_data_format,
|
|
145
|
+
optimize_onnx,
|
|
146
|
+
precision,
|
|
147
|
+
verbose,
|
|
148
|
+
use_decoder_start_token: bool = False,
|
|
149
|
+
merge_encoder_and_decoder_init: bool = True,
|
|
150
|
+
overwrite: bool = False,
|
|
151
|
+
disable_auto_mixed_precision: bool = False,
|
|
152
|
+
use_int32_inputs: bool = True,
|
|
153
|
+
model_type: str = "t5",
|
|
154
|
+
state_dict_path: str = "",
|
|
155
|
+
):
|
|
156
|
+
device = torch.device("cuda:0" if use_gpu else "cpu")
|
|
157
|
+
|
|
158
|
+
models = T5Helper.load_model(
|
|
159
|
+
model_name_or_path, cache_dir, device, merge_encoder_and_decoder_init, model_type, state_dict_path
|
|
160
|
+
)
|
|
161
|
+
config = models["decoder"].config
|
|
162
|
+
|
|
163
|
+
if (not use_external_data_format) and (config.num_layers > 24):
|
|
164
|
+
logger.info("Try use_external_data_format when model size > 2GB")
|
|
165
|
+
|
|
166
|
+
output_paths = []
|
|
167
|
+
for name, model in models.items():
|
|
168
|
+
model.to(device)
|
|
169
|
+
filename_suffix = "_" + name
|
|
170
|
+
|
|
171
|
+
onnx_path = T5Helper.get_onnx_path(
|
|
172
|
+
output_dir,
|
|
173
|
+
model_name_or_path,
|
|
174
|
+
suffix=filename_suffix,
|
|
175
|
+
new_folder=False,
|
|
176
|
+
)
|
|
177
|
+
|
|
178
|
+
if overwrite or not os.path.exists(onnx_path):
|
|
179
|
+
logger.info(f"Exporting ONNX model to {onnx_path}")
|
|
180
|
+
# We have to clone model before exporting onnx, otherwise verify_onnx will report large difference.
|
|
181
|
+
cloned_model = copy.deepcopy(model).to(device)
|
|
182
|
+
T5Helper.export_onnx(
|
|
183
|
+
cloned_model,
|
|
184
|
+
device,
|
|
185
|
+
onnx_path,
|
|
186
|
+
verbose,
|
|
187
|
+
use_external_data_format,
|
|
188
|
+
use_decoder_input_ids=not use_decoder_start_token,
|
|
189
|
+
use_int32_inputs=use_int32_inputs,
|
|
190
|
+
)
|
|
191
|
+
else:
|
|
192
|
+
logger.info(f"Skip exporting: existed ONNX model {onnx_path}")
|
|
193
|
+
|
|
194
|
+
# Optimize ONNX graph. Note that we have not implemented graph optimization for T5 yet.
|
|
195
|
+
if optimize_onnx or precision != Precision.FLOAT32:
|
|
196
|
+
output_path = T5Helper.get_onnx_path(
|
|
197
|
+
output_dir,
|
|
198
|
+
model_name_or_path,
|
|
199
|
+
suffix=filename_suffix + "_" + str(precision),
|
|
200
|
+
new_folder=False,
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
if overwrite or not os.path.exists(output_path):
|
|
204
|
+
logger.info(f"Optimizing model to {output_path}")
|
|
205
|
+
T5Helper.optimize_onnx(
|
|
206
|
+
onnx_path,
|
|
207
|
+
output_path,
|
|
208
|
+
precision == Precision.FLOAT16,
|
|
209
|
+
config.num_heads,
|
|
210
|
+
config.hidden_size,
|
|
211
|
+
use_external_data_format,
|
|
212
|
+
auto_mixed_precision=not disable_auto_mixed_precision,
|
|
213
|
+
use_gpu=use_gpu,
|
|
214
|
+
)
|
|
215
|
+
else:
|
|
216
|
+
logger.info(f"Skip optimizing: existed ONNX model {onnx_path}")
|
|
217
|
+
else:
|
|
218
|
+
output_path = onnx_path
|
|
219
|
+
|
|
220
|
+
ort_session = create_onnxruntime_session(
|
|
221
|
+
output_path,
|
|
222
|
+
use_gpu=use_gpu,
|
|
223
|
+
provider=["CUDAExecutionProvider", "CPUExecutionProvider"] if use_gpu else ["CPUExecutionProvider"],
|
|
224
|
+
)
|
|
225
|
+
|
|
226
|
+
with torch.no_grad():
|
|
227
|
+
max_diff = T5Helper.verify_onnx(model, ort_session, device, use_int32_inputs)
|
|
228
|
+
logger.info(f"PyTorch and OnnxRuntime results max difference = {max_diff}")
|
|
229
|
+
if max_diff > 1e-4:
|
|
230
|
+
logger.warning("PyTorch and OnnxRuntime results are NOT close")
|
|
231
|
+
|
|
232
|
+
output_paths.append(output_path)
|
|
233
|
+
|
|
234
|
+
return output_paths
|
|
235
|
+
|
|
236
|
+
|
|
237
|
+
def main():
|
|
238
|
+
args = parse_arguments()
|
|
239
|
+
|
|
240
|
+
setup_logger(args.verbose)
|
|
241
|
+
|
|
242
|
+
logger.info(f"Arguments:{args}")
|
|
243
|
+
|
|
244
|
+
cache_dir = args.cache_dir
|
|
245
|
+
output_dir = args.output if not args.output.endswith(".onnx") else os.path.dirname(args.output)
|
|
246
|
+
prepare_environment(cache_dir, output_dir, args.use_gpu)
|
|
247
|
+
|
|
248
|
+
if args.precision != Precision.FLOAT32:
|
|
249
|
+
assert args.optimize_onnx, "fp16/int8 requires --optimize_onnx"
|
|
250
|
+
|
|
251
|
+
if args.precision == Precision.FLOAT16:
|
|
252
|
+
assert args.use_gpu, "fp16 requires --use_gpu"
|
|
253
|
+
|
|
254
|
+
if args.optimize_onnx:
|
|
255
|
+
logger.warning("Graph optimization for T5 is not implemented yet.")
|
|
256
|
+
|
|
257
|
+
output_paths = export_onnx_models(
|
|
258
|
+
args.model_name_or_path,
|
|
259
|
+
cache_dir,
|
|
260
|
+
output_dir,
|
|
261
|
+
args.use_gpu,
|
|
262
|
+
args.use_external_data_format,
|
|
263
|
+
args.optimize_onnx,
|
|
264
|
+
args.precision,
|
|
265
|
+
args.verbose,
|
|
266
|
+
args.use_decoder_start_token,
|
|
267
|
+
not args.separate_encoder_and_decoder_init,
|
|
268
|
+
args.overwrite,
|
|
269
|
+
args.disable_auto_mixed_precision,
|
|
270
|
+
not args.use_int64_inputs,
|
|
271
|
+
args.model_type,
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
logger.info(f"Done! Outputs: {output_paths}")
|
|
275
|
+
|
|
276
|
+
|
|
277
|
+
if __name__ == "__main__":
|
|
278
|
+
main()
|
|
@@ -0,0 +1,150 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
import logging
|
|
8
|
+
from typing import List, Tuple
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
logger = logging.getLogger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class PastKeyValuesHelper:
|
|
16
|
+
"""Helper functions to process past key values for encoder-decoder model"""
|
|
17
|
+
|
|
18
|
+
@staticmethod
|
|
19
|
+
def get_past_names(num_layers, present: bool = False):
|
|
20
|
+
past_self_names = []
|
|
21
|
+
past_cross_names = []
|
|
22
|
+
for i in range(num_layers):
|
|
23
|
+
past_self_names.extend(
|
|
24
|
+
[f"present_key_self_{i}", f"present_value_self_{i}"]
|
|
25
|
+
if present
|
|
26
|
+
else [f"past_key_self_{i}", f"past_value_self_{i}"]
|
|
27
|
+
)
|
|
28
|
+
past_cross_names.extend(
|
|
29
|
+
[f"present_key_cross_{i}", f"present_value_cross_{i}"]
|
|
30
|
+
if present
|
|
31
|
+
else [f"past_key_cross_{i}", f"past_value_cross_{i}"]
|
|
32
|
+
)
|
|
33
|
+
return past_self_names + past_cross_names
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def group_by_self_or_cross(present_key_values):
|
|
37
|
+
"""Split present state from grouped by layer to grouped by self/cross attention.
|
|
38
|
+
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
|
39
|
+
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...), (past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
|
40
|
+
|
|
41
|
+
"""
|
|
42
|
+
present_self = []
|
|
43
|
+
present_cross = []
|
|
44
|
+
for _i, present_layer_i in enumerate(present_key_values):
|
|
45
|
+
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
|
46
|
+
(
|
|
47
|
+
present_key_self,
|
|
48
|
+
present_value_self,
|
|
49
|
+
present_key_cross,
|
|
50
|
+
present_value_cross,
|
|
51
|
+
) = present_layer_i
|
|
52
|
+
present_self.extend([present_key_self, present_value_self])
|
|
53
|
+
present_cross.extend([present_key_cross, present_value_cross])
|
|
54
|
+
return present_self, present_cross
|
|
55
|
+
|
|
56
|
+
@staticmethod
|
|
57
|
+
def group_by_layer(past, num_layers):
|
|
58
|
+
"""Reorder past state from grouped by self/cross attention to grouped by layer.
|
|
59
|
+
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ..., past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
|
60
|
+
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0), (past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
|
61
|
+
"""
|
|
62
|
+
assert len(past) == 4 * num_layers
|
|
63
|
+
return tuple(
|
|
64
|
+
[
|
|
65
|
+
past[2 * i],
|
|
66
|
+
past[2 * i + 1],
|
|
67
|
+
past[2 * num_layers + 2 * i],
|
|
68
|
+
past[2 * num_layers + 2 * i + 1],
|
|
69
|
+
]
|
|
70
|
+
for i in range(num_layers)
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def back_group_by_layer(past_key_values: Tuple[Tuple[torch.Tensor]]):
|
|
75
|
+
"""Categorize present_key_values from self and cross attention to layer by layer.
|
|
76
|
+
|
|
77
|
+
Reorder past state from grouped by self/cross attention to grouped by layer.
|
|
78
|
+
Before: past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...,
|
|
79
|
+
past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...
|
|
80
|
+
After: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
|
81
|
+
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1),
|
|
82
|
+
|
|
83
|
+
Args:
|
|
84
|
+
present_key_values: From past_key_values of a model (group by self and cross attention)
|
|
85
|
+
|
|
86
|
+
Returns:
|
|
87
|
+
past_tuples: present key and values grouped by layer.
|
|
88
|
+
"""
|
|
89
|
+
past_tuples = ()
|
|
90
|
+
half_idx = len(past_key_values) // 2
|
|
91
|
+
for i in range(len(past_key_values) // 4):
|
|
92
|
+
idx = 2 * i
|
|
93
|
+
past_tuples += (
|
|
94
|
+
(
|
|
95
|
+
past_key_values[idx],
|
|
96
|
+
past_key_values[idx + 1],
|
|
97
|
+
past_key_values[half_idx + idx],
|
|
98
|
+
past_key_values[half_idx + idx + 1],
|
|
99
|
+
),
|
|
100
|
+
)
|
|
101
|
+
return past_tuples
|
|
102
|
+
|
|
103
|
+
@staticmethod
|
|
104
|
+
def group_by_self_and_cross(present_key_values: Tuple[torch.Tensor], concat: bool = False):
|
|
105
|
+
"""Categorize present_key_values into self and cross attention.
|
|
106
|
+
|
|
107
|
+
Split present state from grouped by layer to grouped by self/cross attention.
|
|
108
|
+
Before: (past_key_self_0, past_value_self_0, past_key_cross_0, past_value_cross_0),
|
|
109
|
+
(past_key_self_1, past_value_self_1, past_key_cross_1, past_value_cross_1), ...
|
|
110
|
+
After: (past_key_self_0, past_value_self_0, past_key_self_1, past_value_self_1, ...),
|
|
111
|
+
(past_key_cross_0, past_value_cross_0, past_key_cross_1, past_value_cross_1, ...)
|
|
112
|
+
|
|
113
|
+
Args:
|
|
114
|
+
present_key_values: From past_key_values of a model (group by layer)
|
|
115
|
+
concat: If concat self attention with cross attention key/value to return
|
|
116
|
+
|
|
117
|
+
Returns:
|
|
118
|
+
present_self (Tuple[torch.Tensor]): present key and values from self attention
|
|
119
|
+
present_cross (Tuple[torch.Tensor]): present key and values from cross attention
|
|
120
|
+
"""
|
|
121
|
+
present_self: List[torch.Tensor] = []
|
|
122
|
+
present_cross: List[torch.Tensor] = []
|
|
123
|
+
for _, present_layer_i in enumerate(present_key_values):
|
|
124
|
+
assert len(present_layer_i) == 4, f"Expected to have four items. Got {len(present_layer_i)}"
|
|
125
|
+
present_key_self, present_value_self, present_key_cross, present_value_cross = present_layer_i
|
|
126
|
+
present_self.extend([present_key_self, present_value_self])
|
|
127
|
+
present_cross.extend([present_key_cross, present_value_cross])
|
|
128
|
+
if concat:
|
|
129
|
+
return present_self + present_cross
|
|
130
|
+
else:
|
|
131
|
+
return present_self, present_cross
|
|
132
|
+
|
|
133
|
+
@staticmethod
|
|
134
|
+
def get_input_names(past_key_values: Tuple[Tuple[torch.Tensor]], encoder=True):
|
|
135
|
+
"""Process input names of model wrapper.
|
|
136
|
+
|
|
137
|
+
Args:
|
|
138
|
+
past_key_values: Consider `self` and `cross` past_key_values
|
|
139
|
+
|
|
140
|
+
Returns:
|
|
141
|
+
names (List[string]): input names
|
|
142
|
+
"""
|
|
143
|
+
names = []
|
|
144
|
+
num_layers = len(past_key_values) // 4 if encoder else len(past_key_values)
|
|
145
|
+
prefix = "past_" if not encoder else "present_"
|
|
146
|
+
for i in range(num_layers):
|
|
147
|
+
names.extend([prefix + s for s in [f"key_self_{i}", f"value_self_{i}"]])
|
|
148
|
+
for i in range(num_layers):
|
|
149
|
+
names.extend([prefix + s for s in [f"key_cross_{i}", f"value_cross_{i}"]])
|
|
150
|
+
return names
|