diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,4 @@
|
|
1
|
-
# Copyright 2024 Black Forest Labs, The HuggingFace Team. All rights reserved.
|
1
|
+
# Copyright 2024 Black Forest Labs, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
2
|
#
|
3
3
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
4
|
# you may not use this file except in compliance with the License.
|
@@ -13,8 +13,9 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import Any, Dict,
|
16
|
+
from typing import Any, Dict, Optional, Tuple, Union
|
17
17
|
|
18
|
+
import numpy as np
|
18
19
|
import torch
|
19
20
|
import torch.nn as nn
|
20
21
|
import torch.nn.functional as F
|
@@ -22,52 +23,23 @@ import torch.nn.functional as F
|
|
22
23
|
from ...configuration_utils import ConfigMixin, register_to_config
|
23
24
|
from ...loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
25
|
from ...models.attention import FeedForward
|
25
|
-
from ...models.attention_processor import
|
26
|
+
from ...models.attention_processor import (
|
27
|
+
Attention,
|
28
|
+
AttentionProcessor,
|
29
|
+
FluxAttnProcessor2_0,
|
30
|
+
FusedFluxAttnProcessor2_0,
|
31
|
+
)
|
26
32
|
from ...models.modeling_utils import ModelMixin
|
27
33
|
from ...models.normalization import AdaLayerNormContinuous, AdaLayerNormZero, AdaLayerNormZeroSingle
|
28
34
|
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
29
35
|
from ...utils.torch_utils import maybe_allow_in_graph
|
30
|
-
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings
|
36
|
+
from ..embeddings import CombinedTimestepGuidanceTextProjEmbeddings, CombinedTimestepTextProjEmbeddings, FluxPosEmbed
|
31
37
|
from ..modeling_outputs import Transformer2DModelOutput
|
32
38
|
|
33
39
|
|
34
40
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
35
41
|
|
36
42
|
|
37
|
-
# YiYi to-do: refactor rope related functions/classes
|
38
|
-
def rope(pos: torch.Tensor, dim: int, theta: int) -> torch.Tensor:
|
39
|
-
assert dim % 2 == 0, "The dimension must be even."
|
40
|
-
|
41
|
-
scale = torch.arange(0, dim, 2, dtype=torch.float64, device=pos.device) / dim
|
42
|
-
omega = 1.0 / (theta**scale)
|
43
|
-
|
44
|
-
batch_size, seq_length = pos.shape
|
45
|
-
out = torch.einsum("...n,d->...nd", pos, omega)
|
46
|
-
cos_out = torch.cos(out)
|
47
|
-
sin_out = torch.sin(out)
|
48
|
-
|
49
|
-
stacked_out = torch.stack([cos_out, -sin_out, sin_out, cos_out], dim=-1)
|
50
|
-
out = stacked_out.view(batch_size, -1, dim // 2, 2, 2)
|
51
|
-
return out.float()
|
52
|
-
|
53
|
-
|
54
|
-
# YiYi to-do: refactor rope related functions/classes
|
55
|
-
class EmbedND(nn.Module):
|
56
|
-
def __init__(self, dim: int, theta: int, axes_dim: List[int]):
|
57
|
-
super().__init__()
|
58
|
-
self.dim = dim
|
59
|
-
self.theta = theta
|
60
|
-
self.axes_dim = axes_dim
|
61
|
-
|
62
|
-
def forward(self, ids: torch.Tensor) -> torch.Tensor:
|
63
|
-
n_axes = ids.shape[-1]
|
64
|
-
emb = torch.cat(
|
65
|
-
[rope(ids[..., i], self.axes_dim[i], self.theta) for i in range(n_axes)],
|
66
|
-
dim=-3,
|
67
|
-
)
|
68
|
-
return emb.unsqueeze(1)
|
69
|
-
|
70
|
-
|
71
43
|
@maybe_allow_in_graph
|
72
44
|
class FluxSingleTransformerBlock(nn.Module):
|
73
45
|
r"""
|
@@ -92,7 +64,7 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
92
64
|
self.act_mlp = nn.GELU(approximate="tanh")
|
93
65
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
94
66
|
|
95
|
-
processor =
|
67
|
+
processor = FluxAttnProcessor2_0()
|
96
68
|
self.attn = Attention(
|
97
69
|
query_dim=dim,
|
98
70
|
cross_attention_dim=None,
|
@@ -111,14 +83,16 @@ class FluxSingleTransformerBlock(nn.Module):
|
|
111
83
|
hidden_states: torch.FloatTensor,
|
112
84
|
temb: torch.FloatTensor,
|
113
85
|
image_rotary_emb=None,
|
86
|
+
joint_attention_kwargs=None,
|
114
87
|
):
|
115
88
|
residual = hidden_states
|
116
89
|
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
117
90
|
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
118
|
-
|
91
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
119
92
|
attn_output = self.attn(
|
120
93
|
hidden_states=norm_hidden_states,
|
121
94
|
image_rotary_emb=image_rotary_emb,
|
95
|
+
**joint_attention_kwargs,
|
122
96
|
)
|
123
97
|
|
124
98
|
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
@@ -189,18 +163,20 @@ class FluxTransformerBlock(nn.Module):
|
|
189
163
|
encoder_hidden_states: torch.FloatTensor,
|
190
164
|
temb: torch.FloatTensor,
|
191
165
|
image_rotary_emb=None,
|
166
|
+
joint_attention_kwargs=None,
|
192
167
|
):
|
193
168
|
norm_hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.norm1(hidden_states, emb=temb)
|
194
169
|
|
195
170
|
norm_encoder_hidden_states, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp = self.norm1_context(
|
196
171
|
encoder_hidden_states, emb=temb
|
197
172
|
)
|
198
|
-
|
173
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
199
174
|
# Attention.
|
200
175
|
attn_output, context_attn_output = self.attn(
|
201
176
|
hidden_states=norm_hidden_states,
|
202
177
|
encoder_hidden_states=norm_encoder_hidden_states,
|
203
178
|
image_rotary_emb=image_rotary_emb,
|
179
|
+
**joint_attention_kwargs,
|
204
180
|
)
|
205
181
|
|
206
182
|
# Process attention outputs for the `hidden_states`.
|
@@ -250,6 +226,7 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
250
226
|
"""
|
251
227
|
|
252
228
|
_supports_gradient_checkpointing = True
|
229
|
+
_no_split_modules = ["FluxTransformerBlock", "FluxSingleTransformerBlock"]
|
253
230
|
|
254
231
|
@register_to_config
|
255
232
|
def __init__(
|
@@ -263,13 +240,14 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
263
240
|
joint_attention_dim: int = 4096,
|
264
241
|
pooled_projection_dim: int = 768,
|
265
242
|
guidance_embeds: bool = False,
|
266
|
-
axes_dims_rope:
|
243
|
+
axes_dims_rope: Tuple[int] = (16, 56, 56),
|
267
244
|
):
|
268
245
|
super().__init__()
|
269
246
|
self.out_channels = in_channels
|
270
247
|
self.inner_dim = self.config.num_attention_heads * self.config.attention_head_dim
|
271
248
|
|
272
|
-
self.pos_embed =
|
249
|
+
self.pos_embed = FluxPosEmbed(theta=10000, axes_dim=axes_dims_rope)
|
250
|
+
|
273
251
|
text_time_guidance_cls = (
|
274
252
|
CombinedTimestepGuidanceTextProjEmbeddings if guidance_embeds else CombinedTimestepTextProjEmbeddings
|
275
253
|
)
|
@@ -307,6 +285,106 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
307
285
|
|
308
286
|
self.gradient_checkpointing = False
|
309
287
|
|
288
|
+
@property
|
289
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
290
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
291
|
+
r"""
|
292
|
+
Returns:
|
293
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
294
|
+
indexed by its weight name.
|
295
|
+
"""
|
296
|
+
# set recursively
|
297
|
+
processors = {}
|
298
|
+
|
299
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
300
|
+
if hasattr(module, "get_processor"):
|
301
|
+
processors[f"{name}.processor"] = module.get_processor()
|
302
|
+
|
303
|
+
for sub_name, child in module.named_children():
|
304
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
305
|
+
|
306
|
+
return processors
|
307
|
+
|
308
|
+
for name, module in self.named_children():
|
309
|
+
fn_recursive_add_processors(name, module, processors)
|
310
|
+
|
311
|
+
return processors
|
312
|
+
|
313
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
314
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
315
|
+
r"""
|
316
|
+
Sets the attention processor to use to compute attention.
|
317
|
+
|
318
|
+
Parameters:
|
319
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
320
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
321
|
+
for **all** `Attention` layers.
|
322
|
+
|
323
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
324
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
325
|
+
|
326
|
+
"""
|
327
|
+
count = len(self.attn_processors.keys())
|
328
|
+
|
329
|
+
if isinstance(processor, dict) and len(processor) != count:
|
330
|
+
raise ValueError(
|
331
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
332
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
333
|
+
)
|
334
|
+
|
335
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
336
|
+
if hasattr(module, "set_processor"):
|
337
|
+
if not isinstance(processor, dict):
|
338
|
+
module.set_processor(processor)
|
339
|
+
else:
|
340
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
341
|
+
|
342
|
+
for sub_name, child in module.named_children():
|
343
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
344
|
+
|
345
|
+
for name, module in self.named_children():
|
346
|
+
fn_recursive_attn_processor(name, module, processor)
|
347
|
+
|
348
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
349
|
+
def fuse_qkv_projections(self):
|
350
|
+
"""
|
351
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
352
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
353
|
+
|
354
|
+
<Tip warning={true}>
|
355
|
+
|
356
|
+
This API is 🧪 experimental.
|
357
|
+
|
358
|
+
</Tip>
|
359
|
+
"""
|
360
|
+
self.original_attn_processors = None
|
361
|
+
|
362
|
+
for _, attn_processor in self.attn_processors.items():
|
363
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
364
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
365
|
+
|
366
|
+
self.original_attn_processors = self.attn_processors
|
367
|
+
|
368
|
+
for module in self.modules():
|
369
|
+
if isinstance(module, Attention):
|
370
|
+
module.fuse_projections(fuse=True)
|
371
|
+
|
372
|
+
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
373
|
+
|
374
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
375
|
+
def unfuse_qkv_projections(self):
|
376
|
+
"""Disables the fused QKV projection if enabled.
|
377
|
+
|
378
|
+
<Tip warning={true}>
|
379
|
+
|
380
|
+
This API is 🧪 experimental.
|
381
|
+
|
382
|
+
</Tip>
|
383
|
+
|
384
|
+
"""
|
385
|
+
if self.original_attn_processors is not None:
|
386
|
+
self.set_attn_processor(self.original_attn_processors)
|
387
|
+
|
310
388
|
def _set_gradient_checkpointing(self, module, value=False):
|
311
389
|
if hasattr(module, "gradient_checkpointing"):
|
312
390
|
module.gradient_checkpointing = value
|
@@ -321,7 +399,10 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
321
399
|
txt_ids: torch.Tensor = None,
|
322
400
|
guidance: torch.Tensor = None,
|
323
401
|
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
402
|
+
controlnet_block_samples=None,
|
403
|
+
controlnet_single_block_samples=None,
|
324
404
|
return_dict: bool = True,
|
405
|
+
controlnet_blocks_repeat: bool = False,
|
325
406
|
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
326
407
|
"""
|
327
408
|
The [`FluxTransformer2DModel`] forward method.
|
@@ -377,7 +458,20 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
377
458
|
)
|
378
459
|
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
379
460
|
|
380
|
-
|
461
|
+
if txt_ids.ndim == 3:
|
462
|
+
logger.warning(
|
463
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
464
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
465
|
+
)
|
466
|
+
txt_ids = txt_ids[0]
|
467
|
+
if img_ids.ndim == 3:
|
468
|
+
logger.warning(
|
469
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
470
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
471
|
+
)
|
472
|
+
img_ids = img_ids[0]
|
473
|
+
|
474
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
381
475
|
image_rotary_emb = self.pos_embed(ids)
|
382
476
|
|
383
477
|
for index_block, block in enumerate(self.transformer_blocks):
|
@@ -408,8 +502,21 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
408
502
|
encoder_hidden_states=encoder_hidden_states,
|
409
503
|
temb=temb,
|
410
504
|
image_rotary_emb=image_rotary_emb,
|
505
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
411
506
|
)
|
412
507
|
|
508
|
+
# controlnet residual
|
509
|
+
if controlnet_block_samples is not None:
|
510
|
+
interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
511
|
+
interval_control = int(np.ceil(interval_control))
|
512
|
+
# For Xlabs ControlNet.
|
513
|
+
if controlnet_blocks_repeat:
|
514
|
+
hidden_states = (
|
515
|
+
hidden_states + controlnet_block_samples[index_block % len(controlnet_block_samples)]
|
516
|
+
)
|
517
|
+
else:
|
518
|
+
hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
519
|
+
|
413
520
|
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
414
521
|
|
415
522
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
@@ -438,6 +545,16 @@ class FluxTransformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrig
|
|
438
545
|
hidden_states=hidden_states,
|
439
546
|
temb=temb,
|
440
547
|
image_rotary_emb=image_rotary_emb,
|
548
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
549
|
+
)
|
550
|
+
|
551
|
+
# controlnet residual
|
552
|
+
if controlnet_single_block_samples is not None:
|
553
|
+
interval_control = len(self.single_transformer_blocks) / len(controlnet_single_block_samples)
|
554
|
+
interval_control = int(np.ceil(interval_control))
|
555
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...] = (
|
556
|
+
hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
557
|
+
+ controlnet_single_block_samples[index_block // interval_control]
|
441
558
|
)
|
442
559
|
|
443
560
|
hidden_states = hidden_states[:, encoder_hidden_states.shape[1] :, ...]
|
@@ -13,7 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
|
16
|
-
from typing import Any, Dict, List, Optional, Union
|
16
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
17
17
|
|
18
18
|
import torch
|
19
19
|
import torch.nn as nn
|
@@ -69,6 +69,10 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
69
69
|
pooled_projection_dim: int = 2048,
|
70
70
|
out_channels: int = 16,
|
71
71
|
pos_embed_max_size: int = 96,
|
72
|
+
dual_attention_layers: Tuple[
|
73
|
+
int, ...
|
74
|
+
] = (), # () for sd3.0; (0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12) for sd3.5
|
75
|
+
qk_norm: Optional[str] = None,
|
72
76
|
):
|
73
77
|
super().__init__()
|
74
78
|
default_out_channels = in_channels
|
@@ -97,6 +101,8 @@ class SD3Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOrigi
|
|
97
101
|
num_attention_heads=self.config.num_attention_heads,
|
98
102
|
attention_head_dim=self.config.attention_head_dim,
|
99
103
|
context_pre_only=i == num_layers - 1,
|
104
|
+
qk_norm=qk_norm,
|
105
|
+
use_dual_attention=True if i in dual_attention_layers else False,
|
100
106
|
)
|
101
107
|
for i in range(self.config.num_layers)
|
102
108
|
]
|
@@ -463,7 +463,6 @@ class UNet2DConditionModel(
|
|
463
463
|
dropout=dropout,
|
464
464
|
)
|
465
465
|
self.up_blocks.append(up_block)
|
466
|
-
prev_output_channel = output_channel
|
467
466
|
|
468
467
|
# out
|
469
468
|
if norm_num_groups is not None:
|
@@ -599,7 +598,7 @@ class UNet2DConditionModel(
|
|
599
598
|
)
|
600
599
|
elif encoder_hid_dim_type is not None:
|
601
600
|
raise ValueError(
|
602
|
-
f"encoder_hid_dim_type
|
601
|
+
f"`encoder_hid_dim_type`: {encoder_hid_dim_type} must be None, 'text_proj', 'text_image_proj', or 'image_proj'."
|
603
602
|
)
|
604
603
|
else:
|
605
604
|
self.encoder_hid_proj = None
|
@@ -679,7 +678,9 @@ class UNet2DConditionModel(
|
|
679
678
|
# Kandinsky 2.2 ControlNet
|
680
679
|
self.add_embedding = ImageHintTimeEmbedding(image_embed_dim=encoder_hid_dim, time_embed_dim=time_embed_dim)
|
681
680
|
elif addition_embed_type is not None:
|
682
|
-
raise ValueError(
|
681
|
+
raise ValueError(
|
682
|
+
f"`addition_embed_type`: {addition_embed_type} must be None, 'text', 'text_image', 'text_time', 'image', or 'image_hint'."
|
683
|
+
)
|
683
684
|
|
684
685
|
def _set_pos_net_if_use_gligen(self, attention_type: str, cross_attention_dim: int):
|
685
686
|
if attention_type in ["gated", "gated-text-image"]:
|
@@ -990,7 +991,7 @@ class UNet2DConditionModel(
|
|
990
991
|
image_embs = added_cond_kwargs.get("image_embeds")
|
991
992
|
aug_emb = self.add_embedding(image_embs)
|
992
993
|
elif self.config.addition_embed_type == "image_hint":
|
993
|
-
# Kandinsky 2.2 - style
|
994
|
+
# Kandinsky 2.2 ControlNet - style
|
994
995
|
if "image_embeds" not in added_cond_kwargs or "hint" not in added_cond_kwargs:
|
995
996
|
raise ValueError(
|
996
997
|
f"{self.__class__} has the config param `addition_embed_type` set to 'image_hint' which requires the keyword arguments `image_embeds` and `hint` to be passed in `added_cond_kwargs`"
|
@@ -1009,7 +1010,7 @@ class UNet2DConditionModel(
|
|
1009
1010
|
# Kandinsky 2.1 - style
|
1010
1011
|
if "image_embeds" not in added_cond_kwargs:
|
1011
1012
|
raise ValueError(
|
1012
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in
|
1013
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'text_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1013
1014
|
)
|
1014
1015
|
|
1015
1016
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
@@ -1018,14 +1019,14 @@ class UNet2DConditionModel(
|
|
1018
1019
|
# Kandinsky 2.2 - style
|
1019
1020
|
if "image_embeds" not in added_cond_kwargs:
|
1020
1021
|
raise ValueError(
|
1021
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in
|
1022
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1022
1023
|
)
|
1023
1024
|
image_embeds = added_cond_kwargs.get("image_embeds")
|
1024
1025
|
encoder_hidden_states = self.encoder_hid_proj(image_embeds)
|
1025
1026
|
elif self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
1026
1027
|
if "image_embeds" not in added_cond_kwargs:
|
1027
1028
|
raise ValueError(
|
1028
|
-
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in
|
1029
|
+
f"{self.__class__} has the config param `encoder_hid_dim_type` set to 'ip_image_proj' which requires the keyword argument `image_embeds` to be passed in `added_cond_kwargs`"
|
1029
1030
|
)
|
1030
1031
|
|
1031
1032
|
if hasattr(self, "text_encoder_hid_proj") and self.text_encoder_hid_proj is not None:
|
@@ -1140,7 +1141,6 @@ class UNet2DConditionModel(
|
|
1140
1141
|
# 1. time
|
1141
1142
|
t_emb = self.get_time_embed(sample=sample, timestep=timestep)
|
1142
1143
|
emb = self.time_embedding(t_emb, timestep_cond)
|
1143
|
-
aug_emb = None
|
1144
1144
|
|
1145
1145
|
class_emb = self.get_class_embed(sample=sample, class_labels=class_labels)
|
1146
1146
|
if class_emb is not None:
|
@@ -116,7 +116,7 @@ class AnimateDiffTransformer3D(nn.Module):
|
|
116
116
|
|
117
117
|
self.in_channels = in_channels
|
118
118
|
|
119
|
-
self.norm =
|
119
|
+
self.norm = nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
|
120
120
|
self.proj_in = nn.Linear(in_channels, inner_dim)
|
121
121
|
|
122
122
|
# 3. Define transformers blocks
|
@@ -187,12 +187,12 @@ class AnimateDiffTransformer3D(nn.Module):
|
|
187
187
|
hidden_states = self.norm(hidden_states)
|
188
188
|
hidden_states = hidden_states.permute(0, 3, 4, 2, 1).reshape(batch_size * height * width, num_frames, channel)
|
189
189
|
|
190
|
-
hidden_states = self.proj_in(hidden_states)
|
190
|
+
hidden_states = self.proj_in(input=hidden_states)
|
191
191
|
|
192
192
|
# 2. Blocks
|
193
193
|
for block in self.transformer_blocks:
|
194
194
|
hidden_states = block(
|
195
|
-
hidden_states,
|
195
|
+
hidden_states=hidden_states,
|
196
196
|
encoder_hidden_states=encoder_hidden_states,
|
197
197
|
timestep=timestep,
|
198
198
|
cross_attention_kwargs=cross_attention_kwargs,
|
@@ -200,7 +200,7 @@ class AnimateDiffTransformer3D(nn.Module):
|
|
200
200
|
)
|
201
201
|
|
202
202
|
# 3. Output
|
203
|
-
hidden_states = self.proj_out(hidden_states)
|
203
|
+
hidden_states = self.proj_out(input=hidden_states)
|
204
204
|
hidden_states = (
|
205
205
|
hidden_states[None, None, :]
|
206
206
|
.reshape(batch_size, height, width, num_frames, channel)
|
@@ -344,7 +344,7 @@ class DownBlockMotion(nn.Module):
|
|
344
344
|
)
|
345
345
|
|
346
346
|
else:
|
347
|
-
hidden_states = resnet(hidden_states, temb)
|
347
|
+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
348
348
|
|
349
349
|
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
350
350
|
|
@@ -352,7 +352,7 @@ class DownBlockMotion(nn.Module):
|
|
352
352
|
|
353
353
|
if self.downsamplers is not None:
|
354
354
|
for downsampler in self.downsamplers:
|
355
|
-
hidden_states = downsampler(hidden_states)
|
355
|
+
hidden_states = downsampler(hidden_states=hidden_states)
|
356
356
|
|
357
357
|
output_states = output_states + (hidden_states,)
|
358
358
|
|
@@ -531,25 +531,18 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
531
531
|
temb,
|
532
532
|
**ckpt_kwargs,
|
533
533
|
)
|
534
|
-
hidden_states = attn(
|
535
|
-
hidden_states,
|
536
|
-
encoder_hidden_states=encoder_hidden_states,
|
537
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
538
|
-
attention_mask=attention_mask,
|
539
|
-
encoder_attention_mask=encoder_attention_mask,
|
540
|
-
return_dict=False,
|
541
|
-
)[0]
|
542
534
|
else:
|
543
|
-
hidden_states = resnet(hidden_states, temb)
|
535
|
+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
536
|
+
|
537
|
+
hidden_states = attn(
|
538
|
+
hidden_states=hidden_states,
|
539
|
+
encoder_hidden_states=encoder_hidden_states,
|
540
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
541
|
+
attention_mask=attention_mask,
|
542
|
+
encoder_attention_mask=encoder_attention_mask,
|
543
|
+
return_dict=False,
|
544
|
+
)[0]
|
544
545
|
|
545
|
-
hidden_states = attn(
|
546
|
-
hidden_states,
|
547
|
-
encoder_hidden_states=encoder_hidden_states,
|
548
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
549
|
-
attention_mask=attention_mask,
|
550
|
-
encoder_attention_mask=encoder_attention_mask,
|
551
|
-
return_dict=False,
|
552
|
-
)[0]
|
553
546
|
hidden_states = motion_module(
|
554
547
|
hidden_states,
|
555
548
|
num_frames=num_frames,
|
@@ -563,7 +556,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
563
556
|
|
564
557
|
if self.downsamplers is not None:
|
565
558
|
for downsampler in self.downsamplers:
|
566
|
-
hidden_states = downsampler(hidden_states)
|
559
|
+
hidden_states = downsampler(hidden_states=hidden_states)
|
567
560
|
|
568
561
|
output_states = output_states + (hidden_states,)
|
569
562
|
|
@@ -757,25 +750,18 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
757
750
|
temb,
|
758
751
|
**ckpt_kwargs,
|
759
752
|
)
|
760
|
-
hidden_states = attn(
|
761
|
-
hidden_states,
|
762
|
-
encoder_hidden_states=encoder_hidden_states,
|
763
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
764
|
-
attention_mask=attention_mask,
|
765
|
-
encoder_attention_mask=encoder_attention_mask,
|
766
|
-
return_dict=False,
|
767
|
-
)[0]
|
768
753
|
else:
|
769
|
-
hidden_states = resnet(hidden_states, temb)
|
754
|
+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
755
|
+
|
756
|
+
hidden_states = attn(
|
757
|
+
hidden_states=hidden_states,
|
758
|
+
encoder_hidden_states=encoder_hidden_states,
|
759
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
760
|
+
attention_mask=attention_mask,
|
761
|
+
encoder_attention_mask=encoder_attention_mask,
|
762
|
+
return_dict=False,
|
763
|
+
)[0]
|
770
764
|
|
771
|
-
hidden_states = attn(
|
772
|
-
hidden_states,
|
773
|
-
encoder_hidden_states=encoder_hidden_states,
|
774
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
775
|
-
attention_mask=attention_mask,
|
776
|
-
encoder_attention_mask=encoder_attention_mask,
|
777
|
-
return_dict=False,
|
778
|
-
)[0]
|
779
765
|
hidden_states = motion_module(
|
780
766
|
hidden_states,
|
781
767
|
num_frames=num_frames,
|
@@ -783,7 +769,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
783
769
|
|
784
770
|
if self.upsamplers is not None:
|
785
771
|
for upsampler in self.upsamplers:
|
786
|
-
hidden_states = upsampler(hidden_states, upsample_size)
|
772
|
+
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
|
787
773
|
|
788
774
|
return hidden_states
|
789
775
|
|
@@ -929,13 +915,13 @@ class UpBlockMotion(nn.Module):
|
|
929
915
|
create_custom_forward(resnet), hidden_states, temb
|
930
916
|
)
|
931
917
|
else:
|
932
|
-
hidden_states = resnet(hidden_states, temb)
|
918
|
+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
933
919
|
|
934
920
|
hidden_states = motion_module(hidden_states, num_frames=num_frames)
|
935
921
|
|
936
922
|
if self.upsamplers is not None:
|
937
923
|
for upsampler in self.upsamplers:
|
938
|
-
hidden_states = upsampler(hidden_states, upsample_size)
|
924
|
+
hidden_states = upsampler(hidden_states=hidden_states, output_size=upsample_size)
|
939
925
|
|
940
926
|
return hidden_states
|
941
927
|
|
@@ -1080,10 +1066,19 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1080
1066
|
if cross_attention_kwargs.get("scale", None) is not None:
|
1081
1067
|
logger.warning("Passing `scale` to `cross_attention_kwargs` is deprecated. `scale` will be ignored.")
|
1082
1068
|
|
1083
|
-
hidden_states = self.resnets[0](hidden_states, temb)
|
1069
|
+
hidden_states = self.resnets[0](input_tensor=hidden_states, temb=temb)
|
1084
1070
|
|
1085
1071
|
blocks = zip(self.attentions, self.resnets[1:], self.motion_modules)
|
1086
1072
|
for attn, resnet, motion_module in blocks:
|
1073
|
+
hidden_states = attn(
|
1074
|
+
hidden_states=hidden_states,
|
1075
|
+
encoder_hidden_states=encoder_hidden_states,
|
1076
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1077
|
+
attention_mask=attention_mask,
|
1078
|
+
encoder_attention_mask=encoder_attention_mask,
|
1079
|
+
return_dict=False,
|
1080
|
+
)[0]
|
1081
|
+
|
1087
1082
|
if self.training and self.gradient_checkpointing:
|
1088
1083
|
|
1089
1084
|
def create_custom_forward(module, return_dict=None):
|
@@ -1096,14 +1091,6 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1096
1091
|
return custom_forward
|
1097
1092
|
|
1098
1093
|
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1099
|
-
hidden_states = attn(
|
1100
|
-
hidden_states,
|
1101
|
-
encoder_hidden_states=encoder_hidden_states,
|
1102
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
1103
|
-
attention_mask=attention_mask,
|
1104
|
-
encoder_attention_mask=encoder_attention_mask,
|
1105
|
-
return_dict=False,
|
1106
|
-
)[0]
|
1107
1094
|
hidden_states = torch.utils.checkpoint.checkpoint(
|
1108
1095
|
create_custom_forward(motion_module),
|
1109
1096
|
hidden_states,
|
@@ -1117,19 +1104,11 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1117
1104
|
**ckpt_kwargs,
|
1118
1105
|
)
|
1119
1106
|
else:
|
1120
|
-
hidden_states = attn(
|
1121
|
-
hidden_states,
|
1122
|
-
encoder_hidden_states=encoder_hidden_states,
|
1123
|
-
cross_attention_kwargs=cross_attention_kwargs,
|
1124
|
-
attention_mask=attention_mask,
|
1125
|
-
encoder_attention_mask=encoder_attention_mask,
|
1126
|
-
return_dict=False,
|
1127
|
-
)[0]
|
1128
1107
|
hidden_states = motion_module(
|
1129
1108
|
hidden_states,
|
1130
1109
|
num_frames=num_frames,
|
1131
1110
|
)
|
1132
|
-
hidden_states = resnet(hidden_states, temb)
|
1111
|
+
hidden_states = resnet(input_tensor=hidden_states, temb=temb)
|
1133
1112
|
|
1134
1113
|
return hidden_states
|
1135
1114
|
|
@@ -2178,7 +2157,6 @@ class UNetMotionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Peft
|
|
2178
2157
|
|
2179
2158
|
emb = emb if aug_emb is None else emb + aug_emb
|
2180
2159
|
emb = emb.repeat_interleave(repeats=num_frames, dim=0)
|
2181
|
-
encoder_hidden_states = encoder_hidden_states.repeat_interleave(repeats=num_frames, dim=0)
|
2182
2160
|
|
2183
2161
|
if self.encoder_hid_proj is not None and self.config.encoder_hid_dim_type == "ip_image_proj":
|
2184
2162
|
if "image_embeds" not in added_cond_kwargs:
|
diffusers/models/upsampling.py
CHANGED
@@ -19,6 +19,7 @@ import torch.nn as nn
|
|
19
19
|
import torch.nn.functional as F
|
20
20
|
|
21
21
|
from ..utils import deprecate
|
22
|
+
from ..utils.import_utils import is_torch_version
|
22
23
|
from .normalization import RMSNorm
|
23
24
|
|
24
25
|
|
@@ -151,11 +152,10 @@ class Upsample2D(nn.Module):
|
|
151
152
|
if self.use_conv_transpose:
|
152
153
|
return self.conv(hidden_states)
|
153
154
|
|
154
|
-
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
|
155
|
-
#
|
156
|
-
# https://github.com/pytorch/pytorch/issues/86679
|
155
|
+
# Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 until PyTorch 2.1
|
156
|
+
# https://github.com/pytorch/pytorch/issues/86679#issuecomment-1783978767
|
157
157
|
dtype = hidden_states.dtype
|
158
|
-
if dtype == torch.bfloat16:
|
158
|
+
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
159
159
|
hidden_states = hidden_states.to(torch.float32)
|
160
160
|
|
161
161
|
# upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
|
@@ -170,8 +170,8 @@ class Upsample2D(nn.Module):
|
|
170
170
|
else:
|
171
171
|
hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
|
172
172
|
|
173
|
-
#
|
174
|
-
if dtype == torch.bfloat16:
|
173
|
+
# Cast back to original dtype
|
174
|
+
if dtype == torch.bfloat16 and is_torch_version("<", "2.1"):
|
175
175
|
hidden_states = hidden_states.to(dtype)
|
176
176
|
|
177
177
|
# TODO(Suraj, Patrick) - clean up after weight dicts are correctly renamed
|