diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (172) hide show
  1. diffusers/__init__.py +34 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +170 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +35 -6
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +2 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
  53. diffusers/pipelines/cogview3/__init__.py +47 -0
  54. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  55. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  56. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  57. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  58. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  60. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  62. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  63. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  64. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  66. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  67. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  68. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  70. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  71. diffusers/pipelines/flux/__init__.py +10 -0
  72. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  73. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  74. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  76. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  77. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  78. diffusers/pipelines/free_noise_utils.py +365 -5
  79. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  80. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  81. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  82. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  83. diffusers/pipelines/kolors/tokenizer.py +4 -0
  84. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  86. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  87. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  89. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  90. diffusers/pipelines/pag/__init__.py +6 -0
  91. diffusers/pipelines/pag/pag_utils.py +8 -2
  92. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  96. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  97. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  98. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  100. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  101. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  102. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  103. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  106. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  107. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  108. diffusers/pipelines/pipeline_utils.py +123 -180
  109. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  111. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  117. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  120. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  121. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  122. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  123. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  126. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  127. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  129. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  130. diffusers/quantizers/__init__.py +16 -0
  131. diffusers/quantizers/auto.py +126 -0
  132. diffusers/quantizers/base.py +233 -0
  133. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  134. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  135. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  136. diffusers/quantizers/quantization_config.py +391 -0
  137. diffusers/schedulers/scheduling_ddim.py +4 -1
  138. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  140. diffusers/schedulers/scheduling_ddpm.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  142. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  143. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  145. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  146. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  147. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  148. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  149. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  150. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  151. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  152. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  153. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  154. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  155. diffusers/schedulers/scheduling_sasolver.py +78 -1
  156. diffusers/schedulers/scheduling_unclip.py +4 -1
  157. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  158. diffusers/training_utils.py +48 -18
  159. diffusers/utils/__init__.py +2 -1
  160. diffusers/utils/dummy_pt_objects.py +60 -0
  161. diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
  162. diffusers/utils/hub_utils.py +16 -4
  163. diffusers/utils/import_utils.py +31 -8
  164. diffusers/utils/loading_utils.py +28 -4
  165. diffusers/utils/peft_utils.py +3 -3
  166. diffusers/utils/testing_utils.py +59 -0
  167. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  168. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
  169. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  170. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
  171. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  172. {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,386 @@
1
+ # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...models.attention import FeedForward
23
+ from ...models.attention_processor import (
24
+ Attention,
25
+ AttentionProcessor,
26
+ CogVideoXAttnProcessor2_0,
27
+ )
28
+ from ...models.modeling_utils import ModelMixin
29
+ from ...models.normalization import AdaLayerNormContinuous
30
+ from ...utils import is_torch_version, logging
31
+ from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
32
+ from ..modeling_outputs import Transformer2DModelOutput
33
+ from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class CogView3PlusTransformerBlock(nn.Module):
40
+ r"""
41
+ Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
42
+
43
+ Args:
44
+ dim (`int`):
45
+ The number of channels in the input and output.
46
+ num_attention_heads (`int`):
47
+ The number of heads to use for multi-head attention.
48
+ attention_head_dim (`int`):
49
+ The number of channels in each head.
50
+ time_embed_dim (`int`):
51
+ The number of channels in timestep embedding.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dim: int = 2560,
57
+ num_attention_heads: int = 64,
58
+ attention_head_dim: int = 40,
59
+ time_embed_dim: int = 512,
60
+ ):
61
+ super().__init__()
62
+
63
+ self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
64
+
65
+ self.attn1 = Attention(
66
+ query_dim=dim,
67
+ heads=num_attention_heads,
68
+ dim_head=attention_head_dim,
69
+ out_dim=dim,
70
+ bias=True,
71
+ qk_norm="layer_norm",
72
+ elementwise_affine=False,
73
+ eps=1e-6,
74
+ processor=CogVideoXAttnProcessor2_0(),
75
+ )
76
+
77
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
78
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
79
+
80
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ encoder_hidden_states: torch.Tensor,
86
+ emb: torch.Tensor,
87
+ ) -> torch.Tensor:
88
+ text_seq_length = encoder_hidden_states.size(1)
89
+
90
+ # norm & modulate
91
+ (
92
+ norm_hidden_states,
93
+ gate_msa,
94
+ shift_mlp,
95
+ scale_mlp,
96
+ gate_mlp,
97
+ norm_encoder_hidden_states,
98
+ c_gate_msa,
99
+ c_shift_mlp,
100
+ c_scale_mlp,
101
+ c_gate_mlp,
102
+ ) = self.norm1(hidden_states, encoder_hidden_states, emb)
103
+
104
+ # attention
105
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
106
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
107
+ )
108
+
109
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
110
+ encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
111
+
112
+ # norm & modulate
113
+ norm_hidden_states = self.norm2(hidden_states)
114
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
115
+
116
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
117
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
118
+
119
+ # feed-forward
120
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
121
+ ff_output = self.ff(norm_hidden_states)
122
+
123
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
125
+
126
+ if hidden_states.dtype == torch.float16:
127
+ hidden_states = hidden_states.clip(-65504, 65504)
128
+ if encoder_hidden_states.dtype == torch.float16:
129
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
130
+ return hidden_states, encoder_hidden_states
131
+
132
+
133
+ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
134
+ r"""
135
+ The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
136
+ Diffusion](https://huggingface.co/papers/2403.05121).
137
+
138
+ Args:
139
+ patch_size (`int`, defaults to `2`):
140
+ The size of the patches to use in the patch embedding layer.
141
+ in_channels (`int`, defaults to `16`):
142
+ The number of channels in the input.
143
+ num_layers (`int`, defaults to `30`):
144
+ The number of layers of Transformer blocks to use.
145
+ attention_head_dim (`int`, defaults to `40`):
146
+ The number of channels in each head.
147
+ num_attention_heads (`int`, defaults to `64`):
148
+ The number of heads to use for multi-head attention.
149
+ out_channels (`int`, defaults to `16`):
150
+ The number of channels in the output.
151
+ text_embed_dim (`int`, defaults to `4096`):
152
+ Input dimension of text embeddings from the text encoder.
153
+ time_embed_dim (`int`, defaults to `512`):
154
+ Output dimension of timestep embeddings.
155
+ condition_dim (`int`, defaults to `256`):
156
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
157
+ crop_coords).
158
+ pos_embed_max_size (`int`, defaults to `128`):
159
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
160
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
161
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
162
+ patch_size => 128 * 8 * 2 => 2048`.
163
+ sample_size (`int`, defaults to `128`):
164
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
165
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
166
+ """
167
+
168
+ _supports_gradient_checkpointing = True
169
+
170
+ @register_to_config
171
+ def __init__(
172
+ self,
173
+ patch_size: int = 2,
174
+ in_channels: int = 16,
175
+ num_layers: int = 30,
176
+ attention_head_dim: int = 40,
177
+ num_attention_heads: int = 64,
178
+ out_channels: int = 16,
179
+ text_embed_dim: int = 4096,
180
+ time_embed_dim: int = 512,
181
+ condition_dim: int = 256,
182
+ pos_embed_max_size: int = 128,
183
+ sample_size: int = 128,
184
+ ):
185
+ super().__init__()
186
+ self.out_channels = out_channels
187
+ self.inner_dim = num_attention_heads * attention_head_dim
188
+
189
+ # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
190
+ # Each of these are sincos embeddings of shape 2 * condition_dim
191
+ self.pooled_projection_dim = 3 * 2 * condition_dim
192
+
193
+ self.patch_embed = CogView3PlusPatchEmbed(
194
+ in_channels=in_channels,
195
+ hidden_size=self.inner_dim,
196
+ patch_size=patch_size,
197
+ text_hidden_size=text_embed_dim,
198
+ pos_embed_max_size=pos_embed_max_size,
199
+ )
200
+
201
+ self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
202
+ embedding_dim=time_embed_dim,
203
+ condition_dim=condition_dim,
204
+ pooled_projection_dim=self.pooled_projection_dim,
205
+ timesteps_dim=self.inner_dim,
206
+ )
207
+
208
+ self.transformer_blocks = nn.ModuleList(
209
+ [
210
+ CogView3PlusTransformerBlock(
211
+ dim=self.inner_dim,
212
+ num_attention_heads=num_attention_heads,
213
+ attention_head_dim=attention_head_dim,
214
+ time_embed_dim=time_embed_dim,
215
+ )
216
+ for _ in range(num_layers)
217
+ ]
218
+ )
219
+
220
+ self.norm_out = AdaLayerNormContinuous(
221
+ embedding_dim=self.inner_dim,
222
+ conditioning_embedding_dim=time_embed_dim,
223
+ elementwise_affine=False,
224
+ eps=1e-6,
225
+ )
226
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
227
+
228
+ self.gradient_checkpointing = False
229
+
230
+ @property
231
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
232
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
233
+ r"""
234
+ Returns:
235
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
236
+ indexed by its weight name.
237
+ """
238
+ # set recursively
239
+ processors = {}
240
+
241
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
242
+ if hasattr(module, "get_processor"):
243
+ processors[f"{name}.processor"] = module.get_processor()
244
+
245
+ for sub_name, child in module.named_children():
246
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
247
+
248
+ return processors
249
+
250
+ for name, module in self.named_children():
251
+ fn_recursive_add_processors(name, module, processors)
252
+
253
+ return processors
254
+
255
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
256
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
257
+ r"""
258
+ Sets the attention processor to use to compute attention.
259
+
260
+ Parameters:
261
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
262
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
263
+ for **all** `Attention` layers.
264
+
265
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
266
+ processor. This is strongly recommended when setting trainable attention processors.
267
+
268
+ """
269
+ count = len(self.attn_processors.keys())
270
+
271
+ if isinstance(processor, dict) and len(processor) != count:
272
+ raise ValueError(
273
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
274
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
275
+ )
276
+
277
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
278
+ if hasattr(module, "set_processor"):
279
+ if not isinstance(processor, dict):
280
+ module.set_processor(processor)
281
+ else:
282
+ module.set_processor(processor.pop(f"{name}.processor"))
283
+
284
+ for sub_name, child in module.named_children():
285
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
286
+
287
+ for name, module in self.named_children():
288
+ fn_recursive_attn_processor(name, module, processor)
289
+
290
+ def _set_gradient_checkpointing(self, module, value=False):
291
+ if hasattr(module, "gradient_checkpointing"):
292
+ module.gradient_checkpointing = value
293
+
294
+ def forward(
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ encoder_hidden_states: torch.Tensor,
298
+ timestep: torch.LongTensor,
299
+ original_size: torch.Tensor,
300
+ target_size: torch.Tensor,
301
+ crop_coords: torch.Tensor,
302
+ return_dict: bool = True,
303
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
304
+ """
305
+ The [`CogView3PlusTransformer2DModel`] forward method.
306
+
307
+ Args:
308
+ hidden_states (`torch.Tensor`):
309
+ Input `hidden_states` of shape `(batch size, channel, height, width)`.
310
+ encoder_hidden_states (`torch.Tensor`):
311
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
312
+ `(batch_size, sequence_len, text_embed_dim)`
313
+ timestep (`torch.LongTensor`):
314
+ Used to indicate denoising step.
315
+ original_size (`torch.Tensor`):
316
+ CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
317
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
318
+ target_size (`torch.Tensor`):
319
+ CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
320
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
321
+ crop_coords (`torch.Tensor`):
322
+ CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
323
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
324
+ return_dict (`bool`, *optional*, defaults to `True`):
325
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
326
+ tuple.
327
+
328
+ Returns:
329
+ `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
330
+ The denoised latents using provided inputs as conditioning.
331
+ """
332
+ height, width = hidden_states.shape[-2:]
333
+ text_seq_length = encoder_hidden_states.shape[1]
334
+
335
+ hidden_states = self.patch_embed(
336
+ hidden_states, encoder_hidden_states
337
+ ) # takes care of adding positional embeddings too.
338
+ emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
339
+
340
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
341
+ hidden_states = hidden_states[:, text_seq_length:]
342
+
343
+ for index_block, block in enumerate(self.transformer_blocks):
344
+ if self.training and self.gradient_checkpointing:
345
+
346
+ def create_custom_forward(module):
347
+ def custom_forward(*inputs):
348
+ return module(*inputs)
349
+
350
+ return custom_forward
351
+
352
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
353
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
354
+ create_custom_forward(block),
355
+ hidden_states,
356
+ encoder_hidden_states,
357
+ emb,
358
+ **ckpt_kwargs,
359
+ )
360
+ else:
361
+ hidden_states, encoder_hidden_states = block(
362
+ hidden_states=hidden_states,
363
+ encoder_hidden_states=encoder_hidden_states,
364
+ emb=emb,
365
+ )
366
+
367
+ hidden_states = self.norm_out(hidden_states, emb)
368
+ hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
369
+
370
+ # unpatchify
371
+ patch_size = self.config.patch_size
372
+ height = height // patch_size
373
+ width = width // patch_size
374
+
375
+ hidden_states = hidden_states.reshape(
376
+ shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
377
+ )
378
+ hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
379
+ output = hidden_states.reshape(
380
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
381
+ )
382
+
383
+ if not return_dict:
384
+ return (output,)
385
+
386
+ return Transformer2DModelOutput(sample=output)