diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -17,10 +17,17 @@ from typing import Any, Callable, Dict, List, Optional, Union
17
17
 
18
18
  import numpy as np
19
19
  import torch
20
- from transformers import CLIPTextModel, CLIPTokenizer, T5EncoderModel, T5TokenizerFast
20
+ from transformers import (
21
+ CLIPImageProcessor,
22
+ CLIPTextModel,
23
+ CLIPTokenizer,
24
+ CLIPVisionModelWithProjection,
25
+ T5EncoderModel,
26
+ T5TokenizerFast,
27
+ )
21
28
 
22
- from ...image_processor import VaeImageProcessor
23
- from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
29
+ from ...image_processor import PipelineImageInput, VaeImageProcessor
30
+ from ...loaders import FluxIPAdapterMixin, FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
24
31
  from ...models.autoencoders import AutoencoderKL
25
32
  from ...models.transformers import FluxTransformer2DModel
26
33
  from ...schedulers import FlowMatchEulerDiscreteScheduler
@@ -142,6 +149,7 @@ class FluxPipeline(
142
149
  FluxLoraLoaderMixin,
143
150
  FromSingleFileMixin,
144
151
  TextualInversionLoaderMixin,
152
+ FluxIPAdapterMixin,
145
153
  ):
146
154
  r"""
147
155
  The Flux pipeline for text-to-image generation.
@@ -169,8 +177,8 @@ class FluxPipeline(
169
177
  [T5TokenizerFast](https://huggingface.co/docs/transformers/en/model_doc/t5#transformers.T5TokenizerFast).
170
178
  """
171
179
 
172
- model_cpu_offload_seq = "text_encoder->text_encoder_2->transformer->vae"
173
- _optional_components = []
180
+ model_cpu_offload_seq = "text_encoder->text_encoder_2->image_encoder->transformer->vae"
181
+ _optional_components = ["image_encoder", "feature_extractor"]
174
182
  _callback_tensor_inputs = ["latents", "prompt_embeds"]
175
183
 
176
184
  def __init__(
@@ -182,6 +190,8 @@ class FluxPipeline(
182
190
  text_encoder_2: T5EncoderModel,
183
191
  tokenizer_2: T5TokenizerFast,
184
192
  transformer: FluxTransformer2DModel,
193
+ image_encoder: CLIPVisionModelWithProjection = None,
194
+ feature_extractor: CLIPImageProcessor = None,
185
195
  ):
186
196
  super().__init__()
187
197
 
@@ -193,15 +203,19 @@ class FluxPipeline(
193
203
  tokenizer_2=tokenizer_2,
194
204
  transformer=transformer,
195
205
  scheduler=scheduler,
206
+ image_encoder=image_encoder,
207
+ feature_extractor=feature_extractor,
196
208
  )
197
209
  self.vae_scale_factor = (
198
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
210
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
199
211
  )
200
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
212
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
213
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
214
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
201
215
  self.tokenizer_max_length = (
202
216
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
203
217
  )
204
- self.default_sample_size = 64
218
+ self.default_sample_size = 128
205
219
 
206
220
  def _get_t5_prompt_embeds(
207
221
  self,
@@ -375,19 +389,67 @@ class FluxPipeline(
375
389
 
376
390
  return prompt_embeds, pooled_prompt_embeds, text_ids
377
391
 
392
+ def encode_image(self, image, device, num_images_per_prompt):
393
+ dtype = next(self.image_encoder.parameters()).dtype
394
+
395
+ if not isinstance(image, torch.Tensor):
396
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
397
+
398
+ image = image.to(device=device, dtype=dtype)
399
+ image_embeds = self.image_encoder(image).image_embeds
400
+ image_embeds = image_embeds.repeat_interleave(num_images_per_prompt, dim=0)
401
+ return image_embeds
402
+
403
+ def prepare_ip_adapter_image_embeds(
404
+ self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt
405
+ ):
406
+ image_embeds = []
407
+ if ip_adapter_image_embeds is None:
408
+ if not isinstance(ip_adapter_image, list):
409
+ ip_adapter_image = [ip_adapter_image]
410
+
411
+ if len(ip_adapter_image) != len(self.transformer.encoder_hid_proj.image_projection_layers):
412
+ raise ValueError(
413
+ f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.transformer.encoder_hid_proj.image_projection_layers)} IP Adapters."
414
+ )
415
+
416
+ for single_ip_adapter_image, image_proj_layer in zip(
417
+ ip_adapter_image, self.transformer.encoder_hid_proj.image_projection_layers
418
+ ):
419
+ single_image_embeds = self.encode_image(single_ip_adapter_image, device, 1)
420
+
421
+ image_embeds.append(single_image_embeds[None, :])
422
+ else:
423
+ for single_image_embeds in ip_adapter_image_embeds:
424
+ image_embeds.append(single_image_embeds)
425
+
426
+ ip_adapter_image_embeds = []
427
+ for i, single_image_embeds in enumerate(image_embeds):
428
+ single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
429
+ single_image_embeds = single_image_embeds.to(device=device)
430
+ ip_adapter_image_embeds.append(single_image_embeds)
431
+
432
+ return ip_adapter_image_embeds
433
+
378
434
  def check_inputs(
379
435
  self,
380
436
  prompt,
381
437
  prompt_2,
382
438
  height,
383
439
  width,
440
+ negative_prompt=None,
441
+ negative_prompt_2=None,
384
442
  prompt_embeds=None,
443
+ negative_prompt_embeds=None,
385
444
  pooled_prompt_embeds=None,
445
+ negative_pooled_prompt_embeds=None,
386
446
  callback_on_step_end_tensor_inputs=None,
387
447
  max_sequence_length=None,
388
448
  ):
389
- if height % 8 != 0 or width % 8 != 0:
390
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
449
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
450
+ logger.warning(
451
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
452
+ )
391
453
 
392
454
  if callback_on_step_end_tensor_inputs is not None and not all(
393
455
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -415,19 +477,42 @@ class FluxPipeline(
415
477
  elif prompt_2 is not None and (not isinstance(prompt_2, str) and not isinstance(prompt_2, list)):
416
478
  raise ValueError(f"`prompt_2` has to be of type `str` or `list` but is {type(prompt_2)}")
417
479
 
480
+ if negative_prompt is not None and negative_prompt_embeds is not None:
481
+ raise ValueError(
482
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
483
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
484
+ )
485
+ elif negative_prompt_2 is not None and negative_prompt_embeds is not None:
486
+ raise ValueError(
487
+ f"Cannot forward both `negative_prompt_2`: {negative_prompt_2} and `negative_prompt_embeds`:"
488
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
489
+ )
490
+
491
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
492
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
493
+ raise ValueError(
494
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
495
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
496
+ f" {negative_prompt_embeds.shape}."
497
+ )
498
+
418
499
  if prompt_embeds is not None and pooled_prompt_embeds is None:
419
500
  raise ValueError(
420
501
  "If `prompt_embeds` are provided, `pooled_prompt_embeds` also have to be passed. Make sure to generate `pooled_prompt_embeds` from the same text encoder that was used to generate `prompt_embeds`."
421
502
  )
503
+ if negative_prompt_embeds is not None and negative_pooled_prompt_embeds is None:
504
+ raise ValueError(
505
+ "If `negative_prompt_embeds` are provided, `negative_pooled_prompt_embeds` also have to be passed. Make sure to generate `negative_pooled_prompt_embeds` from the same text encoder that was used to generate `negative_prompt_embeds`."
506
+ )
422
507
 
423
508
  if max_sequence_length is not None and max_sequence_length > 512:
424
509
  raise ValueError(f"`max_sequence_length` cannot be greater than 512 but is {max_sequence_length}")
425
510
 
426
511
  @staticmethod
427
512
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
428
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
429
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
430
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
513
+ latent_image_ids = torch.zeros(height, width, 3)
514
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
515
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
431
516
 
432
517
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
433
518
 
@@ -449,13 +534,15 @@ class FluxPipeline(
449
534
  def _unpack_latents(latents, height, width, vae_scale_factor):
450
535
  batch_size, num_patches, channels = latents.shape
451
536
 
452
- height = height // vae_scale_factor
453
- width = width // vae_scale_factor
537
+ # VAE applies 8x compression on images but we must also account for packing which requires
538
+ # latent height and width to be divisible by 2.
539
+ height = 2 * (int(height) // (vae_scale_factor * 2))
540
+ width = 2 * (int(width) // (vae_scale_factor * 2))
454
541
 
455
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
542
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
456
543
  latents = latents.permute(0, 3, 1, 4, 2, 5)
457
544
 
458
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
545
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
459
546
 
460
547
  return latents
461
548
 
@@ -499,13 +586,15 @@ class FluxPipeline(
499
586
  generator,
500
587
  latents=None,
501
588
  ):
502
- height = 2 * (int(height) // self.vae_scale_factor)
503
- width = 2 * (int(width) // self.vae_scale_factor)
589
+ # VAE applies 8x compression on images but we must also account for packing which requires
590
+ # latent height and width to be divisible by 2.
591
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
592
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
504
593
 
505
594
  shape = (batch_size, num_channels_latents, height, width)
506
595
 
507
596
  if latents is not None:
508
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
597
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
509
598
  return latents.to(device=device, dtype=dtype), latent_image_ids
510
599
 
511
600
  if isinstance(generator, list) and len(generator) != batch_size:
@@ -517,7 +606,7 @@ class FluxPipeline(
517
606
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
518
607
  latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
519
608
 
520
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
609
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
521
610
 
522
611
  return latents, latent_image_ids
523
612
 
@@ -543,16 +632,25 @@ class FluxPipeline(
543
632
  self,
544
633
  prompt: Union[str, List[str]] = None,
545
634
  prompt_2: Optional[Union[str, List[str]]] = None,
635
+ negative_prompt: Union[str, List[str]] = None,
636
+ negative_prompt_2: Optional[Union[str, List[str]]] = None,
637
+ true_cfg_scale: float = 1.0,
546
638
  height: Optional[int] = None,
547
639
  width: Optional[int] = None,
548
640
  num_inference_steps: int = 28,
549
- timesteps: List[int] = None,
641
+ sigmas: Optional[List[float]] = None,
550
642
  guidance_scale: float = 3.5,
551
643
  num_images_per_prompt: Optional[int] = 1,
552
644
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
553
645
  latents: Optional[torch.FloatTensor] = None,
554
646
  prompt_embeds: Optional[torch.FloatTensor] = None,
555
647
  pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
648
+ ip_adapter_image: Optional[PipelineImageInput] = None,
649
+ ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
650
+ negative_ip_adapter_image: Optional[PipelineImageInput] = None,
651
+ negative_ip_adapter_image_embeds: Optional[List[torch.Tensor]] = None,
652
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
653
+ negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None,
556
654
  output_type: Optional[str] = "pil",
557
655
  return_dict: bool = True,
558
656
  joint_attention_kwargs: Optional[Dict[str, Any]] = None,
@@ -577,10 +675,10 @@ class FluxPipeline(
577
675
  num_inference_steps (`int`, *optional*, defaults to 50):
578
676
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
579
677
  expense of slower inference.
580
- timesteps (`List[int]`, *optional*):
581
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
582
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
583
- passed will be used. Must be in descending order.
678
+ sigmas (`List[float]`, *optional*):
679
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
680
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
681
+ will be used.
584
682
  guidance_scale (`float`, *optional*, defaults to 7.0):
585
683
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
586
684
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -602,6 +700,17 @@ class FluxPipeline(
602
700
  pooled_prompt_embeds (`torch.FloatTensor`, *optional*):
603
701
  Pre-generated pooled text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting.
604
702
  If not provided, pooled text embeddings will be generated from `prompt` input argument.
703
+ ip_adapter_image: (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
704
+ ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
705
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
706
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
707
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
708
+ negative_ip_adapter_image:
709
+ (`PipelineImageInput`, *optional*): Optional image input to work with IP Adapters.
710
+ negative_ip_adapter_image_embeds (`List[torch.Tensor]`, *optional*):
711
+ Pre-generated image embeddings for IP-Adapter. It should be a list of length same as number of
712
+ IP-adapters. Each element should be a tensor of shape `(batch_size, num_images, emb_dim)`. If not
713
+ provided, embeddings are computed from the `ip_adapter_image` input argument.
605
714
  output_type (`str`, *optional*, defaults to `"pil"`):
606
715
  The output format of the generate image. Choose between
607
716
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
@@ -639,8 +748,12 @@ class FluxPipeline(
639
748
  prompt_2,
640
749
  height,
641
750
  width,
751
+ negative_prompt=negative_prompt,
752
+ negative_prompt_2=negative_prompt_2,
642
753
  prompt_embeds=prompt_embeds,
754
+ negative_prompt_embeds=negative_prompt_embeds,
643
755
  pooled_prompt_embeds=pooled_prompt_embeds,
756
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
644
757
  callback_on_step_end_tensor_inputs=callback_on_step_end_tensor_inputs,
645
758
  max_sequence_length=max_sequence_length,
646
759
  )
@@ -662,6 +775,7 @@ class FluxPipeline(
662
775
  lora_scale = (
663
776
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
664
777
  )
778
+ do_true_cfg = true_cfg_scale > 1 and negative_prompt is not None
665
779
  (
666
780
  prompt_embeds,
667
781
  pooled_prompt_embeds,
@@ -676,6 +790,21 @@ class FluxPipeline(
676
790
  max_sequence_length=max_sequence_length,
677
791
  lora_scale=lora_scale,
678
792
  )
793
+ if do_true_cfg:
794
+ (
795
+ negative_prompt_embeds,
796
+ negative_pooled_prompt_embeds,
797
+ _,
798
+ ) = self.encode_prompt(
799
+ prompt=negative_prompt,
800
+ prompt_2=negative_prompt_2,
801
+ prompt_embeds=negative_prompt_embeds,
802
+ pooled_prompt_embeds=negative_pooled_prompt_embeds,
803
+ device=device,
804
+ num_images_per_prompt=num_images_per_prompt,
805
+ max_sequence_length=max_sequence_length,
806
+ lora_scale=lora_scale,
807
+ )
679
808
 
680
809
  # 4. Prepare latent variables
681
810
  num_channels_latents = self.transformer.config.in_channels // 4
@@ -691,7 +820,7 @@ class FluxPipeline(
691
820
  )
692
821
 
693
822
  # 5. Prepare timesteps
694
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
823
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
695
824
  image_seq_len = latents.shape[1]
696
825
  mu = calculate_shift(
697
826
  image_seq_len,
@@ -704,8 +833,7 @@ class FluxPipeline(
704
833
  self.scheduler,
705
834
  num_inference_steps,
706
835
  device,
707
- timesteps,
708
- sigmas,
836
+ sigmas=sigmas,
709
837
  mu=mu,
710
838
  )
711
839
  num_warmup_steps = max(len(timesteps) - num_inference_steps * self.scheduler.order, 0)
@@ -718,12 +846,43 @@ class FluxPipeline(
718
846
  else:
719
847
  guidance = None
720
848
 
849
+ if (ip_adapter_image is not None or ip_adapter_image_embeds is not None) and (
850
+ negative_ip_adapter_image is None and negative_ip_adapter_image_embeds is None
851
+ ):
852
+ negative_ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
853
+ elif (ip_adapter_image is None and ip_adapter_image_embeds is None) and (
854
+ negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None
855
+ ):
856
+ ip_adapter_image = np.zeros((width, height, 3), dtype=np.uint8)
857
+
858
+ if self.joint_attention_kwargs is None:
859
+ self._joint_attention_kwargs = {}
860
+
861
+ image_embeds = None
862
+ negative_image_embeds = None
863
+ if ip_adapter_image is not None or ip_adapter_image_embeds is not None:
864
+ image_embeds = self.prepare_ip_adapter_image_embeds(
865
+ ip_adapter_image,
866
+ ip_adapter_image_embeds,
867
+ device,
868
+ batch_size * num_images_per_prompt,
869
+ )
870
+ if negative_ip_adapter_image is not None or negative_ip_adapter_image_embeds is not None:
871
+ negative_image_embeds = self.prepare_ip_adapter_image_embeds(
872
+ negative_ip_adapter_image,
873
+ negative_ip_adapter_image_embeds,
874
+ device,
875
+ batch_size * num_images_per_prompt,
876
+ )
877
+
721
878
  # 6. Denoising loop
722
879
  with self.progress_bar(total=num_inference_steps) as progress_bar:
723
880
  for i, t in enumerate(timesteps):
724
881
  if self.interrupt:
725
882
  continue
726
883
 
884
+ if image_embeds is not None:
885
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = image_embeds
727
886
  # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
728
887
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
729
888
 
@@ -739,6 +898,22 @@ class FluxPipeline(
739
898
  return_dict=False,
740
899
  )[0]
741
900
 
901
+ if do_true_cfg:
902
+ if negative_image_embeds is not None:
903
+ self._joint_attention_kwargs["ip_adapter_image_embeds"] = negative_image_embeds
904
+ neg_noise_pred = self.transformer(
905
+ hidden_states=latents,
906
+ timestep=timestep / 1000,
907
+ guidance=guidance,
908
+ pooled_projections=negative_pooled_prompt_embeds,
909
+ encoder_hidden_states=negative_prompt_embeds,
910
+ txt_ids=text_ids,
911
+ img_ids=latent_image_ids,
912
+ joint_attention_kwargs=self.joint_attention_kwargs,
913
+ return_dict=False,
914
+ )[0]
915
+ noise_pred = neg_noise_pred + true_cfg_scale * (noise_pred - neg_noise_pred)
916
+
742
917
  # compute the previous noisy sample x_t -> x_t-1
743
918
  latents_dtype = latents.dtype
744
919
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]