onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,778 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
# Modified from TensorRT demo diffusion, which has the following license:
|
|
6
|
+
#
|
|
7
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
8
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
9
|
+
#
|
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
11
|
+
# you may not use this file except in compliance with the License.
|
|
12
|
+
# You may obtain a copy of the License at
|
|
13
|
+
#
|
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
15
|
+
#
|
|
16
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
17
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
|
+
# See the License for the specific language governing permissions and
|
|
20
|
+
# limitations under the License.
|
|
21
|
+
# --------------------------------------------------------------------------
|
|
22
|
+
import argparse
|
|
23
|
+
import os
|
|
24
|
+
import sys
|
|
25
|
+
from importlib.metadata import PackageNotFoundError, version
|
|
26
|
+
from typing import Any, Dict, List, Optional
|
|
27
|
+
|
|
28
|
+
import controlnet_aux
|
|
29
|
+
import cv2
|
|
30
|
+
import numpy as np
|
|
31
|
+
import torch
|
|
32
|
+
from cuda import cudart
|
|
33
|
+
from diffusion_models import PipelineInfo
|
|
34
|
+
from engine_builder import EngineType, get_engine_paths, get_engine_type
|
|
35
|
+
from PIL import Image
|
|
36
|
+
from pipeline_stable_diffusion import StableDiffusionPipeline
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class RawTextArgumentDefaultsHelpFormatter(argparse.ArgumentDefaultsHelpFormatter, argparse.RawTextHelpFormatter):
|
|
40
|
+
pass
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def arg_parser(description: str):
|
|
44
|
+
return argparse.ArgumentParser(
|
|
45
|
+
description=description,
|
|
46
|
+
formatter_class=RawTextArgumentDefaultsHelpFormatter,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
def set_default_arguments(args):
|
|
51
|
+
# set default value for some arguments if not provided
|
|
52
|
+
if args.height is None:
|
|
53
|
+
args.height = PipelineInfo.default_resolution(args.version)
|
|
54
|
+
|
|
55
|
+
if args.width is None:
|
|
56
|
+
args.width = PipelineInfo.default_resolution(args.version)
|
|
57
|
+
|
|
58
|
+
is_lcm = (args.version == "xl-1.0" and args.lcm) or "lcm" in args.lora_weights
|
|
59
|
+
is_turbo = args.version in ["sd-turbo", "xl-turbo"]
|
|
60
|
+
if args.denoising_steps is None:
|
|
61
|
+
args.denoising_steps = 4 if is_turbo else 8 if is_lcm else (30 if args.version == "xl-1.0" else 50)
|
|
62
|
+
|
|
63
|
+
if args.scheduler is None:
|
|
64
|
+
args.scheduler = "LCM" if (is_lcm or is_turbo) else ("EulerA" if args.version == "xl-1.0" else "DDIM")
|
|
65
|
+
|
|
66
|
+
if args.guidance is None:
|
|
67
|
+
args.guidance = 0.0 if (is_lcm or is_turbo) else (5.0 if args.version == "xl-1.0" else 7.5)
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
def parse_arguments(is_xl: bool, parser):
|
|
71
|
+
engines = ["ORT_CUDA", "ORT_TRT", "TRT", "TORCH"]
|
|
72
|
+
|
|
73
|
+
parser.add_argument(
|
|
74
|
+
"-e",
|
|
75
|
+
"--engine",
|
|
76
|
+
type=str,
|
|
77
|
+
default=engines[0],
|
|
78
|
+
choices=engines,
|
|
79
|
+
help="Backend engine in {engines}. "
|
|
80
|
+
"ORT_CUDA is CUDA execution provider; ORT_TRT is Tensorrt execution provider; TRT is TensorRT",
|
|
81
|
+
)
|
|
82
|
+
|
|
83
|
+
supported_versions = PipelineInfo.supported_versions(is_xl)
|
|
84
|
+
parser.add_argument(
|
|
85
|
+
"-v",
|
|
86
|
+
"--version",
|
|
87
|
+
type=str,
|
|
88
|
+
default="xl-1.0" if is_xl else "1.5",
|
|
89
|
+
choices=supported_versions,
|
|
90
|
+
help="Version of Stable Diffusion" + (" XL." if is_xl else "."),
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
parser.add_argument(
|
|
94
|
+
"-y",
|
|
95
|
+
"--height",
|
|
96
|
+
type=int,
|
|
97
|
+
default=None,
|
|
98
|
+
help="Height of image to generate (must be multiple of 8).",
|
|
99
|
+
)
|
|
100
|
+
parser.add_argument(
|
|
101
|
+
"-x", "--width", type=int, default=None, help="Height of image to generate (must be multiple of 8)."
|
|
102
|
+
)
|
|
103
|
+
|
|
104
|
+
parser.add_argument(
|
|
105
|
+
"-s",
|
|
106
|
+
"--scheduler",
|
|
107
|
+
type=str,
|
|
108
|
+
default=None,
|
|
109
|
+
choices=["DDIM", "EulerA", "UniPC", "LCM"],
|
|
110
|
+
help="Scheduler for diffusion process" + " of base" if is_xl else "",
|
|
111
|
+
)
|
|
112
|
+
|
|
113
|
+
parser.add_argument(
|
|
114
|
+
"-wd",
|
|
115
|
+
"--work-dir",
|
|
116
|
+
default=".",
|
|
117
|
+
help="Root Directory to store torch or ONNX models, built engines and output images etc.",
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
parser.add_argument(
|
|
121
|
+
"-i",
|
|
122
|
+
"--engine-dir",
|
|
123
|
+
default=None,
|
|
124
|
+
help="Root Directory to store built engines or optimized ONNX models etc.",
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
parser.add_argument("prompt", nargs="*", default=[""], help="Text prompt(s) to guide image generation.")
|
|
128
|
+
|
|
129
|
+
parser.add_argument(
|
|
130
|
+
"-n",
|
|
131
|
+
"--negative-prompt",
|
|
132
|
+
nargs="*",
|
|
133
|
+
default=[""],
|
|
134
|
+
help="Optional negative prompt(s) to guide the image generation.",
|
|
135
|
+
)
|
|
136
|
+
parser.add_argument(
|
|
137
|
+
"-b",
|
|
138
|
+
"--batch-size",
|
|
139
|
+
type=int,
|
|
140
|
+
default=1,
|
|
141
|
+
choices=[1, 2, 4, 8, 16],
|
|
142
|
+
help="Number of times to repeat the prompt (batch size multiplier).",
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
parser.add_argument(
|
|
146
|
+
"-d",
|
|
147
|
+
"--denoising-steps",
|
|
148
|
+
type=int,
|
|
149
|
+
default=None,
|
|
150
|
+
help="Number of denoising steps" + (" in base." if is_xl else "."),
|
|
151
|
+
)
|
|
152
|
+
|
|
153
|
+
parser.add_argument(
|
|
154
|
+
"-g",
|
|
155
|
+
"--guidance",
|
|
156
|
+
type=float,
|
|
157
|
+
default=None,
|
|
158
|
+
help="Higher guidance scale encourages to generate images that are closely linked to the text prompt.",
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
parser.add_argument(
|
|
162
|
+
"-ls", "--lora-scale", type=float, default=1, help="Scale of LoRA weights, default 1 (must between 0 and 1)"
|
|
163
|
+
)
|
|
164
|
+
parser.add_argument("-lw", "--lora-weights", type=str, default="", help="LoRA weights to apply in the base model")
|
|
165
|
+
|
|
166
|
+
if is_xl:
|
|
167
|
+
parser.add_argument(
|
|
168
|
+
"--lcm",
|
|
169
|
+
action="store_true",
|
|
170
|
+
help="Use fine-tuned latent consistency model to replace the UNet in base.",
|
|
171
|
+
)
|
|
172
|
+
|
|
173
|
+
parser.add_argument(
|
|
174
|
+
"-rs",
|
|
175
|
+
"--refiner-scheduler",
|
|
176
|
+
type=str,
|
|
177
|
+
default="EulerA",
|
|
178
|
+
choices=["DDIM", "EulerA", "UniPC"],
|
|
179
|
+
help="Scheduler for diffusion process of refiner.",
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
parser.add_argument(
|
|
183
|
+
"-rg",
|
|
184
|
+
"--refiner-guidance",
|
|
185
|
+
type=float,
|
|
186
|
+
default=5.0,
|
|
187
|
+
help="Guidance scale used in refiner.",
|
|
188
|
+
)
|
|
189
|
+
|
|
190
|
+
parser.add_argument(
|
|
191
|
+
"-rd",
|
|
192
|
+
"--refiner-denoising-steps",
|
|
193
|
+
type=int,
|
|
194
|
+
default=30,
|
|
195
|
+
help="Number of denoising steps in refiner. Note that actual steps is refiner_denoising_steps * strength.",
|
|
196
|
+
)
|
|
197
|
+
|
|
198
|
+
parser.add_argument(
|
|
199
|
+
"--strength",
|
|
200
|
+
type=float,
|
|
201
|
+
default=0.3,
|
|
202
|
+
help="A value between 0 and 1. The higher the value less the final image similar to the seed image.",
|
|
203
|
+
)
|
|
204
|
+
|
|
205
|
+
parser.add_argument(
|
|
206
|
+
"-r",
|
|
207
|
+
"--enable-refiner",
|
|
208
|
+
action="store_true",
|
|
209
|
+
help="Enable SDXL refiner to refine image from base pipeline.",
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# ONNX export
|
|
213
|
+
parser.add_argument(
|
|
214
|
+
"--onnx-opset",
|
|
215
|
+
type=int,
|
|
216
|
+
default=None,
|
|
217
|
+
choices=range(14, 18),
|
|
218
|
+
help="Select ONNX opset version to target for exported models.",
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
# Engine build options.
|
|
222
|
+
parser.add_argument(
|
|
223
|
+
"-db",
|
|
224
|
+
"--build-dynamic-batch",
|
|
225
|
+
action="store_true",
|
|
226
|
+
help="Build TensorRT engines to support dynamic batch size.",
|
|
227
|
+
)
|
|
228
|
+
parser.add_argument(
|
|
229
|
+
"-ds",
|
|
230
|
+
"--build-dynamic-shape",
|
|
231
|
+
action="store_true",
|
|
232
|
+
help="Build TensorRT engines to support dynamic image sizes.",
|
|
233
|
+
)
|
|
234
|
+
parser.add_argument("--max-batch-size", type=int, default=None, choices=[1, 2, 4, 8, 16, 32], help="Max batch size")
|
|
235
|
+
|
|
236
|
+
# Inference related options
|
|
237
|
+
parser.add_argument(
|
|
238
|
+
"-nw", "--num-warmup-runs", type=int, default=5, help="Number of warmup runs before benchmarking performance."
|
|
239
|
+
)
|
|
240
|
+
parser.add_argument("--nvtx-profile", action="store_true", help="Enable NVTX markers for performance profiling.")
|
|
241
|
+
parser.add_argument("--seed", type=int, default=None, help="Seed for random generator to get consistent results.")
|
|
242
|
+
parser.add_argument("--deterministic", action="store_true", help="use deterministic algorithms.")
|
|
243
|
+
parser.add_argument("-dc", "--disable-cuda-graph", action="store_true", help="Disable cuda graph.")
|
|
244
|
+
|
|
245
|
+
parser.add_argument("--framework-model-dir", default=None, help="framework model directory")
|
|
246
|
+
|
|
247
|
+
group = parser.add_argument_group("Options for ORT_CUDA engine only")
|
|
248
|
+
group.add_argument("--enable-vae-slicing", action="store_true", help="True will feed only one image to VAE once.")
|
|
249
|
+
group.add_argument("--max-cuda-graphs", type=int, default=1, help="Max number of cuda graphs to use. Default 1.")
|
|
250
|
+
group.add_argument("--user-compute-stream", action="store_true", help="Use user compute stream.")
|
|
251
|
+
|
|
252
|
+
# TensorRT only options
|
|
253
|
+
group = parser.add_argument_group("Options for TensorRT (--engine=TRT) only")
|
|
254
|
+
group.add_argument(
|
|
255
|
+
"--build-all-tactics", action="store_true", help="Build TensorRT engines using all tactic sources."
|
|
256
|
+
)
|
|
257
|
+
|
|
258
|
+
args = parser.parse_args()
|
|
259
|
+
|
|
260
|
+
set_default_arguments(args)
|
|
261
|
+
|
|
262
|
+
# Validate image dimensions
|
|
263
|
+
if args.height % 64 != 0 or args.width % 64 != 0:
|
|
264
|
+
raise ValueError(
|
|
265
|
+
f"Image height and width have to be divisible by 64 but specified as: {args.height} and {args.width}."
|
|
266
|
+
)
|
|
267
|
+
|
|
268
|
+
if (args.build_dynamic_batch or args.build_dynamic_shape) and not args.disable_cuda_graph:
|
|
269
|
+
print("[I] CUDA Graph is disabled since dynamic input shape is configured.")
|
|
270
|
+
args.disable_cuda_graph = True
|
|
271
|
+
|
|
272
|
+
if args.onnx_opset is None:
|
|
273
|
+
args.onnx_opset = 14 if args.engine == "ORT_CUDA" else 17
|
|
274
|
+
|
|
275
|
+
if is_xl:
|
|
276
|
+
if args.version == "xl-turbo":
|
|
277
|
+
if args.lcm:
|
|
278
|
+
print("[I] sdxl-turbo cannot use with LCM.")
|
|
279
|
+
args.lcm = False
|
|
280
|
+
|
|
281
|
+
assert args.strength > 0.0 and args.strength < 1.0
|
|
282
|
+
|
|
283
|
+
assert not (args.lcm and args.lora_weights), "it is not supported to use both lcm unet and Lora together"
|
|
284
|
+
|
|
285
|
+
if args.scheduler == "LCM":
|
|
286
|
+
if args.guidance > 2.0:
|
|
287
|
+
print("[I] Use --guidance=0.0 (no more than 2.0) when LCM scheduler is used.")
|
|
288
|
+
args.guidance = 0.0
|
|
289
|
+
if args.denoising_steps > 16:
|
|
290
|
+
print("[I] Use --denoising_steps=8 (no more than 16) when LCM scheduler is used.")
|
|
291
|
+
args.denoising_steps = 8
|
|
292
|
+
|
|
293
|
+
print(args)
|
|
294
|
+
|
|
295
|
+
return args
|
|
296
|
+
|
|
297
|
+
|
|
298
|
+
def max_batch(args):
|
|
299
|
+
if args.max_batch_size:
|
|
300
|
+
max_batch_size = args.max_batch_size
|
|
301
|
+
else:
|
|
302
|
+
do_classifier_free_guidance = args.guidance > 1.0
|
|
303
|
+
batch_multiplier = 2 if do_classifier_free_guidance else 1
|
|
304
|
+
max_batch_size = 32 // batch_multiplier
|
|
305
|
+
if args.engine != "ORT_CUDA" and (args.build_dynamic_shape or args.height > 512 or args.width > 512):
|
|
306
|
+
max_batch_size = 8 // batch_multiplier
|
|
307
|
+
return max_batch_size
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
def get_metadata(args, is_xl: bool = False) -> Dict[str, Any]:
|
|
311
|
+
metadata = {
|
|
312
|
+
"command": " ".join(['"' + x + '"' if " " in x else x for x in sys.argv]),
|
|
313
|
+
"args.prompt": args.prompt,
|
|
314
|
+
"args.negative_prompt": args.negative_prompt,
|
|
315
|
+
"args.batch_size": args.batch_size,
|
|
316
|
+
"height": args.height,
|
|
317
|
+
"width": args.width,
|
|
318
|
+
"cuda_graph": not args.disable_cuda_graph,
|
|
319
|
+
"vae_slicing": args.enable_vae_slicing,
|
|
320
|
+
"engine": args.engine,
|
|
321
|
+
}
|
|
322
|
+
|
|
323
|
+
if args.lora_weights:
|
|
324
|
+
metadata["lora_weights"] = args.lora_weights
|
|
325
|
+
metadata["lora_scale"] = args.lora_scale
|
|
326
|
+
|
|
327
|
+
if args.controlnet_type:
|
|
328
|
+
metadata["controlnet_type"] = args.controlnet_type
|
|
329
|
+
metadata["controlnet_scale"] = args.controlnet_scale
|
|
330
|
+
|
|
331
|
+
if is_xl and args.enable_refiner:
|
|
332
|
+
metadata["base.scheduler"] = args.scheduler
|
|
333
|
+
metadata["base.denoising_steps"] = args.denoising_steps
|
|
334
|
+
metadata["base.guidance"] = args.guidance
|
|
335
|
+
metadata["refiner.strength"] = args.strength
|
|
336
|
+
metadata["refiner.scheduler"] = args.refiner_scheduler
|
|
337
|
+
metadata["refiner.denoising_steps"] = args.refiner_denoising_steps
|
|
338
|
+
metadata["refiner.guidance"] = args.refiner_guidance
|
|
339
|
+
else:
|
|
340
|
+
metadata["scheduler"] = args.scheduler
|
|
341
|
+
metadata["denoising_steps"] = args.denoising_steps
|
|
342
|
+
metadata["guidance"] = args.guidance
|
|
343
|
+
|
|
344
|
+
# Version of installed python packages
|
|
345
|
+
packages = ""
|
|
346
|
+
for name in [
|
|
347
|
+
"onnxruntime-gpu",
|
|
348
|
+
"torch",
|
|
349
|
+
"tensorrt",
|
|
350
|
+
"transformers",
|
|
351
|
+
"diffusers",
|
|
352
|
+
"onnx",
|
|
353
|
+
"onnx-graphsurgeon",
|
|
354
|
+
"polygraphy",
|
|
355
|
+
"controlnet_aux",
|
|
356
|
+
]:
|
|
357
|
+
try:
|
|
358
|
+
packages += (" " if packages else "") + f"{name}=={version(name)}"
|
|
359
|
+
except PackageNotFoundError:
|
|
360
|
+
continue
|
|
361
|
+
metadata["packages"] = packages
|
|
362
|
+
metadata["device"] = torch.cuda.get_device_name()
|
|
363
|
+
metadata["torch.version.cuda"] = torch.version.cuda
|
|
364
|
+
|
|
365
|
+
return metadata
|
|
366
|
+
|
|
367
|
+
|
|
368
|
+
def repeat_prompt(args):
|
|
369
|
+
if not isinstance(args.prompt, list):
|
|
370
|
+
raise ValueError(f"`prompt` must be of type `str` or `str` list, but is {type(args.prompt)}")
|
|
371
|
+
prompt = args.prompt * args.batch_size
|
|
372
|
+
|
|
373
|
+
if not isinstance(args.negative_prompt, list):
|
|
374
|
+
raise ValueError(
|
|
375
|
+
f"`--negative-prompt` must be of type `str` or `str` list, but is {type(args.negative_prompt)}"
|
|
376
|
+
)
|
|
377
|
+
|
|
378
|
+
if len(args.negative_prompt) == 1:
|
|
379
|
+
negative_prompt = args.negative_prompt * len(prompt)
|
|
380
|
+
else:
|
|
381
|
+
negative_prompt = args.negative_prompt
|
|
382
|
+
|
|
383
|
+
return prompt, negative_prompt
|
|
384
|
+
|
|
385
|
+
|
|
386
|
+
def initialize_pipeline(
|
|
387
|
+
version="xl-turbo",
|
|
388
|
+
is_refiner: bool = False,
|
|
389
|
+
is_inpaint: bool = False,
|
|
390
|
+
engine_type=EngineType.ORT_CUDA,
|
|
391
|
+
work_dir: str = ".",
|
|
392
|
+
engine_dir=None,
|
|
393
|
+
onnx_opset: int = 17,
|
|
394
|
+
scheduler="EulerA",
|
|
395
|
+
height=512,
|
|
396
|
+
width=512,
|
|
397
|
+
nvtx_profile=False,
|
|
398
|
+
use_cuda_graph=True,
|
|
399
|
+
build_dynamic_batch=False,
|
|
400
|
+
build_dynamic_shape=False,
|
|
401
|
+
min_image_size: int = 512,
|
|
402
|
+
max_image_size: int = 1024,
|
|
403
|
+
max_batch_size: int = 16,
|
|
404
|
+
opt_batch_size: int = 1,
|
|
405
|
+
build_all_tactics: bool = False,
|
|
406
|
+
do_classifier_free_guidance: bool = False,
|
|
407
|
+
lcm: bool = False,
|
|
408
|
+
controlnet=None,
|
|
409
|
+
lora_weights=None,
|
|
410
|
+
lora_scale: float = 1.0,
|
|
411
|
+
use_fp16_vae: bool = True,
|
|
412
|
+
use_vae: bool = True,
|
|
413
|
+
framework_model_dir: Optional[str] = None,
|
|
414
|
+
max_cuda_graphs: int = 1,
|
|
415
|
+
):
|
|
416
|
+
pipeline_info = PipelineInfo(
|
|
417
|
+
version,
|
|
418
|
+
is_refiner=is_refiner,
|
|
419
|
+
is_inpaint=is_inpaint,
|
|
420
|
+
use_vae=use_vae,
|
|
421
|
+
min_image_size=min_image_size,
|
|
422
|
+
max_image_size=max_image_size,
|
|
423
|
+
use_fp16_vae=use_fp16_vae,
|
|
424
|
+
use_lcm=lcm,
|
|
425
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
426
|
+
controlnet=controlnet,
|
|
427
|
+
lora_weights=lora_weights,
|
|
428
|
+
lora_scale=lora_scale,
|
|
429
|
+
)
|
|
430
|
+
|
|
431
|
+
input_engine_dir = engine_dir
|
|
432
|
+
|
|
433
|
+
onnx_dir, engine_dir, output_dir, framework_model_dir, timing_cache = get_engine_paths(
|
|
434
|
+
work_dir=work_dir, pipeline_info=pipeline_info, engine_type=engine_type, framework_model_dir=framework_model_dir
|
|
435
|
+
)
|
|
436
|
+
|
|
437
|
+
pipeline = StableDiffusionPipeline(
|
|
438
|
+
pipeline_info,
|
|
439
|
+
scheduler=scheduler,
|
|
440
|
+
output_dir=output_dir,
|
|
441
|
+
verbose=False,
|
|
442
|
+
nvtx_profile=nvtx_profile,
|
|
443
|
+
max_batch_size=max_batch_size,
|
|
444
|
+
use_cuda_graph=use_cuda_graph,
|
|
445
|
+
framework_model_dir=framework_model_dir,
|
|
446
|
+
engine_type=engine_type,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
import_engine_dir = None
|
|
450
|
+
if input_engine_dir:
|
|
451
|
+
if not os.path.exists(input_engine_dir):
|
|
452
|
+
raise RuntimeError(f"--engine_dir directory does not exist: {input_engine_dir}")
|
|
453
|
+
|
|
454
|
+
# Support importing from optimized diffusers onnx pipeline
|
|
455
|
+
if engine_type == EngineType.ORT_CUDA and os.path.exists(os.path.join(input_engine_dir, "model_index.json")):
|
|
456
|
+
import_engine_dir = input_engine_dir
|
|
457
|
+
else:
|
|
458
|
+
engine_dir = input_engine_dir
|
|
459
|
+
|
|
460
|
+
opt_image_height = pipeline_info.default_image_size() if build_dynamic_shape else height
|
|
461
|
+
opt_image_width = pipeline_info.default_image_size() if build_dynamic_shape else width
|
|
462
|
+
|
|
463
|
+
if engine_type == EngineType.ORT_CUDA:
|
|
464
|
+
pipeline.backend.build_engines(
|
|
465
|
+
engine_dir=engine_dir,
|
|
466
|
+
framework_model_dir=framework_model_dir,
|
|
467
|
+
onnx_dir=onnx_dir,
|
|
468
|
+
tmp_dir=os.path.join(work_dir or ".", engine_type.name, pipeline_info.short_name(), "tmp"),
|
|
469
|
+
device_id=torch.cuda.current_device(),
|
|
470
|
+
import_engine_dir=import_engine_dir,
|
|
471
|
+
max_cuda_graphs=max_cuda_graphs,
|
|
472
|
+
)
|
|
473
|
+
elif engine_type == EngineType.ORT_TRT:
|
|
474
|
+
pipeline.backend.build_engines(
|
|
475
|
+
engine_dir,
|
|
476
|
+
framework_model_dir,
|
|
477
|
+
onnx_dir,
|
|
478
|
+
onnx_opset,
|
|
479
|
+
opt_image_height=opt_image_height,
|
|
480
|
+
opt_image_width=opt_image_width,
|
|
481
|
+
opt_batch_size=opt_batch_size,
|
|
482
|
+
static_batch=not build_dynamic_batch,
|
|
483
|
+
static_image_shape=not build_dynamic_shape,
|
|
484
|
+
max_workspace_size=0,
|
|
485
|
+
device_id=torch.cuda.current_device(),
|
|
486
|
+
timing_cache=timing_cache,
|
|
487
|
+
)
|
|
488
|
+
elif engine_type == EngineType.TRT:
|
|
489
|
+
pipeline.backend.load_engines(
|
|
490
|
+
engine_dir,
|
|
491
|
+
framework_model_dir,
|
|
492
|
+
onnx_dir,
|
|
493
|
+
onnx_opset,
|
|
494
|
+
opt_batch_size=opt_batch_size,
|
|
495
|
+
opt_image_height=opt_image_height,
|
|
496
|
+
opt_image_width=opt_image_width,
|
|
497
|
+
static_batch=not build_dynamic_batch,
|
|
498
|
+
static_shape=not build_dynamic_shape,
|
|
499
|
+
enable_all_tactics=build_all_tactics,
|
|
500
|
+
timing_cache=timing_cache,
|
|
501
|
+
)
|
|
502
|
+
elif engine_type == EngineType.TORCH:
|
|
503
|
+
pipeline.backend.build_engines(framework_model_dir)
|
|
504
|
+
else:
|
|
505
|
+
raise RuntimeError("invalid engine type")
|
|
506
|
+
|
|
507
|
+
return pipeline
|
|
508
|
+
|
|
509
|
+
|
|
510
|
+
def load_pipelines(args, batch_size=None):
|
|
511
|
+
engine_type = get_engine_type(args.engine)
|
|
512
|
+
|
|
513
|
+
# Register TensorRT plugins
|
|
514
|
+
if engine_type == EngineType.TRT:
|
|
515
|
+
from trt_utilities import init_trt_plugins
|
|
516
|
+
|
|
517
|
+
init_trt_plugins()
|
|
518
|
+
|
|
519
|
+
max_batch_size = max_batch(args)
|
|
520
|
+
|
|
521
|
+
if batch_size is None:
|
|
522
|
+
assert isinstance(args.prompt, list)
|
|
523
|
+
batch_size = len(args.prompt) * args.batch_size
|
|
524
|
+
|
|
525
|
+
if batch_size > max_batch_size:
|
|
526
|
+
raise ValueError(f"Batch size {batch_size} is larger than allowed {max_batch_size}.")
|
|
527
|
+
|
|
528
|
+
# For TensorRT, performance of engine built with dynamic shape is very sensitive to the range of image size.
|
|
529
|
+
# Here, we reduce the range of image size for TensorRT to trade-off flexibility and performance.
|
|
530
|
+
# This range can cover most frequent shape of landscape (832x1216), portrait (1216x832) or square (1024x1024).
|
|
531
|
+
if args.version == "xl-turbo":
|
|
532
|
+
min_image_size = 512
|
|
533
|
+
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
|
|
534
|
+
elif args.version == "xl-1.0":
|
|
535
|
+
min_image_size = 832 if args.engine != "ORT_CUDA" else 512
|
|
536
|
+
max_image_size = 1216 if args.engine != "ORT_CUDA" else 2048
|
|
537
|
+
else:
|
|
538
|
+
# This range can cover common used shape of landscape 512x768, portrait 768x512, or square 512x512 and 768x768.
|
|
539
|
+
min_image_size = 512 if args.engine != "ORT_CUDA" else 256
|
|
540
|
+
max_image_size = 768 if args.engine != "ORT_CUDA" else 1024
|
|
541
|
+
|
|
542
|
+
params = {
|
|
543
|
+
"version": args.version,
|
|
544
|
+
"is_refiner": False,
|
|
545
|
+
"is_inpaint": False,
|
|
546
|
+
"engine_type": engine_type,
|
|
547
|
+
"work_dir": args.work_dir,
|
|
548
|
+
"engine_dir": args.engine_dir,
|
|
549
|
+
"onnx_opset": args.onnx_opset,
|
|
550
|
+
"scheduler": args.scheduler,
|
|
551
|
+
"height": args.height,
|
|
552
|
+
"width": args.width,
|
|
553
|
+
"nvtx_profile": args.nvtx_profile,
|
|
554
|
+
"use_cuda_graph": not args.disable_cuda_graph,
|
|
555
|
+
"build_dynamic_batch": args.build_dynamic_batch,
|
|
556
|
+
"build_dynamic_shape": args.build_dynamic_shape,
|
|
557
|
+
"min_image_size": min_image_size,
|
|
558
|
+
"max_image_size": max_image_size,
|
|
559
|
+
"max_batch_size": max_batch_size,
|
|
560
|
+
"opt_batch_size": 1 if args.build_dynamic_batch else batch_size,
|
|
561
|
+
"build_all_tactics": args.build_all_tactics,
|
|
562
|
+
"do_classifier_free_guidance": args.guidance > 1.0,
|
|
563
|
+
"controlnet": args.controlnet_type,
|
|
564
|
+
"lora_weights": args.lora_weights,
|
|
565
|
+
"lora_scale": args.lora_scale,
|
|
566
|
+
"use_fp16_vae": "xl" in args.version,
|
|
567
|
+
"use_vae": True,
|
|
568
|
+
"framework_model_dir": args.framework_model_dir,
|
|
569
|
+
"max_cuda_graphs": args.max_cuda_graphs,
|
|
570
|
+
}
|
|
571
|
+
|
|
572
|
+
if "xl" in args.version:
|
|
573
|
+
params["lcm"] = args.lcm
|
|
574
|
+
params["use_vae"] = not args.enable_refiner
|
|
575
|
+
base = initialize_pipeline(**params)
|
|
576
|
+
|
|
577
|
+
refiner = None
|
|
578
|
+
if "xl" in args.version and args.enable_refiner:
|
|
579
|
+
params["version"] = "xl-1.0" # Allow SDXL Turbo to use refiner.
|
|
580
|
+
params["is_refiner"] = True
|
|
581
|
+
params["scheduler"] = args.refiner_scheduler
|
|
582
|
+
params["do_classifier_free_guidance"] = args.refiner_guidance > 1.0
|
|
583
|
+
params["lcm"] = False
|
|
584
|
+
params["controlnet"] = None
|
|
585
|
+
params["lora_weights"] = None
|
|
586
|
+
params["use_vae"] = True
|
|
587
|
+
params["use_fp16_vae"] = True
|
|
588
|
+
refiner = initialize_pipeline(**params)
|
|
589
|
+
|
|
590
|
+
if engine_type == EngineType.TRT:
|
|
591
|
+
max_device_memory = max(base.backend.max_device_memory(), (refiner or base).backend.max_device_memory())
|
|
592
|
+
_, shared_device_memory = cudart.cudaMalloc(max_device_memory)
|
|
593
|
+
base.backend.activate_engines(shared_device_memory)
|
|
594
|
+
if refiner:
|
|
595
|
+
refiner.backend.activate_engines(shared_device_memory)
|
|
596
|
+
|
|
597
|
+
if engine_type == EngineType.ORT_CUDA:
|
|
598
|
+
enable_vae_slicing = args.enable_vae_slicing
|
|
599
|
+
if batch_size > 4 and not enable_vae_slicing and (args.height >= 1024 and args.width >= 1024):
|
|
600
|
+
print(
|
|
601
|
+
"Updating enable_vae_slicing to be True to avoid cuDNN error for batch size > 4 and resolution >= 1024."
|
|
602
|
+
)
|
|
603
|
+
enable_vae_slicing = True
|
|
604
|
+
if enable_vae_slicing:
|
|
605
|
+
(refiner or base).backend.enable_vae_slicing()
|
|
606
|
+
return base, refiner
|
|
607
|
+
|
|
608
|
+
|
|
609
|
+
def get_depth_image(image):
|
|
610
|
+
"""
|
|
611
|
+
Create depth map for SDXL depth control net.
|
|
612
|
+
"""
|
|
613
|
+
from transformers import DPTFeatureExtractor, DPTForDepthEstimation
|
|
614
|
+
|
|
615
|
+
depth_estimator = DPTForDepthEstimation.from_pretrained("Intel/dpt-hybrid-midas").to("cuda")
|
|
616
|
+
feature_extractor = DPTFeatureExtractor.from_pretrained("Intel/dpt-hybrid-midas")
|
|
617
|
+
|
|
618
|
+
image = feature_extractor(images=image, return_tensors="pt").pixel_values.to("cuda")
|
|
619
|
+
with torch.no_grad(), torch.autocast("cuda"):
|
|
620
|
+
depth_map = depth_estimator(image).predicted_depth
|
|
621
|
+
|
|
622
|
+
# The depth map is 384x384 by default, here we interpolate to the default output size.
|
|
623
|
+
# Note that it will be resized to output image size later. May change the size here to avoid interpolate twice.
|
|
624
|
+
depth_map = torch.nn.functional.interpolate(
|
|
625
|
+
depth_map.unsqueeze(1),
|
|
626
|
+
size=(1024, 1024),
|
|
627
|
+
mode="bicubic",
|
|
628
|
+
align_corners=False,
|
|
629
|
+
)
|
|
630
|
+
depth_min = torch.amin(depth_map, dim=[1, 2, 3], keepdim=True)
|
|
631
|
+
depth_max = torch.amax(depth_map, dim=[1, 2, 3], keepdim=True)
|
|
632
|
+
depth_map = (depth_map - depth_min) / (depth_max - depth_min)
|
|
633
|
+
image = torch.cat([depth_map] * 3, dim=1)
|
|
634
|
+
|
|
635
|
+
image = image.permute(0, 2, 3, 1).cpu().numpy()[0]
|
|
636
|
+
image = Image.fromarray((image * 255.0).clip(0, 255).astype(np.uint8))
|
|
637
|
+
return image
|
|
638
|
+
|
|
639
|
+
|
|
640
|
+
def get_canny_image(image) -> Image.Image:
|
|
641
|
+
"""
|
|
642
|
+
Create canny image for SDXL control net.
|
|
643
|
+
"""
|
|
644
|
+
image = np.array(image)
|
|
645
|
+
image = cv2.Canny(image, 100, 200)
|
|
646
|
+
image = image[:, :, None]
|
|
647
|
+
image = np.concatenate([image, image, image], axis=2)
|
|
648
|
+
image = Image.fromarray(image)
|
|
649
|
+
return image
|
|
650
|
+
|
|
651
|
+
|
|
652
|
+
def process_controlnet_images_xl(args) -> List[Image.Image]:
|
|
653
|
+
"""
|
|
654
|
+
Process control image for SDXL control net.
|
|
655
|
+
"""
|
|
656
|
+
assert len(args.controlnet_image) == 1
|
|
657
|
+
image = Image.open(args.controlnet_image[0]).convert("RGB")
|
|
658
|
+
|
|
659
|
+
controlnet_images = []
|
|
660
|
+
if args.controlnet_type[0] == "canny":
|
|
661
|
+
controlnet_images.append(get_canny_image(image))
|
|
662
|
+
elif args.controlnet_type[0] == "depth":
|
|
663
|
+
controlnet_images.append(get_depth_image(image))
|
|
664
|
+
else:
|
|
665
|
+
raise ValueError(f"This controlnet type is not supported for SDXL or Turbo: {args.controlnet_type}.")
|
|
666
|
+
|
|
667
|
+
return controlnet_images
|
|
668
|
+
|
|
669
|
+
|
|
670
|
+
def add_controlnet_arguments(parser, is_xl: bool = False):
|
|
671
|
+
"""
|
|
672
|
+
Add control net related arguments.
|
|
673
|
+
"""
|
|
674
|
+
group = parser.add_argument_group("Options for ControlNet (supports 1.5, sd-turbo, xl-turbo, xl-1.0).")
|
|
675
|
+
|
|
676
|
+
group.add_argument(
|
|
677
|
+
"-ci",
|
|
678
|
+
"--controlnet-image",
|
|
679
|
+
nargs="*",
|
|
680
|
+
type=str,
|
|
681
|
+
default=[],
|
|
682
|
+
help="Path to the input regular RGB image/images for controlnet",
|
|
683
|
+
)
|
|
684
|
+
group.add_argument(
|
|
685
|
+
"-ct",
|
|
686
|
+
"--controlnet-type",
|
|
687
|
+
nargs="*",
|
|
688
|
+
type=str,
|
|
689
|
+
default=[],
|
|
690
|
+
choices=list(PipelineInfo.supported_controlnet("xl-1.0" if is_xl else "1.5").keys()),
|
|
691
|
+
help="A list of controlnet type",
|
|
692
|
+
)
|
|
693
|
+
group.add_argument(
|
|
694
|
+
"-cs",
|
|
695
|
+
"--controlnet-scale",
|
|
696
|
+
nargs="*",
|
|
697
|
+
type=float,
|
|
698
|
+
default=[],
|
|
699
|
+
help="The outputs of the controlnet are multiplied by `controlnet_scale` before they are added to the residual in the original unet. Default is 0.5 for SDXL, or 1.0 for SD 1.5",
|
|
700
|
+
)
|
|
701
|
+
|
|
702
|
+
|
|
703
|
+
def process_controlnet_image(controlnet_type: str, image: Image.Image, height, width):
|
|
704
|
+
"""
|
|
705
|
+
Process control images of control net v1.1 for Stable Diffusion 1.5.
|
|
706
|
+
"""
|
|
707
|
+
control_image = None
|
|
708
|
+
shape = (height, width)
|
|
709
|
+
image = image.convert("RGB")
|
|
710
|
+
if controlnet_type == "canny":
|
|
711
|
+
canny_image = controlnet_aux.CannyDetector()(image)
|
|
712
|
+
control_image = canny_image.resize(shape)
|
|
713
|
+
elif controlnet_type == "normalbae":
|
|
714
|
+
normal_image = controlnet_aux.NormalBaeDetector.from_pretrained("lllyasviel/Annotators")(image)
|
|
715
|
+
control_image = normal_image.resize(shape)
|
|
716
|
+
elif controlnet_type == "depth":
|
|
717
|
+
depth_image = controlnet_aux.LeresDetector.from_pretrained("lllyasviel/Annotators")(image)
|
|
718
|
+
control_image = depth_image.resize(shape)
|
|
719
|
+
elif controlnet_type == "mlsd":
|
|
720
|
+
mlsd_image = controlnet_aux.MLSDdetector.from_pretrained("lllyasviel/Annotators")(image)
|
|
721
|
+
control_image = mlsd_image.resize(shape)
|
|
722
|
+
elif controlnet_type == "openpose":
|
|
723
|
+
openpose_image = controlnet_aux.OpenposeDetector.from_pretrained("lllyasviel/Annotators")(image)
|
|
724
|
+
control_image = openpose_image.resize(shape)
|
|
725
|
+
elif controlnet_type == "scribble":
|
|
726
|
+
scribble_image = controlnet_aux.HEDdetector.from_pretrained("lllyasviel/Annotators")(image, scribble=True)
|
|
727
|
+
control_image = scribble_image.resize(shape)
|
|
728
|
+
elif controlnet_type == "seg":
|
|
729
|
+
seg_image = controlnet_aux.SamDetector.from_pretrained("ybelkada/segment-anything", subfolder="checkpoints")(
|
|
730
|
+
image
|
|
731
|
+
)
|
|
732
|
+
control_image = seg_image.resize(shape)
|
|
733
|
+
else:
|
|
734
|
+
raise ValueError(f"There is no demo image of this controlnet_type: {controlnet_type}")
|
|
735
|
+
return control_image
|
|
736
|
+
|
|
737
|
+
|
|
738
|
+
def process_controlnet_arguments(args):
|
|
739
|
+
"""
|
|
740
|
+
Process control net arguments, and returns a list of control images and a tensor of control net scales.
|
|
741
|
+
"""
|
|
742
|
+
assert isinstance(args.controlnet_type, list)
|
|
743
|
+
assert isinstance(args.controlnet_scale, list)
|
|
744
|
+
assert isinstance(args.controlnet_image, list)
|
|
745
|
+
|
|
746
|
+
if len(args.controlnet_image) != len(args.controlnet_type):
|
|
747
|
+
raise ValueError(
|
|
748
|
+
f"Numbers of controlnet_image {len(args.controlnet_image)} should be equal to number of controlnet_type {len(args.controlnet_type)}."
|
|
749
|
+
)
|
|
750
|
+
|
|
751
|
+
if len(args.controlnet_type) == 0:
|
|
752
|
+
return None, None
|
|
753
|
+
|
|
754
|
+
if args.version not in ["1.5", "xl-1.0", "xl-turbo", "sd-turbo"]:
|
|
755
|
+
raise ValueError("This demo only supports ControlNet in Stable Diffusion 1.5, XL or Turbo.")
|
|
756
|
+
|
|
757
|
+
is_xl = "xl" in args.version
|
|
758
|
+
if is_xl and len(args.controlnet_type) > 1:
|
|
759
|
+
raise ValueError("This demo only support one ControlNet for Stable Diffusion XL or Turbo.")
|
|
760
|
+
|
|
761
|
+
if len(args.controlnet_scale) == 0:
|
|
762
|
+
args.controlnet_scale = [0.5 if is_xl else 1.0] * len(args.controlnet_type)
|
|
763
|
+
elif len(args.controlnet_type) != len(args.controlnet_scale):
|
|
764
|
+
raise ValueError(
|
|
765
|
+
f"Numbers of controlnet_type {len(args.controlnet_type)} should be equal to number of controlnet_scale {len(args.controlnet_scale)}."
|
|
766
|
+
)
|
|
767
|
+
|
|
768
|
+
# Convert controlnet scales to tensor
|
|
769
|
+
controlnet_scale = torch.FloatTensor(args.controlnet_scale)
|
|
770
|
+
|
|
771
|
+
if is_xl:
|
|
772
|
+
images = process_controlnet_images_xl(args)
|
|
773
|
+
else:
|
|
774
|
+
images = []
|
|
775
|
+
for i, image in enumerate(args.controlnet_image):
|
|
776
|
+
images.append(process_controlnet_image(args.controlnet_type[i], Image.open(image), args.height, args.width))
|
|
777
|
+
|
|
778
|
+
return images, controlnet_scale
|