onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,74 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License. See License.txt in the project root for
|
|
4
|
+
# license information.
|
|
5
|
+
# --------------------------------------------------------------------------
|
|
6
|
+
|
|
7
|
+
# Maps model class name to a tuple of model class
|
|
8
|
+
MODEL_CLASSES = [
|
|
9
|
+
"AutoModel",
|
|
10
|
+
"AutoModelWithLMHead",
|
|
11
|
+
"AutoModelForSequenceClassification",
|
|
12
|
+
"AutoModelForQuestionAnswering",
|
|
13
|
+
"AutoModelForCausalLM",
|
|
14
|
+
]
|
|
15
|
+
|
|
16
|
+
# Pretrained model name to a tuple of input names, opset_version, use_external_data_format, optimization model type
|
|
17
|
+
# Some models like GPT, T5, Bart etc has its own convert_to_onnx.py in models sub-directory, and they are excluded here.
|
|
18
|
+
MODELS = {
|
|
19
|
+
# BERT
|
|
20
|
+
"bert-base-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"),
|
|
21
|
+
"bert-large-cased": (["input_ids", "attention_mask", "token_type_ids"], 16, False, "bert"),
|
|
22
|
+
# Transformer-XL (Models uses Einsum, which need opset version 16 or later.)
|
|
23
|
+
"transfo-xl-wt103": (["input_ids", "mems"], 16, False, "bert"),
|
|
24
|
+
# XLNet
|
|
25
|
+
"xlnet-base-cased": (["input_ids"], 16, False, "bert"),
|
|
26
|
+
"xlnet-large-cased": (["input_ids"], 16, False, "bert"),
|
|
27
|
+
# XLM
|
|
28
|
+
"xlm-mlm-en-2048": (["input_ids"], 16, True, "bert"),
|
|
29
|
+
"xlm-mlm-ende-1024": (["input_ids"], 16, False, "bert"),
|
|
30
|
+
"xlm-mlm-enfr-1024": (["input_ids"], 16, False, "bert"),
|
|
31
|
+
# RoBERTa
|
|
32
|
+
"roberta-base": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
33
|
+
"roberta-large": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
34
|
+
"roberta-large-mnli": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
35
|
+
"deepset/roberta-base-squad2": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
36
|
+
"distilroberta-base": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
37
|
+
# DistilBERT
|
|
38
|
+
"distilbert-base-uncased": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
39
|
+
"distilbert-base-uncased-distilled-squad": (["input_ids", "attention_mask"], 16, False, "bert"),
|
|
40
|
+
# CTRL
|
|
41
|
+
"ctrl": (["input_ids"], 16, True, "bert"),
|
|
42
|
+
# CamemBERT
|
|
43
|
+
"camembert-base": (["input_ids"], 16, False, "bert"),
|
|
44
|
+
# ALBERT
|
|
45
|
+
"albert-base-v1": (["input_ids"], 16, False, "bert"),
|
|
46
|
+
"albert-large-v1": (["input_ids"], 16, False, "bert"),
|
|
47
|
+
"albert-xlarge-v1": (["input_ids"], 16, True, "bert"),
|
|
48
|
+
# "albert-xxlarge-v1": (["input_ids"], 16, True, "bert"),
|
|
49
|
+
"albert-base-v2": (["input_ids"], 16, False, "bert"),
|
|
50
|
+
"albert-large-v2": (["input_ids"], 16, False, "bert"),
|
|
51
|
+
"albert-xlarge-v2": (["input_ids"], 16, True, "bert"),
|
|
52
|
+
# "albert-xxlarge-v2": (["input_ids"], 16, True, "bert"),
|
|
53
|
+
# XLM-RoBERTa
|
|
54
|
+
"xlm-roberta-base": (["input_ids"], 16, False, "bert"),
|
|
55
|
+
"xlm-roberta-large": (["input_ids"], 16, True, "bert"),
|
|
56
|
+
# FlauBERT
|
|
57
|
+
"flaubert/flaubert_small_cased": (["input_ids"], 16, False, "bert"),
|
|
58
|
+
"flaubert/flaubert_base_cased": (["input_ids"], 16, False, "bert"),
|
|
59
|
+
# "flaubert/flaubert_large_cased": (["input_ids"], 16, False, "bert"),
|
|
60
|
+
# Layoutlm
|
|
61
|
+
"microsoft/layoutlm-base-uncased": (["input_ids"], 16, False, "bert"),
|
|
62
|
+
"microsoft/layoutlm-large-uncased": (["input_ids"], 16, False, "bert"),
|
|
63
|
+
# Squeezebert
|
|
64
|
+
"squeezebert/squeezebert-uncased": (["input_ids"], 16, False, "bert"),
|
|
65
|
+
"squeezebert/squeezebert-mnli": (["input_ids"], 16, False, "bert"),
|
|
66
|
+
"squeezebert/squeezebert-mnli-headless": (["input_ids"], 16, False, "bert"),
|
|
67
|
+
"unc-nlp/lxmert-base-uncased": (["input_ids", "visual_feats", "visual_pos"], 16, False, "bert"),
|
|
68
|
+
# ViT
|
|
69
|
+
"google/vit-base-patch16-224": (["pixel_values"], 16, False, "vit"),
|
|
70
|
+
# Swin
|
|
71
|
+
"microsoft/swin-base-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
|
|
72
|
+
"microsoft/swin-small-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
|
|
73
|
+
"microsoft/swin-tiny-patch4-window7-224": (["pixel_values"], 16, False, "swin"),
|
|
74
|
+
}
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import importlib.metadata
|
|
6
|
+
import importlib.util
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
def is_installed(package):
|
|
10
|
+
try:
|
|
11
|
+
dist = importlib.metadata.distribution(package)
|
|
12
|
+
except importlib.metadata.PackageNotFoundError:
|
|
13
|
+
try:
|
|
14
|
+
spec = importlib.util.find_spec(package)
|
|
15
|
+
except ModuleNotFoundError:
|
|
16
|
+
return False
|
|
17
|
+
|
|
18
|
+
return spec is not None
|
|
19
|
+
|
|
20
|
+
return dist is not None
|
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import logging
|
|
3
|
+
from collections import OrderedDict
|
|
4
|
+
from collections.abc import Mapping
|
|
5
|
+
from typing import Any
|
|
6
|
+
|
|
7
|
+
import numpy
|
|
8
|
+
import torch
|
|
9
|
+
from onnx import TensorProto
|
|
10
|
+
|
|
11
|
+
from onnxruntime import InferenceSession, RunOptions
|
|
12
|
+
|
|
13
|
+
# Type alias
|
|
14
|
+
ShapeDict = Mapping[str, tuple | list[int]]
|
|
15
|
+
|
|
16
|
+
logger = logging.getLogger(__name__)
|
|
17
|
+
|
|
18
|
+
|
|
19
|
+
class TypeHelper:
|
|
20
|
+
@staticmethod
|
|
21
|
+
def get_input_type(ort_session: InferenceSession, name: str) -> str:
|
|
22
|
+
for _i, input in enumerate(ort_session.get_inputs()):
|
|
23
|
+
if input.name == name:
|
|
24
|
+
return input.type
|
|
25
|
+
raise ValueError(f"input name {name} not found")
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def get_output_type(ort_session, name: str) -> str:
|
|
29
|
+
for _i, output in enumerate(ort_session.get_outputs()):
|
|
30
|
+
if output.name == name:
|
|
31
|
+
return output.type
|
|
32
|
+
|
|
33
|
+
raise ValueError(f"output name {name} not found")
|
|
34
|
+
|
|
35
|
+
@staticmethod
|
|
36
|
+
def ort_type_to_numpy_type(ort_type: str):
|
|
37
|
+
ort_type_to_numpy_type_map = {
|
|
38
|
+
"tensor(int64)": numpy.longlong,
|
|
39
|
+
"tensor(int32)": numpy.intc,
|
|
40
|
+
"tensor(float)": numpy.float32,
|
|
41
|
+
"tensor(float16)": numpy.float16,
|
|
42
|
+
"tensor(bool)": bool,
|
|
43
|
+
"tensor(uint8)": numpy.uint8,
|
|
44
|
+
"tensor(int8)": numpy.int8,
|
|
45
|
+
}
|
|
46
|
+
if ort_type not in ort_type_to_numpy_type_map:
|
|
47
|
+
raise ValueError(f"{ort_type} not found in map")
|
|
48
|
+
|
|
49
|
+
return ort_type_to_numpy_type_map[ort_type]
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def ort_type_to_torch_type(ort_type: str):
|
|
53
|
+
ort_type_to_torch_type_map = {
|
|
54
|
+
"tensor(int64)": torch.int64,
|
|
55
|
+
"tensor(int32)": torch.int32,
|
|
56
|
+
"tensor(float)": torch.float32,
|
|
57
|
+
"tensor(float16)": torch.float16,
|
|
58
|
+
"tensor(bfloat16)": torch.bfloat16,
|
|
59
|
+
"tensor(bool)": torch.bool,
|
|
60
|
+
"tensor(uint8)": torch.uint8,
|
|
61
|
+
"tensor(int8)": torch.int8,
|
|
62
|
+
}
|
|
63
|
+
if ort_type not in ort_type_to_torch_type_map:
|
|
64
|
+
raise ValueError(f"{ort_type} not found in map")
|
|
65
|
+
|
|
66
|
+
return ort_type_to_torch_type_map[ort_type]
|
|
67
|
+
|
|
68
|
+
@staticmethod
|
|
69
|
+
def get_io_onnx_type_map(ort_session: InferenceSession) -> dict[str, int]:
|
|
70
|
+
"""Create a mapping from input/output name to onnx data type"""
|
|
71
|
+
name_to_onnx_type = {}
|
|
72
|
+
for input in ort_session.get_inputs():
|
|
73
|
+
name_to_onnx_type[input.name] = TypeHelper.ort_type_to_onnx_type(input.type)
|
|
74
|
+
|
|
75
|
+
for output in ort_session.get_outputs():
|
|
76
|
+
name_to_onnx_type[output.name] = TypeHelper.ort_type_to_onnx_type(output.type)
|
|
77
|
+
return name_to_onnx_type
|
|
78
|
+
|
|
79
|
+
@staticmethod
|
|
80
|
+
def ort_type_to_onnx_type(ort_type: str):
|
|
81
|
+
ort_type_to_onnx_type_map = {
|
|
82
|
+
"tensor(int64)": TensorProto.INT64,
|
|
83
|
+
"tensor(int32)": TensorProto.INT32,
|
|
84
|
+
"tensor(float)": TensorProto.FLOAT,
|
|
85
|
+
"tensor(float16)": TensorProto.FLOAT16,
|
|
86
|
+
"tensor(bfloat16)": TensorProto.BFLOAT16,
|
|
87
|
+
"tensor(bool)": TensorProto.BOOL,
|
|
88
|
+
"tensor(uint8)": TensorProto.UINT8,
|
|
89
|
+
"tensor(int8)": TensorProto.INT8,
|
|
90
|
+
}
|
|
91
|
+
if ort_type not in ort_type_to_onnx_type_map:
|
|
92
|
+
raise ValueError(f"{ort_type} not found in map")
|
|
93
|
+
|
|
94
|
+
return ort_type_to_onnx_type_map[ort_type]
|
|
95
|
+
|
|
96
|
+
@staticmethod
|
|
97
|
+
def numpy_type_to_torch_type(numpy_type: numpy.dtype):
|
|
98
|
+
numpy_type_to_torch_type_map = {
|
|
99
|
+
numpy.longlong: torch.int64,
|
|
100
|
+
numpy.intc: torch.int32,
|
|
101
|
+
numpy.int32: torch.int32,
|
|
102
|
+
numpy.float32: torch.float32,
|
|
103
|
+
numpy.float16: torch.float16,
|
|
104
|
+
bool: torch.bool,
|
|
105
|
+
numpy.uint8: torch.uint8,
|
|
106
|
+
numpy.int8: torch.int8,
|
|
107
|
+
}
|
|
108
|
+
if numpy_type not in numpy_type_to_torch_type_map:
|
|
109
|
+
raise ValueError(f"{numpy_type} not found in map")
|
|
110
|
+
|
|
111
|
+
return numpy_type_to_torch_type_map[numpy_type]
|
|
112
|
+
|
|
113
|
+
@staticmethod
|
|
114
|
+
def torch_type_to_numpy_type(torch_type: torch.dtype):
|
|
115
|
+
torch_type_to_numpy_type_map = {
|
|
116
|
+
torch.int64: numpy.longlong,
|
|
117
|
+
torch.int32: numpy.intc,
|
|
118
|
+
torch.float32: numpy.float32,
|
|
119
|
+
torch.float16: numpy.float16,
|
|
120
|
+
torch.bool: bool,
|
|
121
|
+
torch.uint8: numpy.uint8,
|
|
122
|
+
}
|
|
123
|
+
if torch_type not in torch_type_to_numpy_type_map:
|
|
124
|
+
raise ValueError(f"{torch_type} not found in map")
|
|
125
|
+
|
|
126
|
+
return torch_type_to_numpy_type_map[torch_type]
|
|
127
|
+
|
|
128
|
+
@staticmethod
|
|
129
|
+
def get_io_numpy_type_map(ort_session: InferenceSession) -> dict[str, numpy.dtype]:
|
|
130
|
+
"""Create a mapping from input/output name to numpy data type"""
|
|
131
|
+
name_to_numpy_type = {}
|
|
132
|
+
for input in ort_session.get_inputs():
|
|
133
|
+
name_to_numpy_type[input.name] = TypeHelper.ort_type_to_numpy_type(input.type)
|
|
134
|
+
|
|
135
|
+
for output in ort_session.get_outputs():
|
|
136
|
+
name_to_numpy_type[output.name] = TypeHelper.ort_type_to_numpy_type(output.type)
|
|
137
|
+
return name_to_numpy_type
|
|
138
|
+
|
|
139
|
+
@staticmethod
|
|
140
|
+
def get_io_torch_type_map(ort_session: InferenceSession) -> dict[str, torch.dtype]:
|
|
141
|
+
"""Create a mapping from input/output name to torch data type"""
|
|
142
|
+
name_to_torch_type = {}
|
|
143
|
+
for input in ort_session.get_inputs():
|
|
144
|
+
name_to_torch_type[input.name] = TypeHelper.ort_type_to_torch_type(input.type)
|
|
145
|
+
|
|
146
|
+
for output in ort_session.get_outputs():
|
|
147
|
+
name_to_torch_type[output.name] = TypeHelper.ort_type_to_torch_type(output.type)
|
|
148
|
+
return name_to_torch_type
|
|
149
|
+
|
|
150
|
+
|
|
151
|
+
class IOBindingHelper:
|
|
152
|
+
@staticmethod
|
|
153
|
+
def get_output_buffers(ort_session: InferenceSession, output_shapes, device):
|
|
154
|
+
"""Returns a dictionary of output name as key, and 1D tensor as value. The tensor has enough space for given shape."""
|
|
155
|
+
output_buffers = {}
|
|
156
|
+
for name, shape in output_shapes.items():
|
|
157
|
+
ort_type = TypeHelper.get_output_type(ort_session, name)
|
|
158
|
+
torch_type = TypeHelper.ort_type_to_torch_type(ort_type)
|
|
159
|
+
output_buffers[name] = torch.empty(numpy.prod(shape), dtype=torch_type, device=device)
|
|
160
|
+
return output_buffers
|
|
161
|
+
|
|
162
|
+
@staticmethod
|
|
163
|
+
def prepare_io_binding(
|
|
164
|
+
ort_session,
|
|
165
|
+
input_ids: torch.Tensor,
|
|
166
|
+
position_ids: torch.Tensor,
|
|
167
|
+
attention_mask: torch.Tensor,
|
|
168
|
+
past: list[torch.Tensor],
|
|
169
|
+
output_buffers,
|
|
170
|
+
output_shapes,
|
|
171
|
+
):
|
|
172
|
+
"""IO binding for a session: bind inputs (input_ids, position_ids, attention_mask, past_*) and outputs."""
|
|
173
|
+
|
|
174
|
+
name_to_onnx_type = TypeHelper.get_io_onnx_type_map(ort_session)
|
|
175
|
+
|
|
176
|
+
# Bind inputs and outputs to onnxruntime session
|
|
177
|
+
io_binding = ort_session.io_binding()
|
|
178
|
+
|
|
179
|
+
# Bind inputs
|
|
180
|
+
assert input_ids.is_contiguous()
|
|
181
|
+
io_binding.bind_input(
|
|
182
|
+
"input_ids",
|
|
183
|
+
input_ids.device.type,
|
|
184
|
+
0,
|
|
185
|
+
name_to_onnx_type["input_ids"],
|
|
186
|
+
list(input_ids.size()),
|
|
187
|
+
input_ids.data_ptr(),
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if past is not None:
|
|
191
|
+
for i, past_i in enumerate(past):
|
|
192
|
+
assert past_i.is_contiguous()
|
|
193
|
+
|
|
194
|
+
data_ptr = past_i.data_ptr()
|
|
195
|
+
if data_ptr == 0:
|
|
196
|
+
# When past_sequence_length is 0, its data_ptr will be zero. IO Binding asserts that data_ptr shall not be zero.
|
|
197
|
+
# Here we workaround and pass data pointer of input_ids. Actual data is not used for past so it does not matter.
|
|
198
|
+
data_ptr = input_ids.data_ptr()
|
|
199
|
+
|
|
200
|
+
io_binding.bind_input(
|
|
201
|
+
f"past_{i}",
|
|
202
|
+
past_i.device.type,
|
|
203
|
+
0,
|
|
204
|
+
name_to_onnx_type[f"past_{i}"],
|
|
205
|
+
list(past_i.size()),
|
|
206
|
+
data_ptr,
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if attention_mask is not None:
|
|
210
|
+
assert attention_mask.is_contiguous()
|
|
211
|
+
io_binding.bind_input(
|
|
212
|
+
"attention_mask",
|
|
213
|
+
attention_mask.device.type,
|
|
214
|
+
0,
|
|
215
|
+
name_to_onnx_type["attention_mask"],
|
|
216
|
+
list(attention_mask.size()),
|
|
217
|
+
attention_mask.data_ptr(),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
if position_ids is not None:
|
|
221
|
+
assert position_ids.is_contiguous()
|
|
222
|
+
io_binding.bind_input(
|
|
223
|
+
"position_ids",
|
|
224
|
+
position_ids.device.type,
|
|
225
|
+
0,
|
|
226
|
+
name_to_onnx_type["position_ids"],
|
|
227
|
+
list(position_ids.size()),
|
|
228
|
+
position_ids.data_ptr(),
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
# Bind outputs
|
|
232
|
+
for output in ort_session.get_outputs():
|
|
233
|
+
output_name = output.name
|
|
234
|
+
output_buffer = output_buffers[output_name]
|
|
235
|
+
logger.debug(f"{output_name} device type={output_buffer.device.type} shape={list(output_buffer.size())}")
|
|
236
|
+
io_binding.bind_output(
|
|
237
|
+
output_name,
|
|
238
|
+
output_buffer.device.type,
|
|
239
|
+
0,
|
|
240
|
+
name_to_onnx_type[output_name],
|
|
241
|
+
output_shapes[output_name],
|
|
242
|
+
output_buffer.data_ptr(),
|
|
243
|
+
)
|
|
244
|
+
|
|
245
|
+
return io_binding
|
|
246
|
+
|
|
247
|
+
@staticmethod
|
|
248
|
+
def get_outputs_from_io_binding_buffer(ort_session, output_buffers, output_shapes, return_numpy=True):
|
|
249
|
+
"""Copy results to cpu. Returns a list of numpy array."""
|
|
250
|
+
ort_outputs = []
|
|
251
|
+
for output in ort_session.get_outputs():
|
|
252
|
+
output_name = output.name
|
|
253
|
+
buffer = output_buffers[output_name]
|
|
254
|
+
shape = output_shapes[output_name]
|
|
255
|
+
copy_tensor = buffer[0 : numpy.prod(shape)].reshape(shape).clone().detach()
|
|
256
|
+
if return_numpy:
|
|
257
|
+
ort_outputs.append(copy_tensor.cpu().numpy())
|
|
258
|
+
else:
|
|
259
|
+
ort_outputs.append(copy_tensor)
|
|
260
|
+
return ort_outputs
|
|
261
|
+
|
|
262
|
+
|
|
263
|
+
class CudaSession:
|
|
264
|
+
"""Inference Session with IO Binding for ONNX Runtime CUDA or TensorRT provider"""
|
|
265
|
+
|
|
266
|
+
def __init__(self, ort_session: InferenceSession, device: torch.device, enable_cuda_graph=False):
|
|
267
|
+
self.ort_session = ort_session
|
|
268
|
+
self.input_names = [input.name for input in self.ort_session.get_inputs()]
|
|
269
|
+
self.output_names = [output.name for output in self.ort_session.get_outputs()]
|
|
270
|
+
self.io_name_to_onnx_type = TypeHelper.get_io_onnx_type_map(self.ort_session)
|
|
271
|
+
self.io_name_to_torch_type = TypeHelper.get_io_torch_type_map(self.ort_session)
|
|
272
|
+
self.io_binding = self.ort_session.io_binding()
|
|
273
|
+
self.enable_cuda_graph = enable_cuda_graph
|
|
274
|
+
|
|
275
|
+
self.input_tensors = OrderedDict()
|
|
276
|
+
self.output_tensors = OrderedDict()
|
|
277
|
+
self.device = device
|
|
278
|
+
|
|
279
|
+
# Pairs of input and output names that share the same buffer.
|
|
280
|
+
self.buffer_sharing: dict[str, str] = {}
|
|
281
|
+
|
|
282
|
+
def set_buffer_sharing(self, input_name: str, output_name: str):
|
|
283
|
+
assert input_name in self.input_names
|
|
284
|
+
assert output_name in self.output_names
|
|
285
|
+
self.buffer_sharing[input_name] = output_name
|
|
286
|
+
self.buffer_sharing[output_name] = input_name
|
|
287
|
+
|
|
288
|
+
def __del__(self):
|
|
289
|
+
del self.input_tensors
|
|
290
|
+
del self.output_tensors
|
|
291
|
+
del self.io_binding
|
|
292
|
+
|
|
293
|
+
def bind_input_and_buffer_sharing(self, name: str, tensor: torch.Tensor):
|
|
294
|
+
device_id = tensor.device.index if tensor.device.index is not None else 0
|
|
295
|
+
tensor_shape = [1] if len(tensor.shape) == 0 else list(tensor.shape)
|
|
296
|
+
|
|
297
|
+
self.io_binding.bind_input(
|
|
298
|
+
name,
|
|
299
|
+
tensor.device.type,
|
|
300
|
+
device_id,
|
|
301
|
+
self.io_name_to_onnx_type[name],
|
|
302
|
+
tensor_shape,
|
|
303
|
+
tensor.data_ptr(),
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
if name in self.buffer_sharing:
|
|
307
|
+
self.io_binding.bind_output(
|
|
308
|
+
self.buffer_sharing[name],
|
|
309
|
+
tensor.device.type,
|
|
310
|
+
device_id,
|
|
311
|
+
self.io_name_to_onnx_type[name],
|
|
312
|
+
tensor_shape,
|
|
313
|
+
tensor.data_ptr(),
|
|
314
|
+
)
|
|
315
|
+
self.output_tensors[self.buffer_sharing[name]] = tensor
|
|
316
|
+
|
|
317
|
+
def allocate_buffers(self, shape_dict: ShapeDict):
|
|
318
|
+
"""Allocate tensors for I/O Binding"""
|
|
319
|
+
if self.enable_cuda_graph:
|
|
320
|
+
for name, shape in shape_dict.items():
|
|
321
|
+
if name in self.input_names:
|
|
322
|
+
# Reuse allocated buffer when the shape is same
|
|
323
|
+
if name in self.input_tensors:
|
|
324
|
+
if tuple(self.input_tensors[name].shape) == tuple(shape):
|
|
325
|
+
continue
|
|
326
|
+
raise RuntimeError("Expect static input shape for cuda graph")
|
|
327
|
+
|
|
328
|
+
torch_dtype = self.io_name_to_torch_type[name]
|
|
329
|
+
tensor = torch.empty(tuple(shape), dtype=torch_dtype).to(device=self.device)
|
|
330
|
+
self.input_tensors[name] = tensor
|
|
331
|
+
self.bind_input_and_buffer_sharing(name, tensor)
|
|
332
|
+
|
|
333
|
+
for name, shape in shape_dict.items():
|
|
334
|
+
if name in self.output_names:
|
|
335
|
+
# Reuse allocated buffer when the shape is same
|
|
336
|
+
if name in self.output_tensors and tuple(self.output_tensors[name].shape) == tuple(shape):
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
if name in self.buffer_sharing:
|
|
340
|
+
continue
|
|
341
|
+
|
|
342
|
+
torch_dtype = self.io_name_to_torch_type[name]
|
|
343
|
+
tensor = torch.empty(tuple(shape), dtype=torch_dtype).to(device=self.device)
|
|
344
|
+
self.output_tensors[name] = tensor
|
|
345
|
+
|
|
346
|
+
self.io_binding.bind_output(
|
|
347
|
+
name,
|
|
348
|
+
tensor.device.type,
|
|
349
|
+
tensor.device.index if tensor.device.index is not None else 0,
|
|
350
|
+
self.io_name_to_onnx_type[name],
|
|
351
|
+
list(tensor.size()),
|
|
352
|
+
tensor.data_ptr(),
|
|
353
|
+
)
|
|
354
|
+
|
|
355
|
+
def infer(self, feed_dict: dict[str, torch.Tensor], run_options: RunOptions = None, synchronize: bool = True):
|
|
356
|
+
"""Bind input tensors and run inference"""
|
|
357
|
+
for name, tensor in feed_dict.items():
|
|
358
|
+
assert isinstance(tensor, torch.Tensor) and tensor.is_contiguous()
|
|
359
|
+
if name in self.input_names:
|
|
360
|
+
if self.enable_cuda_graph:
|
|
361
|
+
assert self.input_tensors[name].nelement() == tensor.nelement()
|
|
362
|
+
assert self.input_tensors[name].dtype == tensor.dtype
|
|
363
|
+
assert tensor.device.type == "cuda"
|
|
364
|
+
self.input_tensors[name].copy_(tensor)
|
|
365
|
+
else:
|
|
366
|
+
self.bind_input_and_buffer_sharing(name, tensor)
|
|
367
|
+
|
|
368
|
+
if synchronize:
|
|
369
|
+
self.io_binding.synchronize_inputs()
|
|
370
|
+
self.ort_session.run_with_iobinding(self.io_binding, run_options)
|
|
371
|
+
self.io_binding.synchronize_outputs()
|
|
372
|
+
else:
|
|
373
|
+
self.ort_session.run_with_iobinding(self.io_binding, run_options)
|
|
374
|
+
|
|
375
|
+
return self.output_tensors
|
|
376
|
+
|
|
377
|
+
@staticmethod
|
|
378
|
+
def get_cuda_provider_options(device_id: int, enable_cuda_graph: bool, stream: int = 0) -> dict[str, Any]:
|
|
379
|
+
options = {
|
|
380
|
+
"device_id": device_id,
|
|
381
|
+
"arena_extend_strategy": "kSameAsRequested",
|
|
382
|
+
"enable_cuda_graph": enable_cuda_graph,
|
|
383
|
+
}
|
|
384
|
+
|
|
385
|
+
# Stream is address of a CUDA stream. 0 means the default stream.
|
|
386
|
+
if stream != 0:
|
|
387
|
+
options["user_compute_stream"] = str(stream)
|
|
388
|
+
|
|
389
|
+
return options
|
|
390
|
+
|
|
391
|
+
|
|
392
|
+
class GpuBinding(CudaSession):
|
|
393
|
+
def __init__(
|
|
394
|
+
self,
|
|
395
|
+
ort_session: InferenceSession,
|
|
396
|
+
device: torch.device,
|
|
397
|
+
shape_dict: ShapeDict,
|
|
398
|
+
enable_gpu_graph: bool = False,
|
|
399
|
+
gpu_graph_id: int = -1,
|
|
400
|
+
stream: int = 0,
|
|
401
|
+
buffer_sharing: dict[str, str] | None = None,
|
|
402
|
+
):
|
|
403
|
+
super().__init__(ort_session, device, enable_gpu_graph)
|
|
404
|
+
if buffer_sharing:
|
|
405
|
+
for input_name, output_name in buffer_sharing.items():
|
|
406
|
+
self.set_buffer_sharing(input_name, output_name)
|
|
407
|
+
|
|
408
|
+
self.allocate_buffers(shape_dict)
|
|
409
|
+
self.gpu_graph_id = gpu_graph_id
|
|
410
|
+
# For cuda graph, we need to keep a copy of shape_dict to check if the shape is same in inference later.
|
|
411
|
+
self.shape_dict = copy.deepcopy(shape_dict) if enable_gpu_graph else None
|
|
412
|
+
self.stream = stream
|
|
413
|
+
# The gpu graph id of last run. It will be saved to image metadata.
|
|
414
|
+
self.last_run_gpu_graph_id = None
|
|
415
|
+
|
|
416
|
+
def get_run_options(self, disable_cuda_graph_in_run: bool = False) -> RunOptions:
|
|
417
|
+
options = RunOptions()
|
|
418
|
+
|
|
419
|
+
gpu_graph_id = -1 if disable_cuda_graph_in_run else self.gpu_graph_id
|
|
420
|
+
|
|
421
|
+
options.add_run_config_entry("gpu_graph_id", str(gpu_graph_id))
|
|
422
|
+
|
|
423
|
+
self.last_run_gpu_graph_id = gpu_graph_id
|
|
424
|
+
|
|
425
|
+
return options
|
|
426
|
+
|
|
427
|
+
def infer(self, feed_dict: dict[str, torch.Tensor], disable_cuda_graph_in_run: bool = False):
|
|
428
|
+
run_options = self.get_run_options(disable_cuda_graph_in_run)
|
|
429
|
+
|
|
430
|
+
if self.stream:
|
|
431
|
+
run_options.add_run_config_entry("disable_synchronize_execution_providers", "1")
|
|
432
|
+
|
|
433
|
+
return super().infer(feed_dict, run_options)
|
|
434
|
+
|
|
435
|
+
|
|
436
|
+
class GpuBindingManager:
|
|
437
|
+
"""A manager for I/O bindings that support multiple CUDA Graphs.
|
|
438
|
+
One cuda graph is reused for same input shape. Automatically add a new cuda graph for new input shape.
|
|
439
|
+
"""
|
|
440
|
+
|
|
441
|
+
def __init__(self, ort_session: InferenceSession, device: torch.device, stream: int = 0, max_cuda_graphs: int = 1):
|
|
442
|
+
self.ort_session = ort_session
|
|
443
|
+
self.device = device
|
|
444
|
+
|
|
445
|
+
# Binding supports cuda graphs. For a binding, it is able to disable cuda graph for a specific run.
|
|
446
|
+
self.graph_bindings = []
|
|
447
|
+
|
|
448
|
+
# Binding for not using cuda graph.
|
|
449
|
+
self.no_graph_binding = None
|
|
450
|
+
|
|
451
|
+
self.stream = stream
|
|
452
|
+
|
|
453
|
+
self.max_cuda_graphs = max_cuda_graphs
|
|
454
|
+
|
|
455
|
+
def get_binding(
|
|
456
|
+
self,
|
|
457
|
+
shape_dict: ShapeDict,
|
|
458
|
+
use_cuda_graph: bool = False,
|
|
459
|
+
buffer_sharing: dict[str, str] | None = None,
|
|
460
|
+
) -> GpuBinding:
|
|
461
|
+
for gpu_graph_binding in self.graph_bindings:
|
|
462
|
+
# Found a cuda graph that captured with the same shape
|
|
463
|
+
if gpu_graph_binding.shape_dict == shape_dict:
|
|
464
|
+
return gpu_graph_binding
|
|
465
|
+
|
|
466
|
+
# Reached the maximum number of cuda graphs. Return a binding without cuda graph.
|
|
467
|
+
if len(self.graph_bindings) >= self.max_cuda_graphs or (not use_cuda_graph):
|
|
468
|
+
if self.no_graph_binding is None:
|
|
469
|
+
self.no_graph_binding = GpuBinding(
|
|
470
|
+
self.ort_session, self.device, shape_dict, stream=self.stream, buffer_sharing=buffer_sharing
|
|
471
|
+
)
|
|
472
|
+
else:
|
|
473
|
+
self.no_graph_binding.allocate_buffers(shape_dict)
|
|
474
|
+
return self.no_graph_binding
|
|
475
|
+
|
|
476
|
+
# This is a new input shape, create a new cuda graph
|
|
477
|
+
gpu_graph_binding = GpuBinding(
|
|
478
|
+
self.ort_session,
|
|
479
|
+
self.device,
|
|
480
|
+
shape_dict,
|
|
481
|
+
enable_gpu_graph=True,
|
|
482
|
+
gpu_graph_id=len(self.graph_bindings),
|
|
483
|
+
stream=self.stream,
|
|
484
|
+
buffer_sharing=buffer_sharing,
|
|
485
|
+
)
|
|
486
|
+
self.graph_bindings.append(gpu_graph_binding)
|
|
487
|
+
return gpu_graph_binding
|