onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,295 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import hashlib
|
|
6
|
+
import os
|
|
7
|
+
from enum import Enum
|
|
8
|
+
|
|
9
|
+
import torch
|
|
10
|
+
from diffusion_models import CLIP, VAE, CLIPWithProj, PipelineInfo, UNet, UNetXL
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class EngineType(Enum):
|
|
14
|
+
ORT_CUDA = 0 # ONNX Runtime CUDA Execution Provider
|
|
15
|
+
ORT_TRT = 1 # ONNX Runtime TensorRT Execution Provider
|
|
16
|
+
TRT = 2 # TensorRT
|
|
17
|
+
TORCH = 3 # PyTorch
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
def get_engine_type(name: str) -> EngineType:
|
|
21
|
+
name_to_type = {
|
|
22
|
+
"ORT_CUDA": EngineType.ORT_CUDA,
|
|
23
|
+
"ORT_TRT": EngineType.ORT_TRT,
|
|
24
|
+
"TRT": EngineType.TRT,
|
|
25
|
+
"TORCH": EngineType.TORCH,
|
|
26
|
+
}
|
|
27
|
+
return name_to_type[name]
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class EngineBuilder:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
engine_type: EngineType,
|
|
34
|
+
pipeline_info: PipelineInfo,
|
|
35
|
+
device="cuda",
|
|
36
|
+
max_batch_size=16,
|
|
37
|
+
use_cuda_graph=False,
|
|
38
|
+
):
|
|
39
|
+
"""
|
|
40
|
+
Initializes the Engine Builder.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
pipeline_info (PipelineInfo):
|
|
44
|
+
Version and Type of pipeline.
|
|
45
|
+
device (str | torch.device):
|
|
46
|
+
device to run engine
|
|
47
|
+
max_batch_size (int):
|
|
48
|
+
Maximum batch size for dynamic batch engine.
|
|
49
|
+
use_cuda_graph (bool):
|
|
50
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
51
|
+
"""
|
|
52
|
+
self.engine_type = engine_type
|
|
53
|
+
self.pipeline_info = pipeline_info
|
|
54
|
+
self.max_batch_size = max_batch_size
|
|
55
|
+
self.use_cuda_graph = use_cuda_graph
|
|
56
|
+
self.device = torch.device(device)
|
|
57
|
+
self.torch_device = torch.device(device, torch.cuda.current_device())
|
|
58
|
+
self.stages = pipeline_info.stages()
|
|
59
|
+
|
|
60
|
+
self.vae_torch_fallback = self.pipeline_info.vae_torch_fallback() and self.engine_type != EngineType.TORCH
|
|
61
|
+
self.custom_fp16_vae = self.pipeline_info.custom_fp16_vae()
|
|
62
|
+
|
|
63
|
+
self.models = {}
|
|
64
|
+
self.engines = {}
|
|
65
|
+
self.torch_models = {}
|
|
66
|
+
self.use_vae_slicing = False
|
|
67
|
+
|
|
68
|
+
self.torch_sdpa = getattr(torch.nn.functional, "scaled_dot_product_attention", None)
|
|
69
|
+
|
|
70
|
+
def enable_vae_slicing(self):
|
|
71
|
+
self.use_vae_slicing = True
|
|
72
|
+
|
|
73
|
+
def disable_torch_spda(self):
|
|
74
|
+
if hasattr(torch.nn.functional, "scaled_dot_product_attention"):
|
|
75
|
+
delattr(torch.nn.functional, "scaled_dot_product_attention")
|
|
76
|
+
|
|
77
|
+
def enable_torch_spda(self):
|
|
78
|
+
if (not hasattr(torch.nn.functional, "scaled_dot_product_attention")) and self.torch_sdpa:
|
|
79
|
+
torch.nn.functional.scaled_dot_product_attention = self.torch_sdpa
|
|
80
|
+
|
|
81
|
+
def teardown(self):
|
|
82
|
+
for engine in self.engines.values():
|
|
83
|
+
del engine
|
|
84
|
+
self.engines = {}
|
|
85
|
+
|
|
86
|
+
def get_diffusers_module_name(self, model_name):
|
|
87
|
+
name_mapping = {
|
|
88
|
+
"clip": "text_encoder",
|
|
89
|
+
"clip2": "text_encoder_2",
|
|
90
|
+
"unet": "unet",
|
|
91
|
+
"unetxl": "unet",
|
|
92
|
+
"vae": "vae_decoder",
|
|
93
|
+
}
|
|
94
|
+
return name_mapping.get(model_name, model_name)
|
|
95
|
+
|
|
96
|
+
def get_cached_model_name(self, model_name):
|
|
97
|
+
model_name = self.get_diffusers_module_name(model_name)
|
|
98
|
+
is_unet = model_name == "unet"
|
|
99
|
+
hash_source = []
|
|
100
|
+
if model_name in ["text_encoder", "text_encoder_2", "unet"] and self.pipeline_info.lora_weights:
|
|
101
|
+
if self.pipeline_info.lora_weights in [
|
|
102
|
+
"latent-consistency/lcm-lora-sdxl",
|
|
103
|
+
"latent-consistency/lcm-lora-sdv1-5",
|
|
104
|
+
]:
|
|
105
|
+
if is_unet:
|
|
106
|
+
model_name = "unet_lcm-lora"
|
|
107
|
+
else:
|
|
108
|
+
model_name = model_name + "_lora"
|
|
109
|
+
hash_source.append(self.pipeline_info.lora_weights)
|
|
110
|
+
|
|
111
|
+
# TODO(tianleiwu): save custom model to a directory named by its original model.
|
|
112
|
+
if is_unet and self.pipeline_info.custom_unet():
|
|
113
|
+
model_name = model_name + "_lcm"
|
|
114
|
+
|
|
115
|
+
if model_name in ["unet"] and self.pipeline_info.controlnet:
|
|
116
|
+
model_name = model_name + "_" + "_".join(self.pipeline_info.controlnet)
|
|
117
|
+
|
|
118
|
+
if hash_source:
|
|
119
|
+
model_name += "_" + hashlib.sha256("\t".join(hash_source).encode("utf-8")).hexdigest()[:8]
|
|
120
|
+
|
|
121
|
+
# TODO: When we support original VAE, we shall save custom VAE to another directory.
|
|
122
|
+
|
|
123
|
+
if self.pipeline_info.is_inpaint():
|
|
124
|
+
model_name += "_inpaint"
|
|
125
|
+
return model_name
|
|
126
|
+
|
|
127
|
+
def get_model_dir(self, model_name, root_dir, opt=True, suffix="", create=True):
|
|
128
|
+
engine_name = self.engine_type.name.lower()
|
|
129
|
+
if engine_name != "ort_cuda" and not suffix:
|
|
130
|
+
suffix = f".{engine_name}" if opt else ""
|
|
131
|
+
directory_name = self.get_cached_model_name(model_name) + suffix
|
|
132
|
+
onnx_model_dir = os.path.join(root_dir, directory_name)
|
|
133
|
+
if create:
|
|
134
|
+
os.makedirs(onnx_model_dir, exist_ok=True)
|
|
135
|
+
return onnx_model_dir
|
|
136
|
+
|
|
137
|
+
def get_onnx_path(self, model_name, onnx_dir, opt=True, suffix=""):
|
|
138
|
+
onnx_model_dir = self.get_model_dir(model_name, onnx_dir, opt=opt, suffix=suffix)
|
|
139
|
+
return os.path.join(onnx_model_dir, "model.onnx")
|
|
140
|
+
|
|
141
|
+
def get_engine_path(self, engine_dir, model_name, profile_id):
|
|
142
|
+
return os.path.join(engine_dir, self.get_cached_model_name(model_name) + profile_id)
|
|
143
|
+
|
|
144
|
+
def load_pipeline_with_lora(self):
|
|
145
|
+
"""Load text encoders and UNet with diffusers pipeline"""
|
|
146
|
+
from diffusers import DiffusionPipeline # noqa: PLC0415
|
|
147
|
+
|
|
148
|
+
pipeline = DiffusionPipeline.from_pretrained(
|
|
149
|
+
self.pipeline_info.name(),
|
|
150
|
+
variant="fp16",
|
|
151
|
+
torch_dtype=torch.float16,
|
|
152
|
+
)
|
|
153
|
+
pipeline.load_lora_weights(self.pipeline_info.lora_weights)
|
|
154
|
+
pipeline.fuse_lora(lora_scale=self.pipeline_info.lora_scale)
|
|
155
|
+
|
|
156
|
+
del pipeline.vae
|
|
157
|
+
pipeline.vae = None
|
|
158
|
+
return pipeline
|
|
159
|
+
|
|
160
|
+
def get_or_load_model(self, pipeline, model_name, model_obj, framework_model_dir):
|
|
161
|
+
if model_name in ["clip", "clip2", "unet", "unetxl"] and pipeline:
|
|
162
|
+
if model_name == "clip":
|
|
163
|
+
model = pipeline.text_encoder
|
|
164
|
+
pipeline.text_encoder = None
|
|
165
|
+
elif model_name == "clip2":
|
|
166
|
+
model = pipeline.text_encoder_2
|
|
167
|
+
pipeline.text_encoder_2 = None
|
|
168
|
+
else:
|
|
169
|
+
model = pipeline.unet
|
|
170
|
+
pipeline.unet = None
|
|
171
|
+
else:
|
|
172
|
+
model = model_obj.load_model(framework_model_dir)
|
|
173
|
+
|
|
174
|
+
return model.to(self.torch_device)
|
|
175
|
+
|
|
176
|
+
def load_models(self, framework_model_dir: str):
|
|
177
|
+
# For TRT or ORT_TRT, we will export fp16 torch model for UNet and VAE
|
|
178
|
+
# For ORT_CUDA, we export fp32 model first, then optimize to fp16.
|
|
179
|
+
export_fp16 = self.engine_type in [EngineType.ORT_TRT, EngineType.TRT]
|
|
180
|
+
|
|
181
|
+
if "clip" in self.stages:
|
|
182
|
+
self.models["clip"] = CLIP(
|
|
183
|
+
self.pipeline_info,
|
|
184
|
+
None, # not loaded yet
|
|
185
|
+
device=self.torch_device,
|
|
186
|
+
max_batch_size=self.max_batch_size,
|
|
187
|
+
clip_skip=0,
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
if "clip2" in self.stages:
|
|
191
|
+
self.models["clip2"] = CLIPWithProj(
|
|
192
|
+
self.pipeline_info,
|
|
193
|
+
None, # not loaded yet
|
|
194
|
+
device=self.torch_device,
|
|
195
|
+
max_batch_size=self.max_batch_size,
|
|
196
|
+
clip_skip=0,
|
|
197
|
+
)
|
|
198
|
+
|
|
199
|
+
if "unet" in self.stages:
|
|
200
|
+
self.models["unet"] = UNet(
|
|
201
|
+
self.pipeline_info,
|
|
202
|
+
None, # not loaded yet
|
|
203
|
+
device=self.torch_device,
|
|
204
|
+
fp16=export_fp16,
|
|
205
|
+
max_batch_size=self.max_batch_size,
|
|
206
|
+
unet_dim=(9 if self.pipeline_info.is_inpaint() else 4),
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
if "unetxl" in self.stages:
|
|
210
|
+
self.models["unetxl"] = UNetXL(
|
|
211
|
+
self.pipeline_info,
|
|
212
|
+
None, # not loaded yet
|
|
213
|
+
device=self.torch_device,
|
|
214
|
+
fp16=export_fp16,
|
|
215
|
+
max_batch_size=self.max_batch_size,
|
|
216
|
+
unet_dim=4,
|
|
217
|
+
time_dim=(5 if self.pipeline_info.is_xl_refiner() else 6),
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# VAE Decoder
|
|
221
|
+
if "vae" in self.stages:
|
|
222
|
+
self.models["vae"] = VAE(
|
|
223
|
+
self.pipeline_info,
|
|
224
|
+
None, # not loaded yet
|
|
225
|
+
device=self.torch_device,
|
|
226
|
+
max_batch_size=self.max_batch_size,
|
|
227
|
+
fp16=export_fp16,
|
|
228
|
+
custom_fp16_vae=self.custom_fp16_vae,
|
|
229
|
+
)
|
|
230
|
+
|
|
231
|
+
if self.vae_torch_fallback:
|
|
232
|
+
self.torch_models["vae"] = self.models["vae"].load_model(framework_model_dir)
|
|
233
|
+
|
|
234
|
+
def load_resources(self, image_height, image_width, batch_size):
|
|
235
|
+
if self.engine_type == EngineType.TORCH:
|
|
236
|
+
return
|
|
237
|
+
|
|
238
|
+
# Allocate buffers for I/O bindings
|
|
239
|
+
for model_name, obj in self.models.items():
|
|
240
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
241
|
+
continue
|
|
242
|
+
slice_size = 1 if (model_name == "vae" and self.use_vae_slicing) else batch_size
|
|
243
|
+
self.engines[model_name].allocate_buffers(
|
|
244
|
+
shape_dict=obj.get_shape_dict(slice_size, image_height, image_width), device=self.torch_device
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
def _vae_decode(self, latents):
|
|
248
|
+
if self.engine_type == EngineType.TORCH:
|
|
249
|
+
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
|
|
250
|
+
latents = latents.to(dtype=torch.float32)
|
|
251
|
+
images = self.engines["vae"](latents)["sample"]
|
|
252
|
+
else:
|
|
253
|
+
images = self.engines["vae"](latents)["sample"]
|
|
254
|
+
elif self.vae_torch_fallback:
|
|
255
|
+
if not self.custom_fp16_vae:
|
|
256
|
+
latents = latents.to(dtype=torch.float32)
|
|
257
|
+
self.torch_models["vae"] = self.torch_models["vae"].to(dtype=torch.float32)
|
|
258
|
+
images = self.torch_models["vae"](latents)["sample"]
|
|
259
|
+
else:
|
|
260
|
+
if self.pipeline_info.is_xl() and not self.custom_fp16_vae: # need upcast
|
|
261
|
+
images = self.run_engine("vae", {"latent": latents.to(dtype=torch.float32)})["images"]
|
|
262
|
+
else:
|
|
263
|
+
images = self.run_engine("vae", {"latent": latents})["images"]
|
|
264
|
+
|
|
265
|
+
return images
|
|
266
|
+
|
|
267
|
+
def vae_decode(self, latents):
|
|
268
|
+
if self.use_vae_slicing:
|
|
269
|
+
# The output tensor points to same buffer. Need clone it to avoid overwritten.
|
|
270
|
+
decoded_slices = [self._vae_decode(z_slice).clone() for z_slice in latents.split(1)]
|
|
271
|
+
return torch.cat(decoded_slices)
|
|
272
|
+
|
|
273
|
+
return self._vae_decode(latents)
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
def get_engine_paths(
|
|
277
|
+
work_dir: str, pipeline_info: PipelineInfo, engine_type: EngineType, framework_model_dir: str | None = None
|
|
278
|
+
):
|
|
279
|
+
root_dir = work_dir or "."
|
|
280
|
+
short_name = pipeline_info.short_name()
|
|
281
|
+
|
|
282
|
+
# When both ORT_CUDA and ORT_TRT/TRT is used, we shall make sub directory for each engine since
|
|
283
|
+
# ORT_CUDA need fp32 torch model, while ORT_TRT/TRT use fp16 torch model.
|
|
284
|
+
onnx_dir = os.path.join(root_dir, engine_type.name, short_name, "onnx")
|
|
285
|
+
engine_dir = os.path.join(root_dir, engine_type.name, short_name, "engine")
|
|
286
|
+
output_dir = os.path.join(root_dir, engine_type.name, short_name, "output")
|
|
287
|
+
|
|
288
|
+
timing_cache = os.path.join(root_dir, engine_type.name, "timing_cache")
|
|
289
|
+
|
|
290
|
+
# Shared among ORT_CUDA, ORT_TRT and TRT engines, and need use load_model(..., always_download_fp16=True)
|
|
291
|
+
# So that the shared model is always fp16.
|
|
292
|
+
if framework_model_dir is None:
|
|
293
|
+
framework_model_dir = os.path.join(root_dir, "torch_model")
|
|
294
|
+
|
|
295
|
+
return onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache
|
|
@@ -0,0 +1,387 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
import gc
|
|
7
|
+
import logging
|
|
8
|
+
import os
|
|
9
|
+
|
|
10
|
+
import onnx
|
|
11
|
+
import torch
|
|
12
|
+
from diffusion_models import PipelineInfo
|
|
13
|
+
from engine_builder import EngineBuilder, EngineType
|
|
14
|
+
from packaging import version
|
|
15
|
+
|
|
16
|
+
import onnxruntime as ort
|
|
17
|
+
from onnxruntime.transformers.io_binding_helper import CudaSession, GpuBindingManager
|
|
18
|
+
from onnxruntime.transformers.onnx_model import OnnxModel
|
|
19
|
+
|
|
20
|
+
logger = logging.getLogger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
class OrtCudaEngine:
|
|
24
|
+
def __init__(
|
|
25
|
+
self,
|
|
26
|
+
onnx_path,
|
|
27
|
+
device_id: int = 0,
|
|
28
|
+
enable_cuda_graph: bool = False,
|
|
29
|
+
disable_optimization: bool = False,
|
|
30
|
+
max_cuda_graphs: int = 1,
|
|
31
|
+
):
|
|
32
|
+
self.onnx_path = onnx_path
|
|
33
|
+
self.provider = "CUDAExecutionProvider"
|
|
34
|
+
self.stream = torch.cuda.current_stream().cuda_stream
|
|
35
|
+
self.provider_options = CudaSession.get_cuda_provider_options(device_id, enable_cuda_graph, self.stream)
|
|
36
|
+
session_options = ort.SessionOptions()
|
|
37
|
+
|
|
38
|
+
# When the model has been optimized by onnxruntime, we can disable optimization to save session creation time.
|
|
39
|
+
if disable_optimization:
|
|
40
|
+
session_options.graph_optimization_level = ort.GraphOptimizationLevel.ORT_DISABLE_ALL
|
|
41
|
+
|
|
42
|
+
logger.info("creating CUDA EP session for %s", onnx_path)
|
|
43
|
+
ort_session = ort.InferenceSession(
|
|
44
|
+
onnx_path,
|
|
45
|
+
session_options,
|
|
46
|
+
providers=[
|
|
47
|
+
(self.provider, self.provider_options),
|
|
48
|
+
"CPUExecutionProvider",
|
|
49
|
+
],
|
|
50
|
+
)
|
|
51
|
+
logger.info("created CUDA EP session for %s", onnx_path)
|
|
52
|
+
|
|
53
|
+
device = torch.device("cuda", device_id)
|
|
54
|
+
self.enable_cuda_graph = enable_cuda_graph
|
|
55
|
+
|
|
56
|
+
# Support multiple CUDA graphs for different input shapes.
|
|
57
|
+
# For clip2 model that disabled cuda graph, max_cuda_graphs is updated to 0 here.
|
|
58
|
+
self.gpu_binding_manager = GpuBindingManager(
|
|
59
|
+
ort_session=ort_session,
|
|
60
|
+
device=device,
|
|
61
|
+
stream=self.stream,
|
|
62
|
+
max_cuda_graphs=max_cuda_graphs if enable_cuda_graph else 0,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.current_gpu_binding = None
|
|
66
|
+
|
|
67
|
+
def metadata(self, name: str):
|
|
68
|
+
data = {}
|
|
69
|
+
if self.current_gpu_binding is not None:
|
|
70
|
+
if self.current_gpu_binding.last_run_gpu_graph_id >= 0:
|
|
71
|
+
data[f"{name}.gpu_graph_id"] = self.current_gpu_binding.last_run_gpu_graph_id
|
|
72
|
+
return data
|
|
73
|
+
|
|
74
|
+
def infer(self, feed_dict: dict[str, torch.Tensor]):
|
|
75
|
+
return self.current_gpu_binding.infer(feed_dict=feed_dict, disable_cuda_graph_in_run=not self.enable_cuda_graph)
|
|
76
|
+
|
|
77
|
+
def allocate_buffers(self, shape_dict, device):
|
|
78
|
+
self.current_gpu_binding = self.gpu_binding_manager.get_binding(
|
|
79
|
+
shape_dict=shape_dict, use_cuda_graph=self.enable_cuda_graph
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class _ModelConfig:
|
|
84
|
+
"""
|
|
85
|
+
Configuration of one model (like Clip, UNet etc) on ONNX export and optimization for CUDA provider.
|
|
86
|
+
For example, if you want to use fp32 in layer normalization, set the following:
|
|
87
|
+
force_fp32_ops=["SkipLayerNormalization", "LayerNormalization"]
|
|
88
|
+
"""
|
|
89
|
+
|
|
90
|
+
def __init__(
|
|
91
|
+
self,
|
|
92
|
+
onnx_opset_version: int,
|
|
93
|
+
use_cuda_graph: bool,
|
|
94
|
+
fp16: bool = True,
|
|
95
|
+
force_fp32_ops: list[str] | None = None,
|
|
96
|
+
optimize_by_ort: bool = True,
|
|
97
|
+
):
|
|
98
|
+
self.onnx_opset_version = onnx_opset_version
|
|
99
|
+
self.use_cuda_graph = use_cuda_graph
|
|
100
|
+
self.fp16 = fp16
|
|
101
|
+
self.force_fp32_ops = force_fp32_ops
|
|
102
|
+
self.optimize_by_ort = optimize_by_ort
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class OrtCudaEngineBuilder(EngineBuilder):
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
pipeline_info: PipelineInfo,
|
|
109
|
+
max_batch_size=16,
|
|
110
|
+
device="cuda",
|
|
111
|
+
use_cuda_graph=False,
|
|
112
|
+
):
|
|
113
|
+
"""
|
|
114
|
+
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
|
115
|
+
|
|
116
|
+
Args:
|
|
117
|
+
pipeline_info (PipelineInfo):
|
|
118
|
+
Version and Type of pipeline.
|
|
119
|
+
max_batch_size (int):
|
|
120
|
+
Maximum batch size for dynamic batch engine.
|
|
121
|
+
device (str):
|
|
122
|
+
device to run.
|
|
123
|
+
use_cuda_graph (bool):
|
|
124
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
125
|
+
"""
|
|
126
|
+
super().__init__(
|
|
127
|
+
EngineType.ORT_CUDA,
|
|
128
|
+
pipeline_info,
|
|
129
|
+
max_batch_size=max_batch_size,
|
|
130
|
+
device=device,
|
|
131
|
+
use_cuda_graph=use_cuda_graph,
|
|
132
|
+
)
|
|
133
|
+
|
|
134
|
+
self.model_config = {}
|
|
135
|
+
|
|
136
|
+
def _configure(
|
|
137
|
+
self,
|
|
138
|
+
model_name: str,
|
|
139
|
+
onnx_opset_version: int,
|
|
140
|
+
use_cuda_graph: bool,
|
|
141
|
+
fp16: bool = True,
|
|
142
|
+
force_fp32_ops: list[str] | None = None,
|
|
143
|
+
optimize_by_ort: bool = True,
|
|
144
|
+
):
|
|
145
|
+
self.model_config[model_name] = _ModelConfig(
|
|
146
|
+
onnx_opset_version,
|
|
147
|
+
use_cuda_graph,
|
|
148
|
+
fp16=fp16,
|
|
149
|
+
force_fp32_ops=force_fp32_ops,
|
|
150
|
+
optimize_by_ort=optimize_by_ort,
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
def configure_xl(self, onnx_opset_version: int):
|
|
154
|
+
self._configure(
|
|
155
|
+
"clip",
|
|
156
|
+
onnx_opset_version=onnx_opset_version,
|
|
157
|
+
use_cuda_graph=self.use_cuda_graph,
|
|
158
|
+
)
|
|
159
|
+
self._configure(
|
|
160
|
+
"clip2",
|
|
161
|
+
onnx_opset_version=onnx_opset_version, # TODO: ArgMax-12 is not implemented in CUDA
|
|
162
|
+
use_cuda_graph=False, # TODO: fix Runtime Error with cuda graph
|
|
163
|
+
)
|
|
164
|
+
self._configure(
|
|
165
|
+
"unetxl",
|
|
166
|
+
onnx_opset_version=onnx_opset_version,
|
|
167
|
+
use_cuda_graph=self.use_cuda_graph,
|
|
168
|
+
)
|
|
169
|
+
|
|
170
|
+
self._configure(
|
|
171
|
+
"vae",
|
|
172
|
+
onnx_opset_version=onnx_opset_version,
|
|
173
|
+
use_cuda_graph=self.use_cuda_graph,
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
def optimized_onnx_path(self, engine_dir, model_name):
|
|
177
|
+
suffix = "" if self.model_config[model_name].fp16 else ".fp32"
|
|
178
|
+
return self.get_onnx_path(model_name, engine_dir, opt=True, suffix=suffix)
|
|
179
|
+
|
|
180
|
+
def import_diffusers_engine(self, diffusers_onnx_dir: str, engine_dir: str):
|
|
181
|
+
"""Import optimized onnx models for diffusers from Olive or optimize_pipeline tools.
|
|
182
|
+
|
|
183
|
+
Args:
|
|
184
|
+
diffusers_onnx_dir (str): optimized onnx directory of Olive
|
|
185
|
+
engine_dir (str): the directory to store imported onnx
|
|
186
|
+
"""
|
|
187
|
+
if version.parse(ort.__version__) < version.parse("1.17.0"):
|
|
188
|
+
print("Skip importing since onnxruntime-gpu version < 1.17.0.")
|
|
189
|
+
return
|
|
190
|
+
|
|
191
|
+
for model_name, model_obj in self.models.items():
|
|
192
|
+
onnx_import_path = self.optimized_onnx_path(diffusers_onnx_dir, model_name)
|
|
193
|
+
if not os.path.exists(onnx_import_path):
|
|
194
|
+
print(f"{onnx_import_path} not existed. Skip importing.")
|
|
195
|
+
continue
|
|
196
|
+
|
|
197
|
+
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
|
198
|
+
if os.path.exists(onnx_opt_path):
|
|
199
|
+
print(f"{onnx_opt_path} existed. Skip importing.")
|
|
200
|
+
continue
|
|
201
|
+
|
|
202
|
+
if model_name == "vae" and self.pipeline_info.is_xl():
|
|
203
|
+
print(f"Skip importing VAE since it is not fully compatible with float16: {onnx_import_path}.")
|
|
204
|
+
continue
|
|
205
|
+
|
|
206
|
+
model = OnnxModel(onnx.load(onnx_import_path, load_external_data=True))
|
|
207
|
+
|
|
208
|
+
if model_name in ["clip", "clip2"]:
|
|
209
|
+
hidden_states_per_layer = []
|
|
210
|
+
for output in model.graph().output:
|
|
211
|
+
if output.name.startswith("hidden_states."):
|
|
212
|
+
hidden_states_per_layer.append(output.name)
|
|
213
|
+
if hidden_states_per_layer:
|
|
214
|
+
kept_hidden_states = hidden_states_per_layer[-2 - model_obj.clip_skip]
|
|
215
|
+
model.rename_graph_output(kept_hidden_states, "hidden_states")
|
|
216
|
+
|
|
217
|
+
model.rename_graph_output(
|
|
218
|
+
"last_hidden_state" if model_name == "clip" else "text_embeds", "text_embeddings"
|
|
219
|
+
)
|
|
220
|
+
model.prune_graph(
|
|
221
|
+
["text_embeddings", "hidden_states"] if hidden_states_per_layer else ["text_embeddings"]
|
|
222
|
+
)
|
|
223
|
+
|
|
224
|
+
if model_name == "clip2":
|
|
225
|
+
model.change_graph_input_type(model.find_graph_input("input_ids"), onnx.TensorProto.INT32)
|
|
226
|
+
|
|
227
|
+
model.save_model_to_file(onnx_opt_path, use_external_data_format=(model_name == "clip2"))
|
|
228
|
+
elif model_name in ["unet", "unetxl"]:
|
|
229
|
+
model.rename_graph_output("out_sample", "latent")
|
|
230
|
+
model.save_model_to_file(onnx_opt_path, use_external_data_format=True)
|
|
231
|
+
|
|
232
|
+
del model
|
|
233
|
+
continue
|
|
234
|
+
|
|
235
|
+
def build_engines(
|
|
236
|
+
self,
|
|
237
|
+
engine_dir: str,
|
|
238
|
+
framework_model_dir: str,
|
|
239
|
+
onnx_dir: str,
|
|
240
|
+
tmp_dir: str | None = None,
|
|
241
|
+
onnx_opset_version: int = 17,
|
|
242
|
+
device_id: int = 0,
|
|
243
|
+
save_fp32_intermediate_model: bool = False,
|
|
244
|
+
import_engine_dir: str | None = None,
|
|
245
|
+
max_cuda_graphs: int = 1,
|
|
246
|
+
):
|
|
247
|
+
self.torch_device = torch.device("cuda", device_id)
|
|
248
|
+
self.load_models(framework_model_dir)
|
|
249
|
+
|
|
250
|
+
if not os.path.isdir(engine_dir):
|
|
251
|
+
os.makedirs(engine_dir)
|
|
252
|
+
|
|
253
|
+
if not os.path.isdir(onnx_dir):
|
|
254
|
+
os.makedirs(onnx_dir)
|
|
255
|
+
|
|
256
|
+
# Add default configuration if missing
|
|
257
|
+
if self.pipeline_info.is_xl():
|
|
258
|
+
self.configure_xl(onnx_opset_version)
|
|
259
|
+
for model_name in self.models:
|
|
260
|
+
if model_name not in self.model_config:
|
|
261
|
+
self.model_config[model_name] = _ModelConfig(onnx_opset_version, self.use_cuda_graph)
|
|
262
|
+
|
|
263
|
+
# Import Engine
|
|
264
|
+
if import_engine_dir:
|
|
265
|
+
if self.pipeline_info.is_xl():
|
|
266
|
+
self.import_diffusers_engine(import_engine_dir, engine_dir)
|
|
267
|
+
else:
|
|
268
|
+
print(f"Only support importing SDXL onnx. Ignore --engine-dir {import_engine_dir}")
|
|
269
|
+
|
|
270
|
+
# Load lora only when we need export text encoder or UNet to ONNX.
|
|
271
|
+
load_lora = False
|
|
272
|
+
if self.pipeline_info.lora_weights:
|
|
273
|
+
for model_name in self.models:
|
|
274
|
+
if model_name not in ["clip", "clip2", "unet", "unetxl"]:
|
|
275
|
+
continue
|
|
276
|
+
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
|
277
|
+
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
|
278
|
+
if not os.path.exists(onnx_opt_path):
|
|
279
|
+
if not os.path.exists(onnx_path):
|
|
280
|
+
load_lora = True
|
|
281
|
+
break
|
|
282
|
+
|
|
283
|
+
# Export models to ONNX
|
|
284
|
+
self.disable_torch_spda()
|
|
285
|
+
pipe = self.load_pipeline_with_lora() if load_lora else None
|
|
286
|
+
|
|
287
|
+
for model_name, model_obj in self.models.items():
|
|
288
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
289
|
+
continue
|
|
290
|
+
|
|
291
|
+
onnx_path = self.get_onnx_path(model_name, onnx_dir, opt=False)
|
|
292
|
+
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
|
293
|
+
if not os.path.exists(onnx_opt_path):
|
|
294
|
+
if not os.path.exists(onnx_path):
|
|
295
|
+
print("----")
|
|
296
|
+
logger.info("Exporting model: %s", onnx_path)
|
|
297
|
+
|
|
298
|
+
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
|
299
|
+
model = model.to(torch.float32)
|
|
300
|
+
|
|
301
|
+
with torch.inference_mode():
|
|
302
|
+
# For CUDA EP, export FP32 onnx since some graph fusion only supports fp32 graph pattern.
|
|
303
|
+
# Export model with sample of batch size 1, image size 512 x 512
|
|
304
|
+
inputs = model_obj.get_sample_input(1, 512, 512)
|
|
305
|
+
|
|
306
|
+
torch.onnx.export(
|
|
307
|
+
model,
|
|
308
|
+
inputs,
|
|
309
|
+
onnx_path,
|
|
310
|
+
export_params=True,
|
|
311
|
+
opset_version=self.model_config[model_name].onnx_opset_version,
|
|
312
|
+
do_constant_folding=True,
|
|
313
|
+
input_names=model_obj.get_input_names(),
|
|
314
|
+
output_names=model_obj.get_output_names(),
|
|
315
|
+
dynamic_axes=model_obj.get_dynamic_axes(),
|
|
316
|
+
)
|
|
317
|
+
del model
|
|
318
|
+
torch.cuda.empty_cache()
|
|
319
|
+
gc.collect()
|
|
320
|
+
else:
|
|
321
|
+
logger.info("Found cached model: %s", onnx_path)
|
|
322
|
+
|
|
323
|
+
# Generate fp32 optimized model.
|
|
324
|
+
# If final target is fp16 model, we save fp32 optimized model so that it is easy to tune
|
|
325
|
+
# fp16 conversion. That could save a lot of time in developing.
|
|
326
|
+
use_fp32_intermediate = save_fp32_intermediate_model and self.model_config[model_name].fp16
|
|
327
|
+
onnx_fp32_path = onnx_path
|
|
328
|
+
if use_fp32_intermediate:
|
|
329
|
+
onnx_fp32_path = self.get_onnx_path(model_name, engine_dir, opt=True, suffix=".fp32")
|
|
330
|
+
if not os.path.exists(onnx_fp32_path):
|
|
331
|
+
print("------")
|
|
332
|
+
logger.info("Generating optimized model: %s", onnx_fp32_path)
|
|
333
|
+
model_obj.optimize_ort(
|
|
334
|
+
onnx_path,
|
|
335
|
+
onnx_fp32_path,
|
|
336
|
+
to_fp16=False,
|
|
337
|
+
fp32_op_list=self.model_config[model_name].force_fp32_ops,
|
|
338
|
+
optimize_by_ort=self.model_config[model_name].optimize_by_ort,
|
|
339
|
+
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".fp32", create=False),
|
|
340
|
+
)
|
|
341
|
+
else:
|
|
342
|
+
logger.info("Found cached optimized model: %s", onnx_fp32_path)
|
|
343
|
+
|
|
344
|
+
# Generate the final optimized model.
|
|
345
|
+
if not os.path.exists(onnx_opt_path):
|
|
346
|
+
print("------")
|
|
347
|
+
logger.info("Generating optimized model: %s", onnx_opt_path)
|
|
348
|
+
|
|
349
|
+
# When there is fp32 intermediate optimized model, this will just convert model from fp32 to fp16.
|
|
350
|
+
optimize_by_ort = False if use_fp32_intermediate else self.model_config[model_name].optimize_by_ort
|
|
351
|
+
|
|
352
|
+
model_obj.optimize_ort(
|
|
353
|
+
onnx_fp32_path,
|
|
354
|
+
onnx_opt_path,
|
|
355
|
+
to_fp16=self.model_config[model_name].fp16,
|
|
356
|
+
fp32_op_list=self.model_config[model_name].force_fp32_ops,
|
|
357
|
+
optimize_by_ort=optimize_by_ort,
|
|
358
|
+
optimize_by_fusion=not use_fp32_intermediate,
|
|
359
|
+
tmp_dir=self.get_model_dir(model_name, tmp_dir, opt=False, suffix=".ort", create=False),
|
|
360
|
+
)
|
|
361
|
+
else:
|
|
362
|
+
logger.info("Found cached optimized model: %s", onnx_opt_path)
|
|
363
|
+
self.enable_torch_spda()
|
|
364
|
+
|
|
365
|
+
built_engines = {}
|
|
366
|
+
for model_name in self.models:
|
|
367
|
+
if model_name == "vae" and self.vae_torch_fallback:
|
|
368
|
+
continue
|
|
369
|
+
|
|
370
|
+
onnx_opt_path = self.optimized_onnx_path(engine_dir, model_name)
|
|
371
|
+
use_cuda_graph = self.model_config[model_name].use_cuda_graph
|
|
372
|
+
|
|
373
|
+
engine = OrtCudaEngine(
|
|
374
|
+
onnx_opt_path,
|
|
375
|
+
device_id=device_id,
|
|
376
|
+
enable_cuda_graph=use_cuda_graph,
|
|
377
|
+
disable_optimization=False,
|
|
378
|
+
max_cuda_graphs=max_cuda_graphs,
|
|
379
|
+
)
|
|
380
|
+
|
|
381
|
+
logger.info("%s options for %s: %s", engine.provider, model_name, engine.provider_options)
|
|
382
|
+
built_engines[model_name] = engine
|
|
383
|
+
|
|
384
|
+
self.engines = built_engines
|
|
385
|
+
|
|
386
|
+
def run_engine(self, model_name, feed_dict):
|
|
387
|
+
return self.engines[model_name].infer(feed_dict)
|