diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -19,11 +19,12 @@ import torch
19
19
  from torch import nn
20
20
 
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
- from ...utils import is_torch_version, logging
22
+ from ...loaders import PeftAdapterMixin
23
+ from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
23
24
  from ...utils.torch_utils import maybe_allow_in_graph
24
25
  from ..attention import Attention, FeedForward
25
26
  from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
26
- from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps, get_3d_sincos_pos_embed
27
+ from ..embeddings import CogVideoXPatchEmbed, TimestepEmbedding, Timesteps
27
28
  from ..modeling_outputs import Transformer2DModelOutput
28
29
  from ..modeling_utils import ModelMixin
29
30
  from ..normalization import AdaLayerNorm, CogVideoXLayerNormZero
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
152
153
  return hidden_states, encoder_hidden_states
153
154
 
154
155
 
155
- class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
156
+ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
156
157
  """
157
158
  A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
158
159
 
@@ -235,37 +236,42 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
235
236
  spatial_interpolation_scale: float = 1.875,
236
237
  temporal_interpolation_scale: float = 1.0,
237
238
  use_rotary_positional_embeddings: bool = False,
239
+ use_learned_positional_embeddings: bool = False,
238
240
  ):
239
241
  super().__init__()
240
242
  inner_dim = num_attention_heads * attention_head_dim
241
243
 
242
- post_patch_height = sample_height // patch_size
243
- post_patch_width = sample_width // patch_size
244
- post_time_compression_frames = (sample_frames - 1) // temporal_compression_ratio + 1
245
- self.num_patches = post_patch_height * post_patch_width * post_time_compression_frames
244
+ if not use_rotary_positional_embeddings and use_learned_positional_embeddings:
245
+ raise ValueError(
246
+ "There are no CogVideoX checkpoints available with disable rotary embeddings and learned positional "
247
+ "embeddings. If you're using a custom model and/or believe this should be supported, please open an "
248
+ "issue at https://github.com/huggingface/diffusers/issues."
249
+ )
246
250
 
247
251
  # 1. Patch embedding
248
- self.patch_embed = CogVideoXPatchEmbed(patch_size, in_channels, inner_dim, text_embed_dim, bias=True)
249
- self.embedding_dropout = nn.Dropout(dropout)
250
-
251
- # 2. 3D positional embeddings
252
- spatial_pos_embedding = get_3d_sincos_pos_embed(
253
- inner_dim,
254
- (post_patch_width, post_patch_height),
255
- post_time_compression_frames,
256
- spatial_interpolation_scale,
257
- temporal_interpolation_scale,
252
+ self.patch_embed = CogVideoXPatchEmbed(
253
+ patch_size=patch_size,
254
+ in_channels=in_channels,
255
+ embed_dim=inner_dim,
256
+ text_embed_dim=text_embed_dim,
257
+ bias=True,
258
+ sample_width=sample_width,
259
+ sample_height=sample_height,
260
+ sample_frames=sample_frames,
261
+ temporal_compression_ratio=temporal_compression_ratio,
262
+ max_text_seq_length=max_text_seq_length,
263
+ spatial_interpolation_scale=spatial_interpolation_scale,
264
+ temporal_interpolation_scale=temporal_interpolation_scale,
265
+ use_positional_embeddings=not use_rotary_positional_embeddings,
266
+ use_learned_positional_embeddings=use_learned_positional_embeddings,
258
267
  )
259
- spatial_pos_embedding = torch.from_numpy(spatial_pos_embedding).flatten(0, 1)
260
- pos_embedding = torch.zeros(1, max_text_seq_length + self.num_patches, inner_dim, requires_grad=False)
261
- pos_embedding.data[:, max_text_seq_length:].copy_(spatial_pos_embedding)
262
- self.register_buffer("pos_embedding", pos_embedding, persistent=False)
268
+ self.embedding_dropout = nn.Dropout(dropout)
263
269
 
264
- # 3. Time embeddings
270
+ # 2. Time embeddings
265
271
  self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
266
272
  self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
267
273
 
268
- # 4. Define spatio-temporal transformers blocks
274
+ # 3. Define spatio-temporal transformers blocks
269
275
  self.transformer_blocks = nn.ModuleList(
270
276
  [
271
277
  CogVideoXBlock(
@@ -284,7 +290,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
284
290
  )
285
291
  self.norm_final = nn.LayerNorm(inner_dim, norm_eps, norm_elementwise_affine)
286
292
 
287
- # 5. Output blocks
293
+ # 4. Output blocks
288
294
  self.norm_out = AdaLayerNorm(
289
295
  embedding_dim=time_embed_dim,
290
296
  output_dim=2 * inner_dim,
@@ -406,8 +412,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
406
412
  timestep: Union[int, float, torch.LongTensor],
407
413
  timestep_cond: Optional[torch.Tensor] = None,
408
414
  image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
415
+ attention_kwargs: Optional[Dict[str, Any]] = None,
409
416
  return_dict: bool = True,
410
417
  ):
418
+ if attention_kwargs is not None:
419
+ attention_kwargs = attention_kwargs.copy()
420
+ lora_scale = attention_kwargs.pop("scale", 1.0)
421
+ else:
422
+ lora_scale = 1.0
423
+
424
+ if USE_PEFT_BACKEND:
425
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
426
+ scale_lora_layers(self, lora_scale)
427
+ else:
428
+ if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
429
+ logger.warning(
430
+ "Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
431
+ )
432
+
411
433
  batch_size, num_frames, channels, height, width = hidden_states.shape
412
434
 
413
435
  # 1. Time embedding
@@ -422,20 +444,13 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
422
444
 
423
445
  # 2. Patch embedding
424
446
  hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
447
+ hidden_states = self.embedding_dropout(hidden_states)
425
448
 
426
- # 3. Position embedding
427
449
  text_seq_length = encoder_hidden_states.shape[1]
428
- if not self.config.use_rotary_positional_embeddings:
429
- seq_length = height * width * num_frames // (self.config.patch_size**2)
430
-
431
- pos_embeds = self.pos_embedding[:, : text_seq_length + seq_length]
432
- hidden_states = hidden_states + pos_embeds
433
- hidden_states = self.embedding_dropout(hidden_states)
434
-
435
450
  encoder_hidden_states = hidden_states[:, :text_seq_length]
436
451
  hidden_states = hidden_states[:, text_seq_length:]
437
452
 
438
- # 4. Transformer blocks
453
+ # 3. Transformer blocks
439
454
  for i, block in enumerate(self.transformer_blocks):
440
455
  if self.training and self.gradient_checkpointing:
441
456
 
@@ -471,15 +486,22 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
471
486
  hidden_states = self.norm_final(hidden_states)
472
487
  hidden_states = hidden_states[:, text_seq_length:]
473
488
 
474
- # 5. Final block
489
+ # 4. Final block
475
490
  hidden_states = self.norm_out(hidden_states, temb=emb)
476
491
  hidden_states = self.proj_out(hidden_states)
477
492
 
478
- # 6. Unpatchify
493
+ # 5. Unpatchify
494
+ # Note: we use `-1` instead of `channels`:
495
+ # - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
496
+ # - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
479
497
  p = self.config.patch_size
480
- output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, channels, p, p)
498
+ output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
481
499
  output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
482
500
 
501
+ if USE_PEFT_BACKEND:
502
+ # remove `lora_scale` from each PEFT layer
503
+ unscale_lora_layers(self, lora_scale)
504
+
483
505
  if not return_dict:
484
506
  return (output,)
485
507
  return Transformer2DModelOutput(sample=output)
@@ -19,7 +19,7 @@ from torch import nn
19
19
  from ...configuration_utils import ConfigMixin, register_to_config
20
20
  from ...utils import is_torch_version, logging
21
21
  from ..attention import BasicTransformerBlock
22
- from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
22
+ from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
23
23
  from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
24
24
  from ..modeling_outputs import Transformer2DModelOutput
25
25
  from ..modeling_utils import ModelMixin
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
247
247
  for name, module in self.named_children():
248
248
  fn_recursive_attn_processor(name, module, processor)
249
249
 
250
+ def set_default_attn_processor(self):
251
+ """
252
+ Disables custom attention processors and sets the default attention implementation.
253
+
254
+ Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
255
+ """
256
+ self.set_attn_processor(AttnProcessor())
257
+
250
258
  # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
251
259
  def fuse_qkv_projections(self):
252
260
  """
@@ -0,0 +1,386 @@
1
+ # Copyright 2024 The CogView team, Tsinghua University & ZhipuAI and The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from typing import Any, Dict, Union
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+
21
+ from ...configuration_utils import ConfigMixin, register_to_config
22
+ from ...models.attention import FeedForward
23
+ from ...models.attention_processor import (
24
+ Attention,
25
+ AttentionProcessor,
26
+ CogVideoXAttnProcessor2_0,
27
+ )
28
+ from ...models.modeling_utils import ModelMixin
29
+ from ...models.normalization import AdaLayerNormContinuous
30
+ from ...utils import is_torch_version, logging
31
+ from ..embeddings import CogView3CombinedTimestepSizeEmbeddings, CogView3PlusPatchEmbed
32
+ from ..modeling_outputs import Transformer2DModelOutput
33
+ from ..normalization import CogView3PlusAdaLayerNormZeroTextImage
34
+
35
+
36
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
+
38
+
39
+ class CogView3PlusTransformerBlock(nn.Module):
40
+ r"""
41
+ Transformer block used in [CogView](https://github.com/THUDM/CogView3) model.
42
+
43
+ Args:
44
+ dim (`int`):
45
+ The number of channels in the input and output.
46
+ num_attention_heads (`int`):
47
+ The number of heads to use for multi-head attention.
48
+ attention_head_dim (`int`):
49
+ The number of channels in each head.
50
+ time_embed_dim (`int`):
51
+ The number of channels in timestep embedding.
52
+ """
53
+
54
+ def __init__(
55
+ self,
56
+ dim: int = 2560,
57
+ num_attention_heads: int = 64,
58
+ attention_head_dim: int = 40,
59
+ time_embed_dim: int = 512,
60
+ ):
61
+ super().__init__()
62
+
63
+ self.norm1 = CogView3PlusAdaLayerNormZeroTextImage(embedding_dim=time_embed_dim, dim=dim)
64
+
65
+ self.attn1 = Attention(
66
+ query_dim=dim,
67
+ heads=num_attention_heads,
68
+ dim_head=attention_head_dim,
69
+ out_dim=dim,
70
+ bias=True,
71
+ qk_norm="layer_norm",
72
+ elementwise_affine=False,
73
+ eps=1e-6,
74
+ processor=CogVideoXAttnProcessor2_0(),
75
+ )
76
+
77
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
78
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
79
+
80
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
81
+
82
+ def forward(
83
+ self,
84
+ hidden_states: torch.Tensor,
85
+ encoder_hidden_states: torch.Tensor,
86
+ emb: torch.Tensor,
87
+ ) -> torch.Tensor:
88
+ text_seq_length = encoder_hidden_states.size(1)
89
+
90
+ # norm & modulate
91
+ (
92
+ norm_hidden_states,
93
+ gate_msa,
94
+ shift_mlp,
95
+ scale_mlp,
96
+ gate_mlp,
97
+ norm_encoder_hidden_states,
98
+ c_gate_msa,
99
+ c_shift_mlp,
100
+ c_scale_mlp,
101
+ c_gate_mlp,
102
+ ) = self.norm1(hidden_states, encoder_hidden_states, emb)
103
+
104
+ # attention
105
+ attn_hidden_states, attn_encoder_hidden_states = self.attn1(
106
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
107
+ )
108
+
109
+ hidden_states = hidden_states + gate_msa.unsqueeze(1) * attn_hidden_states
110
+ encoder_hidden_states = encoder_hidden_states + c_gate_msa.unsqueeze(1) * attn_encoder_hidden_states
111
+
112
+ # norm & modulate
113
+ norm_hidden_states = self.norm2(hidden_states)
114
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
115
+
116
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
117
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
118
+
119
+ # feed-forward
120
+ norm_hidden_states = torch.cat([norm_encoder_hidden_states, norm_hidden_states], dim=1)
121
+ ff_output = self.ff(norm_hidden_states)
122
+
123
+ hidden_states = hidden_states + gate_mlp.unsqueeze(1) * ff_output[:, text_seq_length:]
124
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * ff_output[:, :text_seq_length]
125
+
126
+ if hidden_states.dtype == torch.float16:
127
+ hidden_states = hidden_states.clip(-65504, 65504)
128
+ if encoder_hidden_states.dtype == torch.float16:
129
+ encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
130
+ return hidden_states, encoder_hidden_states
131
+
132
+
133
+ class CogView3PlusTransformer2DModel(ModelMixin, ConfigMixin):
134
+ r"""
135
+ The Transformer model introduced in [CogView3: Finer and Faster Text-to-Image Generation via Relay
136
+ Diffusion](https://huggingface.co/papers/2403.05121).
137
+
138
+ Args:
139
+ patch_size (`int`, defaults to `2`):
140
+ The size of the patches to use in the patch embedding layer.
141
+ in_channels (`int`, defaults to `16`):
142
+ The number of channels in the input.
143
+ num_layers (`int`, defaults to `30`):
144
+ The number of layers of Transformer blocks to use.
145
+ attention_head_dim (`int`, defaults to `40`):
146
+ The number of channels in each head.
147
+ num_attention_heads (`int`, defaults to `64`):
148
+ The number of heads to use for multi-head attention.
149
+ out_channels (`int`, defaults to `16`):
150
+ The number of channels in the output.
151
+ text_embed_dim (`int`, defaults to `4096`):
152
+ Input dimension of text embeddings from the text encoder.
153
+ time_embed_dim (`int`, defaults to `512`):
154
+ Output dimension of timestep embeddings.
155
+ condition_dim (`int`, defaults to `256`):
156
+ The embedding dimension of the input SDXL-style resolution conditions (original_size, target_size,
157
+ crop_coords).
158
+ pos_embed_max_size (`int`, defaults to `128`):
159
+ The maximum resolution of the positional embeddings, from which slices of shape `H x W` are taken and added
160
+ to input patched latents, where `H` and `W` are the latent height and width respectively. A value of 128
161
+ means that the maximum supported height and width for image generation is `128 * vae_scale_factor *
162
+ patch_size => 128 * 8 * 2 => 2048`.
163
+ sample_size (`int`, defaults to `128`):
164
+ The base resolution of input latents. If height/width is not provided during generation, this value is used
165
+ to determine the resolution as `sample_size * vae_scale_factor => 128 * 8 => 1024`
166
+ """
167
+
168
+ _supports_gradient_checkpointing = True
169
+
170
+ @register_to_config
171
+ def __init__(
172
+ self,
173
+ patch_size: int = 2,
174
+ in_channels: int = 16,
175
+ num_layers: int = 30,
176
+ attention_head_dim: int = 40,
177
+ num_attention_heads: int = 64,
178
+ out_channels: int = 16,
179
+ text_embed_dim: int = 4096,
180
+ time_embed_dim: int = 512,
181
+ condition_dim: int = 256,
182
+ pos_embed_max_size: int = 128,
183
+ sample_size: int = 128,
184
+ ):
185
+ super().__init__()
186
+ self.out_channels = out_channels
187
+ self.inner_dim = num_attention_heads * attention_head_dim
188
+
189
+ # CogView3 uses 3 additional SDXL-like conditions - original_size, target_size, crop_coords
190
+ # Each of these are sincos embeddings of shape 2 * condition_dim
191
+ self.pooled_projection_dim = 3 * 2 * condition_dim
192
+
193
+ self.patch_embed = CogView3PlusPatchEmbed(
194
+ in_channels=in_channels,
195
+ hidden_size=self.inner_dim,
196
+ patch_size=patch_size,
197
+ text_hidden_size=text_embed_dim,
198
+ pos_embed_max_size=pos_embed_max_size,
199
+ )
200
+
201
+ self.time_condition_embed = CogView3CombinedTimestepSizeEmbeddings(
202
+ embedding_dim=time_embed_dim,
203
+ condition_dim=condition_dim,
204
+ pooled_projection_dim=self.pooled_projection_dim,
205
+ timesteps_dim=self.inner_dim,
206
+ )
207
+
208
+ self.transformer_blocks = nn.ModuleList(
209
+ [
210
+ CogView3PlusTransformerBlock(
211
+ dim=self.inner_dim,
212
+ num_attention_heads=num_attention_heads,
213
+ attention_head_dim=attention_head_dim,
214
+ time_embed_dim=time_embed_dim,
215
+ )
216
+ for _ in range(num_layers)
217
+ ]
218
+ )
219
+
220
+ self.norm_out = AdaLayerNormContinuous(
221
+ embedding_dim=self.inner_dim,
222
+ conditioning_embedding_dim=time_embed_dim,
223
+ elementwise_affine=False,
224
+ eps=1e-6,
225
+ )
226
+ self.proj_out = nn.Linear(self.inner_dim, patch_size * patch_size * self.out_channels, bias=True)
227
+
228
+ self.gradient_checkpointing = False
229
+
230
+ @property
231
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
232
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
233
+ r"""
234
+ Returns:
235
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
236
+ indexed by its weight name.
237
+ """
238
+ # set recursively
239
+ processors = {}
240
+
241
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
242
+ if hasattr(module, "get_processor"):
243
+ processors[f"{name}.processor"] = module.get_processor()
244
+
245
+ for sub_name, child in module.named_children():
246
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
247
+
248
+ return processors
249
+
250
+ for name, module in self.named_children():
251
+ fn_recursive_add_processors(name, module, processors)
252
+
253
+ return processors
254
+
255
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
256
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
257
+ r"""
258
+ Sets the attention processor to use to compute attention.
259
+
260
+ Parameters:
261
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
262
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
263
+ for **all** `Attention` layers.
264
+
265
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
266
+ processor. This is strongly recommended when setting trainable attention processors.
267
+
268
+ """
269
+ count = len(self.attn_processors.keys())
270
+
271
+ if isinstance(processor, dict) and len(processor) != count:
272
+ raise ValueError(
273
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
274
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
275
+ )
276
+
277
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
278
+ if hasattr(module, "set_processor"):
279
+ if not isinstance(processor, dict):
280
+ module.set_processor(processor)
281
+ else:
282
+ module.set_processor(processor.pop(f"{name}.processor"))
283
+
284
+ for sub_name, child in module.named_children():
285
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
286
+
287
+ for name, module in self.named_children():
288
+ fn_recursive_attn_processor(name, module, processor)
289
+
290
+ def _set_gradient_checkpointing(self, module, value=False):
291
+ if hasattr(module, "gradient_checkpointing"):
292
+ module.gradient_checkpointing = value
293
+
294
+ def forward(
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ encoder_hidden_states: torch.Tensor,
298
+ timestep: torch.LongTensor,
299
+ original_size: torch.Tensor,
300
+ target_size: torch.Tensor,
301
+ crop_coords: torch.Tensor,
302
+ return_dict: bool = True,
303
+ ) -> Union[torch.Tensor, Transformer2DModelOutput]:
304
+ """
305
+ The [`CogView3PlusTransformer2DModel`] forward method.
306
+
307
+ Args:
308
+ hidden_states (`torch.Tensor`):
309
+ Input `hidden_states` of shape `(batch size, channel, height, width)`.
310
+ encoder_hidden_states (`torch.Tensor`):
311
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) of shape
312
+ `(batch_size, sequence_len, text_embed_dim)`
313
+ timestep (`torch.LongTensor`):
314
+ Used to indicate denoising step.
315
+ original_size (`torch.Tensor`):
316
+ CogView3 uses SDXL-like micro-conditioning for original image size as explained in section 2.2 of
317
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
318
+ target_size (`torch.Tensor`):
319
+ CogView3 uses SDXL-like micro-conditioning for target image size as explained in section 2.2 of
320
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
321
+ crop_coords (`torch.Tensor`):
322
+ CogView3 uses SDXL-like micro-conditioning for crop coordinates as explained in section 2.2 of
323
+ [https://huggingface.co/papers/2307.01952](https://huggingface.co/papers/2307.01952).
324
+ return_dict (`bool`, *optional*, defaults to `True`):
325
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
326
+ tuple.
327
+
328
+ Returns:
329
+ `torch.Tensor` or [`~models.transformer_2d.Transformer2DModelOutput`]:
330
+ The denoised latents using provided inputs as conditioning.
331
+ """
332
+ height, width = hidden_states.shape[-2:]
333
+ text_seq_length = encoder_hidden_states.shape[1]
334
+
335
+ hidden_states = self.patch_embed(
336
+ hidden_states, encoder_hidden_states
337
+ ) # takes care of adding positional embeddings too.
338
+ emb = self.time_condition_embed(timestep, original_size, target_size, crop_coords, hidden_states.dtype)
339
+
340
+ encoder_hidden_states = hidden_states[:, :text_seq_length]
341
+ hidden_states = hidden_states[:, text_seq_length:]
342
+
343
+ for index_block, block in enumerate(self.transformer_blocks):
344
+ if self.training and self.gradient_checkpointing:
345
+
346
+ def create_custom_forward(module):
347
+ def custom_forward(*inputs):
348
+ return module(*inputs)
349
+
350
+ return custom_forward
351
+
352
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
353
+ hidden_states, encoder_hidden_states = torch.utils.checkpoint.checkpoint(
354
+ create_custom_forward(block),
355
+ hidden_states,
356
+ encoder_hidden_states,
357
+ emb,
358
+ **ckpt_kwargs,
359
+ )
360
+ else:
361
+ hidden_states, encoder_hidden_states = block(
362
+ hidden_states=hidden_states,
363
+ encoder_hidden_states=encoder_hidden_states,
364
+ emb=emb,
365
+ )
366
+
367
+ hidden_states = self.norm_out(hidden_states, emb)
368
+ hidden_states = self.proj_out(hidden_states) # (batch_size, height*width, patch_size*patch_size*out_channels)
369
+
370
+ # unpatchify
371
+ patch_size = self.config.patch_size
372
+ height = height // patch_size
373
+ width = width // patch_size
374
+
375
+ hidden_states = hidden_states.reshape(
376
+ shape=(hidden_states.shape[0], height, width, self.out_channels, patch_size, patch_size)
377
+ )
378
+ hidden_states = torch.einsum("nhwcpq->nchpwq", hidden_states)
379
+ output = hidden_states.reshape(
380
+ shape=(hidden_states.shape[0], self.out_channels, height * patch_size, width * patch_size)
381
+ )
382
+
383
+ if not return_dict:
384
+ return (output,)
385
+
386
+ return Transformer2DModelOutput(sample=output)