diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (268) hide show
  1. diffusers/__init__.py +97 -4
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +13 -1
  4. diffusers/image_processor.py +282 -71
  5. diffusers/loaders/__init__.py +24 -3
  6. diffusers/loaders/ip_adapter.py +543 -16
  7. diffusers/loaders/lora_base.py +138 -125
  8. diffusers/loaders/lora_conversion_utils.py +647 -0
  9. diffusers/loaders/lora_pipeline.py +2216 -230
  10. diffusers/loaders/peft.py +380 -0
  11. diffusers/loaders/single_file_model.py +71 -4
  12. diffusers/loaders/single_file_utils.py +597 -10
  13. diffusers/loaders/textual_inversion.py +5 -3
  14. diffusers/loaders/transformer_flux.py +181 -0
  15. diffusers/loaders/transformer_sd3.py +89 -0
  16. diffusers/loaders/unet.py +56 -12
  17. diffusers/models/__init__.py +49 -12
  18. diffusers/models/activations.py +22 -9
  19. diffusers/models/adapter.py +53 -53
  20. diffusers/models/attention.py +98 -13
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2160 -346
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +73 -12
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +70 -0
  36. diffusers/models/controlnet_sd3.py +26 -376
  37. diffusers/models/controlnet_sparsectrl.py +46 -719
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +996 -92
  49. diffusers/models/embeddings_flax.py +23 -9
  50. diffusers/models/model_loading_utils.py +264 -14
  51. diffusers/models/modeling_flax_utils.py +1 -1
  52. diffusers/models/modeling_utils.py +334 -51
  53. diffusers/models/normalization.py +157 -13
  54. diffusers/models/transformers/__init__.py +6 -0
  55. diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
  56. diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
  57. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  58. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  59. diffusers/models/transformers/pixart_transformer_2d.py +10 -2
  60. diffusers/models/transformers/sana_transformer.py +488 -0
  61. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  62. diffusers/models/transformers/transformer_2d.py +1 -1
  63. diffusers/models/transformers/transformer_allegro.py +422 -0
  64. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  65. diffusers/models/transformers/transformer_flux.py +189 -51
  66. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  67. diffusers/models/transformers/transformer_ltx.py +469 -0
  68. diffusers/models/transformers/transformer_mochi.py +499 -0
  69. diffusers/models/transformers/transformer_sd3.py +112 -18
  70. diffusers/models/transformers/transformer_temporal.py +1 -1
  71. diffusers/models/unets/unet_1d_blocks.py +1 -1
  72. diffusers/models/unets/unet_2d.py +8 -1
  73. diffusers/models/unets/unet_2d_blocks.py +88 -21
  74. diffusers/models/unets/unet_2d_condition.py +9 -9
  75. diffusers/models/unets/unet_3d_blocks.py +9 -7
  76. diffusers/models/unets/unet_motion_model.py +46 -68
  77. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  78. diffusers/models/unets/unet_stable_cascade.py +2 -2
  79. diffusers/models/unets/uvit_2d.py +1 -1
  80. diffusers/models/upsampling.py +14 -6
  81. diffusers/pipelines/__init__.py +69 -6
  82. diffusers/pipelines/allegro/__init__.py +48 -0
  83. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  84. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  85. diffusers/pipelines/animatediff/__init__.py +2 -0
  86. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  87. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
  88. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  89. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
  90. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
  91. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  92. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  93. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
  94. diffusers/pipelines/auto_pipeline.py +88 -10
  95. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  96. diffusers/pipelines/cogvideo/__init__.py +2 -0
  97. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
  98. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
  99. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
  100. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
  101. diffusers/pipelines/cogview3/__init__.py +47 -0
  102. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  103. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  104. diffusers/pipelines/controlnet/__init__.py +86 -80
  105. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  106. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
  107. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
  108. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
  109. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
  110. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
  111. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  113. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  114. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  115. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
  116. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  117. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
  118. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  119. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  120. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  121. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  122. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  123. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
  124. diffusers/pipelines/flux/__init__.py +23 -1
  125. diffusers/pipelines/flux/modeling_flux.py +47 -0
  126. diffusers/pipelines/flux/pipeline_flux.py +256 -48
  127. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  128. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  129. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  130. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
  131. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
  132. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
  133. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  134. diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
  135. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
  136. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  137. diffusers/pipelines/flux/pipeline_output.py +16 -0
  138. diffusers/pipelines/free_noise_utils.py +365 -5
  139. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  140. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  141. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  142. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
  143. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  144. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  145. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  146. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  147. diffusers/pipelines/kolors/text_encoder.py +2 -2
  148. diffusers/pipelines/kolors/tokenizer.py +4 -0
  149. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  150. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  151. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  152. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  153. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  154. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  155. diffusers/pipelines/ltx/__init__.py +50 -0
  156. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  157. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  158. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  159. diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
  160. diffusers/pipelines/mochi/__init__.py +48 -0
  161. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  162. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  163. diffusers/pipelines/pag/__init__.py +13 -0
  164. diffusers/pipelines/pag/pag_utils.py +8 -2
  165. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
  166. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
  167. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
  168. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
  169. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
  170. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  171. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
  172. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  173. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  174. diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
  175. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  176. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  177. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
  178. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  179. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  180. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  181. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  182. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  183. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  184. diffusers/pipelines/pipeline_loading_utils.py +250 -31
  185. diffusers/pipelines/pipeline_utils.py +158 -186
  186. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
  187. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
  188. diffusers/pipelines/sana/__init__.py +47 -0
  189. diffusers/pipelines/sana/pipeline_output.py +21 -0
  190. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  191. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  192. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  193. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  194. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
  195. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  196. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  197. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  198. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
  199. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
  200. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
  201. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  202. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  203. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  204. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  205. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
  206. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
  207. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
  208. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  209. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  210. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  211. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  212. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  213. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  214. diffusers/quantizers/__init__.py +16 -0
  215. diffusers/quantizers/auto.py +139 -0
  216. diffusers/quantizers/base.py +233 -0
  217. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  218. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
  219. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  220. diffusers/quantizers/gguf/__init__.py +1 -0
  221. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  222. diffusers/quantizers/gguf/utils.py +456 -0
  223. diffusers/quantizers/quantization_config.py +669 -0
  224. diffusers/quantizers/torchao/__init__.py +15 -0
  225. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  226. diffusers/schedulers/scheduling_ddim.py +4 -1
  227. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  228. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  229. diffusers/schedulers/scheduling_ddpm.py +6 -7
  230. diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
  231. diffusers/schedulers/scheduling_deis_multistep.py +102 -6
  232. diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
  233. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
  234. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  235. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
  236. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  237. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  238. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  239. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  240. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  241. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  242. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  243. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  244. diffusers/schedulers/scheduling_lcm.py +2 -6
  245. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  246. diffusers/schedulers/scheduling_repaint.py +1 -1
  247. diffusers/schedulers/scheduling_sasolver.py +102 -6
  248. diffusers/schedulers/scheduling_tcd.py +2 -6
  249. diffusers/schedulers/scheduling_unclip.py +4 -1
  250. diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
  251. diffusers/training_utils.py +63 -19
  252. diffusers/utils/__init__.py +7 -1
  253. diffusers/utils/constants.py +1 -0
  254. diffusers/utils/dummy_pt_objects.py +240 -0
  255. diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
  256. diffusers/utils/dynamic_modules_utils.py +3 -3
  257. diffusers/utils/hub_utils.py +44 -40
  258. diffusers/utils/import_utils.py +98 -8
  259. diffusers/utils/loading_utils.py +28 -4
  260. diffusers/utils/peft_utils.py +6 -3
  261. diffusers/utils/testing_utils.py +115 -1
  262. diffusers/utils/torch_utils.py +3 -0
  263. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
  264. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
  265. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  266. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  267. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  268. {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1204 @@
1
+ import inspect
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import numpy as np
5
+ import PIL
6
+ import torch
7
+ from transformers import (
8
+ CLIPTextModel,
9
+ CLIPTokenizer,
10
+ T5EncoderModel,
11
+ T5TokenizerFast,
12
+ )
13
+
14
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
15
+ from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
16
+ from ...models.autoencoders import AutoencoderKL
17
+ from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
18
+ from ...models.transformers import FluxTransformer2DModel
19
+ from ...schedulers import FlowMatchEulerDiscreteScheduler
20
+ from ...utils import (
21
+ USE_PEFT_BACKEND,
22
+ is_torch_xla_available,
23
+ logging,
24
+ replace_example_docstring,
25
+ scale_lora_layers,
26
+ unscale_lora_layers,
27
+ )
28
+ from ...utils.torch_utils import randn_tensor
29
+ from ..pipeline_utils import DiffusionPipeline
30
+ from .pipeline_output import FluxPipelineOutput
31
+
32
+
33
+ if is_torch_xla_available():
34
+ import torch_xla.core.xla_model as xm
35
+
36
+ XLA_AVAILABLE = True
37
+ else:
38
+ XLA_AVAILABLE = False
39
+
40
+ logger = logging.get_logger(__name__)
41
+
42
+ EXAMPLE_DOC_STRING = """
43
+ Examples:
44
+ ```py
45
+ >>> import torch
46
+ >>> from diffusers import FluxControlNetInpaintPipeline
47
+ >>> from diffusers.models import FluxControlNetModel
48
+ >>> from diffusers.utils import load_image
49
+
50
+ >>> controlnet = FluxControlNetModel.from_pretrained(
51
+ ... "InstantX/FLUX.1-dev-controlnet-canny", torch_dtype=torch.float16
52
+ ... )
53
+ >>> pipe = FluxControlNetInpaintPipeline.from_pretrained(
54
+ ... "black-forest-labs/FLUX.1-schnell", controlnet=controlnet, torch_dtype=torch.float16
55
+ ... )
56
+ >>> pipe.to("cuda")
57
+
58
+ >>> control_image = load_image(
59
+ ... "https://huggingface.co/InstantX/FLUX.1-dev-Controlnet-Canny-alpha/resolve/main/canny.jpg"
60
+ ... )
61
+ >>> init_image = load_image(
62
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo.png"
63
+ ... )
64
+ >>> mask_image = load_image(
65
+ ... "https://raw.githubusercontent.com/CompVis/latent-diffusion/main/data/inpainting_examples/overture-creations-5sI6fQgYIuo_mask.png"
66
+ ... )
67
+
68
+ >>> prompt = "A girl holding a sign that says InstantX"
69
+ >>> image = pipe(
70
+ ... prompt,
71
+ ... image=init_image,
72
+ ... mask_image=mask_image,
73
+ ... control_image=control_image,
74
+ ... control_guidance_start=0.2,
75
+ ... control_guidance_end=0.8,
76
+ ... controlnet_conditioning_scale=0.7,
77
+ ... strength=0.7,
78
+ ... num_inference_steps=28,
79
+ ... guidance_scale=3.5,
80
+ ... ).images[0]
81
+ >>> image.save("flux_controlnet_inpaint.png")
82
+ ```
83
+ """
84
+
85
+
86
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
87
+ def calculate_shift(
88
+ image_seq_len,
89
+ base_seq_len: int = 256,
90
+ max_seq_len: int = 4096,
91
+ base_shift: float = 0.5,
92
+ max_shift: float = 1.16,
93
+ ):
94
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
95
+ b = base_shift - m * base_seq_len
96
+ mu = image_seq_len * m + b
97
+ return mu
98
+
99
+
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101
+ def retrieve_latents(
102
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103
+ ):
104
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105
+ return encoder_output.latent_dist.sample(generator)
106
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107
+ return encoder_output.latent_dist.mode()
108
+ elif hasattr(encoder_output, "latents"):
109
+ return encoder_output.latents
110
+ else:
111
+ raise AttributeError("Could not access latents of provided encoder_output")
112
+
113
+
114
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
115
+ def retrieve_timesteps(
116
+ scheduler,
117
+ num_inference_steps: Optional[int] = None,
118
+ device: Optional[Union[str, torch.device]] = None,
119
+ timesteps: Optional[List[int]] = None,
120
+ sigmas: Optional[List[float]] = None,
121
+ **kwargs,
122
+ ):
123
+ r"""
124
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
125
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
126
+
127
+ Args:
128
+ scheduler (`SchedulerMixin`):
129
+ The scheduler to get timesteps from.
130
+ num_inference_steps (`int`):
131
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
132
+ must be `None`.
133
+ device (`str` or `torch.device`, *optional*):
134
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
135
+ timesteps (`List[int]`, *optional*):
136
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
137
+ `num_inference_steps` and `sigmas` must be `None`.
138
+ sigmas (`List[float]`, *optional*):
139
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
140
+ `num_inference_steps` and `timesteps` must be `None`.
141
+
142
+ Returns:
143
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
144
+ second element is the number of inference steps.
145
+ """
146
+ if timesteps is not None and sigmas is not None:
147
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
148
+ if timesteps is not None:
149
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
150
+ if not accepts_timesteps:
151
+ raise ValueError(
152
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
153
+ f" timestep schedules. Please check whether you are using the correct scheduler."
154
+ )
155
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
156
+ timesteps = scheduler.timesteps
157
+ num_inference_steps = len(timesteps)
158
+ elif sigmas is not None:
159
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
160
+ if not accept_sigmas:
161
+ raise ValueError(
162
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
163
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
164
+ )
165
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
166
+ timesteps = scheduler.timesteps
167
+ num_inference_steps = len(timesteps)
168
+ else:
169
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
170
+ timesteps = scheduler.timesteps
171
+ return timesteps, num_inference_steps
172
+
173
+
174
+ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleFileMixin):
175
+ r"""
176
+ The Flux controlnet pipeline for inpainting.
177
+
178
+ Reference: https://blackforestlabs.ai/announcing-black-forest-labs/
179
+
180
+ Args:
181
+ transformer ([`FluxTransformer2DModel`]):
182
+ Conditional Transformer (MMDiT) architecture to denoise the encoded image latents.
183
+ scheduler ([`FlowMatchEulerDiscreteScheduler`]):
184
+ A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
185
+ vae ([`AutoencoderKL`]):
186
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
187
+ text_encoder ([`CLIPTextModel`]):
188
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
189
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
190
+ text_encoder_2 ([`T5EncoderModel`]):
191
+ [T5](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5EncoderModel), specifically
192
+ the [google/t5-v1_1-xxl](https://huggingface.co/google/t5-v1_1-xxl) variant.
193
+ tokenizer (`CLIPTokenizer`):
194
+ Tokenizer of class
195
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/en/model_doc/clip#transformers.CLIPTokenizer).
196
+ tokenizer_2 (`T5TokenizerFast`):
197
+ Second Tokenizer of class
198
+ [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
199
+ """
200
+
201
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
202
+ _optional_components = []
203
+ _callback_tensor_inputs = ["latents", "prompt_embeds"]
204
+
205
+ def __init__(
206
+ self,
207
+ scheduler: FlowMatchEulerDiscreteScheduler,
208
+ vae: AutoencoderKL,
209
+ text_encoder: CLIPTextModel,
210
+ tokenizer: CLIPTokenizer,
211
+ text_encoder_2: T5EncoderModel,
212
+ tokenizer_2: T5TokenizerFast,
213
+ transformer: FluxTransformer2DModel,
214
+ controlnet: Union[
215
+ FluxControlNetModel, List[FluxControlNetModel], Tuple[FluxControlNetModel], FluxMultiControlNetModel
216
+ ],
217
+ ):
218
+ super().__init__()
219
+ if isinstance(controlnet, (list, tuple)):
220
+ controlnet = FluxMultiControlNetModel(controlnet)
221
+
222
+ self.register_modules(
223
+ scheduler=scheduler,
224
+ vae=vae,
225
+ text_encoder=text_encoder,
226
+ tokenizer=tokenizer,
227
+ text_encoder_2=text_encoder_2,
228
+ tokenizer_2=tokenizer_2,
229
+ transformer=transformer,
230
+ controlnet=controlnet,
231
+ )
232
+
233
+ self.vae_scale_factor = (
234
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
235
+ )
236
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
237
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
238
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
239
+ self.mask_processor = VaeImageProcessor(
240
+ vae_scale_factor=self.vae_scale_factor * 2,
241
+ vae_latent_channels=self.vae.config.latent_channels,
242
+ do_normalize=False,
243
+ do_binarize=True,
244
+ do_convert_grayscale=True,
245
+ )
246
+ self.tokenizer_max_length = (
247
+ self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
248
+ )
249
+ self.default_sample_size = 128
250
+
251
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
252
+ def _get_t5_prompt_embeds(
253
+ self,
254
+ prompt: Union[str, List[str]] = None,
255
+ num_images_per_prompt: int = 1,
256
+ max_sequence_length: int = 512,
257
+ device: Optional[torch.device] = None,
258
+ dtype: Optional[torch.dtype] = None,
259
+ ):
260
+ device = device or self._execution_device
261
+ dtype = dtype or self.text_encoder.dtype
262
+
263
+ prompt = [prompt] if isinstance(prompt, str) else prompt
264
+ batch_size = len(prompt)
265
+
266
+ if isinstance(self, TextualInversionLoaderMixin):
267
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer_2)
268
+
269
+ text_inputs = self.tokenizer_2(
270
+ prompt,
271
+ padding="max_length",
272
+ max_length=max_sequence_length,
273
+ truncation=True,
274
+ return_length=False,
275
+ return_overflowing_tokens=False,
276
+ return_tensors="pt",
277
+ )
278
+ text_input_ids = text_inputs.input_ids
279
+ untruncated_ids = self.tokenizer_2(prompt, padding="longest", return_tensors="pt").input_ids
280
+
281
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
282
+ removed_text = self.tokenizer_2.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
283
+ logger.warning(
284
+ "The following part of your input was truncated because `max_sequence_length` is set to "
285
+ f" {max_sequence_length} tokens: {removed_text}"
286
+ )
287
+
288
+ prompt_embeds = self.text_encoder_2(text_input_ids.to(device), output_hidden_states=False)[0]
289
+
290
+ dtype = self.text_encoder_2.dtype
291
+ prompt_embeds = prompt_embeds.to(dtype=dtype, device=device)
292
+
293
+ _, seq_len, _ = prompt_embeds.shape
294
+
295
+ # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
296
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
297
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
298
+
299
+ return prompt_embeds
300
+
301
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_clip_prompt_embeds
302
+ def _get_clip_prompt_embeds(
303
+ self,
304
+ prompt: Union[str, List[str]],
305
+ num_images_per_prompt: int = 1,
306
+ device: Optional[torch.device] = None,
307
+ ):
308
+ device = device or self._execution_device
309
+
310
+ prompt = [prompt] if isinstance(prompt, str) else prompt
311
+ batch_size = len(prompt)
312
+
313
+ if isinstance(self, TextualInversionLoaderMixin):
314
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
315
+
316
+ text_inputs = self.tokenizer(
317
+ prompt,
318
+ padding="max_length",
319
+ max_length=self.tokenizer_max_length,
320
+ truncation=True,
321
+ return_overflowing_tokens=False,
322
+ return_length=False,
323
+ return_tensors="pt",
324
+ )
325
+
326
+ text_input_ids = text_inputs.input_ids
327
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
328
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids):
329
+ removed_text = self.tokenizer.batch_decode(untruncated_ids[:, self.tokenizer_max_length - 1 : -1])
330
+ logger.warning(
331
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
332
+ f" {self.tokenizer_max_length} tokens: {removed_text}"
333
+ )
334
+ prompt_embeds = self.text_encoder(text_input_ids.to(device), output_hidden_states=False)
335
+
336
+ # Use pooled output of CLIPTextModel
337
+ prompt_embeds = prompt_embeds.pooler_output
338
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
339
+
340
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
341
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt)
342
+ prompt_embeds = prompt_embeds.view(batch_size * num_images_per_prompt, -1)
343
+
344
+ return prompt_embeds
345
+
346
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline.encode_prompt
347
+ def encode_prompt(
348
+ self,
349
+ prompt: Union[str, List[str]],
350
+ prompt_2: Union[str, List[str]],
351
+ device: Optional[torch.device] = None,
352
+ num_images_per_prompt: int = 1,
353
+ prompt_embeds: Optional[torch.FloatTensor] = None,
354
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
355
+ max_sequence_length: int = 512,
356
+ lora_scale: Optional[float] = None,
357
+ ):
358
+ r"""
359
+
360
+ Args:
361
+ prompt (`str` or `List[str]`, *optional*):
362
+ prompt to be encoded
363
+ prompt_2 (`str` or `List[str]`, *optional*):
364
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`. If not defined, `prompt` is
365
+ used in all text-encoders
366
+ device: (`torch.device`):
367
+ torch device
368
+ num_images_per_prompt (`int`):
369
+ number of images that should be generated per prompt
370
+ prompt_embeds (`torch.FloatTensor`, *optional*):
371
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
372
+ provided, text embeddings will be generated from `prompt` input argument.
373
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
374
+ Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
375
+ If not provided, pooled text embeddings will be generated from `prompt` input argument.
376
+ lora_scale (`float`, *optional*):
377
+ A lora scale that will be applied to all LoRA layers of the text encoder if LoRA layers are loaded.
378
+ """
379
+ device = device or self._execution_device
380
+
381
+ # set lora scale so that monkey patched LoRA
382
+ # function of text encoder can correctly access it
383
+ if lora_scale is not None and isinstance(self, FluxLoraLoaderMixin):
384
+ self._lora_scale = lora_scale
385
+
386
+ # dynamically adjust the LoRA scale
387
+ if self.text_encoder is not None and USE_PEFT_BACKEND:
388
+ scale_lora_layers(self.text_encoder, lora_scale)
389
+ if self.text_encoder_2 is not None and USE_PEFT_BACKEND:
390
+ scale_lora_layers(self.text_encoder_2, lora_scale)
391
+
392
+ prompt = [prompt] if isinstance(prompt, str) else prompt
393
+
394
+ if prompt_embeds is None:
395
+ prompt_2 = prompt_2 or prompt
396
+ prompt_2 = [prompt_2] if isinstance(prompt_2, str) else prompt_2
397
+
398
+ # We only use the pooled prompt output from the CLIPTextModel
399
+ pooled_prompt_embeds = self._get_clip_prompt_embeds(
400
+ prompt=prompt,
401
+ device=device,
402
+ num_images_per_prompt=num_images_per_prompt,
403
+ )
404
+ prompt_embeds = self._get_t5_prompt_embeds(
405
+ prompt=prompt_2,
406
+ num_images_per_prompt=num_images_per_prompt,
407
+ max_sequence_length=max_sequence_length,
408
+ device=device,
409
+ )
410
+
411
+ if self.text_encoder is not None:
412
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
413
+ # Retrieve the original scale by scaling back the LoRA layers
414
+ unscale_lora_layers(self.text_encoder, lora_scale)
415
+
416
+ if self.text_encoder_2 is not None:
417
+ if isinstance(self, FluxLoraLoaderMixin) and USE_PEFT_BACKEND:
418
+ # Retrieve the original scale by scaling back the LoRA layers
419
+ unscale_lora_layers(self.text_encoder_2, lora_scale)
420
+
421
+ dtype = self.text_encoder.dtype if self.text_encoder is not None else self.transformer.dtype
422
+ text_ids = torch.zeros(prompt_embeds.shape[1], 3).to(device=device, dtype=dtype)
423
+
424
+ return prompt_embeds, pooled_prompt_embeds, text_ids
425
+
426
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_inpaint.StableDiffusion3InpaintPipeline._encode_vae_image
427
+ def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator):
428
+ if isinstance(generator, list):
429
+ image_latents = [
430
+ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i])
431
+ for i in range(image.shape[0])
432
+ ]
433
+ image_latents = torch.cat(image_latents, dim=0)
434
+ else:
435
+ image_latents = retrieve_latents(self.vae.encode(image), generator=generator)
436
+
437
+ image_latents = (image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
438
+
439
+ return image_latents
440
+
441
+ # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3_img2img.StableDiffusion3Img2ImgPipeline.get_timesteps
442
+ def get_timesteps(self, num_inference_steps, strength, device):
443
+ # get the original timestep using init_timestep
444
+ init_timestep = min(num_inference_steps * strength, num_inference_steps)
445
+
446
+ t_start = int(max(num_inference_steps - init_timestep, 0))
447
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
448
+ if hasattr(self.scheduler, "set_begin_index"):
449
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
450
+
451
+ return timesteps, num_inference_steps - t_start
452
+
453
+ def check_inputs(
454
+ self,
455
+ prompt,
456
+ prompt_2,
457
+ image,
458
+ mask_image,
459
+ strength,
460
+ height,
461
+ width,
462
+ output_type,
463
+ prompt_embeds=None,
464
+ pooled_prompt_embeds=None,
465
+ callback_on_step_end_tensor_inputs=None,
466
+ padding_mask_crop=None,
467
+ max_sequence_length=None,
468
+ ):
469
+ if strength < 0 or strength > 1:
470
+ raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
471
+
472
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
473
+ logger.warning(
474
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
475
+ )
476
+
477
+ if callback_on_step_end_tensor_inputs is not None and not all(
478
+ k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
479
+ ):
480
+ raise ValueError(
481
+ f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}"
482
+ )
483
+
484
+ if prompt is not None and prompt_embeds is not None:
485
+ raise ValueError(
486
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
487
+ " only forward one of the two."
488
+ )
489
+ elif prompt_2 is not None and prompt_embeds is not None:
490
+ raise ValueError(
491
+ f"Cannot forward both `prompt_2`: {prompt_2} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
492
+ " only forward one of the two."
493
+ )
494
+ elif prompt is None and prompt_embeds is None:
495
+ raise ValueError(
496
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
497
+ )
498
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
499
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
500
+ elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
501
+ raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
502
+
503
+ if prompt_embeds is not None and pooled_prompt_embeds is None:
504
+ raise ValueError(
505
+ "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
506
+ )
507
+
508
+ if padding_mask_crop is not None:
509
+ if not isinstance(image, PIL.Image.Image):
510
+ raise ValueError(
511
+ f"The image should be a PIL image when inpainting mask crop, but is of type" f" {type(image)}."
512
+ )
513
+ if not isinstance(mask_image, PIL.Image.Image):
514
+ raise ValueError(
515
+ f"The mask image should be a PIL image when inpainting mask crop, but is of type"
516
+ f" {type(mask_image)}."
517
+ )
518
+ if output_type != "pil":
519
+ raise ValueError(f"The output type should be PIL when inpainting mask crop, but is" f" {output_type}.")
520
+
521
+ if max_sequence_length is not None and max_sequence_length > 512:
522
+ raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
523
+
524
+ @staticmethod
525
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
526
+ def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
527
+ latent_image_ids = torch.zeros(height, width, 3)
528
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
529
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
530
+
531
+ latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
532
+
533
+ latent_image_ids = latent_image_ids.reshape(
534
+ latent_image_id_height * latent_image_id_width, latent_image_id_channels
535
+ )
536
+
537
+ return latent_image_ids.to(device=device, dtype=dtype)
538
+
539
+ @staticmethod
540
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._pack_latents
541
+ def _pack_latents(latents, batch_size, num_channels_latents, height, width):
542
+ latents = latents.view(batch_size, num_channels_latents, height // 2, 2, width // 2, 2)
543
+ latents = latents.permute(0, 2, 4, 1, 3, 5)
544
+ latents = latents.reshape(batch_size, (height // 2) * (width // 2), num_channels_latents * 4)
545
+
546
+ return latents
547
+
548
+ @staticmethod
549
+ # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._unpack_latents
550
+ def _unpack_latents(latents, height, width, vae_scale_factor):
551
+ batch_size, num_patches, channels = latents.shape
552
+
553
+ # VAE applies 8x compression on images but we must also account for packing which requires
554
+ # latent height and width to be divisible by 2.
555
+ height = 2 * (int(height) // (vae_scale_factor * 2))
556
+ width = 2 * (int(width) // (vae_scale_factor * 2))
557
+
558
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
559
+ latents = latents.permute(0, 3, 1, 4, 2, 5)
560
+
561
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
562
+
563
+ return latents
564
+
565
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_latents
566
+ def prepare_latents(
567
+ self,
568
+ image,
569
+ timestep,
570
+ batch_size,
571
+ num_channels_latents,
572
+ height,
573
+ width,
574
+ dtype,
575
+ device,
576
+ generator,
577
+ latents=None,
578
+ ):
579
+ if isinstance(generator, list) and len(generator) != batch_size:
580
+ raise ValueError(
581
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
582
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
583
+ )
584
+
585
+ # VAE applies 8x compression on images but we must also account for packing which requires
586
+ # latent height and width to be divisible by 2.
587
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
588
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
589
+ shape = (batch_size, num_channels_latents, height, width)
590
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
591
+
592
+ image = image.to(device=device, dtype=dtype)
593
+ image_latents = self._encode_vae_image(image=image, generator=generator)
594
+
595
+ if batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] == 0:
596
+ # expand init_latents for batch_size
597
+ additional_image_per_prompt = batch_size // image_latents.shape[0]
598
+ image_latents = torch.cat([image_latents] * additional_image_per_prompt, dim=0)
599
+ elif batch_size > image_latents.shape[0] and batch_size % image_latents.shape[0] != 0:
600
+ raise ValueError(
601
+ f"Cannot duplicate `image` of batch size {image_latents.shape[0]} to {batch_size} text prompts."
602
+ )
603
+ else:
604
+ image_latents = torch.cat([image_latents], dim=0)
605
+
606
+ if latents is None:
607
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
608
+ latents = self.scheduler.scale_noise(image_latents, timestep, noise)
609
+ else:
610
+ noise = latents.to(device)
611
+ latents = noise
612
+
613
+ noise = self._pack_latents(noise, batch_size, num_channels_latents, height, width)
614
+ image_latents = self._pack_latents(image_latents, batch_size, num_channels_latents, height, width)
615
+ latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
616
+ return latents, noise, image_latents, latent_image_ids
617
+
618
+ # Copied from diffusers.pipelines.flux.pipeline_flux_inpaint.FluxInpaintPipeline.prepare_mask_latents
619
+ def prepare_mask_latents(
620
+ self,
621
+ mask,
622
+ masked_image,
623
+ batch_size,
624
+ num_channels_latents,
625
+ num_images_per_prompt,
626
+ height,
627
+ width,
628
+ dtype,
629
+ device,
630
+ generator,
631
+ ):
632
+ # VAE applies 8x compression on images but we must also account for packing which requires
633
+ # latent height and width to be divisible by 2.
634
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
635
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
636
+ # resize the mask to latents shape as we concatenate the mask to the latents
637
+ # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
638
+ # and half precision
639
+ mask = torch.nn.functional.interpolate(mask, size=(height, width))
640
+ mask = mask.to(device=device, dtype=dtype)
641
+
642
+ batch_size = batch_size * num_images_per_prompt
643
+
644
+ masked_image = masked_image.to(device=device, dtype=dtype)
645
+
646
+ if masked_image.shape[1] == 16:
647
+ masked_image_latents = masked_image
648
+ else:
649
+ masked_image_latents = retrieve_latents(self.vae.encode(masked_image), generator=generator)
650
+
651
+ masked_image_latents = (masked_image_latents - self.vae.config.shift_factor) * self.vae.config.scaling_factor
652
+
653
+ # duplicate mask and masked_image_latents for each generation per prompt, using mps friendly method
654
+ if mask.shape[0] < batch_size:
655
+ if not batch_size % mask.shape[0] == 0:
656
+ raise ValueError(
657
+ "The passed mask and the required batch size don't match. Masks are supposed to be duplicated to"
658
+ f" a total batch size of {batch_size}, but {mask.shape[0]} masks were passed. Make sure the number"
659
+ " of masks that you pass is divisible by the total requested batch size."
660
+ )
661
+ mask = mask.repeat(batch_size // mask.shape[0], 1, 1, 1)
662
+ if masked_image_latents.shape[0] < batch_size:
663
+ if not batch_size % masked_image_latents.shape[0] == 0:
664
+ raise ValueError(
665
+ "The passed images and the required batch size don't match. Images are supposed to be duplicated"
666
+ f" to a total batch size of {batch_size}, but {masked_image_latents.shape[0]} images were passed."
667
+ " Make sure the number of images that you pass is divisible by the total requested batch size."
668
+ )
669
+ masked_image_latents = masked_image_latents.repeat(batch_size // masked_image_latents.shape[0], 1, 1, 1)
670
+
671
+ # aligning device to prevent device errors when concating it with the latent model input
672
+ masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
673
+ masked_image_latents = self._pack_latents(
674
+ masked_image_latents,
675
+ batch_size,
676
+ num_channels_latents,
677
+ height,
678
+ width,
679
+ )
680
+ mask = self._pack_latents(
681
+ mask.repeat(1, num_channels_latents, 1, 1),
682
+ batch_size,
683
+ num_channels_latents,
684
+ height,
685
+ width,
686
+ )
687
+
688
+ return mask, masked_image_latents
689
+
690
+ # Copied from diffusers.pipelines.controlnet_sd3.pipeline_stable_diffusion_3_controlnet.StableDiffusion3ControlNetPipeline.prepare_image
691
+ def prepare_image(
692
+ self,
693
+ image,
694
+ width,
695
+ height,
696
+ batch_size,
697
+ num_images_per_prompt,
698
+ device,
699
+ dtype,
700
+ do_classifier_free_guidance=False,
701
+ guess_mode=False,
702
+ ):
703
+ if isinstance(image, torch.Tensor):
704
+ pass
705
+ else:
706
+ image = self.image_processor.preprocess(image, height=height, width=width)
707
+
708
+ image_batch_size = image.shape[0]
709
+
710
+ if image_batch_size == 1:
711
+ repeat_by = batch_size
712
+ else:
713
+ # image batch size is the same as prompt batch size
714
+ repeat_by = num_images_per_prompt
715
+
716
+ image = image.repeat_interleave(repeat_by, dim=0)
717
+
718
+ image = image.to(device=device, dtype=dtype)
719
+
720
+ if do_classifier_free_guidance and not guess_mode:
721
+ image = torch.cat([image] * 2)
722
+
723
+ return image
724
+
725
+ @property
726
+ def guidance_scale(self):
727
+ return self._guidance_scale
728
+
729
+ @property
730
+ def joint_attention_kwargs(self):
731
+ return self._joint_attention_kwargs
732
+
733
+ @property
734
+ def num_timesteps(self):
735
+ return self._num_timesteps
736
+
737
+ @property
738
+ def interrupt(self):
739
+ return self._interrupt
740
+
741
+ @torch.no_grad()
742
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
743
+ def __call__(
744
+ self,
745
+ prompt: Union[str, List[str]] = None,
746
+ prompt_2: Optional[Union[str, List[str]]] = None,
747
+ image: PipelineImageInput = None,
748
+ mask_image: PipelineImageInput = None,
749
+ masked_image_latents: PipelineImageInput = None,
750
+ control_image: PipelineImageInput = None,
751
+ height: Optional[int] = None,
752
+ width: Optional[int] = None,
753
+ strength: float = 0.6,
754
+ padding_mask_crop: Optional[int] = None,
755
+ sigmas: Optional[List[float]] = None,
756
+ num_inference_steps: int = 28,
757
+ guidance_scale: float = 7.0,
758
+ control_guidance_start: Union[float, List[float]] = 0.0,
759
+ control_guidance_end: Union[float, List[float]] = 1.0,
760
+ control_mode: Optional[Union[int, List[int]]] = None,
761
+ controlnet_conditioning_scale: Union[float, List[float]] = 1.0,
762
+ num_images_per_prompt: Optional[int] = 1,
763
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
764
+ latents: Optional[torch.FloatTensor] = None,
765
+ prompt_embeds: Optional[torch.FloatTensor] = None,
766
+ pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
767
+ output_type: Optional[str] = "pil",
768
+ return_dict: bool = True,
769
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
770
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
771
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
772
+ max_sequence_length: int = 512,
773
+ ):
774
+ """
775
+ Function invoked when calling the pipeline for generation.
776
+
777
+ Args:
778
+ prompt (`str` or `List[str]`, *optional*):
779
+ The prompt or prompts to guide the image generation.
780
+ prompt_2 (`str` or `List[str]`, *optional*):
781
+ The prompt or prompts to be sent to the `tokenizer_2` and `text_encoder_2`.
782
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
783
+ The image(s) to inpaint.
784
+ mask_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
785
+ The mask image(s) to use for inpainting. White pixels in the mask will be repainted, while black pixels
786
+ will be preserved.
787
+ masked_image_latents (`torch.FloatTensor`, *optional*):
788
+ Pre-generated masked image latents.
789
+ control_image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
790
+ The ControlNet input condition. Image to control the generation.
791
+ height (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
792
+ The height in pixels of the generated image.
793
+ width (`int`, *optional*, defaults to self.default_sample_size * self.vae_scale_factor):
794
+ The width in pixels of the generated image.
795
+ strength (`float`, *optional*, defaults to 0.6):
796
+ Conceptually, indicates how much to inpaint the masked area. Must be between 0 and 1.
797
+ padding_mask_crop (`int`, *optional*):
798
+ The size of the padding to use when cropping the mask.
799
+ num_inference_steps (`int`, *optional*, defaults to 28):
800
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
801
+ expense of slower inference.
802
+ sigmas (`List[float]`, *optional*):
803
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
804
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
805
+ will be used.
806
+ guidance_scale (`float`, *optional*, defaults to 7.0):
807
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
808
+ control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
809
+ The percentage of total steps at which the ControlNet starts applying.
810
+ control_guidance_end (`float` or `List[float]`, *optional*, defaults to 1.0):
811
+ The percentage of total steps at which the ControlNet stops applying.
812
+ control_mode (`int` or `List[int]`, *optional*):
813
+ The mode for the ControlNet. If multiple ControlNets are used, this should be a list.
814
+ controlnet_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
815
+ The outputs of the ControlNet are multiplied by `controlnet_conditioning_scale` before they are added
816
+ to the residual in the original transformer.
817
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
818
+ The number of images to generate per prompt.
819
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
820
+ One or more [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) to
821
+ make generation deterministic.
822
+ latents (`torch.FloatTensor`, *optional*):
823
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
824
+ generation. Can be used to tweak the same generation with different prompts.
825
+ prompt_embeds (`torch.FloatTensor`, *optional*):
826
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
827
+ pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
828
+ Pre-generated pooled text embeddings.
829
+ output_type (`str`, *optional*, defaults to `"pil"`):
830
+ The output format of the generate image. Choose between `PIL.Image` or `np.array`.
831
+ return_dict (`bool`, *optional*, defaults to `True`):
832
+ Whether or not to return a [`~pipelines.flux.FluxPipelineOutput`] instead of a plain tuple.
833
+ joint_attention_kwargs (`dict`, *optional*):
834
+ Additional keyword arguments to be passed to the joint attention mechanism.
835
+ callback_on_step_end (`Callable`, *optional*):
836
+ A function that calls at the end of each denoising step during the inference.
837
+ callback_on_step_end_tensor_inputs (`List[str]`, *optional*):
838
+ The list of tensor inputs for the `callback_on_step_end` function.
839
+ max_sequence_length (`int`, *optional*, defaults to 512):
840
+ The maximum length of the sequence to be generated.
841
+
842
+ Examples:
843
+
844
+ Returns:
845
+ [`~pipelines.flux.FluxPipelineOutput`] or `tuple`: [`~pipelines.flux.FluxPipelineOutput`] if `return_dict`
846
+ is True, otherwise a `tuple`. When returning a tuple, the first element is a list with the generated
847
+ images.
848
+ """
849
+ height = height or self.default_sample_size * self.vae_scale_factor
850
+ width = width or self.default_sample_size * self.vae_scale_factor
851
+
852
+ global_height = height
853
+ global_width = width
854
+
855
+ if not isinstance(control_guidance_start, list) and isinstance(control_guidance_end, list):
856
+ control_guidance_start = len(control_guidance_end) * [control_guidance_start]
857
+ elif not isinstance(control_guidance_end, list) and isinstance(control_guidance_start, list):
858
+ control_guidance_end = len(control_guidance_start) * [control_guidance_end]
859
+ elif not isinstance(control_guidance_start, list) and not isinstance(control_guidance_end, list):
860
+ mult = len(self.controlnet.nets) if isinstance(self.controlnet, FluxMultiControlNetModel) else 1
861
+ control_guidance_start, control_guidance_end = (
862
+ mult * [control_guidance_start],
863
+ mult * [control_guidance_end],
864
+ )
865
+
866
+ # 1. Check inputs
867
+ self.check_inputs(
868
+ prompt,
869
+ prompt_2,
870
+ image,
871
+ mask_image,
872
+ strength,
873
+ height,
874
+ width,
875
+ output_type=output_type,
876
+ prompt_embeds=prompt_embeds,
877
+ pooled_prompt_embeds=pooled_prompt_embeds,
878
+ callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
879
+ padding_mask_crop=padding_mask_crop,
880
+ max_sequence_length=max_sequence_length,
881
+ )
882
+
883
+ self._guidance_scale = guidance_scale
884
+ self._joint_attention_kwargs = joint_attention_kwargs
885
+ self._interrupt = False
886
+
887
+ # 2. Define call parameters
888
+ if prompt is not None and isinstance(prompt, str):
889
+ batch_size = 1
890
+ elif prompt is not None and isinstance(prompt, list):
891
+ batch_size = len(prompt)
892
+ else:
893
+ batch_size = prompt_embeds.shape[0]
894
+
895
+ device = self._execution_device
896
+ dtype = self.transformer.dtype
897
+
898
+ # 3. Encode input prompt
899
+ lora_scale = (
900
+ self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
901
+ )
902
+ prompt_embeds, pooled_prompt_embeds, text_ids = self.encode_prompt(
903
+ prompt=prompt,
904
+ prompt_2=prompt_2,
905
+ prompt_embeds=prompt_embeds,
906
+ pooled_prompt_embeds=pooled_prompt_embeds,
907
+ device=device,
908
+ num_images_per_prompt=num_images_per_prompt,
909
+ max_sequence_length=max_sequence_length,
910
+ lora_scale=lora_scale,
911
+ )
912
+
913
+ # 4. Preprocess mask and image
914
+ if padding_mask_crop is not None:
915
+ crops_coords = self.mask_processor.get_crop_region(
916
+ mask_image, global_width, global_height, pad=padding_mask_crop
917
+ )
918
+ resize_mode = "fill"
919
+ else:
920
+ crops_coords = None
921
+ resize_mode = "default"
922
+
923
+ original_image = image
924
+ init_image = self.image_processor.preprocess(
925
+ image, height=global_height, width=global_width, crops_coords=crops_coords, resize_mode=resize_mode
926
+ )
927
+ init_image = init_image.to(dtype=torch.float32)
928
+
929
+ # 5. Prepare control image
930
+ num_channels_latents = self.transformer.config.in_channels // 4
931
+ if isinstance(self.controlnet, FluxControlNetModel):
932
+ control_image = self.prepare_image(
933
+ image=control_image,
934
+ width=height,
935
+ height=width,
936
+ batch_size=batch_size * num_images_per_prompt,
937
+ num_images_per_prompt=num_images_per_prompt,
938
+ device=device,
939
+ dtype=self.vae.dtype,
940
+ )
941
+ height, width = control_image.shape[-2:]
942
+
943
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
944
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
945
+ if self.controlnet.input_hint_block is None:
946
+ # vae encode
947
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
948
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
949
+
950
+ # pack
951
+ height_control_image, width_control_image = control_image.shape[2:]
952
+ control_image = self._pack_latents(
953
+ control_image,
954
+ batch_size * num_images_per_prompt,
955
+ num_channels_latents,
956
+ height_control_image,
957
+ width_control_image,
958
+ )
959
+
960
+ # set control mode
961
+ if control_mode is not None:
962
+ control_mode = torch.tensor(control_mode).to(device, dtype=torch.long)
963
+ control_mode = control_mode.reshape([-1, 1])
964
+
965
+ elif isinstance(self.controlnet, FluxMultiControlNetModel):
966
+ control_images = []
967
+
968
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
969
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
970
+ for i, control_image_ in enumerate(control_image):
971
+ control_image_ = self.prepare_image(
972
+ image=control_image_,
973
+ width=width,
974
+ height=height,
975
+ batch_size=batch_size * num_images_per_prompt,
976
+ num_images_per_prompt=num_images_per_prompt,
977
+ device=device,
978
+ dtype=self.vae.dtype,
979
+ )
980
+ height, width = control_image_.shape[-2:]
981
+
982
+ if self.controlnet.nets[0].input_hint_block is None:
983
+ # vae encode
984
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
985
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
986
+
987
+ # pack
988
+ height_control_image, width_control_image = control_image_.shape[2:]
989
+ control_image_ = self._pack_latents(
990
+ control_image_,
991
+ batch_size * num_images_per_prompt,
992
+ num_channels_latents,
993
+ height_control_image,
994
+ width_control_image,
995
+ )
996
+
997
+ control_images.append(control_image_)
998
+
999
+ control_image = control_images
1000
+
1001
+ # set control mode
1002
+ control_mode_ = []
1003
+ if isinstance(control_mode, list):
1004
+ for cmode in control_mode:
1005
+ if cmode is None:
1006
+ control_mode_.append(-1)
1007
+ else:
1008
+ control_mode_.append(cmode)
1009
+ control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
1010
+ control_mode = control_mode.reshape([-1, 1])
1011
+
1012
+ # 6. Prepare timesteps
1013
+
1014
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1015
+ image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
1016
+ int(global_width) // self.vae_scale_factor // 2
1017
+ )
1018
+ mu = calculate_shift(
1019
+ image_seq_len,
1020
+ self.scheduler.config.base_image_seq_len,
1021
+ self.scheduler.config.max_image_seq_len,
1022
+ self.scheduler.config.base_shift,
1023
+ self.scheduler.config.max_shift,
1024
+ )
1025
+ timesteps, num_inference_steps = retrieve_timesteps(
1026
+ self.scheduler,
1027
+ num_inference_steps,
1028
+ device,
1029
+ sigmas=sigmas,
1030
+ mu=mu,
1031
+ )
1032
+ timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1033
+
1034
+ if num_inference_steps < 1:
1035
+ raise ValueError(
1036
+ f"After adjusting the num_inference_steps by strength parameter: {strength}, the number of pipeline"
1037
+ f"steps is {num_inference_steps} which is < 1 and not appropriate for this pipeline."
1038
+ )
1039
+ latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
1040
+
1041
+ # 7. Prepare latent variables
1042
+
1043
+ latents, noise, image_latents, latent_image_ids = self.prepare_latents(
1044
+ init_image,
1045
+ latent_timestep,
1046
+ batch_size * num_images_per_prompt,
1047
+ num_channels_latents,
1048
+ global_height,
1049
+ global_width,
1050
+ prompt_embeds.dtype,
1051
+ device,
1052
+ generator,
1053
+ latents,
1054
+ )
1055
+
1056
+ # 8. Prepare mask latents
1057
+ mask_condition = self.mask_processor.preprocess(
1058
+ mask_image, height=global_height, width=global_width, resize_mode=resize_mode, crops_coords=crops_coords
1059
+ )
1060
+ if masked_image_latents is None:
1061
+ masked_image = init_image * (mask_condition < 0.5)
1062
+ else:
1063
+ masked_image = masked_image_latents
1064
+
1065
+ mask, masked_image_latents = self.prepare_mask_latents(
1066
+ mask_condition,
1067
+ masked_image,
1068
+ batch_size,
1069
+ num_channels_latents,
1070
+ num_images_per_prompt,
1071
+ global_height,
1072
+ global_width,
1073
+ prompt_embeds.dtype,
1074
+ device,
1075
+ generator,
1076
+ )
1077
+
1078
+ controlnet_keep = []
1079
+ for i in range(len(timesteps)):
1080
+ keeps = [
1081
+ 1.0 - float(i / len(timesteps) < s or (i + 1) / len(timesteps) > e)
1082
+ for s, e in zip(control_guidance_start, control_guidance_end)
1083
+ ]
1084
+ controlnet_keep.append(keeps[0] if isinstance(self.controlnet, FluxControlNetModel) else keeps)
1085
+
1086
+ # 9. Denoising loop
1087
+ num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
1088
+ self._num_timesteps = len(timesteps)
1089
+
1090
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
1091
+ for i, t in enumerate(timesteps):
1092
+ if self.interrupt:
1093
+ continue
1094
+
1095
+ timestep = t.expand(latents.shape[0]).to(latents.dtype)
1096
+
1097
+ # predict the noise residual
1098
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
1099
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
1100
+ else:
1101
+ use_guidance = self.controlnet.config.guidance_embeds
1102
+ if use_guidance:
1103
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1104
+ guidance = guidance.expand(latents.shape[0])
1105
+ else:
1106
+ guidance = None
1107
+
1108
+ if isinstance(controlnet_keep[i], list):
1109
+ cond_scale = [c * s for c, s in zip(controlnet_conditioning_scale, controlnet_keep[i])]
1110
+ else:
1111
+ controlnet_cond_scale = controlnet_conditioning_scale
1112
+ if isinstance(controlnet_cond_scale, list):
1113
+ controlnet_cond_scale = controlnet_cond_scale[0]
1114
+ cond_scale = controlnet_cond_scale * controlnet_keep[i]
1115
+
1116
+ controlnet_block_samples, controlnet_single_block_samples = self.controlnet(
1117
+ hidden_states=latents,
1118
+ controlnet_cond=control_image,
1119
+ controlnet_mode=control_mode,
1120
+ conditioning_scale=cond_scale,
1121
+ timestep=timestep / 1000,
1122
+ guidance=guidance,
1123
+ pooled_projections=pooled_prompt_embeds,
1124
+ encoder_hidden_states=prompt_embeds,
1125
+ txt_ids=text_ids,
1126
+ img_ids=latent_image_ids,
1127
+ joint_attention_kwargs=self.joint_attention_kwargs,
1128
+ return_dict=False,
1129
+ )
1130
+
1131
+ if self.transformer.config.guidance_embeds:
1132
+ guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1133
+ guidance = guidance.expand(latents.shape[0])
1134
+ else:
1135
+ guidance = None
1136
+
1137
+ noise_pred = self.transformer(
1138
+ hidden_states=latents,
1139
+ timestep=timestep / 1000,
1140
+ guidance=guidance,
1141
+ pooled_projections=pooled_prompt_embeds,
1142
+ encoder_hidden_states=prompt_embeds,
1143
+ controlnet_block_samples=controlnet_block_samples,
1144
+ controlnet_single_block_samples=controlnet_single_block_samples,
1145
+ txt_ids=text_ids,
1146
+ img_ids=latent_image_ids,
1147
+ joint_attention_kwargs=self.joint_attention_kwargs,
1148
+ return_dict=False,
1149
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
1150
+ )[0]
1151
+
1152
+ # compute the previous noisy sample x_t -> x_t-1
1153
+ latents_dtype = latents.dtype
1154
+ latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
1155
+
1156
+ # For inpainting, we need to apply the mask and add the masked image latents
1157
+ init_latents_proper = image_latents
1158
+ init_mask = mask
1159
+
1160
+ if i < len(timesteps) - 1:
1161
+ noise_timestep = timesteps[i + 1]
1162
+ init_latents_proper = self.scheduler.scale_noise(
1163
+ init_latents_proper, torch.tensor([noise_timestep]), noise
1164
+ )
1165
+
1166
+ latents = (1 - init_mask) * init_latents_proper + init_mask * latents
1167
+
1168
+ if latents.dtype != latents_dtype:
1169
+ if torch.backends.mps.is_available():
1170
+ # some platforms (eg. apple mps) misbehave due to a pytorch bug: https://github.com/pytorch/pytorch/pull/99272
1171
+ latents = latents.to(latents_dtype)
1172
+
1173
+ # call the callback, if provided
1174
+ if callback_on_step_end is not None:
1175
+ callback_kwargs = {}
1176
+ for k in callback_on_step_end_tensor_inputs:
1177
+ callback_kwargs[k] = locals()[k]
1178
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
1179
+
1180
+ latents = callback_outputs.pop("latents", latents)
1181
+ prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1182
+
1183
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
1184
+ progress_bar.update()
1185
+
1186
+ if XLA_AVAILABLE:
1187
+ xm.mark_step()
1188
+
1189
+ # Post-processing
1190
+ if output_type == "latent":
1191
+ image = latents
1192
+ else:
1193
+ latents = self._unpack_latents(latents, global_height, global_width, self.vae_scale_factor)
1194
+ latents = (latents / self.vae.config.scaling_factor) + self.vae.config.shift_factor
1195
+ image = self.vae.decode(latents, return_dict=False)[0]
1196
+ image = self.image_processor.postprocess(image, output_type=output_type)
1197
+
1198
+ # Offload all models
1199
+ self.maybe_free_model_hooks()
1200
+
1201
+ if not return_dict:
1202
+ return (image,)
1203
+
1204
+ return FluxPipelineOutput(images=image)