diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,12 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from ...
|
22
|
+
from ...loaders import PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
23
24
|
from ...utils.torch_utils import maybe_allow_in_graph
|
24
25
|
from ..attention import Attention, FeedForward
|
25
26
|
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
26
|
-
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
27
|
+
from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
|
27
28
|
from ..modeling_outputs import Transformer2DModelOutput
|
28
29
|
from ..modeling_utils import ModelMixin
|
29
30
|
from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
|
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
|
|
152
153
|
return hidden_states, encoder_hidden_states
|
153
154
|
|
154
155
|
|
155
|
-
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
156
|
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
156
157
|
"""
|
157
158
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
158
159
|
|
@@ -235,37 +236,42 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
235
236
|
spatial_interpolation_scale: float = 1.875,
|
236
237
|
temporal_interpolation_scale: float = 1.0,
|
237
238
|
use_rotary_positional_embeddings: bool = False,
|
239
|
+
use_learned_positional_embeddings: bool = False,
|
238
240
|
):
|
239
241
|
super().__init__()
|
240
242
|
inner_dim = num_attention_heads * attention_head_dim
|
241
243
|
|
242
|
-
|
243
|
-
|
244
|
-
|
245
|
-
|
244
|
+
if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
|
245
|
+
raise ValueError(
|
246
|
+
"There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
|
247
|
+
"embeddings. If you're using a custom model and/or believe this should be supported, please open an "
|
248
|
+
"issue at https://github.com/huggingface/diffusers/issues."
|
249
|
+
)
|
246
250
|
|
247
251
|
# 1. Patch embedding
|
248
|
-
self.patch_embed = CogVideoXPatchEmbed(
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
|
254
|
-
|
255
|
-
|
256
|
-
|
257
|
-
|
252
|
+
self.patch_embed = CogVideoXPatchEmbed(
|
253
|
+
patch_size=patch_size,
|
254
|
+
in_channels=in_channels,
|
255
|
+
embed_dim=inner_dim,
|
256
|
+
text_embed_dim=text_embed_dim,
|
257
|
+
bias=True,
|
258
|
+
sample_width=sample_width,
|
259
|
+
sample_height=sample_height,
|
260
|
+
sample_frames=sample_frames,
|
261
|
+
temporal_compression_ratio=temporal_compression_ratio,
|
262
|
+
max_text_seq_length=max_text_seq_length,
|
263
|
+
spatial_interpolation_scale=spatial_interpolation_scale,
|
264
|
+
temporal_interpolation_scale=temporal_interpolation_scale,
|
265
|
+
use_positional_embeddings=not use_rotary_positional_embeddings,
|
266
|
+
use_learned_positional_embeddings=use_learned_positional_embeddings,
|
258
267
|
)
|
259
|
-
|
260
|
-
pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
|
261
|
-
pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
|
262
|
-
self.register_buffer("pos_embedding", pos_embedding, persistent=False)
|
268
|
+
self.embedding_dropout = nn.Dropout(dropout)
|
263
269
|
|
264
|
-
#
|
270
|
+
# 2. Time embeddings
|
265
271
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
266
272
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
267
273
|
|
268
|
-
#
|
274
|
+
# 3. Define spatio-temporal transformers blocks
|
269
275
|
self.transformer_blocks = nn.ModuleList(
|
270
276
|
[
|
271
277
|
CogVideoXBlock(
|
@@ -284,7 +290,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
284
290
|
)
|
285
291
|
self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
|
286
292
|
|
287
|
-
#
|
293
|
+
# 4. Output blocks
|
288
294
|
self.norm_out = AdaLayerNorm(
|
289
295
|
embedding_dim=time_embed_dim,
|
290
296
|
output_dim=2 * inner_dim,
|
@@ -406,8 +412,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
406
412
|
timestep: Union[int, float, torch.LongTensor],
|
407
413
|
timestep_cond: Optional[torch.Tensor] = None,
|
408
414
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
415
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
409
416
|
return_dict: bool = True,
|
410
417
|
):
|
418
|
+
if attention_kwargs is not None:
|
419
|
+
attention_kwargs = attention_kwargs.copy()
|
420
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
421
|
+
else:
|
422
|
+
lora_scale = 1.0
|
423
|
+
|
424
|
+
if USE_PEFT_BACKEND:
|
425
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
426
|
+
scale_lora_layers(self, lora_scale)
|
427
|
+
else:
|
428
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
429
|
+
logger.warning(
|
430
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
431
|
+
)
|
432
|
+
|
411
433
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
412
434
|
|
413
435
|
# 1. Time embedding
|
@@ -422,20 +444,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
422
444
|
|
423
445
|
# 2. Patch embedding
|
424
446
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
447
|
+
hidden_states = self.embedding_dropout(hidden_states)
|
425
448
|
|
426
|
-
# 3. Position embedding
|
427
449
|
text_seq_length = encoder_hidden_states.shape[1]
|
428
|
-
if not self.config.use_rotary_positional_embeddings:
|
429
|
-
seq_length = height * width * num_frames // (self.config.patch_size**2)
|
430
|
-
|
431
|
-
pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
|
432
|
-
hidden_states = hidden_states + pos_embeds
|
433
|
-
hidden_states = self.embedding_dropout(hidden_states)
|
434
|
-
|
435
450
|
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
436
451
|
hidden_states = hidden_states[:, text_seq_length:]
|
437
452
|
|
438
|
-
#
|
453
|
+
# 3. Transformer blocks
|
439
454
|
for i, block in enumerate(self.transformer_blocks):
|
440
455
|
if self.training and self.gradient_checkpointing:
|
441
456
|
|
@@ -471,15 +486,22 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
471
486
|
hidden_states = self.norm_final(hidden_states)
|
472
487
|
hidden_states = hidden_states[:, text_seq_length:]
|
473
488
|
|
474
|
-
#
|
489
|
+
# 4. Final block
|
475
490
|
hidden_states = self.norm_out(hidden_states, temb=emb)
|
476
491
|
hidden_states = self.proj_out(hidden_states)
|
477
492
|
|
478
|
-
#
|
493
|
+
# 5. Unpatchify
|
494
|
+
# Note: we use `-1` instead of `channels`:
|
495
|
+
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
496
|
+
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
479
497
|
p = self.config.patch_size
|
480
|
-
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p,
|
498
|
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
481
499
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
482
500
|
|
501
|
+
if USE_PEFT_BACKEND:
|
502
|
+
# remove `lora_scale` from each PEFT layer
|
503
|
+
unscale_lora_layers(self, lora_scale)
|
504
|
+
|
483
505
|
if not return_dict:
|
484
506
|
return (output,)
|
485
507
|
return Transformer2DModelOutput(sample=output)
|
@@ -19,7 +19,7 @@ from torch import nn
|
|
19
19
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
20
|
from ...utils import is_torch_version, logging
|
21
21
|
from ..attention import BasicTransformerBlock
|
22
|
-
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
22
|
+
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
23
23
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
24
24
|
from ..modeling_outputs import Transformer2DModelOutput
|
25
25
|
from ..modeling_utils import ModelMixin
|
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
247
247
|
for name, module in self.named_children():
|
248
248
|
fn_recursive_attn_processor(name, module, processor)
|
249
249
|
|
250
|
+
def set_default_attn_processor(self):
|
251
|
+
"""
|
252
|
+
Disables custom attention processors and sets the default attention implementation.
|
253
|
+
|
254
|
+
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
|
255
|
+
"""
|
256
|
+
self.set_attn_processor(AttnProcessor())
|
257
|
+
|
250
258
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
251
259
|
def fuse_qkv_projections(self):
|
252
260
|
"""
|
@@ -0,0 +1,386 @@
|
|
1
|
+
# Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from typing import Any, Dict, Union
|
17
|
+
|
18
|
+
import torch
|
19
|
+
import torch.nn as nn
|
20
|
+
|
21
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ...models.attention import FeedForward
|
23
|
+
from ...models.attention_processor import (
|
24
|
+
Attention,
|
25
|
+
AttentionProcessor,
|
26
|
+
CogVideoXAttnProcessor2_0,
|
27
|
+
)
|
28
|
+
from ...models.modeling_utils import ModelMixin
|
29
|
+
from ...models.normalization import AdaLayerNormContinuous
|
30
|
+
from ...utils import is_torch_version, logging
|
31
|
+
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
|
32
|
+
from ..modeling_outputs import Transformer2DModelOutput
|
33
|
+
from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
|
34
|
+
|
35
|
+
|
36
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
37
|
+
|
38
|
+
|
39
|
+
class CogView3PlusTransformerBlock(nn.Module):
|
40
|
+
r"""
|
41
|
+
Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
|
42
|
+
|
43
|
+
Args:
|
44
|
+
dim (`int`):
|
45
|
+
The number of channels in the input and output.
|
46
|
+
num_attention_heads (`int`):
|
47
|
+
The number of heads to use for multi-head attention.
|
48
|
+
attention_head_dim (`int`):
|
49
|
+
The number of channels in each head.
|
50
|
+
time_embed_dim (`int`):
|
51
|
+
The number of channels in timestep embedding.
|
52
|
+
"""
|
53
|
+
|
54
|
+
def __init__(
|
55
|
+
self,
|
56
|
+
dim: int = 2560,
|
57
|
+
num_attention_heads: int = 64,
|
58
|
+
attention_head_dim: int = 40,
|
59
|
+
time_embed_dim: int = 512,
|
60
|
+
):
|
61
|
+
super().__init__()
|
62
|
+
|
63
|
+
self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
|
64
|
+
|
65
|
+
self.attn1 = Attention(
|
66
|
+
query_dim=dim,
|
67
|
+
heads=num_attention_heads,
|
68
|
+
dim_head=attention_head_dim,
|
69
|
+
out_dim=dim,
|
70
|
+
bias=True,
|
71
|
+
qk_norm="layer_norm",
|
72
|
+
elementwise_affine=False,
|
73
|
+
eps=1e-6,
|
74
|
+
processor=CogVideoXAttnProcessor2_0(),
|
75
|
+
)
|
76
|
+
|
77
|
+
self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
78
|
+
self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
79
|
+
|
80
|
+
self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
|
81
|
+
|
82
|
+
def forward(
|
83
|
+
self,
|
84
|
+
hidden_states: torch.Tensor,
|
85
|
+
encoder_hidden_states: torch.Tensor,
|
86
|
+
emb: torch.Tensor,
|
87
|
+
) -> torch.Tensor:
|
88
|
+
text_seq_length = encoder_hidden_states.size(1)
|
89
|
+
|
90
|
+
# norm & modulate
|
91
|
+
(
|
92
|
+
norm_hidden_states,
|
93
|
+
gate_msa,
|
94
|
+
shift_mlp,
|
95
|
+
scale_mlp,
|
96
|
+
gate_mlp,
|
97
|
+
norm_encoder_hidden_states,
|
98
|
+
c_gate_msa,
|
99
|
+
c_shift_mlp,
|
100
|
+
c_scale_mlp,
|
101
|
+
c_gate_mlp,
|
102
|
+
) = self.norm1(hidden_states, encoder_hidden_states, emb)
|
103
|
+
|
104
|
+
# attention
|
105
|
+
attn_hidden_states, attn_encoder_hidden_states = self.attn1(
|
106
|
+
hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
|
107
|
+
)
|
108
|
+
|
109
|
+
hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
|
110
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
|
111
|
+
|
112
|
+
# norm & modulate
|
113
|
+
norm_hidden_states = self.norm2(hidden_states)
|
114
|
+
norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
|
115
|
+
|
116
|
+
norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
|
117
|
+
norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
|
118
|
+
|
119
|
+
# feed-forward
|
120
|
+
norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
|
121
|
+
ff_output = self.ff(norm_hidden_states)
|
122
|
+
|
123
|
+
hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
|
124
|
+
encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
|
125
|
+
|
126
|
+
if hidden_states.dtype == torch.float16:
|
127
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
128
|
+
if encoder_hidden_states.dtype == torch.float16:
|
129
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
130
|
+
return hidden_states, encoder_hidden_states
|
131
|
+
|
132
|
+
|
133
|
+
class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
|
134
|
+
r"""
|
135
|
+
The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
|
136
|
+
Diffusion](https://huggingface.co/papers/2403.05121).
|
137
|
+
|
138
|
+
Args:
|
139
|
+
patch_size (`int`, defaults to `2`):
|
140
|
+
The size of the patches to use in the patch embedding layer.
|
141
|
+
in_channels (`int`, defaults to `16`):
|
142
|
+
The number of channels in the input.
|
143
|
+
num_layers (`int`, defaults to `30`):
|
144
|
+
The number of layers of Transformer blocks to use.
|
145
|
+
attention_head_dim (`int`, defaults to `40`):
|
146
|
+
The number of channels in each head.
|
147
|
+
num_attention_heads (`int`, defaults to `64`):
|
148
|
+
The number of heads to use for multi-head attention.
|
149
|
+
out_channels (`int`, defaults to `16`):
|
150
|
+
The number of channels in the output.
|
151
|
+
text_embed_dim (`int`, defaults to `4096`):
|
152
|
+
Input dimension of text embeddings from the text encoder.
|
153
|
+
time_embed_dim (`int`, defaults to `512`):
|
154
|
+
Output dimension of timestep embeddings.
|
155
|
+
condition_dim (`int`, defaults to `256`):
|
156
|
+
The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
|
157
|
+
crop_coords).
|
158
|
+
pos_embed_max_size (`int`, defaults to `128`):
|
159
|
+
The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
|
160
|
+
to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
|
161
|
+
means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
|
162
|
+
patch_size => 128 * 8 * 2 => 2048`.
|
163
|
+
sample_size (`int`, defaults to `128`):
|
164
|
+
The base resolution of input latents. If height/width is not provided during generation, this value is used
|
165
|
+
to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
|
166
|
+
"""
|
167
|
+
|
168
|
+
_supports_gradient_checkpointing = True
|
169
|
+
|
170
|
+
@register_to_config
|
171
|
+
def __init__(
|
172
|
+
self,
|
173
|
+
patch_size: int = 2,
|
174
|
+
in_channels: int = 16,
|
175
|
+
num_layers: int = 30,
|
176
|
+
attention_head_dim: int = 40,
|
177
|
+
num_attention_heads: int = 64,
|
178
|
+
out_channels: int = 16,
|
179
|
+
text_embed_dim: int = 4096,
|
180
|
+
time_embed_dim: int = 512,
|
181
|
+
condition_dim: int = 256,
|
182
|
+
pos_embed_max_size: int = 128,
|
183
|
+
sample_size: int = 128,
|
184
|
+
):
|
185
|
+
super().__init__()
|
186
|
+
self.out_channels = out_channels
|
187
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
188
|
+
|
189
|
+
# CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
|
190
|
+
# Each of these are sincos embeddings of shape 2 * condition_dim
|
191
|
+
self.pooled_projection_dim = 3 * 2 * condition_dim
|
192
|
+
|
193
|
+
self.patch_embed = CogView3PlusPatchEmbed(
|
194
|
+
in_channels=in_channels,
|
195
|
+
hidden_size=self.inner_dim,
|
196
|
+
patch_size=patch_size,
|
197
|
+
text_hidden_size=text_embed_dim,
|
198
|
+
pos_embed_max_size=pos_embed_max_size,
|
199
|
+
)
|
200
|
+
|
201
|
+
self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
|
202
|
+
embedding_dim=time_embed_dim,
|
203
|
+
condition_dim=condition_dim,
|
204
|
+
pooled_projection_dim=self.pooled_projection_dim,
|
205
|
+
timesteps_dim=self.inner_dim,
|
206
|
+
)
|
207
|
+
|
208
|
+
self.transformer_blocks = nn.ModuleList(
|
209
|
+
[
|
210
|
+
CogView3PlusTransformerBlock(
|
211
|
+
dim=self.inner_dim,
|
212
|
+
num_attention_heads=num_attention_heads,
|
213
|
+
attention_head_dim=attention_head_dim,
|
214
|
+
time_embed_dim=time_embed_dim,
|
215
|
+
)
|
216
|
+
for _ in range(num_layers)
|
217
|
+
]
|
218
|
+
)
|
219
|
+
|
220
|
+
self.norm_out = AdaLayerNormContinuous(
|
221
|
+
embedding_dim=self.inner_dim,
|
222
|
+
conditioning_embedding_dim=time_embed_dim,
|
223
|
+
elementwise_affine=False,
|
224
|
+
eps=1e-6,
|
225
|
+
)
|
226
|
+
self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
|
227
|
+
|
228
|
+
self.gradient_checkpointing = False
|
229
|
+
|
230
|
+
@property
|
231
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
232
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
233
|
+
r"""
|
234
|
+
Returns:
|
235
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
236
|
+
indexed by its weight name.
|
237
|
+
"""
|
238
|
+
# set recursively
|
239
|
+
processors = {}
|
240
|
+
|
241
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
242
|
+
if hasattr(module, "get_processor"):
|
243
|
+
processors[f"{name}.processor"] = module.get_processor()
|
244
|
+
|
245
|
+
for sub_name, child in module.named_children():
|
246
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
247
|
+
|
248
|
+
return processors
|
249
|
+
|
250
|
+
for name, module in self.named_children():
|
251
|
+
fn_recursive_add_processors(name, module, processors)
|
252
|
+
|
253
|
+
return processors
|
254
|
+
|
255
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
256
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
257
|
+
r"""
|
258
|
+
Sets the attention processor to use to compute attention.
|
259
|
+
|
260
|
+
Parameters:
|
261
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
262
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
263
|
+
for **all** `Attention` layers.
|
264
|
+
|
265
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
266
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
267
|
+
|
268
|
+
"""
|
269
|
+
count = len(self.attn_processors.keys())
|
270
|
+
|
271
|
+
if isinstance(processor, dict) and len(processor) != count:
|
272
|
+
raise ValueError(
|
273
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
274
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
275
|
+
)
|
276
|
+
|
277
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
278
|
+
if hasattr(module, "set_processor"):
|
279
|
+
if not isinstance(processor, dict):
|
280
|
+
module.set_processor(processor)
|
281
|
+
else:
|
282
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
283
|
+
|
284
|
+
for sub_name, child in module.named_children():
|
285
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
286
|
+
|
287
|
+
for name, module in self.named_children():
|
288
|
+
fn_recursive_attn_processor(name, module, processor)
|
289
|
+
|
290
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
291
|
+
if hasattr(module, "gradient_checkpointing"):
|
292
|
+
module.gradient_checkpointing = value
|
293
|
+
|
294
|
+
def forward(
|
295
|
+
self,
|
296
|
+
hidden_states: torch.Tensor,
|
297
|
+
encoder_hidden_states: torch.Tensor,
|
298
|
+
timestep: torch.LongTensor,
|
299
|
+
original_size: torch.Tensor,
|
300
|
+
target_size: torch.Tensor,
|
301
|
+
crop_coords: torch.Tensor,
|
302
|
+
return_dict: bool = True,
|
303
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
304
|
+
"""
|
305
|
+
The [`CogView3PlusTransformer2DModel`] forward method.
|
306
|
+
|
307
|
+
Args:
|
308
|
+
hidden_states (`torch.Tensor`):
|
309
|
+
Input `hidden_states` of shape `(batch size, channel, height, width)`.
|
310
|
+
encoder_hidden_states (`torch.Tensor`):
|
311
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
|
312
|
+
`(batch_size, sequence_len, text_embed_dim)`
|
313
|
+
timestep (`torch.LongTensor`):
|
314
|
+
Used to indicate denoising step.
|
315
|
+
original_size (`torch.Tensor`):
|
316
|
+
CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
|
317
|
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
318
|
+
target_size (`torch.Tensor`):
|
319
|
+
CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
|
320
|
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
321
|
+
crop_coords (`torch.Tensor`):
|
322
|
+
CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
|
323
|
+
[https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
|
324
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
325
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
326
|
+
tuple.
|
327
|
+
|
328
|
+
Returns:
|
329
|
+
`torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
|
330
|
+
The denoised latents using provided inputs as conditioning.
|
331
|
+
"""
|
332
|
+
height, width = hidden_states.shape[-2:]
|
333
|
+
text_seq_length = encoder_hidden_states.shape[1]
|
334
|
+
|
335
|
+
hidden_states = self.patch_embed(
|
336
|
+
hidden_states, encoder_hidden_states
|
337
|
+
) # takes care of adding positional embeddings too.
|
338
|
+
emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
|
339
|
+
|
340
|
+
encoder_hidden_states = hidden_states[:, :text_seq_length]
|
341
|
+
hidden_states = hidden_states[:, text_seq_length:]
|
342
|
+
|
343
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
344
|
+
if self.training and self.gradient_checkpointing:
|
345
|
+
|
346
|
+
def create_custom_forward(module):
|
347
|
+
def custom_forward(*inputs):
|
348
|
+
return module(*inputs)
|
349
|
+
|
350
|
+
return custom_forward
|
351
|
+
|
352
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
353
|
+
hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
|
354
|
+
create_custom_forward(block),
|
355
|
+
hidden_states,
|
356
|
+
encoder_hidden_states,
|
357
|
+
emb,
|
358
|
+
**ckpt_kwargs,
|
359
|
+
)
|
360
|
+
else:
|
361
|
+
hidden_states, encoder_hidden_states = block(
|
362
|
+
hidden_states=hidden_states,
|
363
|
+
encoder_hidden_states=encoder_hidden_states,
|
364
|
+
emb=emb,
|
365
|
+
)
|
366
|
+
|
367
|
+
hidden_states = self.norm_out(hidden_states, emb)
|
368
|
+
hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
|
369
|
+
|
370
|
+
# unpatchify
|
371
|
+
patch_size = self.config.patch_size
|
372
|
+
height = height // patch_size
|
373
|
+
width = width // patch_size
|
374
|
+
|
375
|
+
hidden_states = hidden_states.reshape(
|
376
|
+
shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
|
377
|
+
)
|
378
|
+
hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
|
379
|
+
output = hidden_states.reshape(
|
380
|
+
shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
|
381
|
+
)
|
382
|
+
|
383
|
+
if not return_dict:
|
384
|
+
return (output,)
|
385
|
+
|
386
|
+
return Transformer2DModelOutput(sample=output)
|