diffusers 0.23.1__py3-none-any.whl → 0.25.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (238) hide show
  1. diffusers/__init__.py +26 -2
  2. diffusers/commands/fp16_safetensors.py +10 -11
  3. diffusers/configuration_utils.py +13 -8
  4. diffusers/dependency_versions_check.py +0 -1
  5. diffusers/dependency_versions_table.py +5 -5
  6. diffusers/experimental/rl/value_guided_sampling.py +1 -1
  7. diffusers/image_processor.py +463 -51
  8. diffusers/loaders/__init__.py +82 -0
  9. diffusers/loaders/ip_adapter.py +159 -0
  10. diffusers/loaders/lora.py +1553 -0
  11. diffusers/loaders/lora_conversion_utils.py +284 -0
  12. diffusers/loaders/single_file.py +637 -0
  13. diffusers/loaders/textual_inversion.py +455 -0
  14. diffusers/loaders/unet.py +828 -0
  15. diffusers/loaders/utils.py +59 -0
  16. diffusers/models/__init__.py +26 -9
  17. diffusers/models/activations.py +9 -6
  18. diffusers/models/attention.py +301 -29
  19. diffusers/models/attention_flax.py +9 -1
  20. diffusers/models/attention_processor.py +378 -6
  21. diffusers/models/autoencoders/__init__.py +5 -0
  22. diffusers/models/{autoencoder_asym_kl.py → autoencoders/autoencoder_asym_kl.py} +17 -12
  23. diffusers/models/{autoencoder_kl.py → autoencoders/autoencoder_kl.py} +47 -23
  24. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +402 -0
  25. diffusers/models/{autoencoder_tiny.py → autoencoders/autoencoder_tiny.py} +24 -28
  26. diffusers/models/{consistency_decoder_vae.py → autoencoders/consistency_decoder_vae.py} +51 -44
  27. diffusers/models/{vae.py → autoencoders/vae.py} +71 -17
  28. diffusers/models/controlnet.py +59 -39
  29. diffusers/models/controlnet_flax.py +19 -18
  30. diffusers/models/downsampling.py +338 -0
  31. diffusers/models/embeddings.py +112 -29
  32. diffusers/models/embeddings_flax.py +2 -0
  33. diffusers/models/lora.py +131 -1
  34. diffusers/models/modeling_flax_utils.py +14 -8
  35. diffusers/models/modeling_outputs.py +17 -0
  36. diffusers/models/modeling_utils.py +37 -29
  37. diffusers/models/normalization.py +110 -4
  38. diffusers/models/resnet.py +299 -652
  39. diffusers/models/transformer_2d.py +22 -5
  40. diffusers/models/transformer_temporal.py +183 -1
  41. diffusers/models/unet_2d_blocks_flax.py +5 -0
  42. diffusers/models/unet_2d_condition.py +46 -0
  43. diffusers/models/unet_2d_condition_flax.py +13 -13
  44. diffusers/models/unet_3d_blocks.py +957 -173
  45. diffusers/models/unet_3d_condition.py +16 -8
  46. diffusers/models/unet_kandinsky3.py +535 -0
  47. diffusers/models/unet_motion_model.py +48 -33
  48. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  49. diffusers/models/upsampling.py +454 -0
  50. diffusers/models/uvit_2d.py +471 -0
  51. diffusers/models/vae_flax.py +7 -0
  52. diffusers/models/vq_model.py +12 -3
  53. diffusers/optimization.py +16 -9
  54. diffusers/pipelines/__init__.py +137 -76
  55. diffusers/pipelines/amused/__init__.py +62 -0
  56. diffusers/pipelines/amused/pipeline_amused.py +328 -0
  57. diffusers/pipelines/amused/pipeline_amused_img2img.py +347 -0
  58. diffusers/pipelines/amused/pipeline_amused_inpaint.py +378 -0
  59. diffusers/pipelines/animatediff/pipeline_animatediff.py +66 -8
  60. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  61. diffusers/pipelines/auto_pipeline.py +23 -13
  62. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  63. diffusers/pipelines/controlnet/pipeline_controlnet.py +238 -35
  64. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +148 -37
  65. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +155 -41
  66. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +123 -43
  67. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +216 -39
  68. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +106 -34
  69. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  70. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  71. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  72. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  73. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  74. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  75. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  76. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  77. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  78. diffusers/pipelines/deprecated/__init__.py +153 -0
  79. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/__init__.py +3 -3
  80. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion.py +177 -34
  81. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_alt_diffusion_img2img.py +182 -37
  82. diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/pipeline_output.py +1 -1
  83. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/__init__.py +1 -1
  84. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/mel.py +2 -2
  85. diffusers/pipelines/{audio_diffusion → deprecated/audio_diffusion}/pipeline_audio_diffusion.py +4 -4
  86. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/__init__.py +1 -1
  87. diffusers/pipelines/{latent_diffusion_uncond → deprecated/latent_diffusion_uncond}/pipeline_latent_diffusion_uncond.py +4 -4
  88. diffusers/pipelines/{pndm → deprecated/pndm}/__init__.py +1 -1
  89. diffusers/pipelines/{pndm → deprecated/pndm}/pipeline_pndm.py +4 -4
  90. diffusers/pipelines/{repaint → deprecated/repaint}/__init__.py +1 -1
  91. diffusers/pipelines/{repaint → deprecated/repaint}/pipeline_repaint.py +5 -5
  92. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/__init__.py +1 -1
  93. diffusers/pipelines/{score_sde_ve → deprecated/score_sde_ve}/pipeline_score_sde_ve.py +5 -4
  94. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/__init__.py +6 -6
  95. diffusers/pipelines/{spectrogram_diffusion/continous_encoder.py → deprecated/spectrogram_diffusion/continuous_encoder.py} +2 -2
  96. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/midi_utils.py +1 -1
  97. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/notes_encoder.py +2 -2
  98. diffusers/pipelines/{spectrogram_diffusion → deprecated/spectrogram_diffusion}/pipeline_spectrogram_diffusion.py +8 -7
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/__init__.py +55 -0
  100. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_cycle_diffusion.py +34 -13
  101. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_onnx_stable_diffusion_inpaint_legacy.py +7 -6
  102. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_inpaint_legacy.py +12 -11
  103. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_model_editing.py +17 -11
  104. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_paradigms.py +11 -10
  105. diffusers/pipelines/{stable_diffusion → deprecated/stable_diffusion_variants}/pipeline_stable_diffusion_pix2pix_zero.py +14 -13
  106. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/__init__.py +1 -1
  107. diffusers/pipelines/{stochastic_karras_ve → deprecated/stochastic_karras_ve}/pipeline_stochastic_karras_ve.py +4 -4
  108. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/__init__.py +3 -3
  109. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/modeling_text_unet.py +83 -51
  110. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion.py +4 -4
  111. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_dual_guided.py +7 -6
  112. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_image_variation.py +7 -6
  113. diffusers/pipelines/{versatile_diffusion → deprecated/versatile_diffusion}/pipeline_versatile_diffusion_text_to_image.py +7 -6
  114. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/__init__.py +3 -3
  115. diffusers/pipelines/{vq_diffusion → deprecated/vq_diffusion}/pipeline_vq_diffusion.py +5 -5
  116. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  117. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  118. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  119. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  120. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  122. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  123. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +98 -0
  124. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +589 -0
  125. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +654 -0
  126. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +111 -11
  127. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +102 -9
  128. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  129. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  130. diffusers/pipelines/onnx_utils.py +8 -5
  131. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  132. diffusers/pipelines/pipeline_flax_utils.py +11 -8
  133. diffusers/pipelines/pipeline_utils.py +63 -42
  134. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +247 -38
  135. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +3 -3
  136. diffusers/pipelines/stable_diffusion/__init__.py +37 -65
  137. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +75 -78
  138. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  139. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  140. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  141. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +174 -11
  142. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  143. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  144. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +178 -11
  145. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +224 -13
  146. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +74 -20
  147. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +4 -0
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +7 -0
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +5 -0
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +5 -0
  151. diffusers/pipelines/stable_diffusion_attend_and_excite/__init__.py +48 -0
  152. diffusers/pipelines/{stable_diffusion → stable_diffusion_attend_and_excite}/pipeline_stable_diffusion_attend_and_excite.py +6 -2
  153. diffusers/pipelines/stable_diffusion_diffedit/__init__.py +48 -0
  154. diffusers/pipelines/{stable_diffusion → stable_diffusion_diffedit}/pipeline_stable_diffusion_diffedit.py +3 -3
  155. diffusers/pipelines/stable_diffusion_gligen/__init__.py +50 -0
  156. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen.py +3 -2
  157. diffusers/pipelines/{stable_diffusion → stable_diffusion_gligen}/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  158. diffusers/pipelines/stable_diffusion_k_diffusion/__init__.py +60 -0
  159. diffusers/pipelines/{stable_diffusion → stable_diffusion_k_diffusion}/pipeline_stable_diffusion_k_diffusion.py +7 -1
  160. diffusers/pipelines/stable_diffusion_ldm3d/__init__.py +48 -0
  161. diffusers/pipelines/{stable_diffusion → stable_diffusion_ldm3d}/pipeline_stable_diffusion_ldm3d.py +51 -7
  162. diffusers/pipelines/stable_diffusion_panorama/__init__.py +48 -0
  163. diffusers/pipelines/{stable_diffusion → stable_diffusion_panorama}/pipeline_stable_diffusion_panorama.py +57 -8
  164. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +58 -6
  165. diffusers/pipelines/stable_diffusion_sag/__init__.py +48 -0
  166. diffusers/pipelines/{stable_diffusion → stable_diffusion_sag}/pipeline_stable_diffusion_sag.py +68 -10
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +194 -17
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +205 -16
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +206 -17
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +23 -17
  171. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  172. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +652 -0
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +115 -14
  175. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +6 -0
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +23 -3
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +334 -10
  179. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +1331 -0
  180. diffusers/pipelines/unclip/pipeline_unclip.py +2 -1
  181. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +1 -0
  182. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +1 -1
  183. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  184. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  185. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  186. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  187. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -1
  188. diffusers/schedulers/__init__.py +4 -4
  189. diffusers/schedulers/deprecated/__init__.py +50 -0
  190. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  191. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  192. diffusers/schedulers/scheduling_amused.py +162 -0
  193. diffusers/schedulers/scheduling_consistency_models.py +2 -0
  194. diffusers/schedulers/scheduling_ddim.py +1 -3
  195. diffusers/schedulers/scheduling_ddim_inverse.py +2 -7
  196. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  197. diffusers/schedulers/scheduling_ddpm.py +47 -3
  198. diffusers/schedulers/scheduling_ddpm_parallel.py +47 -3
  199. diffusers/schedulers/scheduling_deis_multistep.py +28 -6
  200. diffusers/schedulers/scheduling_dpmsolver_multistep.py +28 -6
  201. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +28 -6
  202. diffusers/schedulers/scheduling_dpmsolver_sde.py +3 -3
  203. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +28 -6
  204. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +59 -3
  205. diffusers/schedulers/scheduling_euler_discrete.py +102 -16
  206. diffusers/schedulers/scheduling_heun_discrete.py +17 -5
  207. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +17 -5
  208. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +17 -5
  209. diffusers/schedulers/scheduling_lcm.py +123 -29
  210. diffusers/schedulers/scheduling_lms_discrete.py +3 -3
  211. diffusers/schedulers/scheduling_pndm.py +1 -3
  212. diffusers/schedulers/scheduling_repaint.py +1 -3
  213. diffusers/schedulers/scheduling_unipc_multistep.py +28 -6
  214. diffusers/schedulers/scheduling_utils.py +3 -1
  215. diffusers/schedulers/scheduling_utils_flax.py +3 -1
  216. diffusers/training_utils.py +1 -1
  217. diffusers/utils/__init__.py +1 -2
  218. diffusers/utils/constants.py +10 -12
  219. diffusers/utils/dummy_pt_objects.py +75 -0
  220. diffusers/utils/dummy_torch_and_transformers_objects.py +105 -0
  221. diffusers/utils/dynamic_modules_utils.py +18 -22
  222. diffusers/utils/export_utils.py +8 -3
  223. diffusers/utils/hub_utils.py +24 -36
  224. diffusers/utils/logging.py +11 -11
  225. diffusers/utils/outputs.py +5 -5
  226. diffusers/utils/peft_utils.py +88 -44
  227. diffusers/utils/state_dict_utils.py +8 -0
  228. diffusers/utils/testing_utils.py +199 -1
  229. diffusers/utils/torch_utils.py +4 -4
  230. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/METADATA +86 -69
  231. diffusers-0.25.0.dist-info/RECORD +360 -0
  232. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/WHEEL +1 -1
  233. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/entry_points.txt +0 -1
  234. diffusers/loaders.py +0 -3336
  235. diffusers-0.23.1.dist-info/RECORD +0 -323
  236. /diffusers/pipelines/{alt_diffusion → deprecated/alt_diffusion}/modeling_roberta_series.py +0 -0
  237. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/LICENSE +0 -0
  238. {diffusers-0.23.1.dist-info → diffusers-0.25.0.dist-info}/top_level.txt +0 -0
@@ -477,8 +477,9 @@ class UnCLIPPipeline(DiffusionPipeline):
477
477
  image = super_res_latents
478
478
  # done super res
479
479
 
480
- # post processing
480
+ self.maybe_free_model_hooks()
481
481
 
482
+ # post processing
482
483
  image = image * 0.5 + 0.5
483
484
  image = image.clamp(0, 1)
484
485
  image = image.cpu().permute(0, 2, 3, 1).float().numpy()
@@ -403,6 +403,7 @@ class UnCLIPImageVariationPipeline(DiffusionPipeline):
403
403
  image = super_res_latents
404
404
 
405
405
  # done super res
406
+ self.maybe_free_model_hooks()
406
407
 
407
408
  # post processing
408
409
 
@@ -19,8 +19,8 @@ import torch
19
19
  import torch.nn as nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...models.autoencoders.vae import DecoderOutput, VectorQuantizer
22
23
  from ...models.modeling_utils import ModelMixin
23
- from ...models.vae import DecoderOutput, VectorQuantizer
24
24
  from ...models.vq_model import VQEncoderOutput
25
25
  from ...utils.accelerate_utils import apply_forward_hook
26
26
 
@@ -17,6 +17,8 @@ import torch
17
17
  import torch.nn as nn
18
18
 
19
19
  from ...models.attention_processor import Attention
20
+ from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
21
+ from ...utils import USE_PEFT_BACKEND
20
22
 
21
23
 
22
24
  class WuerstchenLayerNorm(nn.LayerNorm):
@@ -32,7 +34,8 @@ class WuerstchenLayerNorm(nn.LayerNorm):
32
34
  class TimestepBlock(nn.Module):
33
35
  def __init__(self, c, c_timestep):
34
36
  super().__init__()
35
- self.mapper = nn.Linear(c_timestep, c * 2)
37
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
38
+ self.mapper = linear_cls(c_timestep, c * 2)
36
39
 
37
40
  def forward(self, x, t):
38
41
  a, b = self.mapper(t)[:, :, None, None].chunk(2, dim=1)
@@ -42,10 +45,14 @@ class TimestepBlock(nn.Module):
42
45
  class ResBlock(nn.Module):
43
46
  def __init__(self, c, c_skip=0, kernel_size=3, dropout=0.0):
44
47
  super().__init__()
45
- self.depthwise = nn.Conv2d(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
48
+
49
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
50
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
51
+
52
+ self.depthwise = conv_cls(c + c_skip, c, kernel_size=kernel_size, padding=kernel_size // 2, groups=c)
46
53
  self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
47
54
  self.channelwise = nn.Sequential(
48
- nn.Linear(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), nn.Linear(c * 4, c)
55
+ linear_cls(c, c * 4), nn.GELU(), GlobalResponseNorm(c * 4), nn.Dropout(dropout), linear_cls(c * 4, c)
49
56
  )
50
57
 
51
58
  def forward(self, x, x_skip=None):
@@ -73,10 +80,13 @@ class GlobalResponseNorm(nn.Module):
73
80
  class AttnBlock(nn.Module):
74
81
  def __init__(self, c, c_cond, nhead, self_attn=True, dropout=0.0):
75
82
  super().__init__()
83
+
84
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
85
+
76
86
  self.self_attn = self_attn
77
87
  self.norm = WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6)
78
88
  self.attention = Attention(query_dim=c, heads=nhead, dim_head=c // nhead, dropout=dropout, bias=True)
79
- self.kv_mapper = nn.Sequential(nn.SiLU(), nn.Linear(c_cond, c))
89
+ self.kv_mapper = nn.Sequential(nn.SiLU(), linear_cls(c_cond, c))
80
90
 
81
91
  def forward(self, x, kv):
82
92
  kv = self.kv_mapper(kv)
@@ -28,8 +28,9 @@ from ...models.attention_processor import (
28
28
  AttnAddedKVProcessor,
29
29
  AttnProcessor,
30
30
  )
31
+ from ...models.lora import LoRACompatibleConv, LoRACompatibleLinear
31
32
  from ...models.modeling_utils import ModelMixin
32
- from ...utils import is_torch_version
33
+ from ...utils import USE_PEFT_BACKEND, is_torch_version
33
34
  from .modeling_wuerstchen_common import AttnBlock, ResBlock, TimestepBlock, WuerstchenLayerNorm
34
35
 
35
36
 
@@ -40,12 +41,15 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
40
41
  @register_to_config
41
42
  def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
42
43
  super().__init__()
44
+ conv_cls = nn.Conv2d if USE_PEFT_BACKEND else LoRACompatibleConv
45
+ linear_cls = nn.Linear if USE_PEFT_BACKEND else LoRACompatibleLinear
46
+
43
47
  self.c_r = c_r
44
- self.projection = nn.Conv2d(c_in, c, kernel_size=1)
48
+ self.projection = conv_cls(c_in, c, kernel_size=1)
45
49
  self.cond_mapper = nn.Sequential(
46
- nn.Linear(c_cond, c),
50
+ linear_cls(c_cond, c),
47
51
  nn.LeakyReLU(0.2),
48
- nn.Linear(c, c),
52
+ linear_cls(c, c),
49
53
  )
50
54
 
51
55
  self.blocks = nn.ModuleList()
@@ -55,7 +59,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
55
59
  self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
56
60
  self.out = nn.Sequential(
57
61
  WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
58
- nn.Conv2d(c, c_in * 2, kernel_size=1),
62
+ conv_cls(c, c_in * 2, kernel_size=1),
59
63
  )
60
64
 
61
65
  self.gradient_checkpointing = False
@@ -269,7 +269,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
269
269
  callback_on_step_end_tensor_inputs (`List`, *optional*):
270
270
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
271
271
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
272
- `._callback_tensor_inputs` attribute of your pipeine class.
272
+ `._callback_tensor_inputs` attribute of your pipeline class.
273
273
 
274
274
  Examples:
275
275
 
@@ -234,7 +234,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
234
234
  prior_callback_on_step_end_tensor_inputs (`List`, *optional*):
235
235
  The list of tensor inputs for the `prior_callback_on_step_end` function. The tensors specified in the
236
236
  list will be passed as `callback_kwargs` argument. You will only be able to include variables listed in
237
- the `._callback_tensor_inputs` attribute of your pipeine class.
237
+ the `._callback_tensor_inputs` attribute of your pipeline class.
238
238
  callback_on_step_end (`Callable`, *optional*):
239
239
  A function that calls at the end of each denoising steps during the inference. The function is called
240
240
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -243,7 +243,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
243
243
  callback_on_step_end_tensor_inputs (`List`, *optional*):
244
244
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
245
245
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
246
- `._callback_tensor_inputs` attribute of your pipeine class.
246
+ `._callback_tensor_inputs` attribute of your pipeline class.
247
247
 
248
248
  Examples:
249
249
 
@@ -69,6 +69,10 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
69
69
  This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
70
70
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
71
71
 
72
+ The pipeline also inherits the following loading methods:
73
+ - [`~loaders.LoraLoaderMixin.load_lora_weights`] for loading LoRA weights
74
+ - [`~loaders.LoraLoaderMixin.save_lora_weights`] for saving LoRA weights
75
+
72
76
  Args:
73
77
  prior ([`Prior`]):
74
78
  The canonical unCLIP prior to approximate the image embedding from the text embedding.
@@ -349,7 +353,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
349
353
  callback_on_step_end_tensor_inputs (`List`, *optional*):
350
354
  The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
351
355
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
352
- `._callback_tensor_inputs` attribute of your pipeine class.
356
+ `._callback_tensor_inputs` attribute of your pipeline class.
353
357
 
354
358
  Examples:
355
359
 
@@ -38,6 +38,8 @@ except OptionalDependencyNotAvailable:
38
38
  _dummy_modules.update(get_objects_from_module(dummy_pt_objects))
39
39
 
40
40
  else:
41
+ _import_structure["deprecated"] = ["KarrasVeScheduler", "ScoreSdeVpScheduler"]
42
+ _import_structure["scheduling_amused"] = ["AmusedScheduler"]
41
43
  _import_structure["scheduling_consistency_decoder"] = ["ConsistencyDecoderScheduler"]
42
44
  _import_structure["scheduling_consistency_models"] = ["CMStochasticIterativeScheduler"]
43
45
  _import_structure["scheduling_ddim"] = ["DDIMScheduler"]
@@ -56,12 +58,10 @@ else:
56
58
  _import_structure["scheduling_ipndm"] = ["IPNDMScheduler"]
57
59
  _import_structure["scheduling_k_dpm_2_ancestral_discrete"] = ["KDPM2AncestralDiscreteScheduler"]
58
60
  _import_structure["scheduling_k_dpm_2_discrete"] = ["KDPM2DiscreteScheduler"]
59
- _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
60
61
  _import_structure["scheduling_lcm"] = ["LCMScheduler"]
61
62
  _import_structure["scheduling_pndm"] = ["PNDMScheduler"]
62
63
  _import_structure["scheduling_repaint"] = ["RePaintScheduler"]
63
64
  _import_structure["scheduling_sde_ve"] = ["ScoreSdeVeScheduler"]
64
- _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
65
65
  _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
66
66
  _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
67
67
  _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
@@ -129,6 +129,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
129
129
  except OptionalDependencyNotAvailable:
130
130
  from ..utils.dummy_pt_objects import * # noqa F403
131
131
  else:
132
+ from .deprecated import KarrasVeScheduler, ScoreSdeVpScheduler
133
+ from .scheduling_amused import AmusedScheduler
132
134
  from .scheduling_consistency_decoder import ConsistencyDecoderScheduler
133
135
  from .scheduling_consistency_models import CMStochasticIterativeScheduler
134
136
  from .scheduling_ddim import DDIMScheduler
@@ -147,12 +149,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
147
149
  from .scheduling_ipndm import IPNDMScheduler
148
150
  from .scheduling_k_dpm_2_ancestral_discrete import KDPM2AncestralDiscreteScheduler
149
151
  from .scheduling_k_dpm_2_discrete import KDPM2DiscreteScheduler
150
- from .scheduling_karras_ve import KarrasVeScheduler
151
152
  from .scheduling_lcm import LCMScheduler
152
153
  from .scheduling_pndm import PNDMScheduler
153
154
  from .scheduling_repaint import RePaintScheduler
154
155
  from .scheduling_sde_ve import ScoreSdeVeScheduler
155
- from .scheduling_sde_vp import ScoreSdeVpScheduler
156
156
  from .scheduling_unclip import UnCLIPScheduler
157
157
  from .scheduling_unipc_multistep import UniPCMultistepScheduler
158
158
  from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
@@ -0,0 +1,50 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_torch_available,
9
+ is_transformers_available,
10
+ )
11
+
12
+
13
+ _dummy_objects = {}
14
+ _import_structure = {}
15
+
16
+ try:
17
+ if not (is_transformers_available() and is_torch_available()):
18
+ raise OptionalDependencyNotAvailable()
19
+ except OptionalDependencyNotAvailable:
20
+ from ...utils import dummy_pt_objects # noqa F403
21
+
22
+ _dummy_objects.update(get_objects_from_module(dummy_pt_objects))
23
+ else:
24
+ _import_structure["scheduling_karras_ve"] = ["KarrasVeScheduler"]
25
+ _import_structure["scheduling_sde_vp"] = ["ScoreSdeVpScheduler"]
26
+
27
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
28
+ try:
29
+ if not is_torch_available():
30
+ raise OptionalDependencyNotAvailable()
31
+
32
+ except OptionalDependencyNotAvailable:
33
+ from ..utils.dummy_pt_objects import * # noqa F403
34
+ else:
35
+ from .scheduling_karras_ve import KarrasVeScheduler
36
+ from .scheduling_sde_vp import ScoreSdeVpScheduler
37
+
38
+
39
+ else:
40
+ import sys
41
+
42
+ sys.modules[__name__] = _LazyModule(
43
+ __name__,
44
+ globals()["__file__"],
45
+ _import_structure,
46
+ module_spec=__spec__,
47
+ )
48
+
49
+ for name, value in _dummy_objects.items():
50
+ setattr(sys.modules[__name__], name, value)
@@ -19,10 +19,10 @@ from typing import Optional, Tuple, Union
19
19
  import numpy as np
20
20
  import torch
21
21
 
22
- from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput
24
- from ..utils.torch_utils import randn_tensor
25
- from .scheduling_utils import SchedulerMixin
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import BaseOutput
24
+ from ...utils.torch_utils import randn_tensor
25
+ from ..scheduling_utils import SchedulerMixin
26
26
 
27
27
 
28
28
  @dataclass
@@ -19,9 +19,9 @@ from typing import Union
19
19
 
20
20
  import torch
21
21
 
22
- from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils.torch_utils import randn_tensor
24
- from .scheduling_utils import SchedulerMixin
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils.torch_utils import randn_tensor
24
+ from ..scheduling_utils import SchedulerMixin
25
25
 
26
26
 
27
27
  class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
@@ -79,9 +79,7 @@ class ScoreSdeVpScheduler(SchedulerMixin, ConfigMixin):
79
79
 
80
80
  # TODO(Patrick) better comments + non-PyTorch
81
81
  # postprocess model score
82
- log_mean_coeff = (
83
- -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
84
- )
82
+ log_mean_coeff = -0.25 * t**2 * (self.config.beta_max - self.config.beta_min) - 0.5 * t * self.config.beta_min
85
83
  std = torch.sqrt(1.0 - torch.exp(2.0 * log_mean_coeff))
86
84
  std = std.flatten()
87
85
  while len(std.shape) < len(score.shape):
@@ -0,0 +1,162 @@
1
+ import math
2
+ from dataclasses import dataclass
3
+ from typing import List, Optional, Tuple, Union
4
+
5
+ import torch
6
+
7
+ from ..configuration_utils import ConfigMixin, register_to_config
8
+ from ..utils import BaseOutput
9
+ from .scheduling_utils import SchedulerMixin
10
+
11
+
12
+ def gumbel_noise(t, generator=None):
13
+ device = generator.device if generator is not None else t.device
14
+ noise = torch.zeros_like(t, device=device).uniform_(0, 1, generator=generator).to(t.device)
15
+ return -torch.log((-torch.log(noise.clamp(1e-20))).clamp(1e-20))
16
+
17
+
18
+ def mask_by_random_topk(mask_len, probs, temperature=1.0, generator=None):
19
+ confidence = torch.log(probs.clamp(1e-20)) + temperature * gumbel_noise(probs, generator=generator)
20
+ sorted_confidence = torch.sort(confidence, dim=-1).values
21
+ cut_off = torch.gather(sorted_confidence, 1, mask_len.long())
22
+ masking = confidence < cut_off
23
+ return masking
24
+
25
+
26
+ @dataclass
27
+ class AmusedSchedulerOutput(BaseOutput):
28
+ """
29
+ Output class for the scheduler's `step` function output.
30
+
31
+ Args:
32
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
33
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
34
+ denoising loop.
35
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
37
+ `pred_original_sample` can be used to preview progress or for guidance.
38
+ """
39
+
40
+ prev_sample: torch.FloatTensor
41
+ pred_original_sample: torch.FloatTensor = None
42
+
43
+
44
+ class AmusedScheduler(SchedulerMixin, ConfigMixin):
45
+ order = 1
46
+
47
+ temperatures: torch.Tensor
48
+
49
+ @register_to_config
50
+ def __init__(
51
+ self,
52
+ mask_token_id: int,
53
+ masking_schedule: str = "cosine",
54
+ ):
55
+ self.temperatures = None
56
+ self.timesteps = None
57
+
58
+ def set_timesteps(
59
+ self,
60
+ num_inference_steps: int,
61
+ temperature: Union[int, Tuple[int, int], List[int]] = (2, 0),
62
+ device: Union[str, torch.device] = None,
63
+ ):
64
+ self.timesteps = torch.arange(num_inference_steps, device=device).flip(0)
65
+
66
+ if isinstance(temperature, (tuple, list)):
67
+ self.temperatures = torch.linspace(temperature[0], temperature[1], num_inference_steps, device=device)
68
+ else:
69
+ self.temperatures = torch.linspace(temperature, 0.01, num_inference_steps, device=device)
70
+
71
+ def step(
72
+ self,
73
+ model_output: torch.FloatTensor,
74
+ timestep: torch.long,
75
+ sample: torch.LongTensor,
76
+ starting_mask_ratio: int = 1,
77
+ generator: Optional[torch.Generator] = None,
78
+ return_dict: bool = True,
79
+ ) -> Union[AmusedSchedulerOutput, Tuple]:
80
+ two_dim_input = sample.ndim == 3 and model_output.ndim == 4
81
+
82
+ if two_dim_input:
83
+ batch_size, codebook_size, height, width = model_output.shape
84
+ sample = sample.reshape(batch_size, height * width)
85
+ model_output = model_output.reshape(batch_size, codebook_size, height * width).permute(0, 2, 1)
86
+
87
+ unknown_map = sample == self.config.mask_token_id
88
+
89
+ probs = model_output.softmax(dim=-1)
90
+
91
+ device = probs.device
92
+ probs_ = probs.to(generator.device) if generator is not None else probs # handles when generator is on CPU
93
+ if probs_.device.type == "cpu" and probs_.dtype != torch.float32:
94
+ probs_ = probs_.float() # multinomial is not implemented for cpu half precision
95
+ probs_ = probs_.reshape(-1, probs.size(-1))
96
+ pred_original_sample = torch.multinomial(probs_, 1, generator=generator).to(device=device)
97
+ pred_original_sample = pred_original_sample[:, 0].view(*probs.shape[:-1])
98
+ pred_original_sample = torch.where(unknown_map, pred_original_sample, sample)
99
+
100
+ if timestep == 0:
101
+ prev_sample = pred_original_sample
102
+ else:
103
+ seq_len = sample.shape[1]
104
+ step_idx = (self.timesteps == timestep).nonzero()
105
+ ratio = (step_idx + 1) / len(self.timesteps)
106
+
107
+ if self.config.masking_schedule == "cosine":
108
+ mask_ratio = torch.cos(ratio * math.pi / 2)
109
+ elif self.config.masking_schedule == "linear":
110
+ mask_ratio = 1 - ratio
111
+ else:
112
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
113
+
114
+ mask_ratio = starting_mask_ratio * mask_ratio
115
+
116
+ mask_len = (seq_len * mask_ratio).floor()
117
+ # do not mask more than amount previously masked
118
+ mask_len = torch.min(unknown_map.sum(dim=-1, keepdim=True) - 1, mask_len)
119
+ # mask at least one
120
+ mask_len = torch.max(torch.tensor([1], device=model_output.device), mask_len)
121
+
122
+ selected_probs = torch.gather(probs, -1, pred_original_sample[:, :, None])[:, :, 0]
123
+ # Ignores the tokens given in the input by overwriting their confidence.
124
+ selected_probs = torch.where(unknown_map, selected_probs, torch.finfo(selected_probs.dtype).max)
125
+
126
+ masking = mask_by_random_topk(mask_len, selected_probs, self.temperatures[step_idx], generator)
127
+
128
+ # Masks tokens with lower confidence.
129
+ prev_sample = torch.where(masking, self.config.mask_token_id, pred_original_sample)
130
+
131
+ if two_dim_input:
132
+ prev_sample = prev_sample.reshape(batch_size, height, width)
133
+ pred_original_sample = pred_original_sample.reshape(batch_size, height, width)
134
+
135
+ if not return_dict:
136
+ return (prev_sample, pred_original_sample)
137
+
138
+ return AmusedSchedulerOutput(prev_sample, pred_original_sample)
139
+
140
+ def add_noise(self, sample, timesteps, generator=None):
141
+ step_idx = (self.timesteps == timesteps).nonzero()
142
+ ratio = (step_idx + 1) / len(self.timesteps)
143
+
144
+ if self.config.masking_schedule == "cosine":
145
+ mask_ratio = torch.cos(ratio * math.pi / 2)
146
+ elif self.config.masking_schedule == "linear":
147
+ mask_ratio = 1 - ratio
148
+ else:
149
+ raise ValueError(f"unknown masking schedule {self.config.masking_schedule}")
150
+
151
+ mask_indices = (
152
+ torch.rand(
153
+ sample.shape, device=generator.device if generator is not None else sample.device, generator=generator
154
+ ).to(sample.device)
155
+ < mask_ratio
156
+ )
157
+
158
+ masked_sample = sample.clone()
159
+
160
+ masked_sample[mask_indices] = self.config.mask_token_id
161
+
162
+ return masked_sample
@@ -98,6 +98,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
98
98
  self.custom_timesteps = False
99
99
  self.is_scale_input_called = False
100
100
  self._step_index = None
101
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
101
102
 
102
103
  def index_for_timestep(self, timestep, schedule_timesteps=None):
103
104
  if schedule_timesteps is None:
@@ -230,6 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
230
231
  self.timesteps = torch.from_numpy(timesteps).to(device=device)
231
232
 
232
233
  self._step_index = None
234
+ self.sigmas.to("cpu") # to avoid too much CPU/GPU communication
233
235
 
234
236
  # Modified _convert_to_karras implementation that takes in ramp as argument
235
237
  def _convert_to_karras(self, ramp):
@@ -208,9 +208,7 @@ class DDIMScheduler(SchedulerMixin, ConfigMixin):
208
208
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
209
209
  elif beta_schedule == "scaled_linear":
210
210
  # this schedule is very specific to the latent diffusion model.
211
- self.betas = (
212
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
213
- )
211
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
214
212
  elif beta_schedule == "squaredcos_cap_v2":
215
213
  # Glide cosine schedule
216
214
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -204,9 +204,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
204
204
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
205
205
  elif beta_schedule == "scaled_linear":
206
206
  # this schedule is very specific to the latent diffusion model.
207
- self.betas = (
208
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
209
- )
207
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
210
208
  elif beta_schedule == "squaredcos_cap_v2":
211
209
  # Glide cosine schedule
212
210
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -295,9 +293,6 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
295
293
  model_output: torch.FloatTensor,
296
294
  timestep: int,
297
295
  sample: torch.FloatTensor,
298
- eta: float = 0.0,
299
- use_clipped_model_output: bool = False,
300
- variance_noise: Optional[torch.FloatTensor] = None,
301
296
  return_dict: bool = True,
302
297
  ) -> Union[DDIMSchedulerOutput, Tuple]:
303
298
  """
@@ -334,7 +329,7 @@ class DDIMInverseScheduler(SchedulerMixin, ConfigMixin):
334
329
  # 1. get previous step value (=t+1)
335
330
  prev_timestep = timestep
336
331
  timestep = min(
337
- timestep - self.config.num_train_timesteps // self.num_inference_steps, self.num_train_timesteps - 1
332
+ timestep - self.config.num_train_timesteps // self.num_inference_steps, self.config.num_train_timesteps - 1
338
333
  )
339
334
 
340
335
  # 2. compute alphas, betas
@@ -215,9 +215,7 @@ class DDIMParallelScheduler(SchedulerMixin, ConfigMixin):
215
215
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
216
216
  elif beta_schedule == "scaled_linear":
217
217
  # this schedule is very specific to the latent diffusion model.
218
- self.betas = (
219
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
220
- )
218
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
221
219
  elif beta_schedule == "squaredcos_cap_v2":
222
220
  # Glide cosine schedule
223
221
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -89,6 +89,43 @@ def betas_for_alpha_bar(
89
89
  return torch.tensor(betas, dtype=torch.float32)
90
90
 
91
91
 
92
+ # Copied from diffusers.schedulers.scheduling_ddim.rescale_zero_terminal_snr
93
+ def rescale_zero_terminal_snr(betas):
94
+ """
95
+ Rescales betas to have zero terminal SNR Based on https://arxiv.org/pdf/2305.08891.pdf (Algorithm 1)
96
+
97
+
98
+ Args:
99
+ betas (`torch.FloatTensor`):
100
+ the betas that the scheduler is being initialized with.
101
+
102
+ Returns:
103
+ `torch.FloatTensor`: rescaled betas with zero terminal SNR
104
+ """
105
+ # Convert betas to alphas_bar_sqrt
106
+ alphas = 1.0 - betas
107
+ alphas_cumprod = torch.cumprod(alphas, dim=0)
108
+ alphas_bar_sqrt = alphas_cumprod.sqrt()
109
+
110
+ # Store old values.
111
+ alphas_bar_sqrt_0 = alphas_bar_sqrt[0].clone()
112
+ alphas_bar_sqrt_T = alphas_bar_sqrt[-1].clone()
113
+
114
+ # Shift so the last timestep is zero.
115
+ alphas_bar_sqrt -= alphas_bar_sqrt_T
116
+
117
+ # Scale so the first timestep is back to the old value.
118
+ alphas_bar_sqrt *= alphas_bar_sqrt_0 / (alphas_bar_sqrt_0 - alphas_bar_sqrt_T)
119
+
120
+ # Convert alphas_bar_sqrt to betas
121
+ alphas_bar = alphas_bar_sqrt**2 # Revert sqrt
122
+ alphas = alphas_bar[1:] / alphas_bar[:-1] # Revert cumprod
123
+ alphas = torch.cat([alphas_bar[0:1], alphas])
124
+ betas = 1 - alphas
125
+
126
+ return betas
127
+
128
+
92
129
  class DDPMScheduler(SchedulerMixin, ConfigMixin):
93
130
  """
94
131
  `DDPMScheduler` explores the connections between denoising score matching and Langevin dynamics sampling.
@@ -131,6 +168,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
131
168
  An offset added to the inference steps. You can use a combination of `offset=1` and
132
169
  `set_alpha_to_one=False` to make the last step use step 0 for the previous alpha product like in Stable
133
170
  Diffusion.
171
+ rescale_betas_zero_snr (`bool`, defaults to `False`):
172
+ Whether to rescale the betas to have zero terminal SNR. This enables the model to generate very bright and
173
+ dark samples instead of limiting it to samples with medium brightness. Loosely related to
174
+ [`--offset_noise`](https://github.com/huggingface/diffusers/blob/74fd735eb073eb1d774b1ab4154a0876eb82f055/examples/dreambooth/train_dreambooth.py#L506).
134
175
  """
135
176
 
136
177
  _compatibles = [e.name for e in KarrasDiffusionSchedulers]
@@ -153,6 +194,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
153
194
  sample_max_value: float = 1.0,
154
195
  timestep_spacing: str = "leading",
155
196
  steps_offset: int = 0,
197
+ rescale_betas_zero_snr: int = False,
156
198
  ):
157
199
  if trained_betas is not None:
158
200
  self.betas = torch.tensor(trained_betas, dtype=torch.float32)
@@ -160,9 +202,7 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
160
202
  self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
161
203
  elif beta_schedule == "scaled_linear":
162
204
  # this schedule is very specific to the latent diffusion model.
163
- self.betas = (
164
- torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
165
- )
205
+ self.betas = torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
166
206
  elif beta_schedule == "squaredcos_cap_v2":
167
207
  # Glide cosine schedule
168
208
  self.betas = betas_for_alpha_bar(num_train_timesteps)
@@ -173,6 +213,10 @@ class DDPMScheduler(SchedulerMixin, ConfigMixin):
173
213
  else:
174
214
  raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
175
215
 
216
+ # Rescale for zero SNR
217
+ if rescale_betas_zero_snr:
218
+ self.betas = rescale_zero_terminal_snr(self.betas)
219
+
176
220
  self.alphas = 1.0 - self.betas
177
221
  self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
178
222
  self.one = torch.tensor(1.0)