onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1319 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
# Modified from stable_diffusion_tensorrt_txt2img.py in diffusers and TensorRT demo diffusion,
|
|
6
|
+
# which has the following license:
|
|
7
|
+
#
|
|
8
|
+
# Copyright 2023 The HuggingFace Inc. team.
|
|
9
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
10
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
11
|
+
#
|
|
12
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
13
|
+
# you may not use this file except in compliance with the License.
|
|
14
|
+
# You may obtain a copy of the License at
|
|
15
|
+
#
|
|
16
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
17
|
+
#
|
|
18
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
19
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
20
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
21
|
+
# See the License for the specific language governing permissions and
|
|
22
|
+
# limitations under the License.
|
|
23
|
+
|
|
24
|
+
import logging
|
|
25
|
+
import os
|
|
26
|
+
import tempfile
|
|
27
|
+
from typing import Dict, List, Optional
|
|
28
|
+
|
|
29
|
+
import onnx
|
|
30
|
+
import onnx_graphsurgeon as gs
|
|
31
|
+
import torch
|
|
32
|
+
from diffusers.models import AutoencoderKL, ControlNetModel, UNet2DConditionModel
|
|
33
|
+
from onnx import GraphProto, ModelProto, shape_inference
|
|
34
|
+
from ort_optimizer import OrtStableDiffusionOptimizer
|
|
35
|
+
from polygraphy.backend.onnx.loader import fold_constants
|
|
36
|
+
from transformers import CLIPTextModel, CLIPTextModelWithProjection, CLIPTokenizer
|
|
37
|
+
|
|
38
|
+
from onnxruntime.transformers.onnx_model import OnnxModel
|
|
39
|
+
|
|
40
|
+
logger = logging.getLogger(__name__)
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
class TrtOptimizer:
|
|
44
|
+
def __init__(self, onnx_graph):
|
|
45
|
+
self.graph = gs.import_onnx(onnx_graph)
|
|
46
|
+
|
|
47
|
+
def cleanup(self):
|
|
48
|
+
self.graph.cleanup().toposort()
|
|
49
|
+
|
|
50
|
+
def get_optimized_onnx_graph(self):
|
|
51
|
+
return gs.export_onnx(self.graph)
|
|
52
|
+
|
|
53
|
+
def select_outputs(self, keep, names=None):
|
|
54
|
+
self.graph.outputs = [self.graph.outputs[o] for o in keep]
|
|
55
|
+
if names:
|
|
56
|
+
for i, name in enumerate(names):
|
|
57
|
+
self.graph.outputs[i].name = name
|
|
58
|
+
|
|
59
|
+
def fold_constants(self):
|
|
60
|
+
onnx_graph = fold_constants(gs.export_onnx(self.graph), allow_onnxruntime_shape_inference=True)
|
|
61
|
+
self.graph = gs.import_onnx(onnx_graph)
|
|
62
|
+
|
|
63
|
+
def infer_shapes(self):
|
|
64
|
+
onnx_graph = gs.export_onnx(self.graph)
|
|
65
|
+
if onnx_graph.ByteSize() >= onnx.checker.MAXIMUM_PROTOBUF:
|
|
66
|
+
with tempfile.TemporaryDirectory() as temp_dir:
|
|
67
|
+
input_onnx_path = os.path.join(temp_dir, "model.onnx")
|
|
68
|
+
onnx.save_model(
|
|
69
|
+
onnx_graph,
|
|
70
|
+
input_onnx_path,
|
|
71
|
+
save_as_external_data=True,
|
|
72
|
+
all_tensors_to_one_file=True,
|
|
73
|
+
convert_attribute=False,
|
|
74
|
+
)
|
|
75
|
+
output_onnx_path = os.path.join(temp_dir, "model_with_shape.onnx")
|
|
76
|
+
onnx.shape_inference.infer_shapes_path(input_onnx_path, output_onnx_path)
|
|
77
|
+
onnx_graph = onnx.load(output_onnx_path)
|
|
78
|
+
else:
|
|
79
|
+
onnx_graph = shape_inference.infer_shapes(onnx_graph)
|
|
80
|
+
|
|
81
|
+
self.graph = gs.import_onnx(onnx_graph)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class PipelineInfo:
|
|
85
|
+
def __init__(
|
|
86
|
+
self,
|
|
87
|
+
version: str,
|
|
88
|
+
is_inpaint: bool = False,
|
|
89
|
+
is_refiner: bool = False,
|
|
90
|
+
use_vae=True, # TODO: this has couple with output type of pipeline
|
|
91
|
+
min_image_size=256,
|
|
92
|
+
max_image_size=1024,
|
|
93
|
+
use_fp16_vae=True,
|
|
94
|
+
use_lcm=False,
|
|
95
|
+
do_classifier_free_guidance=True,
|
|
96
|
+
controlnet=None,
|
|
97
|
+
lora_weights=None,
|
|
98
|
+
lora_scale=1.0,
|
|
99
|
+
):
|
|
100
|
+
self.version = version
|
|
101
|
+
self._is_inpaint = is_inpaint
|
|
102
|
+
self._is_refiner = is_refiner
|
|
103
|
+
self._use_vae = use_vae
|
|
104
|
+
self._min_image_size = min_image_size
|
|
105
|
+
self._max_image_size = max_image_size
|
|
106
|
+
self._use_fp16_vae = use_fp16_vae
|
|
107
|
+
self._use_lcm = use_lcm
|
|
108
|
+
self.do_classifier_free_guidance = do_classifier_free_guidance and not use_lcm
|
|
109
|
+
self.controlnet = controlnet # A list of control net type
|
|
110
|
+
self.lora_weights = lora_weights
|
|
111
|
+
self.lora_scale = lora_scale
|
|
112
|
+
|
|
113
|
+
if is_refiner:
|
|
114
|
+
assert not use_lcm
|
|
115
|
+
assert self.is_xl()
|
|
116
|
+
|
|
117
|
+
def is_inpaint(self) -> bool:
|
|
118
|
+
return self._is_inpaint
|
|
119
|
+
|
|
120
|
+
def is_xl(self) -> bool:
|
|
121
|
+
return "xl" in self.version
|
|
122
|
+
|
|
123
|
+
def is_xl_turbo(self) -> bool:
|
|
124
|
+
return self.version == "xl-turbo"
|
|
125
|
+
|
|
126
|
+
def is_xl_base(self) -> bool:
|
|
127
|
+
return self.version == "xl-1.0" and not self._is_refiner
|
|
128
|
+
|
|
129
|
+
def is_xl_base_or_turbo(self) -> bool:
|
|
130
|
+
return self.is_xl_base() or self.is_xl_turbo()
|
|
131
|
+
|
|
132
|
+
def is_xl_refiner(self) -> bool:
|
|
133
|
+
return self.version == "xl-1.0" and self._is_refiner
|
|
134
|
+
|
|
135
|
+
def use_safetensors(self) -> bool:
|
|
136
|
+
return self.is_xl() or self.version in ["sd-turbo"]
|
|
137
|
+
|
|
138
|
+
def stages(self) -> List[str]:
|
|
139
|
+
if self.is_xl_base_or_turbo():
|
|
140
|
+
return ["clip", "clip2", "unetxl"] + (["vae"] if self._use_vae else [])
|
|
141
|
+
|
|
142
|
+
if self.is_xl_refiner():
|
|
143
|
+
return ["clip2", "unetxl", "vae"]
|
|
144
|
+
|
|
145
|
+
return ["clip", "unet", "vae"]
|
|
146
|
+
|
|
147
|
+
def vae_scaling_factor(self) -> float:
|
|
148
|
+
return 0.13025 if self.is_xl() else 0.18215
|
|
149
|
+
|
|
150
|
+
def vae_torch_fallback(self) -> bool:
|
|
151
|
+
return self.is_xl() and not self._use_fp16_vae
|
|
152
|
+
|
|
153
|
+
def custom_fp16_vae(self) -> Optional[str]:
|
|
154
|
+
# For SD XL, use a VAE that fine-tuned to run in fp16 precision without generating NaNs
|
|
155
|
+
return "madebyollin/sdxl-vae-fp16-fix" if self._use_fp16_vae and self.is_xl() else None
|
|
156
|
+
|
|
157
|
+
def custom_unet(self) -> Optional[str]:
|
|
158
|
+
return "latent-consistency/lcm-sdxl" if self._use_lcm and self.is_xl_base() else None
|
|
159
|
+
|
|
160
|
+
@staticmethod
|
|
161
|
+
def supported_versions(is_xl: bool):
|
|
162
|
+
return ["xl-1.0", "xl-turbo"] if is_xl else ["1.4", "1.5", "2.0-base", "2.0", "2.1", "2.1-base", "sd-turbo"]
|
|
163
|
+
|
|
164
|
+
@staticmethod
|
|
165
|
+
def supported_models():
|
|
166
|
+
return {
|
|
167
|
+
"CompVis/stable-diffusion-v1-4": "1.4",
|
|
168
|
+
"runwayml/stable-diffusion-v1-5": "1.5",
|
|
169
|
+
"stabilityai/stable-diffusion-2-base": "2.0-base",
|
|
170
|
+
"stabilityai/stable-diffusion-2": "2.0",
|
|
171
|
+
"stabilityai/stable-diffusion-2-1": "2.1",
|
|
172
|
+
"stabilityai/stable-diffusion-2-1-base": "2.1",
|
|
173
|
+
"stabilityai/stable-diffusion-xl-base-1.0": "xl-1.0",
|
|
174
|
+
"stabilityai/stable-diffusion-xl-refiner-1.0": "xl-1.0",
|
|
175
|
+
"stabilityai/sdxl-turbo": "xl-turbo",
|
|
176
|
+
"stabilityai/sd-turbo": "sd-turbo",
|
|
177
|
+
# "runwayml/stable-diffusion-inpainting": "1.5",
|
|
178
|
+
# "stabilityai/stable-diffusion-2-inpainting": "2.0",
|
|
179
|
+
}
|
|
180
|
+
|
|
181
|
+
def name(self) -> str:
|
|
182
|
+
if self.version == "1.4":
|
|
183
|
+
if self.is_inpaint():
|
|
184
|
+
return "runwayml/stable-diffusion-inpainting"
|
|
185
|
+
else:
|
|
186
|
+
return "CompVis/stable-diffusion-v1-4"
|
|
187
|
+
elif self.version == "1.5":
|
|
188
|
+
if self.is_inpaint():
|
|
189
|
+
return "runwayml/stable-diffusion-inpainting"
|
|
190
|
+
else:
|
|
191
|
+
return "runwayml/stable-diffusion-v1-5"
|
|
192
|
+
elif self.version == "2.0-base":
|
|
193
|
+
if self.is_inpaint():
|
|
194
|
+
return "stabilityai/stable-diffusion-2-inpainting"
|
|
195
|
+
else:
|
|
196
|
+
return "stabilityai/stable-diffusion-2-base"
|
|
197
|
+
elif self.version == "2.0":
|
|
198
|
+
if self.is_inpaint():
|
|
199
|
+
return "stabilityai/stable-diffusion-2-inpainting"
|
|
200
|
+
else:
|
|
201
|
+
return "stabilityai/stable-diffusion-2"
|
|
202
|
+
elif self.version == "2.1":
|
|
203
|
+
return "stabilityai/stable-diffusion-2-1"
|
|
204
|
+
elif self.version == "2.1-base":
|
|
205
|
+
return "stabilityai/stable-diffusion-2-1-base"
|
|
206
|
+
elif self.version == "xl-1.0":
|
|
207
|
+
if self.is_xl_refiner():
|
|
208
|
+
return "stabilityai/stable-diffusion-xl-refiner-1.0"
|
|
209
|
+
else:
|
|
210
|
+
return "stabilityai/stable-diffusion-xl-base-1.0"
|
|
211
|
+
elif self.version == "xl-turbo":
|
|
212
|
+
return "stabilityai/sdxl-turbo"
|
|
213
|
+
elif self.version == "sd-turbo":
|
|
214
|
+
return "stabilityai/sd-turbo"
|
|
215
|
+
|
|
216
|
+
raise ValueError(f"Incorrect version {self.version}")
|
|
217
|
+
|
|
218
|
+
def short_name(self) -> str:
|
|
219
|
+
return self.name().split("/")[-1].replace("stable-diffusion", "sd")
|
|
220
|
+
|
|
221
|
+
def clip_embedding_dim(self):
|
|
222
|
+
# TODO: can we read from config instead
|
|
223
|
+
if self.version in ("1.4", "1.5"):
|
|
224
|
+
return 768
|
|
225
|
+
elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base", "sd-turbo"):
|
|
226
|
+
return 1024
|
|
227
|
+
elif self.is_xl_base_or_turbo():
|
|
228
|
+
return 768
|
|
229
|
+
else:
|
|
230
|
+
raise ValueError(f"Invalid version {self.version}")
|
|
231
|
+
|
|
232
|
+
def clipwithproj_embedding_dim(self):
|
|
233
|
+
if self.is_xl():
|
|
234
|
+
return 1280
|
|
235
|
+
else:
|
|
236
|
+
raise ValueError(f"Invalid version {self.version}")
|
|
237
|
+
|
|
238
|
+
def unet_embedding_dim(self):
|
|
239
|
+
if self.version in ("1.4", "1.5"):
|
|
240
|
+
return 768
|
|
241
|
+
elif self.version in ("2.0", "2.0-base", "2.1", "2.1-base", "sd-turbo"):
|
|
242
|
+
return 1024
|
|
243
|
+
elif self.is_xl_base_or_turbo():
|
|
244
|
+
return 2048
|
|
245
|
+
elif self.is_xl_refiner():
|
|
246
|
+
return 1280
|
|
247
|
+
else:
|
|
248
|
+
raise ValueError(f"Invalid version {self.version}")
|
|
249
|
+
|
|
250
|
+
def min_image_size(self):
|
|
251
|
+
return self._min_image_size
|
|
252
|
+
|
|
253
|
+
def max_image_size(self):
|
|
254
|
+
return self._max_image_size
|
|
255
|
+
|
|
256
|
+
@staticmethod
|
|
257
|
+
def default_resolution(version: str) -> int:
|
|
258
|
+
if version == "xl-1.0":
|
|
259
|
+
return 1024
|
|
260
|
+
if version in ("2.0", "2.1"):
|
|
261
|
+
return 768
|
|
262
|
+
return 512
|
|
263
|
+
|
|
264
|
+
def default_image_size(self) -> int:
|
|
265
|
+
return PipelineInfo.default_resolution(self.version)
|
|
266
|
+
|
|
267
|
+
@staticmethod
|
|
268
|
+
def supported_controlnet(version="1.5"):
|
|
269
|
+
if version in ("xl-1.0", "xl-turbo"):
|
|
270
|
+
return {
|
|
271
|
+
"canny": "diffusers/controlnet-canny-sdxl-1.0",
|
|
272
|
+
"depth": "diffusers/controlnet-depth-sdxl-1.0",
|
|
273
|
+
}
|
|
274
|
+
elif version == "1.5":
|
|
275
|
+
return {
|
|
276
|
+
"canny": "lllyasviel/control_v11p_sd15_canny",
|
|
277
|
+
"depth": "lllyasviel/control_v11f1p_sd15_depth",
|
|
278
|
+
"openpose": "lllyasviel/control_v11p_sd15_openpose",
|
|
279
|
+
# "tile": "lllyasviel/control_v11f1e_sd15_tile",
|
|
280
|
+
# "lineart": "lllyasviel/control_v11p_sd15_lineart",
|
|
281
|
+
# "inpaint": "lllyasviel/control_v11p_sd15_inpaint",
|
|
282
|
+
# "softedge": "lllyasviel/control_v11p_sd15_softedge",
|
|
283
|
+
"mlsd": "lllyasviel/control_v11p_sd15_mlsd",
|
|
284
|
+
"scribble": "lllyasviel/control_v11p_sd15_scribble",
|
|
285
|
+
# "ip2p": "lllyasviel/control_v11e_sd15_ip2p",
|
|
286
|
+
"normalbae": "lllyasviel/control_v11p_sd15_normalbae",
|
|
287
|
+
"seg": "lllyasviel/control_v11p_sd15_seg",
|
|
288
|
+
# "shuffle": "lllyasviel/control_v11e_sd15_shuffle",
|
|
289
|
+
# "lineart_anime": "lllyasviel/control_v11p_sd15s2_lineart_anime",
|
|
290
|
+
}
|
|
291
|
+
return None
|
|
292
|
+
|
|
293
|
+
def controlnet_name(self):
|
|
294
|
+
"""Return a list of controlnet name"""
|
|
295
|
+
if not self.controlnet:
|
|
296
|
+
return None
|
|
297
|
+
controlnet_map = PipelineInfo.supported_controlnet(self.version)
|
|
298
|
+
if controlnet_map is None:
|
|
299
|
+
return None
|
|
300
|
+
return [controlnet_map[controlnet] for controlnet in self.controlnet]
|
|
301
|
+
|
|
302
|
+
|
|
303
|
+
class BaseModel:
|
|
304
|
+
def __init__(
|
|
305
|
+
self,
|
|
306
|
+
pipeline_info: PipelineInfo,
|
|
307
|
+
model,
|
|
308
|
+
device,
|
|
309
|
+
fp16: bool = False,
|
|
310
|
+
max_batch_size: int = 16,
|
|
311
|
+
embedding_dim: int = 768,
|
|
312
|
+
text_maxlen: int = 77,
|
|
313
|
+
):
|
|
314
|
+
self.name = self.__class__.__name__
|
|
315
|
+
|
|
316
|
+
self.pipeline_info = pipeline_info
|
|
317
|
+
|
|
318
|
+
self.model = model
|
|
319
|
+
self.fp16 = fp16
|
|
320
|
+
self.device = device
|
|
321
|
+
|
|
322
|
+
self.min_batch = 1
|
|
323
|
+
self.max_batch = max_batch_size
|
|
324
|
+
self.min_image_shape = pipeline_info.min_image_size()
|
|
325
|
+
self.max_image_shape = pipeline_info.max_image_size()
|
|
326
|
+
self.min_latent_shape = self.min_image_shape // 8
|
|
327
|
+
self.max_latent_shape = self.max_image_shape // 8
|
|
328
|
+
|
|
329
|
+
self.embedding_dim = embedding_dim
|
|
330
|
+
self.text_maxlen = text_maxlen
|
|
331
|
+
|
|
332
|
+
def get_batch_multiplier(self):
|
|
333
|
+
return 2 if self.pipeline_info.do_classifier_free_guidance else 1
|
|
334
|
+
|
|
335
|
+
def get_ort_optimizer(self):
|
|
336
|
+
model_name_to_model_type = {
|
|
337
|
+
"CLIP": "clip",
|
|
338
|
+
"UNet": "unet",
|
|
339
|
+
"VAE": "vae",
|
|
340
|
+
"UNetXL": "unet",
|
|
341
|
+
"CLIPWithProj": "clip",
|
|
342
|
+
}
|
|
343
|
+
model_type = model_name_to_model_type[self.name]
|
|
344
|
+
return OrtStableDiffusionOptimizer(model_type)
|
|
345
|
+
|
|
346
|
+
def get_model(self):
|
|
347
|
+
return self.model
|
|
348
|
+
|
|
349
|
+
def from_pretrained(self, model_class, framework_model_dir, subfolder=None, model_name=None, **kwargs):
|
|
350
|
+
if model_name is None:
|
|
351
|
+
model_name = self.pipeline_info.name()
|
|
352
|
+
|
|
353
|
+
if subfolder:
|
|
354
|
+
model_dir = os.path.join(framework_model_dir, model_name, subfolder)
|
|
355
|
+
else:
|
|
356
|
+
model_dir = os.path.join(framework_model_dir, model_name)
|
|
357
|
+
|
|
358
|
+
if not os.path.exists(model_dir):
|
|
359
|
+
model = model_class.from_pretrained(
|
|
360
|
+
model_name,
|
|
361
|
+
subfolder=subfolder,
|
|
362
|
+
use_safetensors=self.pipeline_info.use_safetensors(),
|
|
363
|
+
**kwargs,
|
|
364
|
+
).to(self.device)
|
|
365
|
+
model.save_pretrained(model_dir)
|
|
366
|
+
else:
|
|
367
|
+
print(f"Load {self.name} pytorch model from: {model_dir}")
|
|
368
|
+
|
|
369
|
+
model = model_class.from_pretrained(model_dir).to(self.device)
|
|
370
|
+
return model
|
|
371
|
+
|
|
372
|
+
def load_model(self, framework_model_dir: str, subfolder: str):
|
|
373
|
+
pass
|
|
374
|
+
|
|
375
|
+
def get_input_names(self) -> List[str]:
|
|
376
|
+
pass
|
|
377
|
+
|
|
378
|
+
def get_output_names(self) -> List[str]:
|
|
379
|
+
pass
|
|
380
|
+
|
|
381
|
+
def get_dynamic_axes(self) -> Dict[str, Dict[int, str]]:
|
|
382
|
+
pass
|
|
383
|
+
|
|
384
|
+
def get_sample_input(self, batch_size, image_height, image_width) -> tuple:
|
|
385
|
+
pass
|
|
386
|
+
|
|
387
|
+
def get_profile_id(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
388
|
+
"""For TensorRT EP"""
|
|
389
|
+
(
|
|
390
|
+
min_batch,
|
|
391
|
+
max_batch,
|
|
392
|
+
min_image_height,
|
|
393
|
+
max_image_height,
|
|
394
|
+
min_image_width,
|
|
395
|
+
max_image_width,
|
|
396
|
+
_,
|
|
397
|
+
_,
|
|
398
|
+
_,
|
|
399
|
+
_,
|
|
400
|
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
|
401
|
+
|
|
402
|
+
if (self.name in ["UNet", "UNetXL"]) and (self.get_batch_multiplier() == 1):
|
|
403
|
+
profile_id = f"_b1_{batch_size}" if static_batch else f"_b1_{min_batch}_{max_batch}"
|
|
404
|
+
else:
|
|
405
|
+
profile_id = f"_b_{batch_size}" if static_batch else f"_b_{min_batch}_{max_batch}"
|
|
406
|
+
|
|
407
|
+
if self.name != "CLIP":
|
|
408
|
+
if static_image_shape:
|
|
409
|
+
profile_id += f"_h_{image_height}_w_{image_width}"
|
|
410
|
+
else:
|
|
411
|
+
profile_id += f"_h_{min_image_height}_{max_image_height}_w_{min_image_width}_{max_image_width}"
|
|
412
|
+
|
|
413
|
+
return profile_id
|
|
414
|
+
|
|
415
|
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
416
|
+
"""For TensorRT"""
|
|
417
|
+
|
|
418
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
419
|
+
pass
|
|
420
|
+
|
|
421
|
+
def fp32_input_output_names(self) -> List[str]:
|
|
422
|
+
"""For CUDA EP, we export ONNX model with FP32 first, then convert it to mixed precision model.
|
|
423
|
+
This is a list of input or output names that are kept as float32 in optimized model.
|
|
424
|
+
"""
|
|
425
|
+
return []
|
|
426
|
+
|
|
427
|
+
def optimize_ort(
|
|
428
|
+
self,
|
|
429
|
+
input_onnx_path,
|
|
430
|
+
optimized_onnx_path,
|
|
431
|
+
to_fp16=True,
|
|
432
|
+
fp32_op_list=None,
|
|
433
|
+
optimize_by_ort=True,
|
|
434
|
+
optimize_by_fusion=True,
|
|
435
|
+
tmp_dir=None,
|
|
436
|
+
):
|
|
437
|
+
optimizer = self.get_ort_optimizer()
|
|
438
|
+
optimizer.optimize(
|
|
439
|
+
input_onnx_path,
|
|
440
|
+
optimized_onnx_path,
|
|
441
|
+
float16=to_fp16,
|
|
442
|
+
keep_io_types=self.fp32_input_output_names(),
|
|
443
|
+
fp32_op_list=fp32_op_list,
|
|
444
|
+
optimize_by_ort=optimize_by_ort,
|
|
445
|
+
optimize_by_fusion=optimize_by_fusion,
|
|
446
|
+
tmp_dir=tmp_dir,
|
|
447
|
+
)
|
|
448
|
+
|
|
449
|
+
def optimize_trt(self, input_onnx_path, optimized_onnx_path):
|
|
450
|
+
onnx_graph = onnx.load(input_onnx_path)
|
|
451
|
+
opt = TrtOptimizer(onnx_graph)
|
|
452
|
+
opt.cleanup()
|
|
453
|
+
opt.fold_constants()
|
|
454
|
+
opt.infer_shapes()
|
|
455
|
+
opt.cleanup()
|
|
456
|
+
onnx_opt_graph = opt.get_optimized_onnx_graph()
|
|
457
|
+
|
|
458
|
+
if onnx_opt_graph.ByteSize() > onnx.checker.MAXIMUM_PROTOBUF:
|
|
459
|
+
onnx.save_model(
|
|
460
|
+
onnx_opt_graph,
|
|
461
|
+
optimized_onnx_path,
|
|
462
|
+
save_as_external_data=True,
|
|
463
|
+
all_tensors_to_one_file=True,
|
|
464
|
+
convert_attribute=False,
|
|
465
|
+
)
|
|
466
|
+
else:
|
|
467
|
+
onnx.save(onnx_opt_graph, optimized_onnx_path)
|
|
468
|
+
|
|
469
|
+
def check_dims(self, batch_size, image_height, image_width):
|
|
470
|
+
assert batch_size >= self.min_batch and batch_size <= self.max_batch
|
|
471
|
+
assert image_height % 8 == 0 or image_width % 8 == 0
|
|
472
|
+
latent_height = image_height // 8
|
|
473
|
+
latent_width = image_width // 8
|
|
474
|
+
assert latent_height >= self.min_latent_shape and latent_height <= self.max_latent_shape
|
|
475
|
+
assert latent_width >= self.min_latent_shape and latent_width <= self.max_latent_shape
|
|
476
|
+
return (latent_height, latent_width)
|
|
477
|
+
|
|
478
|
+
def get_minmax_dims(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
479
|
+
min_batch = batch_size if static_batch else self.min_batch
|
|
480
|
+
max_batch = batch_size if static_batch else self.max_batch
|
|
481
|
+
latent_height = image_height // 8
|
|
482
|
+
latent_width = image_width // 8
|
|
483
|
+
min_image_height = image_height if static_image_shape else self.min_image_shape
|
|
484
|
+
max_image_height = image_height if static_image_shape else self.max_image_shape
|
|
485
|
+
min_image_width = image_width if static_image_shape else self.min_image_shape
|
|
486
|
+
max_image_width = image_width if static_image_shape else self.max_image_shape
|
|
487
|
+
min_latent_height = latent_height if static_image_shape else self.min_latent_shape
|
|
488
|
+
max_latent_height = latent_height if static_image_shape else self.max_latent_shape
|
|
489
|
+
min_latent_width = latent_width if static_image_shape else self.min_latent_shape
|
|
490
|
+
max_latent_width = latent_width if static_image_shape else self.max_latent_shape
|
|
491
|
+
return (
|
|
492
|
+
min_batch,
|
|
493
|
+
max_batch,
|
|
494
|
+
min_image_height,
|
|
495
|
+
max_image_height,
|
|
496
|
+
min_image_width,
|
|
497
|
+
max_image_width,
|
|
498
|
+
min_latent_height,
|
|
499
|
+
max_latent_height,
|
|
500
|
+
min_latent_width,
|
|
501
|
+
max_latent_width,
|
|
502
|
+
)
|
|
503
|
+
|
|
504
|
+
|
|
505
|
+
class CLIP(BaseModel):
|
|
506
|
+
def __init__(
|
|
507
|
+
self,
|
|
508
|
+
pipeline_info: PipelineInfo,
|
|
509
|
+
model,
|
|
510
|
+
device,
|
|
511
|
+
max_batch_size,
|
|
512
|
+
embedding_dim: int = 0,
|
|
513
|
+
clip_skip=0,
|
|
514
|
+
):
|
|
515
|
+
super().__init__(
|
|
516
|
+
pipeline_info,
|
|
517
|
+
model=model,
|
|
518
|
+
device=device,
|
|
519
|
+
max_batch_size=max_batch_size,
|
|
520
|
+
embedding_dim=embedding_dim if embedding_dim > 0 else pipeline_info.clip_embedding_dim(),
|
|
521
|
+
)
|
|
522
|
+
self.output_hidden_state = pipeline_info.is_xl()
|
|
523
|
+
|
|
524
|
+
# see https://github.com/huggingface/diffusers/pull/5057 for more information of clip_skip.
|
|
525
|
+
# Clip_skip=1 means that the output of the pre-final layer will be used for computing the prompt embeddings.
|
|
526
|
+
self.clip_skip = clip_skip
|
|
527
|
+
|
|
528
|
+
def get_input_names(self):
|
|
529
|
+
return ["input_ids"]
|
|
530
|
+
|
|
531
|
+
def get_output_names(self):
|
|
532
|
+
# The exported onnx model has no hidden_state. For SD-XL, We will add hidden_state to optimized onnx model.
|
|
533
|
+
return ["text_embeddings"]
|
|
534
|
+
|
|
535
|
+
def get_dynamic_axes(self):
|
|
536
|
+
return {"input_ids": {0: "B", 1: "S"}, "text_embeddings": {0: "B", 1: "S"}}
|
|
537
|
+
|
|
538
|
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
539
|
+
self.check_dims(batch_size, image_height, image_width)
|
|
540
|
+
min_batch, max_batch, _, _, _, _, _, _, _, _ = self.get_minmax_dims(
|
|
541
|
+
batch_size, image_height, image_width, static_batch, static_image_shape
|
|
542
|
+
)
|
|
543
|
+
return {
|
|
544
|
+
"input_ids": [(min_batch, self.text_maxlen), (batch_size, self.text_maxlen), (max_batch, self.text_maxlen)]
|
|
545
|
+
}
|
|
546
|
+
|
|
547
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
548
|
+
self.check_dims(batch_size, image_height, image_width)
|
|
549
|
+
output = {
|
|
550
|
+
"input_ids": (batch_size, self.text_maxlen),
|
|
551
|
+
"text_embeddings": (batch_size, self.text_maxlen, self.embedding_dim),
|
|
552
|
+
}
|
|
553
|
+
|
|
554
|
+
if self.output_hidden_state:
|
|
555
|
+
output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
|
|
556
|
+
|
|
557
|
+
return output
|
|
558
|
+
|
|
559
|
+
def get_sample_input(self, batch_size, image_height, image_width):
|
|
560
|
+
self.check_dims(batch_size, image_height, image_width)
|
|
561
|
+
return (torch.zeros(batch_size, self.text_maxlen, dtype=torch.int32, device=self.device),)
|
|
562
|
+
|
|
563
|
+
def add_hidden_states_graph_output(self, model: ModelProto, optimized_onnx_path, use_external_data_format=False):
|
|
564
|
+
graph: GraphProto = model.graph
|
|
565
|
+
hidden_layers = -1
|
|
566
|
+
for i in range(len(graph.node)):
|
|
567
|
+
for j in range(len(graph.node[i].output)):
|
|
568
|
+
name = graph.node[i].output[j]
|
|
569
|
+
if "layers" in name:
|
|
570
|
+
hidden_layers = max(int(name.split(".")[1].split("/")[0]), hidden_layers)
|
|
571
|
+
|
|
572
|
+
assert self.clip_skip >= 0 and self.clip_skip < hidden_layers
|
|
573
|
+
|
|
574
|
+
node_output_name = f"/text_model/encoder/layers.{hidden_layers - 1 - self.clip_skip}/Add_1_output_0"
|
|
575
|
+
|
|
576
|
+
# search the name in outputs of all node
|
|
577
|
+
found = False
|
|
578
|
+
for i in range(len(graph.node)):
|
|
579
|
+
for j in range(len(graph.node[i].output)):
|
|
580
|
+
if graph.node[i].output[j] == node_output_name:
|
|
581
|
+
found = True
|
|
582
|
+
break
|
|
583
|
+
if found:
|
|
584
|
+
break
|
|
585
|
+
if not found:
|
|
586
|
+
raise RuntimeError("Failed to find hidden_states graph output in clip")
|
|
587
|
+
|
|
588
|
+
# Insert a Cast (fp32 -> fp16) node so that hidden_states has same data type as the first graph output.
|
|
589
|
+
graph_output_name = "hidden_states"
|
|
590
|
+
cast_node = onnx.helper.make_node("Cast", inputs=[node_output_name], outputs=[graph_output_name])
|
|
591
|
+
cast_node.attribute.extend([onnx.helper.make_attribute("to", graph.output[0].type.tensor_type.elem_type)])
|
|
592
|
+
|
|
593
|
+
hidden_state = graph.output.add()
|
|
594
|
+
hidden_state.CopyFrom(
|
|
595
|
+
onnx.helper.make_tensor_value_info(
|
|
596
|
+
graph_output_name,
|
|
597
|
+
graph.output[0].type.tensor_type.elem_type,
|
|
598
|
+
["B", "S", self.embedding_dim],
|
|
599
|
+
)
|
|
600
|
+
)
|
|
601
|
+
|
|
602
|
+
onnx_model = OnnxModel(model)
|
|
603
|
+
onnx_model.add_node(cast_node)
|
|
604
|
+
onnx_model.save_model_to_file(optimized_onnx_path, use_external_data_format=use_external_data_format)
|
|
605
|
+
|
|
606
|
+
def optimize_ort(
|
|
607
|
+
self,
|
|
608
|
+
input_onnx_path,
|
|
609
|
+
optimized_onnx_path,
|
|
610
|
+
to_fp16=True,
|
|
611
|
+
fp32_op_list=None,
|
|
612
|
+
optimize_by_ort=True,
|
|
613
|
+
optimize_by_fusion=True,
|
|
614
|
+
tmp_dir=None,
|
|
615
|
+
):
|
|
616
|
+
optimizer = self.get_ort_optimizer()
|
|
617
|
+
|
|
618
|
+
if not self.output_hidden_state:
|
|
619
|
+
optimizer.optimize(
|
|
620
|
+
input_onnx_path,
|
|
621
|
+
optimized_onnx_path,
|
|
622
|
+
float16=to_fp16,
|
|
623
|
+
keep_io_types=[],
|
|
624
|
+
fp32_op_list=fp32_op_list,
|
|
625
|
+
keep_outputs=["text_embeddings"],
|
|
626
|
+
optimize_by_ort=optimize_by_ort,
|
|
627
|
+
optimize_by_fusion=optimize_by_fusion,
|
|
628
|
+
tmp_dir=tmp_dir,
|
|
629
|
+
)
|
|
630
|
+
elif optimize_by_fusion:
|
|
631
|
+
with tempfile.TemporaryDirectory() as tmp_dir:
|
|
632
|
+
# Save to a temporary file so that we can load it with Onnx Runtime.
|
|
633
|
+
logger.info("Saving a temporary model to add hidden_states to graph output ...")
|
|
634
|
+
tmp_model_path = os.path.join(tmp_dir, "model.onnx")
|
|
635
|
+
|
|
636
|
+
model = onnx.load(input_onnx_path)
|
|
637
|
+
self.add_hidden_states_graph_output(model, tmp_model_path, use_external_data_format=True)
|
|
638
|
+
optimizer.optimize(
|
|
639
|
+
tmp_model_path,
|
|
640
|
+
optimized_onnx_path,
|
|
641
|
+
float16=to_fp16,
|
|
642
|
+
keep_io_types=[],
|
|
643
|
+
fp32_op_list=fp32_op_list,
|
|
644
|
+
keep_outputs=["text_embeddings", "hidden_states"],
|
|
645
|
+
optimize_by_ort=optimize_by_ort,
|
|
646
|
+
optimize_by_fusion=optimize_by_fusion,
|
|
647
|
+
tmp_dir=tmp_dir,
|
|
648
|
+
)
|
|
649
|
+
else: # input is optimized model, there is no need to add hidden states.
|
|
650
|
+
optimizer.optimize(
|
|
651
|
+
input_onnx_path,
|
|
652
|
+
optimized_onnx_path,
|
|
653
|
+
float16=to_fp16,
|
|
654
|
+
keep_io_types=[],
|
|
655
|
+
fp32_op_list=fp32_op_list,
|
|
656
|
+
keep_outputs=["text_embeddings", "hidden_states"],
|
|
657
|
+
optimize_by_ort=optimize_by_ort,
|
|
658
|
+
optimize_by_fusion=optimize_by_fusion,
|
|
659
|
+
tmp_dir=tmp_dir,
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
def optimize_trt(self, input_onnx_path, optimized_onnx_path):
|
|
663
|
+
onnx_graph = onnx.load(input_onnx_path)
|
|
664
|
+
opt = TrtOptimizer(onnx_graph)
|
|
665
|
+
opt.select_outputs([0]) # delete graph output#1
|
|
666
|
+
opt.cleanup()
|
|
667
|
+
opt.fold_constants()
|
|
668
|
+
opt.infer_shapes()
|
|
669
|
+
opt.select_outputs([0], names=["text_embeddings"]) # rename network output
|
|
670
|
+
opt.cleanup()
|
|
671
|
+
onnx_opt_graph = opt.get_optimized_onnx_graph()
|
|
672
|
+
if self.output_hidden_state:
|
|
673
|
+
self.add_hidden_states_graph_output(onnx_opt_graph, optimized_onnx_path)
|
|
674
|
+
else:
|
|
675
|
+
onnx.save(onnx_opt_graph, optimized_onnx_path)
|
|
676
|
+
|
|
677
|
+
def load_model(self, framework_model_dir, subfolder="text_encoder"):
|
|
678
|
+
return self.from_pretrained(CLIPTextModel, framework_model_dir, subfolder)
|
|
679
|
+
|
|
680
|
+
|
|
681
|
+
class CLIPWithProj(CLIP):
|
|
682
|
+
def __init__(
|
|
683
|
+
self,
|
|
684
|
+
pipeline_info: PipelineInfo,
|
|
685
|
+
model,
|
|
686
|
+
device,
|
|
687
|
+
max_batch_size=16,
|
|
688
|
+
clip_skip=0,
|
|
689
|
+
):
|
|
690
|
+
super().__init__(
|
|
691
|
+
pipeline_info,
|
|
692
|
+
model,
|
|
693
|
+
device=device,
|
|
694
|
+
max_batch_size=max_batch_size,
|
|
695
|
+
embedding_dim=pipeline_info.clipwithproj_embedding_dim(),
|
|
696
|
+
clip_skip=clip_skip,
|
|
697
|
+
)
|
|
698
|
+
|
|
699
|
+
def load_model(self, framework_model_dir, subfolder="text_encoder_2"):
|
|
700
|
+
return self.from_pretrained(CLIPTextModelWithProjection, framework_model_dir, subfolder)
|
|
701
|
+
|
|
702
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
703
|
+
self.check_dims(batch_size, image_height, image_width)
|
|
704
|
+
output = {
|
|
705
|
+
"input_ids": (batch_size, self.text_maxlen),
|
|
706
|
+
"text_embeddings": (batch_size, self.embedding_dim),
|
|
707
|
+
}
|
|
708
|
+
|
|
709
|
+
if self.output_hidden_state:
|
|
710
|
+
output["hidden_states"] = (batch_size, self.text_maxlen, self.embedding_dim)
|
|
711
|
+
|
|
712
|
+
return output
|
|
713
|
+
|
|
714
|
+
|
|
715
|
+
class UNet2DConditionControlNetModel(torch.nn.Module):
|
|
716
|
+
def __init__(self, unet, controlnets: ControlNetModel):
|
|
717
|
+
super().__init__()
|
|
718
|
+
self.unet = unet
|
|
719
|
+
self.controlnets = controlnets
|
|
720
|
+
|
|
721
|
+
def forward(self, sample, timestep, encoder_hidden_states, controlnet_images, controlnet_scales):
|
|
722
|
+
for i, (controlnet_image, conditioning_scale, controlnet) in enumerate(
|
|
723
|
+
zip(controlnet_images, controlnet_scales, self.controlnets)
|
|
724
|
+
):
|
|
725
|
+
down_samples, mid_sample = controlnet(
|
|
726
|
+
sample,
|
|
727
|
+
timestep,
|
|
728
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
729
|
+
controlnet_cond=controlnet_image,
|
|
730
|
+
return_dict=False,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
down_samples = [down_sample * conditioning_scale for down_sample in down_samples]
|
|
734
|
+
mid_sample *= conditioning_scale
|
|
735
|
+
|
|
736
|
+
# merge samples
|
|
737
|
+
if i == 0:
|
|
738
|
+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
|
739
|
+
else:
|
|
740
|
+
down_block_res_samples = [
|
|
741
|
+
samples_prev + samples_curr
|
|
742
|
+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
|
743
|
+
]
|
|
744
|
+
mid_block_res_sample += mid_sample
|
|
745
|
+
|
|
746
|
+
noise_pred = self.unet(
|
|
747
|
+
sample,
|
|
748
|
+
timestep,
|
|
749
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
750
|
+
down_block_additional_residuals=down_block_res_samples,
|
|
751
|
+
mid_block_additional_residual=mid_block_res_sample,
|
|
752
|
+
)
|
|
753
|
+
return noise_pred[0]
|
|
754
|
+
|
|
755
|
+
|
|
756
|
+
# Modified from convert_stable_diffusion_controlnet_to_onnx.py in diffusers
|
|
757
|
+
class UNet2DConditionXLControlNetModel(torch.nn.Module):
|
|
758
|
+
def __init__(self, unet, controlnets: ControlNetModel):
|
|
759
|
+
super().__init__()
|
|
760
|
+
self.unet = unet
|
|
761
|
+
self.controlnets = controlnets
|
|
762
|
+
|
|
763
|
+
def forward(
|
|
764
|
+
self,
|
|
765
|
+
sample,
|
|
766
|
+
timestep,
|
|
767
|
+
encoder_hidden_states,
|
|
768
|
+
text_embeds,
|
|
769
|
+
time_ids,
|
|
770
|
+
controlnet_images,
|
|
771
|
+
controlnet_scales,
|
|
772
|
+
):
|
|
773
|
+
added_cond_kwargs = {"text_embeds": text_embeds, "time_ids": time_ids}
|
|
774
|
+
for i, (controlnet_image, conditioning_scale, controlnet) in enumerate(
|
|
775
|
+
zip(controlnet_images, controlnet_scales, self.controlnets)
|
|
776
|
+
):
|
|
777
|
+
down_samples, mid_sample = controlnet(
|
|
778
|
+
sample,
|
|
779
|
+
timestep,
|
|
780
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
781
|
+
controlnet_cond=controlnet_image,
|
|
782
|
+
conditioning_scale=conditioning_scale,
|
|
783
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
784
|
+
return_dict=False,
|
|
785
|
+
)
|
|
786
|
+
|
|
787
|
+
# merge samples
|
|
788
|
+
if i == 0:
|
|
789
|
+
down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
|
|
790
|
+
else:
|
|
791
|
+
down_block_res_samples = [
|
|
792
|
+
samples_prev + samples_curr
|
|
793
|
+
for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
|
|
794
|
+
]
|
|
795
|
+
mid_block_res_sample += mid_sample
|
|
796
|
+
|
|
797
|
+
noise_pred = self.unet(
|
|
798
|
+
sample,
|
|
799
|
+
timestep,
|
|
800
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
801
|
+
down_block_additional_residuals=down_block_res_samples,
|
|
802
|
+
mid_block_additional_residual=mid_block_res_sample,
|
|
803
|
+
added_cond_kwargs=added_cond_kwargs,
|
|
804
|
+
return_dict=False,
|
|
805
|
+
)
|
|
806
|
+
return noise_pred[0]
|
|
807
|
+
|
|
808
|
+
|
|
809
|
+
class UNet(BaseModel):
|
|
810
|
+
def __init__(
|
|
811
|
+
self,
|
|
812
|
+
pipeline_info: PipelineInfo,
|
|
813
|
+
model,
|
|
814
|
+
device,
|
|
815
|
+
fp16=False, # used by TRT
|
|
816
|
+
max_batch_size=16,
|
|
817
|
+
text_maxlen=77,
|
|
818
|
+
unet_dim=4,
|
|
819
|
+
):
|
|
820
|
+
super().__init__(
|
|
821
|
+
pipeline_info,
|
|
822
|
+
model=model,
|
|
823
|
+
device=device,
|
|
824
|
+
fp16=fp16,
|
|
825
|
+
max_batch_size=max_batch_size,
|
|
826
|
+
embedding_dim=pipeline_info.unet_embedding_dim(),
|
|
827
|
+
text_maxlen=text_maxlen,
|
|
828
|
+
)
|
|
829
|
+
|
|
830
|
+
self.unet_dim = unet_dim
|
|
831
|
+
self.controlnet = pipeline_info.controlnet_name()
|
|
832
|
+
|
|
833
|
+
def load_model(self, framework_model_dir, subfolder="unet"):
|
|
834
|
+
options = {"variant": "fp16", "torch_dtype": torch.float16}
|
|
835
|
+
|
|
836
|
+
model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options)
|
|
837
|
+
|
|
838
|
+
if self.controlnet:
|
|
839
|
+
controlnet_list = []
|
|
840
|
+
for name in self.controlnet:
|
|
841
|
+
controlnet = self.from_pretrained(
|
|
842
|
+
ControlNetModel,
|
|
843
|
+
framework_model_dir,
|
|
844
|
+
subfolder=None,
|
|
845
|
+
model_name=name,
|
|
846
|
+
torch_dtype=torch.float16,
|
|
847
|
+
)
|
|
848
|
+
controlnet_list.append(controlnet)
|
|
849
|
+
|
|
850
|
+
model = UNet2DConditionControlNetModel(model, torch.nn.ModuleList(controlnet_list))
|
|
851
|
+
|
|
852
|
+
if not self.fp16:
|
|
853
|
+
model = model.to(torch.float32)
|
|
854
|
+
|
|
855
|
+
return model
|
|
856
|
+
|
|
857
|
+
def get_input_names(self):
|
|
858
|
+
if not self.controlnet:
|
|
859
|
+
return ["sample", "timestep", "encoder_hidden_states"]
|
|
860
|
+
else:
|
|
861
|
+
return ["sample", "timestep", "encoder_hidden_states", "controlnet_images", "controlnet_scales"]
|
|
862
|
+
|
|
863
|
+
def get_output_names(self):
|
|
864
|
+
return ["latent"]
|
|
865
|
+
|
|
866
|
+
def get_dynamic_axes(self):
|
|
867
|
+
b = "2B" if self.get_batch_multiplier() == 2 else "B"
|
|
868
|
+
output = {
|
|
869
|
+
"sample": {0: b, 2: "H", 3: "W"},
|
|
870
|
+
"encoder_hidden_states": {0: b},
|
|
871
|
+
"latent": {0: b, 2: "H", 3: "W"},
|
|
872
|
+
}
|
|
873
|
+
if self.controlnet:
|
|
874
|
+
output.update(
|
|
875
|
+
{
|
|
876
|
+
"controlnet_images": {1: b, 3: "8H", 4: "8W"},
|
|
877
|
+
}
|
|
878
|
+
)
|
|
879
|
+
return output
|
|
880
|
+
|
|
881
|
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
882
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
883
|
+
(
|
|
884
|
+
min_batch,
|
|
885
|
+
max_batch,
|
|
886
|
+
min_image_height,
|
|
887
|
+
max_image_height,
|
|
888
|
+
min_image_width,
|
|
889
|
+
max_image_width,
|
|
890
|
+
min_latent_height,
|
|
891
|
+
max_latent_height,
|
|
892
|
+
min_latent_width,
|
|
893
|
+
max_latent_width,
|
|
894
|
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
|
895
|
+
m = self.get_batch_multiplier()
|
|
896
|
+
output = {
|
|
897
|
+
"sample": [
|
|
898
|
+
(m * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
|
899
|
+
(m * batch_size, self.unet_dim, latent_height, latent_width),
|
|
900
|
+
(m * max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
|
901
|
+
],
|
|
902
|
+
"encoder_hidden_states": [
|
|
903
|
+
(m * min_batch, self.text_maxlen, self.embedding_dim),
|
|
904
|
+
(m * batch_size, self.text_maxlen, self.embedding_dim),
|
|
905
|
+
(m * max_batch, self.text_maxlen, self.embedding_dim),
|
|
906
|
+
],
|
|
907
|
+
}
|
|
908
|
+
|
|
909
|
+
if self.controlnet:
|
|
910
|
+
output.update(
|
|
911
|
+
{
|
|
912
|
+
"controlnet_images": [
|
|
913
|
+
(len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width),
|
|
914
|
+
(len(self.controlnet), m * batch_size, 3, image_height, image_width),
|
|
915
|
+
(len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width),
|
|
916
|
+
]
|
|
917
|
+
}
|
|
918
|
+
)
|
|
919
|
+
return output
|
|
920
|
+
|
|
921
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
922
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
923
|
+
m = self.get_batch_multiplier()
|
|
924
|
+
output = {
|
|
925
|
+
"sample": (m * batch_size, self.unet_dim, latent_height, latent_width),
|
|
926
|
+
"timestep": [1],
|
|
927
|
+
"encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim),
|
|
928
|
+
"latent": (m * batch_size, 4, latent_height, latent_width),
|
|
929
|
+
}
|
|
930
|
+
|
|
931
|
+
if self.controlnet:
|
|
932
|
+
output.update(
|
|
933
|
+
{
|
|
934
|
+
"controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width),
|
|
935
|
+
"controlnet_scales": [len(self.controlnet)],
|
|
936
|
+
}
|
|
937
|
+
)
|
|
938
|
+
return output
|
|
939
|
+
|
|
940
|
+
def get_sample_input(self, batch_size, image_height, image_width):
|
|
941
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
942
|
+
dtype = torch.float16 if self.fp16 else torch.float32
|
|
943
|
+
m = self.get_batch_multiplier()
|
|
944
|
+
output = (
|
|
945
|
+
torch.randn(m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device),
|
|
946
|
+
torch.tensor([1.0], dtype=dtype, device=self.device),
|
|
947
|
+
torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
|
948
|
+
)
|
|
949
|
+
|
|
950
|
+
if self.controlnet:
|
|
951
|
+
output = (
|
|
952
|
+
*output,
|
|
953
|
+
torch.randn(
|
|
954
|
+
len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device
|
|
955
|
+
),
|
|
956
|
+
torch.randn(len(self.controlnet), dtype=dtype, device=self.device),
|
|
957
|
+
)
|
|
958
|
+
return output
|
|
959
|
+
|
|
960
|
+
|
|
961
|
+
class UNetXL(BaseModel):
|
|
962
|
+
def __init__(
|
|
963
|
+
self,
|
|
964
|
+
pipeline_info: PipelineInfo,
|
|
965
|
+
model,
|
|
966
|
+
device,
|
|
967
|
+
fp16=False, # used by TRT
|
|
968
|
+
max_batch_size=16,
|
|
969
|
+
text_maxlen=77,
|
|
970
|
+
unet_dim=4,
|
|
971
|
+
time_dim=6,
|
|
972
|
+
):
|
|
973
|
+
super().__init__(
|
|
974
|
+
pipeline_info,
|
|
975
|
+
model,
|
|
976
|
+
device=device,
|
|
977
|
+
fp16=fp16,
|
|
978
|
+
max_batch_size=max_batch_size,
|
|
979
|
+
embedding_dim=pipeline_info.unet_embedding_dim(),
|
|
980
|
+
text_maxlen=text_maxlen,
|
|
981
|
+
)
|
|
982
|
+
self.unet_dim = unet_dim
|
|
983
|
+
self.time_dim = time_dim
|
|
984
|
+
|
|
985
|
+
self.custom_unet = pipeline_info.custom_unet()
|
|
986
|
+
self.controlnet = pipeline_info.controlnet_name()
|
|
987
|
+
|
|
988
|
+
def load_model(self, framework_model_dir, subfolder="unet", always_download_fp16=True):
|
|
989
|
+
options = {"variant": "fp16", "torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {}
|
|
990
|
+
|
|
991
|
+
if self.custom_unet:
|
|
992
|
+
model_dir = os.path.join(framework_model_dir, self.custom_unet, subfolder)
|
|
993
|
+
if not os.path.exists(model_dir):
|
|
994
|
+
unet = UNet2DConditionModel.from_pretrained(self.custom_unet, **options)
|
|
995
|
+
unet.save_pretrained(model_dir)
|
|
996
|
+
else:
|
|
997
|
+
unet = UNet2DConditionModel.from_pretrained(model_dir, **options)
|
|
998
|
+
model = unet.to(self.device)
|
|
999
|
+
else:
|
|
1000
|
+
model = self.from_pretrained(UNet2DConditionModel, framework_model_dir, subfolder, **options)
|
|
1001
|
+
|
|
1002
|
+
if always_download_fp16 and not self.fp16:
|
|
1003
|
+
model = model.to(torch.float32)
|
|
1004
|
+
|
|
1005
|
+
if self.controlnet:
|
|
1006
|
+
cnet_model_opts = {"torch_dtype": torch.float16} if self.fp16 or always_download_fp16 else {}
|
|
1007
|
+
controlnets = torch.nn.ModuleList(
|
|
1008
|
+
[ControlNetModel.from_pretrained(path, **cnet_model_opts).to(self.device) for path in self.controlnet]
|
|
1009
|
+
)
|
|
1010
|
+
model = UNet2DConditionXLControlNetModel(model, controlnets)
|
|
1011
|
+
|
|
1012
|
+
if always_download_fp16 and not self.fp16:
|
|
1013
|
+
model = model.to(torch.float32)
|
|
1014
|
+
|
|
1015
|
+
return model
|
|
1016
|
+
|
|
1017
|
+
def get_input_names(self):
|
|
1018
|
+
input_names = ["sample", "timestep", "encoder_hidden_states", "text_embeds", "time_ids"]
|
|
1019
|
+
if self.controlnet:
|
|
1020
|
+
return [*input_names, "controlnet_images", "controlnet_scales"]
|
|
1021
|
+
return input_names
|
|
1022
|
+
|
|
1023
|
+
def get_output_names(self):
|
|
1024
|
+
return ["latent"]
|
|
1025
|
+
|
|
1026
|
+
def get_dynamic_axes(self):
|
|
1027
|
+
b = "2B" if self.get_batch_multiplier() == 2 else "B"
|
|
1028
|
+
output = {
|
|
1029
|
+
"sample": {0: b, 2: "H", 3: "W"},
|
|
1030
|
+
"encoder_hidden_states": {0: b},
|
|
1031
|
+
"text_embeds": {0: b},
|
|
1032
|
+
"time_ids": {0: b},
|
|
1033
|
+
"latent": {0: b, 2: "H", 3: "W"},
|
|
1034
|
+
}
|
|
1035
|
+
|
|
1036
|
+
if self.controlnet:
|
|
1037
|
+
output.update(
|
|
1038
|
+
{
|
|
1039
|
+
"controlnet_images": {1: b, 3: "8H", 4: "8W"},
|
|
1040
|
+
}
|
|
1041
|
+
)
|
|
1042
|
+
return output
|
|
1043
|
+
|
|
1044
|
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
1045
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1046
|
+
(
|
|
1047
|
+
min_batch,
|
|
1048
|
+
max_batch,
|
|
1049
|
+
min_image_height,
|
|
1050
|
+
max_image_height,
|
|
1051
|
+
min_image_width,
|
|
1052
|
+
max_image_width,
|
|
1053
|
+
min_latent_height,
|
|
1054
|
+
max_latent_height,
|
|
1055
|
+
min_latent_width,
|
|
1056
|
+
max_latent_width,
|
|
1057
|
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
|
1058
|
+
m = self.get_batch_multiplier()
|
|
1059
|
+
output = {
|
|
1060
|
+
"sample": [
|
|
1061
|
+
(m * min_batch, self.unet_dim, min_latent_height, min_latent_width),
|
|
1062
|
+
(m * batch_size, self.unet_dim, latent_height, latent_width),
|
|
1063
|
+
(m * max_batch, self.unet_dim, max_latent_height, max_latent_width),
|
|
1064
|
+
],
|
|
1065
|
+
"encoder_hidden_states": [
|
|
1066
|
+
(m * min_batch, self.text_maxlen, self.embedding_dim),
|
|
1067
|
+
(m * batch_size, self.text_maxlen, self.embedding_dim),
|
|
1068
|
+
(m * max_batch, self.text_maxlen, self.embedding_dim),
|
|
1069
|
+
],
|
|
1070
|
+
"text_embeds": [(m * min_batch, 1280), (m * batch_size, 1280), (m * max_batch, 1280)],
|
|
1071
|
+
"time_ids": [
|
|
1072
|
+
(m * min_batch, self.time_dim),
|
|
1073
|
+
(m * batch_size, self.time_dim),
|
|
1074
|
+
(m * max_batch, self.time_dim),
|
|
1075
|
+
],
|
|
1076
|
+
}
|
|
1077
|
+
|
|
1078
|
+
if self.controlnet:
|
|
1079
|
+
output.update(
|
|
1080
|
+
{
|
|
1081
|
+
"controlnet_images": [
|
|
1082
|
+
(len(self.controlnet), m * min_batch, 3, min_image_height, min_image_width),
|
|
1083
|
+
(len(self.controlnet), m * batch_size, 3, image_height, image_width),
|
|
1084
|
+
(len(self.controlnet), m * max_batch, 3, max_image_height, max_image_width),
|
|
1085
|
+
],
|
|
1086
|
+
}
|
|
1087
|
+
)
|
|
1088
|
+
return output
|
|
1089
|
+
|
|
1090
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
1091
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1092
|
+
m = self.get_batch_multiplier()
|
|
1093
|
+
output = {
|
|
1094
|
+
"sample": (m * batch_size, self.unet_dim, latent_height, latent_width),
|
|
1095
|
+
"timestep": (1,),
|
|
1096
|
+
"encoder_hidden_states": (m * batch_size, self.text_maxlen, self.embedding_dim),
|
|
1097
|
+
"text_embeds": (m * batch_size, 1280),
|
|
1098
|
+
"time_ids": (m * batch_size, self.time_dim),
|
|
1099
|
+
"latent": (m * batch_size, 4, latent_height, latent_width),
|
|
1100
|
+
}
|
|
1101
|
+
|
|
1102
|
+
if self.controlnet:
|
|
1103
|
+
output.update(
|
|
1104
|
+
{
|
|
1105
|
+
"controlnet_images": (len(self.controlnet), m * batch_size, 3, image_height, image_width),
|
|
1106
|
+
"controlnet_scales": [len(self.controlnet)],
|
|
1107
|
+
}
|
|
1108
|
+
)
|
|
1109
|
+
return output
|
|
1110
|
+
|
|
1111
|
+
def get_sample_input(self, batch_size, image_height, image_width):
|
|
1112
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1113
|
+
dtype = torch.float16 if self.fp16 else torch.float32
|
|
1114
|
+
m = self.get_batch_multiplier()
|
|
1115
|
+
if not self.controlnet:
|
|
1116
|
+
return (
|
|
1117
|
+
torch.randn(
|
|
1118
|
+
m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device
|
|
1119
|
+
),
|
|
1120
|
+
torch.tensor([1.0], dtype=dtype, device=self.device),
|
|
1121
|
+
torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
|
1122
|
+
{
|
|
1123
|
+
"added_cond_kwargs": {
|
|
1124
|
+
"text_embeds": torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device),
|
|
1125
|
+
"time_ids": torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device),
|
|
1126
|
+
}
|
|
1127
|
+
},
|
|
1128
|
+
)
|
|
1129
|
+
else:
|
|
1130
|
+
# sample, timestep, encoder_hidden_states, text_embeds, time_ids, controlnet_images, controlnet_scales,
|
|
1131
|
+
return (
|
|
1132
|
+
torch.randn(
|
|
1133
|
+
m * batch_size, self.unet_dim, latent_height, latent_width, dtype=dtype, device=self.device
|
|
1134
|
+
),
|
|
1135
|
+
torch.tensor([1.0], dtype=dtype, device=self.device),
|
|
1136
|
+
torch.randn(m * batch_size, self.text_maxlen, self.embedding_dim, dtype=dtype, device=self.device),
|
|
1137
|
+
torch.randn(m * batch_size, 1280, dtype=dtype, device=self.device),
|
|
1138
|
+
torch.randn(m * batch_size, self.time_dim, dtype=dtype, device=self.device),
|
|
1139
|
+
torch.randn(
|
|
1140
|
+
len(self.controlnet), m * batch_size, 3, image_height, image_width, dtype=dtype, device=self.device
|
|
1141
|
+
),
|
|
1142
|
+
torch.randn(len(self.controlnet), dtype=dtype, device=self.device),
|
|
1143
|
+
)
|
|
1144
|
+
|
|
1145
|
+
|
|
1146
|
+
# VAE Decoder
|
|
1147
|
+
class VAE(BaseModel):
|
|
1148
|
+
def __init__(
|
|
1149
|
+
self,
|
|
1150
|
+
pipeline_info: PipelineInfo,
|
|
1151
|
+
model,
|
|
1152
|
+
device,
|
|
1153
|
+
max_batch_size,
|
|
1154
|
+
fp16: bool = False,
|
|
1155
|
+
custom_fp16_vae: Optional[str] = None,
|
|
1156
|
+
):
|
|
1157
|
+
super().__init__(
|
|
1158
|
+
pipeline_info,
|
|
1159
|
+
model=model,
|
|
1160
|
+
device=device,
|
|
1161
|
+
fp16=fp16,
|
|
1162
|
+
max_batch_size=max_batch_size,
|
|
1163
|
+
)
|
|
1164
|
+
|
|
1165
|
+
# For SD XL, need custom trained fp16 model to speed up, and avoid overflow at the same time.
|
|
1166
|
+
self.custom_fp16_vae = custom_fp16_vae
|
|
1167
|
+
|
|
1168
|
+
def load_model(self, framework_model_dir, subfolder: str = "vae_decoder"):
|
|
1169
|
+
model_name = self.custom_fp16_vae or self.pipeline_info.name()
|
|
1170
|
+
|
|
1171
|
+
model_dir = os.path.join(framework_model_dir, model_name, subfolder)
|
|
1172
|
+
if not os.path.exists(model_dir):
|
|
1173
|
+
if self.custom_fp16_vae:
|
|
1174
|
+
vae = AutoencoderKL.from_pretrained(self.custom_fp16_vae, torch_dtype=torch.float16).to(self.device)
|
|
1175
|
+
else:
|
|
1176
|
+
vae = AutoencoderKL.from_pretrained(
|
|
1177
|
+
self.pipeline_info.name(),
|
|
1178
|
+
subfolder="vae",
|
|
1179
|
+
use_safetensors=self.pipeline_info.use_safetensors(),
|
|
1180
|
+
).to(self.device)
|
|
1181
|
+
vae.save_pretrained(model_dir)
|
|
1182
|
+
else:
|
|
1183
|
+
print(f"Load {self.name} pytorch model from: {model_dir}")
|
|
1184
|
+
if self.custom_fp16_vae:
|
|
1185
|
+
vae = AutoencoderKL.from_pretrained(model_dir, torch_dtype=torch.float16).to(self.device)
|
|
1186
|
+
else:
|
|
1187
|
+
vae = AutoencoderKL.from_pretrained(model_dir).to(self.device)
|
|
1188
|
+
|
|
1189
|
+
vae.forward = vae.decode
|
|
1190
|
+
return vae
|
|
1191
|
+
|
|
1192
|
+
def get_input_names(self):
|
|
1193
|
+
return ["latent"]
|
|
1194
|
+
|
|
1195
|
+
def get_output_names(self):
|
|
1196
|
+
return ["images"]
|
|
1197
|
+
|
|
1198
|
+
def get_dynamic_axes(self):
|
|
1199
|
+
return {"latent": {0: "B", 2: "H", 3: "W"}, "images": {0: "B", 2: "8H", 3: "8W"}}
|
|
1200
|
+
|
|
1201
|
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
1202
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1203
|
+
(
|
|
1204
|
+
min_batch,
|
|
1205
|
+
max_batch,
|
|
1206
|
+
_,
|
|
1207
|
+
_,
|
|
1208
|
+
_,
|
|
1209
|
+
_,
|
|
1210
|
+
min_latent_height,
|
|
1211
|
+
max_latent_height,
|
|
1212
|
+
min_latent_width,
|
|
1213
|
+
max_latent_width,
|
|
1214
|
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
|
1215
|
+
return {
|
|
1216
|
+
"latent": [
|
|
1217
|
+
(min_batch, 4, min_latent_height, min_latent_width),
|
|
1218
|
+
(batch_size, 4, latent_height, latent_width),
|
|
1219
|
+
(max_batch, 4, max_latent_height, max_latent_width),
|
|
1220
|
+
]
|
|
1221
|
+
}
|
|
1222
|
+
|
|
1223
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
1224
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1225
|
+
return {
|
|
1226
|
+
"latent": (batch_size, 4, latent_height, latent_width),
|
|
1227
|
+
"images": (batch_size, 3, image_height, image_width),
|
|
1228
|
+
}
|
|
1229
|
+
|
|
1230
|
+
def get_sample_input(self, batch_size, image_height, image_width):
|
|
1231
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1232
|
+
dtype = torch.float16 if self.fp16 else torch.float32
|
|
1233
|
+
return (torch.randn(batch_size, 4, latent_height, latent_width, dtype=dtype, device=self.device),)
|
|
1234
|
+
|
|
1235
|
+
def fp32_input_output_names(self) -> List[str]:
|
|
1236
|
+
return []
|
|
1237
|
+
|
|
1238
|
+
|
|
1239
|
+
def get_tokenizer(pipeline_info: PipelineInfo, framework_model_dir, subfolder="tokenizer"):
|
|
1240
|
+
tokenizer_dir = os.path.join(framework_model_dir, pipeline_info.name(), subfolder)
|
|
1241
|
+
|
|
1242
|
+
if not os.path.exists(tokenizer_dir):
|
|
1243
|
+
model = CLIPTokenizer.from_pretrained(
|
|
1244
|
+
pipeline_info.name(),
|
|
1245
|
+
subfolder=subfolder,
|
|
1246
|
+
use_safetensors=pipeline_info.is_xl(),
|
|
1247
|
+
)
|
|
1248
|
+
model.save_pretrained(tokenizer_dir)
|
|
1249
|
+
else:
|
|
1250
|
+
print(f"[I] Load tokenizer pytorch model from: {tokenizer_dir}")
|
|
1251
|
+
model = CLIPTokenizer.from_pretrained(tokenizer_dir)
|
|
1252
|
+
return model
|
|
1253
|
+
|
|
1254
|
+
|
|
1255
|
+
class TorchVAEEncoder(torch.nn.Module):
|
|
1256
|
+
def __init__(self, vae_encoder):
|
|
1257
|
+
super().__init__()
|
|
1258
|
+
self.vae_encoder = vae_encoder
|
|
1259
|
+
|
|
1260
|
+
def forward(self, x):
|
|
1261
|
+
return self.vae_encoder.encode(x).latent_dist.sample()
|
|
1262
|
+
|
|
1263
|
+
|
|
1264
|
+
class VAEEncoder(BaseModel):
|
|
1265
|
+
def __init__(self, pipeline_info: PipelineInfo, model, device, max_batch_size):
|
|
1266
|
+
super().__init__(
|
|
1267
|
+
pipeline_info,
|
|
1268
|
+
model=model,
|
|
1269
|
+
device=device,
|
|
1270
|
+
max_batch_size=max_batch_size,
|
|
1271
|
+
)
|
|
1272
|
+
|
|
1273
|
+
def load_model(self, framework_model_dir, subfolder="vae_encoder"):
|
|
1274
|
+
vae = self.from_pretrained(AutoencoderKL, framework_model_dir, subfolder)
|
|
1275
|
+
return TorchVAEEncoder(vae)
|
|
1276
|
+
|
|
1277
|
+
def get_input_names(self):
|
|
1278
|
+
return ["images"]
|
|
1279
|
+
|
|
1280
|
+
def get_output_names(self):
|
|
1281
|
+
return ["latent"]
|
|
1282
|
+
|
|
1283
|
+
def get_dynamic_axes(self):
|
|
1284
|
+
return {"images": {0: "B", 2: "8H", 3: "8W"}, "latent": {0: "B", 2: "H", 3: "W"}}
|
|
1285
|
+
|
|
1286
|
+
def get_input_profile(self, batch_size, image_height, image_width, static_batch, static_image_shape):
|
|
1287
|
+
self.check_dims(batch_size, image_height, image_width)
|
|
1288
|
+
|
|
1289
|
+
(
|
|
1290
|
+
min_batch,
|
|
1291
|
+
max_batch,
|
|
1292
|
+
min_image_height,
|
|
1293
|
+
max_image_height,
|
|
1294
|
+
min_image_width,
|
|
1295
|
+
max_image_width,
|
|
1296
|
+
_,
|
|
1297
|
+
_,
|
|
1298
|
+
_,
|
|
1299
|
+
_,
|
|
1300
|
+
) = self.get_minmax_dims(batch_size, image_height, image_width, static_batch, static_image_shape)
|
|
1301
|
+
|
|
1302
|
+
return {
|
|
1303
|
+
"images": [
|
|
1304
|
+
(min_batch, 3, min_image_height, min_image_width),
|
|
1305
|
+
(batch_size, 3, image_height, image_width),
|
|
1306
|
+
(max_batch, 3, max_image_height, max_image_width),
|
|
1307
|
+
],
|
|
1308
|
+
}
|
|
1309
|
+
|
|
1310
|
+
def get_shape_dict(self, batch_size, image_height, image_width):
|
|
1311
|
+
latent_height, latent_width = self.check_dims(batch_size, image_height, image_width)
|
|
1312
|
+
return {
|
|
1313
|
+
"images": (batch_size, 3, image_height, image_width),
|
|
1314
|
+
"latent": (batch_size, 4, latent_height, latent_width),
|
|
1315
|
+
}
|
|
1316
|
+
|
|
1317
|
+
def get_sample_input(self, batch_size, image_height, image_width):
|
|
1318
|
+
self.check_dims(batch_size, image_height, image_width)
|
|
1319
|
+
return torch.randn(batch_size, 3, image_height, image_width, dtype=torch.float32, device=self.device)
|