diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +74 -28
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -12,36 +12,24 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
- from dataclasses import dataclass
16
- from typing import Any, Dict, List, Optional, Tuple, Union
17
15
 
18
- import torch
19
- import torch.nn as nn
16
+ from typing import List
20
17
 
21
- from ..configuration_utils import ConfigMixin, register_to_config
22
- from ..loaders import PeftAdapterMixin
23
- from ..models.attention_processor import AttentionProcessor
24
- from ..models.modeling_utils import ModelMixin
25
- from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
26
- from .controlnet import BaseOutput, ControlNetConditioningEmbedding, zero_module
27
- from .embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
28
- from .modeling_outputs import Transformer2DModelOutput
29
- from .transformers.transformer_flux import FluxSingleTransformerBlock, FluxTransformerBlock
18
+ from ..utils import deprecate, logging
19
+ from .controlnets.controlnet_flux import FluxControlNetModel, FluxControlNetOutput, FluxMultiControlNetModel
30
20
 
31
21
 
32
22
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
33
23
 
34
24
 
35
- @dataclass
36
- class FluxControlNetOutput(BaseOutput):
37
- controlnet_block_samples: Tuple[torch.Tensor]
38
- controlnet_single_block_samples: Tuple[torch.Tensor]
25
+ class FluxControlNetOutput(FluxControlNetOutput):
26
+ def __init__(self, *args, **kwargs):
27
+ deprecation_message = "Importing `FluxControlNetOutput` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetOutput`, instead."
28
+ deprecate("diffusers.models.controlnet_flux.FluxControlNetOutput", "0.34", deprecation_message)
29
+ super().__init__(*args, **kwargs)
39
30
 
40
31
 
41
- class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
42
- _supports_gradient_checkpointing = True
43
-
44
- @register_to_config
32
+ class FluxControlNetModel(FluxControlNetModel):
45
33
  def __init__(
46
34
  self,
47
35
  patch_size: int = 1,
@@ -57,480 +45,26 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
57
45
  num_mode: int = None,
58
46
  conditioning_embedding_channels: int = None,
59
47
  ):
60
- super().__init__()
61
- self.out_channels = in_channels
62
- self.inner_dim = num_attention_heads * attention_head_dim
63
-
64
- self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
65
- text_time_guidance_cls = (
66
- CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
67
- )
68
- self.time_text_embed = text_time_guidance_cls(
69
- embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
70
- )
71
-
72
- self.context_embedder = nn.Linear(joint_attention_dim, self.inner_dim)
73
- self.x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
74
-
75
- self.transformer_blocks = nn.ModuleList(
76
- [
77
- FluxTransformerBlock(
78
- dim=self.inner_dim,
79
- num_attention_heads=num_attention_heads,
80
- attention_head_dim=attention_head_dim,
81
- )
82
- for i in range(num_layers)
83
- ]
84
- )
85
-
86
- self.single_transformer_blocks = nn.ModuleList(
87
- [
88
- FluxSingleTransformerBlock(
89
- dim=self.inner_dim,
90
- num_attention_heads=num_attention_heads,
91
- attention_head_dim=attention_head_dim,
92
- )
93
- for i in range(num_single_layers)
94
- ]
95
- )
96
-
97
- # controlnet_blocks
98
- self.controlnet_blocks = nn.ModuleList([])
99
- for _ in range(len(self.transformer_blocks)):
100
- self.controlnet_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
101
-
102
- self.controlnet_single_blocks = nn.ModuleList([])
103
- for _ in range(len(self.single_transformer_blocks)):
104
- self.controlnet_single_blocks.append(zero_module(nn.Linear(self.inner_dim, self.inner_dim)))
105
-
106
- self.union = num_mode is not None
107
- if self.union:
108
- self.controlnet_mode_embedder = nn.Embedding(num_mode, self.inner_dim)
109
-
110
- if conditioning_embedding_channels is not None:
111
- self.input_hint_block = ControlNetConditioningEmbedding(
112
- conditioning_embedding_channels=conditioning_embedding_channels, block_out_channels=(16, 16, 16, 16)
113
- )
114
- self.controlnet_x_embedder = torch.nn.Linear(in_channels, self.inner_dim)
115
- else:
116
- self.input_hint_block = None
117
- self.controlnet_x_embedder = zero_module(torch.nn.Linear(in_channels, self.inner_dim))
118
-
119
- self.gradient_checkpointing = False
120
-
121
- @property
122
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
123
- def attn_processors(self):
124
- r"""
125
- Returns:
126
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
127
- indexed by its weight name.
128
- """
129
- # set recursively
130
- processors = {}
131
-
132
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
133
- if hasattr(module, "get_processor"):
134
- processors[f"{name}.processor"] = module.get_processor()
135
-
136
- for sub_name, child in module.named_children():
137
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
138
-
139
- return processors
140
-
141
- for name, module in self.named_children():
142
- fn_recursive_add_processors(name, module, processors)
143
-
144
- return processors
145
-
146
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
147
- def set_attn_processor(self, processor):
148
- r"""
149
- Sets the attention processor to use to compute attention.
150
-
151
- Parameters:
152
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
153
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
154
- for **all** `Attention` layers.
155
-
156
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
157
- processor. This is strongly recommended when setting trainable attention processors.
158
-
159
- """
160
- count = len(self.attn_processors.keys())
161
-
162
- if isinstance(processor, dict) and len(processor) != count:
163
- raise ValueError(
164
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
165
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
166
- )
167
-
168
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
169
- if hasattr(module, "set_processor"):
170
- if not isinstance(processor, dict):
171
- module.set_processor(processor)
172
- else:
173
- module.set_processor(processor.pop(f"{name}.processor"))
174
-
175
- for sub_name, child in module.named_children():
176
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
177
-
178
- for name, module in self.named_children():
179
- fn_recursive_attn_processor(name, module, processor)
180
-
181
- def _set_gradient_checkpointing(self, module, value=False):
182
- if hasattr(module, "gradient_checkpointing"):
183
- module.gradient_checkpointing = value
184
-
185
- @classmethod
186
- def from_transformer(
187
- cls,
188
- transformer,
189
- num_layers: int = 4,
190
- num_single_layers: int = 10,
191
- attention_head_dim: int = 128,
192
- num_attention_heads: int = 24,
193
- load_weights_from_transformer=True,
194
- ):
195
- config = transformer.config
196
- config["num_layers"] = num_layers
197
- config["num_single_layers"] = num_single_layers
198
- config["attention_head_dim"] = attention_head_dim
199
- config["num_attention_heads"] = num_attention_heads
200
-
201
- controlnet = cls(**config)
202
-
203
- if load_weights_from_transformer:
204
- controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict())
205
- controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict())
206
- controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict())
207
- controlnet.x_embedder.load_state_dict(transformer.x_embedder.state_dict())
208
- controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict(), strict=False)
209
- controlnet.single_transformer_blocks.load_state_dict(
210
- transformer.single_transformer_blocks.state_dict(), strict=False
211
- )
212
-
213
- controlnet.controlnet_x_embedder = zero_module(controlnet.controlnet_x_embedder)
214
-
215
- return controlnet
216
-
217
- def forward(
218
- self,
219
- hidden_states: torch.Tensor,
220
- controlnet_cond: torch.Tensor,
221
- controlnet_mode: torch.Tensor = None,
222
- conditioning_scale: float = 1.0,
223
- encoder_hidden_states: torch.Tensor = None,
224
- pooled_projections: torch.Tensor = None,
225
- timestep: torch.LongTensor = None,
226
- img_ids: torch.Tensor = None,
227
- txt_ids: torch.Tensor = None,
228
- guidance: torch.Tensor = None,
229
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
230
- return_dict: bool = True,
231
- ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
232
- """
233
- The [`FluxTransformer2DModel`] forward method.
234
-
235
- Args:
236
- hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
237
- Input `hidden_states`.
238
- controlnet_cond (`torch.Tensor`):
239
- The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
240
- controlnet_mode (`torch.Tensor`):
241
- The mode tensor of shape `(batch_size, 1)`.
242
- conditioning_scale (`float`, defaults to `1.0`):
243
- The scale factor for ControlNet outputs.
244
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
245
- Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
246
- pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
247
- from the embeddings of input conditions.
248
- timestep ( `torch.LongTensor`):
249
- Used to indicate denoising step.
250
- block_controlnet_hidden_states: (`list` of `torch.Tensor`):
251
- A list of tensors that if specified are added to the residuals of transformer blocks.
252
- joint_attention_kwargs (`dict`, *optional*):
253
- A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
254
- `self.processor` in
255
- [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
256
- return_dict (`bool`, *optional*, defaults to `True`):
257
- Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
258
- tuple.
259
-
260
- Returns:
261
- If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
262
- `tuple` where the first element is the sample tensor.
263
- """
264
- if joint_attention_kwargs is not None:
265
- joint_attention_kwargs = joint_attention_kwargs.copy()
266
- lora_scale = joint_attention_kwargs.pop("scale", 1.0)
267
- else:
268
- lora_scale = 1.0
269
-
270
- if USE_PEFT_BACKEND:
271
- # weight the lora layers by setting `lora_scale` for each PEFT layer
272
- scale_lora_layers(self, lora_scale)
273
- else:
274
- if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
275
- logger.warning(
276
- "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
277
- )
278
- hidden_states = self.x_embedder(hidden_states)
279
-
280
- if self.input_hint_block is not None:
281
- controlnet_cond = self.input_hint_block(controlnet_cond)
282
- batch_size, channels, height_pw, width_pw = controlnet_cond.shape
283
- height = height_pw // self.config.patch_size
284
- width = width_pw // self.config.patch_size
285
- controlnet_cond = controlnet_cond.reshape(
286
- batch_size, channels, height, self.config.patch_size, width, self.config.patch_size
287
- )
288
- controlnet_cond = controlnet_cond.permute(0, 2, 4, 1, 3, 5)
289
- controlnet_cond = controlnet_cond.reshape(batch_size, height * width, -1)
290
- # add
291
- hidden_states = hidden_states + self.controlnet_x_embedder(controlnet_cond)
292
-
293
- timestep = timestep.to(hidden_states.dtype) * 1000
294
- if guidance is not None:
295
- guidance = guidance.to(hidden_states.dtype) * 1000
296
- else:
297
- guidance = None
298
- temb = (
299
- self.time_text_embed(timestep, pooled_projections)
300
- if guidance is None
301
- else self.time_text_embed(timestep, guidance, pooled_projections)
48
+ deprecation_message = "Importing `FluxControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxControlNetModel`, instead."
49
+ deprecate("diffusers.models.controlnet_flux.FluxControlNetModel", "0.34", deprecation_message)
50
+ super().__init__(
51
+ patch_size=patch_size,
52
+ in_channels=in_channels,
53
+ num_layers=num_layers,
54
+ num_single_layers=num_single_layers,
55
+ attention_head_dim=attention_head_dim,
56
+ num_attention_heads=num_attention_heads,
57
+ joint_attention_dim=joint_attention_dim,
58
+ pooled_projection_dim=pooled_projection_dim,
59
+ guidance_embeds=guidance_embeds,
60
+ axes_dims_rope=axes_dims_rope,
61
+ num_mode=num_mode,
62
+ conditioning_embedding_channels=conditioning_embedding_channels,
302
63
  )
303
- encoder_hidden_states = self.context_embedder(encoder_hidden_states)
304
-
305
- if self.union:
306
- # union mode
307
- if controlnet_mode is None:
308
- raise ValueError("`controlnet_mode` cannot be `None` when applying ControlNet-Union")
309
- # union mode emb
310
- controlnet_mode_emb = self.controlnet_mode_embedder(controlnet_mode)
311
- encoder_hidden_states = torch.cat([controlnet_mode_emb, encoder_hidden_states], dim=1)
312
- txt_ids = torch.cat([txt_ids[:1], txt_ids], dim=0)
313
-
314
- if txt_ids.ndim == 3:
315
- logger.warning(
316
- "Passing `txt_ids` 3d torch.Tensor is deprecated."
317
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
318
- )
319
- txt_ids = txt_ids[0]
320
- if img_ids.ndim == 3:
321
- logger.warning(
322
- "Passing `img_ids` 3d torch.Tensor is deprecated."
323
- "Please remove the batch dimension and pass it as a 2d torch Tensor"
324
- )
325
- img_ids = img_ids[0]
326
-
327
- ids = torch.cat((txt_ids, img_ids), dim=0)
328
- image_rotary_emb = self.pos_embed(ids)
329
-
330
- block_samples = ()
331
- for index_block, block in enumerate(self.transformer_blocks):
332
- if self.training and self.gradient_checkpointing:
333
-
334
- def create_custom_forward(module, return_dict=None):
335
- def custom_forward(*inputs):
336
- if return_dict is not None:
337
- return module(*inputs, return_dict=return_dict)
338
- else:
339
- return module(*inputs)
340
-
341
- return custom_forward
342
-
343
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
344
- encoder_hidden_states, hidden_states = torch.utils.checkpoint.checkpoint(
345
- create_custom_forward(block),
346
- hidden_states,
347
- encoder_hidden_states,
348
- temb,
349
- image_rotary_emb,
350
- **ckpt_kwargs,
351
- )
352
-
353
- else:
354
- encoder_hidden_states, hidden_states = block(
355
- hidden_states=hidden_states,
356
- encoder_hidden_states=encoder_hidden_states,
357
- temb=temb,
358
- image_rotary_emb=image_rotary_emb,
359
- )
360
- block_samples = block_samples + (hidden_states,)
361
-
362
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
363
-
364
- single_block_samples = ()
365
- for index_block, block in enumerate(self.single_transformer_blocks):
366
- if self.training and self.gradient_checkpointing:
367
-
368
- def create_custom_forward(module, return_dict=None):
369
- def custom_forward(*inputs):
370
- if return_dict is not None:
371
- return module(*inputs, return_dict=return_dict)
372
- else:
373
- return module(*inputs)
374
-
375
- return custom_forward
376
-
377
- ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
378
- hidden_states = torch.utils.checkpoint.checkpoint(
379
- create_custom_forward(block),
380
- hidden_states,
381
- temb,
382
- image_rotary_emb,
383
- **ckpt_kwargs,
384
- )
385
-
386
- else:
387
- hidden_states = block(
388
- hidden_states=hidden_states,
389
- temb=temb,
390
- image_rotary_emb=image_rotary_emb,
391
- )
392
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
393
-
394
- # controlnet block
395
- controlnet_block_samples = ()
396
- for block_sample, controlnet_block in zip(block_samples, self.controlnet_blocks):
397
- block_sample = controlnet_block(block_sample)
398
- controlnet_block_samples = controlnet_block_samples + (block_sample,)
399
-
400
- controlnet_single_block_samples = ()
401
- for single_block_sample, controlnet_block in zip(single_block_samples, self.controlnet_single_blocks):
402
- single_block_sample = controlnet_block(single_block_sample)
403
- controlnet_single_block_samples = controlnet_single_block_samples + (single_block_sample,)
404
-
405
- # scaling
406
- controlnet_block_samples = [sample * conditioning_scale for sample in controlnet_block_samples]
407
- controlnet_single_block_samples = [sample * conditioning_scale for sample in controlnet_single_block_samples]
408
-
409
- controlnet_block_samples = None if len(controlnet_block_samples) == 0 else controlnet_block_samples
410
- controlnet_single_block_samples = (
411
- None if len(controlnet_single_block_samples) == 0 else controlnet_single_block_samples
412
- )
413
-
414
- if USE_PEFT_BACKEND:
415
- # remove `lora_scale` from each PEFT layer
416
- unscale_lora_layers(self, lora_scale)
417
-
418
- if not return_dict:
419
- return (controlnet_block_samples, controlnet_single_block_samples)
420
-
421
- return FluxControlNetOutput(
422
- controlnet_block_samples=controlnet_block_samples,
423
- controlnet_single_block_samples=controlnet_single_block_samples,
424
- )
425
-
426
-
427
- class FluxMultiControlNetModel(ModelMixin):
428
- r"""
429
- `FluxMultiControlNetModel` wrapper class for Multi-FluxControlNetModel
430
-
431
- This module is a wrapper for multiple instances of the `FluxControlNetModel`. The `forward()` API is designed to be
432
- compatible with `FluxControlNetModel`.
433
-
434
- Args:
435
- controlnets (`List[FluxControlNetModel]`):
436
- Provides additional conditioning to the unet during the denoising process. You must set multiple
437
- `FluxControlNetModel` as a list.
438
- """
439
-
440
- def __init__(self, controlnets):
441
- super().__init__()
442
- self.nets = nn.ModuleList(controlnets)
443
-
444
- def forward(
445
- self,
446
- hidden_states: torch.FloatTensor,
447
- controlnet_cond: List[torch.tensor],
448
- controlnet_mode: List[torch.tensor],
449
- conditioning_scale: List[float],
450
- encoder_hidden_states: torch.Tensor = None,
451
- pooled_projections: torch.Tensor = None,
452
- timestep: torch.LongTensor = None,
453
- img_ids: torch.Tensor = None,
454
- txt_ids: torch.Tensor = None,
455
- guidance: torch.Tensor = None,
456
- joint_attention_kwargs: Optional[Dict[str, Any]] = None,
457
- return_dict: bool = True,
458
- ) -> Union[FluxControlNetOutput, Tuple]:
459
- # ControlNet-Union with multiple conditions
460
- # only load one ControlNet for saving memories
461
- if len(self.nets) == 1 and self.nets[0].union:
462
- controlnet = self.nets[0]
463
-
464
- for i, (image, mode, scale) in enumerate(zip(controlnet_cond, controlnet_mode, conditioning_scale)):
465
- block_samples, single_block_samples = controlnet(
466
- hidden_states=hidden_states,
467
- controlnet_cond=image,
468
- controlnet_mode=mode[:, None],
469
- conditioning_scale=scale,
470
- timestep=timestep,
471
- guidance=guidance,
472
- pooled_projections=pooled_projections,
473
- encoder_hidden_states=encoder_hidden_states,
474
- txt_ids=txt_ids,
475
- img_ids=img_ids,
476
- joint_attention_kwargs=joint_attention_kwargs,
477
- return_dict=return_dict,
478
- )
479
-
480
- # merge samples
481
- if i == 0:
482
- control_block_samples = block_samples
483
- control_single_block_samples = single_block_samples
484
- else:
485
- control_block_samples = [
486
- control_block_sample + block_sample
487
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
488
- ]
489
-
490
- control_single_block_samples = [
491
- control_single_block_sample + block_sample
492
- for control_single_block_sample, block_sample in zip(
493
- control_single_block_samples, single_block_samples
494
- )
495
- ]
496
-
497
- # Regular Multi-ControlNets
498
- # load all ControlNets into memories
499
- else:
500
- for i, (image, mode, scale, controlnet) in enumerate(
501
- zip(controlnet_cond, controlnet_mode, conditioning_scale, self.nets)
502
- ):
503
- block_samples, single_block_samples = controlnet(
504
- hidden_states=hidden_states,
505
- controlnet_cond=image,
506
- controlnet_mode=mode[:, None],
507
- conditioning_scale=scale,
508
- timestep=timestep,
509
- guidance=guidance,
510
- pooled_projections=pooled_projections,
511
- encoder_hidden_states=encoder_hidden_states,
512
- txt_ids=txt_ids,
513
- img_ids=img_ids,
514
- joint_attention_kwargs=joint_attention_kwargs,
515
- return_dict=return_dict,
516
- )
517
64
 
518
- # merge samples
519
- if i == 0:
520
- control_block_samples = block_samples
521
- control_single_block_samples = single_block_samples
522
- else:
523
- if block_samples is not None and control_block_samples is not None:
524
- control_block_samples = [
525
- control_block_sample + block_sample
526
- for control_block_sample, block_sample in zip(control_block_samples, block_samples)
527
- ]
528
- if single_block_samples is not None and control_single_block_samples is not None:
529
- control_single_block_samples = [
530
- control_single_block_sample + block_sample
531
- for control_single_block_sample, block_sample in zip(
532
- control_single_block_samples, single_block_samples
533
- )
534
- ]
535
65
 
536
- return control_block_samples, control_single_block_samples
66
+ class FluxMultiControlNetModel(FluxMultiControlNetModel):
67
+ def __init__(self, *args, **kwargs):
68
+ deprecation_message = "Importing `FluxMultiControlNetModel` from `diffusers.models.controlnet_flux` is deprecated and this will be removed in a future version. Please use `from diffusers.models.controlnets.controlnet_flux import FluxMultiControlNetModel`, instead."
69
+ deprecate("diffusers.models.controlnet_flux.FluxMultiControlNetModel", "0.34", deprecation_message)
70
+ super().__init__(*args, **kwargs)