onnxruntime-directml 1.20.0__cp313-cp313-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (305) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6508 -0
  4. onnxruntime/__init__.py +78 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +174 -0
  7. onnxruntime/backend/backend_rep.py +53 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  13. onnxruntime/capi/onnxruntime.dll +0 -0
  14. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  15. onnxruntime/capi/onnxruntime_inference_collection.py +1108 -0
  16. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  17. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  18. onnxruntime/capi/onnxruntime_validation.py +150 -0
  19. onnxruntime/capi/version_info.py +2 -0
  20. onnxruntime/datasets/__init__.py +17 -0
  21. onnxruntime/datasets/logreg_iris.onnx +0 -0
  22. onnxruntime/datasets/mul_1.onnx +0 -0
  23. onnxruntime/datasets/sigmoid.onnx +13 -0
  24. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  27. onnxruntime/quantization/__init__.py +16 -0
  28. onnxruntime/quantization/base_quantizer.py +532 -0
  29. onnxruntime/quantization/calibrate.py +1245 -0
  30. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  31. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  32. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  33. onnxruntime/quantization/execution_providers/qnn/preprocess.py +307 -0
  34. onnxruntime/quantization/execution_providers/qnn/quant_config.py +387 -0
  35. onnxruntime/quantization/fusions/__init__.py +3 -0
  36. onnxruntime/quantization/fusions/fusion.py +311 -0
  37. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  38. onnxruntime/quantization/fusions/fusion_layernorm.py +135 -0
  39. onnxruntime/quantization/matmul_4bits_quantizer.py +1480 -0
  40. onnxruntime/quantization/matmul_bnb4_quantizer.py +240 -0
  41. onnxruntime/quantization/onnx_model.py +580 -0
  42. onnxruntime/quantization/onnx_quantizer.py +1008 -0
  43. onnxruntime/quantization/operators/__init__.py +2 -0
  44. onnxruntime/quantization/operators/activation.py +119 -0
  45. onnxruntime/quantization/operators/argmax.py +18 -0
  46. onnxruntime/quantization/operators/attention.py +73 -0
  47. onnxruntime/quantization/operators/base_operator.py +26 -0
  48. onnxruntime/quantization/operators/binary_op.py +72 -0
  49. onnxruntime/quantization/operators/concat.py +62 -0
  50. onnxruntime/quantization/operators/conv.py +258 -0
  51. onnxruntime/quantization/operators/direct_q8.py +78 -0
  52. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  53. onnxruntime/quantization/operators/gather.py +64 -0
  54. onnxruntime/quantization/operators/gavgpool.py +62 -0
  55. onnxruntime/quantization/operators/gemm.py +166 -0
  56. onnxruntime/quantization/operators/lstm.py +117 -0
  57. onnxruntime/quantization/operators/matmul.py +231 -0
  58. onnxruntime/quantization/operators/maxpool.py +34 -0
  59. onnxruntime/quantization/operators/norm.py +40 -0
  60. onnxruntime/quantization/operators/pad.py +100 -0
  61. onnxruntime/quantization/operators/pooling.py +67 -0
  62. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  63. onnxruntime/quantization/operators/resize.py +34 -0
  64. onnxruntime/quantization/operators/softmax.py +74 -0
  65. onnxruntime/quantization/operators/split.py +63 -0
  66. onnxruntime/quantization/operators/where.py +87 -0
  67. onnxruntime/quantization/preprocess.py +141 -0
  68. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  69. onnxruntime/quantization/qdq_quantizer.py +1187 -0
  70. onnxruntime/quantization/quant_utils.py +891 -0
  71. onnxruntime/quantization/quantize.py +748 -0
  72. onnxruntime/quantization/registry.py +106 -0
  73. onnxruntime/quantization/shape_inference.py +187 -0
  74. onnxruntime/quantization/tensor_quant_overrides.py +516 -0
  75. onnxruntime/tools/__init__.py +10 -0
  76. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  77. onnxruntime/tools/convert_onnx_models_to_ort.py +377 -0
  78. onnxruntime/tools/file_utils.py +46 -0
  79. onnxruntime/tools/logger.py +11 -0
  80. onnxruntime/tools/make_dynamic_shape_fixed.py +72 -0
  81. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  82. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +33 -0
  83. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  84. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  85. onnxruntime/tools/mobile_helpers/usability_checker.py +739 -0
  86. onnxruntime/tools/offline_tuning.py +169 -0
  87. onnxruntime/tools/onnx_model_utils.py +413 -0
  88. onnxruntime/tools/onnx_randomizer.py +85 -0
  89. onnxruntime/tools/onnxruntime_test.py +164 -0
  90. onnxruntime/tools/optimize_onnx_model.py +55 -0
  91. onnxruntime/tools/ort_format_model/__init__.py +25 -0
  92. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +663 -0
  93. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  94. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  95. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  96. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  97. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  98. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  99. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  100. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  140. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  141. onnxruntime/tools/ort_format_model/types.py +84 -0
  142. onnxruntime/tools/ort_format_model/utils.py +62 -0
  143. onnxruntime/tools/pytorch_export_contrib_ops.py +108 -0
  144. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  145. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  146. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  147. onnxruntime/tools/reduced_build_config_parser.py +202 -0
  148. onnxruntime/tools/symbolic_shape_infer.py +3016 -0
  149. onnxruntime/tools/update_onnx_opset.py +31 -0
  150. onnxruntime/transformers/__init__.py +8 -0
  151. onnxruntime/transformers/affinity_helper.py +40 -0
  152. onnxruntime/transformers/benchmark.py +944 -0
  153. onnxruntime/transformers/benchmark_helper.py +646 -0
  154. onnxruntime/transformers/bert_perf_test.py +634 -0
  155. onnxruntime/transformers/bert_test_data.py +642 -0
  156. onnxruntime/transformers/compare_bert_results.py +246 -0
  157. onnxruntime/transformers/constants.py +47 -0
  158. onnxruntime/transformers/convert_generation.py +3124 -0
  159. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  160. onnxruntime/transformers/convert_to_packing_mode.py +387 -0
  161. onnxruntime/transformers/dynamo_onnx_helper.py +104 -0
  162. onnxruntime/transformers/float16.py +501 -0
  163. onnxruntime/transformers/fusion_attention.py +1235 -0
  164. onnxruntime/transformers/fusion_attention_clip.py +257 -0
  165. onnxruntime/transformers/fusion_attention_sam2.py +534 -0
  166. onnxruntime/transformers/fusion_attention_unet.py +1304 -0
  167. onnxruntime/transformers/fusion_attention_vae.py +301 -0
  168. onnxruntime/transformers/fusion_bart_attention.py +640 -0
  169. onnxruntime/transformers/fusion_base.py +137 -0
  170. onnxruntime/transformers/fusion_bias_add.py +58 -0
  171. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  172. onnxruntime/transformers/fusion_biassplitgelu.py +111 -0
  173. onnxruntime/transformers/fusion_conformer_attention.py +143 -0
  174. onnxruntime/transformers/fusion_embedlayer.py +811 -0
  175. onnxruntime/transformers/fusion_fastgelu.py +360 -0
  176. onnxruntime/transformers/fusion_gelu.py +259 -0
  177. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  178. onnxruntime/transformers/fusion_gemmfastgelu.py +122 -0
  179. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  180. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  181. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  182. onnxruntime/transformers/fusion_group_norm.py +179 -0
  183. onnxruntime/transformers/fusion_layernorm.py +465 -0
  184. onnxruntime/transformers/fusion_nhwc_conv.py +100 -0
  185. onnxruntime/transformers/fusion_options.py +340 -0
  186. onnxruntime/transformers/fusion_qordered_attention.py +421 -0
  187. onnxruntime/transformers/fusion_qordered_gelu.py +119 -0
  188. onnxruntime/transformers/fusion_qordered_layernorm.py +123 -0
  189. onnxruntime/transformers/fusion_qordered_matmul.py +217 -0
  190. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  191. onnxruntime/transformers/fusion_reshape.py +173 -0
  192. onnxruntime/transformers/fusion_rotary_attention.py +1592 -0
  193. onnxruntime/transformers/fusion_shape.py +110 -0
  194. onnxruntime/transformers/fusion_simplified_layernorm.py +159 -0
  195. onnxruntime/transformers/fusion_skip_group_norm.py +255 -0
  196. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  197. onnxruntime/transformers/fusion_transpose.py +168 -0
  198. onnxruntime/transformers/fusion_utils.py +307 -0
  199. onnxruntime/transformers/huggingface_models.py +167 -0
  200. onnxruntime/transformers/import_utils.py +20 -0
  201. onnxruntime/transformers/io_binding_helper.py +442 -0
  202. onnxruntime/transformers/large_model_exporter.py +395 -0
  203. onnxruntime/transformers/machine_info.py +221 -0
  204. onnxruntime/transformers/metrics.py +164 -0
  205. onnxruntime/transformers/models/bart/__init__.py +12 -0
  206. onnxruntime/transformers/models/bart/export.py +98 -0
  207. onnxruntime/transformers/models/bert/__init__.py +12 -0
  208. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  209. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  210. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  211. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +561 -0
  212. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1032 -0
  213. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  214. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  215. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  216. onnxruntime/transformers/models/llama/__init__.py +12 -0
  217. onnxruntime/transformers/models/llama/benchmark.py +703 -0
  218. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  219. onnxruntime/transformers/models/llama/benchmark_e2e.py +606 -0
  220. onnxruntime/transformers/models/llama/convert_to_onnx.py +1027 -0
  221. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  222. onnxruntime/transformers/models/llama/llama_inputs.py +503 -0
  223. onnxruntime/transformers/models/llama/llama_parity.py +309 -0
  224. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  225. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  226. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  227. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  228. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  229. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  230. onnxruntime/transformers/models/longformer/longformer_helper.py +77 -0
  231. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  232. onnxruntime/transformers/models/phi2/convert_to_onnx.py +576 -0
  233. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  234. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  235. onnxruntime/transformers/models/sam2/benchmark_sam2.py +625 -0
  236. onnxruntime/transformers/models/sam2/convert_to_onnx.py +260 -0
  237. onnxruntime/transformers/models/sam2/image_decoder.py +273 -0
  238. onnxruntime/transformers/models/sam2/image_encoder.py +186 -0
  239. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  240. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  241. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  242. onnxruntime/transformers/models/sam2/sam2_demo.py +322 -0
  243. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +280 -0
  244. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  245. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  246. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1429 -0
  247. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  248. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +102 -0
  249. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +268 -0
  250. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  251. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1319 -0
  252. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1181 -0
  253. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +296 -0
  254. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +388 -0
  255. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  256. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  257. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  258. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +350 -0
  259. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  260. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  261. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  262. onnxruntime/transformers/models/t5/__init__.py +12 -0
  263. onnxruntime/transformers/models/t5/convert_to_onnx.py +278 -0
  264. onnxruntime/transformers/models/t5/past_helper.py +150 -0
  265. onnxruntime/transformers/models/t5/t5_decoder.py +438 -0
  266. onnxruntime/transformers/models/t5/t5_encoder.py +171 -0
  267. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +299 -0
  268. onnxruntime/transformers/models/t5/t5_helper.py +272 -0
  269. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  270. onnxruntime/transformers/models/whisper/benchmark.py +610 -0
  271. onnxruntime/transformers/models/whisper/benchmark_all.py +528 -0
  272. onnxruntime/transformers/models/whisper/convert_to_onnx.py +536 -0
  273. onnxruntime/transformers/models/whisper/whisper_chain.py +329 -0
  274. onnxruntime/transformers/models/whisper/whisper_decoder.py +402 -0
  275. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  276. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +306 -0
  277. onnxruntime/transformers/models/whisper/whisper_helper.py +524 -0
  278. onnxruntime/transformers/models/whisper/whisper_openai_helper.py +84 -0
  279. onnxruntime/transformers/onnx_exporter.py +717 -0
  280. onnxruntime/transformers/onnx_model.py +1569 -0
  281. onnxruntime/transformers/onnx_model_bart.py +142 -0
  282. onnxruntime/transformers/onnx_model_bert.py +481 -0
  283. onnxruntime/transformers/onnx_model_bert_keras.py +475 -0
  284. onnxruntime/transformers/onnx_model_bert_tf.py +589 -0
  285. onnxruntime/transformers/onnx_model_clip.py +40 -0
  286. onnxruntime/transformers/onnx_model_conformer.py +33 -0
  287. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  288. onnxruntime/transformers/onnx_model_phi.py +930 -0
  289. onnxruntime/transformers/onnx_model_sam2.py +138 -0
  290. onnxruntime/transformers/onnx_model_t5.py +791 -0
  291. onnxruntime/transformers/onnx_model_tnlr.py +227 -0
  292. onnxruntime/transformers/onnx_model_unet.py +259 -0
  293. onnxruntime/transformers/onnx_model_vae.py +43 -0
  294. onnxruntime/transformers/onnx_utils.py +55 -0
  295. onnxruntime/transformers/optimizer.py +612 -0
  296. onnxruntime/transformers/profiler.py +725 -0
  297. onnxruntime/transformers/quantize_helper.py +76 -0
  298. onnxruntime/transformers/shape_infer_helper.py +122 -0
  299. onnxruntime/transformers/shape_optimizer.py +401 -0
  300. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  301. onnxruntime_directml-1.20.0.dist-info/METADATA +187 -0
  302. onnxruntime_directml-1.20.0.dist-info/RECORD +305 -0
  303. onnxruntime_directml-1.20.0.dist-info/WHEEL +5 -0
  304. onnxruntime_directml-1.20.0.dist-info/entry_points.txt +2 -0
  305. onnxruntime_directml-1.20.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1181 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from utilities.py of TensorRT demo diffusion, which has the following license:
6
+ #
7
+ # Copyright 2022 The HuggingFace Inc. team.
8
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ # --------------------------------------------------------------------------
23
+
24
+ from typing import List, Optional
25
+
26
+ import numpy as np
27
+ import torch
28
+
29
+
30
+ class DDIMScheduler:
31
+ def __init__(
32
+ self,
33
+ device="cuda",
34
+ num_train_timesteps: int = 1000,
35
+ beta_start: float = 0.0001,
36
+ beta_end: float = 0.02,
37
+ clip_sample: bool = False,
38
+ set_alpha_to_one: bool = False,
39
+ steps_offset: int = 1,
40
+ prediction_type: str = "epsilon",
41
+ timestep_spacing: str = "leading",
42
+ ):
43
+ # this schedule is very specific to the latent diffusion model.
44
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
45
+
46
+ alphas = 1.0 - betas
47
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
48
+ # standard deviation of the initial noise distribution
49
+ self.init_noise_sigma = 1.0
50
+
51
+ # At every step in ddim, we are looking into the previous alphas_cumprod
52
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
53
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
54
+ # whether we use the final alpha of the "non-previous" one.
55
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
56
+
57
+ # setable values
58
+ self.num_inference_steps = None
59
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
60
+ self.steps_offset = steps_offset
61
+ self.num_train_timesteps = num_train_timesteps
62
+ self.clip_sample = clip_sample
63
+ self.prediction_type = prediction_type
64
+ self.device = device
65
+ self.timestep_spacing = timestep_spacing
66
+
67
+ def configure(self):
68
+ variance = np.zeros(self.num_inference_steps, dtype=np.float32)
69
+ for idx, timestep in enumerate(self.timesteps):
70
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
71
+ variance[idx] = self._get_variance(timestep, prev_timestep)
72
+ self.variance = torch.from_numpy(variance).to(self.device)
73
+
74
+ timesteps = self.timesteps.long().cpu()
75
+ self.filtered_alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device)
76
+ self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device)
77
+
78
+ def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor:
79
+ return sample
80
+
81
+ def _get_variance(self, timestep, prev_timestep):
82
+ alpha_prod_t = self.alphas_cumprod[timestep]
83
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
84
+ beta_prod_t = 1 - alpha_prod_t
85
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
86
+
87
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
88
+
89
+ return variance
90
+
91
+ def set_timesteps(self, num_inference_steps: int):
92
+ self.num_inference_steps = num_inference_steps
93
+ if self.timestep_spacing == "leading":
94
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
95
+ # creates integer timesteps by multiplying by ratio
96
+ # casting to int to avoid issues when num_inference_step is power of 3
97
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
98
+ timesteps += self.steps_offset
99
+ elif self.timestep_spacing == "trailing":
100
+ step_ratio = self.num_train_timesteps / self.num_inference_steps
101
+ # creates integer timesteps by multiplying by ratio
102
+ # casting to int to avoid issues when num_inference_step is power of 3
103
+ timesteps = np.round(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
104
+ timesteps -= 1
105
+ else:
106
+ raise ValueError(
107
+ f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
108
+ )
109
+
110
+ self.timesteps = torch.from_numpy(timesteps).to(self.device)
111
+
112
+ def step(
113
+ self,
114
+ model_output,
115
+ sample,
116
+ idx,
117
+ timestep,
118
+ eta: float = 0.0,
119
+ use_clipped_model_output: bool = False,
120
+ generator=None,
121
+ variance_noise: torch.FloatTensor = None,
122
+ ):
123
+ if self.num_inference_steps is None:
124
+ raise ValueError(
125
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
126
+ )
127
+
128
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
129
+ # Ideally, read DDIM paper in-detail understanding
130
+
131
+ # Notation (<variable name> -> <name in paper>
132
+ # - pred_noise_t -> e_theta(x_t, t)
133
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
134
+ # - std_dev_t -> sigma_t
135
+ # - eta -> η
136
+ # - pred_sample_direction -> "direction pointing to x_t"
137
+ # - pred_prev_sample -> "x_t-1"
138
+
139
+ prev_idx = idx + 1
140
+ alpha_prod_t = self.filtered_alphas_cumprod[idx]
141
+ alpha_prod_t_prev = (
142
+ self.filtered_alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod
143
+ )
144
+
145
+ beta_prod_t = 1 - alpha_prod_t
146
+
147
+ # 3. compute predicted original sample from predicted noise also called
148
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
149
+ if self.prediction_type == "epsilon":
150
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
151
+ elif self.prediction_type == "sample":
152
+ pred_original_sample = model_output
153
+ elif self.prediction_type == "v_prediction":
154
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
155
+ # predict V
156
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
157
+ else:
158
+ raise ValueError(
159
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
160
+ " `v_prediction`"
161
+ )
162
+
163
+ # 4. Clip "predicted x_0"
164
+ if self.clip_sample:
165
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
166
+
167
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
168
+ # o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1)
169
+ variance = self.variance[idx]
170
+ std_dev_t = eta * variance ** (0.5)
171
+
172
+ if use_clipped_model_output:
173
+ # the model_output is always re-derived from the clipped x_0 in Glide
174
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
175
+
176
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
177
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
178
+
179
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
180
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
181
+
182
+ if eta > 0:
183
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
184
+ device = model_output.device
185
+ if variance_noise is not None and generator is not None:
186
+ raise ValueError(
187
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
188
+ " `variance_noise` stays `None`."
189
+ )
190
+
191
+ if variance_noise is None:
192
+ variance_noise = torch.randn(
193
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
194
+ )
195
+ variance = std_dev_t * variance_noise
196
+
197
+ prev_sample = prev_sample + variance
198
+
199
+ return prev_sample
200
+
201
+ def add_noise(self, init_latents, noise, idx, latent_timestep):
202
+ sqrt_alpha_prod = self.filtered_alphas_cumprod[idx] ** 0.5
203
+ sqrt_one_minus_alpha_prod = (1 - self.filtered_alphas_cumprod[idx]) ** 0.5
204
+ noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise
205
+
206
+ return noisy_latents
207
+
208
+
209
+ class EulerAncestralDiscreteScheduler:
210
+ def __init__(
211
+ self,
212
+ num_train_timesteps: int = 1000,
213
+ beta_start: float = 0.0001,
214
+ beta_end: float = 0.02,
215
+ device="cuda",
216
+ steps_offset: int = 1,
217
+ prediction_type: str = "epsilon",
218
+ timestep_spacing: str = "trailing", # set default to trailing for SDXL Turbo
219
+ ):
220
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
221
+ alphas = 1.0 - betas
222
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
223
+
224
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
225
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
226
+ self.sigmas = torch.from_numpy(sigmas)
227
+
228
+ # standard deviation of the initial noise distribution
229
+ self.init_noise_sigma = self.sigmas.max()
230
+
231
+ # setable values
232
+ self.num_inference_steps = None
233
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
234
+ self.timesteps = torch.from_numpy(timesteps)
235
+ self.is_scale_input_called = False
236
+
237
+ self._step_index = None
238
+
239
+ self.device = device
240
+ self.num_train_timesteps = num_train_timesteps
241
+ self.steps_offset = steps_offset
242
+ self.prediction_type = prediction_type
243
+ self.timestep_spacing = timestep_spacing
244
+
245
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
246
+ def _init_step_index(self, timestep):
247
+ if isinstance(timestep, torch.Tensor):
248
+ timestep = timestep.to(self.timesteps.device)
249
+
250
+ index_candidates = (self.timesteps == timestep).nonzero()
251
+
252
+ # The sigma index that is taken for the **very** first `step`
253
+ # is always the second index (or the last index if there is only 1)
254
+ # This way we can ensure we don't accidentally skip a sigma in
255
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
256
+ if len(index_candidates) > 1:
257
+ step_index = index_candidates[1]
258
+ else:
259
+ step_index = index_candidates[0]
260
+
261
+ self._step_index = step_index.item()
262
+
263
+ def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor:
264
+ if self._step_index is None:
265
+ self._init_step_index(timestep)
266
+
267
+ sigma = self.sigmas[self._step_index]
268
+ sample = sample / ((sigma**2 + 1) ** 0.5)
269
+ self.is_scale_input_called = True
270
+ return sample
271
+
272
+ def set_timesteps(self, num_inference_steps: int):
273
+ self.num_inference_steps = num_inference_steps
274
+
275
+ if self.timestep_spacing == "linspace":
276
+ timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
277
+ elif self.timestep_spacing == "leading":
278
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
279
+ # creates integer timesteps by multiplying by ratio
280
+ # casting to int to avoid issues when num_inference_step is power of 3
281
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
282
+ timesteps += self.steps_offset
283
+ elif self.timestep_spacing == "trailing":
284
+ step_ratio = self.num_train_timesteps / self.num_inference_steps
285
+ # creates integer timesteps by multiplying by ratio
286
+ # casting to int to avoid issues when num_inference_step is power of 3
287
+ timesteps = (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
288
+ timesteps -= 1
289
+ else:
290
+ raise ValueError(
291
+ f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
292
+ )
293
+
294
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
295
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
296
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
297
+ self.sigmas = torch.from_numpy(sigmas).to(device=self.device)
298
+ self.timesteps = torch.from_numpy(timesteps).to(device=self.device)
299
+
300
+ self._step_index = None
301
+
302
+ def configure(self):
303
+ dts = np.zeros(self.num_inference_steps, dtype=np.float32)
304
+ sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32)
305
+ for idx, timestep in enumerate(self.timesteps):
306
+ step_index = (self.timesteps == timestep).nonzero().item()
307
+ sigma = self.sigmas[step_index]
308
+
309
+ sigma_from = self.sigmas[step_index]
310
+ sigma_to = self.sigmas[step_index + 1]
311
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
312
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
313
+ dt = sigma_down - sigma
314
+ dts[idx] = dt
315
+ sigmas_up[idx] = sigma_up
316
+
317
+ self.dts = torch.from_numpy(dts).to(self.device)
318
+ self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device)
319
+
320
+ def step(
321
+ self,
322
+ model_output,
323
+ sample,
324
+ idx,
325
+ timestep,
326
+ generator=None,
327
+ ):
328
+ if self._step_index is None:
329
+ self._init_step_index(timestep)
330
+ sigma = self.sigmas[self._step_index]
331
+
332
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
333
+ if self.prediction_type == "epsilon":
334
+ pred_original_sample = sample - sigma * model_output
335
+ elif self.prediction_type == "v_prediction":
336
+ # * c_out + input * c_skip
337
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
338
+ else:
339
+ raise ValueError(
340
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
341
+ )
342
+
343
+ sigma_from = self.sigmas[self._step_index]
344
+ sigma_to = self.sigmas[self._step_index + 1]
345
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
346
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
347
+
348
+ # 2. Convert to an ODE derivative
349
+ derivative = (sample - pred_original_sample) / sigma
350
+
351
+ dt = sigma_down - sigma
352
+
353
+ prev_sample = sample + derivative * dt
354
+
355
+ device = model_output.device
356
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device)
357
+
358
+ prev_sample = prev_sample + noise * sigma_up
359
+
360
+ # upon completion increase step index by one
361
+ self._step_index += 1
362
+
363
+ return prev_sample
364
+
365
+ def add_noise(self, original_samples, noise, idx, timestep=None):
366
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
367
+ schedule_timesteps = self.timesteps.to(original_samples.device)
368
+ timesteps = timestep.to(original_samples.device)
369
+
370
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
371
+
372
+ sigma = sigmas[step_indices].flatten()
373
+ while len(sigma.shape) < len(original_samples.shape):
374
+ sigma = sigma.unsqueeze(-1)
375
+
376
+ noisy_samples = original_samples + noise * sigma
377
+ return noisy_samples
378
+
379
+
380
+ class UniPCMultistepScheduler:
381
+ def __init__(
382
+ self,
383
+ device="cuda",
384
+ num_train_timesteps: int = 1000,
385
+ beta_start: float = 0.00085,
386
+ beta_end: float = 0.012,
387
+ solver_order: int = 2,
388
+ prediction_type: str = "epsilon",
389
+ thresholding: bool = False,
390
+ dynamic_thresholding_ratio: float = 0.995,
391
+ sample_max_value: float = 1.0,
392
+ predict_x0: bool = True,
393
+ solver_type: str = "bh2",
394
+ lower_order_final: bool = True,
395
+ disable_corrector: Optional[List[int]] = None,
396
+ use_karras_sigmas: Optional[bool] = False,
397
+ timestep_spacing: str = "linspace",
398
+ steps_offset: int = 0,
399
+ sigma_min=None,
400
+ sigma_max=None,
401
+ ):
402
+ self.device = device
403
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
404
+
405
+ self.alphas = 1.0 - self.betas
406
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
407
+ # Currently we only support VP-type noise schedule
408
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
409
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
410
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
411
+
412
+ # standard deviation of the initial noise distribution
413
+ self.init_noise_sigma = 1.0
414
+
415
+ self.predict_x0 = predict_x0
416
+ # setable values
417
+ self.num_inference_steps = None
418
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
419
+ self.timesteps = torch.from_numpy(timesteps)
420
+ self.model_outputs = [None] * solver_order
421
+ self.timestep_list = [None] * solver_order
422
+ self.lower_order_nums = 0
423
+ self.disable_corrector = disable_corrector if disable_corrector else []
424
+ self.last_sample = None
425
+
426
+ self._step_index = None
427
+
428
+ self.num_train_timesteps = num_train_timesteps
429
+ self.solver_order = solver_order
430
+ self.prediction_type = prediction_type
431
+ self.thresholding = thresholding
432
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
433
+ self.sample_max_value = sample_max_value
434
+ self.solver_type = solver_type
435
+ self.lower_order_final = lower_order_final
436
+ self.use_karras_sigmas = use_karras_sigmas
437
+ self.timestep_spacing = timestep_spacing
438
+ self.steps_offset = steps_offset
439
+ self.sigma_min = sigma_min
440
+ self.sigma_max = sigma_max
441
+
442
+ @property
443
+ def step_index(self):
444
+ """
445
+ The index counter for current timestep. It will increase 1 after each scheduler step.
446
+ """
447
+ return self._step_index
448
+
449
+ def set_timesteps(self, num_inference_steps: int):
450
+ if self.timestep_spacing == "linspace":
451
+ timesteps = (
452
+ np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
453
+ .round()[::-1][:-1]
454
+ .copy()
455
+ .astype(np.int64)
456
+ )
457
+ elif self.timestep_spacing == "leading":
458
+ step_ratio = self.num_train_timesteps // (num_inference_steps + 1)
459
+ # creates integer timesteps by multiplying by ratio
460
+ # casting to int to avoid issues when num_inference_step is power of 3
461
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
462
+ timesteps += self.steps_offset
463
+ elif self.timestep_spacing == "trailing":
464
+ step_ratio = self.num_train_timesteps / num_inference_steps
465
+ # creates integer timesteps by multiplying by ratio
466
+ # casting to int to avoid issues when num_inference_step is power of 3
467
+ timesteps = np.arange(self.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
468
+ timesteps -= 1
469
+ else:
470
+ raise ValueError(
471
+ f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
472
+ )
473
+
474
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
475
+ if self.use_karras_sigmas:
476
+ log_sigmas = np.log(sigmas)
477
+ sigmas = np.flip(sigmas).copy()
478
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
479
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
480
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
481
+ else:
482
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
483
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
484
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
485
+
486
+ self.sigmas = torch.from_numpy(sigmas)
487
+ self.timesteps = torch.from_numpy(timesteps).to(device=self.device, dtype=torch.int64)
488
+
489
+ self.num_inference_steps = len(timesteps)
490
+
491
+ self.model_outputs = [
492
+ None,
493
+ ] * self.solver_order
494
+ self.lower_order_nums = 0
495
+ self.last_sample = None
496
+
497
+ # add an index counter for schedulers that allow duplicated timesteps
498
+ self._step_index = None
499
+
500
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
501
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
502
+ dtype = sample.dtype
503
+ batch_size, channels, *remaining_dims = sample.shape
504
+
505
+ if dtype not in (torch.float32, torch.float64):
506
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
507
+
508
+ # Flatten sample for doing quantile calculation along each image
509
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
510
+
511
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
512
+
513
+ s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1)
514
+ s = torch.clamp(
515
+ s, min=1, max=self.sample_max_value
516
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
517
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
518
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
519
+
520
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
521
+ sample = sample.to(dtype)
522
+
523
+ return sample
524
+
525
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
526
+ def _sigma_to_t(self, sigma, log_sigmas):
527
+ # get log sigma
528
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
529
+
530
+ # get distribution
531
+ dists = log_sigma - log_sigmas[:, np.newaxis]
532
+
533
+ # get sigmas range
534
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
535
+ high_idx = low_idx + 1
536
+
537
+ low = log_sigmas[low_idx]
538
+ high = log_sigmas[high_idx]
539
+
540
+ # interpolate sigmas
541
+ w = (low - log_sigma) / (low - high)
542
+ w = np.clip(w, 0, 1)
543
+
544
+ # transform interpolation to time range
545
+ t = (1 - w) * low_idx + w * high_idx
546
+ t = t.reshape(sigma.shape)
547
+ return t
548
+
549
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
550
+ def _sigma_to_alpha_sigma_t(self, sigma):
551
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
552
+ sigma_t = sigma * alpha_t
553
+
554
+ return alpha_t, sigma_t
555
+
556
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
557
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
558
+ """Constructs the noise schedule of Karras et al. (2022)."""
559
+
560
+ sigma_min = self.sigma_min
561
+ sigma_max = self.sigma_max
562
+
563
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
564
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
565
+
566
+ rho = 7.0 # 7.0 is the value used in the paper
567
+ ramp = np.linspace(0, 1, num_inference_steps)
568
+ min_inv_rho = sigma_min ** (1 / rho)
569
+ max_inv_rho = sigma_max ** (1 / rho)
570
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
571
+ return sigmas
572
+
573
+ def convert_model_output(
574
+ self,
575
+ model_output: torch.FloatTensor,
576
+ *args,
577
+ sample: torch.FloatTensor = None,
578
+ **kwargs,
579
+ ) -> torch.FloatTensor:
580
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
581
+ if sample is None:
582
+ if len(args) > 1:
583
+ sample = args[1]
584
+ else:
585
+ raise ValueError("missing `sample` as a required keyword argument")
586
+ if timestep is not None:
587
+ print(
588
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
589
+ )
590
+
591
+ sigma = self.sigmas[self.step_index]
592
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
593
+
594
+ if self.predict_x0:
595
+ if self.prediction_type == "epsilon":
596
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
597
+ elif self.prediction_type == "sample":
598
+ x0_pred = model_output
599
+ elif self.prediction_type == "v_prediction":
600
+ x0_pred = alpha_t * sample - sigma_t * model_output
601
+ else:
602
+ raise ValueError(
603
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
604
+ " `v_prediction` for the UniPCMultistepScheduler."
605
+ )
606
+
607
+ if self.thresholding:
608
+ x0_pred = self._threshold_sample(x0_pred)
609
+
610
+ return x0_pred
611
+ else:
612
+ if self.prediction_type == "epsilon":
613
+ return model_output
614
+ elif self.prediction_type == "sample":
615
+ epsilon = (sample - alpha_t * model_output) / sigma_t
616
+ return epsilon
617
+ elif self.prediction_type == "v_prediction":
618
+ epsilon = alpha_t * model_output + sigma_t * sample
619
+ return epsilon
620
+ else:
621
+ raise ValueError(
622
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
623
+ " `v_prediction` for the UniPCMultistepScheduler."
624
+ )
625
+
626
+ def multistep_uni_p_bh_update(
627
+ self,
628
+ model_output: torch.FloatTensor,
629
+ *args,
630
+ sample: torch.FloatTensor = None,
631
+ order: Optional[int] = None,
632
+ **kwargs,
633
+ ) -> torch.FloatTensor:
634
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
635
+ if sample is None:
636
+ if len(args) > 1:
637
+ sample = args[1]
638
+ else:
639
+ raise ValueError(" missing `sample` as a required keyword argument")
640
+ if order is None:
641
+ if len(args) > 2:
642
+ order = args[2]
643
+ else:
644
+ raise ValueError(" missing `order` as a required keyword argument")
645
+ if prev_timestep is not None:
646
+ print(
647
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
648
+ )
649
+ model_output_list = self.model_outputs
650
+
651
+ # s0 = self.timestep_list[-1]
652
+ m0 = model_output_list[-1]
653
+ x = sample
654
+
655
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
656
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
657
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
658
+
659
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
660
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
661
+
662
+ h = lambda_t - lambda_s0
663
+ device = sample.device
664
+
665
+ rks = []
666
+ d1s = []
667
+ for i in range(1, order):
668
+ si = self.step_index - i
669
+ mi = model_output_list[-(i + 1)]
670
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
671
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
672
+ rk = (lambda_si - lambda_s0) / h
673
+ rks.append(rk)
674
+ d1s.append((mi - m0) / rk)
675
+
676
+ rks.append(1.0)
677
+ rks = torch.tensor(rks, device=device)
678
+
679
+ r = []
680
+ b = []
681
+
682
+ hh = -h if self.predict_x0 else h
683
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
684
+ h_phi_k = h_phi_1 / hh - 1
685
+
686
+ factorial_i = 1
687
+
688
+ if self.solver_type == "bh1":
689
+ b_h = hh
690
+ elif self.solver_type == "bh2":
691
+ b_h = torch.expm1(hh)
692
+ else:
693
+ raise NotImplementedError()
694
+
695
+ for i in range(1, order + 1):
696
+ r.append(torch.pow(rks, i - 1))
697
+ b.append(h_phi_k * factorial_i / b_h)
698
+ factorial_i *= i + 1
699
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
700
+
701
+ r = torch.stack(r)
702
+ b = torch.tensor(b, device=device)
703
+
704
+ if len(d1s) > 0:
705
+ d1s = torch.stack(d1s, dim=1) # (B, K)
706
+ # for order 2, we use a simplified version
707
+ if order == 2:
708
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
709
+ else:
710
+ rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1])
711
+ else:
712
+ d1s = None
713
+
714
+ if self.predict_x0:
715
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
716
+ if d1s is not None:
717
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s)
718
+ else:
719
+ pred_res = 0
720
+ x_t = x_t_ - alpha_t * b_h * pred_res
721
+ else:
722
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
723
+ if d1s is not None:
724
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s)
725
+ else:
726
+ pred_res = 0
727
+ x_t = x_t_ - sigma_t * b_h * pred_res
728
+
729
+ x_t = x_t.to(x.dtype)
730
+ return x_t
731
+
732
+ def multistep_uni_c_bh_update(
733
+ self,
734
+ this_model_output: torch.FloatTensor,
735
+ *args,
736
+ last_sample: torch.FloatTensor = None,
737
+ this_sample: torch.FloatTensor = None,
738
+ order: Optional[int] = None,
739
+ **kwargs,
740
+ ) -> torch.FloatTensor:
741
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
742
+ if last_sample is None:
743
+ if len(args) > 1:
744
+ last_sample = args[1]
745
+ else:
746
+ raise ValueError(" missing`last_sample` as a required keyword argument")
747
+ if this_sample is None:
748
+ if len(args) > 2:
749
+ this_sample = args[2]
750
+ else:
751
+ raise ValueError(" missing`this_sample` as a required keyword argument")
752
+ if order is None:
753
+ if len(args) > 3:
754
+ order = args[3]
755
+ else:
756
+ raise ValueError(" missing`order` as a required keyword argument")
757
+ if this_timestep is not None:
758
+ print(
759
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
760
+ )
761
+
762
+ model_output_list = self.model_outputs
763
+
764
+ m0 = model_output_list[-1]
765
+ x = last_sample
766
+ # x_t = this_sample
767
+ model_t = this_model_output
768
+
769
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
770
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
771
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
772
+
773
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
774
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
775
+
776
+ h = lambda_t - lambda_s0
777
+ device = this_sample.device
778
+
779
+ rks = []
780
+ d1s = []
781
+ for i in range(1, order):
782
+ si = self.step_index - (i + 1)
783
+ mi = model_output_list[-(i + 1)]
784
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
785
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
786
+ rk = (lambda_si - lambda_s0) / h
787
+ rks.append(rk)
788
+ d1s.append((mi - m0) / rk)
789
+
790
+ rks.append(1.0)
791
+ rks = torch.tensor(rks, device=device)
792
+
793
+ r = []
794
+ b = []
795
+
796
+ hh = -h if self.predict_x0 else h
797
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
798
+ h_phi_k = h_phi_1 / hh - 1
799
+
800
+ factorial_i = 1
801
+
802
+ if self.solver_type == "bh1":
803
+ b_h = hh
804
+ elif self.solver_type == "bh2":
805
+ b_h = torch.expm1(hh)
806
+ else:
807
+ raise NotImplementedError()
808
+
809
+ for i in range(1, order + 1):
810
+ r.append(torch.pow(rks, i - 1))
811
+ b.append(h_phi_k * factorial_i / b_h)
812
+ factorial_i *= i + 1
813
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
814
+
815
+ r = torch.stack(r)
816
+ b = torch.tensor(b, device=device)
817
+
818
+ if len(d1s) > 0:
819
+ d1s = torch.stack(d1s, dim=1)
820
+ else:
821
+ d1s = None
822
+
823
+ # for order 1, we use a simplified version
824
+ if order == 1:
825
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
826
+ else:
827
+ rhos_c = torch.linalg.solve(r, b)
828
+
829
+ if self.predict_x0:
830
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
831
+ if d1s is not None:
832
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s)
833
+ else:
834
+ corr_res = 0
835
+ d1_t = model_t - m0
836
+ x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t)
837
+ else:
838
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
839
+ if d1s is not None:
840
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s)
841
+ else:
842
+ corr_res = 0
843
+ d1_t = model_t - m0
844
+ x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t)
845
+ x_t = x_t.to(x.dtype)
846
+ return x_t
847
+
848
+ def _init_step_index(self, timestep):
849
+ if isinstance(timestep, torch.Tensor):
850
+ timestep = timestep.to(self.timesteps.device)
851
+
852
+ index_candidates = (self.timesteps == timestep).nonzero()
853
+
854
+ if len(index_candidates) == 0:
855
+ step_index = len(self.timesteps) - 1
856
+ # The sigma index that is taken for the **very** first `step`
857
+ # is always the second index (or the last index if there is only 1)
858
+ # This way we can ensure we don't accidentally skip a sigma in
859
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
860
+ elif len(index_candidates) > 1:
861
+ step_index = index_candidates[1].item()
862
+ else:
863
+ step_index = index_candidates[0].item()
864
+
865
+ self._step_index = step_index
866
+
867
+ def step(
868
+ self,
869
+ model_output: torch.FloatTensor,
870
+ timestep: int,
871
+ sample: torch.FloatTensor,
872
+ return_dict: bool = True,
873
+ ):
874
+ if self.num_inference_steps is None:
875
+ raise ValueError(
876
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
877
+ )
878
+
879
+ if self.step_index is None:
880
+ self._init_step_index(timestep)
881
+
882
+ use_corrector = (
883
+ self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
884
+ )
885
+
886
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
887
+ if use_corrector:
888
+ sample = self.multistep_uni_c_bh_update(
889
+ this_model_output=model_output_convert,
890
+ last_sample=self.last_sample,
891
+ this_sample=sample,
892
+ order=self.this_order,
893
+ )
894
+
895
+ for i in range(self.solver_order - 1):
896
+ self.model_outputs[i] = self.model_outputs[i + 1]
897
+ self.timestep_list[i] = self.timestep_list[i + 1]
898
+
899
+ self.model_outputs[-1] = model_output_convert
900
+ self.timestep_list[-1] = timestep
901
+
902
+ if self.lower_order_final:
903
+ this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
904
+ else:
905
+ this_order = self.solver_order
906
+
907
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
908
+ assert self.this_order > 0
909
+
910
+ self.last_sample = sample
911
+ prev_sample = self.multistep_uni_p_bh_update(
912
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
913
+ sample=sample,
914
+ order=self.this_order,
915
+ )
916
+
917
+ if self.lower_order_nums < self.solver_order:
918
+ self.lower_order_nums += 1
919
+
920
+ # upon completion increase step index by one
921
+ self._step_index += 1
922
+
923
+ if not return_dict:
924
+ return (prev_sample,)
925
+
926
+ return prev_sample
927
+
928
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
929
+ return sample
930
+
931
+ def add_noise(
932
+ self,
933
+ original_samples: torch.FloatTensor,
934
+ noise: torch.FloatTensor,
935
+ idx,
936
+ timesteps: torch.IntTensor,
937
+ ) -> torch.FloatTensor:
938
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
939
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
940
+ schedule_timesteps = self.timesteps.to(original_samples.device)
941
+ timesteps = timesteps.to(original_samples.device)
942
+
943
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
944
+ sigma = sigmas[step_indices].flatten()
945
+ while len(sigma.shape) < len(original_samples.shape):
946
+ sigma = sigma.unsqueeze(-1)
947
+
948
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
949
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
950
+ return noisy_samples
951
+
952
+ def configure(self):
953
+ pass
954
+
955
+ def __len__(self):
956
+ return self.num_train_timesteps
957
+
958
+
959
+ # Modified from diffusers.schedulers.LCMScheduler
960
+ class LCMScheduler:
961
+ def __init__(
962
+ self,
963
+ device="cuda",
964
+ num_train_timesteps: int = 1000,
965
+ beta_start: float = 0.00085,
966
+ beta_end: float = 0.012,
967
+ original_inference_steps: int = 50,
968
+ clip_sample: bool = False,
969
+ clip_sample_range: float = 1.0,
970
+ steps_offset: int = 0,
971
+ prediction_type: str = "epsilon",
972
+ thresholding: bool = False,
973
+ dynamic_thresholding_ratio: float = 0.995,
974
+ sample_max_value: float = 1.0,
975
+ timestep_spacing: str = "leading",
976
+ timestep_scaling: float = 10.0,
977
+ ):
978
+ self.device = device
979
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
980
+ self.alphas = 1.0 - self.betas
981
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
982
+ self.final_alpha_cumprod = self.alphas_cumprod[0]
983
+ # standard deviation of the initial noise distribution
984
+ self.init_noise_sigma = 1.0
985
+ # setable values
986
+ self.num_inference_steps = None
987
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
988
+
989
+ self.num_train_timesteps = num_train_timesteps
990
+ self.clip_sample = clip_sample
991
+ self.clip_sample_range = clip_sample_range
992
+ self.steps_offset = steps_offset
993
+ self.prediction_type = prediction_type
994
+ self.thresholding = thresholding
995
+ self.timestep_spacing = timestep_spacing
996
+ self.timestep_scaling = timestep_scaling
997
+ self.original_inference_steps = original_inference_steps
998
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
999
+ self.sample_max_value = sample_max_value
1000
+
1001
+ self._step_index = None
1002
+
1003
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
1004
+ def _init_step_index(self, timestep):
1005
+ if isinstance(timestep, torch.Tensor):
1006
+ timestep = timestep.to(self.timesteps.device)
1007
+
1008
+ index_candidates = (self.timesteps == timestep).nonzero()
1009
+
1010
+ if len(index_candidates) > 1:
1011
+ step_index = index_candidates[1]
1012
+ else:
1013
+ step_index = index_candidates[0]
1014
+
1015
+ self._step_index = step_index.item()
1016
+
1017
+ @property
1018
+ def step_index(self):
1019
+ return self._step_index
1020
+
1021
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1022
+ return sample
1023
+
1024
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
1025
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
1026
+ dtype = sample.dtype
1027
+ batch_size, channels, *remaining_dims = sample.shape
1028
+
1029
+ if dtype not in (torch.float32, torch.float64):
1030
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
1031
+
1032
+ # Flatten sample for doing quantile calculation along each image
1033
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
1034
+
1035
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
1036
+
1037
+ s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1)
1038
+ s = torch.clamp(
1039
+ s, min=1, max=self.sample_max_value
1040
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
1041
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
1042
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
1043
+
1044
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
1045
+ sample = sample.to(dtype)
1046
+
1047
+ return sample
1048
+
1049
+ def set_timesteps(
1050
+ self,
1051
+ num_inference_steps: int,
1052
+ strength: int = 1.0,
1053
+ ):
1054
+ assert num_inference_steps <= self.num_train_timesteps
1055
+
1056
+ self.num_inference_steps = num_inference_steps
1057
+ original_steps = self.original_inference_steps
1058
+
1059
+ assert original_steps <= self.num_train_timesteps
1060
+ assert num_inference_steps <= original_steps
1061
+
1062
+ # LCM Timesteps Setting
1063
+ # Currently, only linear spacing is supported.
1064
+ c = self.num_train_timesteps // original_steps
1065
+ # LCM Training Steps Schedule
1066
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
1067
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
1068
+ # LCM Inference Steps Schedule
1069
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
1070
+
1071
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device=self.device, dtype=torch.long)
1072
+
1073
+ self._step_index = None
1074
+
1075
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
1076
+ self.sigma_data = 0.5 # Default: 0.5
1077
+ scaled_timestep = timestep * self.timestep_scaling
1078
+
1079
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
1080
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
1081
+ return c_skip, c_out
1082
+
1083
+ def step(
1084
+ self,
1085
+ model_output: torch.FloatTensor,
1086
+ timestep: int,
1087
+ sample: torch.FloatTensor,
1088
+ generator: Optional[torch.Generator] = None,
1089
+ ):
1090
+ if self.num_inference_steps is None:
1091
+ raise ValueError(
1092
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
1093
+ )
1094
+
1095
+ if self.step_index is None:
1096
+ self._init_step_index(timestep)
1097
+
1098
+ # 1. get previous step value
1099
+ prev_step_index = self.step_index + 1
1100
+ if prev_step_index < len(self.timesteps):
1101
+ prev_timestep = self.timesteps[prev_step_index]
1102
+ else:
1103
+ prev_timestep = timestep
1104
+
1105
+ # 2. compute alphas, betas
1106
+ alpha_prod_t = self.alphas_cumprod[timestep]
1107
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
1108
+
1109
+ beta_prod_t = 1 - alpha_prod_t
1110
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
1111
+
1112
+ # 3. Get scalings for boundary conditions
1113
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
1114
+
1115
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
1116
+ if self.prediction_type == "epsilon": # noise-prediction
1117
+ predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
1118
+ elif self.prediction_type == "sample": # x-prediction
1119
+ predicted_original_sample = model_output
1120
+ elif self.prediction_type == "v_prediction": # v-prediction
1121
+ predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
1122
+ else:
1123
+ raise ValueError(
1124
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or"
1125
+ " `v_prediction` for `LCMScheduler`."
1126
+ )
1127
+
1128
+ # 5. Clip or threshold "predicted x_0"
1129
+ if self.thresholding:
1130
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
1131
+ elif self.clip_sample:
1132
+ predicted_original_sample = predicted_original_sample.clamp(-self.clip_sample_range, self.clip_sample_range)
1133
+
1134
+ # 6. Denoise model output using boundary conditions
1135
+ denoised = c_out * predicted_original_sample + c_skip * sample
1136
+
1137
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
1138
+ # Noise is not used on the final timestep of the timestep schedule.
1139
+ # This also means that noise is not used for one-step sampling.
1140
+ if self.step_index != self.num_inference_steps - 1:
1141
+ noise = torch.randn(
1142
+ model_output.shape, device=model_output.device, dtype=denoised.dtype, generator=generator
1143
+ )
1144
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
1145
+ else:
1146
+ prev_sample = denoised
1147
+
1148
+ # upon completion increase step index by one
1149
+ self._step_index += 1
1150
+
1151
+ return (prev_sample,)
1152
+
1153
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1154
+ def add_noise(
1155
+ self,
1156
+ original_samples: torch.FloatTensor,
1157
+ noise: torch.FloatTensor,
1158
+ timesteps: torch.IntTensor,
1159
+ ) -> torch.FloatTensor:
1160
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1161
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
1162
+ timesteps = timesteps.to(original_samples.device)
1163
+
1164
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1165
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1166
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1167
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1168
+
1169
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1170
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1171
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1172
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1173
+
1174
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1175
+ return noisy_samples
1176
+
1177
+ def configure(self):
1178
+ pass
1179
+
1180
+ def __len__(self):
1181
+ return self.num_train_timesteps