diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -74,6 +74,20 @@ EXAMPLE_DOC_STRING = """
74
74
  """
75
75
 
76
76
 
77
+ # Copied from diffusers.pipelines.flux.pipeline_flux.calculate_shift
78
+ def calculate_shift(
79
+ image_seq_len,
80
+ base_seq_len: int = 256,
81
+ max_seq_len: int = 4096,
82
+ base_shift: float = 0.5,
83
+ max_shift: float = 1.16,
84
+ ):
85
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
86
+ b = base_shift - m * base_seq_len
87
+ mu = image_seq_len * m + b
88
+ return mu
89
+
90
+
77
91
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion_img2img.retrieve_latents
78
92
  def retrieve_latents(
79
93
  encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
@@ -224,6 +238,9 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
224
238
  )
225
239
  self.tokenizer_max_length = self.tokenizer.model_max_length
226
240
  self.default_sample_size = self.transformer.config.sample_size
241
+ self.patch_size = (
242
+ self.transformer.config.patch_size if hasattr(self, "transformer") and self.transformer is not None else 2
243
+ )
227
244
 
228
245
  # Copied from diffusers.pipelines.stable_diffusion_3.pipeline_stable_diffusion_3.StableDiffusion3Pipeline._get_t5_prompt_embeds
229
246
  def _get_t5_prompt_embeds(
@@ -538,6 +555,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
538
555
  prompt,
539
556
  prompt_2,
540
557
  prompt_3,
558
+ height,
559
+ width,
541
560
  strength,
542
561
  negative_prompt=None,
543
562
  negative_prompt_2=None,
@@ -549,6 +568,15 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
549
568
  callback_on_step_end_tensor_inputs=None,
550
569
  max_sequence_length=None,
551
570
  ):
571
+ if (
572
+ height % (self.vae_scale_factor * self.patch_size) != 0
573
+ or width % (self.vae_scale_factor * self.patch_size) != 0
574
+ ):
575
+ raise ValueError(
576
+ f"`height` and `width` have to be divisible by {self.vae_scale_factor * self.patch_size} but are {height} and {width}."
577
+ f"You can use height {height - height % (self.vae_scale_factor * self.patch_size)} and width {width - width % (self.vae_scale_factor * self.patch_size)}."
578
+ )
579
+
552
580
  if strength < 0 or strength > 1:
553
581
  raise ValueError(f"The value of strength should in [0.0, 1.0] but is {strength}")
554
582
 
@@ -806,7 +834,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
806
834
  padding_mask_crop: Optional[int] = None,
807
835
  strength: float = 0.6,
808
836
  num_inference_steps: int = 50,
809
- timesteps: List[int] = None,
837
+ sigmas: Optional[List[float]] = None,
810
838
  guidance_scale: float = 7.0,
811
839
  negative_prompt: Optional[Union[str, List[str]]] = None,
812
840
  negative_prompt_2: Optional[Union[str, List[str]]] = None,
@@ -824,6 +852,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
824
852
  callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
825
853
  callback_on_step_end_tensor_inputs: List[str] = ["latents"],
826
854
  max_sequence_length: int = 256,
855
+ mu: Optional[float] = None,
827
856
  ):
828
857
  r"""
829
858
  Function invoked when calling the pipeline for generation.
@@ -874,10 +903,10 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
874
903
  num_inference_steps (`int`, *optional*, defaults to 50):
875
904
  The number of denoising steps. More denoising steps usually lead to a higher quality image at the
876
905
  expense of slower inference.
877
- timesteps (`List[int]`, *optional*):
878
- Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
879
- in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
880
- passed will be used. Must be in descending order.
906
+ sigmas (`List[float]`, *optional*):
907
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
908
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
909
+ will be used.
881
910
  guidance_scale (`float`, *optional*, defaults to 7.0):
882
911
  Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
883
912
  `guidance_scale` is defined as `w` of equation 2. of [Imagen
@@ -921,8 +950,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
921
950
  The output format of the generate image. Choose between
922
951
  [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
923
952
  return_dict (`bool`, *optional*, defaults to `True`):
924
- Whether or not to return a [`~pipelines.stable_diffusion_xl.StableDiffusionXLPipelineOutput`] instead
925
- of a plain tuple.
953
+ Whether or not to return a [`~pipelines.stable_diffusion_3.StableDiffusion3PipelineOutput`] instead of
954
+ a plain tuple.
926
955
  callback_on_step_end (`Callable`, *optional*):
927
956
  A function that calls at the end of each denoising steps during the inference. The function is called
928
957
  with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
@@ -933,6 +962,7 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
933
962
  will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
934
963
  `._callback_tensor_inputs` attribute of your pipeline class.
935
964
  max_sequence_length (`int` defaults to 256): Maximum sequence length to use with the `prompt`.
965
+ mu (`float`, *optional*): `mu` value used for `dynamic_shifting`.
936
966
 
937
967
  Examples:
938
968
 
@@ -953,6 +983,8 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
953
983
  prompt,
954
984
  prompt_2,
955
985
  prompt_3,
986
+ height,
987
+ width,
956
988
  strength,
957
989
  negative_prompt=negative_prompt,
958
990
  negative_prompt_2=negative_prompt_2,
@@ -1007,7 +1039,24 @@ class StableDiffusion3InpaintPipeline(DiffusionPipeline, SD3LoraLoaderMixin, Fro
1007
1039
  pooled_prompt_embeds = torch.cat([negative_pooled_prompt_embeds, pooled_prompt_embeds], dim=0)
1008
1040
 
1009
1041
  # 3. Prepare timesteps
1010
- timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
1042
+ scheduler_kwargs = {}
1043
+ if self.scheduler.config.get("use_dynamic_shifting", None) and mu is None:
1044
+ image_seq_len = (int(height) // self.vae_scale_factor // self.transformer.config.patch_size) * (
1045
+ int(width) // self.vae_scale_factor // self.transformer.config.patch_size
1046
+ )
1047
+ mu = calculate_shift(
1048
+ image_seq_len,
1049
+ self.scheduler.config.base_image_seq_len,
1050
+ self.scheduler.config.max_image_seq_len,
1051
+ self.scheduler.config.base_shift,
1052
+ self.scheduler.config.max_shift,
1053
+ )
1054
+ scheduler_kwargs["mu"] = mu
1055
+ elif mu is not None:
1056
+ scheduler_kwargs["mu"] = mu
1057
+ timesteps, num_inference_steps = retrieve_timesteps(
1058
+ self.scheduler, num_inference_steps, device, sigmas=sigmas, **scheduler_kwargs
1059
+ )
1011
1060
  timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength, device)
1012
1061
  # check that number of inference steps is not < 1 - as this doesn't make sense
1013
1062
  if num_inference_steps < 1:
@@ -446,13 +446,14 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
446
446
  extra_step_kwargs["generator"] = generator
447
447
  return extra_step_kwargs
448
448
 
449
- # Copied from diffusers.pipelines.stable_diffusion_k_diffusion.pipeline_stable_diffusion_k_diffusion.StableDiffusionKDiffusionPipeline.check_inputs
450
449
  def check_inputs(
451
450
  self,
452
451
  prompt,
453
452
  height,
454
453
  width,
455
454
  callback_steps,
455
+ gligen_images,
456
+ gligen_phrases,
456
457
  negative_prompt=None,
457
458
  prompt_embeds=None,
458
459
  negative_prompt_embeds=None,
@@ -499,6 +500,13 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
499
500
  f" {negative_prompt_embeds.shape}."
500
501
  )
501
502
 
503
+ if gligen_images is not None and gligen_phrases is not None:
504
+ if len(gligen_images) != len(gligen_phrases):
505
+ raise ValueError(
506
+ "`gligen_images` and `gligen_phrases` must have the same length when both are provided, but"
507
+ f" got: `gligen_images` with length {len(gligen_images)} != `gligen_phrases` with length {len(gligen_phrases)}."
508
+ )
509
+
502
510
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_latents
503
511
  def prepare_latents(self, batch_size, num_channels_latents, height, width, dtype, device, generator, latents=None):
504
512
  shape = (
@@ -814,6 +822,8 @@ class StableDiffusionGLIGENTextImagePipeline(DiffusionPipeline, StableDiffusionM
814
822
  height,
815
823
  width,
816
824
  callback_steps,
825
+ gligen_images,
826
+ gligen_phrases,
817
827
  negative_prompt,
818
828
  prompt_embeds,
819
829
  negative_prompt_embeds,
@@ -237,11 +237,8 @@ class StableDiffusionXLPipeline(
237
237
  _callback_tensor_inputs = [
238
238
  "latents",
239
239
  "prompt_embeds",
240
- "negative_prompt_embeds",
241
240
  "add_text_embeds",
242
241
  "add_time_ids",
243
- "negative_pooled_prompt_embeds",
244
- "negative_add_time_ids",
245
242
  ]
246
243
 
247
244
  def __init__(
@@ -1243,13 +1240,8 @@ class StableDiffusionXLPipeline(
1243
1240
 
1244
1241
  latents = callback_outputs.pop("latents", latents)
1245
1242
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1246
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1247
1243
  add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1248
- negative_pooled_prompt_embeds = callback_outputs.pop(
1249
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1250
- )
1251
1244
  add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1252
- negative_add_time_ids = callback_outputs.pop("negative_add_time_ids", negative_add_time_ids)
1253
1245
 
1254
1246
  # call the callback, if provided
1255
1247
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -257,11 +257,8 @@ class StableDiffusionXLImg2ImgPipeline(
257
257
  _callback_tensor_inputs = [
258
258
  "latents",
259
259
  "prompt_embeds",
260
- "negative_prompt_embeds",
261
260
  "add_text_embeds",
262
261
  "add_time_ids",
263
- "negative_pooled_prompt_embeds",
264
- "add_neg_time_ids",
265
262
  ]
266
263
 
267
264
  def __init__(
@@ -1438,13 +1435,8 @@ class StableDiffusionXLImg2ImgPipeline(
1438
1435
 
1439
1436
  latents = callback_outputs.pop("latents", latents)
1440
1437
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1441
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1442
1438
  add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1443
- negative_pooled_prompt_embeds = callback_outputs.pop(
1444
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1445
- )
1446
1439
  add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1447
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1448
1440
 
1449
1441
  # call the callback, if provided
1450
1442
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -285,11 +285,8 @@ class StableDiffusionXLInpaintPipeline(
285
285
  _callback_tensor_inputs = [
286
286
  "latents",
287
287
  "prompt_embeds",
288
- "negative_prompt_embeds",
289
288
  "add_text_embeds",
290
289
  "add_time_ids",
291
- "negative_pooled_prompt_embeds",
292
- "add_neg_time_ids",
293
290
  "mask",
294
291
  "masked_image_latents",
295
292
  ]
@@ -1671,13 +1668,8 @@ class StableDiffusionXLInpaintPipeline(
1671
1668
 
1672
1669
  latents = callback_outputs.pop("latents", latents)
1673
1670
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1674
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1675
1671
  add_text_embeds = callback_outputs.pop("add_text_embeds", add_text_embeds)
1676
- negative_pooled_prompt_embeds = callback_outputs.pop(
1677
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1678
- )
1679
1672
  add_time_ids = callback_outputs.pop("add_time_ids", add_time_ids)
1680
- add_neg_time_ids = callback_outputs.pop("add_neg_time_ids", add_neg_time_ids)
1681
1673
  mask = callback_outputs.pop("mask", mask)
1682
1674
  masked_image_latents = callback_outputs.pop("masked_image_latents", masked_image_latents)
1683
1675
 
@@ -104,8 +104,8 @@ class PatchEmbed(nn.Module):
104
104
 
105
105
  self.use_pos_embed = use_pos_embed
106
106
  if self.use_pos_embed:
107
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5))
108
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
107
+ pos_embed = get_2d_sincos_pos_embed(embed_dim, int(num_patches**0.5), output_type="pt")
108
+ self.register_buffer("pos_embed", pos_embed.float().unsqueeze(0), persistent=False)
109
109
 
110
110
  def forward(self, latent):
111
111
  latent = self.proj(latent)
@@ -158,7 +158,7 @@ class WuerstchenPrior(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
158
158
  c_embed = self.cond_mapper(c)
159
159
  r_embed = self.gen_r_embedding(r)
160
160
 
161
- if self.training and self.gradient_checkpointing:
161
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
162
162
 
163
163
  def create_custom_forward(module):
164
164
  def custom_forward(*inputs):
@@ -15,21 +15,34 @@
15
15
  Adapted from
16
16
  https://github.com/huggingface/transformers/blob/c409cd81777fb27aadc043ed3d8339dbc020fb3b/src/transformers/quantizers/auto.py
17
17
  """
18
+
18
19
  import warnings
19
20
  from typing import Dict, Optional, Union
20
21
 
21
22
  from .bitsandbytes import BnB4BitDiffusersQuantizer, BnB8BitDiffusersQuantizer
22
- from .quantization_config import BitsAndBytesConfig, QuantizationConfigMixin, QuantizationMethod
23
+ from .gguf import GGUFQuantizer
24
+ from .quantization_config import (
25
+ BitsAndBytesConfig,
26
+ GGUFQuantizationConfig,
27
+ QuantizationConfigMixin,
28
+ QuantizationMethod,
29
+ TorchAoConfig,
30
+ )
31
+ from .torchao import TorchAoHfQuantizer
23
32
 
24
33
 
25
34
  AUTO_QUANTIZER_MAPPING = {
26
35
  "bitsandbytes_4bit": BnB4BitDiffusersQuantizer,
27
36
  "bitsandbytes_8bit": BnB8BitDiffusersQuantizer,
37
+ "gguf": GGUFQuantizer,
38
+ "torchao": TorchAoHfQuantizer,
28
39
  }
29
40
 
30
41
  AUTO_QUANTIZATION_CONFIG_MAPPING = {
31
42
  "bitsandbytes_4bit": BitsAndBytesConfig,
32
43
  "bitsandbytes_8bit": BitsAndBytesConfig,
44
+ "gguf": GGUFQuantizationConfig,
45
+ "torchao": TorchAoConfig,
33
46
  }
34
47
 
35
48
 
@@ -204,7 +204,10 @@ class BnB4BitDiffusersQuantizer(DiffusersQuantizer):
204
204
 
205
205
  module._parameters[tensor_name] = new_value
206
206
 
207
- def check_quantized_param_shape(self, param_name, current_param_shape, loaded_param_shape):
207
+ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
208
+ current_param_shape = current_param.shape
209
+ loaded_param_shape = loaded_param.shape
210
+
208
211
  n = current_param_shape.numel()
209
212
  inferred_shape = (n,) if "bias" in param_name else ((n + 1) // 2, 1)
210
213
  if loaded_param_shape != inferred_shape:
@@ -0,0 +1 @@
1
+ from .gguf_quantizer import GGUFQuantizer
@@ -0,0 +1,159 @@
1
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
2
+
3
+ from ..base import DiffusersQuantizer
4
+
5
+
6
+ if TYPE_CHECKING:
7
+ from ...models.modeling_utils import ModelMixin
8
+
9
+
10
+ from ...utils import (
11
+ get_module_from_name,
12
+ is_accelerate_available,
13
+ is_accelerate_version,
14
+ is_gguf_available,
15
+ is_gguf_version,
16
+ is_torch_available,
17
+ logging,
18
+ )
19
+
20
+
21
+ if is_torch_available() and is_gguf_available():
22
+ import torch
23
+
24
+ from .utils import (
25
+ GGML_QUANT_SIZES,
26
+ GGUFParameter,
27
+ _dequantize_gguf_and_restore_linear,
28
+ _quant_shape_from_byte_shape,
29
+ _replace_with_gguf_linear,
30
+ )
31
+
32
+
33
+ logger = logging.get_logger(__name__)
34
+
35
+
36
+ class GGUFQuantizer(DiffusersQuantizer):
37
+ use_keep_in_fp32_modules = True
38
+
39
+ def __init__(self, quantization_config, **kwargs):
40
+ super().__init__(quantization_config, **kwargs)
41
+
42
+ self.compute_dtype = quantization_config.compute_dtype
43
+ self.pre_quantized = quantization_config.pre_quantized
44
+ self.modules_to_not_convert = quantization_config.modules_to_not_convert
45
+
46
+ if not isinstance(self.modules_to_not_convert, list):
47
+ self.modules_to_not_convert = [self.modules_to_not_convert]
48
+
49
+ def validate_environment(self, *args, **kwargs):
50
+ if not is_accelerate_available() or is_accelerate_version("<", "0.26.0"):
51
+ raise ImportError(
52
+ "Loading GGUF Parameters requires `accelerate` installed in your enviroment: `pip install 'accelerate>=0.26.0'`"
53
+ )
54
+ if not is_gguf_available() or is_gguf_version("<", "0.10.0"):
55
+ raise ImportError(
56
+ "To load GGUF format files you must have `gguf` installed in your environment: `pip install gguf>=0.10.0`"
57
+ )
58
+
59
+ # Copied from diffusers.quantizers.bitsandbytes.bnb_quantizer.BnB4BitDiffusersQuantizer.adjust_max_memory
60
+ def adjust_max_memory(self, max_memory: Dict[str, Union[int, str]]) -> Dict[str, Union[int, str]]:
61
+ # need more space for buffers that are created during quantization
62
+ max_memory = {key: val * 0.90 for key, val in max_memory.items()}
63
+ return max_memory
64
+
65
+ def adjust_target_dtype(self, target_dtype: "torch.dtype") -> "torch.dtype":
66
+ if target_dtype != torch.uint8:
67
+ logger.info(f"target_dtype {target_dtype} is replaced by `torch.uint8` for GGUF quantization")
68
+ return torch.uint8
69
+
70
+ def update_torch_dtype(self, torch_dtype: "torch.dtype") -> "torch.dtype":
71
+ if torch_dtype is None:
72
+ torch_dtype = self.compute_dtype
73
+ return torch_dtype
74
+
75
+ def check_quantized_param_shape(self, param_name, current_param, loaded_param):
76
+ loaded_param_shape = loaded_param.shape
77
+ current_param_shape = current_param.shape
78
+ quant_type = loaded_param.quant_type
79
+
80
+ block_size, type_size = GGML_QUANT_SIZES[quant_type]
81
+
82
+ inferred_shape = _quant_shape_from_byte_shape(loaded_param_shape, type_size, block_size)
83
+ if inferred_shape != current_param_shape:
84
+ raise ValueError(
85
+ f"{param_name} has an expected quantized shape of: {inferred_shape}, but receieved shape: {loaded_param_shape}"
86
+ )
87
+
88
+ return True
89
+
90
+ def check_if_quantized_param(
91
+ self,
92
+ model: "ModelMixin",
93
+ param_value: Union["GGUFParameter", "torch.Tensor"],
94
+ param_name: str,
95
+ state_dict: Dict[str, Any],
96
+ **kwargs,
97
+ ) -> bool:
98
+ if isinstance(param_value, GGUFParameter):
99
+ return True
100
+
101
+ return False
102
+
103
+ def create_quantized_param(
104
+ self,
105
+ model: "ModelMixin",
106
+ param_value: Union["GGUFParameter", "torch.Tensor"],
107
+ param_name: str,
108
+ target_device: "torch.device",
109
+ state_dict: Optional[Dict[str, Any]] = None,
110
+ unexpected_keys: Optional[List[str]] = None,
111
+ ):
112
+ module, tensor_name = get_module_from_name(model, param_name)
113
+ if tensor_name not in module._parameters and tensor_name not in module._buffers:
114
+ raise ValueError(f"{module} does not have a parameter or a buffer named {tensor_name}.")
115
+
116
+ if tensor_name in module._parameters:
117
+ module._parameters[tensor_name] = param_value.to(target_device)
118
+ if tensor_name in module._buffers:
119
+ module._buffers[tensor_name] = param_value.to(target_device)
120
+
121
+ def _process_model_before_weight_loading(
122
+ self,
123
+ model: "ModelMixin",
124
+ device_map,
125
+ keep_in_fp32_modules: List[str] = [],
126
+ **kwargs,
127
+ ):
128
+ state_dict = kwargs.get("state_dict", None)
129
+
130
+ self.modules_to_not_convert.extend(keep_in_fp32_modules)
131
+ self.modules_to_not_convert = [module for module in self.modules_to_not_convert if module is not None]
132
+
133
+ _replace_with_gguf_linear(
134
+ model, self.compute_dtype, state_dict, modules_to_not_convert=self.modules_to_not_convert
135
+ )
136
+
137
+ def _process_model_after_weight_loading(self, model: "ModelMixin", **kwargs):
138
+ return model
139
+
140
+ @property
141
+ def is_serializable(self):
142
+ return False
143
+
144
+ @property
145
+ def is_trainable(self) -> bool:
146
+ return False
147
+
148
+ def _dequantize(self, model):
149
+ is_model_on_cpu = model.device.type == "cpu"
150
+ if is_model_on_cpu:
151
+ logger.info(
152
+ "Model was found to be on CPU (could happen as a result of `enable_model_cpu_offload()`). So, moving it to GPU. After dequantization, will move the model back to CPU again to preserve the previous device."
153
+ )
154
+ model.to(torch.cuda.current_device())
155
+
156
+ model = _dequantize_gguf_and_restore_linear(model, self.modules_to_not_convert)
157
+ if is_model_on_cpu:
158
+ model.to("cpu")
159
+ return model