diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,492 @@
1
+ # Copyright 2024 Black Forest Labs and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import List, Optional, Union
17
+
18
+ import torch
19
+ from PIL import Image
20
+ from transformers import (
21
+ CLIPTextModel,
22
+ CLIPTokenizer,
23
+ SiglipImageProcessor,
24
+ SiglipVisionModel,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
28
+
29
+ from ...image_processor import PipelineImageInput
30
+ from ...loaders import FluxLoraLoaderMixin, TextualInversionLoaderMixin
31
+ from ...utils import (
32
+ USE_PEFT_BACKEND,
33
+ is_torch_xla_available,
34
+ logging,
35
+ replace_example_docstring,
36
+ scale_lora_layers,
37
+ unscale_lora_layers,
38
+ )
39
+ from ..pipeline_utils import DiffusionPipeline
40
+ from .modeling_flux import ReduxImageEncoder
41
+ from .pipeline_output import FluxPriorReduxPipelineOutput
42
+
43
+
44
+ if is_torch_xla_available():
45
+ XLA_AVAILABLE = True
46
+ else:
47
+ XLA_AVAILABLE = False
48
+
49
+
50
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
51
+
52
+ EXAMPLE_DOC_STRING = """
53
+ Examples:
54
+ ```py
55
+ >>> import torch
56
+ >>> from diffusers import FluxPriorReduxPipeline, FluxPipeline
57
+ >>> from diffusers.utils import load_image
58
+
59
+ >>> device = "cuda"
60
+ >>> dtype = torch.bfloat16
61
+
62
+ >>> repo_redux = "black-forest-labs/FLUX.1-Redux-dev"
63
+ >>> repo_base = "black-forest-labs/FLUX.1-dev"
64
+ >>> pipe_prior_redux = FluxPriorReduxPipeline.from_pretrained(repo_redux, torch_dtype=dtype).to(device)
65
+ >>> pipe = FluxPipeline.from_pretrained(
66
+ ... repo_base, text_encoder=None, text_encoder_2=None, torch_dtype=torch.bfloat16
67
+ ... ).to(device)
68
+
69
+ >>> image = load_image(
70
+ ... "https://huggingface.co/datasets/YiYiXu/testing-images/resolve/main/style_ziggy/img5.png"
71
+ ... )
72
+ >>> pipe_prior_output = pipe_prior_redux(image)
73
+ >>> images = pipe(
74
+ ... guidance_scale=2.5,
75
+ ... num_inference_steps=50,
76
+ ... generator=torch.Generator("cpu").manual_seed(0),
77
+ ... **pipe_prior_output,
78
+ ... ).images
79
+ >>> images[0].save("flux-redux.png")
80
+ ```
81
+ """
82
+
83
+
84
+ class FluxPriorReduxPipeline(DiffusionPipeline):
85
+ r"""
86
+ The Flux Redux pipeline for image-to-image generation.
87
+
88
+ Reference: https://blackforestlabs.ai/flux-1-tools/
89
+
90
+ Args:
91
+ image_encoder ([`SiglipVisionModel`]):
92
+ SIGLIP vision model to encode the input image.
93
+ feature_extractor ([`SiglipImageProcessor`]):
94
+ Image processor for preprocessing images for the SIGLIP model.
95
+ image_embedder ([`ReduxImageEncoder`]):
96
+ Redux image encoder to process the SIGLIP embeddings.
97
+ text_encoder ([`CLIPTextModel`], *optional*):
98
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
99
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
100
+ text_encoder_2 ([`T5EncoderModel`], *optional*):
101
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
102
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
103
+ tokenizer (`CLIPTokenizer`, *optional*):
104
+ Tokenizer of class
105
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
106
+ tokenizer_2 (`T5TokenizerFast`, *optional*):
107
+ Second Tokenizer of class
108
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
109
+ """
110
+
111
+ model_cpu_offload_seq = "image_encoder->image_embedder"
112
+ _optional_components = [
113
+ "text_encoder",
114
+ "tokenizer",
115
+ "text_encoder_2",
116
+ "tokenizer_2",
117
+ ]
118
+ _callback_tensor_inputs = []
119
+
120
+ def __init__(
121
+ self,
122
+ image_encoder: SiglipVisionModel,
123
+ feature_extractor: SiglipImageProcessor,
124
+ image_embedder: ReduxImageEncoder,
125
+ text_encoder: CLIPTextModel = None,
126
+ tokenizer: CLIPTokenizer = None,
127
+ text_encoder_2: T5EncoderModel = None,
128
+ tokenizer_2: T5TokenizerFast = None,
129
+ ):
130
+ super().__init__()
131
+
132
+ self.register_modules(
133
+ image_encoder=image_encoder,
134
+ feature_extractor=feature_extractor,
135
+ image_embedder=image_embedder,
136
+ text_encoder=text_encoder,
137
+ tokenizer=tokenizer,
138
+ text_encoder_2=text_encoder_2,
139
+ tokenizer_2=tokenizer_2,
140
+ )
141
+ self.tokenizer_max_length = (
142
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
143
+ )
144
+
145
+ def check_inputs(
146
+ self,
147
+ image,
148
+ prompt,
149
+ prompt_2,
150
+ prompt_embeds=None,
151
+ pooled_prompt_embeds=None,
152
+ prompt_embeds_scale=1.0,
153
+ pooled_prompt_embeds_scale=1.0,
154
+ ):
155
+ if prompt is not None and prompt_embeds is not None:
156
+ raise ValueError(
157
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
158
+ " only forward one of the two."
159
+ )
160
+ elif prompt_2 is not None and prompt_embeds is not None:
161
+ raise ValueError(
162
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
163
+ " only forward one of the two."
164
+ )
165
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
166
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
167
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
168
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
169
+ if prompt is not None and (isinstance(prompt, list) and isinstance(image, list) and len(prompt) != len(image)):
170
+ raise ValueError(
171
+ f"number of prompts must be equal to number of images, but {len(prompt)} prompts were provided and {len(image)} images"
172
+ )
173
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
174
+ raise ValueError(
175
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
176
+ )
177
+ if isinstance(prompt_embeds_scale, list) and (
178
+ isinstance(image, list) and len(prompt_embeds_scale) != len(image)
179
+ ):
180
+ raise ValueError(
181
+ f"number of weights must be equal to number of images, but {len(prompt_embeds_scale)} weights were provided and {len(image)} images"
182
+ )
183
+
184
+ def encode_image(self, image, device, num_images_per_prompt):
185
+ dtype = next(self.image_encoder.parameters()).dtype
186
+ image = self.feature_extractor.preprocess(
187
+ images=image, do_resize=True, return_tensors="pt", do_convert_rgb=True
188
+ )
189
+ image = image.to(device=device, dtype=dtype)
190
+
191
+ image_enc_hidden_states = self.image_encoder(**image).last_hidden_state
192
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
193
+
194
+ return image_enc_hidden_states
195
+
196
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
197
+ def _get_t5_prompt_embeds(
198
+ self,
199
+ prompt: Union[str, List[str]] = None,
200
+ num_images_per_prompt: int = 1,
201
+ max_sequence_length: int = 512,
202
+ device: Optional[torch.device] = None,
203
+ dtype: Optional[torch.dtype] = None,
204
+ ):
205
+ device = device or self._execution_device
206
+ dtype = dtype or self.text_encoder.dtype
207
+
208
+ prompt = [prompt] if isinstance(prompt, str) else prompt
209
+ batch_size = len(prompt)
210
+
211
+ if isinstance(self, TextualInversionLoaderMixin):
212
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
213
+
214
+ text_inputs = self.tokenizer_2(
215
+ prompt,
216
+ padding="max_length",
217
+ max_length=max_sequence_length,
218
+ truncation=True,
219
+ return_length=False,
220
+ return_overflowing_tokens=False,
221
+ return_tensors="pt",
222
+ )
223
+ text_input_ids = text_inputs.input_ids
224
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
225
+
226
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
227
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
228
+ logger.warning(
229
+ "The following part of your input was truncated because `max_sequence_length` is set to "
230
+ f" {max_sequence_length} tokens: {removed_text}"
231
+ )
232
+
233
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
234
+
235
+ dtype = self.text_encoder_2.dtype
236
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
237
+
238
+ _, seq_len, _ = prompt_embeds.shape
239
+
240
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
241
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
242
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
243
+
244
+ return prompt_embeds
245
+
246
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
247
+ def _get_clip_prompt_embeds(
248
+ self,
249
+ prompt: Union[str, List[str]],
250
+ num_images_per_prompt: int = 1,
251
+ device: Optional[torch.device] = None,
252
+ ):
253
+ device = device or self._execution_device
254
+
255
+ prompt = [prompt] if isinstance(prompt, str) else prompt
256
+ batch_size = len(prompt)
257
+
258
+ if isinstance(self, TextualInversionLoaderMixin):
259
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
260
+
261
+ text_inputs = self.tokenizer(
262
+ prompt,
263
+ padding="max_length",
264
+ max_length=self.tokenizer_max_length,
265
+ truncation=True,
266
+ return_overflowing_tokens=False,
267
+ return_length=False,
268
+ return_tensors="pt",
269
+ )
270
+
271
+ text_input_ids = text_inputs.input_ids
272
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
273
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
274
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
275
+ logger.warning(
276
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
277
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
278
+ )
279
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
280
+
281
+ # Use pooled output of CLIPTextModel
282
+ prompt_embeds = prompt_embeds.pooler_output
283
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
284
+
285
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
286
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
287
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
288
+
289
+ return prompt_embeds
290
+
291
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
292
+ def encode_prompt(
293
+ self,
294
+ prompt: Union[str, List[str]],
295
+ prompt_2: Union[str, List[str]],
296
+ device: Optional[torch.device] = None,
297
+ num_images_per_prompt: int = 1,
298
+ prompt_embeds: Optional[torch.FloatTensor] = None,
299
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
300
+ max_sequence_length: int = 512,
301
+ lora_scale: Optional[float] = None,
302
+ ):
303
+ r"""
304
+
305
+ Args:
306
+ prompt (`str` or `List[str]`, *optional*):
307
+ prompt to be encoded
308
+ prompt_2 (`str` or `List[str]`, *optional*):
309
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
310
+ used in all text-encoders
311
+ device: (`torch.device`):
312
+ torch device
313
+ num_images_per_prompt (`int`):
314
+ number of images that should be generated per prompt
315
+ prompt_embeds (`torch.FloatTensor`, *optional*):
316
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
317
+ provided, text embeddings will be generated from `prompt` input argument.
318
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
319
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
320
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
321
+ lora_scale (`float`, *optional*):
322
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
323
+ """
324
+ device = device or self._execution_device
325
+
326
+ # set lora scale so that monkey patched LoRA
327
+ # function of text encoder can correctly access it
328
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
329
+ self._lora_scale = lora_scale
330
+
331
+ # dynamically adjust the LoRA scale
332
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
333
+ scale_lora_layers(self.text_encoder, lora_scale)
334
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
335
+ scale_lora_layers(self.text_encoder_2, lora_scale)
336
+
337
+ prompt = [prompt] if isinstance(prompt, str) else prompt
338
+
339
+ if prompt_embeds is None:
340
+ prompt_2 = prompt_2 or prompt
341
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
342
+
343
+ # We only use the pooled prompt output from the CLIPTextModel
344
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
345
+ prompt=prompt,
346
+ device=device,
347
+ num_images_per_prompt=num_images_per_prompt,
348
+ )
349
+ prompt_embeds = self._get_t5_prompt_embeds(
350
+ prompt=prompt_2,
351
+ num_images_per_prompt=num_images_per_prompt,
352
+ max_sequence_length=max_sequence_length,
353
+ device=device,
354
+ )
355
+
356
+ if self.text_encoder is not None:
357
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
358
+ # Retrieve the original scale by scaling back the LoRA layers
359
+ unscale_lora_layers(self.text_encoder, lora_scale)
360
+
361
+ if self.text_encoder_2 is not None:
362
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
363
+ # Retrieve the original scale by scaling back the LoRA layers
364
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
365
+
366
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
367
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
368
+
369
+ return prompt_embeds, pooled_prompt_embeds, text_ids
370
+
371
+ @torch.no_grad()
372
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
373
+ def __call__(
374
+ self,
375
+ image: PipelineImageInput,
376
+ prompt: Union[str, List[str]] = None,
377
+ prompt_2: Optional[Union[str, List[str]]] = None,
378
+ prompt_embeds: Optional[torch.FloatTensor] = None,
379
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
380
+ prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
381
+ pooled_prompt_embeds_scale: Optional[Union[float, List[float]]] = 1.0,
382
+ return_dict: bool = True,
383
+ ):
384
+ r"""
385
+ Function invoked when calling the pipeline for generation.
386
+
387
+ Args:
388
+ image (`torch.Tensor`, `PIL.Image.Image`, `np.ndarray`, `List[torch.Tensor]`, `List[PIL.Image.Image]`, or `List[np.ndarray]`):
389
+ `Image`, numpy array or tensor representing an image batch to be used as the starting point. For both
390
+ numpy array and pytorch tensor, the expected value range is between `[0, 1]` If it's a tensor or a list
391
+ or tensors, the expected shape should be `(B, C, H, W)` or `(C, H, W)`. If it is a numpy array or a
392
+ list of arrays, the expected shape should be `(B, H, W, C)` or `(H, W, C)`
393
+ prompt (`str` or `List[str]`, *optional*):
394
+ The prompt or prompts to guide the image generation. **experimental feature**: to use this feature,
395
+ make sure to explicitly load text encoders to the pipeline. Prompts will be ignored if text encoders
396
+ are not loaded.
397
+ prompt_2 (`str` or `List[str]`, *optional*):
398
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
399
+ prompt_embeds (`torch.FloatTensor`, *optional*):
400
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
401
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
402
+ Pre-generated pooled text embeddings.
403
+ return_dict (`bool`, *optional*, defaults to `True`):
404
+ Whether or not to return a [`~pipelines.flux.FluxPriorReduxPipelineOutput`] instead of a plain tuple.
405
+
406
+ Examples:
407
+
408
+ Returns:
409
+ [`~pipelines.flux.FluxPriorReduxPipelineOutput`] or `tuple`:
410
+ [`~pipelines.flux.FluxPriorReduxPipelineOutput`] if `return_dict` is True, otherwise a `tuple`. When
411
+ returning a tuple, the first element is a list with the generated images.
412
+ """
413
+
414
+ # 1. Check inputs. Raise error if not correct
415
+ self.check_inputs(
416
+ image,
417
+ prompt,
418
+ prompt_2,
419
+ prompt_embeds=prompt_embeds,
420
+ pooled_prompt_embeds=pooled_prompt_embeds,
421
+ prompt_embeds_scale=prompt_embeds_scale,
422
+ pooled_prompt_embeds_scale=pooled_prompt_embeds_scale,
423
+ )
424
+
425
+ # 2. Define call parameters
426
+ if image is not None and isinstance(image, Image.Image):
427
+ batch_size = 1
428
+ elif image is not None and isinstance(image, list):
429
+ batch_size = len(image)
430
+ else:
431
+ batch_size = image.shape[0]
432
+ if prompt is not None and isinstance(prompt, str):
433
+ prompt = batch_size * [prompt]
434
+ if isinstance(prompt_embeds_scale, float):
435
+ prompt_embeds_scale = batch_size * [prompt_embeds_scale]
436
+ if isinstance(pooled_prompt_embeds_scale, float):
437
+ pooled_prompt_embeds_scale = batch_size * [pooled_prompt_embeds_scale]
438
+
439
+ device = self._execution_device
440
+
441
+ # 3. Prepare image embeddings
442
+ image_latents = self.encode_image(image, device, 1)
443
+
444
+ image_embeds = self.image_embedder(image_latents).image_embeds
445
+ image_embeds = image_embeds.to(device=device)
446
+
447
+ # 3. Prepare (dummy) text embeddings
448
+ if hasattr(self, "text_encoder") and self.text_encoder is not None:
449
+ (
450
+ prompt_embeds,
451
+ pooled_prompt_embeds,
452
+ _,
453
+ ) = self.encode_prompt(
454
+ prompt=prompt,
455
+ prompt_2=prompt_2,
456
+ prompt_embeds=prompt_embeds,
457
+ pooled_prompt_embeds=pooled_prompt_embeds,
458
+ device=device,
459
+ num_images_per_prompt=1,
460
+ max_sequence_length=512,
461
+ lora_scale=None,
462
+ )
463
+ else:
464
+ if prompt is not None:
465
+ logger.warning(
466
+ "prompt input is ignored when text encoders are not loaded to the pipeline. "
467
+ "Make sure to explicitly load the text encoders to enable prompt input. "
468
+ )
469
+ # max_sequence_length is 512, t5 encoder hidden size is 4096
470
+ prompt_embeds = torch.zeros((batch_size, 512, 4096), device=device, dtype=image_embeds.dtype)
471
+ # pooled_prompt_embeds is 768, clip text encoder hidden size
472
+ pooled_prompt_embeds = torch.zeros((batch_size, 768), device=device, dtype=image_embeds.dtype)
473
+
474
+ # scale & concatenate image and text embeddings
475
+ prompt_embeds = torch.cat([prompt_embeds, image_embeds], dim=1)
476
+
477
+ prompt_embeds *= torch.tensor(prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[:, None, None]
478
+ pooled_prompt_embeds *= torch.tensor(pooled_prompt_embeds_scale, device=device, dtype=image_embeds.dtype)[
479
+ :, None
480
+ ]
481
+
482
+ # weighted sum
483
+ prompt_embeds = torch.sum(prompt_embeds, dim=0, keepdim=True)
484
+ pooled_prompt_embeds = torch.sum(pooled_prompt_embeds, dim=0, keepdim=True)
485
+
486
+ # Offload all models
487
+ self.maybe_free_model_hooks()
488
+
489
+ if not return_dict:
490
+ return (prompt_embeds, pooled_prompt_embeds)
491
+
492
+ return FluxPriorReduxPipelineOutput(prompt_embeds=prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds)
@@ -3,6 +3,7 @@ from typing import List, Union
3
3
 
4
4
  import numpy as np
5
5
  import PIL.Image
6
+ import torch
6
7
 
7
8
  from ...utils import BaseOutput
8
9
 
@@ -19,3 +20,18 @@ class FluxPipelineOutput(BaseOutput):
19
20
  """
20
21
 
21
22
  images: Union[List[PIL.Image.Image], np.ndarray]
23
+
24
+
25
+ @dataclass
26
+ class FluxPriorReduxPipelineOutput(BaseOutput):
27
+ """
28
+ Output class for Flux Prior Redux pipelines.
29
+
30
+ Args:
31
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
32
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
33
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
34
+ """
35
+
36
+ prompt_embeds: torch.Tensor
37
+ pooled_prompt_embeds: torch.Tensor