diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. diffusers/__init__.py +15 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +13 -1
  9. diffusers/loaders/single_file_model.py +15 -8
  10. diffusers/loaders/single_file_utils.py +267 -17
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +7 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_sd3.py +418 -0
  22. diffusers/models/controlnet_xs.py +6 -6
  23. diffusers/models/embeddings.py +112 -84
  24. diffusers/models/model_loading_utils.py +55 -0
  25. diffusers/models/modeling_utils.py +138 -20
  26. diffusers/models/normalization.py +11 -6
  27. diffusers/models/transformers/__init__.py +1 -0
  28. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  29. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  30. diffusers/models/transformers/prior_transformer.py +5 -5
  31. diffusers/models/transformers/transformer_2d.py +2 -2
  32. diffusers/models/transformers/transformer_sd3.py +353 -0
  33. diffusers/models/transformers/transformer_temporal.py +12 -10
  34. diffusers/models/unets/unet_1d.py +3 -3
  35. diffusers/models/unets/unet_2d.py +3 -3
  36. diffusers/models/unets/unet_2d_condition.py +4 -15
  37. diffusers/models/unets/unet_3d_condition.py +5 -17
  38. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  39. diffusers/models/unets/unet_motion_model.py +4 -4
  40. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  41. diffusers/models/vq_model.py +8 -165
  42. diffusers/pipelines/__init__.py +11 -0
  43. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  45. diffusers/pipelines/auto_pipeline.py +8 -0
  46. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  47. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  48. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  49. diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
  50. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
  51. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  52. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  54. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  55. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  56. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  57. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  58. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  59. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  60. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  61. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  62. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  63. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  64. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  65. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  72. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  73. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  74. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  75. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
  76. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
  77. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  78. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  79. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  80. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  81. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  82. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  83. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  84. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  85. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  86. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  87. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  88. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  89. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  90. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  91. diffusers/schedulers/__init__.py +2 -0
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  93. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  94. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  95. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  96. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  97. diffusers/training_utils.py +4 -4
  98. diffusers/utils/__init__.py +3 -0
  99. diffusers/utils/constants.py +2 -0
  100. diffusers/utils/dummy_pt_objects.py +60 -0
  101. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  102. diffusers/utils/dynamic_modules_utils.py +15 -13
  103. diffusers/utils/hub_utils.py +106 -0
  104. diffusers/utils/import_utils.py +0 -1
  105. diffusers/utils/logging.py +3 -1
  106. diffusers/utils/state_dict_utils.py +2 -0
  107. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
  108. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
  109. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
  110. diffusers/models/dual_transformer_2d.py +0 -20
  111. diffusers/models/prior_transformer.py +0 -12
  112. diffusers/models/t5_film_transformer.py +0 -70
  113. diffusers/models/transformer_2d.py +0 -25
  114. diffusers/models/transformer_temporal.py +0 -34
  115. diffusers/models/unet_1d.py +0 -26
  116. diffusers/models/unet_1d_blocks.py +0 -203
  117. diffusers/models/unet_2d.py +0 -27
  118. diffusers/models/unet_2d_blocks.py +0 -375
  119. diffusers/models/unet_2d_condition.py +0 -25
  120. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
  121. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
  122. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import BaseOutput
22
+ from ...utils.accelerate_utils import apply_forward_hook
23
+ from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24
+ from ..modeling_utils import ModelMixin
25
+
26
+
27
+ @dataclass
28
+ class VQEncoderOutput(BaseOutput):
29
+ """
30
+ Output of VQModel encoding method.
31
+
32
+ Args:
33
+ latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The encoded output sample from the last layer of the model.
35
+ """
36
+
37
+ latents: torch.Tensor
38
+
39
+
40
+ class VQModel(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A VQ-VAE model for decoding latent representations.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
51
+ Tuple of downsample block types.
52
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
53
+ Tuple of upsample block types.
54
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
55
+ Tuple of block output channels.
56
+ layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
57
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
59
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
61
+ norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
62
+ vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
63
+ scaling_factor (`float`, *optional*, defaults to `0.18215`):
64
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
65
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
66
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
67
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
68
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
69
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
70
+ norm_type (`str`, *optional*, defaults to `"group"`):
71
+ Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ in_channels: int = 3,
78
+ out_channels: int = 3,
79
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
80
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
81
+ block_out_channels: Tuple[int, ...] = (64,),
82
+ layers_per_block: int = 1,
83
+ act_fn: str = "silu",
84
+ latent_channels: int = 3,
85
+ sample_size: int = 32,
86
+ num_vq_embeddings: int = 256,
87
+ norm_num_groups: int = 32,
88
+ vq_embed_dim: Optional[int] = None,
89
+ scaling_factor: float = 0.18215,
90
+ norm_type: str = "group", # group, spatial
91
+ mid_block_add_attention=True,
92
+ lookup_from_codebook=False,
93
+ force_upcast=False,
94
+ ):
95
+ super().__init__()
96
+
97
+ # pass init params to Encoder
98
+ self.encoder = Encoder(
99
+ in_channels=in_channels,
100
+ out_channels=latent_channels,
101
+ down_block_types=down_block_types,
102
+ block_out_channels=block_out_channels,
103
+ layers_per_block=layers_per_block,
104
+ act_fn=act_fn,
105
+ norm_num_groups=norm_num_groups,
106
+ double_z=False,
107
+ mid_block_add_attention=mid_block_add_attention,
108
+ )
109
+
110
+ vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
111
+
112
+ self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
113
+ self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
114
+ self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
115
+
116
+ # pass init params to Decoder
117
+ self.decoder = Decoder(
118
+ in_channels=latent_channels,
119
+ out_channels=out_channels,
120
+ up_block_types=up_block_types,
121
+ block_out_channels=block_out_channels,
122
+ layers_per_block=layers_per_block,
123
+ act_fn=act_fn,
124
+ norm_num_groups=norm_num_groups,
125
+ norm_type=norm_type,
126
+ mid_block_add_attention=mid_block_add_attention,
127
+ )
128
+
129
+ @apply_forward_hook
130
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
131
+ h = self.encoder(x)
132
+ h = self.quant_conv(h)
133
+
134
+ if not return_dict:
135
+ return (h,)
136
+
137
+ return VQEncoderOutput(latents=h)
138
+
139
+ @apply_forward_hook
140
+ def decode(
141
+ self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
142
+ ) -> Union[DecoderOutput, torch.Tensor]:
143
+ # also go through quantization layer
144
+ if not force_not_quantize:
145
+ quant, commit_loss, _ = self.quantize(h)
146
+ elif self.config.lookup_from_codebook:
147
+ quant = self.quantize.get_codebook_entry(h, shape)
148
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
149
+ else:
150
+ quant = h
151
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
152
+ quant2 = self.post_quant_conv(quant)
153
+ dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
154
+
155
+ if not return_dict:
156
+ return dec, commit_loss
157
+
158
+ return DecoderOutput(sample=dec, commit_loss=commit_loss)
159
+
160
+ def forward(
161
+ self, sample: torch.Tensor, return_dict: bool = True
162
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
163
+ r"""
164
+ The [`VQModel`] forward method.
165
+
166
+ Args:
167
+ sample (`torch.Tensor`): Input sample.
168
+ return_dict (`bool`, *optional*, defaults to `True`):
169
+ Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
170
+
171
+ Returns:
172
+ [`~models.vq_model.VQEncoderOutput`] or `tuple`:
173
+ If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
174
+ is returned.
175
+ """
176
+
177
+ h = self.encode(sample).latents
178
+ dec = self.decode(h)
179
+
180
+ if not return_dict:
181
+ return dec.sample, dec.commit_loss
182
+ return dec
@@ -0,0 +1,418 @@
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ..models.attention import JointTransformerBlock
25
+ from ..models.attention_processor import Attention, AttentionProcessor
26
+ from ..models.modeling_utils import ModelMixin
27
+ from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from .controlnet import BaseOutput, zero_module
29
+ from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
30
+ from .transformers.transformer_2d import Transformer2DModelOutput
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class SD3ControlNetOutput(BaseOutput):
38
+ controlnet_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ sample_size: int = 128,
48
+ patch_size: int = 2,
49
+ in_channels: int = 16,
50
+ num_layers: int = 18,
51
+ attention_head_dim: int = 64,
52
+ num_attention_heads: int = 18,
53
+ joint_attention_dim: int = 4096,
54
+ caption_projection_dim: int = 1152,
55
+ pooled_projection_dim: int = 2048,
56
+ out_channels: int = 16,
57
+ pos_embed_max_size: int = 96,
58
+ ):
59
+ super().__init__()
60
+ default_out_channels = in_channels
61
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
62
+ self.inner_dim = num_attention_heads * attention_head_dim
63
+
64
+ self.pos_embed = PatchEmbed(
65
+ height=sample_size,
66
+ width=sample_size,
67
+ patch_size=patch_size,
68
+ in_channels=in_channels,
69
+ embed_dim=self.inner_dim,
70
+ pos_embed_max_size=pos_embed_max_size,
71
+ )
72
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
73
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
74
+ )
75
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
76
+
77
+ # `attention_head_dim` is doubled to account for the mixing.
78
+ # It needs to crafted when we get the actual checkpoints.
79
+ self.transformer_blocks = nn.ModuleList(
80
+ [
81
+ JointTransformerBlock(
82
+ dim=self.inner_dim,
83
+ num_attention_heads=num_attention_heads,
84
+ attention_head_dim=self.inner_dim,
85
+ context_pre_only=False,
86
+ )
87
+ for i in range(num_layers)
88
+ ]
89
+ )
90
+
91
+ # controlnet_blocks
92
+ self.controlnet_blocks = nn.ModuleList([])
93
+ for _ in range(len(self.transformer_blocks)):
94
+ controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
95
+ controlnet_block = zero_module(controlnet_block)
96
+ self.controlnet_blocks.append(controlnet_block)
97
+ pos_embed_input = PatchEmbed(
98
+ height=sample_size,
99
+ width=sample_size,
100
+ patch_size=patch_size,
101
+ in_channels=in_channels,
102
+ embed_dim=self.inner_dim,
103
+ pos_embed_type=None,
104
+ )
105
+ self.pos_embed_input = zero_module(pos_embed_input)
106
+
107
+ self.gradient_checkpointing = False
108
+
109
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
110
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
111
+ """
112
+ Sets the attention processor to use [feed forward
113
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
114
+
115
+ Parameters:
116
+ chunk_size (`int`, *optional*):
117
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
118
+ over each tensor of dim=`dim`.
119
+ dim (`int`, *optional*, defaults to `0`):
120
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
121
+ or dim=1 (sequence length).
122
+ """
123
+ if dim not in [0, 1]:
124
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
125
+
126
+ # By default chunk size is 1
127
+ chunk_size = chunk_size or 1
128
+
129
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
130
+ if hasattr(module, "set_chunk_feed_forward"):
131
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
132
+
133
+ for child in module.children():
134
+ fn_recursive_feed_forward(child, chunk_size, dim)
135
+
136
+ for module in self.children():
137
+ fn_recursive_feed_forward(module, chunk_size, dim)
138
+
139
+ @property
140
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
141
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
142
+ r"""
143
+ Returns:
144
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
145
+ indexed by its weight name.
146
+ """
147
+ # set recursively
148
+ processors = {}
149
+
150
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
151
+ if hasattr(module, "get_processor"):
152
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
153
+
154
+ for sub_name, child in module.named_children():
155
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
156
+
157
+ return processors
158
+
159
+ for name, module in self.named_children():
160
+ fn_recursive_add_processors(name, module, processors)
161
+
162
+ return processors
163
+
164
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
165
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
166
+ r"""
167
+ Sets the attention processor to use to compute attention.
168
+
169
+ Parameters:
170
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
171
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
172
+ for **all** `Attention` layers.
173
+
174
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
175
+ processor. This is strongly recommended when setting trainable attention processors.
176
+
177
+ """
178
+ count = len(self.attn_processors.keys())
179
+
180
+ if isinstance(processor, dict) and len(processor) != count:
181
+ raise ValueError(
182
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
183
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
184
+ )
185
+
186
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
187
+ if hasattr(module, "set_processor"):
188
+ if not isinstance(processor, dict):
189
+ module.set_processor(processor)
190
+ else:
191
+ module.set_processor(processor.pop(f"{name}.processor"))
192
+
193
+ for sub_name, child in module.named_children():
194
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
195
+
196
+ for name, module in self.named_children():
197
+ fn_recursive_attn_processor(name, module, processor)
198
+
199
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
200
+ def fuse_qkv_projections(self):
201
+ """
202
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
203
+ are fused. For cross-attention modules, key and value projection matrices are fused.
204
+
205
+ <Tip warning={true}>
206
+
207
+ This API is 🧪 experimental.
208
+
209
+ </Tip>
210
+ """
211
+ self.original_attn_processors = None
212
+
213
+ for _, attn_processor in self.attn_processors.items():
214
+ if "Added" in str(attn_processor.__class__.__name__):
215
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
216
+
217
+ self.original_attn_processors = self.attn_processors
218
+
219
+ for module in self.modules():
220
+ if isinstance(module, Attention):
221
+ module.fuse_projections(fuse=True)
222
+
223
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
224
+ def unfuse_qkv_projections(self):
225
+ """Disables the fused QKV projection if enabled.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This API is 🧪 experimental.
230
+
231
+ </Tip>
232
+
233
+ """
234
+ if self.original_attn_processors is not None:
235
+ self.set_attn_processor(self.original_attn_processors)
236
+
237
+ def _set_gradient_checkpointing(self, module, value=False):
238
+ if hasattr(module, "gradient_checkpointing"):
239
+ module.gradient_checkpointing = value
240
+
241
+ @classmethod
242
+ def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
243
+ config = transformer.config
244
+ config["num_layers"] = num_layers or config.num_layers
245
+ controlnet = cls(**config)
246
+
247
+ if load_weights_from_transformer:
248
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
249
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
250
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
251
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
252
+
253
+ controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
254
+
255
+ return controlnet
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.FloatTensor,
260
+ controlnet_cond: torch.Tensor,
261
+ conditioning_scale: float = 1.0,
262
+ encoder_hidden_states: torch.FloatTensor = None,
263
+ pooled_projections: torch.FloatTensor = None,
264
+ timestep: torch.LongTensor = None,
265
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
266
+ return_dict: bool = True,
267
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
268
+ """
269
+ The [`SD3Transformer2DModel`] forward method.
270
+
271
+ Args:
272
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
273
+ Input `hidden_states`.
274
+ controlnet_cond (`torch.Tensor`):
275
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
276
+ conditioning_scale (`float`, defaults to `1.0`):
277
+ The scale factor for ControlNet outputs.
278
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
+ from the embeddings of input conditions.
282
+ timestep ( `torch.LongTensor`):
283
+ Used to indicate denoising step.
284
+ joint_attention_kwargs (`dict`, *optional*):
285
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
286
+ `self.processor` in
287
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
288
+ return_dict (`bool`, *optional*, defaults to `True`):
289
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
290
+ tuple.
291
+
292
+ Returns:
293
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
+ `tuple` where the first element is the sample tensor.
295
+ """
296
+ if joint_attention_kwargs is not None:
297
+ joint_attention_kwargs = joint_attention_kwargs.copy()
298
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
299
+ else:
300
+ lora_scale = 1.0
301
+
302
+ if USE_PEFT_BACKEND:
303
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
304
+ scale_lora_layers(self, lora_scale)
305
+ else:
306
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
307
+ logger.warning(
308
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
309
+ )
310
+
311
+ height, width = hidden_states.shape[-2:]
312
+
313
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
314
+ temb = self.time_text_embed(timestep, pooled_projections)
315
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
316
+
317
+ # add
318
+ hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
319
+
320
+ block_res_samples = ()
321
+
322
+ for block in self.transformer_blocks:
323
+ if self.training and self.gradient_checkpointing:
324
+
325
+ def create_custom_forward(module, return_dict=None):
326
+ def custom_forward(*inputs):
327
+ if return_dict is not None:
328
+ return module(*inputs, return_dict=return_dict)
329
+ else:
330
+ return module(*inputs)
331
+
332
+ return custom_forward
333
+
334
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
335
+ hidden_states = torch.utils.checkpoint.checkpoint(
336
+ create_custom_forward(block),
337
+ hidden_states,
338
+ encoder_hidden_states,
339
+ temb,
340
+ **ckpt_kwargs,
341
+ )
342
+
343
+ else:
344
+ encoder_hidden_states, hidden_states = block(
345
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
346
+ )
347
+
348
+ block_res_samples = block_res_samples + (hidden_states,)
349
+
350
+ controlnet_block_res_samples = ()
351
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
352
+ block_res_sample = controlnet_block(block_res_sample)
353
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
354
+
355
+ # 6. scaling
356
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
357
+
358
+ if USE_PEFT_BACKEND:
359
+ # remove `lora_scale` from each PEFT layer
360
+ unscale_lora_layers(self, lora_scale)
361
+
362
+ if not return_dict:
363
+ return (controlnet_block_res_samples,)
364
+
365
+ return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
366
+
367
+
368
+ class SD3MultiControlNetModel(ModelMixin):
369
+ r"""
370
+ `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
371
+
372
+ This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
373
+ compatible with `SD3ControlNetModel`.
374
+
375
+ Args:
376
+ controlnets (`List[SD3ControlNetModel]`):
377
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
378
+ `SD3ControlNetModel` as a list.
379
+ """
380
+
381
+ def __init__(self, controlnets):
382
+ super().__init__()
383
+ self.nets = nn.ModuleList(controlnets)
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.FloatTensor,
388
+ controlnet_cond: List[torch.tensor],
389
+ conditioning_scale: List[float],
390
+ pooled_projections: torch.FloatTensor,
391
+ encoder_hidden_states: torch.FloatTensor = None,
392
+ timestep: torch.LongTensor = None,
393
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
394
+ return_dict: bool = True,
395
+ ) -> Union[SD3ControlNetOutput, Tuple]:
396
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
397
+ block_samples = controlnet(
398
+ hidden_states=hidden_states,
399
+ timestep=timestep,
400
+ encoder_hidden_states=encoder_hidden_states,
401
+ pooled_projections=pooled_projections,
402
+ controlnet_cond=image,
403
+ conditioning_scale=scale,
404
+ joint_attention_kwargs=joint_attention_kwargs,
405
+ return_dict=return_dict,
406
+ )
407
+
408
+ # merge samples
409
+ if i == 0:
410
+ control_block_samples = block_samples
411
+ else:
412
+ control_block_samples = [
413
+ control_block_sample + block_sample
414
+ for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
415
+ ]
416
+ control_block_samples = (tuple(control_block_samples),)
417
+
418
+ return control_block_samples
@@ -851,8 +851,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
851
851
  if hasattr(module, "gradient_checkpointing"):
852
852
  module.gradient_checkpointing = value
853
853
 
854
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
855
854
  @property
855
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
856
856
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
857
857
  r"""
858
858
  Returns:
@@ -911,7 +911,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
911
911
  for name, module in self.named_children():
912
912
  fn_recursive_attn_processor(name, module, processor)
913
913
 
914
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
914
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
915
915
  def set_default_attn_processor(self):
916
916
  """
917
917
  Disables custom attention processors and sets the default attention implementation.
@@ -927,7 +927,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
927
927
 
928
928
  self.set_attn_processor(processor)
929
929
 
930
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
930
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
931
931
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
932
932
  r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
933
933
 
@@ -952,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
952
952
  setattr(upsample_block, "b1", b1)
953
953
  setattr(upsample_block, "b2", b2)
954
954
 
955
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
955
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
956
956
  def disable_freeu(self):
957
957
  """Disables the FreeU mechanism."""
958
958
  freeu_keys = {"s1", "s2", "b1", "b2"}
@@ -961,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
961
961
  if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
962
962
  setattr(upsample_block, k, None)
963
963
 
964
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
964
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
965
965
  def fuse_qkv_projections(self):
966
966
  """
967
967
  Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -985,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
985
985
  if isinstance(module, Attention):
986
986
  module.fuse_projections(fuse=True)
987
987
 
988
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
988
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
989
989
  def unfuse_qkv_projections(self):
990
990
  """Disables the fused QKV projection if enabled.
991
991