diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,7 @@ from transformers import (
27
27
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
28
  from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
29
29
  from ...models.autoencoders import AutoencoderKL
30
- from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
30
+ from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
31
31
  from ...models.transformers import FluxTransformer2DModel
32
32
  from ...schedulers import FlowMatchEulerDiscreteScheduler
33
33
  from ...utils import (
@@ -97,6 +97,20 @@ def calculate_shift(
97
97
  return mu
98
98
 
99
99
 
100
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
101
+ def retrieve_latents(
102
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
103
+ ):
104
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
105
+ return encoder_output.latent_dist.sample(generator)
106
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
107
+ return encoder_output.latent_dist.mode()
108
+ elif hasattr(encoder_output, "latents"):
109
+ return encoder_output.latents
110
+ else:
111
+ raise AttributeError("Could not access latents of provided encoder_output")
112
+
113
+
100
114
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
101
115
  def retrieve_timesteps(
102
116
  scheduler,
@@ -216,13 +230,15 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
216
230
  controlnet=controlnet,
217
231
  )
218
232
  self.vae_scale_factor = (
219
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
233
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
220
234
  )
221
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
235
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
236
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
237
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
222
238
  self.tokenizer_max_length = (
223
239
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
224
240
  )
225
- self.default_sample_size = 64
241
+ self.default_sample_size = 128
226
242
 
227
243
  def _get_t5_prompt_embeds(
228
244
  self,
@@ -410,8 +426,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
410
426
  callback_on_step_end_tensor_inputs=None,
411
427
  max_sequence_length=None,
412
428
  ):
413
- if height % 8 != 0 or width % 8 != 0:
414
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
429
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
430
+ logger.warning(
431
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
432
+ )
415
433
 
416
434
  if callback_on_step_end_tensor_inputs is not None and not all(
417
435
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -450,9 +468,9 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
450
468
  @staticmethod
451
469
  # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
452
470
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
453
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
454
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
455
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
471
+ latent_image_ids = torch.zeros(height, width, 3)
472
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
473
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
456
474
 
457
475
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
458
476
 
@@ -476,13 +494,15 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
476
494
  def _unpack_latents(latents, height, width, vae_scale_factor):
477
495
  batch_size, num_patches, channels = latents.shape
478
496
 
479
- height = height // vae_scale_factor
480
- width = width // vae_scale_factor
497
+ # VAE applies 8x compression on images but we must also account for packing which requires
498
+ # latent height and width to be divisible by 2.
499
+ height = 2 * (int(height) // (vae_scale_factor * 2))
500
+ width = 2 * (int(width) // (vae_scale_factor * 2))
481
501
 
482
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
502
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
483
503
  latents = latents.permute(0, 3, 1, 4, 2, 5)
484
504
 
485
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
505
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
486
506
 
487
507
  return latents
488
508
 
@@ -498,13 +518,15 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
498
518
  generator,
499
519
  latents=None,
500
520
  ):
501
- height = 2 * (int(height) // self.vae_scale_factor)
502
- width = 2 * (int(width) // self.vae_scale_factor)
521
+ # VAE applies 8x compression on images but we must also account for packing which requires
522
+ # latent height and width to be divisible by 2.
523
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
524
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
503
525
 
504
526
  shape = (batch_size, num_channels_latents, height, width)
505
527
 
506
528
  if latents is not None:
507
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
529
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
508
530
  return latents.to(device=device, dtype=dtype), latent_image_ids
509
531
 
510
532
  if isinstance(generator, list) and len(generator) != batch_size:
@@ -516,7 +538,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
516
538
  latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
517
539
  latents = self._pack_latents(latents, batch_size, num_channels_latents, height, width)
518
540
 
519
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
541
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
520
542
 
521
543
  return latents, latent_image_ids
522
544
 
@@ -580,7 +602,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
580
602
  height: Optional[int] = None,
581
603
  width: Optional[int] = None,
582
604
  num_inference_steps: int = 28,
583
- timesteps: List[int] = None,
605
+ sigmas: Optional[List[float]] = None,
584
606
  guidance_scale: float = 7.0,
585
607
  control_guidance_start: Union[float, List[float]] = 0.0,
586
608
  control_guidance_end: Union[float, List[float]] = 1.0,
@@ -616,10 +638,10 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
616
638
  num_inference_steps (`int`, *optional*, defaults to 50):
617
639
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
618
640
  expense of slower inference.
619
- timesteps (`List[int]`, *optional*):
620
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
621
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
622
- passed will be used. Must be in descending order.
641
+ sigmas (`List[float]`, *optional*):
642
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
643
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
644
+ will be used.
623
645
  guidance_scale (`float`, *optional*, defaults to 7.0):
624
646
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
625
647
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -728,6 +750,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
728
750
  device = self._execution_device
729
751
  dtype = self.transformer.dtype
730
752
 
753
+ # 3. Prepare text embeddings
731
754
  lora_scale = (
732
755
  self.joint_attention_kwargs.get("scale", None) if self.joint_attention_kwargs is not None else None
733
756
  )
@@ -764,7 +787,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
764
787
  controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
765
788
  if self.controlnet.input_hint_block is None:
766
789
  # vae encode
767
- control_image = self.vae.encode(control_image).latent_dist.sample()
790
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
768
791
  control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
769
792
 
770
793
  # pack
@@ -802,7 +825,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
802
825
 
803
826
  if self.controlnet.nets[0].input_hint_block is None:
804
827
  # vae encode
805
- control_image_ = self.vae.encode(control_image_).latent_dist.sample()
828
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
806
829
  control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
807
830
 
808
831
  # pack
@@ -849,7 +872,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
849
872
  )
850
873
 
851
874
  # 5. Prepare timesteps
852
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
875
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
853
876
  image_seq_len = latents.shape[1]
854
877
  mu = calculate_shift(
855
878
  image_seq_len,
@@ -862,8 +885,7 @@ class FluxControlNetPipeline(DiffusionPipeline, FluxLoraLoaderMixin, FromSingleF
862
885
  self.scheduler,
863
886
  num_inference_steps,
864
887
  device,
865
- timesteps,
866
- sigmas,
888
+ sigmas=sigmas,
867
889
  mu=mu,
868
890
  )
869
891
 
@@ -13,7 +13,7 @@ from transformers import (
13
13
  from ...image_processor import PipelineImageInput, VaeImageProcessor
14
14
  from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
15
15
  from ...models.autoencoders import AutoencoderKL
16
- from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
16
+ from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
17
17
  from ...models.transformers import FluxTransformer2DModel
18
18
  from ...schedulers import FlowMatchEulerDiscreteScheduler
19
19
  from ...utils import (
@@ -228,13 +228,15 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
228
228
  controlnet=controlnet,
229
229
  )
230
230
  self.vae_scale_factor = (
231
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
231
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
232
232
  )
233
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
233
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
234
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
235
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
234
236
  self.tokenizer_max_length = (
235
237
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
236
238
  )
237
- self.default_sample_size = 64
239
+ self.default_sample_size = 128
238
240
 
239
241
  # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
240
242
  def _get_t5_prompt_embeds(
@@ -453,8 +455,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
453
455
  if strength < 0 or strength > 1:
454
456
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
455
457
 
456
- if height % 8 != 0 or width % 8 != 0:
457
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
458
+ if height % self.vae_scale_factor * 2 != 0 or width % self.vae_scale_factor * 2 != 0:
459
+ logger.warning(
460
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
461
+ )
458
462
 
459
463
  if callback_on_step_end_tensor_inputs is not None and not all(
460
464
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -493,9 +497,9 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
493
497
  @staticmethod
494
498
  # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
495
499
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
496
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
497
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
498
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
500
+ latent_image_ids = torch.zeros(height, width, 3)
501
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
502
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
499
503
 
500
504
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
501
505
 
@@ -519,13 +523,15 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
519
523
  def _unpack_latents(latents, height, width, vae_scale_factor):
520
524
  batch_size, num_patches, channels = latents.shape
521
525
 
522
- height = height // vae_scale_factor
523
- width = width // vae_scale_factor
526
+ # VAE applies 8x compression on images but we must also account for packing which requires
527
+ # latent height and width to be divisible by 2.
528
+ height = 2 * (int(height) // (vae_scale_factor * 2))
529
+ width = 2 * (int(width) // (vae_scale_factor * 2))
524
530
 
525
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
531
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
526
532
  latents = latents.permute(0, 3, 1, 4, 2, 5)
527
533
 
528
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
534
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
529
535
 
530
536
  return latents
531
537
 
@@ -549,11 +555,12 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
549
555
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
550
556
  )
551
557
 
552
- height = 2 * (int(height) // self.vae_scale_factor)
553
- width = 2 * (int(width) // self.vae_scale_factor)
554
-
558
+ # VAE applies 8x compression on images but we must also account for packing which requires
559
+ # latent height and width to be divisible by 2.
560
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
561
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
555
562
  shape = (batch_size, num_channels_latents, height, width)
556
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
563
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
557
564
 
558
565
  if latents is not None:
559
566
  return latents.to(device=device, dtype=dtype), latent_image_ids
@@ -639,7 +646,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
639
646
  width: Optional[int] = None,
640
647
  strength: float = 0.6,
641
648
  num_inference_steps: int = 28,
642
- timesteps: List[int] = None,
649
+ sigmas: Optional[List[float]] = None,
643
650
  guidance_scale: float = 7.0,
644
651
  control_guidance_start: Union[float, List[float]] = 0.0,
645
652
  control_guidance_end: Union[float, List[float]] = 1.0,
@@ -678,8 +685,10 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
678
685
  num_inference_steps (`int`, *optional*, defaults to 28):
679
686
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
680
687
  expense of slower inference.
681
- timesteps (`List[int]`, *optional*):
682
- Custom timesteps to use for the denoising process.
688
+ sigmas (`List[float]`, *optional*):
689
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
690
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
691
+ will be used.
683
692
  guidance_scale (`float`, *optional*, defaults to 7.0):
684
693
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
685
694
  control_mode (`int` or `List[int]`, *optional*):
@@ -794,7 +803,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
794
803
  )
795
804
  height, width = control_image.shape[-2:]
796
805
 
797
- control_image = self.vae.encode(control_image).latent_dist.sample()
806
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
798
807
  control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
799
808
 
800
809
  height_control_image, width_control_image = control_image.shape[2:]
@@ -825,7 +834,7 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
825
834
  )
826
835
  height, width = control_image_.shape[-2:]
827
836
 
828
- control_image_ = self.vae.encode(control_image_).latent_dist.sample()
837
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
829
838
  control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
830
839
 
831
840
  height_control_image, width_control_image = control_image_.shape[2:]
@@ -851,8 +860,8 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
851
860
  control_mode = torch.tensor(control_mode_).to(device, dtype=torch.long)
852
861
  control_mode = control_mode.reshape([-1, 1])
853
862
 
854
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
855
- image_seq_len = (int(height) // self.vae_scale_factor) * (int(width) // self.vae_scale_factor)
863
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
864
+ image_seq_len = (int(height) // self.vae_scale_factor // 2) * (int(width) // self.vae_scale_factor // 2)
856
865
  mu = calculate_shift(
857
866
  image_seq_len,
858
867
  self.scheduler.config.base_image_seq_len,
@@ -864,14 +873,12 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
864
873
  self.scheduler,
865
874
  num_inference_steps,
866
875
  device,
867
- timesteps,
868
- sigmas,
876
+ sigmas=sigmas,
869
877
  mu=mu,
870
878
  )
871
879
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
872
880
 
873
881
  latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
874
-
875
882
  latents, latent_image_ids = self.prepare_latents(
876
883
  init_image,
877
884
  latent_timestep,
@@ -903,9 +910,12 @@ class FluxControlNetImg2ImgPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
903
910
 
904
911
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
905
912
 
906
- guidance = (
907
- torch.tensor([guidance_scale], device=device) if self.controlnet.config.guidance_embeds else None
908
- )
913
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
914
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
915
+ else:
916
+ use_guidance = self.controlnet.config.guidance_embeds
917
+
918
+ guidance = torch.tensor([guidance_scale], device=device) if use_guidance else None
909
919
  guidance = guidance.expand(latents.shape[0]) if guidance is not None else None
910
920
 
911
921
  if isinstance(controlnet_keep[i], list):
@@ -14,7 +14,7 @@ from transformers import (
14
14
  from ...image_processor import PipelineImageInput, VaeImageProcessor
15
15
  from ...loaders import FluxLoraLoaderMixin, FromSingleFileMixin, TextualInversionLoaderMixin
16
16
  from ...models.autoencoders import AutoencoderKL
17
- from ...models.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
17
+ from ...models.controlnets.controlnet_flux import FluxControlNetModel, FluxMultiControlNetModel
18
18
  from ...models.transformers import FluxTransformer2DModel
19
19
  from ...schedulers import FlowMatchEulerDiscreteScheduler
20
20
  from ...utils import (
@@ -231,11 +231,13 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
231
231
  )
232
232
 
233
233
  self.vae_scale_factor = (
234
- 2 ** (len(self.vae.config.block_out_channels)) if hasattr(self, "vae") and self.vae is not None else 16
234
+ 2 ** (len(self.vae.config.block_out_channels) - 1) if hasattr(self, "vae") and self.vae is not None else 8
235
235
  )
236
- self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor)
236
+ # Flux latents are turned into 2x2 patches and packed. This means the latent width and height has to be divisible
237
+ # by the patch size. So the vae scale factor is multiplied by the patch size to account for this
238
+ self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor * 2)
237
239
  self.mask_processor = VaeImageProcessor(
238
- vae_scale_factor=self.vae_scale_factor,
240
+ vae_scale_factor=self.vae_scale_factor * 2,
239
241
  vae_latent_channels=self.vae.config.latent_channels,
240
242
  do_normalize=False,
241
243
  do_binarize=True,
@@ -244,7 +246,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
244
246
  self.tokenizer_max_length = (
245
247
  self.tokenizer.model_max_length if hasattr(self, "tokenizer") and self.tokenizer is not None else 77
246
248
  )
247
- self.default_sample_size = 64
249
+ self.default_sample_size = 128
248
250
 
249
251
  # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._get_t5_prompt_embeds
250
252
  def _get_t5_prompt_embeds(
@@ -467,8 +469,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
467
469
  if strength < 0 or strength > 1:
468
470
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
469
471
 
470
- if height % 8 != 0 or width % 8 != 0:
471
- raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
472
+ if height % (self.vae_scale_factor * 2) != 0 or width % (self.vae_scale_factor * 2) != 0:
473
+ logger.warning(
474
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * 2} but are {height} and {width}. Dimensions will be resized accordingly"
475
+ )
472
476
 
473
477
  if callback_on_step_end_tensor_inputs is not None and not all(
474
478
  k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs
@@ -520,9 +524,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
520
524
  @staticmethod
521
525
  # Copied from diffusers.pipelines.flux.pipeline_flux.FluxPipeline._prepare_latent_image_ids
522
526
  def _prepare_latent_image_ids(batch_size, height, width, device, dtype):
523
- latent_image_ids = torch.zeros(height // 2, width // 2, 3)
524
- latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height // 2)[:, None]
525
- latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width // 2)[None, :]
527
+ latent_image_ids = torch.zeros(height, width, 3)
528
+ latent_image_ids[..., 1] = latent_image_ids[..., 1] + torch.arange(height)[:, None]
529
+ latent_image_ids[..., 2] = latent_image_ids[..., 2] + torch.arange(width)[None, :]
526
530
 
527
531
  latent_image_id_height, latent_image_id_width, latent_image_id_channels = latent_image_ids.shape
528
532
 
@@ -546,13 +550,15 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
546
550
  def _unpack_latents(latents, height, width, vae_scale_factor):
547
551
  batch_size, num_patches, channels = latents.shape
548
552
 
549
- height = height // vae_scale_factor
550
- width = width // vae_scale_factor
553
+ # VAE applies 8x compression on images but we must also account for packing which requires
554
+ # latent height and width to be divisible by 2.
555
+ height = 2 * (int(height) // (vae_scale_factor * 2))
556
+ width = 2 * (int(width) // (vae_scale_factor * 2))
551
557
 
552
- latents = latents.view(batch_size, height, width, channels // 4, 2, 2)
558
+ latents = latents.view(batch_size, height // 2, width // 2, channels // 4, 2, 2)
553
559
  latents = latents.permute(0, 3, 1, 4, 2, 5)
554
560
 
555
- latents = latents.reshape(batch_size, channels // (2 * 2), height * 2, width * 2)
561
+ latents = latents.reshape(batch_size, channels // (2 * 2), height, width)
556
562
 
557
563
  return latents
558
564
 
@@ -576,11 +582,12 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
576
582
  f" size of {batch_size}. Make sure the batch size matches the length of the generators."
577
583
  )
578
584
 
579
- height = 2 * (int(height) // self.vae_scale_factor)
580
- width = 2 * (int(width) // self.vae_scale_factor)
581
-
585
+ # VAE applies 8x compression on images but we must also account for packing which requires
586
+ # latent height and width to be divisible by 2.
587
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
588
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
582
589
  shape = (batch_size, num_channels_latents, height, width)
583
- latent_image_ids = self._prepare_latent_image_ids(batch_size, height, width, device, dtype)
590
+ latent_image_ids = self._prepare_latent_image_ids(batch_size, height // 2, width // 2, device, dtype)
584
591
 
585
592
  image = image.to(device=device, dtype=dtype)
586
593
  image_latents = self._encode_vae_image(image=image, generator=generator)
@@ -622,8 +629,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
622
629
  device,
623
630
  generator,
624
631
  ):
625
- height = 2 * (int(height) // self.vae_scale_factor)
626
- width = 2 * (int(width) // self.vae_scale_factor)
632
+ # VAE applies 8x compression on images but we must also account for packing which requires
633
+ # latent height and width to be divisible by 2.
634
+ height = 2 * (int(height) // (self.vae_scale_factor * 2))
635
+ width = 2 * (int(width) // (self.vae_scale_factor * 2))
627
636
  # resize the mask to latents shape as we concatenate the mask to the latents
628
637
  # we do that before converting to dtype to avoid breaking in case we're using cpu_offload
629
638
  # and half precision
@@ -661,7 +670,6 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
661
670
 
662
671
  # aligning device to prevent device errors when concating it with the latent model input
663
672
  masked_image_latents = masked_image_latents.to(device=device, dtype=dtype)
664
-
665
673
  masked_image_latents = self._pack_latents(
666
674
  masked_image_latents,
667
675
  batch_size,
@@ -744,7 +752,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
744
752
  width: Optional[int] = None,
745
753
  strength: float = 0.6,
746
754
  padding_mask_crop: Optional[int] = None,
747
- timesteps: List[int] = None,
755
+ sigmas: Optional[List[float]] = None,
748
756
  num_inference_steps: int = 28,
749
757
  guidance_scale: float = 7.0,
750
758
  control_guidance_start: Union[float, List[float]] = 0.0,
@@ -791,8 +799,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
791
799
  num_inference_steps (`int`, *optional*, defaults to 28):
792
800
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
793
801
  expense of slower inference.
794
- timesteps (`List[int]`, *optional*):
795
- Custom timesteps to use for the denoising process.
802
+ sigmas (`List[float]`, *optional*):
803
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
804
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
805
+ will be used.
796
806
  guidance_scale (`float`, *optional*, defaults to 7.0):
797
807
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
798
808
  control_guidance_start (`float` or `List[float]`, *optional*, defaults to 0.0):
@@ -930,19 +940,22 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
930
940
  )
931
941
  height, width = control_image.shape[-2:]
932
942
 
933
- # vae encode
934
- control_image = self.vae.encode(control_image).latent_dist.sample()
935
- control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
936
-
937
- # pack
938
- height_control_image, width_control_image = control_image.shape[2:]
939
- control_image = self._pack_latents(
940
- control_image,
941
- batch_size * num_images_per_prompt,
942
- num_channels_latents,
943
- height_control_image,
944
- width_control_image,
945
- )
943
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
944
+ controlnet_blocks_repeat = False if self.controlnet.input_hint_block is None else True
945
+ if self.controlnet.input_hint_block is None:
946
+ # vae encode
947
+ control_image = retrieve_latents(self.vae.encode(control_image), generator=generator)
948
+ control_image = (control_image - self.vae.config.shift_factor) * self.vae.config.scaling_factor
949
+
950
+ # pack
951
+ height_control_image, width_control_image = control_image.shape[2:]
952
+ control_image = self._pack_latents(
953
+ control_image,
954
+ batch_size * num_images_per_prompt,
955
+ num_channels_latents,
956
+ height_control_image,
957
+ width_control_image,
958
+ )
946
959
 
947
960
  # set control mode
948
961
  if control_mode is not None:
@@ -952,7 +965,9 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
952
965
  elif isinstance(self.controlnet, FluxMultiControlNetModel):
953
966
  control_images = []
954
967
 
955
- for control_image_ in control_image:
968
+ # xlab controlnet has a input_hint_block and instantx controlnet does not
969
+ controlnet_blocks_repeat = False if self.controlnet.nets[0].input_hint_block is None else True
970
+ for i, control_image_ in enumerate(control_image):
956
971
  control_image_ = self.prepare_image(
957
972
  image=control_image_,
958
973
  width=width,
@@ -964,19 +979,20 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
964
979
  )
965
980
  height, width = control_image_.shape[-2:]
966
981
 
967
- # vae encode
968
- control_image_ = self.vae.encode(control_image_).latent_dist.sample()
969
- control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
970
-
971
- # pack
972
- height_control_image, width_control_image = control_image_.shape[2:]
973
- control_image_ = self._pack_latents(
974
- control_image_,
975
- batch_size * num_images_per_prompt,
976
- num_channels_latents,
977
- height_control_image,
978
- width_control_image,
979
- )
982
+ if self.controlnet.nets[0].input_hint_block is None:
983
+ # vae encode
984
+ control_image_ = retrieve_latents(self.vae.encode(control_image_), generator=generator)
985
+ control_image_ = (control_image_ - self.vae.config.shift_factor) * self.vae.config.scaling_factor
986
+
987
+ # pack
988
+ height_control_image, width_control_image = control_image_.shape[2:]
989
+ control_image_ = self._pack_latents(
990
+ control_image_,
991
+ batch_size * num_images_per_prompt,
992
+ num_channels_latents,
993
+ height_control_image,
994
+ width_control_image,
995
+ )
980
996
 
981
997
  control_images.append(control_image_)
982
998
 
@@ -995,8 +1011,10 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
995
1011
 
996
1012
  # 6. Prepare timesteps
997
1013
 
998
- sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
999
- image_seq_len = (int(global_height) // self.vae_scale_factor) * (int(global_width) // self.vae_scale_factor)
1014
+ sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps) if sigmas is None else sigmas
1015
+ image_seq_len = (int(global_height) // self.vae_scale_factor // 2) * (
1016
+ int(global_width) // self.vae_scale_factor // 2
1017
+ )
1000
1018
  mu = calculate_shift(
1001
1019
  image_seq_len,
1002
1020
  self.scheduler.config.base_image_seq_len,
@@ -1008,8 +1026,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
1008
1026
  self.scheduler,
1009
1027
  num_inference_steps,
1010
1028
  device,
1011
- timesteps,
1012
- sigmas,
1029
+ sigmas=sigmas,
1013
1030
  mu=mu,
1014
1031
  )
1015
1032
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
@@ -1078,7 +1095,11 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
1078
1095
  timestep = t.expand(latents.shape[0]).to(latents.dtype)
1079
1096
 
1080
1097
  # predict the noise residual
1081
- if self.controlnet.config.guidance_embeds:
1098
+ if isinstance(self.controlnet, FluxMultiControlNetModel):
1099
+ use_guidance = self.controlnet.nets[0].config.guidance_embeds
1100
+ else:
1101
+ use_guidance = self.controlnet.config.guidance_embeds
1102
+ if use_guidance:
1082
1103
  guidance = torch.full([1], guidance_scale, device=device, dtype=torch.float32)
1083
1104
  guidance = guidance.expand(latents.shape[0])
1084
1105
  else:
@@ -1125,6 +1146,7 @@ class FluxControlNetInpaintPipeline(DiffusionPipeline, FluxLoraLoaderMixin, From
1125
1146
  img_ids=latent_image_ids,
1126
1147
  joint_attention_kwargs=self.joint_attention_kwargs,
1127
1148
  return_dict=False,
1149
+ controlnet_blocks_repeat=controlnet_blocks_repeat,
1128
1150
  )[0]
1129
1151
 
1130
1152
  # compute the previous noisy sample x_t -> x_t-1