onnxruntime-directml 1.24.1__cp314-cp314-win_amd64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (322) hide show
  1. onnxruntime/LICENSE +21 -0
  2. onnxruntime/Privacy.md +21 -0
  3. onnxruntime/ThirdPartyNotices.txt +6121 -0
  4. onnxruntime/__init__.py +418 -0
  5. onnxruntime/backend/__init__.py +6 -0
  6. onnxruntime/backend/backend.py +175 -0
  7. onnxruntime/backend/backend_rep.py +52 -0
  8. onnxruntime/capi/DirectML.dll +0 -0
  9. onnxruntime/capi/__init__.py +4 -0
  10. onnxruntime/capi/_ld_preload.py +7 -0
  11. onnxruntime/capi/_pybind_state.py +33 -0
  12. onnxruntime/capi/build_and_package_info.py +2 -0
  13. onnxruntime/capi/convert_npz_to_onnx_adapter.py +48 -0
  14. onnxruntime/capi/onnxruntime.dll +0 -0
  15. onnxruntime/capi/onnxruntime_collect_build_info.py +47 -0
  16. onnxruntime/capi/onnxruntime_inference_collection.py +1440 -0
  17. onnxruntime/capi/onnxruntime_providers_shared.dll +0 -0
  18. onnxruntime/capi/onnxruntime_pybind11_state.pyd +0 -0
  19. onnxruntime/capi/onnxruntime_validation.py +154 -0
  20. onnxruntime/capi/version_info.py +2 -0
  21. onnxruntime/datasets/__init__.py +18 -0
  22. onnxruntime/datasets/logreg_iris.onnx +0 -0
  23. onnxruntime/datasets/mul_1.onnx +0 -0
  24. onnxruntime/datasets/sigmoid.onnx +13 -0
  25. onnxruntime/quantization/CalTableFlatBuffers/KeyValue.py +78 -0
  26. onnxruntime/quantization/CalTableFlatBuffers/TrtTable.py +90 -0
  27. onnxruntime/quantization/CalTableFlatBuffers/__init__.py +0 -0
  28. onnxruntime/quantization/__init__.py +19 -0
  29. onnxruntime/quantization/base_quantizer.py +529 -0
  30. onnxruntime/quantization/calibrate.py +1267 -0
  31. onnxruntime/quantization/execution_providers/qnn/__init__.py +2 -0
  32. onnxruntime/quantization/execution_providers/qnn/fusion_lpnorm.py +132 -0
  33. onnxruntime/quantization/execution_providers/qnn/fusion_spacetodepth.py +162 -0
  34. onnxruntime/quantization/execution_providers/qnn/mixed_precision_overrides_utils.py +413 -0
  35. onnxruntime/quantization/execution_providers/qnn/preprocess.py +353 -0
  36. onnxruntime/quantization/execution_providers/qnn/quant_config.py +389 -0
  37. onnxruntime/quantization/fusions/__init__.py +4 -0
  38. onnxruntime/quantization/fusions/fusion.py +311 -0
  39. onnxruntime/quantization/fusions/fusion_gelu.py +272 -0
  40. onnxruntime/quantization/fusions/fusion_layernorm.py +146 -0
  41. onnxruntime/quantization/fusions/replace_upsample_with_resize.py +96 -0
  42. onnxruntime/quantization/matmul_bnb4_quantizer.py +239 -0
  43. onnxruntime/quantization/matmul_nbits_quantizer.py +1638 -0
  44. onnxruntime/quantization/neural_compressor/__init__.py +1 -0
  45. onnxruntime/quantization/neural_compressor/onnx_model.py +1251 -0
  46. onnxruntime/quantization/neural_compressor/util.py +80 -0
  47. onnxruntime/quantization/neural_compressor/weight_only.py +932 -0
  48. onnxruntime/quantization/onnx_model.py +600 -0
  49. onnxruntime/quantization/onnx_quantizer.py +1163 -0
  50. onnxruntime/quantization/operators/__init__.py +2 -0
  51. onnxruntime/quantization/operators/activation.py +119 -0
  52. onnxruntime/quantization/operators/argmax.py +18 -0
  53. onnxruntime/quantization/operators/attention.py +73 -0
  54. onnxruntime/quantization/operators/base_operator.py +26 -0
  55. onnxruntime/quantization/operators/binary_op.py +72 -0
  56. onnxruntime/quantization/operators/concat.py +62 -0
  57. onnxruntime/quantization/operators/conv.py +260 -0
  58. onnxruntime/quantization/operators/direct_q8.py +78 -0
  59. onnxruntime/quantization/operators/embed_layernorm.py +121 -0
  60. onnxruntime/quantization/operators/gather.py +64 -0
  61. onnxruntime/quantization/operators/gavgpool.py +62 -0
  62. onnxruntime/quantization/operators/gemm.py +172 -0
  63. onnxruntime/quantization/operators/lstm.py +121 -0
  64. onnxruntime/quantization/operators/matmul.py +231 -0
  65. onnxruntime/quantization/operators/maxpool.py +34 -0
  66. onnxruntime/quantization/operators/norm.py +40 -0
  67. onnxruntime/quantization/operators/pad.py +172 -0
  68. onnxruntime/quantization/operators/pooling.py +67 -0
  69. onnxruntime/quantization/operators/qdq_base_operator.py +22 -0
  70. onnxruntime/quantization/operators/resize.py +34 -0
  71. onnxruntime/quantization/operators/softmax.py +74 -0
  72. onnxruntime/quantization/operators/split.py +63 -0
  73. onnxruntime/quantization/operators/where.py +87 -0
  74. onnxruntime/quantization/preprocess.py +141 -0
  75. onnxruntime/quantization/qdq_loss_debug.py +389 -0
  76. onnxruntime/quantization/qdq_quantizer.py +1477 -0
  77. onnxruntime/quantization/quant_utils.py +1051 -0
  78. onnxruntime/quantization/quantize.py +953 -0
  79. onnxruntime/quantization/registry.py +110 -0
  80. onnxruntime/quantization/shape_inference.py +204 -0
  81. onnxruntime/quantization/static_quantize_runner.py +256 -0
  82. onnxruntime/quantization/tensor_quant_overrides.py +520 -0
  83. onnxruntime/tools/__init__.py +10 -0
  84. onnxruntime/tools/check_onnx_model_mobile_usability.py +47 -0
  85. onnxruntime/tools/convert_onnx_models_to_ort.py +380 -0
  86. onnxruntime/tools/file_utils.py +47 -0
  87. onnxruntime/tools/logger.py +11 -0
  88. onnxruntime/tools/make_dynamic_shape_fixed.py +73 -0
  89. onnxruntime/tools/mobile_helpers/__init__.py +0 -0
  90. onnxruntime/tools/mobile_helpers/coreml_supported_mlprogram_ops.md +53 -0
  91. onnxruntime/tools/mobile_helpers/coreml_supported_neuralnetwork_ops.md +43 -0
  92. onnxruntime/tools/mobile_helpers/nnapi_supported_ops.md +58 -0
  93. onnxruntime/tools/mobile_helpers/usability_checker.py +738 -0
  94. onnxruntime/tools/offline_tuning.py +169 -0
  95. onnxruntime/tools/onnx_model_utils.py +416 -0
  96. onnxruntime/tools/onnx_randomizer.py +85 -0
  97. onnxruntime/tools/onnxruntime_test.py +164 -0
  98. onnxruntime/tools/optimize_onnx_model.py +56 -0
  99. onnxruntime/tools/ort_format_model/__init__.py +27 -0
  100. onnxruntime/tools/ort_format_model/operator_type_usage_processors.py +653 -0
  101. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/__init__.py +0 -0
  102. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgType.py +7 -0
  103. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ArgTypeAndIndex.py +67 -0
  104. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Attribute.py +337 -0
  105. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/AttributeType.py +18 -0
  106. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Checkpoint.py +125 -0
  107. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedKernelCreateInfos.py +120 -0
  108. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedNodeIndexAndKernelDefHash.py +68 -0
  109. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSessionState.py +96 -0
  110. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DeprecatedSubGraphSessionState.py +72 -0
  111. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Dimension.py +71 -0
  112. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValue.py +80 -0
  113. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/DimensionValueType.py +8 -0
  114. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/EdgeEnd.py +32 -0
  115. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/FloatProperty.py +67 -0
  116. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Graph.py +320 -0
  117. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/InferenceSession.py +88 -0
  118. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/IntProperty.py +67 -0
  119. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrArgsEntry.py +91 -0
  120. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/KernelTypeStrResolver.py +78 -0
  121. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/MapType.py +71 -0
  122. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Model.py +223 -0
  123. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ModuleState.py +141 -0
  124. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Node.py +317 -0
  125. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeEdge.py +126 -0
  126. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodeType.py +7 -0
  127. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/NodesToOptimizeIndices.py +160 -0
  128. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OpIdKernelTypeStrArgsEntry.py +91 -0
  129. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OperatorSetId.py +67 -0
  130. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/OptimizerGroup.py +117 -0
  131. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ParameterOptimizerState.py +91 -0
  132. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/PropertyBag.py +152 -0
  133. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecord.py +105 -0
  134. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizationRecordContainerEntry.py +91 -0
  135. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/RuntimeOptimizations.py +79 -0
  136. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SequenceType.py +58 -0
  137. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Shape.py +78 -0
  138. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/SparseTensor.py +114 -0
  139. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringProperty.py +67 -0
  140. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/StringStringEntry.py +67 -0
  141. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/Tensor.py +203 -0
  142. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorDataType.py +26 -0
  143. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TensorTypeAndShape.py +71 -0
  144. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfo.py +83 -0
  145. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/TypeInfoValue.py +9 -0
  146. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/ValueInfo.py +84 -0
  147. onnxruntime/tools/ort_format_model/ort_flatbuffers_py/fbs/__init__.py +6 -0
  148. onnxruntime/tools/ort_format_model/ort_model_processor.py +86 -0
  149. onnxruntime/tools/ort_format_model/types.py +85 -0
  150. onnxruntime/tools/ort_format_model/utils.py +61 -0
  151. onnxruntime/tools/pytorch_export_contrib_ops.py +129 -0
  152. onnxruntime/tools/pytorch_export_helpers.py +131 -0
  153. onnxruntime/tools/qdq_helpers/__init__.py +0 -0
  154. onnxruntime/tools/qdq_helpers/optimize_qdq_model.py +37 -0
  155. onnxruntime/tools/qnn/add_trans_cast.py +292 -0
  156. onnxruntime/tools/qnn/gen_qnn_ctx_onnx_model.py +364 -0
  157. onnxruntime/tools/qnn/preprocess.py +165 -0
  158. onnxruntime/tools/reduced_build_config_parser.py +203 -0
  159. onnxruntime/tools/remove_initializer_from_input.py +37 -0
  160. onnxruntime/tools/symbolic_shape_infer.py +3094 -0
  161. onnxruntime/tools/update_onnx_opset.py +31 -0
  162. onnxruntime/transformers/__init__.py +8 -0
  163. onnxruntime/transformers/affinity_helper.py +40 -0
  164. onnxruntime/transformers/benchmark.py +942 -0
  165. onnxruntime/transformers/benchmark_helper.py +643 -0
  166. onnxruntime/transformers/bert_perf_test.py +629 -0
  167. onnxruntime/transformers/bert_test_data.py +641 -0
  168. onnxruntime/transformers/compare_bert_results.py +256 -0
  169. onnxruntime/transformers/constants.py +47 -0
  170. onnxruntime/transformers/convert_generation.py +3605 -0
  171. onnxruntime/transformers/convert_tf_models_to_pytorch.py +205 -0
  172. onnxruntime/transformers/convert_to_packing_mode.py +385 -0
  173. onnxruntime/transformers/dynamo_onnx_helper.py +205 -0
  174. onnxruntime/transformers/float16.py +501 -0
  175. onnxruntime/transformers/fusion_attention.py +1189 -0
  176. onnxruntime/transformers/fusion_attention_clip.py +340 -0
  177. onnxruntime/transformers/fusion_attention_sam2.py +533 -0
  178. onnxruntime/transformers/fusion_attention_unet.py +1307 -0
  179. onnxruntime/transformers/fusion_attention_vae.py +300 -0
  180. onnxruntime/transformers/fusion_bart_attention.py +435 -0
  181. onnxruntime/transformers/fusion_base.py +141 -0
  182. onnxruntime/transformers/fusion_bias_add.py +57 -0
  183. onnxruntime/transformers/fusion_biasgelu.py +66 -0
  184. onnxruntime/transformers/fusion_biassplitgelu.py +110 -0
  185. onnxruntime/transformers/fusion_conformer_attention.py +222 -0
  186. onnxruntime/transformers/fusion_constant_fold.py +144 -0
  187. onnxruntime/transformers/fusion_embedlayer.py +810 -0
  188. onnxruntime/transformers/fusion_fastgelu.py +492 -0
  189. onnxruntime/transformers/fusion_gelu.py +258 -0
  190. onnxruntime/transformers/fusion_gelu_approximation.py +25 -0
  191. onnxruntime/transformers/fusion_gemmfastgelu.py +121 -0
  192. onnxruntime/transformers/fusion_gpt_attention.py +546 -0
  193. onnxruntime/transformers/fusion_gpt_attention_megatron.py +355 -0
  194. onnxruntime/transformers/fusion_gpt_attention_no_past.py +260 -0
  195. onnxruntime/transformers/fusion_group_norm.py +180 -0
  196. onnxruntime/transformers/fusion_layernorm.py +489 -0
  197. onnxruntime/transformers/fusion_mha_mmdit.py +667 -0
  198. onnxruntime/transformers/fusion_nhwc_conv.py +99 -0
  199. onnxruntime/transformers/fusion_options.py +340 -0
  200. onnxruntime/transformers/fusion_qordered_attention.py +420 -0
  201. onnxruntime/transformers/fusion_qordered_gelu.py +118 -0
  202. onnxruntime/transformers/fusion_qordered_layernorm.py +122 -0
  203. onnxruntime/transformers/fusion_qordered_matmul.py +216 -0
  204. onnxruntime/transformers/fusion_quickgelu.py +74 -0
  205. onnxruntime/transformers/fusion_reshape.py +173 -0
  206. onnxruntime/transformers/fusion_rotary_attention.py +1591 -0
  207. onnxruntime/transformers/fusion_shape.py +109 -0
  208. onnxruntime/transformers/fusion_simplified_layernorm.py +165 -0
  209. onnxruntime/transformers/fusion_skip_group_norm.py +254 -0
  210. onnxruntime/transformers/fusion_skiplayernorm.py +209 -0
  211. onnxruntime/transformers/fusion_transpose.py +167 -0
  212. onnxruntime/transformers/fusion_utils.py +321 -0
  213. onnxruntime/transformers/huggingface_models.py +74 -0
  214. onnxruntime/transformers/import_utils.py +20 -0
  215. onnxruntime/transformers/io_binding_helper.py +487 -0
  216. onnxruntime/transformers/large_model_exporter.py +395 -0
  217. onnxruntime/transformers/machine_info.py +230 -0
  218. onnxruntime/transformers/metrics.py +163 -0
  219. onnxruntime/transformers/models/bart/__init__.py +12 -0
  220. onnxruntime/transformers/models/bart/export.py +98 -0
  221. onnxruntime/transformers/models/bert/__init__.py +12 -0
  222. onnxruntime/transformers/models/bert/eval_squad.py +329 -0
  223. onnxruntime/transformers/models/gpt2/__init__.py +12 -0
  224. onnxruntime/transformers/models/gpt2/benchmark_gpt2.py +413 -0
  225. onnxruntime/transformers/models/gpt2/convert_to_onnx.py +566 -0
  226. onnxruntime/transformers/models/gpt2/gpt2_helper.py +1031 -0
  227. onnxruntime/transformers/models/gpt2/gpt2_parity.py +513 -0
  228. onnxruntime/transformers/models/gpt2/gpt2_tester.py +501 -0
  229. onnxruntime/transformers/models/gpt2/parity_check_helper.py +146 -0
  230. onnxruntime/transformers/models/llama/__init__.py +12 -0
  231. onnxruntime/transformers/models/llama/benchmark.py +700 -0
  232. onnxruntime/transformers/models/llama/benchmark_all.py +488 -0
  233. onnxruntime/transformers/models/llama/benchmark_e2e.py +608 -0
  234. onnxruntime/transformers/models/llama/convert_to_onnx.py +1064 -0
  235. onnxruntime/transformers/models/llama/dist_settings.py +57 -0
  236. onnxruntime/transformers/models/llama/llama_inputs.py +504 -0
  237. onnxruntime/transformers/models/llama/llama_parity.py +343 -0
  238. onnxruntime/transformers/models/llama/llama_torch.py +47 -0
  239. onnxruntime/transformers/models/llama/quant_kv_dataloader.py +108 -0
  240. onnxruntime/transformers/models/longformer/__init__.py +12 -0
  241. onnxruntime/transformers/models/longformer/benchmark_longformer.py +821 -0
  242. onnxruntime/transformers/models/longformer/convert_to_onnx.py +413 -0
  243. onnxruntime/transformers/models/longformer/generate_test_data.py +347 -0
  244. onnxruntime/transformers/models/longformer/longformer_helper.py +76 -0
  245. onnxruntime/transformers/models/phi2/__init__.py +12 -0
  246. onnxruntime/transformers/models/phi2/convert_to_onnx.py +590 -0
  247. onnxruntime/transformers/models/phi2/inference_example.py +414 -0
  248. onnxruntime/transformers/models/sam2/__init__.py +12 -0
  249. onnxruntime/transformers/models/sam2/benchmark_sam2.py +638 -0
  250. onnxruntime/transformers/models/sam2/convert_to_onnx.py +270 -0
  251. onnxruntime/transformers/models/sam2/image_decoder.py +272 -0
  252. onnxruntime/transformers/models/sam2/image_encoder.py +236 -0
  253. onnxruntime/transformers/models/sam2/mask_decoder.py +208 -0
  254. onnxruntime/transformers/models/sam2/nvtx_helper.py +33 -0
  255. onnxruntime/transformers/models/sam2/prompt_encoder.py +189 -0
  256. onnxruntime/transformers/models/sam2/sam2_demo.py +321 -0
  257. onnxruntime/transformers/models/sam2/sam2_image_onnx_predictor.py +279 -0
  258. onnxruntime/transformers/models/sam2/sam2_utils.py +147 -0
  259. onnxruntime/transformers/models/stable_diffusion/__init__.py +12 -0
  260. onnxruntime/transformers/models/stable_diffusion/benchmark.py +1519 -0
  261. onnxruntime/transformers/models/stable_diffusion/benchmark_controlnet.py +426 -0
  262. onnxruntime/transformers/models/stable_diffusion/demo_txt2img.py +103 -0
  263. onnxruntime/transformers/models/stable_diffusion/demo_txt2img_xl.py +269 -0
  264. onnxruntime/transformers/models/stable_diffusion/demo_utils.py +778 -0
  265. onnxruntime/transformers/models/stable_diffusion/diffusion_models.py +1318 -0
  266. onnxruntime/transformers/models/stable_diffusion/diffusion_schedulers.py +1179 -0
  267. onnxruntime/transformers/models/stable_diffusion/engine_builder.py +295 -0
  268. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_cuda.py +387 -0
  269. onnxruntime/transformers/models/stable_diffusion/engine_builder_ort_trt.py +288 -0
  270. onnxruntime/transformers/models/stable_diffusion/engine_builder_tensorrt.py +395 -0
  271. onnxruntime/transformers/models/stable_diffusion/engine_builder_torch.py +108 -0
  272. onnxruntime/transformers/models/stable_diffusion/optimize_pipeline.py +590 -0
  273. onnxruntime/transformers/models/stable_diffusion/ort_optimizer.py +136 -0
  274. onnxruntime/transformers/models/stable_diffusion/pipeline_stable_diffusion.py +831 -0
  275. onnxruntime/transformers/models/stable_diffusion/trt_utilities.py +12 -0
  276. onnxruntime/transformers/models/t5/__init__.py +12 -0
  277. onnxruntime/transformers/models/t5/convert_to_onnx.py +318 -0
  278. onnxruntime/transformers/models/t5/t5_decoder.py +437 -0
  279. onnxruntime/transformers/models/t5/t5_encoder.py +70 -0
  280. onnxruntime/transformers/models/t5/t5_encoder_decoder_init.py +361 -0
  281. onnxruntime/transformers/models/t5/t5_helper.py +302 -0
  282. onnxruntime/transformers/models/whisper/__init__.py +12 -0
  283. onnxruntime/transformers/models/whisper/benchmark.py +585 -0
  284. onnxruntime/transformers/models/whisper/benchmark_all.py +526 -0
  285. onnxruntime/transformers/models/whisper/convert_to_onnx.py +609 -0
  286. onnxruntime/transformers/models/whisper/whisper_chain.py +334 -0
  287. onnxruntime/transformers/models/whisper/whisper_decoder.py +464 -0
  288. onnxruntime/transformers/models/whisper/whisper_encoder.py +164 -0
  289. onnxruntime/transformers/models/whisper/whisper_encoder_decoder_init.py +371 -0
  290. onnxruntime/transformers/models/whisper/whisper_helper.py +1035 -0
  291. onnxruntime/transformers/models/whisper/whisper_inputs.py +380 -0
  292. onnxruntime/transformers/models/whisper/whisper_jump_times.py +477 -0
  293. onnxruntime/transformers/onnx_exporter.py +719 -0
  294. onnxruntime/transformers/onnx_model.py +1636 -0
  295. onnxruntime/transformers/onnx_model_bart.py +141 -0
  296. onnxruntime/transformers/onnx_model_bert.py +488 -0
  297. onnxruntime/transformers/onnx_model_bert_keras.py +474 -0
  298. onnxruntime/transformers/onnx_model_bert_tf.py +588 -0
  299. onnxruntime/transformers/onnx_model_clip.py +42 -0
  300. onnxruntime/transformers/onnx_model_conformer.py +32 -0
  301. onnxruntime/transformers/onnx_model_gpt2.py +101 -0
  302. onnxruntime/transformers/onnx_model_mmdit.py +112 -0
  303. onnxruntime/transformers/onnx_model_phi.py +929 -0
  304. onnxruntime/transformers/onnx_model_sam2.py +137 -0
  305. onnxruntime/transformers/onnx_model_t5.py +985 -0
  306. onnxruntime/transformers/onnx_model_tnlr.py +226 -0
  307. onnxruntime/transformers/onnx_model_unet.py +258 -0
  308. onnxruntime/transformers/onnx_model_vae.py +42 -0
  309. onnxruntime/transformers/onnx_utils.py +55 -0
  310. onnxruntime/transformers/optimizer.py +620 -0
  311. onnxruntime/transformers/past_helper.py +149 -0
  312. onnxruntime/transformers/profile_result_processor.py +358 -0
  313. onnxruntime/transformers/profiler.py +434 -0
  314. onnxruntime/transformers/quantize_helper.py +76 -0
  315. onnxruntime/transformers/shape_infer_helper.py +121 -0
  316. onnxruntime/transformers/shape_optimizer.py +400 -0
  317. onnxruntime/transformers/torch_onnx_export_helper.py +74 -0
  318. onnxruntime_directml-1.24.1.dist-info/METADATA +216 -0
  319. onnxruntime_directml-1.24.1.dist-info/RECORD +322 -0
  320. onnxruntime_directml-1.24.1.dist-info/WHEEL +5 -0
  321. onnxruntime_directml-1.24.1.dist-info/entry_points.txt +2 -0
  322. onnxruntime_directml-1.24.1.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1179 @@
1
+ # -------------------------------------------------------------------------
2
+ # Copyright (c) Microsoft Corporation. All rights reserved.
3
+ # Licensed under the MIT License.
4
+ # --------------------------------------------------------------------------
5
+ # Modified from utilities.py of TensorRT demo diffusion, which has the following license:
6
+ #
7
+ # Copyright 2022 The HuggingFace Inc. team.
8
+ # SPDX-FileCopyrightText: Copyright (c) 1993-2023 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
9
+ # SPDX-License-Identifier: Apache-2.0
10
+ #
11
+ # Licensed under the Apache License, Version 2.0 (the "License");
12
+ # you may not use this file except in compliance with the License.
13
+ # You may obtain a copy of the License at
14
+ #
15
+ # http://www.apache.org/licenses/LICENSE-2.0
16
+ #
17
+ # Unless required by applicable law or agreed to in writing, software
18
+ # distributed under the License is distributed on an "AS IS" BASIS,
19
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
20
+ # See the License for the specific language governing permissions and
21
+ # limitations under the License.
22
+ # --------------------------------------------------------------------------
23
+
24
+
25
+ import numpy as np
26
+ import torch
27
+
28
+
29
+ class DDIMScheduler:
30
+ def __init__(
31
+ self,
32
+ device="cuda",
33
+ num_train_timesteps: int = 1000,
34
+ beta_start: float = 0.0001,
35
+ beta_end: float = 0.02,
36
+ clip_sample: bool = False,
37
+ set_alpha_to_one: bool = False,
38
+ steps_offset: int = 1,
39
+ prediction_type: str = "epsilon",
40
+ timestep_spacing: str = "leading",
41
+ ):
42
+ # this schedule is very specific to the latent diffusion model.
43
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
44
+
45
+ alphas = 1.0 - betas
46
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
47
+ # standard deviation of the initial noise distribution
48
+ self.init_noise_sigma = 1.0
49
+
50
+ # At every step in ddim, we are looking into the previous alphas_cumprod
51
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
52
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
53
+ # whether we use the final alpha of the "non-previous" one.
54
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
55
+
56
+ # setable values
57
+ self.num_inference_steps = None
58
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
59
+ self.steps_offset = steps_offset
60
+ self.num_train_timesteps = num_train_timesteps
61
+ self.clip_sample = clip_sample
62
+ self.prediction_type = prediction_type
63
+ self.device = device
64
+ self.timestep_spacing = timestep_spacing
65
+
66
+ def configure(self):
67
+ variance = np.zeros(self.num_inference_steps, dtype=np.float32)
68
+ for idx, timestep in enumerate(self.timesteps):
69
+ prev_timestep = timestep - self.num_train_timesteps // self.num_inference_steps
70
+ variance[idx] = self._get_variance(timestep, prev_timestep)
71
+ self.variance = torch.from_numpy(variance).to(self.device)
72
+
73
+ timesteps = self.timesteps.long().cpu()
74
+ self.filtered_alphas_cumprod = self.alphas_cumprod[timesteps].to(self.device)
75
+ self.final_alpha_cumprod = self.final_alpha_cumprod.to(self.device)
76
+
77
+ def scale_model_input(self, sample: torch.FloatTensor, idx, *args, **kwargs) -> torch.FloatTensor:
78
+ return sample
79
+
80
+ def _get_variance(self, timestep, prev_timestep):
81
+ alpha_prod_t = self.alphas_cumprod[timestep]
82
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
83
+ beta_prod_t = 1 - alpha_prod_t
84
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
85
+
86
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
87
+
88
+ return variance
89
+
90
+ def set_timesteps(self, num_inference_steps: int):
91
+ self.num_inference_steps = num_inference_steps
92
+ if self.timestep_spacing == "leading":
93
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
94
+ # creates integer timesteps by multiplying by ratio
95
+ # casting to int to avoid issues when num_inference_step is power of 3
96
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
97
+ timesteps += self.steps_offset
98
+ elif self.timestep_spacing == "trailing":
99
+ step_ratio = self.num_train_timesteps / self.num_inference_steps
100
+ # creates integer timesteps by multiplying by ratio
101
+ # casting to int to avoid issues when num_inference_step is power of 3
102
+ timesteps = np.round(np.arange(self.num_train_timesteps, 0, -step_ratio)).astype(np.int64)
103
+ timesteps -= 1
104
+ else:
105
+ raise ValueError(
106
+ f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
107
+ )
108
+
109
+ self.timesteps = torch.from_numpy(timesteps).to(self.device)
110
+
111
+ def step(
112
+ self,
113
+ model_output,
114
+ sample,
115
+ idx,
116
+ timestep,
117
+ eta: float = 0.0,
118
+ use_clipped_model_output: bool = False,
119
+ generator=None,
120
+ variance_noise: torch.FloatTensor = None,
121
+ ):
122
+ if self.num_inference_steps is None:
123
+ raise ValueError(
124
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
125
+ )
126
+
127
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
128
+ # Ideally, read DDIM paper in-detail understanding
129
+
130
+ # Notation (<variable name> -> <name in paper>
131
+ # - pred_noise_t -> e_theta(x_t, t)
132
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
133
+ # - std_dev_t -> sigma_t
134
+ # - eta -> η
135
+ # - pred_sample_direction -> "direction pointing to x_t"
136
+ # - pred_prev_sample -> "x_t-1"
137
+
138
+ prev_idx = idx + 1
139
+ alpha_prod_t = self.filtered_alphas_cumprod[idx]
140
+ alpha_prod_t_prev = (
141
+ self.filtered_alphas_cumprod[prev_idx] if prev_idx < self.num_inference_steps else self.final_alpha_cumprod
142
+ )
143
+
144
+ beta_prod_t = 1 - alpha_prod_t
145
+
146
+ # 3. compute predicted original sample from predicted noise also called
147
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
148
+ if self.prediction_type == "epsilon":
149
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
150
+ elif self.prediction_type == "sample":
151
+ pred_original_sample = model_output
152
+ elif self.prediction_type == "v_prediction":
153
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
154
+ # predict V
155
+ model_output = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
156
+ else:
157
+ raise ValueError(
158
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or `v_prediction`"
159
+ )
160
+
161
+ # 4. Clip "predicted x_0"
162
+ if self.clip_sample:
163
+ pred_original_sample = torch.clamp(pred_original_sample, -1, 1)
164
+
165
+ # 5. compute variance: "sigma_t(η)" -> see formula (16)
166
+ # o_t = sqrt((1 - a_t-1)/(1 - a_t)) * sqrt(1 - a_t/a_t-1)
167
+ variance = self.variance[idx]
168
+ std_dev_t = eta * variance ** (0.5)
169
+
170
+ if use_clipped_model_output:
171
+ # the model_output is always re-derived from the clipped x_0 in Glide
172
+ model_output = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
173
+
174
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
175
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * model_output
176
+
177
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
178
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
179
+
180
+ if eta > 0:
181
+ # randn_like does not support generator https://github.com/pytorch/pytorch/issues/27072
182
+ device = model_output.device
183
+ if variance_noise is not None and generator is not None:
184
+ raise ValueError(
185
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
186
+ " `variance_noise` stays `None`."
187
+ )
188
+
189
+ if variance_noise is None:
190
+ variance_noise = torch.randn(
191
+ model_output.shape, generator=generator, device=device, dtype=model_output.dtype
192
+ )
193
+ variance = std_dev_t * variance_noise
194
+
195
+ prev_sample = prev_sample + variance
196
+
197
+ return prev_sample
198
+
199
+ def add_noise(self, init_latents, noise, idx, latent_timestep):
200
+ sqrt_alpha_prod = self.filtered_alphas_cumprod[idx] ** 0.5
201
+ sqrt_one_minus_alpha_prod = (1 - self.filtered_alphas_cumprod[idx]) ** 0.5
202
+ noisy_latents = sqrt_alpha_prod * init_latents + sqrt_one_minus_alpha_prod * noise
203
+
204
+ return noisy_latents
205
+
206
+
207
+ class EulerAncestralDiscreteScheduler:
208
+ def __init__(
209
+ self,
210
+ num_train_timesteps: int = 1000,
211
+ beta_start: float = 0.0001,
212
+ beta_end: float = 0.02,
213
+ device="cuda",
214
+ steps_offset: int = 1,
215
+ prediction_type: str = "epsilon",
216
+ timestep_spacing: str = "trailing", # set default to trailing for SDXL Turbo
217
+ ):
218
+ betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
219
+ alphas = 1.0 - betas
220
+ self.alphas_cumprod = torch.cumprod(alphas, dim=0)
221
+
222
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
223
+ sigmas = np.concatenate([sigmas[::-1], [0.0]]).astype(np.float32)
224
+ self.sigmas = torch.from_numpy(sigmas)
225
+
226
+ # standard deviation of the initial noise distribution
227
+ self.init_noise_sigma = self.sigmas.max()
228
+
229
+ # setable values
230
+ self.num_inference_steps = None
231
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=float)[::-1].copy()
232
+ self.timesteps = torch.from_numpy(timesteps)
233
+ self.is_scale_input_called = False
234
+
235
+ self._step_index = None
236
+
237
+ self.device = device
238
+ self.num_train_timesteps = num_train_timesteps
239
+ self.steps_offset = steps_offset
240
+ self.prediction_type = prediction_type
241
+ self.timestep_spacing = timestep_spacing
242
+
243
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
244
+ def _init_step_index(self, timestep):
245
+ if isinstance(timestep, torch.Tensor):
246
+ timestep = timestep.to(self.timesteps.device)
247
+
248
+ index_candidates = (self.timesteps == timestep).nonzero()
249
+
250
+ # The sigma index that is taken for the **very** first `step`
251
+ # is always the second index (or the last index if there is only 1)
252
+ # This way we can ensure we don't accidentally skip a sigma in
253
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
254
+ if len(index_candidates) > 1:
255
+ step_index = index_candidates[1]
256
+ else:
257
+ step_index = index_candidates[0]
258
+
259
+ self._step_index = step_index.item()
260
+
261
+ def scale_model_input(self, sample: torch.FloatTensor, idx, timestep, *args, **kwargs) -> torch.FloatTensor:
262
+ if self._step_index is None:
263
+ self._init_step_index(timestep)
264
+
265
+ sigma = self.sigmas[self._step_index]
266
+ sample = sample / ((sigma**2 + 1) ** 0.5)
267
+ self.is_scale_input_called = True
268
+ return sample
269
+
270
+ def set_timesteps(self, num_inference_steps: int):
271
+ self.num_inference_steps = num_inference_steps
272
+
273
+ if self.timestep_spacing == "linspace":
274
+ timesteps = np.linspace(0, self.num_train_timesteps - 1, num_inference_steps, dtype=np.float32)[::-1].copy()
275
+ elif self.timestep_spacing == "leading":
276
+ step_ratio = self.num_train_timesteps // self.num_inference_steps
277
+ # creates integer timesteps by multiplying by ratio
278
+ # casting to int to avoid issues when num_inference_step is power of 3
279
+ timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.float32)
280
+ timesteps += self.steps_offset
281
+ elif self.timestep_spacing == "trailing":
282
+ step_ratio = self.num_train_timesteps / self.num_inference_steps
283
+ # creates integer timesteps by multiplying by ratio
284
+ # casting to int to avoid issues when num_inference_step is power of 3
285
+ timesteps = (np.arange(self.num_train_timesteps, 0, -step_ratio)).round().copy().astype(np.float32)
286
+ timesteps -= 1
287
+ else:
288
+ raise ValueError(
289
+ f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
290
+ )
291
+
292
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
293
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
294
+ sigmas = np.concatenate([sigmas, [0.0]]).astype(np.float32)
295
+ self.sigmas = torch.from_numpy(sigmas).to(device=self.device)
296
+ self.timesteps = torch.from_numpy(timesteps).to(device=self.device)
297
+
298
+ self._step_index = None
299
+
300
+ def configure(self):
301
+ dts = np.zeros(self.num_inference_steps, dtype=np.float32)
302
+ sigmas_up = np.zeros(self.num_inference_steps, dtype=np.float32)
303
+ for idx, timestep in enumerate(self.timesteps):
304
+ step_index = (self.timesteps == timestep).nonzero().item()
305
+ sigma = self.sigmas[step_index]
306
+
307
+ sigma_from = self.sigmas[step_index]
308
+ sigma_to = self.sigmas[step_index + 1]
309
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
310
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
311
+ dt = sigma_down - sigma
312
+ dts[idx] = dt
313
+ sigmas_up[idx] = sigma_up
314
+
315
+ self.dts = torch.from_numpy(dts).to(self.device)
316
+ self.sigmas_up = torch.from_numpy(sigmas_up).to(self.device)
317
+
318
+ def step(
319
+ self,
320
+ model_output,
321
+ sample,
322
+ idx,
323
+ timestep,
324
+ generator=None,
325
+ ):
326
+ if self._step_index is None:
327
+ self._init_step_index(timestep)
328
+ sigma = self.sigmas[self._step_index]
329
+
330
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
331
+ if self.prediction_type == "epsilon":
332
+ pred_original_sample = sample - sigma * model_output
333
+ elif self.prediction_type == "v_prediction":
334
+ # * c_out + input * c_skip
335
+ pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (sample / (sigma**2 + 1))
336
+ else:
337
+ raise ValueError(
338
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, or `v_prediction`"
339
+ )
340
+
341
+ sigma_from = self.sigmas[self._step_index]
342
+ sigma_to = self.sigmas[self._step_index + 1]
343
+ sigma_up = (sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2) ** 0.5
344
+ sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5
345
+
346
+ # 2. Convert to an ODE derivative
347
+ derivative = (sample - pred_original_sample) / sigma
348
+
349
+ dt = sigma_down - sigma
350
+
351
+ prev_sample = sample + derivative * dt
352
+
353
+ device = model_output.device
354
+ noise = torch.randn(model_output.shape, dtype=model_output.dtype, device=device, generator=generator).to(device)
355
+
356
+ prev_sample = prev_sample + noise * sigma_up
357
+
358
+ # upon completion increase step index by one
359
+ self._step_index += 1
360
+
361
+ return prev_sample
362
+
363
+ def add_noise(self, original_samples, noise, idx, timestep=None):
364
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
365
+ schedule_timesteps = self.timesteps.to(original_samples.device)
366
+ timesteps = timestep.to(original_samples.device)
367
+
368
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
369
+
370
+ sigma = sigmas[step_indices].flatten()
371
+ while len(sigma.shape) < len(original_samples.shape):
372
+ sigma = sigma.unsqueeze(-1)
373
+
374
+ noisy_samples = original_samples + noise * sigma
375
+ return noisy_samples
376
+
377
+
378
+ class UniPCMultistepScheduler:
379
+ def __init__(
380
+ self,
381
+ device="cuda",
382
+ num_train_timesteps: int = 1000,
383
+ beta_start: float = 0.00085,
384
+ beta_end: float = 0.012,
385
+ solver_order: int = 2,
386
+ prediction_type: str = "epsilon",
387
+ thresholding: bool = False,
388
+ dynamic_thresholding_ratio: float = 0.995,
389
+ sample_max_value: float = 1.0,
390
+ predict_x0: bool = True,
391
+ solver_type: str = "bh2",
392
+ lower_order_final: bool = True,
393
+ disable_corrector: list[int] | None = None,
394
+ use_karras_sigmas: bool | None = False,
395
+ timestep_spacing: str = "linspace",
396
+ steps_offset: int = 0,
397
+ sigma_min=None,
398
+ sigma_max=None,
399
+ ):
400
+ self.device = device
401
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
402
+
403
+ self.alphas = 1.0 - self.betas
404
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
405
+ # Currently we only support VP-type noise schedule
406
+ self.alpha_t = torch.sqrt(self.alphas_cumprod)
407
+ self.sigma_t = torch.sqrt(1 - self.alphas_cumprod)
408
+ self.lambda_t = torch.log(self.alpha_t) - torch.log(self.sigma_t)
409
+
410
+ # standard deviation of the initial noise distribution
411
+ self.init_noise_sigma = 1.0
412
+
413
+ self.predict_x0 = predict_x0
414
+ # setable values
415
+ self.num_inference_steps = None
416
+ timesteps = np.linspace(0, num_train_timesteps - 1, num_train_timesteps, dtype=np.float32)[::-1].copy()
417
+ self.timesteps = torch.from_numpy(timesteps)
418
+ self.model_outputs = [None] * solver_order
419
+ self.timestep_list = [None] * solver_order
420
+ self.lower_order_nums = 0
421
+ self.disable_corrector = disable_corrector if disable_corrector else []
422
+ self.last_sample = None
423
+
424
+ self._step_index = None
425
+
426
+ self.num_train_timesteps = num_train_timesteps
427
+ self.solver_order = solver_order
428
+ self.prediction_type = prediction_type
429
+ self.thresholding = thresholding
430
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
431
+ self.sample_max_value = sample_max_value
432
+ self.solver_type = solver_type
433
+ self.lower_order_final = lower_order_final
434
+ self.use_karras_sigmas = use_karras_sigmas
435
+ self.timestep_spacing = timestep_spacing
436
+ self.steps_offset = steps_offset
437
+ self.sigma_min = sigma_min
438
+ self.sigma_max = sigma_max
439
+
440
+ @property
441
+ def step_index(self):
442
+ """
443
+ The index counter for current timestep. It will increase 1 after each scheduler step.
444
+ """
445
+ return self._step_index
446
+
447
+ def set_timesteps(self, num_inference_steps: int):
448
+ if self.timestep_spacing == "linspace":
449
+ timesteps = (
450
+ np.linspace(0, self.num_train_timesteps - 1, num_inference_steps + 1)
451
+ .round()[::-1][:-1]
452
+ .copy()
453
+ .astype(np.int64)
454
+ )
455
+ elif self.timestep_spacing == "leading":
456
+ step_ratio = self.num_train_timesteps // (num_inference_steps + 1)
457
+ # creates integer timesteps by multiplying by ratio
458
+ # casting to int to avoid issues when num_inference_step is power of 3
459
+ timesteps = (np.arange(0, num_inference_steps + 1) * step_ratio).round()[::-1][:-1].copy().astype(np.int64)
460
+ timesteps += self.steps_offset
461
+ elif self.timestep_spacing == "trailing":
462
+ step_ratio = self.num_train_timesteps / num_inference_steps
463
+ # creates integer timesteps by multiplying by ratio
464
+ # casting to int to avoid issues when num_inference_step is power of 3
465
+ timesteps = np.arange(self.num_train_timesteps, 0, -step_ratio).round().copy().astype(np.int64)
466
+ timesteps -= 1
467
+ else:
468
+ raise ValueError(
469
+ f"{self.timestep_spacing} is not supported. Please make sure to choose one of 'linspace', 'leading' or 'trailing'."
470
+ )
471
+
472
+ sigmas = np.array(((1 - self.alphas_cumprod) / self.alphas_cumprod) ** 0.5)
473
+ if self.use_karras_sigmas:
474
+ log_sigmas = np.log(sigmas)
475
+ sigmas = np.flip(sigmas).copy()
476
+ sigmas = self._convert_to_karras(in_sigmas=sigmas, num_inference_steps=num_inference_steps)
477
+ timesteps = np.array([self._sigma_to_t(sigma, log_sigmas) for sigma in sigmas]).round()
478
+ sigmas = np.concatenate([sigmas, sigmas[-1:]]).astype(np.float32)
479
+ else:
480
+ sigmas = np.interp(timesteps, np.arange(0, len(sigmas)), sigmas)
481
+ sigma_last = ((1 - self.alphas_cumprod[0]) / self.alphas_cumprod[0]) ** 0.5
482
+ sigmas = np.concatenate([sigmas, [sigma_last]]).astype(np.float32)
483
+
484
+ self.sigmas = torch.from_numpy(sigmas)
485
+ self.timesteps = torch.from_numpy(timesteps).to(device=self.device, dtype=torch.int64)
486
+
487
+ self.num_inference_steps = len(timesteps)
488
+
489
+ self.model_outputs = [
490
+ None,
491
+ ] * self.solver_order
492
+ self.lower_order_nums = 0
493
+ self.last_sample = None
494
+
495
+ # add an index counter for schedulers that allow duplicated timesteps
496
+ self._step_index = None
497
+
498
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
499
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
500
+ dtype = sample.dtype
501
+ batch_size, channels, *remaining_dims = sample.shape
502
+
503
+ if dtype not in (torch.float32, torch.float64):
504
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
505
+
506
+ # Flatten sample for doing quantile calculation along each image
507
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
508
+
509
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
510
+
511
+ s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1)
512
+ s = torch.clamp(
513
+ s, min=1, max=self.sample_max_value
514
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
515
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
516
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
517
+
518
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
519
+ sample = sample.to(dtype)
520
+
521
+ return sample
522
+
523
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._sigma_to_t
524
+ def _sigma_to_t(self, sigma, log_sigmas):
525
+ # get log sigma
526
+ log_sigma = np.log(np.maximum(sigma, 1e-10))
527
+
528
+ # get distribution
529
+ dists = log_sigma - log_sigmas[:, np.newaxis]
530
+
531
+ # get sigmas range
532
+ low_idx = np.cumsum((dists >= 0), axis=0).argmax(axis=0).clip(max=log_sigmas.shape[0] - 2)
533
+ high_idx = low_idx + 1
534
+
535
+ low = log_sigmas[low_idx]
536
+ high = log_sigmas[high_idx]
537
+
538
+ # interpolate sigmas
539
+ w = (low - log_sigma) / (low - high)
540
+ w = np.clip(w, 0, 1)
541
+
542
+ # transform interpolation to time range
543
+ t = (1 - w) * low_idx + w * high_idx
544
+ t = t.reshape(sigma.shape)
545
+ return t
546
+
547
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler._sigma_to_alpha_sigma_t
548
+ def _sigma_to_alpha_sigma_t(self, sigma):
549
+ alpha_t = 1 / ((sigma**2 + 1) ** 0.5)
550
+ sigma_t = sigma * alpha_t
551
+
552
+ return alpha_t, sigma_t
553
+
554
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._convert_to_karras
555
+ def _convert_to_karras(self, in_sigmas: torch.FloatTensor, num_inference_steps) -> torch.FloatTensor:
556
+ """Constructs the noise schedule of Karras et al. (2022)."""
557
+
558
+ sigma_min = self.sigma_min
559
+ sigma_max = self.sigma_max
560
+
561
+ sigma_min = sigma_min if sigma_min is not None else in_sigmas[-1].item()
562
+ sigma_max = sigma_max if sigma_max is not None else in_sigmas[0].item()
563
+
564
+ rho = 7.0 # 7.0 is the value used in the paper
565
+ ramp = np.linspace(0, 1, num_inference_steps)
566
+ min_inv_rho = sigma_min ** (1 / rho)
567
+ max_inv_rho = sigma_max ** (1 / rho)
568
+ sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho
569
+ return sigmas
570
+
571
+ def convert_model_output(
572
+ self,
573
+ model_output: torch.FloatTensor,
574
+ *args,
575
+ sample: torch.FloatTensor = None,
576
+ **kwargs,
577
+ ) -> torch.FloatTensor:
578
+ timestep = args[0] if len(args) > 0 else kwargs.pop("timestep", None)
579
+ if sample is None:
580
+ if len(args) > 1:
581
+ sample = args[1]
582
+ else:
583
+ raise ValueError("missing `sample` as a required keyword argument")
584
+ if timestep is not None:
585
+ print(
586
+ "Passing `timesteps` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
587
+ )
588
+
589
+ sigma = self.sigmas[self.step_index]
590
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
591
+
592
+ if self.predict_x0:
593
+ if self.prediction_type == "epsilon":
594
+ x0_pred = (sample - sigma_t * model_output) / alpha_t
595
+ elif self.prediction_type == "sample":
596
+ x0_pred = model_output
597
+ elif self.prediction_type == "v_prediction":
598
+ x0_pred = alpha_t * sample - sigma_t * model_output
599
+ else:
600
+ raise ValueError(
601
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
602
+ " `v_prediction` for the UniPCMultistepScheduler."
603
+ )
604
+
605
+ if self.thresholding:
606
+ x0_pred = self._threshold_sample(x0_pred)
607
+
608
+ return x0_pred
609
+ else:
610
+ if self.prediction_type == "epsilon":
611
+ return model_output
612
+ elif self.prediction_type == "sample":
613
+ epsilon = (sample - alpha_t * model_output) / sigma_t
614
+ return epsilon
615
+ elif self.prediction_type == "v_prediction":
616
+ epsilon = alpha_t * model_output + sigma_t * sample
617
+ return epsilon
618
+ else:
619
+ raise ValueError(
620
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample`, or"
621
+ " `v_prediction` for the UniPCMultistepScheduler."
622
+ )
623
+
624
+ def multistep_uni_p_bh_update(
625
+ self,
626
+ model_output: torch.FloatTensor,
627
+ *args,
628
+ sample: torch.FloatTensor = None,
629
+ order: int | None = None,
630
+ **kwargs,
631
+ ) -> torch.FloatTensor:
632
+ prev_timestep = args[0] if len(args) > 0 else kwargs.pop("prev_timestep", None)
633
+ if sample is None:
634
+ if len(args) > 1:
635
+ sample = args[1]
636
+ else:
637
+ raise ValueError(" missing `sample` as a required keyword argument")
638
+ if order is None:
639
+ if len(args) > 2:
640
+ order = args[2]
641
+ else:
642
+ raise ValueError(" missing `order` as a required keyword argument")
643
+ if prev_timestep is not None:
644
+ print(
645
+ "Passing `prev_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
646
+ )
647
+ model_output_list = self.model_outputs
648
+
649
+ # s0 = self.timestep_list[-1]
650
+ m0 = model_output_list[-1]
651
+ x = sample
652
+
653
+ sigma_t, sigma_s0 = self.sigmas[self.step_index + 1], self.sigmas[self.step_index]
654
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
655
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
656
+
657
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
658
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
659
+
660
+ h = lambda_t - lambda_s0
661
+ device = sample.device
662
+
663
+ rks = []
664
+ d1s = []
665
+ for i in range(1, order):
666
+ si = self.step_index - i
667
+ mi = model_output_list[-(i + 1)]
668
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
669
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
670
+ rk = (lambda_si - lambda_s0) / h
671
+ rks.append(rk)
672
+ d1s.append((mi - m0) / rk)
673
+
674
+ rks.append(1.0)
675
+ rks = torch.tensor(rks, device=device)
676
+
677
+ r = []
678
+ b = []
679
+
680
+ hh = -h if self.predict_x0 else h
681
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
682
+ h_phi_k = h_phi_1 / hh - 1
683
+
684
+ factorial_i = 1
685
+
686
+ if self.solver_type == "bh1":
687
+ b_h = hh
688
+ elif self.solver_type == "bh2":
689
+ b_h = torch.expm1(hh)
690
+ else:
691
+ raise NotImplementedError()
692
+
693
+ for i in range(1, order + 1):
694
+ r.append(torch.pow(rks, i - 1))
695
+ b.append(h_phi_k * factorial_i / b_h)
696
+ factorial_i *= i + 1
697
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
698
+
699
+ r = torch.stack(r)
700
+ b = torch.tensor(b, device=device)
701
+
702
+ if len(d1s) > 0:
703
+ d1s = torch.stack(d1s, dim=1) # (B, K)
704
+ # for order 2, we use a simplified version
705
+ if order == 2:
706
+ rhos_p = torch.tensor([0.5], dtype=x.dtype, device=device)
707
+ else:
708
+ rhos_p = torch.linalg.solve(r[:-1, :-1], b[:-1])
709
+ else:
710
+ d1s = None
711
+
712
+ if self.predict_x0:
713
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
714
+ if d1s is not None:
715
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s)
716
+ else:
717
+ pred_res = 0
718
+ x_t = x_t_ - alpha_t * b_h * pred_res
719
+ else:
720
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
721
+ if d1s is not None:
722
+ pred_res = torch.einsum("k,bkc...->bc...", rhos_p, d1s)
723
+ else:
724
+ pred_res = 0
725
+ x_t = x_t_ - sigma_t * b_h * pred_res
726
+
727
+ x_t = x_t.to(x.dtype)
728
+ return x_t
729
+
730
+ def multistep_uni_c_bh_update(
731
+ self,
732
+ this_model_output: torch.FloatTensor,
733
+ *args,
734
+ last_sample: torch.FloatTensor = None,
735
+ this_sample: torch.FloatTensor = None,
736
+ order: int | None = None,
737
+ **kwargs,
738
+ ) -> torch.FloatTensor:
739
+ this_timestep = args[0] if len(args) > 0 else kwargs.pop("this_timestep", None)
740
+ if last_sample is None:
741
+ if len(args) > 1:
742
+ last_sample = args[1]
743
+ else:
744
+ raise ValueError(" missing`last_sample` as a required keyword argument")
745
+ if this_sample is None:
746
+ if len(args) > 2:
747
+ this_sample = args[2]
748
+ else:
749
+ raise ValueError(" missing`this_sample` as a required keyword argument")
750
+ if order is None:
751
+ if len(args) > 3:
752
+ order = args[3]
753
+ else:
754
+ raise ValueError(" missing`order` as a required keyword argument")
755
+ if this_timestep is not None:
756
+ print(
757
+ "Passing `this_timestep` is deprecated and has no effect as model output conversion is now handled via an internal counter `self.step_index`",
758
+ )
759
+
760
+ model_output_list = self.model_outputs
761
+
762
+ m0 = model_output_list[-1]
763
+ x = last_sample
764
+ # x_t = this_sample
765
+ model_t = this_model_output
766
+
767
+ sigma_t, sigma_s0 = self.sigmas[self.step_index], self.sigmas[self.step_index - 1]
768
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma_t)
769
+ alpha_s0, sigma_s0 = self._sigma_to_alpha_sigma_t(sigma_s0)
770
+
771
+ lambda_t = torch.log(alpha_t) - torch.log(sigma_t)
772
+ lambda_s0 = torch.log(alpha_s0) - torch.log(sigma_s0)
773
+
774
+ h = lambda_t - lambda_s0
775
+ device = this_sample.device
776
+
777
+ rks = []
778
+ d1s = []
779
+ for i in range(1, order):
780
+ si = self.step_index - (i + 1)
781
+ mi = model_output_list[-(i + 1)]
782
+ alpha_si, sigma_si = self._sigma_to_alpha_sigma_t(self.sigmas[si])
783
+ lambda_si = torch.log(alpha_si) - torch.log(sigma_si)
784
+ rk = (lambda_si - lambda_s0) / h
785
+ rks.append(rk)
786
+ d1s.append((mi - m0) / rk)
787
+
788
+ rks.append(1.0)
789
+ rks = torch.tensor(rks, device=device)
790
+
791
+ r = []
792
+ b = []
793
+
794
+ hh = -h if self.predict_x0 else h
795
+ h_phi_1 = torch.expm1(hh) # h\phi_1(h) = e^h - 1
796
+ h_phi_k = h_phi_1 / hh - 1
797
+
798
+ factorial_i = 1
799
+
800
+ if self.solver_type == "bh1":
801
+ b_h = hh
802
+ elif self.solver_type == "bh2":
803
+ b_h = torch.expm1(hh)
804
+ else:
805
+ raise NotImplementedError()
806
+
807
+ for i in range(1, order + 1):
808
+ r.append(torch.pow(rks, i - 1))
809
+ b.append(h_phi_k * factorial_i / b_h)
810
+ factorial_i *= i + 1
811
+ h_phi_k = h_phi_k / hh - 1 / factorial_i
812
+
813
+ r = torch.stack(r)
814
+ b = torch.tensor(b, device=device)
815
+
816
+ if len(d1s) > 0:
817
+ d1s = torch.stack(d1s, dim=1)
818
+ else:
819
+ d1s = None
820
+
821
+ # for order 1, we use a simplified version
822
+ if order == 1:
823
+ rhos_c = torch.tensor([0.5], dtype=x.dtype, device=device)
824
+ else:
825
+ rhos_c = torch.linalg.solve(r, b)
826
+
827
+ if self.predict_x0:
828
+ x_t_ = sigma_t / sigma_s0 * x - alpha_t * h_phi_1 * m0
829
+ if d1s is not None:
830
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s)
831
+ else:
832
+ corr_res = 0
833
+ d1_t = model_t - m0
834
+ x_t = x_t_ - alpha_t * b_h * (corr_res + rhos_c[-1] * d1_t)
835
+ else:
836
+ x_t_ = alpha_t / alpha_s0 * x - sigma_t * h_phi_1 * m0
837
+ if d1s is not None:
838
+ corr_res = torch.einsum("k,bkc...->bc...", rhos_c[:-1], d1s)
839
+ else:
840
+ corr_res = 0
841
+ d1_t = model_t - m0
842
+ x_t = x_t_ - sigma_t * b_h * (corr_res + rhos_c[-1] * d1_t)
843
+ x_t = x_t.to(x.dtype)
844
+ return x_t
845
+
846
+ def _init_step_index(self, timestep):
847
+ if isinstance(timestep, torch.Tensor):
848
+ timestep = timestep.to(self.timesteps.device)
849
+
850
+ index_candidates = (self.timesteps == timestep).nonzero()
851
+
852
+ if len(index_candidates) == 0:
853
+ step_index = len(self.timesteps) - 1
854
+ # The sigma index that is taken for the **very** first `step`
855
+ # is always the second index (or the last index if there is only 1)
856
+ # This way we can ensure we don't accidentally skip a sigma in
857
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
858
+ elif len(index_candidates) > 1:
859
+ step_index = index_candidates[1].item()
860
+ else:
861
+ step_index = index_candidates[0].item()
862
+
863
+ self._step_index = step_index
864
+
865
+ def step(
866
+ self,
867
+ model_output: torch.FloatTensor,
868
+ timestep: int,
869
+ sample: torch.FloatTensor,
870
+ return_dict: bool = True,
871
+ ):
872
+ if self.num_inference_steps is None:
873
+ raise ValueError(
874
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
875
+ )
876
+
877
+ if self.step_index is None:
878
+ self._init_step_index(timestep)
879
+
880
+ use_corrector = (
881
+ self.step_index > 0 and self.step_index - 1 not in self.disable_corrector and self.last_sample is not None
882
+ )
883
+
884
+ model_output_convert = self.convert_model_output(model_output, sample=sample)
885
+ if use_corrector:
886
+ sample = self.multistep_uni_c_bh_update(
887
+ this_model_output=model_output_convert,
888
+ last_sample=self.last_sample,
889
+ this_sample=sample,
890
+ order=self.this_order,
891
+ )
892
+
893
+ for i in range(self.solver_order - 1):
894
+ self.model_outputs[i] = self.model_outputs[i + 1]
895
+ self.timestep_list[i] = self.timestep_list[i + 1]
896
+
897
+ self.model_outputs[-1] = model_output_convert
898
+ self.timestep_list[-1] = timestep
899
+
900
+ if self.lower_order_final:
901
+ this_order = min(self.solver_order, len(self.timesteps) - self.step_index)
902
+ else:
903
+ this_order = self.solver_order
904
+
905
+ self.this_order = min(this_order, self.lower_order_nums + 1) # warmup for multistep
906
+ assert self.this_order > 0
907
+
908
+ self.last_sample = sample
909
+ prev_sample = self.multistep_uni_p_bh_update(
910
+ model_output=model_output, # pass the original non-converted model output, in case solver-p is used
911
+ sample=sample,
912
+ order=self.this_order,
913
+ )
914
+
915
+ if self.lower_order_nums < self.solver_order:
916
+ self.lower_order_nums += 1
917
+
918
+ # upon completion increase step index by one
919
+ self._step_index += 1
920
+
921
+ if not return_dict:
922
+ return (prev_sample,)
923
+
924
+ return prev_sample
925
+
926
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
927
+ return sample
928
+
929
+ def add_noise(
930
+ self,
931
+ original_samples: torch.FloatTensor,
932
+ noise: torch.FloatTensor,
933
+ idx,
934
+ timesteps: torch.IntTensor,
935
+ ) -> torch.FloatTensor:
936
+ # Make sure sigmas and timesteps have the same device and dtype as original_samples
937
+ sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
938
+ schedule_timesteps = self.timesteps.to(original_samples.device)
939
+ timesteps = timesteps.to(original_samples.device)
940
+
941
+ step_indices = [(schedule_timesteps == t).nonzero().item() for t in timesteps]
942
+ sigma = sigmas[step_indices].flatten()
943
+ while len(sigma.shape) < len(original_samples.shape):
944
+ sigma = sigma.unsqueeze(-1)
945
+
946
+ alpha_t, sigma_t = self._sigma_to_alpha_sigma_t(sigma)
947
+ noisy_samples = alpha_t * original_samples + sigma_t * noise
948
+ return noisy_samples
949
+
950
+ def configure(self):
951
+ pass
952
+
953
+ def __len__(self):
954
+ return self.num_train_timesteps
955
+
956
+
957
+ # Modified from diffusers.schedulers.LCMScheduler
958
+ class LCMScheduler:
959
+ def __init__(
960
+ self,
961
+ device="cuda",
962
+ num_train_timesteps: int = 1000,
963
+ beta_start: float = 0.00085,
964
+ beta_end: float = 0.012,
965
+ original_inference_steps: int = 50,
966
+ clip_sample: bool = False,
967
+ clip_sample_range: float = 1.0,
968
+ steps_offset: int = 0,
969
+ prediction_type: str = "epsilon",
970
+ thresholding: bool = False,
971
+ dynamic_thresholding_ratio: float = 0.995,
972
+ sample_max_value: float = 1.0,
973
+ timestep_spacing: str = "leading",
974
+ timestep_scaling: float = 10.0,
975
+ ):
976
+ self.device = device
977
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
978
+ self.alphas = 1.0 - self.betas
979
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
980
+ self.final_alpha_cumprod = self.alphas_cumprod[0]
981
+ # standard deviation of the initial noise distribution
982
+ self.init_noise_sigma = 1.0
983
+ # setable values
984
+ self.num_inference_steps = None
985
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
986
+
987
+ self.num_train_timesteps = num_train_timesteps
988
+ self.clip_sample = clip_sample
989
+ self.clip_sample_range = clip_sample_range
990
+ self.steps_offset = steps_offset
991
+ self.prediction_type = prediction_type
992
+ self.thresholding = thresholding
993
+ self.timestep_spacing = timestep_spacing
994
+ self.timestep_scaling = timestep_scaling
995
+ self.original_inference_steps = original_inference_steps
996
+ self.dynamic_thresholding_ratio = dynamic_thresholding_ratio
997
+ self.sample_max_value = sample_max_value
998
+
999
+ self._step_index = None
1000
+
1001
+ # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler._init_step_index
1002
+ def _init_step_index(self, timestep):
1003
+ if isinstance(timestep, torch.Tensor):
1004
+ timestep = timestep.to(self.timesteps.device)
1005
+
1006
+ index_candidates = (self.timesteps == timestep).nonzero()
1007
+
1008
+ if len(index_candidates) > 1:
1009
+ step_index = index_candidates[1]
1010
+ else:
1011
+ step_index = index_candidates[0]
1012
+
1013
+ self._step_index = step_index.item()
1014
+
1015
+ @property
1016
+ def step_index(self):
1017
+ return self._step_index
1018
+
1019
+ def scale_model_input(self, sample: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
1020
+ return sample
1021
+
1022
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
1023
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
1024
+ dtype = sample.dtype
1025
+ batch_size, channels, *remaining_dims = sample.shape
1026
+
1027
+ if dtype not in (torch.float32, torch.float64):
1028
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
1029
+
1030
+ # Flatten sample for doing quantile calculation along each image
1031
+ sample = sample.reshape(batch_size, channels * np.prod(remaining_dims))
1032
+
1033
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
1034
+
1035
+ s = torch.quantile(abs_sample, self.dynamic_thresholding_ratio, dim=1)
1036
+ s = torch.clamp(
1037
+ s, min=1, max=self.sample_max_value
1038
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
1039
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
1040
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
1041
+
1042
+ sample = sample.reshape(batch_size, channels, *remaining_dims)
1043
+ sample = sample.to(dtype)
1044
+
1045
+ return sample
1046
+
1047
+ def set_timesteps(
1048
+ self,
1049
+ num_inference_steps: int,
1050
+ strength: int = 1.0,
1051
+ ):
1052
+ assert num_inference_steps <= self.num_train_timesteps
1053
+
1054
+ self.num_inference_steps = num_inference_steps
1055
+ original_steps = self.original_inference_steps
1056
+
1057
+ assert original_steps <= self.num_train_timesteps
1058
+ assert num_inference_steps <= original_steps
1059
+
1060
+ # LCM Timesteps Setting
1061
+ # Currently, only linear spacing is supported.
1062
+ c = self.num_train_timesteps // original_steps
1063
+ # LCM Training Steps Schedule
1064
+ lcm_origin_timesteps = np.asarray(list(range(1, int(original_steps * strength) + 1))) * c - 1
1065
+ skipping_step = len(lcm_origin_timesteps) // num_inference_steps
1066
+ # LCM Inference Steps Schedule
1067
+ timesteps = lcm_origin_timesteps[::-skipping_step][:num_inference_steps]
1068
+
1069
+ self.timesteps = torch.from_numpy(timesteps.copy()).to(device=self.device, dtype=torch.long)
1070
+
1071
+ self._step_index = None
1072
+
1073
+ def get_scalings_for_boundary_condition_discrete(self, timestep):
1074
+ self.sigma_data = 0.5 # Default: 0.5
1075
+ scaled_timestep = timestep * self.timestep_scaling
1076
+
1077
+ c_skip = self.sigma_data**2 / (scaled_timestep**2 + self.sigma_data**2)
1078
+ c_out = scaled_timestep / (scaled_timestep**2 + self.sigma_data**2) ** 0.5
1079
+ return c_skip, c_out
1080
+
1081
+ def step(
1082
+ self,
1083
+ model_output: torch.FloatTensor,
1084
+ timestep: int,
1085
+ sample: torch.FloatTensor,
1086
+ generator: torch.Generator | None = None,
1087
+ ):
1088
+ if self.num_inference_steps is None:
1089
+ raise ValueError(
1090
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
1091
+ )
1092
+
1093
+ if self.step_index is None:
1094
+ self._init_step_index(timestep)
1095
+
1096
+ # 1. get previous step value
1097
+ prev_step_index = self.step_index + 1
1098
+ if prev_step_index < len(self.timesteps):
1099
+ prev_timestep = self.timesteps[prev_step_index]
1100
+ else:
1101
+ prev_timestep = timestep
1102
+
1103
+ # 2. compute alphas, betas
1104
+ alpha_prod_t = self.alphas_cumprod[timestep]
1105
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
1106
+
1107
+ beta_prod_t = 1 - alpha_prod_t
1108
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
1109
+
1110
+ # 3. Get scalings for boundary conditions
1111
+ c_skip, c_out = self.get_scalings_for_boundary_condition_discrete(timestep)
1112
+
1113
+ # 4. Compute the predicted original sample x_0 based on the model parameterization
1114
+ if self.prediction_type == "epsilon": # noise-prediction
1115
+ predicted_original_sample = (sample - beta_prod_t.sqrt() * model_output) / alpha_prod_t.sqrt()
1116
+ elif self.prediction_type == "sample": # x-prediction
1117
+ predicted_original_sample = model_output
1118
+ elif self.prediction_type == "v_prediction": # v-prediction
1119
+ predicted_original_sample = alpha_prod_t.sqrt() * sample - beta_prod_t.sqrt() * model_output
1120
+ else:
1121
+ raise ValueError(
1122
+ f"prediction_type given as {self.prediction_type} must be one of `epsilon`, `sample` or"
1123
+ " `v_prediction` for `LCMScheduler`."
1124
+ )
1125
+
1126
+ # 5. Clip or threshold "predicted x_0"
1127
+ if self.thresholding:
1128
+ predicted_original_sample = self._threshold_sample(predicted_original_sample)
1129
+ elif self.clip_sample:
1130
+ predicted_original_sample = predicted_original_sample.clamp(-self.clip_sample_range, self.clip_sample_range)
1131
+
1132
+ # 6. Denoise model output using boundary conditions
1133
+ denoised = c_out * predicted_original_sample + c_skip * sample
1134
+
1135
+ # 7. Sample and inject noise z ~ N(0, I) for MultiStep Inference
1136
+ # Noise is not used on the final timestep of the timestep schedule.
1137
+ # This also means that noise is not used for one-step sampling.
1138
+ if self.step_index != self.num_inference_steps - 1:
1139
+ noise = torch.randn(
1140
+ model_output.shape, device=model_output.device, dtype=denoised.dtype, generator=generator
1141
+ )
1142
+ prev_sample = alpha_prod_t_prev.sqrt() * denoised + beta_prod_t_prev.sqrt() * noise
1143
+ else:
1144
+ prev_sample = denoised
1145
+
1146
+ # upon completion increase step index by one
1147
+ self._step_index += 1
1148
+
1149
+ return (prev_sample,)
1150
+
1151
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
1152
+ def add_noise(
1153
+ self,
1154
+ original_samples: torch.FloatTensor,
1155
+ noise: torch.FloatTensor,
1156
+ timesteps: torch.IntTensor,
1157
+ ) -> torch.FloatTensor:
1158
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
1159
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
1160
+ timesteps = timesteps.to(original_samples.device)
1161
+
1162
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
1163
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
1164
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
1165
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
1166
+
1167
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
1168
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
1169
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
1170
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
1171
+
1172
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
1173
+ return noisy_samples
1174
+
1175
+ def configure(self):
1176
+ pass
1177
+
1178
+ def __len__(self):
1179
+ return self.num_train_timesteps