onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,288 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import gc
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
from cuda import cudart
|
|
12
|
+
from diffusion_models import PipelineInfo
|
|
13
|
+
from engine_builder import EngineBuilder, EngineType
|
|
14
|
+
from packaging import version
|
|
15
|
+
|
|
16
|
+
import onnxruntime as ort
|
|
17
|
+
from onnxruntime.transformers.io_binding_helper import CudaSession
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class OrtTensorrtEngine(CudaSession):
|
|
23
|
+
def __init__(
|
|
24
|
+
self,
|
|
25
|
+
engine_path,
|
|
26
|
+
device_id,
|
|
27
|
+
onnx_path,
|
|
28
|
+
fp16,
|
|
29
|
+
input_profile,
|
|
30
|
+
workspace_size,
|
|
31
|
+
enable_cuda_graph,
|
|
32
|
+
timing_cache_path=None,
|
|
33
|
+
):
|
|
34
|
+
self.engine_path = engine_path
|
|
35
|
+
self.ort_trt_provider_options = self.get_tensorrt_provider_options(
|
|
36
|
+
input_profile,
|
|
37
|
+
workspace_size,
|
|
38
|
+
fp16,
|
|
39
|
+
device_id,
|
|
40
|
+
enable_cuda_graph,
|
|
41
|
+
timing_cache_path=timing_cache_path,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
session_options = ort.SessionOptions()
|
|
45
|
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
46
|
+
logger.info("creating TRT EP session for %s", onnx_path)
|
|
47
|
+
ort_session = ort.InferenceSession(
|
|
48
|
+
onnx_path,
|
|
49
|
+
session_options,
|
|
50
|
+
providers=[
|
|
51
|
+
("TensorrtExecutionProvider", self.ort_trt_provider_options),
|
|
52
|
+
],
|
|
53
|
+
)
|
|
54
|
+
logger.info("created TRT EP session for %s", onnx_path)
|
|
55
|
+
|
|
56
|
+
device = torch.device("cuda", device_id)
|
|
57
|
+
super().__init__(ort_session, device, enable_cuda_graph)
|
|
58
|
+
|
|
59
|
+
def get_tensorrt_provider_options(
|
|
60
|
+
self, input_profile, workspace_size, fp16, device_id, enable_cuda_graph, timing_cache_path=None
|
|
61
|
+
):
|
|
62
|
+
trt_ep_options = {
|
|
63
|
+
"device_id": device_id,
|
|
64
|
+
"trt_fp16_enable": fp16,
|
|
65
|
+
"trt_engine_cache_enable": True,
|
|
66
|
+
"trt_timing_cache_enable": True,
|
|
67
|
+
"trt_detailed_build_log": True,
|
|
68
|
+
"trt_engine_cache_path": self.engine_path,
|
|
69
|
+
}
|
|
70
|
+
|
|
71
|
+
if version.parse(ort.__version__) > version.parse("1.16.2") and timing_cache_path is not None:
|
|
72
|
+
trt_ep_options["trt_timing_cache_path"] = timing_cache_path
|
|
73
|
+
|
|
74
|
+
if enable_cuda_graph:
|
|
75
|
+
trt_ep_options["trt_cuda_graph_enable"] = True
|
|
76
|
+
|
|
77
|
+
if workspace_size > 0:
|
|
78
|
+
trt_ep_options["trt_max_workspace_size"] = workspace_size
|
|
79
|
+
|
|
80
|
+
if input_profile:
|
|
81
|
+
min_shapes = []
|
|
82
|
+
max_shapes = []
|
|
83
|
+
opt_shapes = []
|
|
84
|
+
for name, profile in input_profile.items():
|
|
85
|
+
assert isinstance(profile, list) and len(profile) == 3
|
|
86
|
+
min_shape = profile[0]
|
|
87
|
+
opt_shape = profile[1]
|
|
88
|
+
max_shape = profile[2]
|
|
89
|
+
assert len(min_shape) == len(opt_shape) and len(opt_shape) == len(max_shape)
|
|
90
|
+
|
|
91
|
+
min_shapes.append(f"{name}:" + "x".join([str(x) for x in min_shape]))
|
|
92
|
+
opt_shapes.append(f"{name}:" + "x".join([str(x) for x in opt_shape]))
|
|
93
|
+
max_shapes.append(f"{name}:" + "x".join([str(x) for x in max_shape]))
|
|
94
|
+
|
|
95
|
+
trt_ep_options["trt_profile_min_shapes"] = ",".join(min_shapes)
|
|
96
|
+
trt_ep_options["trt_profile_max_shapes"] = ",".join(max_shapes)
|
|
97
|
+
trt_ep_options["trt_profile_opt_shapes"] = ",".join(opt_shapes)
|
|
98
|
+
|
|
99
|
+
logger.info("trt_ep_options=%s", trt_ep_options)
|
|
100
|
+
|
|
101
|
+
return trt_ep_options
|
|
102
|
+
|
|
103
|
+
def allocate_buffers(self, shape_dict, device):
|
|
104
|
+
super().allocate_buffers(shape_dict)
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class OrtTensorrtEngineBuilder(EngineBuilder):
|
|
108
|
+
def __init__(
|
|
109
|
+
self,
|
|
110
|
+
pipeline_info: PipelineInfo,
|
|
111
|
+
max_batch_size=16,
|
|
112
|
+
device="cuda",
|
|
113
|
+
use_cuda_graph=False,
|
|
114
|
+
):
|
|
115
|
+
"""
|
|
116
|
+
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
|
117
|
+
|
|
118
|
+
Args:
|
|
119
|
+
pipeline_info (PipelineInfo):
|
|
120
|
+
Version and Type of pipeline.
|
|
121
|
+
max_batch_size (int):
|
|
122
|
+
Maximum batch size for dynamic batch engine.
|
|
123
|
+
device (str):
|
|
124
|
+
device to run.
|
|
125
|
+
use_cuda_graph (bool):
|
|
126
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
127
|
+
"""
|
|
128
|
+
super().__init__(
|
|
129
|
+
EngineType.ORT_TRT,
|
|
130
|
+
pipeline_info,
|
|
131
|
+
max_batch_size=max_batch_size,
|
|
132
|
+
device=device,
|
|
133
|
+
use_cuda_graph=use_cuda_graph,
|
|
134
|
+
)
|
|
135
|
+
|
|
136
|
+
def has_engine_file(self, engine_path):
|
|
137
|
+
if os.path.isdir(engine_path):
|
|
138
|
+
children = os.scandir(engine_path)
|
|
139
|
+
for entry in children:
|
|
140
|
+
if entry.is_file() and entry.name.endswith(".engine"):
|
|
141
|
+
return True
|
|
142
|
+
return False
|
|
143
|
+
|
|
144
|
+
def get_work_space_size(self, model_name, max_workspace_size):
|
|
145
|
+
gibibyte = 2**30
|
|
146
|
+
workspace_size = 4 * gibibyte if model_name == "clip" else max_workspace_size
|
|
147
|
+
if workspace_size == 0:
|
|
148
|
+
_, free_mem, _ = cudart.cudaMemGetInfo()
|
|
149
|
+
# The following logic are adopted from TensorRT demo diffusion.
|
|
150
|
+
if free_mem > 6 * gibibyte:
|
|
151
|
+
workspace_size = free_mem - 4 * gibibyte
|
|
152
|
+
return workspace_size
|
|
153
|
+
|
|
154
|
+
def build_engines(
|
|
155
|
+
self,
|
|
156
|
+
engine_dir,
|
|
157
|
+
framework_model_dir,
|
|
158
|
+
onnx_dir,
|
|
159
|
+
onnx_opset,
|
|
160
|
+
opt_image_height,
|
|
161
|
+
opt_image_width,
|
|
162
|
+
opt_batch_size=1,
|
|
163
|
+
static_batch=False,
|
|
164
|
+
static_image_shape=True,
|
|
165
|
+
max_workspace_size=0,
|
|
166
|
+
device_id=0,
|
|
167
|
+
timing_cache=None,
|
|
168
|
+
):
|
|
169
|
+
self.torch_device = torch.device("cuda", device_id)
|
|
170
|
+
self.load_models(framework_model_dir)
|
|
171
|
+
|
|
172
|
+
if not os.path.isdir(engine_dir):
|
|
173
|
+
os.makedirs(engine_dir)
|
|
174
|
+
|
|
175
|
+
if not os.path.isdir(onnx_dir):
|
|
176
|
+
os.makedirs(onnx_dir)
|
|
177
|
+
|
|
178
|
+
# Load lora only when we need export text encoder or UNet to ONNX.
|
|
179
|
+
load_lora = False
|
|
180
|
+
if self.pipeline_info.lora_weights:
|
|
181
|
+
for model_name, model_obj in self.models.items():
|
|
182
|
+
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
|
|
183
|
+
continue
|
|
184
|
+
profile_id = model_obj.get_profile_id(
|
|
185
|
+
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
|
|
186
|
+
)
|
|
187
|
+
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
|
188
|
+
if not self.has_engine_file(engine_path):
|
|
189
|
+
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
|
190
|
+
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
|
191
|
+
if not os.path.exists(onnx_opt_path):
|
|
192
|
+
if not os.path.exists(onnx_path):
|
|
193
|
+
load_lora = True
|
|
194
|
+
break
|
|
195
|
+
|
|
196
|
+
# Export models to ONNX
|
|
197
|
+
self.disable_torch_spda()
|
|
198
|
+
pipe = self.load_pipeline_with_lora() if load_lora else None
|
|
199
|
+
|
|
200
|
+
for model_name, model_obj in self.models.items():
|
|
201
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
202
|
+
continue
|
|
203
|
+
|
|
204
|
+
profile_id = model_obj.get_profile_id(
|
|
205
|
+
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
|
|
206
|
+
)
|
|
207
|
+
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
|
208
|
+
if not self.has_engine_file(engine_path):
|
|
209
|
+
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
|
210
|
+
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
|
211
|
+
if not os.path.exists(onnx_opt_path):
|
|
212
|
+
if not os.path.exists(onnx_path):
|
|
213
|
+
logger.info(f"Exporting model: {onnx_path}")
|
|
214
|
+
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
|
215
|
+
|
|
216
|
+
with torch.inference_mode(), torch.autocast("cuda"):
|
|
217
|
+
inputs = model_obj.get_sample_input(opt_batch_size, opt_image_height, opt_image_width)
|
|
218
|
+
torch.onnx.export(
|
|
219
|
+
model,
|
|
220
|
+
inputs,
|
|
221
|
+
onnx_path,
|
|
222
|
+
export_params=True,
|
|
223
|
+
opset_version=onnx_opset,
|
|
224
|
+
do_constant_folding=True,
|
|
225
|
+
input_names=model_obj.get_input_names(),
|
|
226
|
+
output_names=model_obj.get_output_names(),
|
|
227
|
+
dynamic_axes=model_obj.get_dynamic_axes(),
|
|
228
|
+
)
|
|
229
|
+
del model
|
|
230
|
+
torch.cuda.empty_cache()
|
|
231
|
+
gc.collect()
|
|
232
|
+
else:
|
|
233
|
+
logger.info("Found cached model: %s", onnx_path)
|
|
234
|
+
|
|
235
|
+
# Optimize onnx
|
|
236
|
+
if not os.path.exists(onnx_opt_path):
|
|
237
|
+
logger.info("Generating optimizing model: %s", onnx_opt_path)
|
|
238
|
+
model_obj.optimize_trt(onnx_path, onnx_opt_path)
|
|
239
|
+
else:
|
|
240
|
+
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
|
241
|
+
self.enable_torch_spda()
|
|
242
|
+
|
|
243
|
+
built_engines = {}
|
|
244
|
+
for model_name, model_obj in self.models.items():
|
|
245
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
246
|
+
continue
|
|
247
|
+
|
|
248
|
+
profile_id = model_obj.get_profile_id(
|
|
249
|
+
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_image_shape
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
|
253
|
+
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
|
254
|
+
if not self.has_engine_file(engine_path):
|
|
255
|
+
logger.info(
|
|
256
|
+
"Building TensorRT engine for %s from %s to %s. It can take a while to complete...",
|
|
257
|
+
model_name,
|
|
258
|
+
onnx_opt_path,
|
|
259
|
+
engine_path,
|
|
260
|
+
)
|
|
261
|
+
else:
|
|
262
|
+
logger.info("Reuse cached TensorRT engine in directory %s", engine_path)
|
|
263
|
+
|
|
264
|
+
input_profile = model_obj.get_input_profile(
|
|
265
|
+
opt_batch_size,
|
|
266
|
+
opt_image_height,
|
|
267
|
+
opt_image_width,
|
|
268
|
+
static_batch=static_batch,
|
|
269
|
+
static_image_shape=static_image_shape,
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
engine = OrtTensorrtEngine(
|
|
273
|
+
engine_path,
|
|
274
|
+
device_id,
|
|
275
|
+
onnx_opt_path,
|
|
276
|
+
fp16=True,
|
|
277
|
+
input_profile=input_profile,
|
|
278
|
+
workspace_size=self.get_work_space_size(model_name, max_workspace_size),
|
|
279
|
+
enable_cuda_graph=self.use_cuda_graph,
|
|
280
|
+
timing_cache_path=timing_cache,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
built_engines[model_name] = engine
|
|
284
|
+
|
|
285
|
+
self.engines = built_engines
|
|
286
|
+
|
|
287
|
+
def run_engine(self, model_name, feed_dict):
|
|
288
|
+
return self.engines[model_name].infer(feed_dict)
|
|
@@ -0,0 +1,395 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
# Modified from TensorRT demo diffusion, which has the following license:
|
|
6
|
+
#
|
|
7
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
8
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
9
|
+
#
|
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
11
|
+
# you may not use this file except in compliance with the License.
|
|
12
|
+
# You may obtain a copy of the License at
|
|
13
|
+
#
|
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
15
|
+
#
|
|
16
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
17
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
|
+
# See the License for the specific language governing permissions and
|
|
20
|
+
# limitations under the License.
|
|
21
|
+
# --------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
import gc
|
|
24
|
+
import os
|
|
25
|
+
import pathlib
|
|
26
|
+
from collections import OrderedDict
|
|
27
|
+
|
|
28
|
+
import numpy as np
|
|
29
|
+
import tensorrt as trt
|
|
30
|
+
import torch
|
|
31
|
+
from cuda import cudart
|
|
32
|
+
from diffusion_models import PipelineInfo
|
|
33
|
+
from engine_builder import EngineBuilder, EngineType
|
|
34
|
+
from polygraphy.backend.common import bytes_from_path
|
|
35
|
+
from polygraphy.backend.trt import (
|
|
36
|
+
CreateConfig,
|
|
37
|
+
ModifyNetworkOutputs,
|
|
38
|
+
Profile,
|
|
39
|
+
engine_from_bytes,
|
|
40
|
+
engine_from_network,
|
|
41
|
+
network_from_onnx_path,
|
|
42
|
+
save_engine,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
# Map of numpy dtype -> torch dtype
|
|
46
|
+
numpy_to_torch_dtype_dict = {
|
|
47
|
+
np.int32: torch.int32,
|
|
48
|
+
np.int64: torch.int64,
|
|
49
|
+
np.float16: torch.float16,
|
|
50
|
+
np.float32: torch.float32,
|
|
51
|
+
}
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _cuda_assert(cuda_ret):
|
|
55
|
+
err = cuda_ret[0]
|
|
56
|
+
if err != cudart.cudaError_t.cudaSuccess:
|
|
57
|
+
raise RuntimeError(
|
|
58
|
+
f"CUDA ERROR: {err}, error code reference: https://nvidia.github.io/cuda-python/module/cudart.html#cuda.cudart.cudaError_t"
|
|
59
|
+
)
|
|
60
|
+
if len(cuda_ret) > 1:
|
|
61
|
+
return cuda_ret[1]
|
|
62
|
+
return None
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class TensorrtEngine:
|
|
66
|
+
def __init__(
|
|
67
|
+
self,
|
|
68
|
+
engine_path,
|
|
69
|
+
):
|
|
70
|
+
self.engine_path = engine_path
|
|
71
|
+
self.engine = None
|
|
72
|
+
self.context = None
|
|
73
|
+
self.buffers = OrderedDict()
|
|
74
|
+
self.tensors = OrderedDict()
|
|
75
|
+
self.cuda_graph_instance = None
|
|
76
|
+
|
|
77
|
+
def __del__(self):
|
|
78
|
+
del self.engine
|
|
79
|
+
del self.context
|
|
80
|
+
del self.buffers
|
|
81
|
+
del self.tensors
|
|
82
|
+
|
|
83
|
+
def build(
|
|
84
|
+
self,
|
|
85
|
+
onnx_path,
|
|
86
|
+
fp16,
|
|
87
|
+
input_profile=None,
|
|
88
|
+
enable_all_tactics=False,
|
|
89
|
+
timing_cache=None,
|
|
90
|
+
update_output_names=None,
|
|
91
|
+
):
|
|
92
|
+
print(f"Building TensorRT engine for {onnx_path}: {self.engine_path}")
|
|
93
|
+
p = Profile()
|
|
94
|
+
if input_profile:
|
|
95
|
+
for name, dims in input_profile.items():
|
|
96
|
+
assert len(dims) == 3
|
|
97
|
+
p.add(name, min=dims[0], opt=dims[1], max=dims[2])
|
|
98
|
+
|
|
99
|
+
config_kwargs = {}
|
|
100
|
+
if not enable_all_tactics:
|
|
101
|
+
config_kwargs["tactic_sources"] = []
|
|
102
|
+
|
|
103
|
+
network = network_from_onnx_path(onnx_path, flags=[trt.OnnxParserFlag.NATIVE_INSTANCENORM])
|
|
104
|
+
if update_output_names:
|
|
105
|
+
print(f"Updating network outputs to {update_output_names}")
|
|
106
|
+
network = ModifyNetworkOutputs(network, update_output_names)
|
|
107
|
+
engine = engine_from_network(
|
|
108
|
+
network,
|
|
109
|
+
config=CreateConfig(
|
|
110
|
+
fp16=fp16, refittable=False, profiles=[p], load_timing_cache=timing_cache, **config_kwargs
|
|
111
|
+
),
|
|
112
|
+
save_timing_cache=timing_cache,
|
|
113
|
+
)
|
|
114
|
+
save_engine(engine, path=self.engine_path)
|
|
115
|
+
|
|
116
|
+
def load(self):
|
|
117
|
+
print(f"Loading TensorRT engine: {self.engine_path}")
|
|
118
|
+
self.engine = engine_from_bytes(bytes_from_path(self.engine_path))
|
|
119
|
+
|
|
120
|
+
def activate(self, reuse_device_memory=None):
|
|
121
|
+
if reuse_device_memory:
|
|
122
|
+
self.context = self.engine.create_execution_context_without_device_memory()
|
|
123
|
+
self.context.device_memory = reuse_device_memory
|
|
124
|
+
else:
|
|
125
|
+
self.context = self.engine.create_execution_context()
|
|
126
|
+
|
|
127
|
+
def allocate_buffers(self, shape_dict=None, device="cuda"):
|
|
128
|
+
for idx in range(self.engine.num_io_tensors):
|
|
129
|
+
binding = self.engine[idx]
|
|
130
|
+
if shape_dict and binding in shape_dict:
|
|
131
|
+
shape = shape_dict[binding]
|
|
132
|
+
else:
|
|
133
|
+
shape = self.engine.get_binding_shape(binding)
|
|
134
|
+
dtype = trt.nptype(self.engine.get_binding_dtype(binding))
|
|
135
|
+
if self.engine.binding_is_input(binding):
|
|
136
|
+
self.context.set_binding_shape(idx, shape)
|
|
137
|
+
tensor = torch.empty(tuple(shape), dtype=numpy_to_torch_dtype_dict[dtype]).to(device=device)
|
|
138
|
+
self.tensors[binding] = tensor
|
|
139
|
+
|
|
140
|
+
def infer(self, feed_dict, stream, use_cuda_graph=False):
|
|
141
|
+
for name, buf in feed_dict.items():
|
|
142
|
+
self.tensors[name].copy_(buf)
|
|
143
|
+
|
|
144
|
+
for name, tensor in self.tensors.items():
|
|
145
|
+
self.context.set_tensor_address(name, tensor.data_ptr())
|
|
146
|
+
|
|
147
|
+
if use_cuda_graph:
|
|
148
|
+
if self.cuda_graph_instance is not None:
|
|
149
|
+
_cuda_assert(cudart.cudaGraphLaunch(self.cuda_graph_instance, stream))
|
|
150
|
+
_cuda_assert(cudart.cudaStreamSynchronize(stream))
|
|
151
|
+
else:
|
|
152
|
+
# do inference before CUDA graph capture
|
|
153
|
+
noerror = self.context.execute_async_v3(stream)
|
|
154
|
+
if not noerror:
|
|
155
|
+
raise ValueError("ERROR: inference failed.")
|
|
156
|
+
# capture cuda graph
|
|
157
|
+
_cuda_assert(
|
|
158
|
+
cudart.cudaStreamBeginCapture(stream, cudart.cudaStreamCaptureMode.cudaStreamCaptureModeGlobal)
|
|
159
|
+
)
|
|
160
|
+
self.context.execute_async_v3(stream)
|
|
161
|
+
self.graph = _cuda_assert(cudart.cudaStreamEndCapture(stream))
|
|
162
|
+
|
|
163
|
+
from cuda import nvrtc # noqa: PLC0415
|
|
164
|
+
|
|
165
|
+
result, major, minor = nvrtc.nvrtcVersion()
|
|
166
|
+
assert result == nvrtc.nvrtcResult(0)
|
|
167
|
+
if major < 12:
|
|
168
|
+
self.cuda_graph_instance = _cuda_assert(
|
|
169
|
+
cudart.cudaGraphInstantiate(self.graph, b"", 0)
|
|
170
|
+
) # cuda < 12
|
|
171
|
+
else:
|
|
172
|
+
self.cuda_graph_instance = _cuda_assert(cudart.cudaGraphInstantiate(self.graph, 0)) # cuda >= 12
|
|
173
|
+
else:
|
|
174
|
+
noerror = self.context.execute_async_v3(stream)
|
|
175
|
+
if not noerror:
|
|
176
|
+
raise ValueError("ERROR: inference failed.")
|
|
177
|
+
|
|
178
|
+
return self.tensors
|
|
179
|
+
|
|
180
|
+
|
|
181
|
+
class TensorrtEngineBuilder(EngineBuilder):
|
|
182
|
+
"""
|
|
183
|
+
Helper class to hide the detail of TensorRT Engine from pipeline.
|
|
184
|
+
"""
|
|
185
|
+
|
|
186
|
+
def __init__(
|
|
187
|
+
self,
|
|
188
|
+
pipeline_info: PipelineInfo,
|
|
189
|
+
max_batch_size=16,
|
|
190
|
+
device="cuda",
|
|
191
|
+
use_cuda_graph=False,
|
|
192
|
+
):
|
|
193
|
+
"""
|
|
194
|
+
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
|
195
|
+
|
|
196
|
+
Args:
|
|
197
|
+
pipeline_info (PipelineInfo):
|
|
198
|
+
Version and Type of pipeline.
|
|
199
|
+
max_batch_size (int):
|
|
200
|
+
Maximum batch size for dynamic batch engine.
|
|
201
|
+
device (str):
|
|
202
|
+
device to run.
|
|
203
|
+
use_cuda_graph (bool):
|
|
204
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
205
|
+
"""
|
|
206
|
+
super().__init__(
|
|
207
|
+
EngineType.TRT,
|
|
208
|
+
pipeline_info,
|
|
209
|
+
max_batch_size=max_batch_size,
|
|
210
|
+
device=device,
|
|
211
|
+
use_cuda_graph=use_cuda_graph,
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
self.stream = None
|
|
215
|
+
self.shared_device_memory = None
|
|
216
|
+
|
|
217
|
+
def load_resources(self, image_height, image_width, batch_size):
|
|
218
|
+
super().load_resources(image_height, image_width, batch_size)
|
|
219
|
+
|
|
220
|
+
self.stream = _cuda_assert(cudart.cudaStreamCreate())
|
|
221
|
+
|
|
222
|
+
def teardown(self):
|
|
223
|
+
super().teardown()
|
|
224
|
+
|
|
225
|
+
if self.shared_device_memory:
|
|
226
|
+
cudart.cudaFree(self.shared_device_memory)
|
|
227
|
+
|
|
228
|
+
cudart.cudaStreamDestroy(self.stream)
|
|
229
|
+
del self.stream
|
|
230
|
+
|
|
231
|
+
def load_engines(
|
|
232
|
+
self,
|
|
233
|
+
engine_dir,
|
|
234
|
+
framework_model_dir,
|
|
235
|
+
onnx_dir,
|
|
236
|
+
onnx_opset,
|
|
237
|
+
opt_batch_size,
|
|
238
|
+
opt_image_height,
|
|
239
|
+
opt_image_width,
|
|
240
|
+
static_batch=False,
|
|
241
|
+
static_shape=True,
|
|
242
|
+
enable_all_tactics=False,
|
|
243
|
+
timing_cache=None,
|
|
244
|
+
):
|
|
245
|
+
"""
|
|
246
|
+
Build and load engines for TensorRT accelerated inference.
|
|
247
|
+
Export ONNX models first, if applicable.
|
|
248
|
+
|
|
249
|
+
Args:
|
|
250
|
+
engine_dir (str):
|
|
251
|
+
Directory to write the TensorRT engines.
|
|
252
|
+
framework_model_dir (str):
|
|
253
|
+
Directory to write the framework model ckpt.
|
|
254
|
+
onnx_dir (str):
|
|
255
|
+
Directory to write the ONNX models.
|
|
256
|
+
onnx_opset (int):
|
|
257
|
+
ONNX opset version to export the models.
|
|
258
|
+
opt_batch_size (int):
|
|
259
|
+
Batch size to optimize for during engine building.
|
|
260
|
+
opt_image_height (int):
|
|
261
|
+
Image height to optimize for during engine building. Must be a multiple of 8.
|
|
262
|
+
opt_image_width (int):
|
|
263
|
+
Image width to optimize for during engine building. Must be a multiple of 8.
|
|
264
|
+
static_batch (bool):
|
|
265
|
+
Build engine only for specified opt_batch_size.
|
|
266
|
+
static_shape (bool):
|
|
267
|
+
Build engine only for specified opt_image_height & opt_image_width. Default = True.
|
|
268
|
+
enable_all_tactics (bool):
|
|
269
|
+
Enable all tactic sources during TensorRT engine builds.
|
|
270
|
+
timing_cache (str):
|
|
271
|
+
Path to the timing cache to accelerate build or None
|
|
272
|
+
"""
|
|
273
|
+
# Create directory
|
|
274
|
+
for directory in [engine_dir, onnx_dir]:
|
|
275
|
+
if not os.path.exists(directory):
|
|
276
|
+
print(f"[I] Create directory: {directory}")
|
|
277
|
+
pathlib.Path(directory).mkdir(parents=True)
|
|
278
|
+
|
|
279
|
+
self.load_models(framework_model_dir)
|
|
280
|
+
|
|
281
|
+
# Load lora only when we need export text encoder or UNet to ONNX.
|
|
282
|
+
load_lora = False
|
|
283
|
+
if self.pipeline_info.lora_weights:
|
|
284
|
+
for model_name, model_obj in self.models.items():
|
|
285
|
+
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
|
|
286
|
+
continue
|
|
287
|
+
profile_id = model_obj.get_profile_id(
|
|
288
|
+
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
|
|
289
|
+
)
|
|
290
|
+
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
|
291
|
+
if not os.path.exists(engine_path):
|
|
292
|
+
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
|
293
|
+
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
|
294
|
+
if not os.path.exists(onnx_opt_path):
|
|
295
|
+
if not os.path.exists(onnx_path):
|
|
296
|
+
load_lora = True
|
|
297
|
+
break
|
|
298
|
+
|
|
299
|
+
# Export models to ONNX
|
|
300
|
+
self.disable_torch_spda()
|
|
301
|
+
pipe = self.load_pipeline_with_lora() if load_lora else None
|
|
302
|
+
|
|
303
|
+
for model_name, model_obj in self.models.items():
|
|
304
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
305
|
+
continue
|
|
306
|
+
profile_id = model_obj.get_profile_id(
|
|
307
|
+
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
|
|
308
|
+
)
|
|
309
|
+
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
|
310
|
+
if not os.path.exists(engine_path):
|
|
311
|
+
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
|
312
|
+
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
|
313
|
+
if not os.path.exists(onnx_opt_path):
|
|
314
|
+
if not os.path.exists(onnx_path):
|
|
315
|
+
print(f"Exporting model: {onnx_path}")
|
|
316
|
+
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
|
317
|
+
|
|
318
|
+
with torch.inference_mode(), torch.autocast("cuda"):
|
|
319
|
+
inputs = model_obj.get_sample_input(1, opt_image_height, opt_image_width)
|
|
320
|
+
torch.onnx.export(
|
|
321
|
+
model,
|
|
322
|
+
inputs,
|
|
323
|
+
onnx_path,
|
|
324
|
+
export_params=True,
|
|
325
|
+
opset_version=onnx_opset,
|
|
326
|
+
do_constant_folding=True,
|
|
327
|
+
input_names=model_obj.get_input_names(),
|
|
328
|
+
output_names=model_obj.get_output_names(),
|
|
329
|
+
dynamic_axes=model_obj.get_dynamic_axes(),
|
|
330
|
+
)
|
|
331
|
+
del model
|
|
332
|
+
torch.cuda.empty_cache()
|
|
333
|
+
gc.collect()
|
|
334
|
+
else:
|
|
335
|
+
print(f"Found cached model: {onnx_path}")
|
|
336
|
+
|
|
337
|
+
# Optimize onnx
|
|
338
|
+
if not os.path.exists(onnx_opt_path):
|
|
339
|
+
print(f"Generating optimizing model: {onnx_opt_path}")
|
|
340
|
+
model_obj.optimize_trt(onnx_path, onnx_opt_path)
|
|
341
|
+
else:
|
|
342
|
+
print(f"Found cached optimized model: {onnx_opt_path} ")
|
|
343
|
+
self.enable_torch_spda()
|
|
344
|
+
|
|
345
|
+
# Build TensorRT engines
|
|
346
|
+
for model_name, model_obj in self.models.items():
|
|
347
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
348
|
+
continue
|
|
349
|
+
profile_id = model_obj.get_profile_id(
|
|
350
|
+
opt_batch_size, opt_image_height, opt_image_width, static_batch, static_shape
|
|
351
|
+
)
|
|
352
|
+
engine_path = self.get_engine_path(engine_dir, model_name, profile_id)
|
|
353
|
+
engine = TensorrtEngine(engine_path)
|
|
354
|
+
onnx_opt_path = self.get_onnx_path(model_name, onnx_dir, opt=True)
|
|
355
|
+
|
|
356
|
+
if not os.path.exists(engine.engine_path):
|
|
357
|
+
engine.build(
|
|
358
|
+
onnx_opt_path,
|
|
359
|
+
fp16=True,
|
|
360
|
+
input_profile=model_obj.get_input_profile(
|
|
361
|
+
opt_batch_size,
|
|
362
|
+
opt_image_height,
|
|
363
|
+
opt_image_width,
|
|
364
|
+
static_batch,
|
|
365
|
+
static_shape,
|
|
366
|
+
),
|
|
367
|
+
enable_all_tactics=enable_all_tactics,
|
|
368
|
+
timing_cache=timing_cache,
|
|
369
|
+
update_output_names=None,
|
|
370
|
+
)
|
|
371
|
+
self.engines[model_name] = engine
|
|
372
|
+
|
|
373
|
+
# Load TensorRT engines
|
|
374
|
+
for model_name in self.models:
|
|
375
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
376
|
+
continue
|
|
377
|
+
self.engines[model_name].load()
|
|
378
|
+
|
|
379
|
+
def max_device_memory(self):
|
|
380
|
+
max_device_memory = 0
|
|
381
|
+
for engine in self.engines.values():
|
|
382
|
+
max_device_memory = max(max_device_memory, engine.engine.device_memory_size)
|
|
383
|
+
return max_device_memory
|
|
384
|
+
|
|
385
|
+
def activate_engines(self, shared_device_memory=None):
|
|
386
|
+
if shared_device_memory is None:
|
|
387
|
+
max_device_memory = self.max_device_memory()
|
|
388
|
+
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
|
|
389
|
+
self.shared_device_memory = shared_device_memory
|
|
390
|
+
# Load and activate TensorRT engines
|
|
391
|
+
for engine in self.engines.values():
|
|
392
|
+
engine.activate(reuse_device_memory=self.shared_device_memory)
|
|
393
|
+
|
|
394
|
+
def run_engine(self, model_name, feed_dict):
|
|
395
|
+
return self.engines[model_name].infer(feed_dict, self.stream, use_cuda_graph=self.use_cuda_graph)
|