diffusers 0.26.3__py3-none-any.whl → 0.27.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (299) hide show
  1. diffusers/__init__.py +20 -1
  2. diffusers/commands/__init__.py +1 -1
  3. diffusers/commands/diffusers_cli.py +1 -1
  4. diffusers/commands/env.py +1 -1
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +7 -3
  7. diffusers/dependency_versions_check.py +1 -1
  8. diffusers/dependency_versions_table.py +2 -2
  9. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  10. diffusers/image_processor.py +110 -4
  11. diffusers/loaders/autoencoder.py +7 -8
  12. diffusers/loaders/controlnet.py +17 -8
  13. diffusers/loaders/ip_adapter.py +86 -23
  14. diffusers/loaders/lora.py +105 -310
  15. diffusers/loaders/lora_conversion_utils.py +1 -1
  16. diffusers/loaders/peft.py +1 -1
  17. diffusers/loaders/single_file.py +51 -12
  18. diffusers/loaders/single_file_utils.py +274 -49
  19. diffusers/loaders/textual_inversion.py +23 -4
  20. diffusers/loaders/unet.py +195 -41
  21. diffusers/loaders/utils.py +1 -1
  22. diffusers/models/__init__.py +3 -1
  23. diffusers/models/activations.py +9 -9
  24. diffusers/models/attention.py +26 -36
  25. diffusers/models/attention_flax.py +1 -1
  26. diffusers/models/attention_processor.py +171 -114
  27. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  28. diffusers/models/autoencoders/autoencoder_kl.py +3 -1
  29. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  30. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  31. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  32. diffusers/models/autoencoders/vae.py +1 -1
  33. diffusers/models/controlnet.py +1 -1
  34. diffusers/models/controlnet_flax.py +1 -1
  35. diffusers/models/downsampling.py +8 -12
  36. diffusers/models/dual_transformer_2d.py +1 -1
  37. diffusers/models/embeddings.py +3 -4
  38. diffusers/models/embeddings_flax.py +1 -1
  39. diffusers/models/lora.py +33 -10
  40. diffusers/models/modeling_flax_pytorch_utils.py +1 -1
  41. diffusers/models/modeling_flax_utils.py +1 -1
  42. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  43. diffusers/models/modeling_utils.py +4 -6
  44. diffusers/models/normalization.py +1 -1
  45. diffusers/models/resnet.py +31 -58
  46. diffusers/models/resnet_flax.py +1 -1
  47. diffusers/models/t5_film_transformer.py +1 -1
  48. diffusers/models/transformer_2d.py +1 -1
  49. diffusers/models/transformer_temporal.py +1 -1
  50. diffusers/models/transformers/dual_transformer_2d.py +1 -1
  51. diffusers/models/transformers/t5_film_transformer.py +1 -1
  52. diffusers/models/transformers/transformer_2d.py +29 -31
  53. diffusers/models/transformers/transformer_temporal.py +1 -1
  54. diffusers/models/unet_1d.py +1 -1
  55. diffusers/models/unet_1d_blocks.py +1 -1
  56. diffusers/models/unet_2d.py +1 -1
  57. diffusers/models/unet_2d_blocks.py +1 -1
  58. diffusers/models/unet_2d_condition.py +1 -1
  59. diffusers/models/unets/__init__.py +1 -0
  60. diffusers/models/unets/unet_1d.py +1 -1
  61. diffusers/models/unets/unet_1d_blocks.py +1 -1
  62. diffusers/models/unets/unet_2d.py +4 -4
  63. diffusers/models/unets/unet_2d_blocks.py +238 -98
  64. diffusers/models/unets/unet_2d_blocks_flax.py +1 -1
  65. diffusers/models/unets/unet_2d_condition.py +420 -323
  66. diffusers/models/unets/unet_2d_condition_flax.py +21 -12
  67. diffusers/models/unets/unet_3d_blocks.py +50 -40
  68. diffusers/models/unets/unet_3d_condition.py +47 -8
  69. diffusers/models/unets/unet_i2vgen_xl.py +75 -30
  70. diffusers/models/unets/unet_kandinsky3.py +1 -1
  71. diffusers/models/unets/unet_motion_model.py +48 -8
  72. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  73. diffusers/models/unets/unet_stable_cascade.py +610 -0
  74. diffusers/models/unets/uvit_2d.py +1 -1
  75. diffusers/models/upsampling.py +10 -16
  76. diffusers/models/vae_flax.py +1 -1
  77. diffusers/models/vq_model.py +1 -1
  78. diffusers/optimization.py +1 -1
  79. diffusers/pipelines/__init__.py +26 -0
  80. diffusers/pipelines/amused/pipeline_amused.py +1 -1
  81. diffusers/pipelines/amused/pipeline_amused_img2img.py +1 -1
  82. diffusers/pipelines/amused/pipeline_amused_inpaint.py +1 -1
  83. diffusers/pipelines/animatediff/pipeline_animatediff.py +162 -417
  84. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +165 -137
  85. diffusers/pipelines/animatediff/pipeline_output.py +7 -6
  86. diffusers/pipelines/audioldm/pipeline_audioldm.py +3 -19
  87. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  88. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +3 -3
  89. diffusers/pipelines/auto_pipeline.py +7 -16
  90. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +2 -2
  93. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  94. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  95. diffusers/pipelines/controlnet/pipeline_controlnet.py +90 -90
  96. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  97. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +98 -90
  98. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +92 -90
  99. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +145 -70
  100. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +126 -89
  101. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +108 -96
  102. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  103. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -1
  104. diffusers/pipelines/ddim/pipeline_ddim.py +1 -1
  105. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -1
  106. diffusers/pipelines/deepfloyd_if/pipeline_if.py +4 -4
  107. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +4 -4
  108. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +5 -5
  109. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +4 -4
  110. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +5 -5
  111. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +5 -5
  112. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +10 -120
  113. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -91
  114. diffusers/pipelines/deprecated/audio_diffusion/mel.py +1 -1
  115. diffusers/pipelines/deprecated/audio_diffusion/pipeline_audio_diffusion.py +1 -1
  116. diffusers/pipelines/deprecated/latent_diffusion_uncond/pipeline_latent_diffusion_uncond.py +1 -1
  117. diffusers/pipelines/deprecated/pndm/pipeline_pndm.py +1 -1
  118. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +1 -1
  119. diffusers/pipelines/deprecated/score_sde_ve/pipeline_score_sde_ve.py +1 -1
  120. diffusers/pipelines/deprecated/spectrogram_diffusion/continuous_encoder.py +1 -1
  121. diffusers/pipelines/deprecated/spectrogram_diffusion/midi_utils.py +1 -1
  122. diffusers/pipelines/deprecated/spectrogram_diffusion/notes_encoder.py +1 -1
  123. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -1
  124. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +5 -4
  125. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +5 -4
  126. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +7 -22
  127. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +5 -39
  128. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +5 -5
  129. diffusers/pipelines/deprecated/stochastic_karras_ve/pipeline_stochastic_karras_ve.py +1 -1
  130. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -22
  131. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -1
  132. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -1
  133. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -2
  134. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +1 -1
  135. diffusers/pipelines/dit/pipeline_dit.py +1 -1
  136. diffusers/pipelines/free_init_utils.py +184 -0
  137. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +22 -104
  138. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +1 -1
  139. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +1 -1
  140. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +1 -1
  141. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +2 -2
  142. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -1
  143. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +1 -1
  145. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +1 -1
  146. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  147. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  148. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +2 -2
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +104 -93
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +112 -74
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/ledits_pp/__init__.py +55 -0
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +1505 -0
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +1797 -0
  155. diffusers/pipelines/ledits_pp/pipeline_output.py +43 -0
  156. diffusers/pipelines/musicldm/pipeline_musicldm.py +3 -19
  157. diffusers/pipelines/onnx_utils.py +1 -1
  158. diffusers/pipelines/paint_by_example/image_encoder.py +1 -1
  159. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +3 -3
  160. diffusers/pipelines/pia/pipeline_pia.py +168 -327
  161. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  162. diffusers/pipelines/pipeline_loading_utils.py +508 -0
  163. diffusers/pipelines/pipeline_utils.py +188 -534
  164. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +56 -10
  165. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +3 -3
  166. diffusers/pipelines/shap_e/camera.py +1 -1
  167. diffusers/pipelines/shap_e/pipeline_shap_e.py +1 -1
  168. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +1 -1
  169. diffusers/pipelines/shap_e/renderer.py +1 -1
  170. diffusers/pipelines/stable_cascade/__init__.py +50 -0
  171. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +482 -0
  172. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +311 -0
  173. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +638 -0
  174. diffusers/pipelines/stable_diffusion/clip_image_project_model.py +1 -1
  175. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +4 -1
  176. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  177. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_img2img.py +2 -2
  178. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion_inpaint.py +1 -1
  179. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +1 -1
  180. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +1 -1
  181. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  182. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  183. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +90 -146
  184. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  185. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +4 -32
  186. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -119
  187. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -119
  188. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +13 -59
  189. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +3 -31
  190. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +5 -33
  191. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -21
  192. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -21
  193. diffusers/pipelines/stable_diffusion/safety_checker.py +1 -1
  194. diffusers/pipelines/stable_diffusion/safety_checker_flax.py +1 -1
  195. diffusers/pipelines/stable_diffusion/stable_unclip_image_normalizer.py +1 -1
  196. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +5 -21
  197. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +9 -38
  198. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -34
  199. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +6 -35
  200. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +7 -6
  201. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +4 -124
  202. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +282 -80
  203. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +94 -46
  204. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +3 -3
  205. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  206. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +6 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_flax_stable_diffusion_xl.py +1 -1
  208. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +96 -148
  209. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +98 -154
  210. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +98 -153
  211. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +25 -87
  212. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +89 -80
  213. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +5 -49
  214. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +80 -88
  215. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +8 -6
  216. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +15 -86
  217. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +20 -93
  218. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +5 -5
  219. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +3 -19
  220. diffusers/pipelines/unclip/pipeline_unclip.py +1 -1
  221. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -1
  222. diffusers/pipelines/unclip/text_proj.py +1 -1
  223. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +35 -35
  224. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  225. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +4 -21
  226. diffusers/pipelines/wuerstchen/modeling_wuerstchen_diffnext.py +2 -2
  227. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -5
  228. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +8 -8
  229. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +1 -1
  230. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +2 -2
  231. diffusers/schedulers/__init__.py +7 -1
  232. diffusers/schedulers/deprecated/scheduling_karras_ve.py +1 -1
  233. diffusers/schedulers/deprecated/scheduling_sde_vp.py +1 -1
  234. diffusers/schedulers/scheduling_consistency_models.py +42 -19
  235. diffusers/schedulers/scheduling_ddim.py +2 -4
  236. diffusers/schedulers/scheduling_ddim_flax.py +13 -5
  237. diffusers/schedulers/scheduling_ddim_inverse.py +2 -4
  238. diffusers/schedulers/scheduling_ddim_parallel.py +2 -4
  239. diffusers/schedulers/scheduling_ddpm.py +2 -4
  240. diffusers/schedulers/scheduling_ddpm_flax.py +1 -1
  241. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -4
  242. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +1 -1
  243. diffusers/schedulers/scheduling_deis_multistep.py +46 -19
  244. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -21
  245. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +1 -1
  246. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +9 -7
  247. diffusers/schedulers/scheduling_dpmsolver_sde.py +35 -35
  248. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +49 -18
  249. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +683 -0
  250. diffusers/schedulers/scheduling_edm_euler.py +381 -0
  251. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +43 -15
  252. diffusers/schedulers/scheduling_euler_discrete.py +42 -17
  253. diffusers/schedulers/scheduling_euler_discrete_flax.py +1 -1
  254. diffusers/schedulers/scheduling_heun_discrete.py +35 -35
  255. diffusers/schedulers/scheduling_ipndm.py +37 -11
  256. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +44 -44
  257. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +44 -44
  258. diffusers/schedulers/scheduling_karras_ve_flax.py +1 -1
  259. diffusers/schedulers/scheduling_lcm.py +38 -14
  260. diffusers/schedulers/scheduling_lms_discrete.py +43 -15
  261. diffusers/schedulers/scheduling_lms_discrete_flax.py +1 -1
  262. diffusers/schedulers/scheduling_pndm.py +2 -4
  263. diffusers/schedulers/scheduling_pndm_flax.py +2 -4
  264. diffusers/schedulers/scheduling_repaint.py +1 -1
  265. diffusers/schedulers/scheduling_sasolver.py +41 -9
  266. diffusers/schedulers/scheduling_sde_ve.py +1 -1
  267. diffusers/schedulers/scheduling_sde_ve_flax.py +1 -1
  268. diffusers/schedulers/scheduling_tcd.py +686 -0
  269. diffusers/schedulers/scheduling_unclip.py +1 -1
  270. diffusers/schedulers/scheduling_unipc_multistep.py +46 -19
  271. diffusers/schedulers/scheduling_utils.py +2 -1
  272. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  273. diffusers/schedulers/scheduling_vq_diffusion.py +1 -1
  274. diffusers/training_utils.py +9 -2
  275. diffusers/utils/__init__.py +2 -1
  276. diffusers/utils/accelerate_utils.py +1 -1
  277. diffusers/utils/constants.py +1 -1
  278. diffusers/utils/doc_utils.py +1 -1
  279. diffusers/utils/dummy_pt_objects.py +60 -0
  280. diffusers/utils/dummy_torch_and_transformers_objects.py +75 -0
  281. diffusers/utils/dynamic_modules_utils.py +1 -1
  282. diffusers/utils/export_utils.py +3 -3
  283. diffusers/utils/hub_utils.py +60 -16
  284. diffusers/utils/import_utils.py +15 -1
  285. diffusers/utils/loading_utils.py +2 -0
  286. diffusers/utils/logging.py +1 -1
  287. diffusers/utils/model_card_template.md +24 -0
  288. diffusers/utils/outputs.py +14 -7
  289. diffusers/utils/peft_utils.py +1 -1
  290. diffusers/utils/state_dict_utils.py +1 -1
  291. diffusers/utils/testing_utils.py +2 -0
  292. diffusers/utils/torch_utils.py +1 -1
  293. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/METADATA +46 -46
  294. diffusers-0.27.0.dist-info/RECORD +399 -0
  295. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/WHEEL +1 -1
  296. diffusers-0.26.3.dist-info/RECORD +0 -384
  297. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/LICENSE +0 -0
  298. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/entry_points.txt +0 -0
  299. {diffusers-0.26.3.dist-info → diffusers-0.27.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Pix2Pix Zero Authors and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -46,7 +46,7 @@ from ....utils import (
46
46
  unscale_lora_layers,
47
47
  )
48
48
  from ....utils.torch_utils import randn_tensor
49
- from ...pipeline_utils import DiffusionPipeline
49
+ from ...pipeline_utils import DiffusionPipeline, StableDiffusionMixin
50
50
  from ...stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
51
51
  from ...stable_diffusion.safety_checker import StableDiffusionSafetyChecker
52
52
 
@@ -280,7 +280,7 @@ class Pix2PixZeroAttnProcessor:
280
280
  return hidden_states
281
281
 
282
282
 
283
- class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
283
+ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline, StableDiffusionMixin):
284
284
  r"""
285
285
  Pipeline for pixel-level image editing using Pix2Pix Zero. Based on Stable Diffusion.
286
286
 
@@ -463,7 +463,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
463
463
  batch_size = prompt_embeds.shape[0]
464
464
 
465
465
  if prompt_embeds is None:
466
- # textual inversion: procecss multi-vector tokens if necessary
466
+ # textual inversion: process multi-vector tokens if necessary
467
467
  if isinstance(self, TextualInversionLoaderMixin):
468
468
  prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
469
469
 
@@ -545,7 +545,7 @@ class StableDiffusionPix2PixZeroPipeline(DiffusionPipeline):
545
545
  else:
546
546
  uncond_tokens = negative_prompt
547
547
 
548
- # textual inversion: procecss multi-vector tokens if necessary
548
+ # textual inversion: process multi-vector tokens if necessary
549
549
  if isinstance(self, TextualInversionLoaderMixin):
550
550
  uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
551
551
 
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -268,7 +268,6 @@ class GLIGENTextBoundingboxProjection(nn.Module):
268
268
  return objs
269
269
 
270
270
 
271
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
272
271
  class UNetFlatConditionModel(ModelMixin, ConfigMixin):
273
272
  r"""
274
273
  A conditional 2D UNet model that takes a noisy sample, conditional state, and a timestep and returns a sample
@@ -1334,7 +1333,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1334
1333
  **additional_residuals,
1335
1334
  )
1336
1335
  else:
1337
- sample, res_samples = downsample_block(hidden_states=sample, temb=emb, scale=lora_scale)
1336
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
1338
1337
  if is_adapter and len(down_intrablock_additional_residuals) > 0:
1339
1338
  sample += down_intrablock_additional_residuals.pop(0)
1340
1339
 
@@ -1590,7 +1589,7 @@ class DownBlockFlat(nn.Module):
1590
1589
  self.gradient_checkpointing = False
1591
1590
 
1592
1591
  def forward(
1593
- self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None, scale: float = 1.0
1592
+ self, hidden_states: torch.FloatTensor, temb: Optional[torch.FloatTensor] = None
1594
1593
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1595
1594
  output_states = ()
1596
1595
 
@@ -1612,13 +1611,13 @@ class DownBlockFlat(nn.Module):
1612
1611
  create_custom_forward(resnet), hidden_states, temb
1613
1612
  )
1614
1613
  else:
1615
- hidden_states = resnet(hidden_states, temb, scale=scale)
1614
+ hidden_states = resnet(hidden_states, temb)
1616
1615
 
1617
1616
  output_states = output_states + (hidden_states,)
1618
1617
 
1619
1618
  if self.downsamplers is not None:
1620
1619
  for downsampler in self.downsamplers:
1621
- hidden_states = downsampler(hidden_states, scale=scale)
1620
+ hidden_states = downsampler(hidden_states)
1622
1621
 
1623
1622
  output_states = output_states + (hidden_states,)
1624
1623
 
@@ -1729,8 +1728,6 @@ class CrossAttnDownBlockFlat(nn.Module):
1729
1728
  ) -> Tuple[torch.FloatTensor, Tuple[torch.FloatTensor, ...]]:
1730
1729
  output_states = ()
1731
1730
 
1732
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
1733
-
1734
1731
  blocks = list(zip(self.resnets, self.attentions))
1735
1732
 
1736
1733
  for i, (resnet, attn) in enumerate(blocks):
@@ -1761,7 +1758,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1761
1758
  return_dict=False,
1762
1759
  )[0]
1763
1760
  else:
1764
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
1761
+ hidden_states = resnet(hidden_states, temb)
1765
1762
  hidden_states = attn(
1766
1763
  hidden_states,
1767
1764
  encoder_hidden_states=encoder_hidden_states,
@@ -1779,7 +1776,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1779
1776
 
1780
1777
  if self.downsamplers is not None:
1781
1778
  for downsampler in self.downsamplers:
1782
- hidden_states = downsampler(hidden_states, scale=lora_scale)
1779
+ hidden_states = downsampler(hidden_states)
1783
1780
 
1784
1781
  output_states = output_states + (hidden_states,)
1785
1782
 
@@ -1843,8 +1840,13 @@ class UpBlockFlat(nn.Module):
1843
1840
  res_hidden_states_tuple: Tuple[torch.FloatTensor, ...],
1844
1841
  temb: Optional[torch.FloatTensor] = None,
1845
1842
  upsample_size: Optional[int] = None,
1846
- scale: float = 1.0,
1843
+ *args,
1844
+ **kwargs,
1847
1845
  ) -> torch.FloatTensor:
1846
+ if len(args) > 0 or kwargs.get("scale", None) is not None:
1847
+ deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
1848
+ deprecate("scale", "1.0.0", deprecation_message)
1849
+
1848
1850
  is_freeu_enabled = (
1849
1851
  getattr(self, "s1", None)
1850
1852
  and getattr(self, "s2", None)
@@ -1888,11 +1890,11 @@ class UpBlockFlat(nn.Module):
1888
1890
  create_custom_forward(resnet), hidden_states, temb
1889
1891
  )
1890
1892
  else:
1891
- hidden_states = resnet(hidden_states, temb, scale=scale)
1893
+ hidden_states = resnet(hidden_states, temb)
1892
1894
 
1893
1895
  if self.upsamplers is not None:
1894
1896
  for upsampler in self.upsamplers:
1895
- hidden_states = upsampler(hidden_states, upsample_size, scale=scale)
1897
+ hidden_states = upsampler(hidden_states, upsample_size)
1896
1898
 
1897
1899
  return hidden_states
1898
1900
 
@@ -2000,7 +2002,10 @@ class CrossAttnUpBlockFlat(nn.Module):
2000
2002
  attention_mask: Optional[torch.FloatTensor] = None,
2001
2003
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
2002
2004
  ) -> torch.FloatTensor:
2003
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
2005
+ if cross_attention_kwargs is not None:
2006
+ if cross_attention_kwargs.get("scale", None) is not None:
2007
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2008
+
2004
2009
  is_freeu_enabled = (
2005
2010
  getattr(self, "s1", None)
2006
2011
  and getattr(self, "s2", None)
@@ -2054,7 +2059,7 @@ class CrossAttnUpBlockFlat(nn.Module):
2054
2059
  return_dict=False,
2055
2060
  )[0]
2056
2061
  else:
2057
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
2062
+ hidden_states = resnet(hidden_states, temb)
2058
2063
  hidden_states = attn(
2059
2064
  hidden_states,
2060
2065
  encoder_hidden_states=encoder_hidden_states,
@@ -2066,7 +2071,7 @@ class CrossAttnUpBlockFlat(nn.Module):
2066
2071
 
2067
2072
  if self.upsamplers is not None:
2068
2073
  for upsampler in self.upsamplers:
2069
- hidden_states = upsampler(hidden_states, upsample_size, scale=lora_scale)
2074
+ hidden_states = upsampler(hidden_states, upsample_size)
2070
2075
 
2071
2076
  return hidden_states
2072
2077
 
@@ -2159,7 +2164,7 @@ class UNetMidBlockFlat(nn.Module):
2159
2164
  attentions = []
2160
2165
 
2161
2166
  if attention_head_dim is None:
2162
- logger.warn(
2167
+ logger.warning(
2163
2168
  f"It is not recommend to pass `attention_head_dim=None`. Defaulting `attention_head_dim` to `in_channels`: {in_channels}."
2164
2169
  )
2165
2170
  attention_head_dim = in_channels
@@ -2331,8 +2336,11 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2331
2336
  cross_attention_kwargs: Optional[Dict[str, Any]] = None,
2332
2337
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
2333
2338
  ) -> torch.FloatTensor:
2334
- lora_scale = cross_attention_kwargs.get("scale", 1.0) if cross_attention_kwargs is not None else 1.0
2335
- hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
2339
+ if cross_attention_kwargs is not None:
2340
+ if cross_attention_kwargs.get("scale", None) is not None:
2341
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2342
+
2343
+ hidden_states = self.resnets[0](hidden_states, temb)
2336
2344
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2337
2345
  if self.training and self.gradient_checkpointing:
2338
2346
 
@@ -2369,7 +2377,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2369
2377
  encoder_attention_mask=encoder_attention_mask,
2370
2378
  return_dict=False,
2371
2379
  )[0]
2372
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
2380
+ hidden_states = resnet(hidden_states, temb)
2373
2381
 
2374
2382
  return hidden_states
2375
2383
 
@@ -2470,7 +2478,8 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
2470
2478
  encoder_attention_mask: Optional[torch.FloatTensor] = None,
2471
2479
  ) -> torch.FloatTensor:
2472
2480
  cross_attention_kwargs = cross_attention_kwargs if cross_attention_kwargs is not None else {}
2473
- lora_scale = cross_attention_kwargs.get("scale", 1.0)
2481
+ if cross_attention_kwargs.get("scale", None) is not None:
2482
+ logger.warning("Passing `scale` to `cross_attention_kwargs` is depcrecated. `scale` will be ignored.")
2474
2483
 
2475
2484
  if attention_mask is None:
2476
2485
  # if encoder_hidden_states is defined: we are doing cross-attn, so we should use cross-attn mask.
@@ -2483,7 +2492,7 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
2483
2492
  # mask = attention_mask if encoder_hidden_states is None else encoder_attention_mask
2484
2493
  mask = attention_mask
2485
2494
 
2486
- hidden_states = self.resnets[0](hidden_states, temb, scale=lora_scale)
2495
+ hidden_states = self.resnets[0](hidden_states, temb)
2487
2496
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2488
2497
  # attn
2489
2498
  hidden_states = attn(
@@ -2494,6 +2503,6 @@ class UNetMidBlockFlatSimpleCrossAttn(nn.Module):
2494
2503
  )
2495
2504
 
2496
2505
  # resnet
2497
- hidden_states = resnet(hidden_states, temb, scale=lora_scale)
2506
+ hidden_states = resnet(hidden_states, temb)
2498
2507
 
2499
2508
  return hidden_states
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -1,4 +1,4 @@
1
- # Copyright 2023 The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -246,7 +246,6 @@ class VersatileDiffusionTextToImagePipeline(DiffusionPipeline):
246
246
  extra_step_kwargs["generator"] = generator
247
247
  return extra_step_kwargs
248
248
 
249
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.check_inputs
250
249
  def check_inputs(
251
250
  self,
252
251
  prompt,
@@ -1,4 +1,4 @@
1
- # Copyright 2023 Microsoft and The HuggingFace Team. All rights reserved.
1
+ # Copyright 2024 Microsoft and The HuggingFace Team. All rights reserved.
2
2
  #
3
3
  # Licensed under the Apache License, Version 2.0 (the "License");
4
4
  # you may not use this file except in compliance with the License.
@@ -4,7 +4,7 @@
4
4
  # Copyright (c) 2021 OpenAI
5
5
  # MIT License
6
6
  #
7
- # Copyright 2023 The HuggingFace Team. All rights reserved.
7
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
8
8
  #
9
9
  # Licensed under the Apache License, Version 2.0 (the "License");
10
10
  # you may not use this file except in compliance with the License.
@@ -0,0 +1,184 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import math
16
+ from typing import Tuple, Union
17
+
18
+ import torch
19
+ import torch.fft as fft
20
+
21
+ from ..utils.torch_utils import randn_tensor
22
+
23
+
24
+ class FreeInitMixin:
25
+ r"""Mixin class for FreeInit."""
26
+
27
+ def enable_free_init(
28
+ self,
29
+ num_iters: int = 3,
30
+ use_fast_sampling: bool = False,
31
+ method: str = "butterworth",
32
+ order: int = 4,
33
+ spatial_stop_frequency: float = 0.25,
34
+ temporal_stop_frequency: float = 0.25,
35
+ ):
36
+ """Enables the FreeInit mechanism as in https://arxiv.org/abs/2312.07537.
37
+
38
+ This implementation has been adapted from the [official repository](https://github.com/TianxingWu/FreeInit).
39
+
40
+ Args:
41
+ num_iters (`int`, *optional*, defaults to `3`):
42
+ Number of FreeInit noise re-initialization iterations.
43
+ use_fast_sampling (`bool`, *optional*, defaults to `False`):
44
+ Whether or not to speedup sampling procedure at the cost of probably lower quality results. Enables
45
+ the "Coarse-to-Fine Sampling" strategy, as mentioned in the paper, if set to `True`.
46
+ method (`str`, *optional*, defaults to `butterworth`):
47
+ Must be one of `butterworth`, `ideal` or `gaussian` to use as the filtering method for the
48
+ FreeInit low pass filter.
49
+ order (`int`, *optional*, defaults to `4`):
50
+ Order of the filter used in `butterworth` method. Larger values lead to `ideal` method behaviour
51
+ whereas lower values lead to `gaussian` method behaviour.
52
+ spatial_stop_frequency (`float`, *optional*, defaults to `0.25`):
53
+ Normalized stop frequency for spatial dimensions. Must be between 0 to 1. Referred to as `d_s` in
54
+ the original implementation.
55
+ temporal_stop_frequency (`float`, *optional*, defaults to `0.25`):
56
+ Normalized stop frequency for temporal dimensions. Must be between 0 to 1. Referred to as `d_t` in
57
+ the original implementation.
58
+ """
59
+ self._free_init_num_iters = num_iters
60
+ self._free_init_use_fast_sampling = use_fast_sampling
61
+ self._free_init_method = method
62
+ self._free_init_order = order
63
+ self._free_init_spatial_stop_frequency = spatial_stop_frequency
64
+ self._free_init_temporal_stop_frequency = temporal_stop_frequency
65
+
66
+ def disable_free_init(self):
67
+ """Disables the FreeInit mechanism if enabled."""
68
+ self._free_init_num_iters = None
69
+
70
+ @property
71
+ def free_init_enabled(self):
72
+ return hasattr(self, "_free_init_num_iters") and self._free_init_num_iters is not None
73
+
74
+ def _get_free_init_freq_filter(
75
+ self,
76
+ shape: Tuple[int, ...],
77
+ device: Union[str, torch.dtype],
78
+ filter_type: str,
79
+ order: float,
80
+ spatial_stop_frequency: float,
81
+ temporal_stop_frequency: float,
82
+ ) -> torch.Tensor:
83
+ r"""Returns the FreeInit filter based on filter type and other input conditions."""
84
+
85
+ time, height, width = shape[-3], shape[-2], shape[-1]
86
+ mask = torch.zeros(shape)
87
+
88
+ if spatial_stop_frequency == 0 or temporal_stop_frequency == 0:
89
+ return mask
90
+
91
+ if filter_type == "butterworth":
92
+
93
+ def retrieve_mask(x):
94
+ return 1 / (1 + (x / spatial_stop_frequency**2) ** order)
95
+ elif filter_type == "gaussian":
96
+
97
+ def retrieve_mask(x):
98
+ return math.exp(-1 / (2 * spatial_stop_frequency**2) * x)
99
+ elif filter_type == "ideal":
100
+
101
+ def retrieve_mask(x):
102
+ return 1 if x <= spatial_stop_frequency * 2 else 0
103
+ else:
104
+ raise NotImplementedError("`filter_type` must be one of gaussian, butterworth or ideal")
105
+
106
+ for t in range(time):
107
+ for h in range(height):
108
+ for w in range(width):
109
+ d_square = (
110
+ ((spatial_stop_frequency / temporal_stop_frequency) * (2 * t / time - 1)) ** 2
111
+ + (2 * h / height - 1) ** 2
112
+ + (2 * w / width - 1) ** 2
113
+ )
114
+ mask[..., t, h, w] = retrieve_mask(d_square)
115
+
116
+ return mask.to(device)
117
+
118
+ def _apply_freq_filter(self, x: torch.Tensor, noise: torch.Tensor, low_pass_filter: torch.Tensor) -> torch.Tensor:
119
+ r"""Noise reinitialization."""
120
+ # FFT
121
+ x_freq = fft.fftn(x, dim=(-3, -2, -1))
122
+ x_freq = fft.fftshift(x_freq, dim=(-3, -2, -1))
123
+ noise_freq = fft.fftn(noise, dim=(-3, -2, -1))
124
+ noise_freq = fft.fftshift(noise_freq, dim=(-3, -2, -1))
125
+
126
+ # frequency mix
127
+ high_pass_filter = 1 - low_pass_filter
128
+ x_freq_low = x_freq * low_pass_filter
129
+ noise_freq_high = noise_freq * high_pass_filter
130
+ x_freq_mixed = x_freq_low + noise_freq_high # mix in freq domain
131
+
132
+ # IFFT
133
+ x_freq_mixed = fft.ifftshift(x_freq_mixed, dim=(-3, -2, -1))
134
+ x_mixed = fft.ifftn(x_freq_mixed, dim=(-3, -2, -1)).real
135
+
136
+ return x_mixed
137
+
138
+ def _apply_free_init(
139
+ self,
140
+ latents: torch.Tensor,
141
+ free_init_iteration: int,
142
+ num_inference_steps: int,
143
+ device: torch.device,
144
+ dtype: torch.dtype,
145
+ generator: torch.Generator,
146
+ ):
147
+ if free_init_iteration == 0:
148
+ self._free_init_initial_noise = latents.detach().clone()
149
+ return latents, self.scheduler.timesteps
150
+
151
+ latent_shape = latents.shape
152
+
153
+ free_init_filter_shape = (1, *latent_shape[1:])
154
+ free_init_freq_filter = self._get_free_init_freq_filter(
155
+ shape=free_init_filter_shape,
156
+ device=device,
157
+ filter_type=self._free_init_method,
158
+ order=self._free_init_order,
159
+ spatial_stop_frequency=self._free_init_spatial_stop_frequency,
160
+ temporal_stop_frequency=self._free_init_temporal_stop_frequency,
161
+ )
162
+
163
+ current_diffuse_timestep = self.scheduler.config.num_train_timesteps - 1
164
+ diffuse_timesteps = torch.full((latent_shape[0],), current_diffuse_timestep).long()
165
+
166
+ z_t = self.scheduler.add_noise(
167
+ original_samples=latents, noise=self._free_init_initial_noise, timesteps=diffuse_timesteps.to(device)
168
+ ).to(dtype=torch.float32)
169
+
170
+ z_rand = randn_tensor(
171
+ shape=latent_shape,
172
+ generator=generator,
173
+ device=device,
174
+ dtype=torch.float32,
175
+ )
176
+ latents = self._apply_freq_filter(z_t, z_rand, low_pass_filter=free_init_freq_filter)
177
+ latents = latents.to(dtype)
178
+
179
+ # Coarse-to-Fine Sampling for faster inference (can lead to lower quality)
180
+ if self._free_init_use_fast_sampling:
181
+ num_inference_steps = int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
182
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
183
+
184
+ return latents, self.scheduler.timesteps