diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -40,15 +40,13 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
40
40
  @register_to_config
41
41
  def __init__(self, c_in=16, c=1280, c_cond=1024, c_r=64, depth=16, nhead=16, dropout=0.1):
42
42
  super().__init__()
43
- conv_cls = nn.Conv2d
44
- linear_cls = nn.Linear
45
43
 
46
44
  self.c_r = c_r
47
- self.projection = conv_cls(c_in, c, kernel_size=1)
45
+ self.projection = nn.Conv2d(c_in, c, kernel_size=1)
48
46
  self.cond_mapper = nn.Sequential(
49
- linear_cls(c_cond, c),
47
+ nn.Linear(c_cond, c),
50
48
  nn.LeakyReLU(0.2),
51
- linear_cls(c, c),
49
+ nn.Linear(c, c),
52
50
  )
53
51
 
54
52
  self.blocks = nn.ModuleList()
@@ -58,7 +56,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
58
56
  self.blocks.append(AttnBlock(c, c, nhead, self_attn=True, dropout=dropout))
59
57
  self.out = nn.Sequential(
60
58
  WuerstchenLayerNorm(c, elementwise_affine=False, eps=1e-6),
61
- conv_cls(c, c_in * 2, kernel_size=1),
59
+ nn.Conv2d(c, c_in * 2, kernel_size=1),
62
60
  )
63
61
 
64
62
  self.gradient_checkpointing = False
@@ -209,7 +209,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
209
209
  @replace_example_docstring(EXAMPLE_DOC_STRING)
210
210
  def __call__(
211
211
  self,
212
- image_embeddings: Union[torch.FloatTensor, List[torch.FloatTensor]],
212
+ image_embeddings: Union[torch.Tensor, List[torch.Tensor]],
213
213
  prompt: Union[str, List[str]] = None,
214
214
  num_inference_steps: int = 12,
215
215
  timesteps: Optional[List[float]] = None,
@@ -217,7 +217,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
217
217
  negative_prompt: Optional[Union[str, List[str]]] = None,
218
218
  num_images_per_prompt: int = 1,
219
219
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
220
- latents: Optional[torch.FloatTensor] = None,
220
+ latents: Optional[torch.Tensor] = None,
221
221
  output_type: Optional[str] = "pil",
222
222
  return_dict: bool = True,
223
223
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -228,7 +228,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
228
228
  Function invoked when calling the pipeline for generation.
229
229
 
230
230
  Args:
231
- image_embedding (`torch.FloatTensor` or `List[torch.FloatTensor]`):
231
+ image_embedding (`torch.Tensor` or `List[torch.Tensor]`):
232
232
  Image Embeddings either extracted from an image or generated by a Prior Model.
233
233
  prompt (`str` or `List[str]`):
234
234
  The prompt or prompts to guide the image generation.
@@ -252,7 +252,7 @@ class WuerstchenDecoderPipeline(DiffusionPipeline):
252
252
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
253
253
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
254
254
  to make generation deterministic.
255
- latents (`torch.FloatTensor`, *optional*):
255
+ latents (`torch.Tensor`, *optional*):
256
256
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
257
257
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
258
258
  tensor will ge generated by sampling using the supplied random `generator`.
@@ -112,25 +112,25 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
112
112
  def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
113
113
  self.decoder_pipe.enable_xformers_memory_efficient_attention(attention_op)
114
114
 
115
- def enable_model_cpu_offload(self, gpu_id=0):
115
+ def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
116
116
  r"""
117
117
  Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
118
118
  to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
119
119
  method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
120
120
  `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
121
121
  """
122
- self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
123
- self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id)
122
+ self.prior_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
123
+ self.decoder_pipe.enable_model_cpu_offload(gpu_id=gpu_id, device=device)
124
124
 
125
- def enable_sequential_cpu_offload(self, gpu_id=0):
125
+ def enable_sequential_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
126
126
  r"""
127
127
  Offloads all models (`unet`, `text_encoder`, `vae`, and `safety checker` state dicts) to CPU using 🤗
128
128
  Accelerate, significantly reducing memory usage. Models are moved to a `torch.device('meta')` and loaded on a
129
129
  GPU only when their specific submodule's `forward` method is called. Offloading happens on a submodule basis.
130
130
  Memory savings are higher than using `enable_model_cpu_offload`, but performance is lower.
131
131
  """
132
- self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
133
- self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id)
132
+ self.prior_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
133
+ self.decoder_pipe.enable_sequential_cpu_offload(gpu_id=gpu_id, device=device)
134
134
 
135
135
  def progress_bar(self, iterable=None, total=None):
136
136
  self.prior_pipe.progress_bar(iterable=iterable, total=total)
@@ -154,11 +154,11 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
154
154
  decoder_timesteps: Optional[List[float]] = None,
155
155
  decoder_guidance_scale: float = 0.0,
156
156
  negative_prompt: Optional[Union[str, List[str]]] = None,
157
- prompt_embeds: Optional[torch.FloatTensor] = None,
158
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
157
+ prompt_embeds: Optional[torch.Tensor] = None,
158
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
159
159
  num_images_per_prompt: int = 1,
160
160
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
161
- latents: Optional[torch.FloatTensor] = None,
161
+ latents: Optional[torch.Tensor] = None,
162
162
  output_type: Optional[str] = "pil",
163
163
  return_dict: bool = True,
164
164
  prior_callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -176,10 +176,10 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
176
176
  negative_prompt (`str` or `List[str]`, *optional*):
177
177
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
178
178
  if `guidance_scale` is less than `1`).
179
- prompt_embeds (`torch.FloatTensor`, *optional*):
179
+ prompt_embeds (`torch.Tensor`, *optional*):
180
180
  Pre-generated text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.* prompt
181
181
  weighting. If not provided, text embeddings will be generated from `prompt` input argument.
182
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
182
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
183
183
  Pre-generated negative text embeddings for the prior. Can be used to easily tweak text inputs, *e.g.*
184
184
  prompt weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt`
185
185
  input argument.
@@ -218,7 +218,7 @@ class WuerstchenCombinedPipeline(DiffusionPipeline):
218
218
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
219
219
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
220
220
  to make generation deterministic.
221
- latents (`torch.FloatTensor`, *optional*):
221
+ latents (`torch.Tensor`, *optional*):
222
222
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
223
223
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
224
224
  tensor will ge generated by sampling using the supplied random `generator`.
@@ -54,12 +54,12 @@ class WuerstchenPriorPipelineOutput(BaseOutput):
54
54
  Output class for WuerstchenPriorPipeline.
55
55
 
56
56
  Args:
57
- image_embeddings (`torch.FloatTensor` or `np.ndarray`)
57
+ image_embeddings (`torch.Tensor` or `np.ndarray`)
58
58
  Prior image embeddings for text prompt
59
59
 
60
60
  """
61
61
 
62
- image_embeddings: Union[torch.FloatTensor, np.ndarray]
62
+ image_embeddings: Union[torch.Tensor, np.ndarray]
63
63
 
64
64
 
65
65
  class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
@@ -136,8 +136,8 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
136
136
  do_classifier_free_guidance,
137
137
  prompt=None,
138
138
  negative_prompt=None,
139
- prompt_embeds: Optional[torch.FloatTensor] = None,
140
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
139
+ prompt_embeds: Optional[torch.Tensor] = None,
140
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
141
141
  ):
142
142
  if prompt is not None and isinstance(prompt, str):
143
143
  batch_size = 1
@@ -288,11 +288,11 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
288
288
  timesteps: List[float] = None,
289
289
  guidance_scale: float = 8.0,
290
290
  negative_prompt: Optional[Union[str, List[str]]] = None,
291
- prompt_embeds: Optional[torch.FloatTensor] = None,
292
- negative_prompt_embeds: Optional[torch.FloatTensor] = None,
291
+ prompt_embeds: Optional[torch.Tensor] = None,
292
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
293
293
  num_images_per_prompt: Optional[int] = 1,
294
294
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
295
- latents: Optional[torch.FloatTensor] = None,
295
+ latents: Optional[torch.Tensor] = None,
296
296
  output_type: Optional[str] = "pt",
297
297
  return_dict: bool = True,
298
298
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
@@ -324,10 +324,10 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
324
324
  negative_prompt (`str` or `List[str]`, *optional*):
325
325
  The prompt or prompts not to guide the image generation. Ignored when not using guidance (i.e., ignored
326
326
  if `decoder_guidance_scale` is less than `1`).
327
- prompt_embeds (`torch.FloatTensor`, *optional*):
327
+ prompt_embeds (`torch.Tensor`, *optional*):
328
328
  Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
329
329
  provided, text embeddings will be generated from `prompt` input argument.
330
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
330
+ negative_prompt_embeds (`torch.Tensor`, *optional*):
331
331
  Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
332
332
  weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
333
333
  argument.
@@ -336,7 +336,7 @@ class WuerstchenPriorPipeline(DiffusionPipeline, LoraLoaderMixin):
336
336
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
337
337
  One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
338
338
  to make generation deterministic.
339
- latents (`torch.FloatTensor`, *optional*):
339
+ latents (`torch.Tensor`, *optional*):
340
340
  Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
341
341
  generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
342
342
  tensor will ge generated by sampling using the supplied random `generator`.
@@ -68,7 +68,7 @@ else:
68
68
  _import_structure["scheduling_tcd"] = ["TCDScheduler"]
69
69
  _import_structure["scheduling_unclip"] = ["UnCLIPScheduler"]
70
70
  _import_structure["scheduling_unipc_multistep"] = ["UniPCMultistepScheduler"]
71
- _import_structure["scheduling_utils"] = ["KarrasDiffusionSchedulers", "SchedulerMixin"]
71
+ _import_structure["scheduling_utils"] = ["AysSchedules", "KarrasDiffusionSchedulers", "SchedulerMixin"]
72
72
  _import_structure["scheduling_vq_diffusion"] = ["VQDiffusionScheduler"]
73
73
 
74
74
  try:
@@ -163,7 +163,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
163
163
  from .scheduling_tcd import TCDScheduler
164
164
  from .scheduling_unclip import UnCLIPScheduler
165
165
  from .scheduling_unipc_multistep import UniPCMultistepScheduler
166
- from .scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
166
+ from .scheduling_utils import AysSchedules, KarrasDiffusionSchedulers, SchedulerMixin
167
167
  from .scheduling_vq_diffusion import VQDiffusionScheduler
168
168
 
169
169
  try:
@@ -30,7 +30,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30
30
  raise OptionalDependencyNotAvailable()
31
31
 
32
32
  except OptionalDependencyNotAvailable:
33
- from ..utils.dummy_pt_objects import * # noqa F403
33
+ from ...utils.dummy_pt_objects import * # noqa F403
34
34
  else:
35
35
  from .scheduling_karras_ve import KarrasVeScheduler
36
36
  from .scheduling_sde_vp import ScoreSdeVpScheduler
@@ -31,19 +31,19 @@ class KarrasVeOutput(BaseOutput):
31
31
  Output class for the scheduler's step function output.
32
32
 
33
33
  Args:
34
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
34
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
35
35
  Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
36
36
  denoising loop.
37
- derivative (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ derivative (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
38
38
  Derivative of predicted original image sample (x_0).
39
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
39
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
40
40
  The predicted denoised sample (x_{0}) based on the model output from the current timestep.
41
41
  `pred_original_sample` can be used to preview progress or for guidance.
42
42
  """
43
43
 
44
- prev_sample: torch.FloatTensor
45
- derivative: torch.FloatTensor
46
- pred_original_sample: Optional[torch.FloatTensor] = None
44
+ prev_sample: torch.Tensor
45
+ derivative: torch.Tensor
46
+ pred_original_sample: Optional[torch.Tensor] = None
47
47
 
48
48
 
49
49
  class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
@@ -94,21 +94,21 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
94
94
  # setable values
95
95
  self.num_inference_steps: int = None
96
96
  self.timesteps: np.IntTensor = None
97
- self.schedule: torch.FloatTensor = None # sigma(t_i)
97
+ self.schedule: torch.Tensor = None # sigma(t_i)
98
98
 
99
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
99
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
100
100
  """
101
101
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
102
102
  current timestep.
103
103
 
104
104
  Args:
105
- sample (`torch.FloatTensor`):
105
+ sample (`torch.Tensor`):
106
106
  The input sample.
107
107
  timestep (`int`, *optional*):
108
108
  The current timestep in the diffusion chain.
109
109
 
110
110
  Returns:
111
- `torch.FloatTensor`:
111
+ `torch.Tensor`:
112
112
  A scaled input sample.
113
113
  """
114
114
  return sample
@@ -136,14 +136,14 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
136
136
  self.schedule = torch.tensor(schedule, dtype=torch.float32, device=device)
137
137
 
138
138
  def add_noise_to_input(
139
- self, sample: torch.FloatTensor, sigma: float, generator: Optional[torch.Generator] = None
140
- ) -> Tuple[torch.FloatTensor, float]:
139
+ self, sample: torch.Tensor, sigma: float, generator: Optional[torch.Generator] = None
140
+ ) -> Tuple[torch.Tensor, float]:
141
141
  """
142
142
  Explicit Langevin-like "churn" step of adding noise to the sample according to a `gamma_i ≥ 0` to reach a
143
143
  higher noise level `sigma_hat = sigma_i + gamma_i*sigma_i`.
144
144
 
145
145
  Args:
146
- sample (`torch.FloatTensor`):
146
+ sample (`torch.Tensor`):
147
147
  The input sample.
148
148
  sigma (`float`):
149
149
  generator (`torch.Generator`, *optional*):
@@ -163,10 +163,10 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
163
163
 
164
164
  def step(
165
165
  self,
166
- model_output: torch.FloatTensor,
166
+ model_output: torch.Tensor,
167
167
  sigma_hat: float,
168
168
  sigma_prev: float,
169
- sample_hat: torch.FloatTensor,
169
+ sample_hat: torch.Tensor,
170
170
  return_dict: bool = True,
171
171
  ) -> Union[KarrasVeOutput, Tuple]:
172
172
  """
@@ -174,11 +174,11 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
174
174
  process from the learned model outputs (most often the predicted noise).
175
175
 
176
176
  Args:
177
- model_output (`torch.FloatTensor`):
177
+ model_output (`torch.Tensor`):
178
178
  The direct output from learned diffusion model.
179
179
  sigma_hat (`float`):
180
180
  sigma_prev (`float`):
181
- sample_hat (`torch.FloatTensor`):
181
+ sample_hat (`torch.Tensor`):
182
182
  return_dict (`bool`, *optional*, defaults to `True`):
183
183
  Whether or not to return a [`~schedulers.scheduling_karras_ve.KarrasVESchedulerOutput`] or `tuple`.
184
184
 
@@ -202,25 +202,25 @@ class KarrasVeScheduler(SchedulerMixin, ConfigMixin):
202
202
 
203
203
  def step_correct(
204
204
  self,
205
- model_output: torch.FloatTensor,
205
+ model_output: torch.Tensor,
206
206
  sigma_hat: float,
207
207
  sigma_prev: float,
208
- sample_hat: torch.FloatTensor,
209
- sample_prev: torch.FloatTensor,
210
- derivative: torch.FloatTensor,
208
+ sample_hat: torch.Tensor,
209
+ sample_prev: torch.Tensor,
210
+ derivative: torch.Tensor,
211
211
  return_dict: bool = True,
212
212
  ) -> Union[KarrasVeOutput, Tuple]:
213
213
  """
214
214
  Corrects the predicted sample based on the `model_output` of the network.
215
215
 
216
216
  Args:
217
- model_output (`torch.FloatTensor`):
217
+ model_output (`torch.Tensor`):
218
218
  The direct output from learned diffusion model.
219
219
  sigma_hat (`float`): TODO
220
220
  sigma_prev (`float`): TODO
221
- sample_hat (`torch.FloatTensor`): TODO
222
- sample_prev (`torch.FloatTensor`): TODO
223
- derivative (`torch.FloatTensor`): TODO
221
+ sample_hat (`torch.Tensor`): TODO
222
+ sample_prev (`torch.Tensor`): TODO
223
+ derivative (`torch.Tensor`): TODO
224
224
  return_dict (`bool`, *optional*, defaults to `True`):
225
225
  Whether or not to return a [`~schedulers.scheduling_ddpm.DDPMSchedulerOutput`] or `tuple`.
226
226
 
@@ -29,16 +29,16 @@ class AmusedSchedulerOutput(BaseOutput):
29
29
  Output class for the scheduler's `step` function output.
30
30
 
31
31
  Args:
32
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
32
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
33
33
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
34
34
  denoising loop.
35
- pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
35
+ pred_original_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
36
36
  The predicted denoised sample `(x_{0})` based on the model output from the current timestep.
37
37
  `pred_original_sample` can be used to preview progress or for guidance.
38
38
  """
39
39
 
40
- prev_sample: torch.FloatTensor
41
- pred_original_sample: torch.FloatTensor = None
40
+ prev_sample: torch.Tensor
41
+ pred_original_sample: torch.Tensor = None
42
42
 
43
43
 
44
44
  class AmusedScheduler(SchedulerMixin, ConfigMixin):
@@ -70,7 +70,7 @@ class AmusedScheduler(SchedulerMixin, ConfigMixin):
70
70
 
71
71
  def step(
72
72
  self,
73
- model_output: torch.FloatTensor,
73
+ model_output: torch.Tensor,
74
74
  timestep: torch.long,
75
75
  sample: torch.LongTensor,
76
76
  starting_mask_ratio: int = 1,
@@ -45,7 +45,7 @@ def betas_for_alpha_bar(
45
45
  return math.exp(t * -12.0)
46
46
 
47
47
  else:
48
- raise ValueError(f"Unsupported alpha_tranform_type: {alpha_transform_type}")
48
+ raise ValueError(f"Unsupported alpha_transform_type: {alpha_transform_type}")
49
49
 
50
50
  betas = []
51
51
  for i in range(num_diffusion_timesteps):
@@ -61,12 +61,12 @@ class ConsistencyDecoderSchedulerOutput(BaseOutput):
61
61
  Output class for the scheduler's `step` function.
62
62
 
63
63
  Args:
64
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
64
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
65
65
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
66
66
  denoising loop.
67
67
  """
68
68
 
69
- prev_sample: torch.FloatTensor
69
+ prev_sample: torch.Tensor
70
70
 
71
71
 
72
72
  class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
@@ -113,28 +113,28 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
113
113
  def init_noise_sigma(self):
114
114
  return self.sqrt_one_minus_alphas_cumprod[self.timesteps[0]]
115
115
 
116
- def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
116
+ def scale_model_input(self, sample: torch.Tensor, timestep: Optional[int] = None) -> torch.Tensor:
117
117
  """
118
118
  Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
119
119
  current timestep.
120
120
 
121
121
  Args:
122
- sample (`torch.FloatTensor`):
122
+ sample (`torch.Tensor`):
123
123
  The input sample.
124
124
  timestep (`int`, *optional*):
125
125
  The current timestep in the diffusion chain.
126
126
 
127
127
  Returns:
128
- `torch.FloatTensor`:
128
+ `torch.Tensor`:
129
129
  A scaled input sample.
130
130
  """
131
131
  return sample * self.c_in[timestep]
132
132
 
133
133
  def step(
134
134
  self,
135
- model_output: torch.FloatTensor,
136
- timestep: Union[float, torch.FloatTensor],
137
- sample: torch.FloatTensor,
135
+ model_output: torch.Tensor,
136
+ timestep: Union[float, torch.Tensor],
137
+ sample: torch.Tensor,
138
138
  generator: Optional[torch.Generator] = None,
139
139
  return_dict: bool = True,
140
140
  ) -> Union[ConsistencyDecoderSchedulerOutput, Tuple]:
@@ -143,11 +143,11 @@ class ConsistencyDecoderScheduler(SchedulerMixin, ConfigMixin):
143
143
  process from the learned model outputs (most often the predicted noise).
144
144
 
145
145
  Args:
146
- model_output (`torch.FloatTensor`):
146
+ model_output (`torch.Tensor`):
147
147
  The direct output from the learned diffusion model.
148
148
  timestep (`float`):
149
149
  The current timestep in the diffusion chain.
150
- sample (`torch.FloatTensor`):
150
+ sample (`torch.Tensor`):
151
151
  A current instance of a sample created by the diffusion process.
152
152
  generator (`torch.Generator`, *optional*):
153
153
  A random number generator.
@@ -33,12 +33,12 @@ class CMStochasticIterativeSchedulerOutput(BaseOutput):
33
33
  Output class for the scheduler's `step` function.
34
34
 
35
35
  Args:
36
- prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
36
+ prev_sample (`torch.Tensor` of shape `(batch_size, num_channels, height, width)` for images):
37
37
  Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
38
  denoising loop.
39
39
  """
40
40
 
41
- prev_sample: torch.FloatTensor
41
+ prev_sample: torch.Tensor
42
42
 
43
43
 
44
44
  class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
@@ -104,7 +104,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
104
104
  @property
105
105
  def step_index(self):
106
106
  """
107
- The index counter for current timestep. It will increae 1 after each scheduler step.
107
+ The index counter for current timestep. It will increase 1 after each scheduler step.
108
108
  """
109
109
  return self._step_index
110
110
 
@@ -126,20 +126,18 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
126
126
  """
127
127
  self._begin_index = begin_index
128
128
 
129
- def scale_model_input(
130
- self, sample: torch.FloatTensor, timestep: Union[float, torch.FloatTensor]
131
- ) -> torch.FloatTensor:
129
+ def scale_model_input(self, sample: torch.Tensor, timestep: Union[float, torch.Tensor]) -> torch.Tensor:
132
130
  """
133
131
  Scales the consistency model input by `(sigma**2 + sigma_data**2) ** 0.5`.
134
132
 
135
133
  Args:
136
- sample (`torch.FloatTensor`):
134
+ sample (`torch.Tensor`):
137
135
  The input sample.
138
- timestep (`float` or `torch.FloatTensor`):
136
+ timestep (`float` or `torch.Tensor`):
139
137
  The current timestep in the diffusion chain.
140
138
 
141
139
  Returns:
142
- `torch.FloatTensor`:
140
+ `torch.Tensor`:
143
141
  A scaled input sample.
144
142
  """
145
143
  # Get sigma corresponding to timestep
@@ -233,7 +231,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
233
231
  sigmas = self._convert_to_karras(ramp)
234
232
  timesteps = self.sigma_to_t(sigmas)
235
233
 
236
- sigmas = np.concatenate([sigmas, [self.sigma_min]]).astype(np.float32)
234
+ sigmas = np.concatenate([sigmas, [self.config.sigma_min]]).astype(np.float32)
237
235
  self.sigmas = torch.from_numpy(sigmas).to(device=device)
238
236
 
239
237
  if str(device).startswith("mps"):
@@ -278,7 +276,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
278
276
  </Tip>
279
277
 
280
278
  Args:
281
- sigma (`torch.FloatTensor`):
279
+ sigma (`torch.Tensor`):
282
280
  The current sigma in the Karras sigma schedule.
283
281
 
284
282
  Returns:
@@ -319,9 +317,9 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
319
317
 
320
318
  def step(
321
319
  self,
322
- model_output: torch.FloatTensor,
323
- timestep: Union[float, torch.FloatTensor],
324
- sample: torch.FloatTensor,
320
+ model_output: torch.Tensor,
321
+ timestep: Union[float, torch.Tensor],
322
+ sample: torch.Tensor,
325
323
  generator: Optional[torch.Generator] = None,
326
324
  return_dict: bool = True,
327
325
  ) -> Union[CMStochasticIterativeSchedulerOutput, Tuple]:
@@ -330,11 +328,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
330
328
  process from the learned model outputs (most often the predicted noise).
331
329
 
332
330
  Args:
333
- model_output (`torch.FloatTensor`):
331
+ model_output (`torch.Tensor`):
334
332
  The direct output from the learned diffusion model.
335
333
  timestep (`float`):
336
334
  The current timestep in the diffusion chain.
337
- sample (`torch.FloatTensor`):
335
+ sample (`torch.Tensor`):
338
336
  A current instance of a sample created by the diffusion process.
339
337
  generator (`torch.Generator`, *optional*):
340
338
  A random number generator.
@@ -349,11 +347,7 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
349
347
  otherwise a tuple is returned where the first element is the sample tensor.
350
348
  """
351
349
 
352
- if (
353
- isinstance(timestep, int)
354
- or isinstance(timestep, torch.IntTensor)
355
- or isinstance(timestep, torch.LongTensor)
356
- ):
350
+ if isinstance(timestep, (int, torch.IntTensor, torch.LongTensor)):
357
351
  raise ValueError(
358
352
  (
359
353
  "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
@@ -417,10 +411,10 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
417
411
  # Copied from diffusers.schedulers.scheduling_euler_discrete.EulerDiscreteScheduler.add_noise
418
412
  def add_noise(
419
413
  self,
420
- original_samples: torch.FloatTensor,
421
- noise: torch.FloatTensor,
422
- timesteps: torch.FloatTensor,
423
- ) -> torch.FloatTensor:
414
+ original_samples: torch.Tensor,
415
+ noise: torch.Tensor,
416
+ timesteps: torch.Tensor,
417
+ ) -> torch.Tensor:
424
418
  # Make sure sigmas and timesteps have the same device and dtype as original_samples
425
419
  sigmas = self.sigmas.to(device=original_samples.device, dtype=original_samples.dtype)
426
420
  if original_samples.device.type == "mps" and torch.is_floating_point(timesteps):
@@ -434,7 +428,11 @@ class CMStochasticIterativeScheduler(SchedulerMixin, ConfigMixin):
434
428
  # self.begin_index is None when scheduler is used for training, or pipeline does not implement set_begin_index
435
429
  if self.begin_index is None:
436
430
  step_indices = [self.index_for_timestep(t, schedule_timesteps) for t in timesteps]
431
+ elif self.step_index is not None:
432
+ # add_noise is called after first denoising step (for inpainting)
433
+ step_indices = [self.step_index] * timesteps.shape[0]
437
434
  else:
435
+ # add noise is called before first denoising step to create initial latent(img2img)
438
436
  step_indices = [self.begin_index] * timesteps.shape[0]
439
437
 
440
438
  sigma = sigmas[step_indices].flatten()