diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,21 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from typing import Optional, Union
15
+ from typing import Callable, Dict, List, Optional, Tuple, Union
16
16
 
17
17
  import torch
18
+ import torch.nn as nn
18
19
 
19
20
  from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
21
+ from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
22
+ from ..models.transformers.transformer_2d import Transformer2DModel
20
23
  from ..models.unets.unet_motion_model import (
24
+ AnimateDiffTransformer3D,
21
25
  CrossAttnDownBlockMotion,
22
26
  DownBlockMotion,
23
27
  UpBlockMotion,
24
28
  )
29
+ from ..pipelines.pipeline_utils import DiffusionPipeline
25
30
  from ..utils import logging
26
31
  from ..utils.torch_utils import randn_tensor
27
32
 
@@ -29,6 +34,114 @@ from ..utils.torch_utils import randn_tensor
29
34
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
35
 
31
36
 
37
+ class SplitInferenceModule(nn.Module):
38
+ r"""
39
+ A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
40
+
41
+ This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
42
+ them into smaller chunks, processing each chunk separately, and then reassembling the results.
43
+
44
+ Args:
45
+ module (`nn.Module`):
46
+ The underlying PyTorch module that will be applied to each chunk of split inputs.
47
+ split_size (`int`, defaults to `1`):
48
+ The size of each chunk after splitting the input tensor.
49
+ split_dim (`int`, defaults to `0`):
50
+ The dimension along which the input tensors are split.
51
+ input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`):
52
+ A list of keyword arguments (strings) that represent the input tensors to be split.
53
+
54
+ Workflow:
55
+ 1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
56
+ `torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
57
+ 2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
58
+ that were passed.
59
+ 3. The output tensors from each split are concatenated back together along `split_dim` before returning.
60
+
61
+ Example:
62
+ ```python
63
+ >>> import torch
64
+ >>> import torch.nn as nn
65
+
66
+ >>> model = nn.Linear(1000, 1000)
67
+ >>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"])
68
+
69
+ >>> input_tensor = torch.randn(42, 1000)
70
+ >>> # Will split the tensor into 21 slices of shape [2, 1000].
71
+ >>> output = split_module(input=input_tensor)
72
+ ```
73
+
74
+ It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex
75
+ multi-dimensional splitting.
76
+ """
77
+
78
+ def __init__(
79
+ self,
80
+ module: nn.Module,
81
+ split_size: int = 1,
82
+ split_dim: int = 0,
83
+ input_kwargs_to_split: List[str] = ["hidden_states"],
84
+ ) -> None:
85
+ super().__init__()
86
+
87
+ self.module = module
88
+ self.split_size = split_size
89
+ self.split_dim = split_dim
90
+ self.input_kwargs_to_split = set(input_kwargs_to_split)
91
+
92
+ def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
93
+ r"""Forward method for the `SplitInferenceModule`.
94
+
95
+ This method processes the input by splitting specified keyword arguments along a given dimension, running the
96
+ underlying module on each split, and then concatenating the results. The splitting is controlled by the
97
+ `split_size` and `split_dim` parameters specified during initialization.
98
+
99
+ Args:
100
+ *args (`Any`):
101
+ Positional arguments that are passed directly to the `module` without modification.
102
+ **kwargs (`Dict[str, torch.Tensor]`):
103
+ Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
104
+ entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
105
+ arguments are passed unchanged.
106
+
107
+ Returns:
108
+ `Union[torch.Tensor, Tuple[torch.Tensor]]`:
109
+ The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
110
+ without it.
111
+ - If the underlying module returns a single tensor, the result will be a single concatenated tensor
112
+ along the same `split_dim` after processing all splits.
113
+ - If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
114
+ along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
115
+ """
116
+ split_inputs = {}
117
+
118
+ # 1. Split inputs that were specified during initialization and also present in passed kwargs
119
+ for key in list(kwargs.keys()):
120
+ if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
121
+ continue
122
+ split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
123
+ kwargs.pop(key)
124
+
125
+ # 2. Invoke forward pass across each split
126
+ results = []
127
+ for split_input in zip(*split_inputs.values()):
128
+ inputs = dict(zip(split_inputs.keys(), split_input))
129
+ inputs.update(kwargs)
130
+
131
+ intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
132
+ results.append(intermediate_tensor_or_tensor_tuple)
133
+
134
+ # 3. Concatenate split restuls to obtain final outputs
135
+ if isinstance(results[0], torch.Tensor):
136
+ return torch.cat(results, dim=self.split_dim)
137
+ elif isinstance(results[0], tuple):
138
+ return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
139
+ else:
140
+ raise ValueError(
141
+ "In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
142
+ )
143
+
144
+
32
145
  class AnimateDiffFreeNoiseMixin:
33
146
  r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
34
147
 
@@ -69,6 +182,9 @@ class AnimateDiffFreeNoiseMixin:
69
182
  motion_module.transformer_blocks[i].load_state_dict(
70
183
  basic_transfomer_block.state_dict(), strict=True
71
184
  )
185
+ motion_module.transformer_blocks[i].set_chunk_feed_forward(
186
+ basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
187
+ )
72
188
 
73
189
  def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
74
190
  r"""Helper function to disable FreeNoise in transformer blocks."""
@@ -97,6 +213,145 @@ class AnimateDiffFreeNoiseMixin:
97
213
  motion_module.transformer_blocks[i].load_state_dict(
98
214
  free_noise_transfomer_block.state_dict(), strict=True
99
215
  )
216
+ motion_module.transformer_blocks[i].set_chunk_feed_forward(
217
+ free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
218
+ )
219
+
220
+ def _check_inputs_free_noise(
221
+ self,
222
+ prompt,
223
+ negative_prompt,
224
+ prompt_embeds,
225
+ negative_prompt_embeds,
226
+ num_frames,
227
+ ) -> None:
228
+ if not isinstance(prompt, (str, dict)):
229
+ raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
230
+
231
+ if negative_prompt is not None:
232
+ if not isinstance(negative_prompt, (str, dict)):
233
+ raise ValueError(
234
+ f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
235
+ )
236
+
237
+ if prompt_embeds is not None or negative_prompt_embeds is not None:
238
+ raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
239
+
240
+ frame_indices = [isinstance(x, int) for x in prompt.keys()]
241
+ frame_prompts = [isinstance(x, str) for x in prompt.values()]
242
+ min_frame = min(list(prompt.keys()))
243
+ max_frame = max(list(prompt.keys()))
244
+
245
+ if not all(frame_indices):
246
+ raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
247
+ if not all(frame_prompts):
248
+ raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
249
+ if min_frame != 0:
250
+ raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
251
+ if max_frame >= num_frames:
252
+ raise ValueError(
253
+ f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
254
+ )
255
+
256
+ def _encode_prompt_free_noise(
257
+ self,
258
+ prompt: Union[str, Dict[int, str]],
259
+ num_frames: int,
260
+ device: torch.device,
261
+ num_videos_per_prompt: int,
262
+ do_classifier_free_guidance: bool,
263
+ negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
264
+ prompt_embeds: Optional[torch.Tensor] = None,
265
+ negative_prompt_embeds: Optional[torch.Tensor] = None,
266
+ lora_scale: Optional[float] = None,
267
+ clip_skip: Optional[int] = None,
268
+ ) -> torch.Tensor:
269
+ if negative_prompt is None:
270
+ negative_prompt = ""
271
+
272
+ # Ensure that we have a dictionary of prompts
273
+ if isinstance(prompt, str):
274
+ prompt = {0: prompt}
275
+ if isinstance(negative_prompt, str):
276
+ negative_prompt = {0: negative_prompt}
277
+
278
+ self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
279
+
280
+ # Sort the prompts based on frame indices
281
+ prompt = dict(sorted(prompt.items()))
282
+ negative_prompt = dict(sorted(negative_prompt.items()))
283
+
284
+ # Ensure that we have a prompt for the last frame index
285
+ prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
286
+ negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
287
+
288
+ frame_indices = list(prompt.keys())
289
+ frame_prompts = list(prompt.values())
290
+ frame_negative_indices = list(negative_prompt.keys())
291
+ frame_negative_prompts = list(negative_prompt.values())
292
+
293
+ # Generate and interpolate positive prompts
294
+ prompt_embeds, _ = self.encode_prompt(
295
+ prompt=frame_prompts,
296
+ device=device,
297
+ num_images_per_prompt=num_videos_per_prompt,
298
+ do_classifier_free_guidance=False,
299
+ negative_prompt=None,
300
+ prompt_embeds=None,
301
+ negative_prompt_embeds=None,
302
+ lora_scale=lora_scale,
303
+ clip_skip=clip_skip,
304
+ )
305
+
306
+ shape = (num_frames, *prompt_embeds.shape[1:])
307
+ prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
308
+
309
+ for i in range(len(frame_indices) - 1):
310
+ start_frame = frame_indices[i]
311
+ end_frame = frame_indices[i + 1]
312
+ start_tensor = prompt_embeds[i].unsqueeze(0)
313
+ end_tensor = prompt_embeds[i + 1].unsqueeze(0)
314
+
315
+ prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
316
+ start_frame, end_frame, start_tensor, end_tensor
317
+ )
318
+
319
+ # Generate and interpolate negative prompts
320
+ negative_prompt_embeds = None
321
+ negative_prompt_interpolation_embeds = None
322
+
323
+ if do_classifier_free_guidance:
324
+ _, negative_prompt_embeds = self.encode_prompt(
325
+ prompt=[""] * len(frame_negative_prompts),
326
+ device=device,
327
+ num_images_per_prompt=num_videos_per_prompt,
328
+ do_classifier_free_guidance=True,
329
+ negative_prompt=frame_negative_prompts,
330
+ prompt_embeds=None,
331
+ negative_prompt_embeds=None,
332
+ lora_scale=lora_scale,
333
+ clip_skip=clip_skip,
334
+ )
335
+
336
+ negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
337
+
338
+ for i in range(len(frame_negative_indices) - 1):
339
+ start_frame = frame_negative_indices[i]
340
+ end_frame = frame_negative_indices[i + 1]
341
+ start_tensor = negative_prompt_embeds[i].unsqueeze(0)
342
+ end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
343
+
344
+ negative_prompt_interpolation_embeds[
345
+ start_frame : end_frame + 1
346
+ ] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
347
+
348
+ prompt_embeds = prompt_interpolation_embeds
349
+ negative_prompt_embeds = negative_prompt_interpolation_embeds
350
+
351
+ if do_classifier_free_guidance:
352
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
353
+
354
+ return prompt_embeds, negative_prompt_embeds
100
355
 
101
356
  def _prepare_latents_free_noise(
102
357
  self,
@@ -172,12 +427,29 @@ class AnimateDiffFreeNoiseMixin:
172
427
  latents = latents[:, :, :num_frames]
173
428
  return latents
174
429
 
430
+ def _lerp(
431
+ self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
432
+ ) -> torch.Tensor:
433
+ num_indices = end_index - start_index + 1
434
+ interpolated_tensors = []
435
+
436
+ for i in range(num_indices):
437
+ alpha = i / (num_indices - 1)
438
+ interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
439
+ interpolated_tensors.append(interpolated_tensor)
440
+
441
+ interpolated_tensors = torch.cat(interpolated_tensors)
442
+ return interpolated_tensors
443
+
175
444
  def enable_free_noise(
176
445
  self,
177
446
  context_length: Optional[int] = 16,
178
447
  context_stride: int = 4,
179
448
  weighting_scheme: str = "pyramid",
180
449
  noise_type: str = "shuffle_context",
450
+ prompt_interpolation_callback: Optional[
451
+ Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
452
+ ] = None,
181
453
  ) -> None:
182
454
  r"""
183
455
  Enable long video generation using FreeNoise.
@@ -195,13 +467,27 @@ class AnimateDiffFreeNoiseMixin:
195
467
  weighting_scheme (`str`, defaults to `pyramid`):
196
468
  Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
197
469
  schemes are supported currently:
470
+ - "flat"
471
+ Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
198
472
  - "pyramid"
199
- Peforms weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
473
+ Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
474
+ - "delayed_reverse_sawtooth"
475
+ Performs weighted averaging with low weights for earlier frames and high-to-low weights for
476
+ later frames: [0.01, 0.01, 3, 2, 1].
200
477
  noise_type (`str`, defaults to "shuffle_context"):
201
- TODO
478
+ Must be one of ["shuffle_context", "repeat_context", "random"].
479
+ - "shuffle_context"
480
+ Shuffles a fixed batch of `context_length` latents to create a final latent of size
481
+ `num_frames`. This is usually the best setting for most generation scenarious. However, there
482
+ might be visible repetition noticeable in the kinds of motion/animation generated.
483
+ - "repeated_context"
484
+ Repeats a fixed batch of `context_length` latents to create a final latent of size
485
+ `num_frames`.
486
+ - "random"
487
+ The final latents are random without any repetition.
202
488
  """
203
489
 
204
- allowed_weighting_scheme = ["pyramid"]
490
+ allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
205
491
  allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
206
492
 
207
493
  if context_length > self.motion_adapter.config.motion_max_seq_length:
@@ -219,18 +505,92 @@ class AnimateDiffFreeNoiseMixin:
219
505
  self._free_noise_context_stride = context_stride
220
506
  self._free_noise_weighting_scheme = weighting_scheme
221
507
  self._free_noise_noise_type = noise_type
508
+ self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
509
+
510
+ if hasattr(self.unet.mid_block, "motion_modules"):
511
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
512
+ else:
513
+ blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
222
514
 
223
- blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
224
515
  for block in blocks:
225
516
  self._enable_free_noise_in_block(block)
226
517
 
227
518
  def disable_free_noise(self) -> None:
519
+ r"""Disable the FreeNoise sampling mechanism."""
228
520
  self._free_noise_context_length = None
229
521
 
522
+ if hasattr(self.unet.mid_block, "motion_modules"):
523
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
524
+ else:
525
+ blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
526
+
230
527
  blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
231
528
  for block in blocks:
232
529
  self._disable_free_noise_in_block(block)
233
530
 
531
+ def _enable_split_inference_motion_modules_(
532
+ self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
533
+ ) -> None:
534
+ for motion_module in motion_modules:
535
+ motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
536
+
537
+ for i in range(len(motion_module.transformer_blocks)):
538
+ motion_module.transformer_blocks[i] = SplitInferenceModule(
539
+ motion_module.transformer_blocks[i],
540
+ spatial_split_size,
541
+ 0,
542
+ ["hidden_states", "encoder_hidden_states"],
543
+ )
544
+
545
+ motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
546
+
547
+ def _enable_split_inference_attentions_(
548
+ self, attentions: List[Transformer2DModel], temporal_split_size: int
549
+ ) -> None:
550
+ for i in range(len(attentions)):
551
+ attentions[i] = SplitInferenceModule(
552
+ attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
553
+ )
554
+
555
+ def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
556
+ for i in range(len(resnets)):
557
+ resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
558
+
559
+ def _enable_split_inference_samplers_(
560
+ self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
561
+ ) -> None:
562
+ for i in range(len(samplers)):
563
+ samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
564
+
565
+ def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
566
+ r"""
567
+ Enable FreeNoise memory optimizations by utilizing
568
+ [`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
569
+
570
+ Args:
571
+ spatial_split_size (`int`, defaults to `256`):
572
+ The split size across spatial dimensions for internal blocks. This is used in facilitating split
573
+ inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
574
+ modeling blocks.
575
+ temporal_split_size (`int`, defaults to `16`):
576
+ The split size across temporal dimensions for internal blocks. This is used in facilitating split
577
+ inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
578
+ attention, resnets, downsampling and upsampling blocks.
579
+ """
580
+ # TODO(aryan): Discuss on what's the best way to provide more control to users
581
+ blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
582
+ for block in blocks:
583
+ if getattr(block, "motion_modules", None) is not None:
584
+ self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
585
+ if getattr(block, "attentions", None) is not None:
586
+ self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
587
+ if getattr(block, "resnets", None) is not None:
588
+ self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
589
+ if getattr(block, "downsamplers", None) is not None:
590
+ self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
591
+ if getattr(block, "upsamplers", None) is not None:
592
+ self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
593
+
234
594
  @property
235
595
  def free_noise_enabled(self):
236
596
  return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
@@ -125,9 +125,21 @@ def get_resize_crop_region_for_grid(src, tgt_size):
125
125
 
126
126
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
127
127
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
128
- """
129
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
130
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
128
+ r"""
129
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
130
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
131
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
132
+
133
+ Args:
134
+ noise_cfg (`torch.Tensor`):
135
+ The predicted noise tensor for the guided diffusion process.
136
+ noise_pred_text (`torch.Tensor`):
137
+ The predicted noise tensor for the text-guided diffusion process.
138
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
139
+ A rescale factor applied to the noise predictions.
140
+
141
+ Returns:
142
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
131
143
  """
132
144
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
133
145
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
@@ -547,7 +547,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
547
547
  negative_image_embeds = prior_outputs[1]
548
548
 
549
549
  prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
550
- image = [image] if isinstance(prompt, PIL.Image.Image) else image
550
+ image = [image] if isinstance(image, PIL.Image.Image) else image
551
551
 
552
552
  if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
553
553
  prompt = (image_embeds.shape[0] // len(prompt)) * prompt
@@ -813,7 +813,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
813
813
  negative_image_embeds = prior_outputs[1]
814
814
 
815
815
  prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
816
- image = [image] if isinstance(prompt, PIL.Image.Image) else image
816
+ image = [image] if isinstance(image, PIL.Image.Image) else image
817
817
  mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image
818
818
 
819
819
  if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
@@ -70,7 +70,7 @@ def retrieve_timesteps(
70
70
  sigmas: Optional[List[float]] = None,
71
71
  **kwargs,
72
72
  ):
73
- """
73
+ r"""
74
74
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
75
75
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
76
76
 
@@ -89,7 +89,7 @@ def retrieve_timesteps(
89
89
  sigmas: Optional[List[float]] = None,
90
90
  **kwargs,
91
91
  ):
92
- """
92
+ r"""
93
93
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
94
94
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
95
95
 
@@ -564,14 +564,16 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
564
564
  if denoising_start is None:
565
565
  init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
566
566
  t_start = max(num_inference_steps - init_timestep, 0)
567
- else:
568
- t_start = 0
569
567
 
570
- timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
568
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
569
+ if hasattr(self.scheduler, "set_begin_index"):
570
+ self.scheduler.set_begin_index(t_start * self.scheduler.order)
571
571
 
572
- # Strength is irrelevant if we directly request a timestep to start at;
573
- # that is, strength is determined by the denoising_start instead.
574
- if denoising_start is not None:
572
+ return timesteps, num_inference_steps - t_start
573
+
574
+ else:
575
+ # Strength is irrelevant if we directly request a timestep to start at;
576
+ # that is, strength is determined by the denoising_start instead.
575
577
  discrete_timestep_cutoff = int(
576
578
  round(
577
579
  self.scheduler.config.num_train_timesteps
@@ -579,7 +581,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
579
581
  )
580
582
  )
581
583
 
582
- num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
584
+ num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
583
585
  if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
584
586
  # if the scheduler is a 2nd order scheduler we might have to do +1
585
587
  # because `num_inference_steps` might be even given that every timestep
@@ -590,11 +592,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
590
592
  num_inference_steps = num_inference_steps + 1
591
593
 
592
594
  # because t_n+1 >= t_n, we slice the timesteps starting from the end
593
- timesteps = timesteps[-num_inference_steps:]
595
+ t_start = len(self.scheduler.timesteps) - num_inference_steps
596
+ timesteps = self.scheduler.timesteps[t_start:]
597
+ if hasattr(self.scheduler, "set_begin_index"):
598
+ self.scheduler.set_begin_index(t_start)
594
599
  return timesteps, num_inference_steps
595
600
 
596
- return timesteps, num_inference_steps - t_start
597
-
598
601
  # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
599
602
  def prepare_latents(
600
603
  self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
@@ -277,6 +277,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
277
277
  padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
278
278
  pad_to_multiple_of: Optional[int] = None,
279
279
  return_attention_mask: Optional[bool] = None,
280
+ padding_side: Optional[bool] = None,
280
281
  ) -> dict:
281
282
  """
282
283
  Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
@@ -298,6 +299,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
298
299
  pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
299
300
  This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
300
301
  `>= 7.5` (Volta).
302
+ padding_side (`str`, *optional*):
303
+ The side on which the model should have padding applied. Should be selected between ['right', 'left'].
304
+ Default value is picked from the class attribute of the same name.
301
305
  return_attention_mask:
302
306
  (optional) Set to False to avoid returning attention mask (default: set to model specifics)
303
307
  """
@@ -66,7 +66,7 @@ def retrieve_timesteps(
66
66
  sigmas: Optional[List[float]] = None,
67
67
  **kwargs,
68
68
  ):
69
- """
69
+ r"""
70
70
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
71
71
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
72
72
 
@@ -70,7 +70,7 @@ def retrieve_timesteps(
70
70
  sigmas: Optional[List[float]] = None,
71
71
  **kwargs,
72
72
  ):
73
- """
73
+ r"""
74
74
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
75
75
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
76
76
 
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
56
56
  >>> from diffusers.utils import export_to_gif
57
57
 
58
58
  >>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
59
- >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16).to("cuda")
59
+ >>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
60
60
  >>> # Enable memory optimizations.
61
61
  >>> pipe.enable_model_cpu_offload()
62
62
 
@@ -76,7 +76,7 @@ def retrieve_timesteps(
76
76
  sigmas: Optional[List[float]] = None,
77
77
  **kwargs,
78
78
  ):
79
- """
79
+ r"""
80
80
  Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
81
81
  custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
82
82
 
@@ -234,9 +234,21 @@ class LEDITSCrossAttnProcessor:
234
234
 
235
235
  # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
236
236
  def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
237
- """
238
- Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
239
- Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
237
+ r"""
238
+ Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
239
+ Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
240
+ Flawed](https://arxiv.org/pdf/2305.08891.pdf).
241
+
242
+ Args:
243
+ noise_cfg (`torch.Tensor`):
244
+ The predicted noise tensor for the guided diffusion process.
245
+ noise_pred_text (`torch.Tensor`):
246
+ The predicted noise tensor for the text-guided diffusion process.
247
+ guidance_rescale (`float`, *optional*, defaults to 0.0):
248
+ A rescale factor applied to the noise predictions.
249
+
250
+ Returns:
251
+ noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
240
252
  """
241
253
  std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
242
254
  std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)