onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6121 -0
- onnxruntime/__init__.py +418 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +175 -0
- onnxruntime/backend/backend_rep.py +52 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/build_and_package_info.py +2 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +154 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +18 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +19 -0
- onnxruntime/quantization/base_quantizer.py +529 -0
- onnxruntime/quantization/calibrate.py +1267 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
- onnxruntime/quantization/fusions/__init__.py +4 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
- onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
- onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
- onnxruntime/quantization/neural_compressor/__init__.py +1 -0
- onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
- onnxruntime/quantization/neural_compressor/util.py +80 -0
- onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
- onnxruntime/quantization/onnx_model.py +600 -0
- onnxruntime/quantization/onnx_quantizer.py +1163 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +260 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +172 -0
- onnxruntime/quantization/operators/lstm.py +121 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +172 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1477 -0
- onnxruntime/quantization/quant_utils.py +1051 -0
- onnxruntime/quantization/quantize.py +953 -0
- onnxruntime/quantization/registry.py +110 -0
- onnxruntime/quantization/shape_inference.py +204 -0
- onnxruntime/quantization/static_quantize_runner.py +256 -0
- onnxruntime/quantization/tensor_quant_overrides.py +520 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
- onnxruntime/tools/file_utils.py +47 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +416 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +56 -0
- onnxruntime/tools/ort_format_model/__init__.py +27 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +85 -0
- onnxruntime/tools/ort_format_model/utils.py +61 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/qnn/add_trans_cast.py +292 -0
- onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
- onnxruntime/tools/qnn/preprocess.py +165 -0
- onnxruntime/tools/reduced_build_config_parser.py +203 -0
- onnxruntime/tools/remove_initializer_from_input.py +37 -0
- onnxruntime/tools/symbolic_shape_infer.py +3094 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +942 -0
- onnxruntime/transformers/benchmark_helper.py +643 -0
- onnxruntime/transformers/bert_perf_test.py +629 -0
- onnxruntime/transformers/bert_test_data.py +641 -0
- onnxruntime/transformers/compare_bert_results.py +256 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3605 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +385 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1189 -0
- onnxruntime/transformers/fusion_attention_clip.py +340 -0
- onnxruntime/transformers/fusion_attention_sam2.py +533 -0
- onnxruntime/transformers/fusion_attention_unet.py +1307 -0
- onnxruntime/transformers/fusion_attention_vae.py +300 -0
- onnxruntime/transformers/fusion_bart_attention.py +435 -0
- onnxruntime/transformers/fusion_base.py +141 -0
- onnxruntime/transformers/fusion_bias_add.py +57 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
- onnxruntime/transformers/fusion_conformer_attention.py +222 -0
- onnxruntime/transformers/fusion_constant_fold.py +144 -0
- onnxruntime/transformers/fusion_embedlayer.py +810 -0
- onnxruntime/transformers/fusion_fastgelu.py +492 -0
- onnxruntime/transformers/fusion_gelu.py +258 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +180 -0
- onnxruntime/transformers/fusion_layernorm.py +489 -0
- onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +420 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
- onnxruntime/transformers/fusion_shape.py +109 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +167 -0
- onnxruntime/transformers/fusion_utils.py +321 -0
- onnxruntime/transformers/huggingface_models.py +74 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +487 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +230 -0
- onnxruntime/transformers/metrics.py +163 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +700 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
- onnxruntime/transformers/models/llama/llama_parity.py +343 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
- onnxruntime/transformers/models/t5/t5_helper.py +302 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +585 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
- onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
- onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
- onnxruntime/transformers/onnx_exporter.py +719 -0
- onnxruntime/transformers/onnx_model.py +1636 -0
- onnxruntime/transformers/onnx_model_bart.py +141 -0
- onnxruntime/transformers/onnx_model_bert.py +488 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
- onnxruntime/transformers/onnx_model_clip.py +42 -0
- onnxruntime/transformers/onnx_model_conformer.py +32 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_mmdit.py +112 -0
- onnxruntime/transformers/onnx_model_phi.py +929 -0
- onnxruntime/transformers/onnx_model_sam2.py +137 -0
- onnxruntime/transformers/onnx_model_t5.py +985 -0
- onnxruntime/transformers/onnx_model_tnlr.py +226 -0
- onnxruntime/transformers/onnx_model_unet.py +258 -0
- onnxruntime/transformers/onnx_model_vae.py +42 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +620 -0
- onnxruntime/transformers/past_helper.py +149 -0
- onnxruntime/transformers/profile_result_processor.py +358 -0
- onnxruntime/transformers/profiler.py +434 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +121 -0
- onnxruntime/transformers/shape_optimizer.py +400 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
- onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
- onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,831 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
# Modified from TensorRT demo diffusion, which has the following license:
|
|
6
|
+
#
|
|
7
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
8
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
9
|
+
#
|
|
10
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
11
|
+
# you may not use this file except in compliance with the License.
|
|
12
|
+
# You may obtain a copy of the License at
|
|
13
|
+
#
|
|
14
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
15
|
+
#
|
|
16
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
17
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
18
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
19
|
+
# See the License for the specific language governing permissions and
|
|
20
|
+
# limitations under the License.
|
|
21
|
+
# --------------------------------------------------------------------------
|
|
22
|
+
|
|
23
|
+
import os
|
|
24
|
+
import pathlib
|
|
25
|
+
import random
|
|
26
|
+
import time
|
|
27
|
+
from typing import Any
|
|
28
|
+
|
|
29
|
+
import numpy as np
|
|
30
|
+
import nvtx
|
|
31
|
+
import torch
|
|
32
|
+
from cuda import cudart
|
|
33
|
+
from diffusion_models import PipelineInfo, get_tokenizer
|
|
34
|
+
from diffusion_schedulers import DDIMScheduler, EulerAncestralDiscreteScheduler, LCMScheduler, UniPCMultistepScheduler
|
|
35
|
+
from engine_builder import EngineType
|
|
36
|
+
from engine_builder_ort_cuda import OrtCudaEngineBuilder
|
|
37
|
+
from engine_builder_ort_trt import OrtTensorrtEngineBuilder
|
|
38
|
+
from engine_builder_tensorrt import TensorrtEngineBuilder
|
|
39
|
+
from engine_builder_torch import TorchEngineBuilder
|
|
40
|
+
from PIL import Image
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class StableDiffusionPipeline:
|
|
44
|
+
"""
|
|
45
|
+
Stable Diffusion pipeline using TensorRT.
|
|
46
|
+
"""
|
|
47
|
+
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
pipeline_info: PipelineInfo,
|
|
51
|
+
max_batch_size=16,
|
|
52
|
+
scheduler="DDIM",
|
|
53
|
+
device="cuda",
|
|
54
|
+
output_dir=".",
|
|
55
|
+
verbose=False,
|
|
56
|
+
nvtx_profile=False,
|
|
57
|
+
use_cuda_graph=False,
|
|
58
|
+
framework_model_dir="pytorch_model",
|
|
59
|
+
engine_type: EngineType = EngineType.ORT_CUDA,
|
|
60
|
+
):
|
|
61
|
+
"""
|
|
62
|
+
Initializes the Diffusion pipeline.
|
|
63
|
+
|
|
64
|
+
Args:
|
|
65
|
+
pipeline_info (PipelineInfo):
|
|
66
|
+
Version and Type of pipeline.
|
|
67
|
+
max_batch_size (int):
|
|
68
|
+
Maximum batch size for dynamic batch engine.
|
|
69
|
+
scheduler (str):
|
|
70
|
+
The scheduler to guide the denoising process. Must be one of [DDIM, EulerA, UniPC, LCM].
|
|
71
|
+
device (str):
|
|
72
|
+
PyTorch device to run inference. Default: 'cuda'
|
|
73
|
+
output_dir (str):
|
|
74
|
+
Output directory for log files and image artifacts
|
|
75
|
+
verbose (bool):
|
|
76
|
+
Enable verbose logging.
|
|
77
|
+
nvtx_profile (bool):
|
|
78
|
+
Insert NVTX profiling markers.
|
|
79
|
+
use_cuda_graph (bool):
|
|
80
|
+
Use CUDA graph to capture engine execution and then launch inference
|
|
81
|
+
framework_model_dir (str):
|
|
82
|
+
cache directory for framework checkpoints
|
|
83
|
+
engine_type (EngineType)
|
|
84
|
+
backend engine type like ORT_TRT or TRT
|
|
85
|
+
"""
|
|
86
|
+
|
|
87
|
+
self.pipeline_info = pipeline_info
|
|
88
|
+
self.version = pipeline_info.version
|
|
89
|
+
|
|
90
|
+
self.vae_scaling_factor = pipeline_info.vae_scaling_factor()
|
|
91
|
+
|
|
92
|
+
self.max_batch_size = max_batch_size
|
|
93
|
+
|
|
94
|
+
self.framework_model_dir = framework_model_dir
|
|
95
|
+
self.output_dir = output_dir
|
|
96
|
+
for directory in [self.framework_model_dir, self.output_dir]:
|
|
97
|
+
if not os.path.exists(directory):
|
|
98
|
+
print(f"[I] Create directory: {directory}")
|
|
99
|
+
pathlib.Path(directory).mkdir(parents=True)
|
|
100
|
+
|
|
101
|
+
self.device = device
|
|
102
|
+
self.torch_device = torch.device(device, torch.cuda.current_device())
|
|
103
|
+
self.verbose = verbose
|
|
104
|
+
self.nvtx_profile = nvtx_profile
|
|
105
|
+
|
|
106
|
+
self.use_cuda_graph = use_cuda_graph
|
|
107
|
+
|
|
108
|
+
self.tokenizer = None
|
|
109
|
+
self.tokenizer2 = None
|
|
110
|
+
|
|
111
|
+
self.generator = torch.Generator(device="cuda")
|
|
112
|
+
self.actual_steps = None
|
|
113
|
+
|
|
114
|
+
self.current_scheduler = None
|
|
115
|
+
self.set_scheduler(scheduler)
|
|
116
|
+
|
|
117
|
+
# backend engine
|
|
118
|
+
self.engine_type = engine_type
|
|
119
|
+
if engine_type == EngineType.TRT:
|
|
120
|
+
self.backend = TensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
|
121
|
+
elif engine_type == EngineType.ORT_TRT:
|
|
122
|
+
self.backend = OrtTensorrtEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
|
123
|
+
elif engine_type == EngineType.ORT_CUDA:
|
|
124
|
+
self.backend = OrtCudaEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
|
125
|
+
elif engine_type == EngineType.TORCH:
|
|
126
|
+
self.backend = TorchEngineBuilder(pipeline_info, max_batch_size, device, use_cuda_graph)
|
|
127
|
+
else:
|
|
128
|
+
raise RuntimeError(f"Backend engine type {engine_type.name} is not supported")
|
|
129
|
+
|
|
130
|
+
# Load text tokenizer
|
|
131
|
+
if not self.pipeline_info.is_xl_refiner():
|
|
132
|
+
self.tokenizer = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer")
|
|
133
|
+
|
|
134
|
+
if self.pipeline_info.is_xl():
|
|
135
|
+
self.tokenizer2 = get_tokenizer(self.pipeline_info, self.framework_model_dir, subfolder="tokenizer_2")
|
|
136
|
+
|
|
137
|
+
self.control_image_processor = None
|
|
138
|
+
if self.pipeline_info.is_xl() and self.pipeline_info.controlnet:
|
|
139
|
+
from diffusers.image_processor import VaeImageProcessor # noqa: PLC0415
|
|
140
|
+
|
|
141
|
+
self.control_image_processor = VaeImageProcessor(
|
|
142
|
+
vae_scale_factor=8, do_convert_rgb=True, do_normalize=False
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
# Create CUDA events
|
|
146
|
+
self.events = {}
|
|
147
|
+
for stage in ["clip", "denoise", "vae", "vae_encoder", "pil"]:
|
|
148
|
+
for marker in ["start", "stop"]:
|
|
149
|
+
self.events[stage + "-" + marker] = cudart.cudaEventCreate()[1]
|
|
150
|
+
self.markers = {}
|
|
151
|
+
|
|
152
|
+
def is_backend_tensorrt(self):
|
|
153
|
+
return self.engine_type == EngineType.TRT
|
|
154
|
+
|
|
155
|
+
def set_scheduler(self, scheduler: str):
|
|
156
|
+
if scheduler == self.current_scheduler:
|
|
157
|
+
return
|
|
158
|
+
|
|
159
|
+
# Scheduler options
|
|
160
|
+
sched_opts = {"num_train_timesteps": 1000, "beta_start": 0.00085, "beta_end": 0.012}
|
|
161
|
+
if self.version in ("2.0", "2.1"):
|
|
162
|
+
sched_opts["prediction_type"] = "v_prediction"
|
|
163
|
+
else:
|
|
164
|
+
sched_opts["prediction_type"] = "epsilon"
|
|
165
|
+
|
|
166
|
+
if scheduler == "DDIM":
|
|
167
|
+
self.scheduler = DDIMScheduler(device=self.device, **sched_opts)
|
|
168
|
+
elif scheduler == "EulerA":
|
|
169
|
+
self.scheduler = EulerAncestralDiscreteScheduler(device=self.device, **sched_opts)
|
|
170
|
+
elif scheduler == "UniPC":
|
|
171
|
+
self.scheduler = UniPCMultistepScheduler(device=self.device, **sched_opts)
|
|
172
|
+
elif scheduler == "LCM":
|
|
173
|
+
self.scheduler = LCMScheduler(device=self.device, **sched_opts)
|
|
174
|
+
else:
|
|
175
|
+
raise ValueError("Scheduler should be either DDIM, EulerA, UniPC or LCM")
|
|
176
|
+
|
|
177
|
+
self.current_scheduler = scheduler
|
|
178
|
+
self.denoising_steps = None
|
|
179
|
+
|
|
180
|
+
def set_denoising_steps(self, denoising_steps: int):
|
|
181
|
+
if not (self.denoising_steps == denoising_steps and isinstance(self.scheduler, DDIMScheduler)):
|
|
182
|
+
self.scheduler.set_timesteps(denoising_steps)
|
|
183
|
+
self.scheduler.configure()
|
|
184
|
+
self.denoising_steps = denoising_steps
|
|
185
|
+
|
|
186
|
+
def load_resources(self, image_height, image_width, batch_size):
|
|
187
|
+
# If engine is built with static input shape, call this only once after engine build.
|
|
188
|
+
# Otherwise, it need be called before every inference run.
|
|
189
|
+
self.backend.load_resources(image_height, image_width, batch_size)
|
|
190
|
+
|
|
191
|
+
def set_random_seed(self, seed):
|
|
192
|
+
if isinstance(seed, int):
|
|
193
|
+
self.generator.manual_seed(seed)
|
|
194
|
+
else:
|
|
195
|
+
self.generator.seed()
|
|
196
|
+
|
|
197
|
+
def get_current_seed(self):
|
|
198
|
+
return self.generator.initial_seed()
|
|
199
|
+
|
|
200
|
+
def teardown(self):
|
|
201
|
+
for e in self.events.values():
|
|
202
|
+
cudart.cudaEventDestroy(e)
|
|
203
|
+
|
|
204
|
+
if self.backend:
|
|
205
|
+
self.backend.teardown()
|
|
206
|
+
|
|
207
|
+
def run_engine(self, model_name, feed_dict):
|
|
208
|
+
return self.backend.run_engine(model_name, feed_dict)
|
|
209
|
+
|
|
210
|
+
def initialize_latents(self, batch_size, unet_channels, latent_height, latent_width):
|
|
211
|
+
latents_dtype = torch.float16
|
|
212
|
+
latents_shape = (batch_size, unet_channels, latent_height, latent_width)
|
|
213
|
+
latents = torch.randn(latents_shape, device=self.device, dtype=latents_dtype, generator=self.generator)
|
|
214
|
+
# Scale the initial noise by the standard deviation required by the scheduler
|
|
215
|
+
latents = latents * self.scheduler.init_noise_sigma
|
|
216
|
+
return latents
|
|
217
|
+
|
|
218
|
+
def initialize_timesteps(self, timesteps, strength):
|
|
219
|
+
"""Initialize timesteps for refiner."""
|
|
220
|
+
self.scheduler.set_timesteps(timesteps)
|
|
221
|
+
offset = self.scheduler.steps_offset if hasattr(self.scheduler, "steps_offset") else 0
|
|
222
|
+
init_timestep = int(timesteps * strength) + offset
|
|
223
|
+
init_timestep = min(init_timestep, timesteps)
|
|
224
|
+
t_start = max(timesteps - init_timestep + offset, 0)
|
|
225
|
+
timesteps = self.scheduler.timesteps[t_start:].to(self.device)
|
|
226
|
+
return timesteps, t_start
|
|
227
|
+
|
|
228
|
+
def initialize_refiner(self, batch_size, image, strength):
|
|
229
|
+
"""Add noise to a reference image."""
|
|
230
|
+
# Initialize timesteps
|
|
231
|
+
timesteps, t_start = self.initialize_timesteps(self.denoising_steps, strength)
|
|
232
|
+
|
|
233
|
+
latent_timestep = timesteps[:1].repeat(batch_size)
|
|
234
|
+
|
|
235
|
+
# Pre-process input image
|
|
236
|
+
image = self.preprocess_images(batch_size, (image,))[0]
|
|
237
|
+
|
|
238
|
+
# VAE encode init image
|
|
239
|
+
if image.shape[1] == 4:
|
|
240
|
+
init_latents = image
|
|
241
|
+
else:
|
|
242
|
+
init_latents = self.encode_image(image)
|
|
243
|
+
|
|
244
|
+
# Add noise to latents using timesteps
|
|
245
|
+
noise = torch.randn(init_latents.shape, device=self.device, dtype=torch.float16, generator=self.generator)
|
|
246
|
+
|
|
247
|
+
latents = self.scheduler.add_noise(init_latents, noise, t_start, latent_timestep)
|
|
248
|
+
|
|
249
|
+
return timesteps, t_start, latents
|
|
250
|
+
|
|
251
|
+
def _get_add_time_ids(
|
|
252
|
+
self,
|
|
253
|
+
original_size,
|
|
254
|
+
crops_coords_top_left,
|
|
255
|
+
target_size,
|
|
256
|
+
aesthetic_score,
|
|
257
|
+
negative_aesthetic_score,
|
|
258
|
+
dtype,
|
|
259
|
+
requires_aesthetics_score,
|
|
260
|
+
):
|
|
261
|
+
if requires_aesthetics_score:
|
|
262
|
+
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,))
|
|
263
|
+
add_neg_time_ids = list(original_size + crops_coords_top_left + (negative_aesthetic_score,))
|
|
264
|
+
else:
|
|
265
|
+
add_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
266
|
+
add_neg_time_ids = list(original_size + crops_coords_top_left + target_size)
|
|
267
|
+
|
|
268
|
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
|
269
|
+
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype)
|
|
270
|
+
|
|
271
|
+
return add_time_ids, add_neg_time_ids
|
|
272
|
+
|
|
273
|
+
def start_profile(self, name, color="blue"):
|
|
274
|
+
if self.nvtx_profile:
|
|
275
|
+
self.markers[name] = nvtx.start_range(message=name, color=color)
|
|
276
|
+
event_name = name + "-start"
|
|
277
|
+
if event_name in self.events:
|
|
278
|
+
cudart.cudaEventRecord(self.events[event_name], 0)
|
|
279
|
+
|
|
280
|
+
def stop_profile(self, name):
|
|
281
|
+
event_name = name + "-stop"
|
|
282
|
+
if event_name in self.events:
|
|
283
|
+
cudart.cudaEventRecord(self.events[event_name], 0)
|
|
284
|
+
if self.nvtx_profile:
|
|
285
|
+
nvtx.end_range(self.markers[name])
|
|
286
|
+
|
|
287
|
+
def preprocess_images(self, batch_size, images=()):
|
|
288
|
+
self.start_profile("preprocess", color="pink")
|
|
289
|
+
init_images = []
|
|
290
|
+
for i in images:
|
|
291
|
+
image = i.to(self.device)
|
|
292
|
+
if image.shape[0] != batch_size:
|
|
293
|
+
image = image.repeat(batch_size, 1, 1, 1)
|
|
294
|
+
init_images.append(image)
|
|
295
|
+
self.stop_profile("preprocess")
|
|
296
|
+
return tuple(init_images)
|
|
297
|
+
|
|
298
|
+
def preprocess_controlnet_images(
|
|
299
|
+
self, batch_size, images=None, do_classifier_free_guidance=True, height=1024, width=1024
|
|
300
|
+
):
|
|
301
|
+
"""
|
|
302
|
+
Process a list of PIL.Image.Image as control images, and return a torch tensor.
|
|
303
|
+
"""
|
|
304
|
+
if images is None:
|
|
305
|
+
return None
|
|
306
|
+
self.start_profile("preprocess", color="pink")
|
|
307
|
+
|
|
308
|
+
if not self.pipeline_info.is_xl():
|
|
309
|
+
images = [
|
|
310
|
+
torch.from_numpy(
|
|
311
|
+
(np.array(image.convert("RGB")).astype(np.float32) / 255.0)[..., None].transpose(3, 2, 0, 1)
|
|
312
|
+
)
|
|
313
|
+
.to(device=self.device, dtype=torch.float16)
|
|
314
|
+
.repeat_interleave(batch_size, dim=0)
|
|
315
|
+
for image in images
|
|
316
|
+
]
|
|
317
|
+
else:
|
|
318
|
+
images = [
|
|
319
|
+
self.control_image_processor.preprocess(image, height=height, width=width)
|
|
320
|
+
.to(device=self.device, dtype=torch.float16)
|
|
321
|
+
.repeat_interleave(batch_size, dim=0)
|
|
322
|
+
for image in images
|
|
323
|
+
]
|
|
324
|
+
|
|
325
|
+
if do_classifier_free_guidance:
|
|
326
|
+
images = [torch.cat([i] * 2) for i in images]
|
|
327
|
+
images = torch.cat([image[None, ...] for image in images], dim=0)
|
|
328
|
+
|
|
329
|
+
self.stop_profile("preprocess")
|
|
330
|
+
return images
|
|
331
|
+
|
|
332
|
+
def encode_prompt(
|
|
333
|
+
self,
|
|
334
|
+
prompt,
|
|
335
|
+
negative_prompt,
|
|
336
|
+
encoder="clip",
|
|
337
|
+
tokenizer=None,
|
|
338
|
+
pooled_outputs=False,
|
|
339
|
+
output_hidden_states=False,
|
|
340
|
+
force_zeros_for_empty_prompt=False,
|
|
341
|
+
do_classifier_free_guidance=True,
|
|
342
|
+
dtype=torch.float16,
|
|
343
|
+
):
|
|
344
|
+
if tokenizer is None:
|
|
345
|
+
tokenizer = self.tokenizer
|
|
346
|
+
|
|
347
|
+
self.start_profile("clip", color="green")
|
|
348
|
+
|
|
349
|
+
def tokenize(prompt, output_hidden_states):
|
|
350
|
+
text_input_ids = (
|
|
351
|
+
tokenizer(
|
|
352
|
+
prompt,
|
|
353
|
+
padding="max_length",
|
|
354
|
+
max_length=tokenizer.model_max_length,
|
|
355
|
+
truncation=True,
|
|
356
|
+
return_tensors="pt",
|
|
357
|
+
)
|
|
358
|
+
.input_ids.type(torch.int32)
|
|
359
|
+
.to(self.device)
|
|
360
|
+
)
|
|
361
|
+
|
|
362
|
+
hidden_states = None
|
|
363
|
+
if self.engine_type == EngineType.TORCH:
|
|
364
|
+
outputs = self.backend.engines[encoder](text_input_ids)
|
|
365
|
+
text_embeddings = outputs[0]
|
|
366
|
+
if output_hidden_states:
|
|
367
|
+
hidden_states = outputs["last_hidden_state"]
|
|
368
|
+
else:
|
|
369
|
+
outputs = self.run_engine(encoder, {"input_ids": text_input_ids})
|
|
370
|
+
text_embeddings = outputs["text_embeddings"]
|
|
371
|
+
if output_hidden_states:
|
|
372
|
+
hidden_states = outputs["hidden_states"]
|
|
373
|
+
return text_embeddings, hidden_states
|
|
374
|
+
|
|
375
|
+
# Tokenize prompt
|
|
376
|
+
text_embeddings, hidden_states = tokenize(prompt, output_hidden_states)
|
|
377
|
+
|
|
378
|
+
# NOTE: output tensor for CLIP must be cloned because it will be overwritten when called again for negative prompt
|
|
379
|
+
text_embeddings = text_embeddings.clone()
|
|
380
|
+
if hidden_states is not None:
|
|
381
|
+
hidden_states = hidden_states.clone()
|
|
382
|
+
|
|
383
|
+
# Note: negative prompt embedding is not needed for SD XL when guidance <= 1
|
|
384
|
+
if do_classifier_free_guidance:
|
|
385
|
+
# For SD XL base, handle force_zeros_for_empty_prompt
|
|
386
|
+
is_empty_negative_prompt = all(not i for i in negative_prompt)
|
|
387
|
+
if force_zeros_for_empty_prompt and is_empty_negative_prompt:
|
|
388
|
+
uncond_embeddings = torch.zeros_like(text_embeddings)
|
|
389
|
+
if output_hidden_states:
|
|
390
|
+
uncond_hidden_states = torch.zeros_like(hidden_states)
|
|
391
|
+
else:
|
|
392
|
+
# Tokenize negative prompt
|
|
393
|
+
uncond_embeddings, uncond_hidden_states = tokenize(negative_prompt, output_hidden_states)
|
|
394
|
+
|
|
395
|
+
# Concatenate the unconditional and text embeddings into a single batch to avoid doing two forward passes for classifier free guidance
|
|
396
|
+
text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
|
|
397
|
+
|
|
398
|
+
if output_hidden_states:
|
|
399
|
+
hidden_states = torch.cat([uncond_hidden_states, hidden_states])
|
|
400
|
+
|
|
401
|
+
self.stop_profile("clip")
|
|
402
|
+
|
|
403
|
+
if pooled_outputs:
|
|
404
|
+
# For text encoder in sdxl base
|
|
405
|
+
return hidden_states.to(dtype=dtype), text_embeddings.to(dtype=dtype)
|
|
406
|
+
|
|
407
|
+
if output_hidden_states:
|
|
408
|
+
# For text encoder 2 in sdxl base or refiner
|
|
409
|
+
return hidden_states.to(dtype=dtype)
|
|
410
|
+
|
|
411
|
+
# For text encoder in sd 1.5
|
|
412
|
+
return text_embeddings.to(dtype=dtype)
|
|
413
|
+
|
|
414
|
+
def denoise_latent(
|
|
415
|
+
self,
|
|
416
|
+
latents,
|
|
417
|
+
text_embeddings,
|
|
418
|
+
denoiser="unet",
|
|
419
|
+
timesteps=None,
|
|
420
|
+
step_offset=0,
|
|
421
|
+
guidance=7.5,
|
|
422
|
+
add_kwargs=None,
|
|
423
|
+
):
|
|
424
|
+
do_classifier_free_guidance = guidance > 1.0
|
|
425
|
+
|
|
426
|
+
self.start_profile("denoise", color="blue")
|
|
427
|
+
|
|
428
|
+
if not isinstance(timesteps, torch.Tensor):
|
|
429
|
+
timesteps = self.scheduler.timesteps
|
|
430
|
+
|
|
431
|
+
for step_index, timestep in enumerate(timesteps):
|
|
432
|
+
# Expand the latents if we are doing classifier free guidance
|
|
433
|
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
|
434
|
+
|
|
435
|
+
latent_model_input = self.scheduler.scale_model_input(
|
|
436
|
+
latent_model_input, step_offset + step_index, timestep
|
|
437
|
+
)
|
|
438
|
+
|
|
439
|
+
# Predict the noise residual
|
|
440
|
+
if self.nvtx_profile:
|
|
441
|
+
nvtx_unet = nvtx.start_range(message="unet", color="blue")
|
|
442
|
+
|
|
443
|
+
params = {
|
|
444
|
+
"sample": latent_model_input,
|
|
445
|
+
"timestep": timestep.to(latents.dtype),
|
|
446
|
+
"encoder_hidden_states": text_embeddings,
|
|
447
|
+
}
|
|
448
|
+
|
|
449
|
+
if add_kwargs:
|
|
450
|
+
params.update(add_kwargs)
|
|
451
|
+
|
|
452
|
+
noise_pred = self.run_engine(denoiser, params)["latent"]
|
|
453
|
+
|
|
454
|
+
if self.nvtx_profile:
|
|
455
|
+
nvtx.end_range(nvtx_unet)
|
|
456
|
+
|
|
457
|
+
# perform guidance
|
|
458
|
+
if do_classifier_free_guidance:
|
|
459
|
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
|
460
|
+
noise_pred = noise_pred_uncond + guidance * (noise_pred_text - noise_pred_uncond)
|
|
461
|
+
|
|
462
|
+
if type(self.scheduler) is UniPCMultistepScheduler:
|
|
463
|
+
latents = self.scheduler.step(noise_pred, timestep, latents, return_dict=False)[0]
|
|
464
|
+
elif type(self.scheduler) is LCMScheduler:
|
|
465
|
+
latents = self.scheduler.step(noise_pred, timestep, latents, generator=self.generator)[0]
|
|
466
|
+
else:
|
|
467
|
+
latents = self.scheduler.step(noise_pred, latents, step_offset + step_index, timestep)
|
|
468
|
+
|
|
469
|
+
# The actual number of steps. It might be different from denoising_steps.
|
|
470
|
+
self.actual_steps = len(timesteps)
|
|
471
|
+
|
|
472
|
+
self.stop_profile("denoise")
|
|
473
|
+
return latents
|
|
474
|
+
|
|
475
|
+
def encode_image(self, image):
|
|
476
|
+
self.start_profile("vae_encoder", color="red")
|
|
477
|
+
init_latents = self.run_engine("vae_encoder", {"images": image})["latent"]
|
|
478
|
+
init_latents = self.vae_scaling_factor * init_latents
|
|
479
|
+
self.stop_profile("vae_encoder")
|
|
480
|
+
return init_latents
|
|
481
|
+
|
|
482
|
+
def decode_latent(self, latents):
|
|
483
|
+
self.start_profile("vae", color="red")
|
|
484
|
+
images = self.backend.vae_decode(latents)
|
|
485
|
+
self.stop_profile("vae")
|
|
486
|
+
return images
|
|
487
|
+
|
|
488
|
+
def print_summary(self, tic, toc, batch_size, vae_enc=False, pil=False) -> dict[str, Any]:
|
|
489
|
+
throughput = batch_size / (toc - tic)
|
|
490
|
+
latency_clip = cudart.cudaEventElapsedTime(self.events["clip-start"], self.events["clip-stop"])[1]
|
|
491
|
+
latency_unet = cudart.cudaEventElapsedTime(self.events["denoise-start"], self.events["denoise-stop"])[1]
|
|
492
|
+
latency_vae = cudart.cudaEventElapsedTime(self.events["vae-start"], self.events["vae-stop"])[1]
|
|
493
|
+
latency_vae_encoder = (
|
|
494
|
+
cudart.cudaEventElapsedTime(self.events["vae_encoder-start"], self.events["vae_encoder-stop"])[1]
|
|
495
|
+
if vae_enc
|
|
496
|
+
else None
|
|
497
|
+
)
|
|
498
|
+
latency_pil = cudart.cudaEventElapsedTime(self.events["pil-start"], self.events["pil-stop"])[1] if pil else None
|
|
499
|
+
|
|
500
|
+
latency = (toc - tic) * 1000.0
|
|
501
|
+
|
|
502
|
+
print("|----------------|--------------|")
|
|
503
|
+
print("| {:^14} | {:^12} |".format("Module", "Latency"))
|
|
504
|
+
print("|----------------|--------------|")
|
|
505
|
+
if vae_enc:
|
|
506
|
+
print("| {:^14} | {:>9.2f} ms |".format("VAE-Enc", latency_vae_encoder))
|
|
507
|
+
print("| {:^14} | {:>9.2f} ms |".format("CLIP", latency_clip))
|
|
508
|
+
print(
|
|
509
|
+
"| {:^14} | {:>9.2f} ms |".format(
|
|
510
|
+
"UNet" + ("+CNet" if self.pipeline_info.controlnet else "") + " x " + str(self.actual_steps),
|
|
511
|
+
latency_unet,
|
|
512
|
+
)
|
|
513
|
+
)
|
|
514
|
+
print("| {:^14} | {:>9.2f} ms |".format("VAE-Dec", latency_vae))
|
|
515
|
+
pipeline = "Refiner" if self.pipeline_info.is_xl_refiner() else "Pipeline"
|
|
516
|
+
if pil:
|
|
517
|
+
print("| {:^14} | {:>9.2f} ms |".format("PIL", latency_pil))
|
|
518
|
+
print("|----------------|--------------|")
|
|
519
|
+
print(f"| {pipeline:^14} | {latency:>9.2f} ms |")
|
|
520
|
+
print("|----------------|--------------|")
|
|
521
|
+
print(f"Throughput: {throughput:.2f} image/s")
|
|
522
|
+
|
|
523
|
+
perf_data = {
|
|
524
|
+
"latency_clip": latency_clip,
|
|
525
|
+
"latency_unet": latency_unet,
|
|
526
|
+
"latency_vae": latency_vae,
|
|
527
|
+
"latency_pil": latency_pil,
|
|
528
|
+
"latency": latency,
|
|
529
|
+
"throughput": throughput,
|
|
530
|
+
}
|
|
531
|
+
if vae_enc:
|
|
532
|
+
perf_data["latency_vae_encoder"] = latency_vae_encoder
|
|
533
|
+
return perf_data
|
|
534
|
+
|
|
535
|
+
@staticmethod
|
|
536
|
+
def pt_to_pil(images):
|
|
537
|
+
images = (
|
|
538
|
+
((images + 1) * 255 / 2).clamp(0, 255).detach().permute(0, 2, 3, 1).round().type(torch.uint8).cpu().numpy()
|
|
539
|
+
)
|
|
540
|
+
return [Image.fromarray(images[i]) for i in range(images.shape[0])]
|
|
541
|
+
|
|
542
|
+
@staticmethod
|
|
543
|
+
def pt_to_numpy(images: torch.FloatTensor):
|
|
544
|
+
"""
|
|
545
|
+
Convert a PyTorch tensor to a NumPy image.
|
|
546
|
+
"""
|
|
547
|
+
return ((images + 1) / 2).clamp(0, 1).detach().permute(0, 2, 3, 1).float().cpu().numpy()
|
|
548
|
+
|
|
549
|
+
def metadata(self) -> dict[str, Any]:
|
|
550
|
+
data = {
|
|
551
|
+
"actual_steps": self.actual_steps,
|
|
552
|
+
"seed": self.get_current_seed(),
|
|
553
|
+
"name": self.pipeline_info.name(),
|
|
554
|
+
"custom_vae": self.pipeline_info.custom_fp16_vae(),
|
|
555
|
+
"custom_unet": self.pipeline_info.custom_unet(),
|
|
556
|
+
}
|
|
557
|
+
|
|
558
|
+
if self.engine_type == EngineType.ORT_CUDA:
|
|
559
|
+
for engine_name, engine in self.backend.engines.items():
|
|
560
|
+
data.update(engine.metadata(engine_name))
|
|
561
|
+
|
|
562
|
+
return data
|
|
563
|
+
|
|
564
|
+
def save_images(self, images: list, prompt: list[str], negative_prompt: list[str], metadata: dict[str, Any]):
|
|
565
|
+
session_id = str(random.randint(1000, 9999))
|
|
566
|
+
for i, image in enumerate(images):
|
|
567
|
+
seed = str(self.get_current_seed())
|
|
568
|
+
prefix = "".join(x for x in prompt[i] if x.isalnum() or x in ", -").replace(" ", "_")[:20]
|
|
569
|
+
parts = [prefix, session_id, str(i + 1), str(seed), self.current_scheduler, str(self.actual_steps)]
|
|
570
|
+
image_path = os.path.join(self.output_dir, "-".join(parts) + ".png")
|
|
571
|
+
print(f"Saving image {i + 1} / {len(images)} to: {image_path}")
|
|
572
|
+
|
|
573
|
+
from PIL import PngImagePlugin # noqa: PLC0415
|
|
574
|
+
|
|
575
|
+
info = PngImagePlugin.PngInfo()
|
|
576
|
+
for k, v in metadata.items():
|
|
577
|
+
info.add_text(k, str(v))
|
|
578
|
+
info.add_text("prompt", prompt[i])
|
|
579
|
+
info.add_text("negative_prompt", negative_prompt[i])
|
|
580
|
+
|
|
581
|
+
image.save(image_path, "PNG", pnginfo=info)
|
|
582
|
+
|
|
583
|
+
def _infer(
|
|
584
|
+
self,
|
|
585
|
+
prompt,
|
|
586
|
+
negative_prompt,
|
|
587
|
+
image_height,
|
|
588
|
+
image_width,
|
|
589
|
+
denoising_steps=30,
|
|
590
|
+
guidance=5.0,
|
|
591
|
+
seed=None,
|
|
592
|
+
image=None,
|
|
593
|
+
strength=0.3,
|
|
594
|
+
controlnet_images=None,
|
|
595
|
+
controlnet_scales=None,
|
|
596
|
+
show_latency=False,
|
|
597
|
+
output_type="pil",
|
|
598
|
+
):
|
|
599
|
+
if show_latency:
|
|
600
|
+
torch.cuda.synchronize()
|
|
601
|
+
start_time = time.perf_counter()
|
|
602
|
+
|
|
603
|
+
assert len(prompt) == len(negative_prompt)
|
|
604
|
+
batch_size = len(prompt)
|
|
605
|
+
|
|
606
|
+
self.set_denoising_steps(denoising_steps)
|
|
607
|
+
self.set_random_seed(seed)
|
|
608
|
+
|
|
609
|
+
timesteps = None
|
|
610
|
+
step_offset = 0
|
|
611
|
+
with torch.inference_mode(), torch.autocast("cuda"):
|
|
612
|
+
if image is not None:
|
|
613
|
+
timesteps, step_offset, latents = self.initialize_refiner(
|
|
614
|
+
batch_size=batch_size,
|
|
615
|
+
image=image,
|
|
616
|
+
strength=strength,
|
|
617
|
+
)
|
|
618
|
+
else:
|
|
619
|
+
# Pre-initialize latents
|
|
620
|
+
latents = self.initialize_latents(
|
|
621
|
+
batch_size=batch_size,
|
|
622
|
+
unet_channels=4,
|
|
623
|
+
latent_height=(image_height // 8),
|
|
624
|
+
latent_width=(image_width // 8),
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
do_classifier_free_guidance = guidance > 1.0
|
|
628
|
+
if not self.pipeline_info.is_xl():
|
|
629
|
+
denoiser = "unet"
|
|
630
|
+
text_embeddings = self.encode_prompt(
|
|
631
|
+
prompt,
|
|
632
|
+
negative_prompt,
|
|
633
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
634
|
+
dtype=latents.dtype,
|
|
635
|
+
)
|
|
636
|
+
add_kwargs = {}
|
|
637
|
+
else:
|
|
638
|
+
denoiser = "unetxl"
|
|
639
|
+
|
|
640
|
+
# Time embeddings
|
|
641
|
+
original_size = (image_height, image_width)
|
|
642
|
+
crops_coords_top_left = (0, 0)
|
|
643
|
+
target_size = (image_height, image_width)
|
|
644
|
+
aesthetic_score = 6.0
|
|
645
|
+
negative_aesthetic_score = 2.5
|
|
646
|
+
add_time_ids, add_negative_time_ids = self._get_add_time_ids(
|
|
647
|
+
original_size,
|
|
648
|
+
crops_coords_top_left,
|
|
649
|
+
target_size,
|
|
650
|
+
aesthetic_score,
|
|
651
|
+
negative_aesthetic_score,
|
|
652
|
+
dtype=latents.dtype,
|
|
653
|
+
requires_aesthetics_score=self.pipeline_info.is_xl_refiner(),
|
|
654
|
+
)
|
|
655
|
+
if do_classifier_free_guidance:
|
|
656
|
+
add_time_ids = torch.cat([add_negative_time_ids, add_time_ids], dim=0)
|
|
657
|
+
add_time_ids = add_time_ids.to(device=self.device).repeat(batch_size, 1)
|
|
658
|
+
|
|
659
|
+
if self.pipeline_info.is_xl_refiner():
|
|
660
|
+
# CLIP text encoder 2
|
|
661
|
+
text_embeddings, pooled_embeddings2 = self.encode_prompt(
|
|
662
|
+
prompt,
|
|
663
|
+
negative_prompt,
|
|
664
|
+
encoder="clip2",
|
|
665
|
+
tokenizer=self.tokenizer2,
|
|
666
|
+
pooled_outputs=True,
|
|
667
|
+
output_hidden_states=True,
|
|
668
|
+
dtype=latents.dtype,
|
|
669
|
+
)
|
|
670
|
+
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
|
|
671
|
+
else: # XL Base
|
|
672
|
+
# CLIP text encoder
|
|
673
|
+
text_embeddings = self.encode_prompt(
|
|
674
|
+
prompt,
|
|
675
|
+
negative_prompt,
|
|
676
|
+
encoder="clip",
|
|
677
|
+
tokenizer=self.tokenizer,
|
|
678
|
+
output_hidden_states=True,
|
|
679
|
+
force_zeros_for_empty_prompt=True,
|
|
680
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
681
|
+
dtype=latents.dtype,
|
|
682
|
+
)
|
|
683
|
+
# CLIP text encoder 2
|
|
684
|
+
text_embeddings2, pooled_embeddings2 = self.encode_prompt(
|
|
685
|
+
prompt,
|
|
686
|
+
negative_prompt,
|
|
687
|
+
encoder="clip2",
|
|
688
|
+
tokenizer=self.tokenizer2,
|
|
689
|
+
pooled_outputs=True,
|
|
690
|
+
output_hidden_states=True,
|
|
691
|
+
force_zeros_for_empty_prompt=True,
|
|
692
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
693
|
+
dtype=latents.dtype,
|
|
694
|
+
)
|
|
695
|
+
|
|
696
|
+
# Merged text embeddings
|
|
697
|
+
text_embeddings = torch.cat([text_embeddings, text_embeddings2], dim=-1)
|
|
698
|
+
|
|
699
|
+
add_kwargs = {"text_embeds": pooled_embeddings2, "time_ids": add_time_ids}
|
|
700
|
+
|
|
701
|
+
if self.pipeline_info.controlnet:
|
|
702
|
+
controlnet_images = self.preprocess_controlnet_images(
|
|
703
|
+
latents.shape[0],
|
|
704
|
+
controlnet_images,
|
|
705
|
+
do_classifier_free_guidance=do_classifier_free_guidance,
|
|
706
|
+
height=image_height,
|
|
707
|
+
width=image_width,
|
|
708
|
+
)
|
|
709
|
+
add_kwargs.update(
|
|
710
|
+
{
|
|
711
|
+
"controlnet_images": controlnet_images,
|
|
712
|
+
"controlnet_scales": controlnet_scales.to(controlnet_images.dtype).to(controlnet_images.device),
|
|
713
|
+
}
|
|
714
|
+
)
|
|
715
|
+
|
|
716
|
+
# UNet denoiser
|
|
717
|
+
latents = self.denoise_latent(
|
|
718
|
+
latents,
|
|
719
|
+
text_embeddings,
|
|
720
|
+
timesteps=timesteps,
|
|
721
|
+
step_offset=step_offset,
|
|
722
|
+
denoiser=denoiser,
|
|
723
|
+
guidance=guidance,
|
|
724
|
+
add_kwargs=add_kwargs,
|
|
725
|
+
)
|
|
726
|
+
|
|
727
|
+
with torch.inference_mode():
|
|
728
|
+
# VAE decode latent
|
|
729
|
+
if output_type == "latent":
|
|
730
|
+
images = latents
|
|
731
|
+
else:
|
|
732
|
+
images = self.decode_latent(latents / self.vae_scaling_factor)
|
|
733
|
+
if output_type == "pil":
|
|
734
|
+
self.start_profile("pil", color="green")
|
|
735
|
+
images = self.pt_to_pil(images)
|
|
736
|
+
self.stop_profile("pil")
|
|
737
|
+
|
|
738
|
+
perf_data = None
|
|
739
|
+
if show_latency:
|
|
740
|
+
torch.cuda.synchronize()
|
|
741
|
+
end_time = time.perf_counter()
|
|
742
|
+
perf_data = self.print_summary(
|
|
743
|
+
start_time, end_time, batch_size, vae_enc=self.pipeline_info.is_xl_refiner(), pil=(output_type == "pil")
|
|
744
|
+
)
|
|
745
|
+
|
|
746
|
+
return images, perf_data
|
|
747
|
+
|
|
748
|
+
def run(
|
|
749
|
+
self,
|
|
750
|
+
prompt: list[str],
|
|
751
|
+
negative_prompt: list[str],
|
|
752
|
+
image_height: int,
|
|
753
|
+
image_width: int,
|
|
754
|
+
denoising_steps: int = 30,
|
|
755
|
+
guidance: float = 5.0,
|
|
756
|
+
seed: int | None = None,
|
|
757
|
+
image: torch.Tensor | None = None,
|
|
758
|
+
strength: float = 0.3,
|
|
759
|
+
controlnet_images: torch.Tensor | None = None,
|
|
760
|
+
controlnet_scales: torch.Tensor | None = None,
|
|
761
|
+
show_latency: bool = False,
|
|
762
|
+
output_type: str = "pil",
|
|
763
|
+
deterministic: bool = False,
|
|
764
|
+
):
|
|
765
|
+
"""
|
|
766
|
+
Run the diffusion pipeline.
|
|
767
|
+
|
|
768
|
+
Args:
|
|
769
|
+
prompt (List[str]):
|
|
770
|
+
The text prompt to guide image generation.
|
|
771
|
+
negative_prompt (List[str]):
|
|
772
|
+
The prompt not to guide the image generation.
|
|
773
|
+
image_height (int):
|
|
774
|
+
Height (in pixels) of the image to be generated. Must be a multiple of 8.
|
|
775
|
+
image_width (int):
|
|
776
|
+
Width (in pixels) of the image to be generated. Must be a multiple of 8.
|
|
777
|
+
denoising_steps (int):
|
|
778
|
+
Number of denoising steps. More steps usually lead to higher quality image at the expense of slower inference.
|
|
779
|
+
guidance (float):
|
|
780
|
+
Higher guidance scale encourages to generate images that are closely linked to the text prompt.
|
|
781
|
+
seed (int):
|
|
782
|
+
Seed for the random generator
|
|
783
|
+
image (tuple[torch.Tensor]):
|
|
784
|
+
Reference image.
|
|
785
|
+
strength (float):
|
|
786
|
+
Indicates extent to transform the reference image, which is used as a starting point,
|
|
787
|
+
and more noise is added the higher the strength.
|
|
788
|
+
show_latency (bool):
|
|
789
|
+
Whether return latency data.
|
|
790
|
+
output_type (str):
|
|
791
|
+
It can be "latent", "pt" or "pil".
|
|
792
|
+
"""
|
|
793
|
+
if deterministic:
|
|
794
|
+
torch.use_deterministic_algorithms(True)
|
|
795
|
+
|
|
796
|
+
if self.is_backend_tensorrt():
|
|
797
|
+
import tensorrt as trt # noqa: PLC0415
|
|
798
|
+
from trt_utilities import TRT_LOGGER # noqa: PLC0415
|
|
799
|
+
|
|
800
|
+
with trt.Runtime(TRT_LOGGER):
|
|
801
|
+
return self._infer(
|
|
802
|
+
prompt,
|
|
803
|
+
negative_prompt,
|
|
804
|
+
image_height,
|
|
805
|
+
image_width,
|
|
806
|
+
denoising_steps=denoising_steps,
|
|
807
|
+
guidance=guidance,
|
|
808
|
+
seed=seed,
|
|
809
|
+
image=image,
|
|
810
|
+
strength=strength,
|
|
811
|
+
controlnet_images=controlnet_images,
|
|
812
|
+
controlnet_scales=controlnet_scales,
|
|
813
|
+
show_latency=show_latency,
|
|
814
|
+
output_type=output_type,
|
|
815
|
+
)
|
|
816
|
+
else:
|
|
817
|
+
return self._infer(
|
|
818
|
+
prompt,
|
|
819
|
+
negative_prompt,
|
|
820
|
+
image_height,
|
|
821
|
+
image_width,
|
|
822
|
+
denoising_steps=denoising_steps,
|
|
823
|
+
guidance=guidance,
|
|
824
|
+
seed=seed,
|
|
825
|
+
image=image,
|
|
826
|
+
strength=strength,
|
|
827
|
+
controlnet_images=controlnet_images,
|
|
828
|
+
controlnet_scales=controlnet_scales,
|
|
829
|
+
show_latency=show_latency,
|
|
830
|
+
output_type=output_type,
|
|
831
|
+
)
|