diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,20 @@
1
+ from dataclasses import dataclass
2
+
3
+ import torch
4
+
5
+ from diffusers.utils import BaseOutput
6
+
7
+
8
+ @dataclass
9
+ class MochiPipelineOutput(BaseOutput):
10
+ r"""
11
+ Output class for Mochi pipelines.
12
+
13
+ Args:
14
+ frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
15
+ List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
16
+ denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
17
+ `(batch_size, num_frames, channels, height, width)`.
18
+ """
19
+
20
+ frames: torch.Tensor
@@ -29,10 +29,14 @@ else:
29
29
  _import_structure["pipeline_pag_hunyuandit"] = ["HunyuanDiTPAGPipeline"]
30
30
  _import_structure["pipeline_pag_kolors"] = ["KolorsPAGPipeline"]
31
31
  _import_structure["pipeline_pag_pixart_sigma"] = ["PixArtSigmaPAGPipeline"]
32
+ _import_structure["pipeline_pag_sana"] = ["SanaPAGPipeline"]
32
33
  _import_structure["pipeline_pag_sd"] = ["StableDiffusionPAGPipeline"]
33
34
  _import_structure["pipeline_pag_sd_3"] = ["StableDiffusion3PAGPipeline"]
35
+ _import_structure["pipeline_pag_sd_3_img2img"] = ["StableDiffusion3PAGImg2ImgPipeline"]
34
36
  _import_structure["pipeline_pag_sd_animatediff"] = ["AnimateDiffPAGPipeline"]
35
37
  _import_structure["pipeline_pag_sd_img2img"] = ["StableDiffusionPAGImg2ImgPipeline"]
38
+ _import_structure["pipeline_pag_sd_inpaint"] = ["StableDiffusionPAGInpaintPipeline"]
39
+
36
40
  _import_structure["pipeline_pag_sd_xl"] = ["StableDiffusionXLPAGPipeline"]
37
41
  _import_structure["pipeline_pag_sd_xl_img2img"] = ["StableDiffusionXLPAGImg2ImgPipeline"]
38
42
  _import_structure["pipeline_pag_sd_xl_inpaint"] = ["StableDiffusionXLPAGInpaintPipeline"]
@@ -52,10 +56,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
52
56
  from .pipeline_pag_hunyuandit import HunyuanDiTPAGPipeline
53
57
  from .pipeline_pag_kolors import KolorsPAGPipeline
54
58
  from .pipeline_pag_pixart_sigma import PixArtSigmaPAGPipeline
59
+ from .pipeline_pag_sana import SanaPAGPipeline
55
60
  from .pipeline_pag_sd import StableDiffusionPAGPipeline
56
61
  from .pipeline_pag_sd_3 import StableDiffusion3PAGPipeline
62
+ from .pipeline_pag_sd_3_img2img import StableDiffusion3PAGImg2ImgPipeline
57
63
  from .pipeline_pag_sd_animatediff import AnimateDiffPAGPipeline
58
64
  from .pipeline_pag_sd_img2img import StableDiffusionPAGImg2ImgPipeline
65
+ from .pipeline_pag_sd_inpaint import StableDiffusionPAGInpaintPipeline
59
66
  from .pipeline_pag_sd_xl import StableDiffusionXLPAGPipeline
60
67
  from .pipeline_pag_sd_xl_img2img import StableDiffusionXLPAGImg2ImgPipeline
61
68
  from .pipeline_pag_sd_xl_inpaint import StableDiffusionXLPAGInpaintPipeline
@@ -25,7 +25,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
25
25
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
26
26
  from ...image_processor import PipelineImageInput, VaeImageProcessor
27
27
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
28
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
28
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
29
29
  from ...models.lora import adjust_lora_scale_text_encoder
30
30
  from ...schedulers import KarrasDiffusionSchedulers
31
31
  from ...utils import (
@@ -36,7 +36,6 @@ from ...utils import (
36
36
  unscale_lora_layers,
37
37
  )
38
38
  from ...utils.torch_utils import is_compiled_module, is_torch_version, randn_tensor
39
- from ..controlnet.multicontrolnet import MultiControlNetModel
40
39
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
41
40
  from ..stable_diffusion.pipeline_output import StableDiffusionPipelineOutput
42
41
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -26,7 +26,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
26
26
  from ...callbacks import MultiPipelineCallbacks, PipelineCallback
27
27
  from ...image_processor import PipelineImageInput, VaeImageProcessor
28
28
  from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
29
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
29
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
30
30
  from ...models.lora import adjust_lora_scale_text_encoder
31
31
  from ...schedulers import KarrasDiffusionSchedulers
32
32
  from ...utils import (
@@ -37,7 +37,6 @@ from ...utils import (
37
37
  unscale_lora_layers,
38
38
  )
39
39
  from ...utils.torch_utils import is_compiled_module, randn_tensor
40
- from ..controlnet.multicontrolnet import MultiControlNetModel
41
40
  from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
42
41
  from ..stable_diffusion import StableDiffusionPipelineOutput
43
42
  from ..stable_diffusion.safety_checker import StableDiffusionSafetyChecker
@@ -38,7 +38,7 @@ from ...loaders import (
38
38
  StableDiffusionXLLoraLoaderMixin,
39
39
  TextualInversionLoaderMixin,
40
40
  )
41
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
41
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
42
42
  from ...models.attention_processor import (
43
43
  AttnProcessor2_0,
44
44
  XFormersAttnProcessor,
@@ -61,8 +61,6 @@ from .pag_utils import PAGMixin
61
61
  if is_invisible_watermark_available():
62
62
  from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
63
 
64
- from ..controlnet.multicontrolnet import MultiControlNetModel
65
-
66
64
 
67
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
68
66
 
@@ -38,7 +38,7 @@ from ...loaders import (
38
38
  StableDiffusionXLLoraLoaderMixin,
39
39
  TextualInversionLoaderMixin,
40
40
  )
41
- from ...models import AutoencoderKL, ControlNetModel, ImageProjection, UNet2DConditionModel
41
+ from ...models import AutoencoderKL, ControlNetModel, ImageProjection, MultiControlNetModel, UNet2DConditionModel
42
42
  from ...models.attention_processor import (
43
43
  AttnProcessor2_0,
44
44
  XFormersAttnProcessor,
@@ -61,8 +61,6 @@ from .pag_utils import PAGMixin
61
61
  if is_invisible_watermark_available():
62
62
  from ..stable_diffusion_xl.watermark import StableDiffusionXLWatermarker
63
63
 
64
- from ..controlnet.multicontrolnet import MultiControlNetModel
65
-
66
64
 
67
65
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
68
66
 
@@ -818,7 +818,11 @@ class HunyuanDiTPAGPipeline(DiffusionPipeline, PAGMixin):
818
818
  base_size = 512 // 8 // self.transformer.config.patch_size
819
819
  grid_crops_coords = get_resize_crop_region_for_grid((grid_height, grid_width), base_size)
820
820
  image_rotary_emb = get_2d_rotary_pos_embed(
821
- self.transformer.inner_dim // self.transformer.num_heads, grid_crops_coords, (grid_height, grid_width)
821
+ self.transformer.inner_dim // self.transformer.num_heads,
822
+ grid_crops_coords,
823
+ (grid_height, grid_width),
824
+ device=device,
825
+ output_type="pt",
822
826
  )
823
827
 
824
828
  style = torch.tensor([0], device=device)
@@ -227,13 +227,6 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
227
227
  if device is None:
228
228
  device = self._execution_device
229
229
 
230
- if prompt is not None and isinstance(prompt, str):
231
- batch_size = 1
232
- elif prompt is not None and isinstance(prompt, list):
233
- batch_size = len(prompt)
234
- else:
235
- batch_size = prompt_embeds.shape[0]
236
-
237
230
  # See Section 3.1. of the paper.
238
231
  max_length = max_sequence_length
239
232
 
@@ -278,12 +271,12 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
278
271
  # duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
279
272
  prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
280
273
  prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
281
- prompt_attention_mask = prompt_attention_mask.view(bs_embed, -1)
282
- prompt_attention_mask = prompt_attention_mask.repeat(num_images_per_prompt, 1)
274
+ prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
275
+ prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
283
276
 
284
277
  # get unconditional embeddings for classifier free guidance
285
278
  if do_classifier_free_guidance and negative_prompt_embeds is None:
286
- uncond_tokens = [negative_prompt] * batch_size if isinstance(negative_prompt, str) else negative_prompt
279
+ uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
287
280
  uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
288
281
  max_length = prompt_embeds.shape[1]
289
282
  uncond_input = self.tokenizer(
@@ -310,10 +303,10 @@ class PixArtSigmaPAGPipeline(DiffusionPipeline, PAGMixin):
310
303
  negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
311
304
 
312
305
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
313
- negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
306
+ negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
314
307
 
315
- negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed, -1)
316
- negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(num_images_per_prompt, 1)
308
+ negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
309
+ negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
317
310
  else:
318
311
  negative_prompt_embeds = None
319
312
  negative_prompt_attention_mask = None