diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ from dataclasses import dataclass
2
+ from typing import List, Union
3
+
4
+ import numpy as np
5
+ import PIL.Image
6
+
7
+ from ...utils import BaseOutput
8
+
9
+
10
+ @dataclass
11
+ class FluxPipelineOutput(BaseOutput):
12
+ """
13
+ Output class for Stable Diffusion pipelines.
14
+
15
+ Args:
16
+ images (`List[PIL.Image.Image]` or `np.ndarray`)
17
+ List of denoised PIL images of length `batch_size` or numpy array of shape `(batch_size, height, width,
18
+ num_channels)`. PIL images or numpy array present the denoised images of the diffusion pipeline.
19
+ """
20
+
21
+ images: Union[List[PIL.Image.Image], np.ndarray]
@@ -180,6 +180,8 @@ class FreeInitMixin:
180
180
  num_inference_steps = max(
181
181
  1, int(num_inference_steps / self._free_init_num_iters * (free_init_iteration + 1))
182
182
  )
183
+
184
+ if num_inference_steps > 0:
183
185
  self.scheduler.set_timesteps(num_inference_steps, device=device)
184
186
 
185
187
  return latents, self.scheduler.timesteps
@@ -0,0 +1,236 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Optional, Union
16
+
17
+ import torch
18
+
19
+ from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
20
+ from ..models.unets.unet_motion_model import (
21
+ CrossAttnDownBlockMotion,
22
+ DownBlockMotion,
23
+ UpBlockMotion,
24
+ )
25
+ from ..utils import logging
26
+ from ..utils.torch_utils import randn_tensor
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class AnimateDiffFreeNoiseMixin:
33
+ r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
34
+
35
+ def _enable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
36
+ r"""Helper function to enable FreeNoise in transformer blocks."""
37
+
38
+ for motion_module in block.motion_modules:
39
+ num_transformer_blocks = len(motion_module.transformer_blocks)
40
+
41
+ for i in range(num_transformer_blocks):
42
+ if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
43
+ motion_module.transformer_blocks[i].set_free_noise_properties(
44
+ self._free_noise_context_length,
45
+ self._free_noise_context_stride,
46
+ self._free_noise_weighting_scheme,
47
+ )
48
+ else:
49
+ assert isinstance(motion_module.transformer_blocks[i], BasicTransformerBlock)
50
+ basic_transfomer_block = motion_module.transformer_blocks[i]
51
+
52
+ motion_module.transformer_blocks[i] = FreeNoiseTransformerBlock(
53
+ dim=basic_transfomer_block.dim,
54
+ num_attention_heads=basic_transfomer_block.num_attention_heads,
55
+ attention_head_dim=basic_transfomer_block.attention_head_dim,
56
+ dropout=basic_transfomer_block.dropout,
57
+ cross_attention_dim=basic_transfomer_block.cross_attention_dim,
58
+ activation_fn=basic_transfomer_block.activation_fn,
59
+ attention_bias=basic_transfomer_block.attention_bias,
60
+ only_cross_attention=basic_transfomer_block.only_cross_attention,
61
+ double_self_attention=basic_transfomer_block.double_self_attention,
62
+ positional_embeddings=basic_transfomer_block.positional_embeddings,
63
+ num_positional_embeddings=basic_transfomer_block.num_positional_embeddings,
64
+ context_length=self._free_noise_context_length,
65
+ context_stride=self._free_noise_context_stride,
66
+ weighting_scheme=self._free_noise_weighting_scheme,
67
+ ).to(device=self.device, dtype=self.dtype)
68
+
69
+ motion_module.transformer_blocks[i].load_state_dict(
70
+ basic_transfomer_block.state_dict(), strict=True
71
+ )
72
+
73
+ def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
74
+ r"""Helper function to disable FreeNoise in transformer blocks."""
75
+
76
+ for motion_module in block.motion_modules:
77
+ num_transformer_blocks = len(motion_module.transformer_blocks)
78
+
79
+ for i in range(num_transformer_blocks):
80
+ if isinstance(motion_module.transformer_blocks[i], FreeNoiseTransformerBlock):
81
+ free_noise_transfomer_block = motion_module.transformer_blocks[i]
82
+
83
+ motion_module.transformer_blocks[i] = BasicTransformerBlock(
84
+ dim=free_noise_transfomer_block.dim,
85
+ num_attention_heads=free_noise_transfomer_block.num_attention_heads,
86
+ attention_head_dim=free_noise_transfomer_block.attention_head_dim,
87
+ dropout=free_noise_transfomer_block.dropout,
88
+ cross_attention_dim=free_noise_transfomer_block.cross_attention_dim,
89
+ activation_fn=free_noise_transfomer_block.activation_fn,
90
+ attention_bias=free_noise_transfomer_block.attention_bias,
91
+ only_cross_attention=free_noise_transfomer_block.only_cross_attention,
92
+ double_self_attention=free_noise_transfomer_block.double_self_attention,
93
+ positional_embeddings=free_noise_transfomer_block.positional_embeddings,
94
+ num_positional_embeddings=free_noise_transfomer_block.num_positional_embeddings,
95
+ ).to(device=self.device, dtype=self.dtype)
96
+
97
+ motion_module.transformer_blocks[i].load_state_dict(
98
+ free_noise_transfomer_block.state_dict(), strict=True
99
+ )
100
+
101
+ def _prepare_latents_free_noise(
102
+ self,
103
+ batch_size: int,
104
+ num_channels_latents: int,
105
+ num_frames: int,
106
+ height: int,
107
+ width: int,
108
+ dtype: torch.dtype,
109
+ device: torch.device,
110
+ generator: Optional[torch.Generator] = None,
111
+ latents: Optional[torch.Tensor] = None,
112
+ ):
113
+ if isinstance(generator, list) and len(generator) != batch_size:
114
+ raise ValueError(
115
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
116
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
117
+ )
118
+
119
+ context_num_frames = (
120
+ self._free_noise_context_length if self._free_noise_context_length == "repeat_context" else num_frames
121
+ )
122
+
123
+ shape = (
124
+ batch_size,
125
+ num_channels_latents,
126
+ context_num_frames,
127
+ height // self.vae_scale_factor,
128
+ width // self.vae_scale_factor,
129
+ )
130
+
131
+ if latents is None:
132
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
133
+ if self._free_noise_noise_type == "random":
134
+ return latents
135
+ else:
136
+ if latents.size(2) == num_frames:
137
+ return latents
138
+ elif latents.size(2) != self._free_noise_context_length:
139
+ raise ValueError(
140
+ f"You have passed `latents` as a parameter to FreeNoise. The expected number of frames is either {num_frames} or {self._free_noise_context_length}, but found {latents.size(2)}"
141
+ )
142
+ latents = latents.to(device)
143
+
144
+ if self._free_noise_noise_type == "shuffle_context":
145
+ for i in range(self._free_noise_context_length, num_frames, self._free_noise_context_stride):
146
+ # ensure window is within bounds
147
+ window_start = max(0, i - self._free_noise_context_length)
148
+ window_end = min(num_frames, window_start + self._free_noise_context_stride)
149
+ window_length = window_end - window_start
150
+
151
+ if window_length == 0:
152
+ break
153
+
154
+ indices = torch.LongTensor(list(range(window_start, window_end)))
155
+ shuffled_indices = indices[torch.randperm(window_length, generator=generator)]
156
+
157
+ current_start = i
158
+ current_end = min(num_frames, current_start + window_length)
159
+ if current_end == current_start + window_length:
160
+ # batch of frames perfectly fits the window
161
+ latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
162
+ else:
163
+ # handle the case where the last batch of frames does not fit perfectly with the window
164
+ prefix_length = current_end - current_start
165
+ shuffled_indices = shuffled_indices[:prefix_length]
166
+ latents[:, :, current_start:current_end] = latents[:, :, shuffled_indices]
167
+
168
+ elif self._free_noise_noise_type == "repeat_context":
169
+ num_repeats = (num_frames + self._free_noise_context_length - 1) // self._free_noise_context_length
170
+ latents = torch.cat([latents] * num_repeats, dim=2)
171
+
172
+ latents = latents[:, :, :num_frames]
173
+ return latents
174
+
175
+ def enable_free_noise(
176
+ self,
177
+ context_length: Optional[int] = 16,
178
+ context_stride: int = 4,
179
+ weighting_scheme: str = "pyramid",
180
+ noise_type: str = "shuffle_context",
181
+ ) -> None:
182
+ r"""
183
+ Enable long video generation using FreeNoise.
184
+
185
+ Args:
186
+ context_length (`int`, defaults to `16`, *optional*):
187
+ The number of video frames to process at once. It's recommended to set this to the maximum frames the
188
+ Motion Adapter was trained with (usually 16/24/32). If `None`, the default value from the motion
189
+ adapter config is used.
190
+ context_stride (`int`, *optional*):
191
+ Long videos are generated by processing many frames. FreeNoise processes these frames in sliding
192
+ windows of size `context_length`. Context stride allows you to specify how many frames to skip between
193
+ each window. For example, a context length of 16 and context stride of 4 would process 24 frames as:
194
+ [0, 15], [4, 19], [8, 23] (0-based indexing)
195
+ weighting_scheme (`str`, defaults to `pyramid`):
196
+ Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
197
+ schemes are supported currently:
198
+ - "pyramid"
199
+ Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
200
+ noise_type (`str`, defaults to "shuffle_context"):
201
+ TODO
202
+ """
203
+
204
+ allowed_weighting_scheme = ["pyramid"]
205
+ allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
206
+
207
+ if context_length > self.motion_adapter.config.motion_max_seq_length:
208
+ logger.warning(
209
+ f"You have set {context_length=} which is greater than {self.motion_adapter.config.motion_max_seq_length=}. This can lead to bad generation results."
210
+ )
211
+ if weighting_scheme not in allowed_weighting_scheme:
212
+ raise ValueError(
213
+ f"The parameter `weighting_scheme` must be one of {allowed_weighting_scheme}, but got {weighting_scheme=}"
214
+ )
215
+ if noise_type not in allowed_noise_type:
216
+ raise ValueError(f"The parameter `noise_type` must be one of {allowed_noise_type}, but got {noise_type=}")
217
+
218
+ self._free_noise_context_length = context_length or self.motion_adapter.config.motion_max_seq_length
219
+ self._free_noise_context_stride = context_stride
220
+ self._free_noise_weighting_scheme = weighting_scheme
221
+ self._free_noise_noise_type = noise_type
222
+
223
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
224
+ for block in blocks:
225
+ self._enable_free_noise_in_block(block)
226
+
227
+ def disable_free_noise(self) -> None:
228
+ self._free_noise_context_length = None
229
+
230
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
231
+ for block in blocks:
232
+ self._disable_free_noise_in_block(block)
233
+
234
+ @property
235
+ def free_noise_enabled(self):
236
+ return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
@@ -3,7 +3,7 @@ from typing import Callable, Dict, List, Optional, Union
3
3
  import torch
4
4
  from transformers import T5EncoderModel, T5Tokenizer
5
5
 
6
- from ...loaders import LoraLoaderMixin
6
+ from ...loaders import StableDiffusionLoraLoaderMixin
7
7
  from ...models import Kandinsky3UNet, VQModel
8
8
  from ...schedulers import DDPMScheduler
9
9
  from ...utils import (
@@ -47,7 +47,7 @@ def downscale_height_and_width(height, width, scale_factor=8):
47
47
  return new_height * scale_factor, new_width * scale_factor
48
48
 
49
49
 
50
- class Kandinsky3Pipeline(DiffusionPipeline, LoraLoaderMixin):
50
+ class Kandinsky3Pipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
51
51
  model_cpu_offload_seq = "text_encoder->unet->movq"
52
52
  _callback_tensor_inputs = [
53
53
  "latents",
@@ -7,7 +7,7 @@ import PIL.Image
7
7
  import torch
8
8
  from transformers import T5EncoderModel, T5Tokenizer
9
9
 
10
- from ...loaders import LoraLoaderMixin
10
+ from ...loaders import StableDiffusionLoraLoaderMixin
11
11
  from ...models import Kandinsky3UNet, VQModel
12
12
  from ...schedulers import DDPMScheduler
13
13
  from ...utils import (
@@ -62,7 +62,7 @@ def prepare_image(pil_image):
62
62
  return image
63
63
 
64
64
 
65
- class Kandinsky3Img2ImgPipeline(DiffusionPipeline, LoraLoaderMixin):
65
+ class Kandinsky3Img2ImgPipeline(DiffusionPipeline, StableDiffusionLoraLoaderMixin):
66
66
  model_cpu_offload_seq = "text_encoder->movq->unet->movq"
67
67
  _callback_tensor_inputs = [
68
68
  "latents",
@@ -0,0 +1,54 @@
1
+ from typing import TYPE_CHECKING
2
+
3
+ from ...utils import (
4
+ DIFFUSERS_SLOW_IMPORT,
5
+ OptionalDependencyNotAvailable,
6
+ _LazyModule,
7
+ get_objects_from_module,
8
+ is_sentencepiece_available,
9
+ is_torch_available,
10
+ is_transformers_available,
11
+ )
12
+
13
+
14
+ _dummy_objects = {}
15
+ _import_structure = {}
16
+
17
+ try:
18
+ if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available():
19
+ raise OptionalDependencyNotAvailable()
20
+ except OptionalDependencyNotAvailable:
21
+ from ...utils import dummy_torch_and_transformers_and_sentencepiece_objects # noqa F403
22
+
23
+ _dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_and_sentencepiece_objects))
24
+ else:
25
+ _import_structure["pipeline_kolors"] = ["KolorsPipeline"]
26
+ _import_structure["pipeline_kolors_img2img"] = ["KolorsImg2ImgPipeline"]
27
+ _import_structure["text_encoder"] = ["ChatGLMModel"]
28
+ _import_structure["tokenizer"] = ["ChatGLMTokenizer"]
29
+
30
+ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
31
+ try:
32
+ if not (is_transformers_available() and is_torch_available()) and is_sentencepiece_available():
33
+ raise OptionalDependencyNotAvailable()
34
+ except OptionalDependencyNotAvailable:
35
+ from ...utils.dummy_torch_and_transformers_and_sentencepiece_objects import *
36
+
37
+ else:
38
+ from .pipeline_kolors import KolorsPipeline
39
+ from .pipeline_kolors_img2img import KolorsImg2ImgPipeline
40
+ from .text_encoder import ChatGLMModel
41
+ from .tokenizer import ChatGLMTokenizer
42
+
43
+ else:
44
+ import sys
45
+
46
+ sys.modules[__name__] = _LazyModule(
47
+ __name__,
48
+ globals()["__file__"],
49
+ _import_structure,
50
+ module_spec=__spec__,
51
+ )
52
+
53
+ for name, value in _dummy_objects.items():
54
+ setattr(sys.modules[__name__], name, value)