onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,108 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
import logging
|
|
6
|
+
|
|
7
|
+
from diffusion_models import PipelineInfo
|
|
8
|
+
from engine_builder import EngineBuilder, EngineType
|
|
9
|
+
|
|
10
|
+
logger = logging.getLogger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class TorchEngineBuilder(EngineBuilder):
|
|
14
|
+
def __init__(
|
|
15
|
+
self,
|
|
16
|
+
pipeline_info: PipelineInfo,
|
|
17
|
+
max_batch_size=16,
|
|
18
|
+
device="cuda",
|
|
19
|
+
use_cuda_graph=False,
|
|
20
|
+
):
|
|
21
|
+
"""
|
|
22
|
+
Initializes the ONNX Runtime TensorRT ExecutionProvider Engine Builder.
|
|
23
|
+
|
|
24
|
+
Args:
|
|
25
|
+
pipeline_info (PipelineInfo):
|
|
26
|
+
Version and Type of pipeline.
|
|
27
|
+
max_batch_size (int):
|
|
28
|
+
Maximum batch size for dynamic batch engine.
|
|
29
|
+
device (str):
|
|
30
|
+
device to run.
|
|
31
|
+
use_cuda_graph (bool):
|
|
32
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
33
|
+
"""
|
|
34
|
+
super().__init__(
|
|
35
|
+
EngineType.TORCH,
|
|
36
|
+
pipeline_info,
|
|
37
|
+
max_batch_size=max_batch_size,
|
|
38
|
+
device=device,
|
|
39
|
+
use_cuda_graph=use_cuda_graph,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
self.compile_config = {}
|
|
43
|
+
if use_cuda_graph:
|
|
44
|
+
self.compile_config = {
|
|
45
|
+
"clip": {"mode": "reduce-overhead", "dynamic": False},
|
|
46
|
+
"clip2": {"mode": "reduce-overhead", "dynamic": False},
|
|
47
|
+
"unet": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
|
|
48
|
+
"unetxl": {"mode": "reduce-overhead", "fullgraph": True, "dynamic": False},
|
|
49
|
+
"vae": {"mode": "reduce-overhead", "fullgraph": False, "dynamic": False},
|
|
50
|
+
}
|
|
51
|
+
|
|
52
|
+
def build_engines(
|
|
53
|
+
self,
|
|
54
|
+
framework_model_dir: str,
|
|
55
|
+
):
|
|
56
|
+
import torch
|
|
57
|
+
|
|
58
|
+
self.torch_device = torch.device("cuda", torch.cuda.current_device())
|
|
59
|
+
self.load_models(framework_model_dir)
|
|
60
|
+
|
|
61
|
+
pipe = self.load_pipeline_with_lora() if self.pipeline_info.lora_weights else None
|
|
62
|
+
|
|
63
|
+
built_engines = {}
|
|
64
|
+
for model_name, model_obj in self.models.items():
|
|
65
|
+
model = self.get_or_load_model(pipe, model_name, model_obj, framework_model_dir)
|
|
66
|
+
if self.pipeline_info.is_xl() and not self.custom_fp16_vae:
|
|
67
|
+
model = model.to(device=self.torch_device, dtype=torch.float32)
|
|
68
|
+
else:
|
|
69
|
+
model = model.to(device=self.torch_device, dtype=torch.float16)
|
|
70
|
+
|
|
71
|
+
if model_name in self.compile_config:
|
|
72
|
+
compile_config = self.compile_config[model_name]
|
|
73
|
+
if model_name in ["unet", "unetxl"]:
|
|
74
|
+
model.to(memory_format=torch.channels_last)
|
|
75
|
+
engine = torch.compile(model, **compile_config)
|
|
76
|
+
built_engines[model_name] = engine
|
|
77
|
+
else: # eager mode
|
|
78
|
+
built_engines[model_name] = model
|
|
79
|
+
|
|
80
|
+
self.engines = built_engines
|
|
81
|
+
|
|
82
|
+
def run_engine(self, model_name, feed_dict):
|
|
83
|
+
if model_name in ["unet", "unetxl"]:
|
|
84
|
+
if "controlnet_images" in feed_dict:
|
|
85
|
+
return {"latent": self.engines[model_name](**feed_dict)}
|
|
86
|
+
|
|
87
|
+
if model_name == "unetxl":
|
|
88
|
+
added_cond_kwargs = {k: feed_dict[k] for k in feed_dict if k in ["text_embeds", "time_ids"]}
|
|
89
|
+
return {
|
|
90
|
+
"latent": self.engines[model_name](
|
|
91
|
+
feed_dict["sample"],
|
|
92
|
+
feed_dict["timestep"],
|
|
93
|
+
feed_dict["encoder_hidden_states"],
|
|
94
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
95
|
+
return_dict=False,
|
|
96
|
+
)[0]
|
|
97
|
+
}
|
|
98
|
+
|
|
99
|
+
return {
|
|
100
|
+
"latent": self.engines[model_name](
|
|
101
|
+
feed_dict["sample"], feed_dict["timestep"], feed_dict["encoder_hidden_states"], return_dict=False
|
|
102
|
+
)[0]
|
|
103
|
+
}
|
|
104
|
+
|
|
105
|
+
if model_name in ["vae_encoder"]:
|
|
106
|
+
return {"latent": self.engines[model_name](feed_dict["images"])}
|
|
107
|
+
|
|
108
|
+
raise RuntimeError(f"Shall not reach here: {model_name}")
|
|
@@ -0,0 +1,350 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
#
|
|
6
|
+
# This script converts stable diffusion onnx models from float to half (mixed) precision for GPU inference.
|
|
7
|
+
#
|
|
8
|
+
# Before running this script, follow README.md to setup python environment and convert stable diffusion checkpoint
|
|
9
|
+
# to float32 onnx models.
|
|
10
|
+
#
|
|
11
|
+
# For example, the float32 ONNX pipeline is saved to ./sd-v1-5 directory, you can optimize and convert it to float16
|
|
12
|
+
# like the following:
|
|
13
|
+
# python optimize_pipeline.py -i ./sd-v1-5 -o ./sd-v1-5-fp16 --float16
|
|
14
|
+
#
|
|
15
|
+
# Note that the optimizations are carried out for CUDA Execution Provider at first, other EPs may not have the support
|
|
16
|
+
# for the fused operators. The users could disable the operator fusion manually to workaround.
|
|
17
|
+
|
|
18
|
+
import argparse
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import shutil
|
|
22
|
+
import tempfile
|
|
23
|
+
from pathlib import Path
|
|
24
|
+
from typing import List, Optional
|
|
25
|
+
|
|
26
|
+
import __init__ # noqa: F401. Walk-around to run this script directly
|
|
27
|
+
import coloredlogs
|
|
28
|
+
import onnx
|
|
29
|
+
from fusion_options import FusionOptions
|
|
30
|
+
from onnx_model_clip import ClipOnnxModel
|
|
31
|
+
from onnx_model_unet import UnetOnnxModel
|
|
32
|
+
from onnx_model_vae import VaeOnnxModel
|
|
33
|
+
from optimizer import optimize_by_onnxruntime, optimize_model
|
|
34
|
+
from packaging import version
|
|
35
|
+
|
|
36
|
+
import onnxruntime
|
|
37
|
+
|
|
38
|
+
logger = logging.getLogger(__name__)
|
|
39
|
+
|
|
40
|
+
|
|
41
|
+
def has_external_data(onnx_model_path):
|
|
42
|
+
original_model = onnx.load_model(str(onnx_model_path), load_external_data=False)
|
|
43
|
+
for initializer in original_model.graph.initializer:
|
|
44
|
+
if initializer.HasField("data_location") and initializer.data_location == onnx.TensorProto.EXTERNAL:
|
|
45
|
+
return True
|
|
46
|
+
return False
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
def _optimize_sd_pipeline(
|
|
50
|
+
source_dir: Path,
|
|
51
|
+
target_dir: Path,
|
|
52
|
+
use_external_data_format: Optional[bool],
|
|
53
|
+
float16: bool,
|
|
54
|
+
force_fp32_ops: List[str],
|
|
55
|
+
enable_runtime_optimization: bool,
|
|
56
|
+
args,
|
|
57
|
+
):
|
|
58
|
+
"""Optimize onnx models used in stable diffusion onnx pipeline and optionally convert to float16.
|
|
59
|
+
|
|
60
|
+
Args:
|
|
61
|
+
source_dir (Path): Root of input directory of stable diffusion onnx pipeline with float32 models.
|
|
62
|
+
target_dir (Path): Root of output directory of stable diffusion onnx pipeline with optimized models.
|
|
63
|
+
use_external_data_format (Optional[bool]): use external data format.
|
|
64
|
+
float16 (bool): use half precision
|
|
65
|
+
force_fp32_ops(List[str]): operators that are forced to run in float32.
|
|
66
|
+
enable_runtime_optimization(bool): run graph optimization using Onnx Runtime.
|
|
67
|
+
|
|
68
|
+
Raises:
|
|
69
|
+
RuntimeError: input onnx model does not exist
|
|
70
|
+
RuntimeError: output onnx model path existed
|
|
71
|
+
"""
|
|
72
|
+
model_type_mapping = {
|
|
73
|
+
"unet": "unet",
|
|
74
|
+
"vae_encoder": "vae",
|
|
75
|
+
"vae_decoder": "vae",
|
|
76
|
+
"text_encoder": "clip",
|
|
77
|
+
"text_encoder_2": "clip",
|
|
78
|
+
"safety_checker": "unet",
|
|
79
|
+
}
|
|
80
|
+
|
|
81
|
+
model_type_class_mapping = {
|
|
82
|
+
"unet": UnetOnnxModel,
|
|
83
|
+
"vae": VaeOnnxModel,
|
|
84
|
+
"clip": ClipOnnxModel,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
force_fp32_operators = {
|
|
88
|
+
"unet": [],
|
|
89
|
+
"vae_encoder": [],
|
|
90
|
+
"vae_decoder": [],
|
|
91
|
+
"text_encoder": [],
|
|
92
|
+
"text_encoder_2": [],
|
|
93
|
+
"safety_checker": [],
|
|
94
|
+
}
|
|
95
|
+
|
|
96
|
+
is_xl = (source_dir / "text_encoder_2").exists()
|
|
97
|
+
|
|
98
|
+
if force_fp32_ops:
|
|
99
|
+
for fp32_operator in force_fp32_ops:
|
|
100
|
+
parts = fp32_operator.split(":")
|
|
101
|
+
if len(parts) == 2 and parts[0] in force_fp32_operators and (parts[1] and parts[1][0].isupper()):
|
|
102
|
+
force_fp32_operators[parts[0]].append(parts[1])
|
|
103
|
+
else:
|
|
104
|
+
raise ValueError(
|
|
105
|
+
f"--force_fp32_ops shall be in the format of module:operator like unet:Attention, got {fp32_operator}"
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
for name, model_type in model_type_mapping.items():
|
|
109
|
+
onnx_model_path = source_dir / name / "model.onnx"
|
|
110
|
+
if not os.path.exists(onnx_model_path):
|
|
111
|
+
if name != "safety_checker":
|
|
112
|
+
logger.info("input onnx model does not exist: %s", onnx_model_path)
|
|
113
|
+
# some model are optional so we do not raise error here.
|
|
114
|
+
continue
|
|
115
|
+
|
|
116
|
+
# Prepare output directory
|
|
117
|
+
optimized_model_path = target_dir / name / "model.onnx"
|
|
118
|
+
output_dir = optimized_model_path.parent
|
|
119
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
|
120
|
+
|
|
121
|
+
if use_external_data_format is None:
|
|
122
|
+
use_external_data_format = has_external_data(onnx_model_path)
|
|
123
|
+
|
|
124
|
+
# Graph fusion before fp16 conversion, otherwise they cannot be fused later.
|
|
125
|
+
logger.info(f"Optimize {onnx_model_path}...")
|
|
126
|
+
|
|
127
|
+
args.model_type = model_type
|
|
128
|
+
fusion_options = FusionOptions.parse(args)
|
|
129
|
+
|
|
130
|
+
if model_type in ["unet"]:
|
|
131
|
+
# Some optimizations are not available in v1.14 or older version: packed QKV and BiasAdd
|
|
132
|
+
has_all_optimizations = version.parse(onnxruntime.__version__) >= version.parse("1.15.0")
|
|
133
|
+
fusion_options.enable_packed_kv = float16 and fusion_options.enable_packed_kv
|
|
134
|
+
fusion_options.enable_packed_qkv = float16 and has_all_optimizations and fusion_options.enable_packed_qkv
|
|
135
|
+
fusion_options.enable_bias_add = has_all_optimizations and fusion_options.enable_bias_add
|
|
136
|
+
|
|
137
|
+
m = optimize_model(
|
|
138
|
+
str(onnx_model_path),
|
|
139
|
+
model_type=model_type,
|
|
140
|
+
num_heads=0, # will be deduced from graph
|
|
141
|
+
hidden_size=0, # will be deduced from graph
|
|
142
|
+
opt_level=0,
|
|
143
|
+
optimization_options=fusion_options,
|
|
144
|
+
use_gpu=True,
|
|
145
|
+
provider=args.provider,
|
|
146
|
+
)
|
|
147
|
+
|
|
148
|
+
if float16:
|
|
149
|
+
# For SD-XL, use FP16 in VAE decoder will cause NaN and black image so we keep it in FP32.
|
|
150
|
+
if is_xl and name == "vae_decoder":
|
|
151
|
+
logger.info("Skip converting %s to float16 to avoid NaN", name)
|
|
152
|
+
else:
|
|
153
|
+
logger.info("Convert %s to float16 ...", name)
|
|
154
|
+
m.convert_float_to_float16(
|
|
155
|
+
keep_io_types=False,
|
|
156
|
+
op_block_list=force_fp32_operators[name],
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
if enable_runtime_optimization:
|
|
160
|
+
# Use this step to see the final graph that executed by Onnx Runtime.
|
|
161
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
162
|
+
# Save to a temporary file so that we can load it with Onnx Runtime.
|
|
163
|
+
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
|
164
|
+
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
|
165
|
+
m.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
|
|
166
|
+
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
|
|
167
|
+
optimize_by_onnxruntime(
|
|
168
|
+
str(tmp_model_path),
|
|
169
|
+
use_gpu=True,
|
|
170
|
+
provider=args.provider,
|
|
171
|
+
optimized_model_path=str(ort_optimized_model_path),
|
|
172
|
+
save_as_external_data=use_external_data_format,
|
|
173
|
+
)
|
|
174
|
+
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
|
175
|
+
m = model_type_class_mapping[model_type](model)
|
|
176
|
+
|
|
177
|
+
m.get_operator_statistics()
|
|
178
|
+
m.get_fused_operator_statistics()
|
|
179
|
+
m.save_model_to_file(str(optimized_model_path), use_external_data_format=use_external_data_format)
|
|
180
|
+
logger.info("%s is optimized", name)
|
|
181
|
+
logger.info("*" * 20)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
def _copy_extra_directory(source_dir: Path, target_dir: Path):
|
|
185
|
+
"""Copy extra directory that does not have onnx model
|
|
186
|
+
|
|
187
|
+
Args:
|
|
188
|
+
source_dir (Path): source directory
|
|
189
|
+
target_dir (Path): target directory
|
|
190
|
+
|
|
191
|
+
Raises:
|
|
192
|
+
RuntimeError: source path does not exist
|
|
193
|
+
"""
|
|
194
|
+
extra_dirs = ["scheduler", "tokenizer", "tokenizer_2", "feature_extractor"]
|
|
195
|
+
|
|
196
|
+
for name in extra_dirs:
|
|
197
|
+
source_path = source_dir / name
|
|
198
|
+
if not os.path.exists(source_path):
|
|
199
|
+
continue
|
|
200
|
+
|
|
201
|
+
target_path = target_dir / name
|
|
202
|
+
shutil.copytree(source_path, target_path)
|
|
203
|
+
logger.info("%s => %s", source_path, target_path)
|
|
204
|
+
|
|
205
|
+
extra_files = ["model_index.json"]
|
|
206
|
+
for name in extra_files:
|
|
207
|
+
source_path = source_dir / name
|
|
208
|
+
if not os.path.exists(source_path):
|
|
209
|
+
raise RuntimeError(f"source path does not exist: {source_path}")
|
|
210
|
+
|
|
211
|
+
target_path = target_dir / name
|
|
212
|
+
shutil.copyfile(source_path, target_path)
|
|
213
|
+
logger.info("%s => %s", source_path, target_path)
|
|
214
|
+
|
|
215
|
+
# Some directory are optional
|
|
216
|
+
onnx_model_dirs = ["text_encoder", "text_encoder_2", "unet", "vae_encoder", "vae_decoder", "safety_checker"]
|
|
217
|
+
for onnx_model_dir in onnx_model_dirs:
|
|
218
|
+
source_path = source_dir / onnx_model_dir / "config.json"
|
|
219
|
+
target_path = target_dir / onnx_model_dir / "config.json"
|
|
220
|
+
if source_path.exists():
|
|
221
|
+
target_path.parent.mkdir(parents=True, exist_ok=True)
|
|
222
|
+
shutil.copyfile(source_path, target_path)
|
|
223
|
+
logger.info("%s => %s", source_path, target_path)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def optimize_stable_diffusion_pipeline(
|
|
227
|
+
input_dir: str,
|
|
228
|
+
output_dir: str,
|
|
229
|
+
overwrite: bool,
|
|
230
|
+
use_external_data_format: Optional[bool],
|
|
231
|
+
float16: bool,
|
|
232
|
+
enable_runtime_optimization: bool,
|
|
233
|
+
args,
|
|
234
|
+
):
|
|
235
|
+
if os.path.exists(output_dir):
|
|
236
|
+
if overwrite:
|
|
237
|
+
shutil.rmtree(output_dir, ignore_errors=True)
|
|
238
|
+
else:
|
|
239
|
+
raise RuntimeError("output directory existed:{output_dir}. Add --overwrite to empty the directory.")
|
|
240
|
+
|
|
241
|
+
source_dir = Path(input_dir)
|
|
242
|
+
target_dir = Path(output_dir)
|
|
243
|
+
target_dir.mkdir(parents=True, exist_ok=True)
|
|
244
|
+
|
|
245
|
+
_copy_extra_directory(source_dir, target_dir)
|
|
246
|
+
|
|
247
|
+
_optimize_sd_pipeline(
|
|
248
|
+
source_dir,
|
|
249
|
+
target_dir,
|
|
250
|
+
use_external_data_format,
|
|
251
|
+
float16,
|
|
252
|
+
args.force_fp32_ops,
|
|
253
|
+
enable_runtime_optimization,
|
|
254
|
+
args,
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def parse_arguments(argv: Optional[List[str]] = None):
|
|
259
|
+
"""Parse arguments
|
|
260
|
+
|
|
261
|
+
Returns:
|
|
262
|
+
Namespace: arguments
|
|
263
|
+
"""
|
|
264
|
+
parser = argparse.ArgumentParser()
|
|
265
|
+
|
|
266
|
+
parser.add_argument(
|
|
267
|
+
"-i",
|
|
268
|
+
"--input",
|
|
269
|
+
required=True,
|
|
270
|
+
type=str,
|
|
271
|
+
help="Root of input directory of stable diffusion onnx pipeline with float32 models.",
|
|
272
|
+
)
|
|
273
|
+
|
|
274
|
+
parser.add_argument(
|
|
275
|
+
"-o",
|
|
276
|
+
"--output",
|
|
277
|
+
required=True,
|
|
278
|
+
type=str,
|
|
279
|
+
help="Root of output directory of stable diffusion onnx pipeline with optimized models.",
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
parser.add_argument(
|
|
283
|
+
"--float16",
|
|
284
|
+
required=False,
|
|
285
|
+
action="store_true",
|
|
286
|
+
help="Output models of half or mixed precision.",
|
|
287
|
+
)
|
|
288
|
+
parser.set_defaults(float16=False)
|
|
289
|
+
|
|
290
|
+
parser.add_argument(
|
|
291
|
+
"--force_fp32_ops",
|
|
292
|
+
required=False,
|
|
293
|
+
nargs="+",
|
|
294
|
+
type=str,
|
|
295
|
+
help="Force given operators (like unet:Attention) to run in float32. It is case sensitive!",
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
parser.add_argument(
|
|
299
|
+
"--inspect",
|
|
300
|
+
required=False,
|
|
301
|
+
action="store_true",
|
|
302
|
+
help="Save the optimized graph from Onnx Runtime. "
|
|
303
|
+
"This option has no impact on inference performance except it might reduce session creation time.",
|
|
304
|
+
)
|
|
305
|
+
parser.set_defaults(inspect=False)
|
|
306
|
+
|
|
307
|
+
parser.add_argument(
|
|
308
|
+
"--overwrite",
|
|
309
|
+
required=False,
|
|
310
|
+
action="store_true",
|
|
311
|
+
help="Overwrite exists files.",
|
|
312
|
+
)
|
|
313
|
+
parser.set_defaults(overwrite=False)
|
|
314
|
+
|
|
315
|
+
parser.add_argument(
|
|
316
|
+
"-e",
|
|
317
|
+
"--use_external_data_format",
|
|
318
|
+
required=False,
|
|
319
|
+
action="store_true",
|
|
320
|
+
help="Onnx model larger than 2GB need to use external data format. "
|
|
321
|
+
"If specified, save each onnx model to two files: one for onnx graph, another for weights. "
|
|
322
|
+
"If not specified, use same format as original model by default. ",
|
|
323
|
+
)
|
|
324
|
+
parser.set_defaults(use_external_data_format=None)
|
|
325
|
+
|
|
326
|
+
parser.add_argument(
|
|
327
|
+
"--provider",
|
|
328
|
+
required=False,
|
|
329
|
+
type=str,
|
|
330
|
+
default=None,
|
|
331
|
+
help="Execution provider to use.",
|
|
332
|
+
)
|
|
333
|
+
|
|
334
|
+
FusionOptions.add_arguments(parser)
|
|
335
|
+
|
|
336
|
+
args = parser.parse_args(argv)
|
|
337
|
+
return args
|
|
338
|
+
|
|
339
|
+
|
|
340
|
+
def main(argv: Optional[List[str]] = None):
|
|
341
|
+
args = parse_arguments(argv)
|
|
342
|
+
logger.info("Arguments: %s", str(args))
|
|
343
|
+
optimize_stable_diffusion_pipeline(
|
|
344
|
+
args.input, args.output, args.overwrite, args.use_external_data_format, args.float16, args.inspect, args
|
|
345
|
+
)
|
|
346
|
+
|
|
347
|
+
|
|
348
|
+
if __name__ == "__main__":
|
|
349
|
+
coloredlogs.install(fmt="%(funcName)20s: %(message)s")
|
|
350
|
+
main()
|
|
@@ -0,0 +1,136 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
|
|
6
|
+
"""
|
|
7
|
+
ONNX Model Optimizer for Stable Diffusion
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
import gc
|
|
11
|
+
import logging
|
|
12
|
+
import os
|
|
13
|
+
import shutil
|
|
14
|
+
import tempfile
|
|
15
|
+
from pathlib import Path
|
|
16
|
+
|
|
17
|
+
import onnx
|
|
18
|
+
from packaging import version
|
|
19
|
+
|
|
20
|
+
from onnxruntime.transformers.fusion_options import FusionOptions
|
|
21
|
+
from onnxruntime.transformers.onnx_model_clip import ClipOnnxModel
|
|
22
|
+
from onnxruntime.transformers.onnx_model_unet import UnetOnnxModel
|
|
23
|
+
from onnxruntime.transformers.onnx_model_vae import VaeOnnxModel
|
|
24
|
+
from onnxruntime.transformers.optimizer import optimize_by_onnxruntime, optimize_model
|
|
25
|
+
|
|
26
|
+
logger = logging.getLogger(__name__)
|
|
27
|
+
|
|
28
|
+
|
|
29
|
+
class OrtStableDiffusionOptimizer:
|
|
30
|
+
def __init__(self, model_type: str):
|
|
31
|
+
assert model_type in ["vae", "unet", "clip"]
|
|
32
|
+
self.model_type = model_type
|
|
33
|
+
self.model_type_class_mapping = {
|
|
34
|
+
"unet": UnetOnnxModel,
|
|
35
|
+
"vae": VaeOnnxModel,
|
|
36
|
+
"clip": ClipOnnxModel,
|
|
37
|
+
}
|
|
38
|
+
|
|
39
|
+
def _optimize_by_ort(self, onnx_model, use_external_data_format, tmp_dir):
|
|
40
|
+
# Save to a temporary file so that we can load it with Onnx Runtime.
|
|
41
|
+
logger.info("Saving a temporary model to run OnnxRuntime graph optimizations...")
|
|
42
|
+
tmp_model_path = Path(tmp_dir) / "model.onnx"
|
|
43
|
+
onnx_model.save_model_to_file(str(tmp_model_path), use_external_data_format=use_external_data_format)
|
|
44
|
+
|
|
45
|
+
del onnx_model
|
|
46
|
+
gc.collect()
|
|
47
|
+
|
|
48
|
+
ort_optimized_model_path = Path(tmp_dir) / "optimized.onnx"
|
|
49
|
+
optimize_by_onnxruntime(
|
|
50
|
+
str(tmp_model_path),
|
|
51
|
+
use_gpu=True,
|
|
52
|
+
optimized_model_path=str(ort_optimized_model_path),
|
|
53
|
+
save_as_external_data=use_external_data_format,
|
|
54
|
+
external_data_filename="optimized.onnx_data",
|
|
55
|
+
)
|
|
56
|
+
model = onnx.load(str(ort_optimized_model_path), load_external_data=True)
|
|
57
|
+
return self.model_type_class_mapping[self.model_type](model)
|
|
58
|
+
|
|
59
|
+
def optimize_by_ort(self, onnx_model, use_external_data_format=False, tmp_dir=None):
|
|
60
|
+
# Use this step to see the final graph that executed by Onnx Runtime.
|
|
61
|
+
if tmp_dir is None:
|
|
62
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
63
|
+
return self._optimize_by_ort(onnx_model, use_external_data_format, temp_dir)
|
|
64
|
+
else:
|
|
65
|
+
os.makedirs(tmp_dir, exist_ok=True)
|
|
66
|
+
model = self._optimize_by_ort(onnx_model, use_external_data_format, tmp_dir)
|
|
67
|
+
shutil.rmtree(tmp_dir)
|
|
68
|
+
return model
|
|
69
|
+
|
|
70
|
+
def optimize(
|
|
71
|
+
self,
|
|
72
|
+
input_fp32_onnx_path,
|
|
73
|
+
optimized_onnx_path,
|
|
74
|
+
float16=True,
|
|
75
|
+
keep_io_types=False,
|
|
76
|
+
fp32_op_list=None,
|
|
77
|
+
keep_outputs=None,
|
|
78
|
+
optimize_by_ort=True,
|
|
79
|
+
optimize_by_fusion=True,
|
|
80
|
+
final_target_float16=True,
|
|
81
|
+
tmp_dir=None,
|
|
82
|
+
):
|
|
83
|
+
"""Optimize onnx model using ONNX Runtime transformers optimizer"""
|
|
84
|
+
logger.info(f"Optimize {input_fp32_onnx_path}...")
|
|
85
|
+
|
|
86
|
+
if optimize_by_fusion:
|
|
87
|
+
fusion_options = FusionOptions(self.model_type)
|
|
88
|
+
|
|
89
|
+
# It is allowed float16=False and final_target_float16=True, for using fp32 as intermediate optimization step.
|
|
90
|
+
# For rare fp32 use case, we can disable packed kv/qkv since there is no fp32 TRT fused attention kernel.
|
|
91
|
+
if self.model_type in ["unet"] and not final_target_float16:
|
|
92
|
+
fusion_options.enable_packed_kv = False
|
|
93
|
+
fusion_options.enable_packed_qkv = False
|
|
94
|
+
|
|
95
|
+
m = optimize_model(
|
|
96
|
+
input_fp32_onnx_path,
|
|
97
|
+
model_type=self.model_type,
|
|
98
|
+
num_heads=0, # will be deduced from graph
|
|
99
|
+
hidden_size=0, # will be deduced from graph
|
|
100
|
+
opt_level=0,
|
|
101
|
+
optimization_options=fusion_options,
|
|
102
|
+
use_gpu=True,
|
|
103
|
+
)
|
|
104
|
+
else:
|
|
105
|
+
model = onnx.load_model(input_fp32_onnx_path, load_external_data=True)
|
|
106
|
+
m = self.model_type_class_mapping[self.model_type](model)
|
|
107
|
+
|
|
108
|
+
if keep_outputs:
|
|
109
|
+
m.prune_graph(outputs=keep_outputs)
|
|
110
|
+
|
|
111
|
+
model_size = m.model.ByteSize()
|
|
112
|
+
|
|
113
|
+
# model size might be negative (overflow?) in Windows.
|
|
114
|
+
use_external_data_format = model_size <= 0 or model_size >= onnx.checker.MAXIMUM_PROTOBUF
|
|
115
|
+
|
|
116
|
+
# Note that ORT < 1.16 could not save model larger than 2GB.
|
|
117
|
+
# This step is is optional since it has no impact on inference latency.
|
|
118
|
+
# The optimized model is not portable. It could only run in the same execution provider (CUDA EP in this case).
|
|
119
|
+
# When the model has been optimized by onnxruntime, we can disable optimization in SessionOption
|
|
120
|
+
# to save session creation time. Another benefit is to inspect the final graph for developing purpose.
|
|
121
|
+
from onnxruntime import __version__ as ort_version
|
|
122
|
+
|
|
123
|
+
if optimize_by_ort and (version.parse(ort_version) >= version.parse("1.16.0") or not use_external_data_format):
|
|
124
|
+
m = self.optimize_by_ort(m, use_external_data_format=use_external_data_format, tmp_dir=tmp_dir)
|
|
125
|
+
|
|
126
|
+
if float16:
|
|
127
|
+
logger.info("Convert to float16 ...")
|
|
128
|
+
m.convert_float_to_float16(
|
|
129
|
+
keep_io_types=keep_io_types,
|
|
130
|
+
op_block_list=fp32_op_list,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
m.get_operator_statistics()
|
|
134
|
+
m.get_fused_operator_statistics()
|
|
135
|
+
m.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
|
|
136
|
+
logger.info("%s is optimized: %s", self.model_type, optimized_onnx_path)
|