diffusers 0.29.0__py3-none-any.whl → 0.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +7 -1
- diffusers/loaders/single_file.py +12 -0
- diffusers/loaders/single_file_model.py +10 -8
- diffusers/loaders/single_file_utils.py +33 -23
- diffusers/models/__init__.py +2 -0
- diffusers/models/controlnet_sd3.py +418 -0
- diffusers/models/modeling_utils.py +10 -3
- diffusers/models/transformers/transformer_sd3.py +16 -7
- diffusers/pipelines/__init__.py +9 -0
- diffusers/pipelines/auto_pipeline.py +8 -0
- diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +23 -5
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +23 -5
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
- {diffusers-0.29.0.dist-info → diffusers-0.29.1.dist-info}/METADATA +44 -44
- {diffusers-0.29.0.dist-info → diffusers-0.29.1.dist-info}/RECORD +22 -19
- {diffusers-0.29.0.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
- {diffusers-0.29.0.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.0.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.0.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.29.
|
1
|
+
__version__ = "0.29.1"
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
@@ -91,6 +91,8 @@ else:
|
|
91
91
|
"MultiAdapter",
|
92
92
|
"PixArtTransformer2DModel",
|
93
93
|
"PriorTransformer",
|
94
|
+
"SD3ControlNetModel",
|
95
|
+
"SD3MultiControlNetModel",
|
94
96
|
"SD3Transformer2DModel",
|
95
97
|
"StableCascadeUNet",
|
96
98
|
"T2IAdapter",
|
@@ -278,6 +280,7 @@ else:
|
|
278
280
|
"StableCascadeCombinedPipeline",
|
279
281
|
"StableCascadeDecoderPipeline",
|
280
282
|
"StableCascadePriorPipeline",
|
283
|
+
"StableDiffusion3ControlNetPipeline",
|
281
284
|
"StableDiffusion3Img2ImgPipeline",
|
282
285
|
"StableDiffusion3Pipeline",
|
283
286
|
"StableDiffusionAdapterPipeline",
|
@@ -501,6 +504,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
501
504
|
MultiAdapter,
|
502
505
|
PixArtTransformer2DModel,
|
503
506
|
PriorTransformer,
|
507
|
+
SD3ControlNetModel,
|
508
|
+
SD3MultiControlNetModel,
|
504
509
|
SD3Transformer2DModel,
|
505
510
|
T2IAdapter,
|
506
511
|
T5FilmDecoder,
|
@@ -666,6 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
666
671
|
StableCascadeCombinedPipeline,
|
667
672
|
StableCascadeDecoderPipeline,
|
668
673
|
StableCascadePriorPipeline,
|
674
|
+
StableDiffusion3ControlNetPipeline,
|
669
675
|
StableDiffusion3Img2ImgPipeline,
|
670
676
|
StableDiffusion3Pipeline,
|
671
677
|
StableDiffusionAdapterPipeline,
|
diffusers/loaders/single_file.py
CHANGED
@@ -28,9 +28,11 @@ from .single_file_utils import (
|
|
28
28
|
_legacy_load_safety_checker,
|
29
29
|
_legacy_load_scheduler,
|
30
30
|
create_diffusers_clip_model_from_ldm,
|
31
|
+
create_diffusers_t5_model_from_checkpoint,
|
31
32
|
fetch_diffusers_config,
|
32
33
|
fetch_original_config,
|
33
34
|
is_clip_model_in_single_file,
|
35
|
+
is_t5_in_single_file,
|
34
36
|
load_single_file_checkpoint,
|
35
37
|
)
|
36
38
|
|
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
|
|
118
120
|
is_legacy_loading=is_legacy_loading,
|
119
121
|
)
|
120
122
|
|
123
|
+
elif is_transformers_model and is_t5_in_single_file(checkpoint):
|
124
|
+
loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
|
125
|
+
class_obj,
|
126
|
+
checkpoint=checkpoint,
|
127
|
+
config=cached_model_config_path,
|
128
|
+
subfolder=name,
|
129
|
+
torch_dtype=torch_dtype,
|
130
|
+
local_files_only=local_files_only,
|
131
|
+
)
|
132
|
+
|
121
133
|
elif is_tokenizer and is_legacy_loading:
|
122
134
|
loaded_sub_model = _legacy_load_clip_tokenizer(
|
123
135
|
class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
|
@@ -276,16 +276,18 @@ class FromOriginalModelMixin:
|
|
276
276
|
|
277
277
|
if is_accelerate_available():
|
278
278
|
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
279
|
-
if model._keys_to_ignore_on_load_unexpected is not None:
|
280
|
-
for pat in model._keys_to_ignore_on_load_unexpected:
|
281
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
282
279
|
|
283
|
-
if len(unexpected_keys) > 0:
|
284
|
-
logger.warning(
|
285
|
-
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
286
|
-
)
|
287
280
|
else:
|
288
|
-
model.load_state_dict(diffusers_format_checkpoint)
|
281
|
+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
282
|
+
|
283
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
284
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
285
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
286
|
+
|
287
|
+
if len(unexpected_keys) > 0:
|
288
|
+
logger.warning(
|
289
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
290
|
+
)
|
289
291
|
|
290
292
|
if torch_dtype is not None:
|
291
293
|
model.to(torch_dtype)
|
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
|
|
252
252
|
LDM_CLIP_PREFIX_TO_REMOVE = [
|
253
253
|
"cond_stage_model.transformer.",
|
254
254
|
"conditioner.embedders.0.transformer.",
|
255
|
-
"text_encoders.clip_l.transformer.",
|
256
255
|
]
|
257
256
|
OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
|
258
257
|
LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
|
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
|
|
399
398
|
|
400
399
|
|
401
400
|
def is_open_clip_sd3_model(checkpoint):
|
402
|
-
|
401
|
+
if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
|
402
|
+
return True
|
403
|
+
|
404
|
+
return False
|
403
405
|
|
404
406
|
|
405
407
|
def is_open_clip_sdxl_refiner_model(checkpoint):
|
406
|
-
if CHECKPOINT_KEY_NAMES["
|
408
|
+
if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
|
407
409
|
return True
|
408
410
|
|
409
411
|
return False
|
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
|
|
1233
1235
|
return new_checkpoint
|
1234
1236
|
|
1235
1237
|
|
1236
|
-
def convert_ldm_clip_checkpoint(checkpoint):
|
1238
|
+
def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
|
1237
1239
|
keys = list(checkpoint.keys())
|
1238
1240
|
text_model_dict = {}
|
1239
1241
|
|
1240
|
-
remove_prefixes =
|
1242
|
+
remove_prefixes = []
|
1243
|
+
remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
|
1244
|
+
if remove_prefix:
|
1245
|
+
remove_prefixes.append(remove_prefix)
|
1241
1246
|
|
1242
1247
|
for key in keys:
|
1243
1248
|
for prefix in remove_prefixes:
|
@@ -1263,8 +1268,6 @@ def convert_open_clip_checkpoint(
|
|
1263
1268
|
else:
|
1264
1269
|
text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
|
1265
1270
|
|
1266
|
-
text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
|
1267
|
-
|
1268
1271
|
keys = list(checkpoint.keys())
|
1269
1272
|
keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
|
1270
1273
|
|
@@ -1313,9 +1316,6 @@ def convert_open_clip_checkpoint(
|
|
1313
1316
|
else:
|
1314
1317
|
text_model_dict[diffusers_key] = checkpoint.get(key)
|
1315
1318
|
|
1316
|
-
if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
|
1317
|
-
text_model_dict.pop("text_model.embeddings.position_ids", None)
|
1318
|
-
|
1319
1319
|
return text_model_dict
|
1320
1320
|
|
1321
1321
|
|
@@ -1376,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
|
|
1376
1376
|
):
|
1377
1377
|
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
|
1378
1378
|
|
1379
|
+
elif (
|
1380
|
+
is_clip_sd3_model(checkpoint)
|
1381
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
|
1382
|
+
):
|
1383
|
+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
|
1384
|
+
diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
|
1385
|
+
|
1379
1386
|
elif is_open_clip_model(checkpoint):
|
1380
1387
|
prefix = "cond_stage_model.model."
|
1381
1388
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
@@ -1391,26 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
|
|
1391
1398
|
prefix = "conditioner.embedders.0.model."
|
1392
1399
|
diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
|
1393
1400
|
|
1394
|
-
elif
|
1395
|
-
|
1396
|
-
|
1401
|
+
elif (
|
1402
|
+
is_open_clip_sd3_model(checkpoint)
|
1403
|
+
and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
|
1404
|
+
):
|
1405
|
+
diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
|
1397
1406
|
|
1398
1407
|
else:
|
1399
1408
|
raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
|
1400
1409
|
|
1401
1410
|
if is_accelerate_available():
|
1402
1411
|
unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
|
1403
|
-
|
1404
|
-
|
1405
|
-
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1412
|
+
else:
|
1413
|
+
_, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
|
1406
1414
|
|
1407
|
-
|
1408
|
-
|
1409
|
-
|
1410
|
-
)
|
1415
|
+
if model._keys_to_ignore_on_load_unexpected is not None:
|
1416
|
+
for pat in model._keys_to_ignore_on_load_unexpected:
|
1417
|
+
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1411
1418
|
|
1412
|
-
|
1413
|
-
|
1419
|
+
if len(unexpected_keys) > 0:
|
1420
|
+
logger.warning(
|
1421
|
+
f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
|
1422
|
+
)
|
1414
1423
|
|
1415
1424
|
if torch_dtype is not None:
|
1416
1425
|
model.to(torch_dtype)
|
@@ -1755,7 +1764,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
|
|
1755
1764
|
keys = list(checkpoint.keys())
|
1756
1765
|
text_model_dict = {}
|
1757
1766
|
|
1758
|
-
remove_prefixes = ["text_encoders.t5xxl.transformer.
|
1767
|
+
remove_prefixes = ["text_encoders.t5xxl.transformer."]
|
1759
1768
|
|
1760
1769
|
for key in keys:
|
1761
1770
|
for prefix in remove_prefixes:
|
@@ -1799,3 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
|
|
1799
1808
|
|
1800
1809
|
else:
|
1801
1810
|
model.load_state_dict(diffusers_format_checkpoint)
|
1811
|
+
return model
|
diffusers/models/__init__.py
CHANGED
@@ -33,6 +33,7 @@ if is_torch_available():
|
|
33
33
|
_import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
|
34
34
|
_import_structure["autoencoders.vq_model"] = ["VQModel"]
|
35
35
|
_import_structure["controlnet"] = ["ControlNetModel"]
|
36
|
+
_import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
|
36
37
|
_import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
|
37
38
|
_import_structure["embeddings"] = ["ImageProjection"]
|
38
39
|
_import_structure["modeling_utils"] = ["ModelMixin"]
|
@@ -74,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
74
75
|
VQModel,
|
75
76
|
)
|
76
77
|
from .controlnet import ControlNetModel
|
78
|
+
from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
|
77
79
|
from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
|
78
80
|
from .embeddings import ImageProjection
|
79
81
|
from .modeling_utils import ModelMixin
|
@@ -0,0 +1,418 @@
|
|
1
|
+
# Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
|
16
|
+
from dataclasses import dataclass
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
|
22
|
+
from ..configuration_utils import ConfigMixin, register_to_config
|
23
|
+
from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
|
24
|
+
from ..models.attention import JointTransformerBlock
|
25
|
+
from ..models.attention_processor import Attention, AttentionProcessor
|
26
|
+
from ..models.modeling_utils import ModelMixin
|
27
|
+
from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
28
|
+
from .controlnet import BaseOutput, zero_module
|
29
|
+
from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
|
30
|
+
from .transformers.transformer_2d import Transformer2DModelOutput
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
34
|
+
|
35
|
+
|
36
|
+
@dataclass
|
37
|
+
class SD3ControlNetOutput(BaseOutput):
|
38
|
+
controlnet_block_samples: Tuple[torch.Tensor]
|
39
|
+
|
40
|
+
|
41
|
+
class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
|
42
|
+
_supports_gradient_checkpointing = True
|
43
|
+
|
44
|
+
@register_to_config
|
45
|
+
def __init__(
|
46
|
+
self,
|
47
|
+
sample_size: int = 128,
|
48
|
+
patch_size: int = 2,
|
49
|
+
in_channels: int = 16,
|
50
|
+
num_layers: int = 18,
|
51
|
+
attention_head_dim: int = 64,
|
52
|
+
num_attention_heads: int = 18,
|
53
|
+
joint_attention_dim: int = 4096,
|
54
|
+
caption_projection_dim: int = 1152,
|
55
|
+
pooled_projection_dim: int = 2048,
|
56
|
+
out_channels: int = 16,
|
57
|
+
pos_embed_max_size: int = 96,
|
58
|
+
):
|
59
|
+
super().__init__()
|
60
|
+
default_out_channels = in_channels
|
61
|
+
self.out_channels = out_channels if out_channels is not None else default_out_channels
|
62
|
+
self.inner_dim = num_attention_heads * attention_head_dim
|
63
|
+
|
64
|
+
self.pos_embed = PatchEmbed(
|
65
|
+
height=sample_size,
|
66
|
+
width=sample_size,
|
67
|
+
patch_size=patch_size,
|
68
|
+
in_channels=in_channels,
|
69
|
+
embed_dim=self.inner_dim,
|
70
|
+
pos_embed_max_size=pos_embed_max_size,
|
71
|
+
)
|
72
|
+
self.time_text_embed = CombinedTimestepTextProjEmbeddings(
|
73
|
+
embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
|
74
|
+
)
|
75
|
+
self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
|
76
|
+
|
77
|
+
# `attention_head_dim` is doubled to account for the mixing.
|
78
|
+
# It needs to crafted when we get the actual checkpoints.
|
79
|
+
self.transformer_blocks = nn.ModuleList(
|
80
|
+
[
|
81
|
+
JointTransformerBlock(
|
82
|
+
dim=self.inner_dim,
|
83
|
+
num_attention_heads=num_attention_heads,
|
84
|
+
attention_head_dim=self.inner_dim,
|
85
|
+
context_pre_only=False,
|
86
|
+
)
|
87
|
+
for i in range(num_layers)
|
88
|
+
]
|
89
|
+
)
|
90
|
+
|
91
|
+
# controlnet_blocks
|
92
|
+
self.controlnet_blocks = nn.ModuleList([])
|
93
|
+
for _ in range(len(self.transformer_blocks)):
|
94
|
+
controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
|
95
|
+
controlnet_block = zero_module(controlnet_block)
|
96
|
+
self.controlnet_blocks.append(controlnet_block)
|
97
|
+
pos_embed_input = PatchEmbed(
|
98
|
+
height=sample_size,
|
99
|
+
width=sample_size,
|
100
|
+
patch_size=patch_size,
|
101
|
+
in_channels=in_channels,
|
102
|
+
embed_dim=self.inner_dim,
|
103
|
+
pos_embed_type=None,
|
104
|
+
)
|
105
|
+
self.pos_embed_input = zero_module(pos_embed_input)
|
106
|
+
|
107
|
+
self.gradient_checkpointing = False
|
108
|
+
|
109
|
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
110
|
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
111
|
+
"""
|
112
|
+
Sets the attention processor to use [feed forward
|
113
|
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
114
|
+
|
115
|
+
Parameters:
|
116
|
+
chunk_size (`int`, *optional*):
|
117
|
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
118
|
+
over each tensor of dim=`dim`.
|
119
|
+
dim (`int`, *optional*, defaults to `0`):
|
120
|
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
121
|
+
or dim=1 (sequence length).
|
122
|
+
"""
|
123
|
+
if dim not in [0, 1]:
|
124
|
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
125
|
+
|
126
|
+
# By default chunk size is 1
|
127
|
+
chunk_size = chunk_size or 1
|
128
|
+
|
129
|
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
130
|
+
if hasattr(module, "set_chunk_feed_forward"):
|
131
|
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
132
|
+
|
133
|
+
for child in module.children():
|
134
|
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
135
|
+
|
136
|
+
for module in self.children():
|
137
|
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
138
|
+
|
139
|
+
@property
|
140
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
141
|
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
142
|
+
r"""
|
143
|
+
Returns:
|
144
|
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
145
|
+
indexed by its weight name.
|
146
|
+
"""
|
147
|
+
# set recursively
|
148
|
+
processors = {}
|
149
|
+
|
150
|
+
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
151
|
+
if hasattr(module, "get_processor"):
|
152
|
+
processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
|
153
|
+
|
154
|
+
for sub_name, child in module.named_children():
|
155
|
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
156
|
+
|
157
|
+
return processors
|
158
|
+
|
159
|
+
for name, module in self.named_children():
|
160
|
+
fn_recursive_add_processors(name, module, processors)
|
161
|
+
|
162
|
+
return processors
|
163
|
+
|
164
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
165
|
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
166
|
+
r"""
|
167
|
+
Sets the attention processor to use to compute attention.
|
168
|
+
|
169
|
+
Parameters:
|
170
|
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
171
|
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
172
|
+
for **all** `Attention` layers.
|
173
|
+
|
174
|
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
175
|
+
processor. This is strongly recommended when setting trainable attention processors.
|
176
|
+
|
177
|
+
"""
|
178
|
+
count = len(self.attn_processors.keys())
|
179
|
+
|
180
|
+
if isinstance(processor, dict) and len(processor) != count:
|
181
|
+
raise ValueError(
|
182
|
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
183
|
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
184
|
+
)
|
185
|
+
|
186
|
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
187
|
+
if hasattr(module, "set_processor"):
|
188
|
+
if not isinstance(processor, dict):
|
189
|
+
module.set_processor(processor)
|
190
|
+
else:
|
191
|
+
module.set_processor(processor.pop(f"{name}.processor"))
|
192
|
+
|
193
|
+
for sub_name, child in module.named_children():
|
194
|
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
195
|
+
|
196
|
+
for name, module in self.named_children():
|
197
|
+
fn_recursive_attn_processor(name, module, processor)
|
198
|
+
|
199
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
200
|
+
def fuse_qkv_projections(self):
|
201
|
+
"""
|
202
|
+
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
203
|
+
are fused. For cross-attention modules, key and value projection matrices are fused.
|
204
|
+
|
205
|
+
<Tip warning={true}>
|
206
|
+
|
207
|
+
This API is 🧪 experimental.
|
208
|
+
|
209
|
+
</Tip>
|
210
|
+
"""
|
211
|
+
self.original_attn_processors = None
|
212
|
+
|
213
|
+
for _, attn_processor in self.attn_processors.items():
|
214
|
+
if "Added" in str(attn_processor.__class__.__name__):
|
215
|
+
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
216
|
+
|
217
|
+
self.original_attn_processors = self.attn_processors
|
218
|
+
|
219
|
+
for module in self.modules():
|
220
|
+
if isinstance(module, Attention):
|
221
|
+
module.fuse_projections(fuse=True)
|
222
|
+
|
223
|
+
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
224
|
+
def unfuse_qkv_projections(self):
|
225
|
+
"""Disables the fused QKV projection if enabled.
|
226
|
+
|
227
|
+
<Tip warning={true}>
|
228
|
+
|
229
|
+
This API is 🧪 experimental.
|
230
|
+
|
231
|
+
</Tip>
|
232
|
+
|
233
|
+
"""
|
234
|
+
if self.original_attn_processors is not None:
|
235
|
+
self.set_attn_processor(self.original_attn_processors)
|
236
|
+
|
237
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
238
|
+
if hasattr(module, "gradient_checkpointing"):
|
239
|
+
module.gradient_checkpointing = value
|
240
|
+
|
241
|
+
@classmethod
|
242
|
+
def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
|
243
|
+
config = transformer.config
|
244
|
+
config["num_layers"] = num_layers or config.num_layers
|
245
|
+
controlnet = cls(**config)
|
246
|
+
|
247
|
+
if load_weights_from_transformer:
|
248
|
+
controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
|
249
|
+
controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
|
250
|
+
controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
|
251
|
+
controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
|
252
|
+
|
253
|
+
controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
|
254
|
+
|
255
|
+
return controlnet
|
256
|
+
|
257
|
+
def forward(
|
258
|
+
self,
|
259
|
+
hidden_states: torch.FloatTensor,
|
260
|
+
controlnet_cond: torch.Tensor,
|
261
|
+
conditioning_scale: float = 1.0,
|
262
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
263
|
+
pooled_projections: torch.FloatTensor = None,
|
264
|
+
timestep: torch.LongTensor = None,
|
265
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
266
|
+
return_dict: bool = True,
|
267
|
+
) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
|
268
|
+
"""
|
269
|
+
The [`SD3Transformer2DModel`] forward method.
|
270
|
+
|
271
|
+
Args:
|
272
|
+
hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
|
273
|
+
Input `hidden_states`.
|
274
|
+
controlnet_cond (`torch.Tensor`):
|
275
|
+
The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
|
276
|
+
conditioning_scale (`float`, defaults to `1.0`):
|
277
|
+
The scale factor for ControlNet outputs.
|
278
|
+
encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
|
279
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
280
|
+
pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
|
281
|
+
from the embeddings of input conditions.
|
282
|
+
timestep ( `torch.LongTensor`):
|
283
|
+
Used to indicate denoising step.
|
284
|
+
joint_attention_kwargs (`dict`, *optional*):
|
285
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
286
|
+
`self.processor` in
|
287
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
288
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
289
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
290
|
+
tuple.
|
291
|
+
|
292
|
+
Returns:
|
293
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
294
|
+
`tuple` where the first element is the sample tensor.
|
295
|
+
"""
|
296
|
+
if joint_attention_kwargs is not None:
|
297
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
298
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
299
|
+
else:
|
300
|
+
lora_scale = 1.0
|
301
|
+
|
302
|
+
if USE_PEFT_BACKEND:
|
303
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
304
|
+
scale_lora_layers(self, lora_scale)
|
305
|
+
else:
|
306
|
+
if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
|
307
|
+
logger.warning(
|
308
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
309
|
+
)
|
310
|
+
|
311
|
+
height, width = hidden_states.shape[-2:]
|
312
|
+
|
313
|
+
hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
|
314
|
+
temb = self.time_text_embed(timestep, pooled_projections)
|
315
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
316
|
+
|
317
|
+
# add
|
318
|
+
hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
|
319
|
+
|
320
|
+
block_res_samples = ()
|
321
|
+
|
322
|
+
for block in self.transformer_blocks:
|
323
|
+
if self.training and self.gradient_checkpointing:
|
324
|
+
|
325
|
+
def create_custom_forward(module, return_dict=None):
|
326
|
+
def custom_forward(*inputs):
|
327
|
+
if return_dict is not None:
|
328
|
+
return module(*inputs, return_dict=return_dict)
|
329
|
+
else:
|
330
|
+
return module(*inputs)
|
331
|
+
|
332
|
+
return custom_forward
|
333
|
+
|
334
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
335
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
336
|
+
create_custom_forward(block),
|
337
|
+
hidden_states,
|
338
|
+
encoder_hidden_states,
|
339
|
+
temb,
|
340
|
+
**ckpt_kwargs,
|
341
|
+
)
|
342
|
+
|
343
|
+
else:
|
344
|
+
encoder_hidden_states, hidden_states = block(
|
345
|
+
hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
|
346
|
+
)
|
347
|
+
|
348
|
+
block_res_samples = block_res_samples + (hidden_states,)
|
349
|
+
|
350
|
+
controlnet_block_res_samples = ()
|
351
|
+
for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
|
352
|
+
block_res_sample = controlnet_block(block_res_sample)
|
353
|
+
controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
|
354
|
+
|
355
|
+
# 6. scaling
|
356
|
+
controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
|
357
|
+
|
358
|
+
if USE_PEFT_BACKEND:
|
359
|
+
# remove `lora_scale` from each PEFT layer
|
360
|
+
unscale_lora_layers(self, lora_scale)
|
361
|
+
|
362
|
+
if not return_dict:
|
363
|
+
return (controlnet_block_res_samples,)
|
364
|
+
|
365
|
+
return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
|
366
|
+
|
367
|
+
|
368
|
+
class SD3MultiControlNetModel(ModelMixin):
|
369
|
+
r"""
|
370
|
+
`SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
|
371
|
+
|
372
|
+
This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
|
373
|
+
compatible with `SD3ControlNetModel`.
|
374
|
+
|
375
|
+
Args:
|
376
|
+
controlnets (`List[SD3ControlNetModel]`):
|
377
|
+
Provides additional conditioning to the unet during the denoising process. You must set multiple
|
378
|
+
`SD3ControlNetModel` as a list.
|
379
|
+
"""
|
380
|
+
|
381
|
+
def __init__(self, controlnets):
|
382
|
+
super().__init__()
|
383
|
+
self.nets = nn.ModuleList(controlnets)
|
384
|
+
|
385
|
+
def forward(
|
386
|
+
self,
|
387
|
+
hidden_states: torch.FloatTensor,
|
388
|
+
controlnet_cond: List[torch.tensor],
|
389
|
+
conditioning_scale: List[float],
|
390
|
+
pooled_projections: torch.FloatTensor,
|
391
|
+
encoder_hidden_states: torch.FloatTensor = None,
|
392
|
+
timestep: torch.LongTensor = None,
|
393
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
394
|
+
return_dict: bool = True,
|
395
|
+
) -> Union[SD3ControlNetOutput, Tuple]:
|
396
|
+
for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
|
397
|
+
block_samples = controlnet(
|
398
|
+
hidden_states=hidden_states,
|
399
|
+
timestep=timestep,
|
400
|
+
encoder_hidden_states=encoder_hidden_states,
|
401
|
+
pooled_projections=pooled_projections,
|
402
|
+
controlnet_cond=image,
|
403
|
+
conditioning_scale=scale,
|
404
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
405
|
+
return_dict=return_dict,
|
406
|
+
)
|
407
|
+
|
408
|
+
# merge samples
|
409
|
+
if i == 0:
|
410
|
+
control_block_samples = block_samples
|
411
|
+
else:
|
412
|
+
control_block_samples = [
|
413
|
+
control_block_sample + block_sample
|
414
|
+
for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
|
415
|
+
]
|
416
|
+
control_block_samples = (tuple(control_block_samples),)
|
417
|
+
|
418
|
+
return control_block_samples
|