diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (220) hide show
  1. diffusers/__init__.py +94 -3
  2. diffusers/commands/env.py +1 -5
  3. diffusers/configuration_utils.py +4 -9
  4. diffusers/dependency_versions_table.py +2 -2
  5. diffusers/image_processor.py +1 -2
  6. diffusers/loaders/__init__.py +17 -2
  7. diffusers/loaders/ip_adapter.py +10 -7
  8. diffusers/loaders/lora_base.py +752 -0
  9. diffusers/loaders/lora_pipeline.py +2252 -0
  10. diffusers/loaders/peft.py +213 -5
  11. diffusers/loaders/single_file.py +3 -14
  12. diffusers/loaders/single_file_model.py +31 -10
  13. diffusers/loaders/single_file_utils.py +293 -8
  14. diffusers/loaders/textual_inversion.py +1 -6
  15. diffusers/loaders/unet.py +23 -208
  16. diffusers/models/__init__.py +20 -0
  17. diffusers/models/activations.py +22 -0
  18. diffusers/models/attention.py +386 -7
  19. diffusers/models/attention_processor.py +1937 -629
  20. diffusers/models/autoencoders/__init__.py +2 -0
  21. diffusers/models/autoencoders/autoencoder_kl.py +14 -3
  22. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
  23. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
  24. diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
  25. diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
  26. diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
  27. diffusers/models/autoencoders/vq_model.py +4 -4
  28. diffusers/models/controlnet.py +2 -3
  29. diffusers/models/controlnet_hunyuan.py +401 -0
  30. diffusers/models/controlnet_sd3.py +11 -11
  31. diffusers/models/controlnet_sparsectrl.py +789 -0
  32. diffusers/models/controlnet_xs.py +40 -10
  33. diffusers/models/downsampling.py +68 -0
  34. diffusers/models/embeddings.py +403 -36
  35. diffusers/models/model_loading_utils.py +1 -3
  36. diffusers/models/modeling_flax_utils.py +1 -6
  37. diffusers/models/modeling_utils.py +4 -16
  38. diffusers/models/normalization.py +203 -12
  39. diffusers/models/transformers/__init__.py +6 -0
  40. diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
  41. diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
  42. diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
  43. diffusers/models/transformers/latte_transformer_3d.py +327 -0
  44. diffusers/models/transformers/lumina_nextdit2d.py +340 -0
  45. diffusers/models/transformers/pixart_transformer_2d.py +102 -1
  46. diffusers/models/transformers/prior_transformer.py +1 -1
  47. diffusers/models/transformers/stable_audio_transformer.py +458 -0
  48. diffusers/models/transformers/transformer_flux.py +455 -0
  49. diffusers/models/transformers/transformer_sd3.py +18 -4
  50. diffusers/models/unets/unet_1d_blocks.py +1 -1
  51. diffusers/models/unets/unet_2d_condition.py +8 -1
  52. diffusers/models/unets/unet_3d_blocks.py +51 -920
  53. diffusers/models/unets/unet_3d_condition.py +4 -1
  54. diffusers/models/unets/unet_i2vgen_xl.py +4 -1
  55. diffusers/models/unets/unet_kandinsky3.py +1 -1
  56. diffusers/models/unets/unet_motion_model.py +1330 -84
  57. diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
  58. diffusers/models/unets/unet_stable_cascade.py +1 -3
  59. diffusers/models/unets/uvit_2d.py +1 -1
  60. diffusers/models/upsampling.py +64 -0
  61. diffusers/models/vq_model.py +8 -4
  62. diffusers/optimization.py +1 -1
  63. diffusers/pipelines/__init__.py +100 -3
  64. diffusers/pipelines/animatediff/__init__.py +4 -0
  65. diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
  66. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
  70. diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
  71. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
  72. diffusers/pipelines/aura_flow/__init__.py +48 -0
  73. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
  74. diffusers/pipelines/auto_pipeline.py +97 -19
  75. diffusers/pipelines/cogvideo/__init__.py +48 -0
  76. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
  77. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
  78. diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
  79. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
  80. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
  81. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
  82. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
  83. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
  84. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
  85. diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
  86. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
  87. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
  88. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
  90. diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
  91. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
  96. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
  97. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
  98. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
  99. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  100. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
  101. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
  103. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  104. diffusers/pipelines/flux/__init__.py +47 -0
  105. diffusers/pipelines/flux/pipeline_flux.py +749 -0
  106. diffusers/pipelines/flux/pipeline_output.py +21 -0
  107. diffusers/pipelines/free_init_utils.py +2 -0
  108. diffusers/pipelines/free_noise_utils.py +236 -0
  109. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
  110. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
  111. diffusers/pipelines/kolors/__init__.py +54 -0
  112. diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
  113. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
  114. diffusers/pipelines/kolors/pipeline_output.py +21 -0
  115. diffusers/pipelines/kolors/text_encoder.py +889 -0
  116. diffusers/pipelines/kolors/tokenizer.py +334 -0
  117. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
  118. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
  119. diffusers/pipelines/latte/__init__.py +48 -0
  120. diffusers/pipelines/latte/pipeline_latte.py +881 -0
  121. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
  122. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
  123. diffusers/pipelines/lumina/__init__.py +48 -0
  124. diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
  125. diffusers/pipelines/pag/__init__.py +67 -0
  126. diffusers/pipelines/pag/pag_utils.py +237 -0
  127. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
  128. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
  129. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
  130. diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
  131. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
  132. diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
  133. diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
  134. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
  135. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
  136. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
  137. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
  138. diffusers/pipelines/pia/pipeline_pia.py +30 -37
  139. diffusers/pipelines/pipeline_flax_utils.py +4 -9
  140. diffusers/pipelines/pipeline_loading_utils.py +0 -3
  141. diffusers/pipelines/pipeline_utils.py +2 -14
  142. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
  143. diffusers/pipelines/stable_audio/__init__.py +50 -0
  144. diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
  145. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
  146. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
  147. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
  148. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
  149. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
  150. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
  151. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
  152. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
  153. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
  154. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
  155. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
  156. diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
  157. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
  158. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
  160. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
  161. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
  162. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
  163. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
  164. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
  165. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
  166. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
  167. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
  168. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
  170. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
  171. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
  172. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
  173. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
  174. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
  175. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
  176. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
  177. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
  178. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
  179. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  180. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  181. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
  182. diffusers/schedulers/__init__.py +8 -0
  183. diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
  184. diffusers/schedulers/scheduling_ddim.py +1 -1
  185. diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
  186. diffusers/schedulers/scheduling_ddpm.py +1 -1
  187. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
  188. diffusers/schedulers/scheduling_deis_multistep.py +2 -2
  189. diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
  190. diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
  191. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
  192. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
  193. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
  194. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
  195. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
  196. diffusers/schedulers/scheduling_ipndm.py +1 -1
  197. diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
  198. diffusers/schedulers/scheduling_utils.py +1 -3
  199. diffusers/schedulers/scheduling_utils_flax.py +1 -3
  200. diffusers/training_utils.py +99 -14
  201. diffusers/utils/__init__.py +2 -2
  202. diffusers/utils/dummy_pt_objects.py +210 -0
  203. diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
  204. diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
  205. diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
  206. diffusers/utils/dynamic_modules_utils.py +1 -11
  207. diffusers/utils/export_utils.py +50 -6
  208. diffusers/utils/hub_utils.py +45 -42
  209. diffusers/utils/import_utils.py +37 -15
  210. diffusers/utils/loading_utils.py +80 -3
  211. diffusers/utils/testing_utils.py +11 -8
  212. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
  213. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
  214. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
  215. diffusers/loaders/autoencoder.py +0 -146
  216. diffusers/loaders/controlnet.py +0 -136
  217. diffusers/loaders/lora.py +0 -1728
  218. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
  219. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
  220. {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,321 @@
1
+ # Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import Optional, Tuple, Union
17
+
18
+ import numpy as np
19
+ import torch
20
+
21
+ from ..configuration_utils import ConfigMixin, register_to_config
22
+ from ..utils import BaseOutput, logging
23
+ from ..utils.torch_utils import randn_tensor
24
+ from .scheduling_utils import SchedulerMixin
25
+
26
+
27
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
28
+
29
+
30
+ @dataclass
31
+ class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
32
+ """
33
+ Output class for the scheduler's `step` function output.
34
+
35
+ Args:
36
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
37
+ Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
38
+ denoising loop.
39
+ """
40
+
41
+ prev_sample: torch.FloatTensor
42
+
43
+
44
+ class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
45
+ """
46
+ Heun scheduler.
47
+
48
+ This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
49
+ methods the library implements for all schedulers such as loading and saving.
50
+
51
+ Args:
52
+ num_train_timesteps (`int`, defaults to 1000):
53
+ The number of diffusion steps to train the model.
54
+ timestep_spacing (`str`, defaults to `"linspace"`):
55
+ The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
56
+ Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
57
+ shift (`float`, defaults to 1.0):
58
+ The shift value for the timestep schedule.
59
+ """
60
+
61
+ _compatibles = []
62
+ order = 2
63
+
64
+ @register_to_config
65
+ def __init__(
66
+ self,
67
+ num_train_timesteps: int = 1000,
68
+ shift: float = 1.0,
69
+ ):
70
+ timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
71
+ timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
72
+
73
+ sigmas = timesteps / num_train_timesteps
74
+ sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
75
+
76
+ self.timesteps = sigmas * num_train_timesteps
77
+
78
+ self._step_index = None
79
+ self._begin_index = None
80
+
81
+ self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
82
+ self.sigma_min = self.sigmas[-1].item()
83
+ self.sigma_max = self.sigmas[0].item()
84
+
85
+ @property
86
+ def step_index(self):
87
+ """
88
+ The index counter for current timestep. It will increase 1 after each scheduler step.
89
+ """
90
+ return self._step_index
91
+
92
+ @property
93
+ def begin_index(self):
94
+ """
95
+ The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
96
+ """
97
+ return self._begin_index
98
+
99
+ # Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
100
+ def set_begin_index(self, begin_index: int = 0):
101
+ """
102
+ Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
103
+
104
+ Args:
105
+ begin_index (`int`):
106
+ The begin index for the scheduler.
107
+ """
108
+ self._begin_index = begin_index
109
+
110
+ def scale_noise(
111
+ self,
112
+ sample: torch.FloatTensor,
113
+ timestep: Union[float, torch.FloatTensor],
114
+ noise: Optional[torch.FloatTensor] = None,
115
+ ) -> torch.FloatTensor:
116
+ """
117
+ Forward process in flow-matching
118
+
119
+ Args:
120
+ sample (`torch.FloatTensor`):
121
+ The input sample.
122
+ timestep (`int`, *optional*):
123
+ The current timestep in the diffusion chain.
124
+
125
+ Returns:
126
+ `torch.FloatTensor`:
127
+ A scaled input sample.
128
+ """
129
+ if self.step_index is None:
130
+ self._init_step_index(timestep)
131
+
132
+ sigma = self.sigmas[self.step_index]
133
+ sample = sigma * noise + (1.0 - sigma) * sample
134
+
135
+ return sample
136
+
137
+ def _sigma_to_t(self, sigma):
138
+ return sigma * self.config.num_train_timesteps
139
+
140
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
141
+ """
142
+ Sets the discrete timesteps used for the diffusion chain (to be run before inference).
143
+
144
+ Args:
145
+ num_inference_steps (`int`):
146
+ The number of diffusion steps used when generating samples with a pre-trained model.
147
+ device (`str` or `torch.device`, *optional*):
148
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
149
+ """
150
+ self.num_inference_steps = num_inference_steps
151
+
152
+ timesteps = np.linspace(
153
+ self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
154
+ )
155
+
156
+ sigmas = timesteps / self.config.num_train_timesteps
157
+ sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
158
+ sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
159
+
160
+ timesteps = sigmas * self.config.num_train_timesteps
161
+ timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
162
+ self.timesteps = timesteps.to(device=device)
163
+
164
+ sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
165
+ self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
166
+
167
+ # empty dt and derivative
168
+ self.prev_derivative = None
169
+ self.dt = None
170
+
171
+ self._step_index = None
172
+ self._begin_index = None
173
+
174
+ def index_for_timestep(self, timestep, schedule_timesteps=None):
175
+ if schedule_timesteps is None:
176
+ schedule_timesteps = self.timesteps
177
+
178
+ indices = (schedule_timesteps == timestep).nonzero()
179
+
180
+ # The sigma index that is taken for the **very** first `step`
181
+ # is always the second index (or the last index if there is only 1)
182
+ # This way we can ensure we don't accidentally skip a sigma in
183
+ # case we start in the middle of the denoising schedule (e.g. for image-to-image)
184
+ pos = 1 if len(indices) > 1 else 0
185
+
186
+ return indices[pos].item()
187
+
188
+ def _init_step_index(self, timestep):
189
+ if self.begin_index is None:
190
+ if isinstance(timestep, torch.Tensor):
191
+ timestep = timestep.to(self.timesteps.device)
192
+ self._step_index = self.index_for_timestep(timestep)
193
+ else:
194
+ self._step_index = self._begin_index
195
+
196
+ @property
197
+ def state_in_first_order(self):
198
+ return self.dt is None
199
+
200
+ def step(
201
+ self,
202
+ model_output: torch.FloatTensor,
203
+ timestep: Union[float, torch.FloatTensor],
204
+ sample: torch.FloatTensor,
205
+ s_churn: float = 0.0,
206
+ s_tmin: float = 0.0,
207
+ s_tmax: float = float("inf"),
208
+ s_noise: float = 1.0,
209
+ generator: Optional[torch.Generator] = None,
210
+ return_dict: bool = True,
211
+ ) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
212
+ """
213
+ Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
214
+ process from the learned model outputs (most often the predicted noise).
215
+
216
+ Args:
217
+ model_output (`torch.FloatTensor`):
218
+ The direct output from learned diffusion model.
219
+ timestep (`float`):
220
+ The current discrete timestep in the diffusion chain.
221
+ sample (`torch.FloatTensor`):
222
+ A current instance of a sample created by the diffusion process.
223
+ s_churn (`float`):
224
+ s_tmin (`float`):
225
+ s_tmax (`float`):
226
+ s_noise (`float`, defaults to 1.0):
227
+ Scaling factor for noise added to the sample.
228
+ generator (`torch.Generator`, *optional*):
229
+ A random number generator.
230
+ return_dict (`bool`):
231
+ Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
232
+ tuple.
233
+
234
+ Returns:
235
+ [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
236
+ If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
237
+ returned, otherwise a tuple is returned where the first element is the sample tensor.
238
+ """
239
+
240
+ if (
241
+ isinstance(timestep, int)
242
+ or isinstance(timestep, torch.IntTensor)
243
+ or isinstance(timestep, torch.LongTensor)
244
+ ):
245
+ raise ValueError(
246
+ (
247
+ "Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
248
+ " `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
249
+ " one of the `scheduler.timesteps` as a timestep."
250
+ ),
251
+ )
252
+
253
+ if self.step_index is None:
254
+ self._init_step_index(timestep)
255
+
256
+ # Upcast to avoid precision issues when computing prev_sample
257
+ sample = sample.to(torch.float32)
258
+
259
+ if self.state_in_first_order:
260
+ sigma = self.sigmas[self.step_index]
261
+ sigma_next = self.sigmas[self.step_index + 1]
262
+ else:
263
+ # 2nd order / Heun's method
264
+ sigma = self.sigmas[self.step_index - 1]
265
+ sigma_next = self.sigmas[self.step_index]
266
+
267
+ gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
268
+
269
+ noise = randn_tensor(
270
+ model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
271
+ )
272
+
273
+ eps = noise * s_noise
274
+ sigma_hat = sigma * (gamma + 1)
275
+
276
+ if gamma > 0:
277
+ sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
278
+
279
+ if self.state_in_first_order:
280
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
281
+ denoised = sample - model_output * sigma
282
+ # 2. convert to an ODE derivative for 1st order
283
+ derivative = (sample - denoised) / sigma_hat
284
+ # 3. Delta timestep
285
+ dt = sigma_next - sigma_hat
286
+
287
+ # store for 2nd order step
288
+ self.prev_derivative = derivative
289
+ self.dt = dt
290
+ self.sample = sample
291
+ else:
292
+ # 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
293
+ denoised = sample - model_output * sigma_next
294
+ # 2. 2nd order / Heun's method
295
+ derivative = (sample - denoised) / sigma_next
296
+ derivative = 0.5 * (self.prev_derivative + derivative)
297
+
298
+ # 3. take prev timestep & sample
299
+ dt = self.dt
300
+ sample = self.sample
301
+
302
+ # free dt and derivative
303
+ # Note, this puts the scheduler in "first order mode"
304
+ self.prev_derivative = None
305
+ self.dt = None
306
+ self.sample = None
307
+
308
+ prev_sample = sample + derivative * dt
309
+ # Cast sample back to model compatible dtype
310
+ prev_sample = prev_sample.to(model_output.dtype)
311
+
312
+ # upon completion increase step index by one
313
+ self._step_index += 1
314
+
315
+ if not return_dict:
316
+ return (prev_sample,)
317
+
318
+ return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
319
+
320
+ def __len__(self):
321
+ return self.config.num_train_timesteps
@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
138
138
  def step(
139
139
  self,
140
140
  model_output: torch.Tensor,
141
- timestep: int,
141
+ timestep: Union[int, torch.Tensor],
142
142
  sample: torch.Tensor,
143
143
  return_dict: bool = True,
144
144
  ) -> Union[SchedulerOutput, Tuple]:
@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
822
822
  def step(
823
823
  self,
824
824
  model_output: torch.Tensor,
825
- timestep: int,
825
+ timestep: Union[int, torch.Tensor],
826
826
  sample: torch.Tensor,
827
827
  return_dict: bool = True,
828
828
  ) -> Union[SchedulerOutput, Tuple]:
@@ -121,9 +121,7 @@ class SchedulerMixin(PushToHubMixin):
121
121
  force_download (`bool`, *optional*, defaults to `False`):
122
122
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
123
123
  cached versions if they exist.
124
- resume_download:
125
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
126
- of Diffusers.
124
+
127
125
  proxies (`Dict[str, str]`, *optional*):
128
126
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
129
127
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -102,9 +102,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
102
102
  force_download (`bool`, *optional*, defaults to `False`):
103
103
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
104
104
  cached versions if they exist.
105
- resume_download:
106
- Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
107
- of Diffusers.
105
+
108
106
  proxies (`Dict[str, str]`, *optional*):
109
107
  A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
110
108
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -1,5 +1,6 @@
1
1
  import contextlib
2
2
  import copy
3
+ import math
3
4
  import random
4
5
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
5
6
 
@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
220
221
  set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
221
222
 
222
223
 
224
+ def compute_density_for_timestep_sampling(
225
+ weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
226
+ ):
227
+ """Compute the density for sampling the timesteps when doing SD3 training.
228
+
229
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
230
+
231
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
232
+ """
233
+ if weighting_scheme == "logit_normal":
234
+ # See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
235
+ u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
236
+ u = torch.nn.functional.sigmoid(u)
237
+ elif weighting_scheme == "mode":
238
+ u = torch.rand(size=(batch_size,), device="cpu")
239
+ u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
240
+ else:
241
+ u = torch.rand(size=(batch_size,), device="cpu")
242
+ return u
243
+
244
+
245
+ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
246
+ """Computes loss weighting scheme for SD3 training.
247
+
248
+ Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
249
+
250
+ SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
251
+ """
252
+ if weighting_scheme == "sigma_sqrt":
253
+ weighting = (sigmas**-2.0).float()
254
+ elif weighting_scheme == "cosmap":
255
+ bot = 1 - 2 * sigmas + 2 * sigmas**2
256
+ weighting = 2 / (math.pi * bot)
257
+ else:
258
+ weighting = torch.ones_like(sigmas)
259
+ return weighting
260
+
261
+
223
262
  # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
224
263
  class EMAModel:
225
264
  """
@@ -235,6 +274,7 @@ class EMAModel:
235
274
  use_ema_warmup: bool = False,
236
275
  inv_gamma: Union[float, int] = 1.0,
237
276
  power: Union[float, int] = 2 / 3,
277
+ foreach: bool = False,
238
278
  model_cls: Optional[Any] = None,
239
279
  model_config: Dict[str, Any] = None,
240
280
  **kwargs,
@@ -249,6 +289,7 @@ class EMAModel:
249
289
  inv_gamma (float):
250
290
  Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
251
291
  power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
292
+ foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
252
293
  device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
253
294
  weights will be stored on CPU.
254
295
 
@@ -303,16 +344,17 @@ class EMAModel:
303
344
  self.power = power
304
345
  self.optimization_step = 0
305
346
  self.cur_decay_value = None # set in `step()`
347
+ self.foreach = foreach
306
348
 
307
349
  self.model_cls = model_cls
308
350
  self.model_config = model_config
309
351
 
310
352
  @classmethod
311
- def from_pretrained(cls, path, model_cls) -> "EMAModel":
353
+ def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
312
354
  _, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
313
355
  model = model_cls.from_pretrained(path)
314
356
 
315
- ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
357
+ ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
316
358
 
317
359
  ema_model.load_state_dict(ema_kwargs)
318
360
  return ema_model
@@ -379,15 +421,37 @@ class EMAModel:
379
421
  if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
380
422
  import deepspeed
381
423
 
382
- for s_param, param in zip(self.shadow_params, parameters):
424
+ if self.foreach:
383
425
  if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
384
- context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
426
+ context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
385
427
 
386
428
  with context_manager():
387
- if param.requires_grad:
388
- s_param.sub_(one_minus_decay * (s_param - param))
389
- else:
390
- s_param.copy_(param)
429
+ params_grad = [param for param in parameters if param.requires_grad]
430
+ s_params_grad = [
431
+ s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
432
+ ]
433
+
434
+ if len(params_grad) < len(parameters):
435
+ torch._foreach_copy_(
436
+ [s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
437
+ [param for param in parameters if not param.requires_grad],
438
+ non_blocking=True,
439
+ )
440
+
441
+ torch._foreach_sub_(
442
+ s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
443
+ )
444
+
445
+ else:
446
+ for s_param, param in zip(self.shadow_params, parameters):
447
+ if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
448
+ context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
449
+
450
+ with context_manager():
451
+ if param.requires_grad:
452
+ s_param.sub_(one_minus_decay * (s_param - param))
453
+ else:
454
+ s_param.copy_(param)
391
455
 
392
456
  def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
393
457
  """
@@ -399,10 +463,24 @@ class EMAModel:
399
463
  `ExponentialMovingAverage` was initialized will be used.
400
464
  """
401
465
  parameters = list(parameters)
402
- for s_param, param in zip(self.shadow_params, parameters):
403
- param.data.copy_(s_param.to(param.device).data)
466
+ if self.foreach:
467
+ torch._foreach_copy_(
468
+ [param.data for param in parameters],
469
+ [s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
470
+ )
471
+ else:
472
+ for s_param, param in zip(self.shadow_params, parameters):
473
+ param.data.copy_(s_param.to(param.device).data)
404
474
 
405
- def to(self, device=None, dtype=None) -> None:
475
+ def pin_memory(self) -> None:
476
+ r"""
477
+ Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
478
+ offloading EMA params to the host.
479
+ """
480
+
481
+ self.shadow_params = [p.pin_memory() for p in self.shadow_params]
482
+
483
+ def to(self, device=None, dtype=None, non_blocking=False) -> None:
406
484
  r"""Move internal buffers of the ExponentialMovingAverage to `device`.
407
485
 
408
486
  Args:
@@ -410,7 +488,9 @@ class EMAModel:
410
488
  """
411
489
  # .to() on the tensors handles None correctly
412
490
  self.shadow_params = [
413
- p.to(device=device, dtype=dtype) if p.is_floating_point() else p.to(device=device)
491
+ p.to(device=device, dtype=dtype, non_blocking=non_blocking)
492
+ if p.is_floating_point()
493
+ else p.to(device=device, non_blocking=non_blocking)
414
494
  for p in self.shadow_params
415
495
  ]
416
496
 
@@ -454,8 +534,13 @@ class EMAModel:
454
534
  """
455
535
  if self.temp_stored_params is None:
456
536
  raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
457
- for c_param, param in zip(self.temp_stored_params, parameters):
458
- param.data.copy_(c_param.data)
537
+ if self.foreach:
538
+ torch._foreach_copy_(
539
+ [param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
540
+ )
541
+ else:
542
+ for c_param, param in zip(self.temp_stored_params, parameters):
543
+ param.data.copy_(c_param.data)
459
544
 
460
545
  # Better memory-wise.
461
546
  self.temp_stored_params = None
@@ -73,12 +73,12 @@ from .import_utils import (
73
73
  is_librosa_available,
74
74
  is_matplotlib_available,
75
75
  is_note_seq_available,
76
- is_notebook,
77
76
  is_onnx_available,
78
77
  is_peft_available,
79
78
  is_peft_version,
80
79
  is_safetensors_available,
81
80
  is_scipy_available,
81
+ is_sentencepiece_available,
82
82
  is_tensorboard_available,
83
83
  is_timm_available,
84
84
  is_torch_available,
@@ -94,7 +94,7 @@ from .import_utils import (
94
94
  is_xformers_available,
95
95
  requires_backends,
96
96
  )
97
- from .loading_utils import load_image
97
+ from .loading_utils import load_image, load_video
98
98
  from .logging import get_logger
99
99
  from .outputs import BaseOutput
100
100
  from .peft_utils import (