onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- onnxruntime/LICENSE +21 -0
- onnxruntime/Privacy.md +21 -0
- onnxruntime/ThirdPartyNotices.txt +6508 -0
- onnxruntime/__init__.py +78 -0
- onnxruntime/backend/__init__.py +6 -0
- onnxruntime/backend/backend.py +174 -0
- onnxruntime/backend/backend_rep.py +53 -0
- onnxruntime/capi/DirectML.dll +0 -0
- onnxruntime/capi/__init__.py +4 -0
- onnxruntime/capi/_ld_preload.py +7 -0
- onnxruntime/capi/_pybind_state.py +33 -0
- onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
- onnxruntime/capi/onnxruntime.dll +0 -0
- onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
- onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
- onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
- onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
- onnxruntime/capi/onnxruntime_validation.py +150 -0
- onnxruntime/capi/version_info.py +2 -0
- onnxruntime/datasets/__init__.py +17 -0
- onnxruntime/datasets/logreg_iris.onnx +0 -0
- onnxruntime/datasets/mul_1.onnx +0 -0
- onnxruntime/datasets/sigmoid.onnx +13 -0
- onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
- onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
- onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
- onnxruntime/quantization/__init__.py +16 -0
- onnxruntime/quantization/base_quantizer.py +532 -0
- onnxruntime/quantization/calibrate.py +1245 -0
- onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
- onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
- onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
- onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
- onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
- onnxruntime/quantization/fusions/__init__.py +3 -0
- onnxruntime/quantization/fusions/fusion.py +311 -0
- onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
- onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
- onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
- onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
- onnxruntime/quantization/onnx_model.py +580 -0
- onnxruntime/quantization/onnx_quantizer.py +1008 -0
- onnxruntime/quantization/operators/__init__.py +2 -0
- onnxruntime/quantization/operators/activation.py +119 -0
- onnxruntime/quantization/operators/argmax.py +18 -0
- onnxruntime/quantization/operators/attention.py +73 -0
- onnxruntime/quantization/operators/base_operator.py +26 -0
- onnxruntime/quantization/operators/binary_op.py +72 -0
- onnxruntime/quantization/operators/concat.py +62 -0
- onnxruntime/quantization/operators/conv.py +258 -0
- onnxruntime/quantization/operators/direct_q8.py +78 -0
- onnxruntime/quantization/operators/embed_layernorm.py +121 -0
- onnxruntime/quantization/operators/gather.py +64 -0
- onnxruntime/quantization/operators/gavgpool.py +62 -0
- onnxruntime/quantization/operators/gemm.py +166 -0
- onnxruntime/quantization/operators/lstm.py +117 -0
- onnxruntime/quantization/operators/matmul.py +231 -0
- onnxruntime/quantization/operators/maxpool.py +34 -0
- onnxruntime/quantization/operators/norm.py +40 -0
- onnxruntime/quantization/operators/pad.py +100 -0
- onnxruntime/quantization/operators/pooling.py +67 -0
- onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
- onnxruntime/quantization/operators/resize.py +34 -0
- onnxruntime/quantization/operators/softmax.py +74 -0
- onnxruntime/quantization/operators/split.py +63 -0
- onnxruntime/quantization/operators/where.py +87 -0
- onnxruntime/quantization/preprocess.py +141 -0
- onnxruntime/quantization/qdq_loss_debug.py +389 -0
- onnxruntime/quantization/qdq_quantizer.py +1187 -0
- onnxruntime/quantization/quant_utils.py +891 -0
- onnxruntime/quantization/quantize.py +748 -0
- onnxruntime/quantization/registry.py +106 -0
- onnxruntime/quantization/shape_inference.py +187 -0
- onnxruntime/quantization/tensor_quant_overrides.py +516 -0
- onnxruntime/tools/__init__.py +10 -0
- onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
- onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
- onnxruntime/tools/file_utils.py +46 -0
- onnxruntime/tools/logger.py +11 -0
- onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
- onnxruntime/tools/mobile_helpers/__init__.py +0 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
- onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
- onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
- onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
- onnxruntime/tools/offline_tuning.py +169 -0
- onnxruntime/tools/onnx_model_utils.py +413 -0
- onnxruntime/tools/onnx_randomizer.py +85 -0
- onnxruntime/tools/onnxruntime_test.py +164 -0
- onnxruntime/tools/optimize_onnx_model.py +55 -0
- onnxruntime/tools/ort_format_model/__init__.py +25 -0
- onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
- onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
- onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
- onnxruntime/tools/ort_format_model/types.py +84 -0
- onnxruntime/tools/ort_format_model/utils.py +62 -0
- onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
- onnxruntime/tools/pytorch_export_helpers.py +131 -0
- onnxruntime/tools/qdq_helpers/__init__.py +0 -0
- onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
- onnxruntime/tools/reduced_build_config_parser.py +202 -0
- onnxruntime/tools/symbolic_shape_infer.py +3016 -0
- onnxruntime/tools/update_onnx_opset.py +31 -0
- onnxruntime/transformers/__init__.py +8 -0
- onnxruntime/transformers/affinity_helper.py +40 -0
- onnxruntime/transformers/benchmark.py +944 -0
- onnxruntime/transformers/benchmark_helper.py +646 -0
- onnxruntime/transformers/bert_perf_test.py +634 -0
- onnxruntime/transformers/bert_test_data.py +642 -0
- onnxruntime/transformers/compare_bert_results.py +246 -0
- onnxruntime/transformers/constants.py +47 -0
- onnxruntime/transformers/convert_generation.py +3124 -0
- onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
- onnxruntime/transformers/convert_to_packing_mode.py +387 -0
- onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
- onnxruntime/transformers/float16.py +501 -0
- onnxruntime/transformers/fusion_attention.py +1235 -0
- onnxruntime/transformers/fusion_attention_clip.py +257 -0
- onnxruntime/transformers/fusion_attention_sam2.py +534 -0
- onnxruntime/transformers/fusion_attention_unet.py +1304 -0
- onnxruntime/transformers/fusion_attention_vae.py +301 -0
- onnxruntime/transformers/fusion_bart_attention.py +640 -0
- onnxruntime/transformers/fusion_base.py +137 -0
- onnxruntime/transformers/fusion_bias_add.py +58 -0
- onnxruntime/transformers/fusion_biasgelu.py +66 -0
- onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
- onnxruntime/transformers/fusion_conformer_attention.py +143 -0
- onnxruntime/transformers/fusion_embedlayer.py +811 -0
- onnxruntime/transformers/fusion_fastgelu.py +360 -0
- onnxruntime/transformers/fusion_gelu.py +259 -0
- onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
- onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
- onnxruntime/transformers/fusion_gpt_attention.py +546 -0
- onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
- onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
- onnxruntime/transformers/fusion_group_norm.py +179 -0
- onnxruntime/transformers/fusion_layernorm.py +465 -0
- onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
- onnxruntime/transformers/fusion_options.py +340 -0
- onnxruntime/transformers/fusion_qordered_attention.py +421 -0
- onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
- onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
- onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
- onnxruntime/transformers/fusion_quickgelu.py +74 -0
- onnxruntime/transformers/fusion_reshape.py +173 -0
- onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
- onnxruntime/transformers/fusion_shape.py +110 -0
- onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
- onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
- onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
- onnxruntime/transformers/fusion_transpose.py +168 -0
- onnxruntime/transformers/fusion_utils.py +307 -0
- onnxruntime/transformers/huggingface_models.py +167 -0
- onnxruntime/transformers/import_utils.py +20 -0
- onnxruntime/transformers/io_binding_helper.py +442 -0
- onnxruntime/transformers/large_model_exporter.py +395 -0
- onnxruntime/transformers/machine_info.py +221 -0
- onnxruntime/transformers/metrics.py +164 -0
- onnxruntime/transformers/models/bart/__init__.py +12 -0
- onnxruntime/transformers/models/bart/export.py +98 -0
- onnxruntime/transformers/models/bert/__init__.py +12 -0
- onnxruntime/transformers/models/bert/eval_squad.py +329 -0
- onnxruntime/transformers/models/gpt2/__init__.py +12 -0
- onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
- onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
- onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
- onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
- onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
- onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
- onnxruntime/transformers/models/llama/__init__.py +12 -0
- onnxruntime/transformers/models/llama/benchmark.py +703 -0
- onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
- onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
- onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
- onnxruntime/transformers/models/llama/dist_settings.py +57 -0
- onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
- onnxruntime/transformers/models/llama/llama_parity.py +309 -0
- onnxruntime/transformers/models/llama/llama_torch.py +47 -0
- onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
- onnxruntime/transformers/models/longformer/__init__.py +12 -0
- onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
- onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
- onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
- onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
- onnxruntime/transformers/models/phi2/__init__.py +12 -0
- onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
- onnxruntime/transformers/models/phi2/inference_example.py +414 -0
- onnxruntime/transformers/models/sam2/__init__.py +12 -0
- onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
- onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
- onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
- onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
- onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
- onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
- onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
- onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
- onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
- onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
- onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
- onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
- onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
- onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
- onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
- onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
- onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
- onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
- onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
- onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
- onnxruntime/transformers/models/t5/__init__.py +12 -0
- onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
- onnxruntime/transformers/models/t5/past_helper.py +150 -0
- onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
- onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
- onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
- onnxruntime/transformers/models/t5/t5_helper.py +272 -0
- onnxruntime/transformers/models/whisper/__init__.py +12 -0
- onnxruntime/transformers/models/whisper/benchmark.py +610 -0
- onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
- onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
- onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
- onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
- onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
- onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
- onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
- onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
- onnxruntime/transformers/onnx_exporter.py +717 -0
- onnxruntime/transformers/onnx_model.py +1569 -0
- onnxruntime/transformers/onnx_model_bart.py +142 -0
- onnxruntime/transformers/onnx_model_bert.py +481 -0
- onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
- onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
- onnxruntime/transformers/onnx_model_clip.py +40 -0
- onnxruntime/transformers/onnx_model_conformer.py +33 -0
- onnxruntime/transformers/onnx_model_gpt2.py +101 -0
- onnxruntime/transformers/onnx_model_phi.py +930 -0
- onnxruntime/transformers/onnx_model_sam2.py +138 -0
- onnxruntime/transformers/onnx_model_t5.py +791 -0
- onnxruntime/transformers/onnx_model_tnlr.py +227 -0
- onnxruntime/transformers/onnx_model_unet.py +259 -0
- onnxruntime/transformers/onnx_model_vae.py +43 -0
- onnxruntime/transformers/onnx_utils.py +55 -0
- onnxruntime/transformers/optimizer.py +612 -0
- onnxruntime/transformers/profiler.py +725 -0
- onnxruntime/transformers/quantize_helper.py +76 -0
- onnxruntime/transformers/shape_infer_helper.py +122 -0
- onnxruntime/transformers/shape_optimizer.py +401 -0
- onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
- onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
- onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
- onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
- onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
- onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,1181 @@
|
|
|
1
|
+
# -------------------------------------------------------------------------
|
|
2
|
+
# Copyright (c) Microsoft Corporation. All rights reserved.
|
|
3
|
+
# Licensed under the MIT License.
|
|
4
|
+
# --------------------------------------------------------------------------
|
|
5
|
+
# Modified from utilities.py of TensorRT demo diffusion, which has the following license:
|
|
6
|
+
#
|
|
7
|
+
# Copyright 2022 The HuggingFace Inc. team.
|
|
8
|
+
# SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
|
9
|
+
# SPDX-License-Identifier: Apache-2.0
|
|
10
|
+
#
|
|
11
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
12
|
+
# you may not use this file except in compliance with the License.
|
|
13
|
+
# You may obtain a copy of the License at
|
|
14
|
+
#
|
|
15
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
|
16
|
+
#
|
|
17
|
+
# Unless required by applicable law or agreed to in writing, software
|
|
18
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
19
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
20
|
+
# See the License for the specific language governing permissions and
|
|
21
|
+
# limitations under the License.
|
|
22
|
+
# --------------------------------------------------------------------------
|
|
23
|
+
|
|
24
|
+
from typing import List, Optional
|
|
25
|
+
|
|
26
|
+
import numpy as np
|
|
27
|
+
import torch
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
class DDIMScheduler:
|
|
31
|
+
def __init__(
|
|
32
|
+
self,
|
|
33
|
+
device="cuda",
|
|
34
|
+
num_train_timesteps: int = 1000,
|
|
35
|
+
beta_start: float = 0.0001,
|
|
36
|
+
beta_end: float = 0.02,
|
|
37
|
+
clip_sample: bool = False,
|
|
38
|
+
set_alpha_to_one: bool = False,
|
|
39
|
+
steps_offset: int = 1,
|
|
40
|
+
prediction_type: str = "epsilon",
|
|
41
|
+
timestep_spacing: str = "leading",
|
|
42
|
+
):
|
|
43
|
+
# this schedule is very specific to the latent diffusion model.
|
|
44
|
+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
45
|
+
|
|
46
|
+
alphas = 1.0 - betas
|
|
47
|
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
48
|
+
# standard deviation of the initial noise distribution
|
|
49
|
+
self.init_noise_sigma = 1.0
|
|
50
|
+
|
|
51
|
+
# At every step in ddim, we are looking into the previous alphas_cumprod
|
|
52
|
+
# For the final step, there is no previous alphas_cumprod because we are already at 0
|
|
53
|
+
# `set_alpha_to_one` decides whether we set this parameter simply to one or
|
|
54
|
+
# whether we use the final alpha of the "non-previous" one.
|
|
55
|
+
self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
|
|
56
|
+
|
|
57
|
+
# setable values
|
|
58
|
+
self.num_inference_steps = None
|
|
59
|
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
|
60
|
+
self.steps_offset = steps_offset
|
|
61
|
+
self.num_train_timesteps = num_train_timesteps
|
|
62
|
+
self.clip_sample = clip_sample
|
|
63
|
+
self.prediction_type = prediction_type
|
|
64
|
+
self.device = device
|
|
65
|
+
self.timestep_spacing = timestep_spacing
|
|
66
|
+
|
|
67
|
+
def configure(self):
|
|
68
|
+
variance = np.zeros(self.num_inference_steps, dtype=np.float32)
|
|
69
|
+
for idx, timestep in enumerate(self.timesteps):
|
|
70
|
+
prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
|
|
71
|
+
variance[idx] = self._get_variance(timestep, prev_timestep)
|
|
72
|
+
self.variance = torch.from_numpy(variance).to(self.device)
|
|
73
|
+
|
|
74
|
+
timesteps = self.timesteps.long().cpu()
|
|
75
|
+
self.filtered_alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device)
|
|
76
|
+
self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device)
|
|
77
|
+
|
|
78
|
+
def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor:
|
|
79
|
+
return sample
|
|
80
|
+
|
|
81
|
+
def _get_variance(self, timestep, prev_timestep):
|
|
82
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
|
83
|
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
|
84
|
+
beta_prod_t = 1 - alpha_prod_t
|
|
85
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
|
86
|
+
|
|
87
|
+
variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
|
|
88
|
+
|
|
89
|
+
return variance
|
|
90
|
+
|
|
91
|
+
def set_timesteps(self, num_inference_steps: int):
|
|
92
|
+
self.num_inference_steps = num_inference_steps
|
|
93
|
+
if self.timestep_spacing == "leading":
|
|
94
|
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
|
95
|
+
# creates integer timesteps by multiplying by ratio
|
|
96
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
|
97
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
|
|
98
|
+
timesteps += self.steps_offset
|
|
99
|
+
elif self.timestep_spacing == "trailing":
|
|
100
|
+
step_ratio = self.num_train_timesteps / self.num_inference_steps
|
|
101
|
+
# creates integer timesteps by multiplying by ratio
|
|
102
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
|
103
|
+
timesteps = np.round(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
|
|
104
|
+
timesteps -= 1
|
|
105
|
+
else:
|
|
106
|
+
raise ValueError(
|
|
107
|
+
f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
self.timesteps = torch.from_numpy(timesteps).to(self.device)
|
|
111
|
+
|
|
112
|
+
def step(
|
|
113
|
+
self,
|
|
114
|
+
model_output,
|
|
115
|
+
sample,
|
|
116
|
+
idx,
|
|
117
|
+
timestep,
|
|
118
|
+
eta: float = 0.0,
|
|
119
|
+
use_clipped_model_output: bool = False,
|
|
120
|
+
generator=None,
|
|
121
|
+
variance_noise: torch.FloatTensor = None,
|
|
122
|
+
):
|
|
123
|
+
if self.num_inference_steps is None:
|
|
124
|
+
raise ValueError(
|
|
125
|
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
# See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
|
|
129
|
+
# Ideally, read DDIM paper in-detail understanding
|
|
130
|
+
|
|
131
|
+
# Notation (<variable name> -> <name in paper>
|
|
132
|
+
# - pred_noise_t -> e_theta(x_t, t)
|
|
133
|
+
# - pred_original_sample -> f_theta(x_t, t) or x_0
|
|
134
|
+
# - std_dev_t -> sigma_t
|
|
135
|
+
# - eta -> η
|
|
136
|
+
# - pred_sample_direction -> "direction pointing to x_t"
|
|
137
|
+
# - pred_prev_sample -> "x_t-1"
|
|
138
|
+
|
|
139
|
+
prev_idx = idx + 1
|
|
140
|
+
alpha_prod_t = self.filtered_alphas_cumprod[idx]
|
|
141
|
+
alpha_prod_t_prev = (
|
|
142
|
+
self.filtered_alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
beta_prod_t = 1 - alpha_prod_t
|
|
146
|
+
|
|
147
|
+
# 3. compute predicted original sample from predicted noise also called
|
|
148
|
+
# "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
|
149
|
+
if self.prediction_type == "epsilon":
|
|
150
|
+
pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
|
|
151
|
+
elif self.prediction_type == "sample":
|
|
152
|
+
pred_original_sample = model_output
|
|
153
|
+
elif self.prediction_type == "v_prediction":
|
|
154
|
+
pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
|
|
155
|
+
# predict V
|
|
156
|
+
model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
|
|
157
|
+
else:
|
|
158
|
+
raise ValueError(
|
|
159
|
+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
|
|
160
|
+
" `v_prediction`"
|
|
161
|
+
)
|
|
162
|
+
|
|
163
|
+
# 4. Clip "predicted x_0"
|
|
164
|
+
if self.clip_sample:
|
|
165
|
+
pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
|
|
166
|
+
|
|
167
|
+
# 5. compute variance: "sigma_t(η)" -> see formula (16)
|
|
168
|
+
# o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1)
|
|
169
|
+
variance = self.variance[idx]
|
|
170
|
+
std_dev_t = eta * variance ** (0.5)
|
|
171
|
+
|
|
172
|
+
if use_clipped_model_output:
|
|
173
|
+
# the model_output is always re-derived from the clipped x_0 in Glide
|
|
174
|
+
model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
|
|
175
|
+
|
|
176
|
+
# 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
|
177
|
+
pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
|
|
178
|
+
|
|
179
|
+
# 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
|
|
180
|
+
prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
|
|
181
|
+
|
|
182
|
+
if eta > 0:
|
|
183
|
+
# randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
|
|
184
|
+
device = model_output.device
|
|
185
|
+
if variance_noise is not None and generator is not None:
|
|
186
|
+
raise ValueError(
|
|
187
|
+
"Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
|
|
188
|
+
" `variance_noise` stays `None`."
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
if variance_noise is None:
|
|
192
|
+
variance_noise = torch.randn(
|
|
193
|
+
model_output.shape, generator=generator, device=device, dtype=model_output.dtype
|
|
194
|
+
)
|
|
195
|
+
variance = std_dev_t * variance_noise
|
|
196
|
+
|
|
197
|
+
prev_sample = prev_sample + variance
|
|
198
|
+
|
|
199
|
+
return prev_sample
|
|
200
|
+
|
|
201
|
+
def add_noise(self, init_latents, noise, idx, latent_timestep):
|
|
202
|
+
sqrt_alpha_prod = self.filtered_alphas_cumprod[idx] ** 0.5
|
|
203
|
+
sqrt_one_minus_alpha_prod = (1 - self.filtered_alphas_cumprod[idx]) ** 0.5
|
|
204
|
+
noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise
|
|
205
|
+
|
|
206
|
+
return noisy_latents
|
|
207
|
+
|
|
208
|
+
|
|
209
|
+
class EulerAncestralDiscreteScheduler:
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
num_train_timesteps: int = 1000,
|
|
213
|
+
beta_start: float = 0.0001,
|
|
214
|
+
beta_end: float = 0.02,
|
|
215
|
+
device="cuda",
|
|
216
|
+
steps_offset: int = 1,
|
|
217
|
+
prediction_type: str = "epsilon",
|
|
218
|
+
timestep_spacing: str = "trailing", # set default to trailing for SDXL Turbo
|
|
219
|
+
):
|
|
220
|
+
betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
221
|
+
alphas = 1.0 - betas
|
|
222
|
+
self.alphas_cumprod = torch.cumprod(alphas, dim=0)
|
|
223
|
+
|
|
224
|
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
|
225
|
+
sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
|
|
226
|
+
self.sigmas = torch.from_numpy(sigmas)
|
|
227
|
+
|
|
228
|
+
# standard deviation of the initial noise distribution
|
|
229
|
+
self.init_noise_sigma = self.sigmas.max()
|
|
230
|
+
|
|
231
|
+
# setable values
|
|
232
|
+
self.num_inference_steps = None
|
|
233
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
|
|
234
|
+
self.timesteps = torch.from_numpy(timesteps)
|
|
235
|
+
self.is_scale_input_called = False
|
|
236
|
+
|
|
237
|
+
self._step_index = None
|
|
238
|
+
|
|
239
|
+
self.device = device
|
|
240
|
+
self.num_train_timesteps = num_train_timesteps
|
|
241
|
+
self.steps_offset = steps_offset
|
|
242
|
+
self.prediction_type = prediction_type
|
|
243
|
+
self.timestep_spacing = timestep_spacing
|
|
244
|
+
|
|
245
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
|
246
|
+
def _init_step_index(self, timestep):
|
|
247
|
+
if isinstance(timestep, torch.Tensor):
|
|
248
|
+
timestep = timestep.to(self.timesteps.device)
|
|
249
|
+
|
|
250
|
+
index_candidates = (self.timesteps == timestep).nonzero()
|
|
251
|
+
|
|
252
|
+
# The sigma index that is taken for the **very** first `step`
|
|
253
|
+
# is always the second index (or the last index if there is only 1)
|
|
254
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
|
255
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
|
256
|
+
if len(index_candidates) > 1:
|
|
257
|
+
step_index = index_candidates[1]
|
|
258
|
+
else:
|
|
259
|
+
step_index = index_candidates[0]
|
|
260
|
+
|
|
261
|
+
self._step_index = step_index.item()
|
|
262
|
+
|
|
263
|
+
def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor:
|
|
264
|
+
if self._step_index is None:
|
|
265
|
+
self._init_step_index(timestep)
|
|
266
|
+
|
|
267
|
+
sigma = self.sigmas[self._step_index]
|
|
268
|
+
sample = sample / ((sigma**2 + 1) ** 0.5)
|
|
269
|
+
self.is_scale_input_called = True
|
|
270
|
+
return sample
|
|
271
|
+
|
|
272
|
+
def set_timesteps(self, num_inference_steps: int):
|
|
273
|
+
self.num_inference_steps = num_inference_steps
|
|
274
|
+
|
|
275
|
+
if self.timestep_spacing == "linspace":
|
|
276
|
+
timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
|
|
277
|
+
elif self.timestep_spacing == "leading":
|
|
278
|
+
step_ratio = self.num_train_timesteps // self.num_inference_steps
|
|
279
|
+
# creates integer timesteps by multiplying by ratio
|
|
280
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
|
281
|
+
timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
|
|
282
|
+
timesteps += self.steps_offset
|
|
283
|
+
elif self.timestep_spacing == "trailing":
|
|
284
|
+
step_ratio = self.num_train_timesteps / self.num_inference_steps
|
|
285
|
+
# creates integer timesteps by multiplying by ratio
|
|
286
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
|
287
|
+
timesteps = (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
|
|
288
|
+
timesteps -= 1
|
|
289
|
+
else:
|
|
290
|
+
raise ValueError(
|
|
291
|
+
f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
|
292
|
+
)
|
|
293
|
+
|
|
294
|
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
|
295
|
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
|
296
|
+
sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
|
|
297
|
+
self.sigmas = torch.from_numpy(sigmas).to(device=self.device)
|
|
298
|
+
self.timesteps = torch.from_numpy(timesteps).to(device=self.device)
|
|
299
|
+
|
|
300
|
+
self._step_index = None
|
|
301
|
+
|
|
302
|
+
def configure(self):
|
|
303
|
+
dts = np.zeros(self.num_inference_steps, dtype=np.float32)
|
|
304
|
+
sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32)
|
|
305
|
+
for idx, timestep in enumerate(self.timesteps):
|
|
306
|
+
step_index = (self.timesteps == timestep).nonzero().item()
|
|
307
|
+
sigma = self.sigmas[step_index]
|
|
308
|
+
|
|
309
|
+
sigma_from = self.sigmas[step_index]
|
|
310
|
+
sigma_to = self.sigmas[step_index + 1]
|
|
311
|
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
|
312
|
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
|
313
|
+
dt = sigma_down - sigma
|
|
314
|
+
dts[idx] = dt
|
|
315
|
+
sigmas_up[idx] = sigma_up
|
|
316
|
+
|
|
317
|
+
self.dts = torch.from_numpy(dts).to(self.device)
|
|
318
|
+
self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device)
|
|
319
|
+
|
|
320
|
+
def step(
|
|
321
|
+
self,
|
|
322
|
+
model_output,
|
|
323
|
+
sample,
|
|
324
|
+
idx,
|
|
325
|
+
timestep,
|
|
326
|
+
generator=None,
|
|
327
|
+
):
|
|
328
|
+
if self._step_index is None:
|
|
329
|
+
self._init_step_index(timestep)
|
|
330
|
+
sigma = self.sigmas[self._step_index]
|
|
331
|
+
|
|
332
|
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
|
333
|
+
if self.prediction_type == "epsilon":
|
|
334
|
+
pred_original_sample = sample - sigma * model_output
|
|
335
|
+
elif self.prediction_type == "v_prediction":
|
|
336
|
+
# * c_out + input * c_skip
|
|
337
|
+
pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
|
|
338
|
+
else:
|
|
339
|
+
raise ValueError(
|
|
340
|
+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
sigma_from = self.sigmas[self._step_index]
|
|
344
|
+
sigma_to = self.sigmas[self._step_index + 1]
|
|
345
|
+
sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
|
|
346
|
+
sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
|
|
347
|
+
|
|
348
|
+
# 2. Convert to an ODE derivative
|
|
349
|
+
derivative = (sample - pred_original_sample) / sigma
|
|
350
|
+
|
|
351
|
+
dt = sigma_down - sigma
|
|
352
|
+
|
|
353
|
+
prev_sample = sample + derivative * dt
|
|
354
|
+
|
|
355
|
+
device = model_output.device
|
|
356
|
+
noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device)
|
|
357
|
+
|
|
358
|
+
prev_sample = prev_sample + noise * sigma_up
|
|
359
|
+
|
|
360
|
+
# upon completion increase step index by one
|
|
361
|
+
self._step_index += 1
|
|
362
|
+
|
|
363
|
+
return prev_sample
|
|
364
|
+
|
|
365
|
+
def add_noise(self, original_samples, noise, idx, timestep=None):
|
|
366
|
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
367
|
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
|
368
|
+
timesteps = timestep.to(original_samples.device)
|
|
369
|
+
|
|
370
|
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
|
371
|
+
|
|
372
|
+
sigma = sigmas[step_indices].flatten()
|
|
373
|
+
while len(sigma.shape) < len(original_samples.shape):
|
|
374
|
+
sigma = sigma.unsqueeze(-1)
|
|
375
|
+
|
|
376
|
+
noisy_samples = original_samples + noise * sigma
|
|
377
|
+
return noisy_samples
|
|
378
|
+
|
|
379
|
+
|
|
380
|
+
class UniPCMultistepScheduler:
|
|
381
|
+
def __init__(
|
|
382
|
+
self,
|
|
383
|
+
device="cuda",
|
|
384
|
+
num_train_timesteps: int = 1000,
|
|
385
|
+
beta_start: float = 0.00085,
|
|
386
|
+
beta_end: float = 0.012,
|
|
387
|
+
solver_order: int = 2,
|
|
388
|
+
prediction_type: str = "epsilon",
|
|
389
|
+
thresholding: bool = False,
|
|
390
|
+
dynamic_thresholding_ratio: float = 0.995,
|
|
391
|
+
sample_max_value: float = 1.0,
|
|
392
|
+
predict_x0: bool = True,
|
|
393
|
+
solver_type: str = "bh2",
|
|
394
|
+
lower_order_final: bool = True,
|
|
395
|
+
disable_corrector: Optional[List[int]] = None,
|
|
396
|
+
use_karras_sigmas: Optional[bool] = False,
|
|
397
|
+
timestep_spacing: str = "linspace",
|
|
398
|
+
steps_offset: int = 0,
|
|
399
|
+
sigma_min=None,
|
|
400
|
+
sigma_max=None,
|
|
401
|
+
):
|
|
402
|
+
self.device = device
|
|
403
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
404
|
+
|
|
405
|
+
self.alphas = 1.0 - self.betas
|
|
406
|
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
407
|
+
# Currently we only support VP-type noise schedule
|
|
408
|
+
self.alpha_t = torch.sqrt(self.alphas_cumprod)
|
|
409
|
+
self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
|
|
410
|
+
self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
|
|
411
|
+
|
|
412
|
+
# standard deviation of the initial noise distribution
|
|
413
|
+
self.init_noise_sigma = 1.0
|
|
414
|
+
|
|
415
|
+
self.predict_x0 = predict_x0
|
|
416
|
+
# setable values
|
|
417
|
+
self.num_inference_steps = None
|
|
418
|
+
timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
|
419
|
+
self.timesteps = torch.from_numpy(timesteps)
|
|
420
|
+
self.model_outputs = [None] * solver_order
|
|
421
|
+
self.timestep_list = [None] * solver_order
|
|
422
|
+
self.lower_order_nums = 0
|
|
423
|
+
self.disable_corrector = disable_corrector if disable_corrector else []
|
|
424
|
+
self.last_sample = None
|
|
425
|
+
|
|
426
|
+
self._step_index = None
|
|
427
|
+
|
|
428
|
+
self.num_train_timesteps = num_train_timesteps
|
|
429
|
+
self.solver_order = solver_order
|
|
430
|
+
self.prediction_type = prediction_type
|
|
431
|
+
self.thresholding = thresholding
|
|
432
|
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
|
433
|
+
self.sample_max_value = sample_max_value
|
|
434
|
+
self.solver_type = solver_type
|
|
435
|
+
self.lower_order_final = lower_order_final
|
|
436
|
+
self.use_karras_sigmas = use_karras_sigmas
|
|
437
|
+
self.timestep_spacing = timestep_spacing
|
|
438
|
+
self.steps_offset = steps_offset
|
|
439
|
+
self.sigma_min = sigma_min
|
|
440
|
+
self.sigma_max = sigma_max
|
|
441
|
+
|
|
442
|
+
@property
|
|
443
|
+
def step_index(self):
|
|
444
|
+
"""
|
|
445
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
|
446
|
+
"""
|
|
447
|
+
return self._step_index
|
|
448
|
+
|
|
449
|
+
def set_timesteps(self, num_inference_steps: int):
|
|
450
|
+
if self.timestep_spacing == "linspace":
|
|
451
|
+
timesteps = (
|
|
452
|
+
np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
|
|
453
|
+
.round()[::-1][:-1]
|
|
454
|
+
.copy()
|
|
455
|
+
.astype(np.int64)
|
|
456
|
+
)
|
|
457
|
+
elif self.timestep_spacing == "leading":
|
|
458
|
+
step_ratio = self.num_train_timesteps // (num_inference_steps + 1)
|
|
459
|
+
# creates integer timesteps by multiplying by ratio
|
|
460
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
|
461
|
+
timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
|
|
462
|
+
timesteps += self.steps_offset
|
|
463
|
+
elif self.timestep_spacing == "trailing":
|
|
464
|
+
step_ratio = self.num_train_timesteps / num_inference_steps
|
|
465
|
+
# creates integer timesteps by multiplying by ratio
|
|
466
|
+
# casting to int to avoid issues when num_inference_step is power of 3
|
|
467
|
+
timesteps = np.arange(self.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
|
|
468
|
+
timesteps -= 1
|
|
469
|
+
else:
|
|
470
|
+
raise ValueError(
|
|
471
|
+
f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
|
|
472
|
+
)
|
|
473
|
+
|
|
474
|
+
sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
|
|
475
|
+
if self.use_karras_sigmas:
|
|
476
|
+
log_sigmas = np.log(sigmas)
|
|
477
|
+
sigmas = np.flip(sigmas).copy()
|
|
478
|
+
sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
|
|
479
|
+
timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
|
|
480
|
+
sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
|
|
481
|
+
else:
|
|
482
|
+
sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
|
|
483
|
+
sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
|
|
484
|
+
sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
|
|
485
|
+
|
|
486
|
+
self.sigmas = torch.from_numpy(sigmas)
|
|
487
|
+
self.timesteps = torch.from_numpy(timesteps).to(device=self.device, dtype=torch.int64)
|
|
488
|
+
|
|
489
|
+
self.num_inference_steps = len(timesteps)
|
|
490
|
+
|
|
491
|
+
self.model_outputs = [
|
|
492
|
+
None,
|
|
493
|
+
] * self.solver_order
|
|
494
|
+
self.lower_order_nums = 0
|
|
495
|
+
self.last_sample = None
|
|
496
|
+
|
|
497
|
+
# add an index counter for schedulers that allow duplicated timesteps
|
|
498
|
+
self._step_index = None
|
|
499
|
+
|
|
500
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
|
501
|
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
|
502
|
+
dtype = sample.dtype
|
|
503
|
+
batch_size, channels, *remaining_dims = sample.shape
|
|
504
|
+
|
|
505
|
+
if dtype not in (torch.float32, torch.float64):
|
|
506
|
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
|
507
|
+
|
|
508
|
+
# Flatten sample for doing quantile calculation along each image
|
|
509
|
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
|
510
|
+
|
|
511
|
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
|
512
|
+
|
|
513
|
+
s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1)
|
|
514
|
+
s = torch.clamp(
|
|
515
|
+
s, min=1, max=self.sample_max_value
|
|
516
|
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
|
517
|
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
|
518
|
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
|
519
|
+
|
|
520
|
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
|
521
|
+
sample = sample.to(dtype)
|
|
522
|
+
|
|
523
|
+
return sample
|
|
524
|
+
|
|
525
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
|
|
526
|
+
def _sigma_to_t(self, sigma, log_sigmas):
|
|
527
|
+
# get log sigma
|
|
528
|
+
log_sigma = np.log(np.maximum(sigma, 1e-10))
|
|
529
|
+
|
|
530
|
+
# get distribution
|
|
531
|
+
dists = log_sigma - log_sigmas[:, np.newaxis]
|
|
532
|
+
|
|
533
|
+
# get sigmas range
|
|
534
|
+
low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
|
|
535
|
+
high_idx = low_idx + 1
|
|
536
|
+
|
|
537
|
+
low = log_sigmas[low_idx]
|
|
538
|
+
high = log_sigmas[high_idx]
|
|
539
|
+
|
|
540
|
+
# interpolate sigmas
|
|
541
|
+
w = (low - log_sigma) / (low - high)
|
|
542
|
+
w = np.clip(w, 0, 1)
|
|
543
|
+
|
|
544
|
+
# transform interpolation to time range
|
|
545
|
+
t = (1 - w) * low_idx + w * high_idx
|
|
546
|
+
t = t.reshape(sigma.shape)
|
|
547
|
+
return t
|
|
548
|
+
|
|
549
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
|
|
550
|
+
def _sigma_to_alpha_sigma_t(self, sigma):
|
|
551
|
+
alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
|
|
552
|
+
sigma_t = sigma * alpha_t
|
|
553
|
+
|
|
554
|
+
return alpha_t, sigma_t
|
|
555
|
+
|
|
556
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
|
|
557
|
+
def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
|
|
558
|
+
"""Constructs the noise schedule of Karras et al. (2022)."""
|
|
559
|
+
|
|
560
|
+
sigma_min = self.sigma_min
|
|
561
|
+
sigma_max = self.sigma_max
|
|
562
|
+
|
|
563
|
+
sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
|
|
564
|
+
sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
|
|
565
|
+
|
|
566
|
+
rho = 7.0 # 7.0 is the value used in the paper
|
|
567
|
+
ramp = np.linspace(0, 1, num_inference_steps)
|
|
568
|
+
min_inv_rho = sigma_min ** (1 / rho)
|
|
569
|
+
max_inv_rho = sigma_max ** (1 / rho)
|
|
570
|
+
sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
|
|
571
|
+
return sigmas
|
|
572
|
+
|
|
573
|
+
def convert_model_output(
|
|
574
|
+
self,
|
|
575
|
+
model_output: torch.FloatTensor,
|
|
576
|
+
*args,
|
|
577
|
+
sample: torch.FloatTensor = None,
|
|
578
|
+
**kwargs,
|
|
579
|
+
) -> torch.FloatTensor:
|
|
580
|
+
timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
|
|
581
|
+
if sample is None:
|
|
582
|
+
if len(args) > 1:
|
|
583
|
+
sample = args[1]
|
|
584
|
+
else:
|
|
585
|
+
raise ValueError("missing `sample` as a required keyword argument")
|
|
586
|
+
if timestep is not None:
|
|
587
|
+
print(
|
|
588
|
+
"Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
|
589
|
+
)
|
|
590
|
+
|
|
591
|
+
sigma = self.sigmas[self.step_index]
|
|
592
|
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
|
593
|
+
|
|
594
|
+
if self.predict_x0:
|
|
595
|
+
if self.prediction_type == "epsilon":
|
|
596
|
+
x0_pred = (sample - sigma_t * model_output) / alpha_t
|
|
597
|
+
elif self.prediction_type == "sample":
|
|
598
|
+
x0_pred = model_output
|
|
599
|
+
elif self.prediction_type == "v_prediction":
|
|
600
|
+
x0_pred = alpha_t * sample - sigma_t * model_output
|
|
601
|
+
else:
|
|
602
|
+
raise ValueError(
|
|
603
|
+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
|
|
604
|
+
" `v_prediction` for the UniPCMultistepScheduler."
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
if self.thresholding:
|
|
608
|
+
x0_pred = self._threshold_sample(x0_pred)
|
|
609
|
+
|
|
610
|
+
return x0_pred
|
|
611
|
+
else:
|
|
612
|
+
if self.prediction_type == "epsilon":
|
|
613
|
+
return model_output
|
|
614
|
+
elif self.prediction_type == "sample":
|
|
615
|
+
epsilon = (sample - alpha_t * model_output) / sigma_t
|
|
616
|
+
return epsilon
|
|
617
|
+
elif self.prediction_type == "v_prediction":
|
|
618
|
+
epsilon = alpha_t * model_output + sigma_t * sample
|
|
619
|
+
return epsilon
|
|
620
|
+
else:
|
|
621
|
+
raise ValueError(
|
|
622
|
+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
|
|
623
|
+
" `v_prediction` for the UniPCMultistepScheduler."
|
|
624
|
+
)
|
|
625
|
+
|
|
626
|
+
def multistep_uni_p_bh_update(
|
|
627
|
+
self,
|
|
628
|
+
model_output: torch.FloatTensor,
|
|
629
|
+
*args,
|
|
630
|
+
sample: torch.FloatTensor = None,
|
|
631
|
+
order: Optional[int] = None,
|
|
632
|
+
**kwargs,
|
|
633
|
+
) -> torch.FloatTensor:
|
|
634
|
+
prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
|
|
635
|
+
if sample is None:
|
|
636
|
+
if len(args) > 1:
|
|
637
|
+
sample = args[1]
|
|
638
|
+
else:
|
|
639
|
+
raise ValueError(" missing `sample` as a required keyword argument")
|
|
640
|
+
if order is None:
|
|
641
|
+
if len(args) > 2:
|
|
642
|
+
order = args[2]
|
|
643
|
+
else:
|
|
644
|
+
raise ValueError(" missing `order` as a required keyword argument")
|
|
645
|
+
if prev_timestep is not None:
|
|
646
|
+
print(
|
|
647
|
+
"Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
|
648
|
+
)
|
|
649
|
+
model_output_list = self.model_outputs
|
|
650
|
+
|
|
651
|
+
# s0 = self.timestep_list[-1]
|
|
652
|
+
m0 = model_output_list[-1]
|
|
653
|
+
x = sample
|
|
654
|
+
|
|
655
|
+
sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
|
|
656
|
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
|
657
|
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
|
658
|
+
|
|
659
|
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
|
660
|
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
|
661
|
+
|
|
662
|
+
h = lambda_t - lambda_s0
|
|
663
|
+
device = sample.device
|
|
664
|
+
|
|
665
|
+
rks = []
|
|
666
|
+
d1s = []
|
|
667
|
+
for i in range(1, order):
|
|
668
|
+
si = self.step_index - i
|
|
669
|
+
mi = model_output_list[-(i + 1)]
|
|
670
|
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
|
671
|
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
|
672
|
+
rk = (lambda_si - lambda_s0) / h
|
|
673
|
+
rks.append(rk)
|
|
674
|
+
d1s.append((mi - m0) / rk)
|
|
675
|
+
|
|
676
|
+
rks.append(1.0)
|
|
677
|
+
rks = torch.tensor(rks, device=device)
|
|
678
|
+
|
|
679
|
+
r = []
|
|
680
|
+
b = []
|
|
681
|
+
|
|
682
|
+
hh = -h if self.predict_x0 else h
|
|
683
|
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
|
684
|
+
h_phi_k = h_phi_1 / hh - 1
|
|
685
|
+
|
|
686
|
+
factorial_i = 1
|
|
687
|
+
|
|
688
|
+
if self.solver_type == "bh1":
|
|
689
|
+
b_h = hh
|
|
690
|
+
elif self.solver_type == "bh2":
|
|
691
|
+
b_h = torch.expm1(hh)
|
|
692
|
+
else:
|
|
693
|
+
raise NotImplementedError()
|
|
694
|
+
|
|
695
|
+
for i in range(1, order + 1):
|
|
696
|
+
r.append(torch.pow(rks, i - 1))
|
|
697
|
+
b.append(h_phi_k * factorial_i / b_h)
|
|
698
|
+
factorial_i *= i + 1
|
|
699
|
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
|
700
|
+
|
|
701
|
+
r = torch.stack(r)
|
|
702
|
+
b = torch.tensor(b, device=device)
|
|
703
|
+
|
|
704
|
+
if len(d1s) > 0:
|
|
705
|
+
d1s = torch.stack(d1s, dim=1) # (B, K)
|
|
706
|
+
# for order 2, we use a simplified version
|
|
707
|
+
if order == 2:
|
|
708
|
+
rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
|
|
709
|
+
else:
|
|
710
|
+
rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1])
|
|
711
|
+
else:
|
|
712
|
+
d1s = None
|
|
713
|
+
|
|
714
|
+
if self.predict_x0:
|
|
715
|
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
|
716
|
+
if d1s is not None:
|
|
717
|
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s)
|
|
718
|
+
else:
|
|
719
|
+
pred_res = 0
|
|
720
|
+
x_t = x_t_ - alpha_t * b_h * pred_res
|
|
721
|
+
else:
|
|
722
|
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
|
723
|
+
if d1s is not None:
|
|
724
|
+
pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s)
|
|
725
|
+
else:
|
|
726
|
+
pred_res = 0
|
|
727
|
+
x_t = x_t_ - sigma_t * b_h * pred_res
|
|
728
|
+
|
|
729
|
+
x_t = x_t.to(x.dtype)
|
|
730
|
+
return x_t
|
|
731
|
+
|
|
732
|
+
def multistep_uni_c_bh_update(
|
|
733
|
+
self,
|
|
734
|
+
this_model_output: torch.FloatTensor,
|
|
735
|
+
*args,
|
|
736
|
+
last_sample: torch.FloatTensor = None,
|
|
737
|
+
this_sample: torch.FloatTensor = None,
|
|
738
|
+
order: Optional[int] = None,
|
|
739
|
+
**kwargs,
|
|
740
|
+
) -> torch.FloatTensor:
|
|
741
|
+
this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
|
|
742
|
+
if last_sample is None:
|
|
743
|
+
if len(args) > 1:
|
|
744
|
+
last_sample = args[1]
|
|
745
|
+
else:
|
|
746
|
+
raise ValueError(" missing`last_sample` as a required keyword argument")
|
|
747
|
+
if this_sample is None:
|
|
748
|
+
if len(args) > 2:
|
|
749
|
+
this_sample = args[2]
|
|
750
|
+
else:
|
|
751
|
+
raise ValueError(" missing`this_sample` as a required keyword argument")
|
|
752
|
+
if order is None:
|
|
753
|
+
if len(args) > 3:
|
|
754
|
+
order = args[3]
|
|
755
|
+
else:
|
|
756
|
+
raise ValueError(" missing`order` as a required keyword argument")
|
|
757
|
+
if this_timestep is not None:
|
|
758
|
+
print(
|
|
759
|
+
"Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
|
|
760
|
+
)
|
|
761
|
+
|
|
762
|
+
model_output_list = self.model_outputs
|
|
763
|
+
|
|
764
|
+
m0 = model_output_list[-1]
|
|
765
|
+
x = last_sample
|
|
766
|
+
# x_t = this_sample
|
|
767
|
+
model_t = this_model_output
|
|
768
|
+
|
|
769
|
+
sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
|
|
770
|
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
|
|
771
|
+
alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
|
|
772
|
+
|
|
773
|
+
lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
|
|
774
|
+
lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
|
|
775
|
+
|
|
776
|
+
h = lambda_t - lambda_s0
|
|
777
|
+
device = this_sample.device
|
|
778
|
+
|
|
779
|
+
rks = []
|
|
780
|
+
d1s = []
|
|
781
|
+
for i in range(1, order):
|
|
782
|
+
si = self.step_index - (i + 1)
|
|
783
|
+
mi = model_output_list[-(i + 1)]
|
|
784
|
+
alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
|
|
785
|
+
lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
|
|
786
|
+
rk = (lambda_si - lambda_s0) / h
|
|
787
|
+
rks.append(rk)
|
|
788
|
+
d1s.append((mi - m0) / rk)
|
|
789
|
+
|
|
790
|
+
rks.append(1.0)
|
|
791
|
+
rks = torch.tensor(rks, device=device)
|
|
792
|
+
|
|
793
|
+
r = []
|
|
794
|
+
b = []
|
|
795
|
+
|
|
796
|
+
hh = -h if self.predict_x0 else h
|
|
797
|
+
h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
|
|
798
|
+
h_phi_k = h_phi_1 / hh - 1
|
|
799
|
+
|
|
800
|
+
factorial_i = 1
|
|
801
|
+
|
|
802
|
+
if self.solver_type == "bh1":
|
|
803
|
+
b_h = hh
|
|
804
|
+
elif self.solver_type == "bh2":
|
|
805
|
+
b_h = torch.expm1(hh)
|
|
806
|
+
else:
|
|
807
|
+
raise NotImplementedError()
|
|
808
|
+
|
|
809
|
+
for i in range(1, order + 1):
|
|
810
|
+
r.append(torch.pow(rks, i - 1))
|
|
811
|
+
b.append(h_phi_k * factorial_i / b_h)
|
|
812
|
+
factorial_i *= i + 1
|
|
813
|
+
h_phi_k = h_phi_k / hh - 1 / factorial_i
|
|
814
|
+
|
|
815
|
+
r = torch.stack(r)
|
|
816
|
+
b = torch.tensor(b, device=device)
|
|
817
|
+
|
|
818
|
+
if len(d1s) > 0:
|
|
819
|
+
d1s = torch.stack(d1s, dim=1)
|
|
820
|
+
else:
|
|
821
|
+
d1s = None
|
|
822
|
+
|
|
823
|
+
# for order 1, we use a simplified version
|
|
824
|
+
if order == 1:
|
|
825
|
+
rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
|
|
826
|
+
else:
|
|
827
|
+
rhos_c = torch.linalg.solve(r, b)
|
|
828
|
+
|
|
829
|
+
if self.predict_x0:
|
|
830
|
+
x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
|
|
831
|
+
if d1s is not None:
|
|
832
|
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s)
|
|
833
|
+
else:
|
|
834
|
+
corr_res = 0
|
|
835
|
+
d1_t = model_t - m0
|
|
836
|
+
x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t)
|
|
837
|
+
else:
|
|
838
|
+
x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
|
|
839
|
+
if d1s is not None:
|
|
840
|
+
corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s)
|
|
841
|
+
else:
|
|
842
|
+
corr_res = 0
|
|
843
|
+
d1_t = model_t - m0
|
|
844
|
+
x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t)
|
|
845
|
+
x_t = x_t.to(x.dtype)
|
|
846
|
+
return x_t
|
|
847
|
+
|
|
848
|
+
def _init_step_index(self, timestep):
|
|
849
|
+
if isinstance(timestep, torch.Tensor):
|
|
850
|
+
timestep = timestep.to(self.timesteps.device)
|
|
851
|
+
|
|
852
|
+
index_candidates = (self.timesteps == timestep).nonzero()
|
|
853
|
+
|
|
854
|
+
if len(index_candidates) == 0:
|
|
855
|
+
step_index = len(self.timesteps) - 1
|
|
856
|
+
# The sigma index that is taken for the **very** first `step`
|
|
857
|
+
# is always the second index (or the last index if there is only 1)
|
|
858
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
|
859
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
|
860
|
+
elif len(index_candidates) > 1:
|
|
861
|
+
step_index = index_candidates[1].item()
|
|
862
|
+
else:
|
|
863
|
+
step_index = index_candidates[0].item()
|
|
864
|
+
|
|
865
|
+
self._step_index = step_index
|
|
866
|
+
|
|
867
|
+
def step(
|
|
868
|
+
self,
|
|
869
|
+
model_output: torch.FloatTensor,
|
|
870
|
+
timestep: int,
|
|
871
|
+
sample: torch.FloatTensor,
|
|
872
|
+
return_dict: bool = True,
|
|
873
|
+
):
|
|
874
|
+
if self.num_inference_steps is None:
|
|
875
|
+
raise ValueError(
|
|
876
|
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
|
877
|
+
)
|
|
878
|
+
|
|
879
|
+
if self.step_index is None:
|
|
880
|
+
self._init_step_index(timestep)
|
|
881
|
+
|
|
882
|
+
use_corrector = (
|
|
883
|
+
self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
|
|
884
|
+
)
|
|
885
|
+
|
|
886
|
+
model_output_convert = self.convert_model_output(model_output, sample=sample)
|
|
887
|
+
if use_corrector:
|
|
888
|
+
sample = self.multistep_uni_c_bh_update(
|
|
889
|
+
this_model_output=model_output_convert,
|
|
890
|
+
last_sample=self.last_sample,
|
|
891
|
+
this_sample=sample,
|
|
892
|
+
order=self.this_order,
|
|
893
|
+
)
|
|
894
|
+
|
|
895
|
+
for i in range(self.solver_order - 1):
|
|
896
|
+
self.model_outputs[i] = self.model_outputs[i + 1]
|
|
897
|
+
self.timestep_list[i] = self.timestep_list[i + 1]
|
|
898
|
+
|
|
899
|
+
self.model_outputs[-1] = model_output_convert
|
|
900
|
+
self.timestep_list[-1] = timestep
|
|
901
|
+
|
|
902
|
+
if self.lower_order_final:
|
|
903
|
+
this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
|
|
904
|
+
else:
|
|
905
|
+
this_order = self.solver_order
|
|
906
|
+
|
|
907
|
+
self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
|
|
908
|
+
assert self.this_order > 0
|
|
909
|
+
|
|
910
|
+
self.last_sample = sample
|
|
911
|
+
prev_sample = self.multistep_uni_p_bh_update(
|
|
912
|
+
model_output=model_output, # pass the original non-converted model output, in case solver-p is used
|
|
913
|
+
sample=sample,
|
|
914
|
+
order=self.this_order,
|
|
915
|
+
)
|
|
916
|
+
|
|
917
|
+
if self.lower_order_nums < self.solver_order:
|
|
918
|
+
self.lower_order_nums += 1
|
|
919
|
+
|
|
920
|
+
# upon completion increase step index by one
|
|
921
|
+
self._step_index += 1
|
|
922
|
+
|
|
923
|
+
if not return_dict:
|
|
924
|
+
return (prev_sample,)
|
|
925
|
+
|
|
926
|
+
return prev_sample
|
|
927
|
+
|
|
928
|
+
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
|
929
|
+
return sample
|
|
930
|
+
|
|
931
|
+
def add_noise(
|
|
932
|
+
self,
|
|
933
|
+
original_samples: torch.FloatTensor,
|
|
934
|
+
noise: torch.FloatTensor,
|
|
935
|
+
idx,
|
|
936
|
+
timesteps: torch.IntTensor,
|
|
937
|
+
) -> torch.FloatTensor:
|
|
938
|
+
# Make sure sigmas and timesteps have the same device and dtype as original_samples
|
|
939
|
+
sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
940
|
+
schedule_timesteps = self.timesteps.to(original_samples.device)
|
|
941
|
+
timesteps = timesteps.to(original_samples.device)
|
|
942
|
+
|
|
943
|
+
step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
|
|
944
|
+
sigma = sigmas[step_indices].flatten()
|
|
945
|
+
while len(sigma.shape) < len(original_samples.shape):
|
|
946
|
+
sigma = sigma.unsqueeze(-1)
|
|
947
|
+
|
|
948
|
+
alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
|
|
949
|
+
noisy_samples = alpha_t * original_samples + sigma_t * noise
|
|
950
|
+
return noisy_samples
|
|
951
|
+
|
|
952
|
+
def configure(self):
|
|
953
|
+
pass
|
|
954
|
+
|
|
955
|
+
def __len__(self):
|
|
956
|
+
return self.num_train_timesteps
|
|
957
|
+
|
|
958
|
+
|
|
959
|
+
# Modified from diffusers.schedulers.LCMScheduler
|
|
960
|
+
class LCMScheduler:
|
|
961
|
+
def __init__(
|
|
962
|
+
self,
|
|
963
|
+
device="cuda",
|
|
964
|
+
num_train_timesteps: int = 1000,
|
|
965
|
+
beta_start: float = 0.00085,
|
|
966
|
+
beta_end: float = 0.012,
|
|
967
|
+
original_inference_steps: int = 50,
|
|
968
|
+
clip_sample: bool = False,
|
|
969
|
+
clip_sample_range: float = 1.0,
|
|
970
|
+
steps_offset: int = 0,
|
|
971
|
+
prediction_type: str = "epsilon",
|
|
972
|
+
thresholding: bool = False,
|
|
973
|
+
dynamic_thresholding_ratio: float = 0.995,
|
|
974
|
+
sample_max_value: float = 1.0,
|
|
975
|
+
timestep_spacing: str = "leading",
|
|
976
|
+
timestep_scaling: float = 10.0,
|
|
977
|
+
):
|
|
978
|
+
self.device = device
|
|
979
|
+
self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
|
|
980
|
+
self.alphas = 1.0 - self.betas
|
|
981
|
+
self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
|
|
982
|
+
self.final_alpha_cumprod = self.alphas_cumprod[0]
|
|
983
|
+
# standard deviation of the initial noise distribution
|
|
984
|
+
self.init_noise_sigma = 1.0
|
|
985
|
+
# setable values
|
|
986
|
+
self.num_inference_steps = None
|
|
987
|
+
self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
|
|
988
|
+
|
|
989
|
+
self.num_train_timesteps = num_train_timesteps
|
|
990
|
+
self.clip_sample = clip_sample
|
|
991
|
+
self.clip_sample_range = clip_sample_range
|
|
992
|
+
self.steps_offset = steps_offset
|
|
993
|
+
self.prediction_type = prediction_type
|
|
994
|
+
self.thresholding = thresholding
|
|
995
|
+
self.timestep_spacing = timestep_spacing
|
|
996
|
+
self.timestep_scaling = timestep_scaling
|
|
997
|
+
self.original_inference_steps = original_inference_steps
|
|
998
|
+
self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
|
|
999
|
+
self.sample_max_value = sample_max_value
|
|
1000
|
+
|
|
1001
|
+
self._step_index = None
|
|
1002
|
+
|
|
1003
|
+
# Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
|
|
1004
|
+
def _init_step_index(self, timestep):
|
|
1005
|
+
if isinstance(timestep, torch.Tensor):
|
|
1006
|
+
timestep = timestep.to(self.timesteps.device)
|
|
1007
|
+
|
|
1008
|
+
index_candidates = (self.timesteps == timestep).nonzero()
|
|
1009
|
+
|
|
1010
|
+
if len(index_candidates) > 1:
|
|
1011
|
+
step_index = index_candidates[1]
|
|
1012
|
+
else:
|
|
1013
|
+
step_index = index_candidates[0]
|
|
1014
|
+
|
|
1015
|
+
self._step_index = step_index.item()
|
|
1016
|
+
|
|
1017
|
+
@property
|
|
1018
|
+
def step_index(self):
|
|
1019
|
+
return self._step_index
|
|
1020
|
+
|
|
1021
|
+
def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
|
|
1022
|
+
return sample
|
|
1023
|
+
|
|
1024
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
|
|
1025
|
+
def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
|
|
1026
|
+
dtype = sample.dtype
|
|
1027
|
+
batch_size, channels, *remaining_dims = sample.shape
|
|
1028
|
+
|
|
1029
|
+
if dtype not in (torch.float32, torch.float64):
|
|
1030
|
+
sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
|
|
1031
|
+
|
|
1032
|
+
# Flatten sample for doing quantile calculation along each image
|
|
1033
|
+
sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
|
|
1034
|
+
|
|
1035
|
+
abs_sample = sample.abs() # "a certain percentile absolute pixel value"
|
|
1036
|
+
|
|
1037
|
+
s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1)
|
|
1038
|
+
s = torch.clamp(
|
|
1039
|
+
s, min=1, max=self.sample_max_value
|
|
1040
|
+
) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
|
|
1041
|
+
s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
|
|
1042
|
+
sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
|
|
1043
|
+
|
|
1044
|
+
sample = sample.reshape(batch_size, channels, *remaining_dims)
|
|
1045
|
+
sample = sample.to(dtype)
|
|
1046
|
+
|
|
1047
|
+
return sample
|
|
1048
|
+
|
|
1049
|
+
def set_timesteps(
|
|
1050
|
+
self,
|
|
1051
|
+
num_inference_steps: int,
|
|
1052
|
+
strength: int = 1.0,
|
|
1053
|
+
):
|
|
1054
|
+
assert num_inference_steps <= self.num_train_timesteps
|
|
1055
|
+
|
|
1056
|
+
self.num_inference_steps = num_inference_steps
|
|
1057
|
+
original_steps = self.original_inference_steps
|
|
1058
|
+
|
|
1059
|
+
assert original_steps <= self.num_train_timesteps
|
|
1060
|
+
assert num_inference_steps <= original_steps
|
|
1061
|
+
|
|
1062
|
+
# LCM Timesteps Setting
|
|
1063
|
+
# Currently, only linear spacing is supported.
|
|
1064
|
+
c = self.num_train_timesteps // original_steps
|
|
1065
|
+
# LCM Training Steps Schedule
|
|
1066
|
+
lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
|
|
1067
|
+
skipping_step = len(lcm_origin_timesteps) // num_inference_steps
|
|
1068
|
+
# LCM Inference Steps Schedule
|
|
1069
|
+
timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
|
|
1070
|
+
|
|
1071
|
+
self.timesteps = torch.from_numpy(timesteps.copy()).to(device=self.device, dtype=torch.long)
|
|
1072
|
+
|
|
1073
|
+
self._step_index = None
|
|
1074
|
+
|
|
1075
|
+
def get_scalings_for_boundary_condition_discrete(self, timestep):
|
|
1076
|
+
self.sigma_data = 0.5 # Default: 0.5
|
|
1077
|
+
scaled_timestep = timestep * self.timestep_scaling
|
|
1078
|
+
|
|
1079
|
+
c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
|
|
1080
|
+
c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
|
|
1081
|
+
return c_skip, c_out
|
|
1082
|
+
|
|
1083
|
+
def step(
|
|
1084
|
+
self,
|
|
1085
|
+
model_output: torch.FloatTensor,
|
|
1086
|
+
timestep: int,
|
|
1087
|
+
sample: torch.FloatTensor,
|
|
1088
|
+
generator: Optional[torch.Generator] = None,
|
|
1089
|
+
):
|
|
1090
|
+
if self.num_inference_steps is None:
|
|
1091
|
+
raise ValueError(
|
|
1092
|
+
"Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
|
|
1093
|
+
)
|
|
1094
|
+
|
|
1095
|
+
if self.step_index is None:
|
|
1096
|
+
self._init_step_index(timestep)
|
|
1097
|
+
|
|
1098
|
+
# 1. get previous step value
|
|
1099
|
+
prev_step_index = self.step_index + 1
|
|
1100
|
+
if prev_step_index < len(self.timesteps):
|
|
1101
|
+
prev_timestep = self.timesteps[prev_step_index]
|
|
1102
|
+
else:
|
|
1103
|
+
prev_timestep = timestep
|
|
1104
|
+
|
|
1105
|
+
# 2. compute alphas, betas
|
|
1106
|
+
alpha_prod_t = self.alphas_cumprod[timestep]
|
|
1107
|
+
alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
|
|
1108
|
+
|
|
1109
|
+
beta_prod_t = 1 - alpha_prod_t
|
|
1110
|
+
beta_prod_t_prev = 1 - alpha_prod_t_prev
|
|
1111
|
+
|
|
1112
|
+
# 3. Get scalings for boundary conditions
|
|
1113
|
+
c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
|
|
1114
|
+
|
|
1115
|
+
# 4. Compute the predicted original sample x_0 based on the model parameterization
|
|
1116
|
+
if self.prediction_type == "epsilon": # noise-prediction
|
|
1117
|
+
predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
|
|
1118
|
+
elif self.prediction_type == "sample": # x-prediction
|
|
1119
|
+
predicted_original_sample = model_output
|
|
1120
|
+
elif self.prediction_type == "v_prediction": # v-prediction
|
|
1121
|
+
predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
|
|
1122
|
+
else:
|
|
1123
|
+
raise ValueError(
|
|
1124
|
+
f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or"
|
|
1125
|
+
" `v_prediction` for `LCMScheduler`."
|
|
1126
|
+
)
|
|
1127
|
+
|
|
1128
|
+
# 5. Clip or threshold "predicted x_0"
|
|
1129
|
+
if self.thresholding:
|
|
1130
|
+
predicted_original_sample = self._threshold_sample(predicted_original_sample)
|
|
1131
|
+
elif self.clip_sample:
|
|
1132
|
+
predicted_original_sample = predicted_original_sample.clamp(-self.clip_sample_range, self.clip_sample_range)
|
|
1133
|
+
|
|
1134
|
+
# 6. Denoise model output using boundary conditions
|
|
1135
|
+
denoised = c_out * predicted_original_sample + c_skip * sample
|
|
1136
|
+
|
|
1137
|
+
# 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
|
|
1138
|
+
# Noise is not used on the final timestep of the timestep schedule.
|
|
1139
|
+
# This also means that noise is not used for one-step sampling.
|
|
1140
|
+
if self.step_index != self.num_inference_steps - 1:
|
|
1141
|
+
noise = torch.randn(
|
|
1142
|
+
model_output.shape, device=model_output.device, dtype=denoised.dtype, generator=generator
|
|
1143
|
+
)
|
|
1144
|
+
prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
|
|
1145
|
+
else:
|
|
1146
|
+
prev_sample = denoised
|
|
1147
|
+
|
|
1148
|
+
# upon completion increase step index by one
|
|
1149
|
+
self._step_index += 1
|
|
1150
|
+
|
|
1151
|
+
return (prev_sample,)
|
|
1152
|
+
|
|
1153
|
+
# Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
|
|
1154
|
+
def add_noise(
|
|
1155
|
+
self,
|
|
1156
|
+
original_samples: torch.FloatTensor,
|
|
1157
|
+
noise: torch.FloatTensor,
|
|
1158
|
+
timesteps: torch.IntTensor,
|
|
1159
|
+
) -> torch.FloatTensor:
|
|
1160
|
+
# Make sure alphas_cumprod and timestep have same device and dtype as original_samples
|
|
1161
|
+
alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
|
|
1162
|
+
timesteps = timesteps.to(original_samples.device)
|
|
1163
|
+
|
|
1164
|
+
sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
|
|
1165
|
+
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
|
|
1166
|
+
while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
|
|
1167
|
+
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
|
|
1168
|
+
|
|
1169
|
+
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
|
|
1170
|
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
|
|
1171
|
+
while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
|
|
1172
|
+
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
|
|
1173
|
+
|
|
1174
|
+
noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
|
|
1175
|
+
return noisy_samples
|
|
1176
|
+
|
|
1177
|
+
def configure(self):
|
|
1178
|
+
pass
|
|
1179
|
+
|
|
1180
|
+
def __len__(self):
|
|
1181
|
+
return self.num_train_timesteps
|