diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -101,10 +101,10 @@ class DDPMPipeline(DiffusionPipeline):
101
101
 
102
102
  if self.device.type == "mps":
103
103
  # randn does not work reproducibly on mps
104
- image = randn_tensor(image_shape, generator=generator)
104
+ image = randn_tensor(image_shape, generator=generator, dtype=self.unet.dtype)
105
105
  image = image.to(self.device)
106
106
  else:
107
- image = randn_tensor(image_shape, generator=generator, device=self.device)
107
+ image = randn_tensor(image_shape, generator=generator, device=self.device, dtype=self.unet.dtype)
108
108
 
109
109
  # set step values
110
110
  self.scheduler.set_timesteps(num_inference_steps)
@@ -9,16 +9,17 @@ from ...utils import BaseOutput
9
9
 
10
10
  @dataclass
11
11
  class IFPipelineOutput(BaseOutput):
12
- """
13
- Args:
12
+ r"""
14
13
  Output class for Stable Diffusion pipelines.
15
- images (`List[PIL.Image.Image]` or `np.ndarray`)
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`):
16
17
  List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
17
18
  num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
18
- nsfw_detected (`List[bool]`)
19
+ nsfw_detected (`List[bool]`):
19
20
  List of flags denoting whether the corresponding generated image likely represents "not-safe-for-work"
20
21
  (nsfw) content or a watermark. `None` if safety checking could not be performed.
21
- watermark_detected (`List[bool]`)
22
+ watermark_detected (`List[bool]`):
22
23
  List of flags denoting whether the corresponding generated image likely has a watermark. `None` if safety
23
24
  checking could not be performed.
24
25
  """
@@ -65,9 +65,21 @@ EXAMPLE_DOC_STRING = """
65
65
 
66
66
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
67
67
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
68
- """
69
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
70
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
68
+ r"""
69
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
70
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
71
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
72
+
73
+ Args:
74
+ noise_cfg (`torch.Tensor`):
75
+ The predicted noise tensor for the guided diffusion process.
76
+ noise_pred_text (`torch.Tensor`):
77
+ The predicted noise tensor for the text-guided diffusion process.
78
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
79
+ A rescale factor applied to the noise predictions.
80
+
81
+ Returns:
82
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
71
83
  """
72
84
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
73
85
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -87,7 +99,7 @@ def retrieve_timesteps(
87
99
  sigmas: Optional[List[float]] = None,
88
100
  **kwargs,
89
101
  ):
90
- """
102
+ r"""
91
103
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
92
104
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
93
105
 
@@ -127,7 +127,7 @@ def retrieve_timesteps(
127
127
  sigmas: Optional[List[float]] = None,
128
128
  **kwargs,
129
129
  ):
130
- """
130
+ r"""
131
131
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
132
132
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
133
133
 
@@ -546,7 +546,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
546
546
  )
547
547
  elif encoder_hid_dim_type is not None:
548
548
  raise ValueError(
549
- f"encoder_hid_dim_type: {encoder_hid_dim_type} must be None, 'text_proj' or 'text_image_proj'."
549
+ f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj' or 'image_proj'."
550
550
  )
551
551
  else:
552
552
  self.encoder_hid_proj = None
@@ -1595,7 +1595,7 @@ class DownBlockFlat(nn.Module):
1595
1595
  output_states = ()
1596
1596
 
1597
1597
  for resnet in self.resnets:
1598
- if self.training and self.gradient_checkpointing:
1598
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1599
1599
 
1600
1600
  def create_custom_forward(module):
1601
1601
  def custom_forward(*inputs):
@@ -1732,7 +1732,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1732
1732
  blocks = list(zip(self.resnets, self.attentions))
1733
1733
 
1734
1734
  for i, (resnet, attn) in enumerate(blocks):
1735
- if self.training and self.gradient_checkpointing:
1735
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1736
1736
 
1737
1737
  def create_custom_forward(module, return_dict=None):
1738
1738
  def custom_forward(*inputs):
@@ -1874,7 +1874,7 @@ class UpBlockFlat(nn.Module):
1874
1874
 
1875
1875
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1876
1876
 
1877
- if self.training and self.gradient_checkpointing:
1877
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1878
1878
 
1879
1879
  def create_custom_forward(module):
1880
1880
  def custom_forward(*inputs):
@@ -2033,7 +2033,7 @@ class CrossAttnUpBlockFlat(nn.Module):
2033
2033
 
2034
2034
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2035
2035
 
2036
- if self.training and self.gradient_checkpointing:
2036
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2037
2037
 
2038
2038
  def create_custom_forward(module, return_dict=None):
2039
2039
  def custom_forward(*inputs):
@@ -2223,12 +2223,35 @@ class UNetMidBlockFlat(nn.Module):
2223
2223
  self.attentions = nn.ModuleList(attentions)
2224
2224
  self.resnets = nn.ModuleList(resnets)
2225
2225
 
2226
+ self.gradient_checkpointing = False
2227
+
2226
2228
  def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
2227
2229
  hidden_states = self.resnets[0](hidden_states, temb)
2228
2230
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2229
- if attn is not None:
2230
- hidden_states = attn(hidden_states, temb=temb)
2231
- hidden_states = resnet(hidden_states, temb)
2231
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2232
+
2233
+ def create_custom_forward(module, return_dict=None):
2234
+ def custom_forward(*inputs):
2235
+ if return_dict is not None:
2236
+ return module(*inputs, return_dict=return_dict)
2237
+ else:
2238
+ return module(*inputs)
2239
+
2240
+ return custom_forward
2241
+
2242
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2243
+ if attn is not None:
2244
+ hidden_states = attn(hidden_states, temb=temb)
2245
+ hidden_states = torch.utils.checkpoint.checkpoint(
2246
+ create_custom_forward(resnet),
2247
+ hidden_states,
2248
+ temb,
2249
+ **ckpt_kwargs,
2250
+ )
2251
+ else:
2252
+ if attn is not None:
2253
+ hidden_states = attn(hidden_states, temb=temb)
2254
+ hidden_states = resnet(hidden_states, temb)
2232
2255
 
2233
2256
  return hidden_states
2234
2257
 
@@ -2352,7 +2375,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
2352
2375
 
2353
2376
  hidden_states = self.resnets[0](hidden_states, temb)
2354
2377
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
2355
- if self.training and self.gradient_checkpointing:
2378
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2356
2379
 
2357
2380
  def create_custom_forward(module, return_dict=None):
2358
2381
  def custom_forward(*inputs):
@@ -12,7 +12,7 @@ from ...utils import (
12
12
 
13
13
  _dummy_objects = {}
14
14
  _additional_imports = {}
15
- _import_structure = {"pipeline_output": ["FluxPipelineOutput"]}
15
+ _import_structure = {"pipeline_output": ["FluxPipelineOutput", "FluxPriorReduxPipelineOutput"]}
16
16
 
17
17
  try:
18
18
  if not (is_transformers_available() and is_torch_available()):
@@ -22,7 +22,18 @@ except OptionalDependencyNotAvailable:
22
22
 
23
23
  _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
24
24
  else:
25
+ _import_structure["modeling_flux"] = ["ReduxImageEncoder"]
25
26
  _import_structure["pipeline_flux"] = ["FluxPipeline"]
27
+ _import_structure["pipeline_flux_control"] = ["FluxControlPipeline"]
28
+ _import_structure["pipeline_flux_control_img2img"] = ["FluxControlImg2ImgPipeline"]
29
+ _import_structure["pipeline_flux_control_inpaint"] = ["FluxControlInpaintPipeline"]
30
+ _import_structure["pipeline_flux_controlnet"] = ["FluxControlNetPipeline"]
31
+ _import_structure["pipeline_flux_controlnet_image_to_image"] = ["FluxControlNetImg2ImgPipeline"]
32
+ _import_structure["pipeline_flux_controlnet_inpainting"] = ["FluxControlNetInpaintPipeline"]
33
+ _import_structure["pipeline_flux_fill"] = ["FluxFillPipeline"]
34
+ _import_structure["pipeline_flux_img2img"] = ["FluxImg2ImgPipeline"]
35
+ _import_structure["pipeline_flux_inpaint"] = ["FluxInpaintPipeline"]
36
+ _import_structure["pipeline_flux_prior_redux"] = ["FluxPriorReduxPipeline"]
26
37
  if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
27
38
  try:
28
39
  if not (is_transformers_available() and is_torch_available()):
@@ -30,7 +41,18 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
30
41
  except OptionalDependencyNotAvailable:
31
42
  from ...utils.dummy_torch_and_transformers_objects import * # noqa F403
32
43
  else:
44
+ from .modeling_flux import ReduxImageEncoder
33
45
  from .pipeline_flux import FluxPipeline
46
+ from .pipeline_flux_control import FluxControlPipeline
47
+ from .pipeline_flux_control_img2img import FluxControlImg2ImgPipeline
48
+ from .pipeline_flux_control_inpaint import FluxControlInpaintPipeline
49
+ from .pipeline_flux_controlnet import FluxControlNetPipeline
50
+ from .pipeline_flux_controlnet_image_to_image import FluxControlNetImg2ImgPipeline
51
+ from .pipeline_flux_controlnet_inpainting import FluxControlNetInpaintPipeline
52
+ from .pipeline_flux_fill import FluxFillPipeline
53
+ from .pipeline_flux_img2img import FluxImg2ImgPipeline
54
+ from .pipeline_flux_inpaint import FluxInpaintPipeline
55
+ from .pipeline_flux_prior_redux import FluxPriorReduxPipeline
34
56
  else:
35
57
  import sys
36
58
 
@@ -0,0 +1,47 @@
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Optional
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...models.modeling_utils import ModelMixin
24
+ from ...utils import BaseOutput
25
+
26
+
27
+ @dataclass
28
+ class ReduxImageEncoderOutput(BaseOutput):
29
+ image_embeds: Optional[torch.Tensor] = None
30
+
31
+
32
+ class ReduxImageEncoder(ModelMixin, ConfigMixin):
33
+ @register_to_config
34
+ def __init__(
35
+ self,
36
+ redux_dim: int = 1152,
37
+ txt_in_features: int = 4096,
38
+ ) -> None:
39
+ super().__init__()
40
+
41
+ self.redux_up = nn.Linear(redux_dim, txt_in_features * 3)
42
+ self.redux_down = nn.Linear(txt_in_features * 3, txt_in_features)
43
+
44
+ def forward(self, x: torch.Tensor) -> ReduxImageEncoderOutput:
45
+ projected_x = self.redux_down(nn.functional.silu(self.redux_up(x)))
46
+
47
+ return ReduxImageEncoderOutput(image_embeds=projected_x)