diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. diffusers/__init__.py +15 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +13 -1
  9. diffusers/loaders/single_file_model.py +15 -8
  10. diffusers/loaders/single_file_utils.py +267 -17
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +7 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_sd3.py +418 -0
  22. diffusers/models/controlnet_xs.py +6 -6
  23. diffusers/models/embeddings.py +112 -84
  24. diffusers/models/model_loading_utils.py +55 -0
  25. diffusers/models/modeling_utils.py +138 -20
  26. diffusers/models/normalization.py +11 -6
  27. diffusers/models/transformers/__init__.py +1 -0
  28. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  29. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  30. diffusers/models/transformers/prior_transformer.py +5 -5
  31. diffusers/models/transformers/transformer_2d.py +2 -2
  32. diffusers/models/transformers/transformer_sd3.py +353 -0
  33. diffusers/models/transformers/transformer_temporal.py +12 -10
  34. diffusers/models/unets/unet_1d.py +3 -3
  35. diffusers/models/unets/unet_2d.py +3 -3
  36. diffusers/models/unets/unet_2d_condition.py +4 -15
  37. diffusers/models/unets/unet_3d_condition.py +5 -17
  38. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  39. diffusers/models/unets/unet_motion_model.py +4 -4
  40. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  41. diffusers/models/vq_model.py +8 -165
  42. diffusers/pipelines/__init__.py +11 -0
  43. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  45. diffusers/pipelines/auto_pipeline.py +8 -0
  46. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  47. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  48. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  49. diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
  50. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
  51. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  52. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  54. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  55. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  56. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  57. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  58. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  59. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  60. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  61. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  62. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  63. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  64. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  65. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  72. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  73. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  74. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  75. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
  76. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
  77. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  78. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  79. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  80. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  81. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  82. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  83. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  84. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  85. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  86. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  87. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  88. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  89. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  90. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  91. diffusers/schedulers/__init__.py +2 -0
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  93. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  94. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  95. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  96. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  97. diffusers/training_utils.py +4 -4
  98. diffusers/utils/__init__.py +3 -0
  99. diffusers/utils/constants.py +2 -0
  100. diffusers/utils/dummy_pt_objects.py +60 -0
  101. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  102. diffusers/utils/dynamic_modules_utils.py +15 -13
  103. diffusers/utils/hub_utils.py +106 -0
  104. diffusers/utils/import_utils.py +0 -1
  105. diffusers/utils/logging.py +3 -1
  106. diffusers/utils/state_dict_utils.py +2 -0
  107. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
  108. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
  109. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
  110. diffusers/models/dual_transformer_2d.py +0 -20
  111. diffusers/models/prior_transformer.py +0 -12
  112. diffusers/models/t5_film_transformer.py +0 -70
  113. diffusers/models/transformer_2d.py +0 -25
  114. diffusers/models/transformer_temporal.py +0 -34
  115. diffusers/models/unet_1d.py +0 -26
  116. diffusers/models/unet_1d_blocks.py +0 -203
  117. diffusers/models/unet_2d.py +0 -27
  118. diffusers/models/unet_2d_blocks.py +0 -375
  119. diffusers/models/unet_2d_condition.py +0 -25
  120. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
  121. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
  122. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
@@ -31,17 +31,20 @@ if is_torch_available():
31
31
  _import_structure["autoencoders.autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
32
32
  _import_structure["autoencoders.autoencoder_tiny"] = ["AutoencoderTiny"]
33
33
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
34
+ _import_structure["autoencoders.vq_model"] = ["VQModel"]
34
35
  _import_structure["controlnet"] = ["ControlNetModel"]
36
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
35
37
  _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
36
- _import_structure["dual_transformer_2d"] = ["DualTransformer2DModel"]
37
38
  _import_structure["embeddings"] = ["ImageProjection"]
38
39
  _import_structure["modeling_utils"] = ["ModelMixin"]
39
40
  _import_structure["transformers.dit_transformer_2d"] = ["DiTTransformer2DModel"]
41
+ _import_structure["transformers.dual_transformer_2d"] = ["DualTransformer2DModel"]
40
42
  _import_structure["transformers.hunyuan_transformer_2d"] = ["HunyuanDiT2DModel"]
41
43
  _import_structure["transformers.pixart_transformer_2d"] = ["PixArtTransformer2DModel"]
42
44
  _import_structure["transformers.prior_transformer"] = ["PriorTransformer"]
43
45
  _import_structure["transformers.t5_film_transformer"] = ["T5FilmDecoder"]
44
46
  _import_structure["transformers.transformer_2d"] = ["Transformer2DModel"]
47
+ _import_structure["transformers.transformer_sd3"] = ["SD3Transformer2DModel"]
45
48
  _import_structure["transformers.transformer_temporal"] = ["TransformerTemporalModel"]
46
49
  _import_structure["unets.unet_1d"] = ["UNet1DModel"]
47
50
  _import_structure["unets.unet_2d"] = ["UNet2DModel"]
@@ -53,7 +56,6 @@ if is_torch_available():
53
56
  _import_structure["unets.unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
54
57
  _import_structure["unets.unet_stable_cascade"] = ["StableCascadeUNet"]
55
58
  _import_structure["unets.uvit_2d"] = ["UVit2DModel"]
56
- _import_structure["vq_model"] = ["VQModel"]
57
59
 
58
60
  if is_flax_available():
59
61
  _import_structure["controlnet_flax"] = ["FlaxControlNetModel"]
@@ -70,8 +72,10 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
70
72
  AutoencoderKLTemporalDecoder,
71
73
  AutoencoderTiny,
72
74
  ConsistencyDecoderVAE,
75
+ VQModel,
73
76
  )
74
77
  from .controlnet import ControlNetModel
78
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
75
79
  from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
76
80
  from .embeddings import ImageProjection
77
81
  from .modeling_utils import ModelMixin
@@ -81,6 +85,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
81
85
  HunyuanDiT2DModel,
82
86
  PixArtTransformer2DModel,
83
87
  PriorTransformer,
88
+ SD3Transformer2DModel,
84
89
  T5FilmDecoder,
85
90
  Transformer2DModel,
86
91
  TransformerTemporalModel,
@@ -98,7 +103,6 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
98
103
  UNetSpatioTemporalConditionModel,
99
104
  UVit2DModel,
100
105
  )
101
- from .vq_model import VQModel
102
106
 
103
107
  if is_flax_available():
104
108
  from .controlnet_flax import FlaxControlNetModel
@@ -20,7 +20,7 @@ from torch import nn
20
20
  from ..utils import deprecate, logging
21
21
  from ..utils.torch_utils import maybe_allow_in_graph
22
22
  from .activations import GEGLU, GELU, ApproximateGELU
23
- from .attention_processor import Attention
23
+ from .attention_processor import Attention, JointAttnProcessor2_0
24
24
  from .embeddings import SinusoidalPositionalEmbedding
25
25
  from .normalization import AdaLayerNorm, AdaLayerNormContinuous, AdaLayerNormZero, RMSNorm
26
26
 
@@ -85,6 +85,130 @@ class GatedSelfAttentionDense(nn.Module):
85
85
  return x
86
86
 
87
87
 
88
+ @maybe_allow_in_graph
89
+ class JointTransformerBlock(nn.Module):
90
+ r"""
91
+ A Transformer block following the MMDiT architecture, introduced in Stable Diffusion 3.
92
+
93
+ Reference: https://arxiv.org/abs/2403.03206
94
+
95
+ Parameters:
96
+ dim (`int`): The number of channels in the input and output.
97
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
98
+ attention_head_dim (`int`): The number of channels in each head.
99
+ context_pre_only (`bool`): Boolean to determine if we should add some blocks associated with the
100
+ processing of `context` conditions.
101
+ """
102
+
103
+ def __init__(self, dim, num_attention_heads, attention_head_dim, context_pre_only=False):
104
+ super().__init__()
105
+
106
+ self.context_pre_only = context_pre_only
107
+ context_norm_type = "ada_norm_continous" if context_pre_only else "ada_norm_zero"
108
+
109
+ self.norm1 = AdaLayerNormZero(dim)
110
+
111
+ if context_norm_type == "ada_norm_continous":
112
+ self.norm1_context = AdaLayerNormContinuous(
113
+ dim, dim, elementwise_affine=False, eps=1e-6, bias=True, norm_type="layer_norm"
114
+ )
115
+ elif context_norm_type == "ada_norm_zero":
116
+ self.norm1_context = AdaLayerNormZero(dim)
117
+ else:
118
+ raise ValueError(
119
+ f"Unknown context_norm_type: {context_norm_type}, currently only support `ada_norm_continous`, `ada_norm_zero`"
120
+ )
121
+ if hasattr(F, "scaled_dot_product_attention"):
122
+ processor = JointAttnProcessor2_0()
123
+ else:
124
+ raise ValueError(
125
+ "The current PyTorch version does not support the `scaled_dot_product_attention` function."
126
+ )
127
+ self.attn = Attention(
128
+ query_dim=dim,
129
+ cross_attention_dim=None,
130
+ added_kv_proj_dim=dim,
131
+ dim_head=attention_head_dim // num_attention_heads,
132
+ heads=num_attention_heads,
133
+ out_dim=attention_head_dim,
134
+ context_pre_only=context_pre_only,
135
+ bias=True,
136
+ processor=processor,
137
+ )
138
+
139
+ self.norm2 = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
140
+ self.ff = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
141
+
142
+ if not context_pre_only:
143
+ self.norm2_context = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-6)
144
+ self.ff_context = FeedForward(dim=dim, dim_out=dim, activation_fn="gelu-approximate")
145
+ else:
146
+ self.norm2_context = None
147
+ self.ff_context = None
148
+
149
+ # let chunk size default to None
150
+ self._chunk_size = None
151
+ self._chunk_dim = 0
152
+
153
+ # Copied from diffusers.models.attention.BasicTransformerBlock.set_chunk_feed_forward
154
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
155
+ # Sets chunk feed-forward
156
+ self._chunk_size = chunk_size
157
+ self._chunk_dim = dim
158
+
159
+ def forward(
160
+ self, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor, temb: torch.FloatTensor
161
+ ):
162
+ norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
163
+
164
+ if self.context_pre_only:
165
+ norm_encoder_hidden_states = self.norm1_context(encoder_hidden_states, temb)
166
+ else:
167
+ norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
168
+ encoder_hidden_states, emb=temb
169
+ )
170
+
171
+ # Attention.
172
+ attn_output, context_attn_output = self.attn(
173
+ hidden_states=norm_hidden_states, encoder_hidden_states=norm_encoder_hidden_states
174
+ )
175
+
176
+ # Process attention outputs for the `hidden_states`.
177
+ attn_output = gate_msa.unsqueeze(1) * attn_output
178
+ hidden_states = hidden_states + attn_output
179
+
180
+ norm_hidden_states = self.norm2(hidden_states)
181
+ norm_hidden_states = norm_hidden_states * (1 + scale_mlp[:, None]) + shift_mlp[:, None]
182
+ if self._chunk_size is not None:
183
+ # "feed_forward_chunk_size" can be used to save memory
184
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
185
+ else:
186
+ ff_output = self.ff(norm_hidden_states)
187
+ ff_output = gate_mlp.unsqueeze(1) * ff_output
188
+
189
+ hidden_states = hidden_states + ff_output
190
+
191
+ # Process attention outputs for the `encoder_hidden_states`.
192
+ if self.context_pre_only:
193
+ encoder_hidden_states = None
194
+ else:
195
+ context_attn_output = c_gate_msa.unsqueeze(1) * context_attn_output
196
+ encoder_hidden_states = encoder_hidden_states + context_attn_output
197
+
198
+ norm_encoder_hidden_states = self.norm2_context(encoder_hidden_states)
199
+ norm_encoder_hidden_states = norm_encoder_hidden_states * (1 + c_scale_mlp[:, None]) + c_shift_mlp[:, None]
200
+ if self._chunk_size is not None:
201
+ # "feed_forward_chunk_size" can be used to save memory
202
+ context_ff_output = _chunked_feed_forward(
203
+ self.ff_context, norm_encoder_hidden_states, self._chunk_dim, self._chunk_size
204
+ )
205
+ else:
206
+ context_ff_output = self.ff_context(norm_encoder_hidden_states)
207
+ encoder_hidden_states = encoder_hidden_states + c_gate_mlp.unsqueeze(1) * context_ff_output
208
+
209
+ return encoder_hidden_states, hidden_states
210
+
211
+
88
212
  @maybe_allow_in_graph
89
213
  class BasicTransformerBlock(nn.Module):
90
214
  r"""
@@ -116,6 +116,7 @@ class Attention(nn.Module):
116
116
  _from_deprecated_attn_block: bool = False,
117
117
  processor: Optional["AttnProcessor"] = None,
118
118
  out_dim: int = None,
119
+ context_pre_only=None,
119
120
  ):
120
121
  super().__init__()
121
122
  self.inner_dim = out_dim if out_dim is not None else dim_head * heads
@@ -130,6 +131,7 @@ class Attention(nn.Module):
130
131
  self.dropout = dropout
131
132
  self.fused_projections = False
132
133
  self.out_dim = out_dim if out_dim is not None else query_dim
134
+ self.context_pre_only = context_pre_only
133
135
 
134
136
  # we make use of this private variable to know whether this class is loaded
135
137
  # with an deprecated state dict so that we can convert it on the fly
@@ -207,11 +209,16 @@ class Attention(nn.Module):
207
209
  if self.added_kv_proj_dim is not None:
208
210
  self.add_k_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
209
211
  self.add_v_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
212
+ if self.context_pre_only is not None:
213
+ self.add_q_proj = nn.Linear(added_kv_proj_dim, self.inner_dim)
210
214
 
211
215
  self.to_out = nn.ModuleList([])
212
216
  self.to_out.append(nn.Linear(self.inner_dim, self.out_dim, bias=out_bias))
213
217
  self.to_out.append(nn.Dropout(dropout))
214
218
 
219
+ if self.context_pre_only is not None and not self.context_pre_only:
220
+ self.to_add_out = nn.Linear(self.inner_dim, self.out_dim, bias=out_bias)
221
+
215
222
  # set attention processor
216
223
  # We use the AttnProcessor2_0 by default when torch 2.x is used which uses
217
224
  # torch.nn.functional.scaled_dot_product_attention for native Flash/memory_efficient_attention
@@ -539,7 +546,10 @@ class Attention(nn.Module):
539
546
  # For standard processors that are defined here, `**cross_attention_kwargs` is empty
540
547
 
541
548
  attn_parameters = set(inspect.signature(self.processor.__call__).parameters.keys())
542
- unused_kwargs = [k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters]
549
+ quiet_attn_parameters = {"ip_adapter_masks"}
550
+ unused_kwargs = [
551
+ k for k, _ in cross_attention_kwargs.items() if k not in attn_parameters and k not in quiet_attn_parameters
552
+ ]
543
553
  if len(unused_kwargs) > 0:
544
554
  logger.warning(
545
555
  f"cross_attention_kwargs {unused_kwargs} are not expected by {self.processor.__class__.__name__} and will be ignored."
@@ -1072,6 +1082,164 @@ class AttnAddedKVProcessor2_0:
1072
1082
  return hidden_states
1073
1083
 
1074
1084
 
1085
+ class JointAttnProcessor2_0:
1086
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1087
+
1088
+ def __init__(self):
1089
+ if not hasattr(F, "scaled_dot_product_attention"):
1090
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1091
+
1092
+ def __call__(
1093
+ self,
1094
+ attn: Attention,
1095
+ hidden_states: torch.FloatTensor,
1096
+ encoder_hidden_states: torch.FloatTensor = None,
1097
+ attention_mask: Optional[torch.FloatTensor] = None,
1098
+ *args,
1099
+ **kwargs,
1100
+ ) -> torch.FloatTensor:
1101
+ residual = hidden_states
1102
+
1103
+ input_ndim = hidden_states.ndim
1104
+ if input_ndim == 4:
1105
+ batch_size, channel, height, width = hidden_states.shape
1106
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1107
+ context_input_ndim = encoder_hidden_states.ndim
1108
+ if context_input_ndim == 4:
1109
+ batch_size, channel, height, width = encoder_hidden_states.shape
1110
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1111
+
1112
+ batch_size = encoder_hidden_states.shape[0]
1113
+
1114
+ # `sample` projections.
1115
+ query = attn.to_q(hidden_states)
1116
+ key = attn.to_k(hidden_states)
1117
+ value = attn.to_v(hidden_states)
1118
+
1119
+ # `context` projections.
1120
+ encoder_hidden_states_query_proj = attn.add_q_proj(encoder_hidden_states)
1121
+ encoder_hidden_states_key_proj = attn.add_k_proj(encoder_hidden_states)
1122
+ encoder_hidden_states_value_proj = attn.add_v_proj(encoder_hidden_states)
1123
+
1124
+ # attention
1125
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1126
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1127
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1128
+
1129
+ inner_dim = key.shape[-1]
1130
+ head_dim = inner_dim // attn.heads
1131
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1132
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1133
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1134
+
1135
+ hidden_states = hidden_states = F.scaled_dot_product_attention(
1136
+ query, key, value, dropout_p=0.0, is_causal=False
1137
+ )
1138
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1139
+ hidden_states = hidden_states.to(query.dtype)
1140
+
1141
+ # Split the attention outputs.
1142
+ hidden_states, encoder_hidden_states = (
1143
+ hidden_states[:, : residual.shape[1]],
1144
+ hidden_states[:, residual.shape[1] :],
1145
+ )
1146
+
1147
+ # linear proj
1148
+ hidden_states = attn.to_out[0](hidden_states)
1149
+ # dropout
1150
+ hidden_states = attn.to_out[1](hidden_states)
1151
+ if not attn.context_pre_only:
1152
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1153
+
1154
+ if input_ndim == 4:
1155
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1156
+ if context_input_ndim == 4:
1157
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1158
+
1159
+ return hidden_states, encoder_hidden_states
1160
+
1161
+
1162
+ class FusedJointAttnProcessor2_0:
1163
+ """Attention processor used typically in processing the SD3-like self-attention projections."""
1164
+
1165
+ def __init__(self):
1166
+ if not hasattr(F, "scaled_dot_product_attention"):
1167
+ raise ImportError("AttnProcessor2_0 requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0.")
1168
+
1169
+ def __call__(
1170
+ self,
1171
+ attn: Attention,
1172
+ hidden_states: torch.FloatTensor,
1173
+ encoder_hidden_states: torch.FloatTensor = None,
1174
+ attention_mask: Optional[torch.FloatTensor] = None,
1175
+ *args,
1176
+ **kwargs,
1177
+ ) -> torch.FloatTensor:
1178
+ residual = hidden_states
1179
+
1180
+ input_ndim = hidden_states.ndim
1181
+ if input_ndim == 4:
1182
+ batch_size, channel, height, width = hidden_states.shape
1183
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1184
+ context_input_ndim = encoder_hidden_states.ndim
1185
+ if context_input_ndim == 4:
1186
+ batch_size, channel, height, width = encoder_hidden_states.shape
1187
+ encoder_hidden_states = encoder_hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
1188
+
1189
+ batch_size = encoder_hidden_states.shape[0]
1190
+
1191
+ # `sample` projections.
1192
+ qkv = attn.to_qkv(hidden_states)
1193
+ split_size = qkv.shape[-1] // 3
1194
+ query, key, value = torch.split(qkv, split_size, dim=-1)
1195
+
1196
+ # `context` projections.
1197
+ encoder_qkv = attn.to_added_qkv(encoder_hidden_states)
1198
+ split_size = encoder_qkv.shape[-1] // 3
1199
+ (
1200
+ encoder_hidden_states_query_proj,
1201
+ encoder_hidden_states_key_proj,
1202
+ encoder_hidden_states_value_proj,
1203
+ ) = torch.split(encoder_qkv, split_size, dim=-1)
1204
+
1205
+ # attention
1206
+ query = torch.cat([query, encoder_hidden_states_query_proj], dim=1)
1207
+ key = torch.cat([key, encoder_hidden_states_key_proj], dim=1)
1208
+ value = torch.cat([value, encoder_hidden_states_value_proj], dim=1)
1209
+
1210
+ inner_dim = key.shape[-1]
1211
+ head_dim = inner_dim // attn.heads
1212
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1213
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1214
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
1215
+
1216
+ hidden_states = hidden_states = F.scaled_dot_product_attention(
1217
+ query, key, value, dropout_p=0.0, is_causal=False
1218
+ )
1219
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
1220
+ hidden_states = hidden_states.to(query.dtype)
1221
+
1222
+ # Split the attention outputs.
1223
+ hidden_states, encoder_hidden_states = (
1224
+ hidden_states[:, : residual.shape[1]],
1225
+ hidden_states[:, residual.shape[1] :],
1226
+ )
1227
+
1228
+ # linear proj
1229
+ hidden_states = attn.to_out[0](hidden_states)
1230
+ # dropout
1231
+ hidden_states = attn.to_out[1](hidden_states)
1232
+ if not attn.context_pre_only:
1233
+ encoder_hidden_states = attn.to_add_out(encoder_hidden_states)
1234
+
1235
+ if input_ndim == 4:
1236
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1237
+ if context_input_ndim == 4:
1238
+ encoder_hidden_states = encoder_hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
1239
+
1240
+ return hidden_states, encoder_hidden_states
1241
+
1242
+
1075
1243
  class XFormersAttnAddedKVProcessor:
1076
1244
  r"""
1077
1245
  Processor for implementing memory efficient attention using xFormers.
@@ -3,3 +3,4 @@ from .autoencoder_kl import AutoencoderKL
3
3
  from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
4
4
  from .autoencoder_tiny import AutoencoderTiny
5
5
  from .consistency_decoder_vae import ConsistencyDecoderVAE
6
+ from .vq_model import VQModel
@@ -176,7 +176,7 @@ class AsymmetricAutoencoderKL(ModelMixin, ConfigMixin):
176
176
  z = posterior.sample(generator=generator)
177
177
  else:
178
178
  z = posterior.mode()
179
- dec = self.decode(z, sample, mask).sample
179
+ dec = self.decode(z, generator, sample, mask).sample
180
180
 
181
181
  if not return_dict:
182
182
  return (dec,)
@@ -81,9 +81,12 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
81
81
  norm_num_groups: int = 32,
82
82
  sample_size: int = 32,
83
83
  scaling_factor: float = 0.18215,
84
+ shift_factor: Optional[float] = None,
84
85
  latents_mean: Optional[Tuple[float]] = None,
85
86
  latents_std: Optional[Tuple[float]] = None,
86
87
  force_upcast: float = True,
88
+ use_quant_conv: bool = True,
89
+ use_post_quant_conv: bool = True,
87
90
  ):
88
91
  super().__init__()
89
92
 
@@ -110,8 +113,8 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
110
113
  act_fn=act_fn,
111
114
  )
112
115
 
113
- self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
114
- self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
116
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) if use_quant_conv else None
117
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) if use_post_quant_conv else None
115
118
 
116
119
  self.use_slicing = False
117
120
  self.use_tiling = False
@@ -260,7 +263,11 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
260
263
  else:
261
264
  h = self.encoder(x)
262
265
 
263
- moments = self.quant_conv(h)
266
+ if self.quant_conv is not None:
267
+ moments = self.quant_conv(h)
268
+ else:
269
+ moments = h
270
+
264
271
  posterior = DiagonalGaussianDistribution(moments)
265
272
 
266
273
  if not return_dict:
@@ -272,7 +279,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
272
279
  if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
273
280
  return self.tiled_decode(z, return_dict=return_dict)
274
281
 
275
- z = self.post_quant_conv(z)
282
+ if self.post_quant_conv is not None:
283
+ z = self.post_quant_conv(z)
284
+
276
285
  dec = self.decoder(z)
277
286
 
278
287
  if not return_dict:
@@ -281,7 +290,9 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
281
290
  return DecoderOutput(sample=dec)
282
291
 
283
292
  @apply_forward_hook
284
- def decode(self, z: torch.Tensor, return_dict: bool = True, generator=None) -> Union[DecoderOutput, torch.Tensor]:
293
+ def decode(
294
+ self, z: torch.FloatTensor, return_dict: bool = True, generator=None
295
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
285
296
  """
286
297
  Decode a batch of images.
287
298
 
@@ -300,7 +311,7 @@ class AutoencoderKL(ModelMixin, ConfigMixin, FromOriginalModelMixin):
300
311
  decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
301
312
  decoded = torch.cat(decoded_slices)
302
313
  else:
303
- decoded = self._decode(z, return_dict=False)[0]
314
+ decoded = self._decode(z).sample
304
315
 
305
316
  if not return_dict:
306
317
  return (decoded,)
@@ -323,11 +323,13 @@ class AutoencoderKLTemporalDecoder(ModelMixin, ConfigMixin):
323
323
  Args:
324
324
  x (`torch.Tensor`): Input batch of images.
325
325
  return_dict (`bool`, *optional*, defaults to `True`):
326
- Whether to return a [`~models.autoencoder_kl.AutoencoderKLOutput`] instead of a plain tuple.
326
+ Whether to return a [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] instead of a plain
327
+ tuple.
327
328
 
328
329
  Returns:
329
330
  The latent representations of the encoded images. If `return_dict` is True, a
330
- [`~models.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is returned.
331
+ [`~models.autoencoders.autoencoder_kl.AutoencoderKLOutput`] is returned, otherwise a plain `tuple` is
332
+ returned.
331
333
  """
332
334
  h = self.encoder(x)
333
335
  moments = self.quant_conv(h)
@@ -284,13 +284,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
284
284
  Args:
285
285
  x (`torch.Tensor`): Input batch of images.
286
286
  return_dict (`bool`, *optional*, defaults to `True`):
287
- Whether to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a plain
288
- tuple.
287
+ Whether to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
288
+ instead of a plain tuple.
289
289
 
290
290
  Returns:
291
291
  The latent representations of the encoded images. If `return_dict` is True, a
292
- [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a plain `tuple`
293
- is returned.
292
+ [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned, otherwise a
293
+ plain `tuple` is returned.
294
294
  """
295
295
  if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
296
296
  return self.tiled_encode(x, return_dict=return_dict)
@@ -382,13 +382,13 @@ class ConsistencyDecoderVAE(ModelMixin, ConfigMixin):
382
382
  Args:
383
383
  x (`torch.Tensor`): Input batch of images.
384
384
  return_dict (`bool`, *optional*, defaults to `True`):
385
- Whether or not to return a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] instead of a
386
- plain tuple.
385
+ Whether or not to return a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
386
+ instead of a plain tuple.
387
387
 
388
388
  Returns:
389
- [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
390
- If return_dict is True, a [`~models.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] is returned,
391
- otherwise a plain `tuple` is returned.
389
+ [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`] or `tuple`:
390
+ If return_dict is True, a [`~models.autoencoders.consistency_decoder_vae.ConsistencyDecoderVAEOutput`]
391
+ is returned, otherwise a plain `tuple` is returned.
392
392
  """
393
393
  overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
394
394
  blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)