diffusers 0.29.0__py3-none-any.whl → 0.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.29.0"
1
+ __version__ = "0.29.1"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -91,6 +91,8 @@ else:
91
91
  "MultiAdapter",
92
92
  "PixArtTransformer2DModel",
93
93
  "PriorTransformer",
94
+ "SD3ControlNetModel",
95
+ "SD3MultiControlNetModel",
94
96
  "SD3Transformer2DModel",
95
97
  "StableCascadeUNet",
96
98
  "T2IAdapter",
@@ -278,6 +280,7 @@ else:
278
280
  "StableCascadeCombinedPipeline",
279
281
  "StableCascadeDecoderPipeline",
280
282
  "StableCascadePriorPipeline",
283
+ "StableDiffusion3ControlNetPipeline",
281
284
  "StableDiffusion3Img2ImgPipeline",
282
285
  "StableDiffusion3Pipeline",
283
286
  "StableDiffusionAdapterPipeline",
@@ -501,6 +504,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
501
504
  MultiAdapter,
502
505
  PixArtTransformer2DModel,
503
506
  PriorTransformer,
507
+ SD3ControlNetModel,
508
+ SD3MultiControlNetModel,
504
509
  SD3Transformer2DModel,
505
510
  T2IAdapter,
506
511
  T5FilmDecoder,
@@ -666,6 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
666
671
  StableCascadeCombinedPipeline,
667
672
  StableCascadeDecoderPipeline,
668
673
  StableCascadePriorPipeline,
674
+ StableDiffusion3ControlNetPipeline,
669
675
  StableDiffusion3Img2ImgPipeline,
670
676
  StableDiffusion3Pipeline,
671
677
  StableDiffusionAdapterPipeline,
@@ -28,9 +28,11 @@ from .single_file_utils import (
28
28
  _legacy_load_safety_checker,
29
29
  _legacy_load_scheduler,
30
30
  create_diffusers_clip_model_from_ldm,
31
+ create_diffusers_t5_model_from_checkpoint,
31
32
  fetch_diffusers_config,
32
33
  fetch_original_config,
33
34
  is_clip_model_in_single_file,
35
+ is_t5_in_single_file,
34
36
  load_single_file_checkpoint,
35
37
  )
36
38
 
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
118
120
  is_legacy_loading=is_legacy_loading,
119
121
  )
120
122
 
123
+ elif is_transformers_model and is_t5_in_single_file(checkpoint):
124
+ loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
125
+ class_obj,
126
+ checkpoint=checkpoint,
127
+ config=cached_model_config_path,
128
+ subfolder=name,
129
+ torch_dtype=torch_dtype,
130
+ local_files_only=local_files_only,
131
+ )
132
+
121
133
  elif is_tokenizer and is_legacy_loading:
122
134
  loaded_sub_model = _legacy_load_clip_tokenizer(
123
135
  class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
@@ -276,16 +276,18 @@ class FromOriginalModelMixin:
276
276
 
277
277
  if is_accelerate_available():
278
278
  unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
279
- if model._keys_to_ignore_on_load_unexpected is not None:
280
- for pat in model._keys_to_ignore_on_load_unexpected:
281
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
282
279
 
283
- if len(unexpected_keys) > 0:
284
- logger.warning(
285
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
286
- )
287
280
  else:
288
- model.load_state_dict(diffusers_format_checkpoint)
281
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
282
+
283
+ if model._keys_to_ignore_on_load_unexpected is not None:
284
+ for pat in model._keys_to_ignore_on_load_unexpected:
285
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
286
+
287
+ if len(unexpected_keys) > 0:
288
+ logger.warning(
289
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
290
+ )
289
291
 
290
292
  if torch_dtype is not None:
291
293
  model.to(torch_dtype)
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
252
252
  LDM_CLIP_PREFIX_TO_REMOVE = [
253
253
  "cond_stage_model.transformer.",
254
254
  "conditioner.embedders.0.transformer.",
255
- "text_encoders.clip_l.transformer.",
256
255
  ]
257
256
  OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
258
257
  LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
399
398
 
400
399
 
401
400
  def is_open_clip_sd3_model(checkpoint):
402
- is_open_clip_sdxl_refiner_model(checkpoint)
401
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
402
+ return True
403
+
404
+ return False
403
405
 
404
406
 
405
407
  def is_open_clip_sdxl_refiner_model(checkpoint):
406
- if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
408
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
407
409
  return True
408
410
 
409
411
  return False
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1233
1235
  return new_checkpoint
1234
1236
 
1235
1237
 
1236
- def convert_ldm_clip_checkpoint(checkpoint):
1238
+ def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
1237
1239
  keys = list(checkpoint.keys())
1238
1240
  text_model_dict = {}
1239
1241
 
1240
- remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
1242
+ remove_prefixes = []
1243
+ remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
1244
+ if remove_prefix:
1245
+ remove_prefixes.append(remove_prefix)
1241
1246
 
1242
1247
  for key in keys:
1243
1248
  for prefix in remove_prefixes:
@@ -1263,8 +1268,6 @@ def convert_open_clip_checkpoint(
1263
1268
  else:
1264
1269
  text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1265
1270
 
1266
- text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
1267
-
1268
1271
  keys = list(checkpoint.keys())
1269
1272
  keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
1270
1273
 
@@ -1313,9 +1316,6 @@ def convert_open_clip_checkpoint(
1313
1316
  else:
1314
1317
  text_model_dict[diffusers_key] = checkpoint.get(key)
1315
1318
 
1316
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1317
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1318
-
1319
1319
  return text_model_dict
1320
1320
 
1321
1321
 
@@ -1376,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
1376
1376
  ):
1377
1377
  diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1378
1378
 
1379
+ elif (
1380
+ is_clip_sd3_model(checkpoint)
1381
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
1382
+ ):
1383
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
1384
+ diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
1385
+
1379
1386
  elif is_open_clip_model(checkpoint):
1380
1387
  prefix = "cond_stage_model.model."
1381
1388
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
@@ -1391,26 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
1391
1398
  prefix = "conditioner.embedders.0.model."
1392
1399
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1393
1400
 
1394
- elif is_open_clip_sd3_model(checkpoint):
1395
- prefix = "text_encoders.clip_g.transformer."
1396
- diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1401
+ elif (
1402
+ is_open_clip_sd3_model(checkpoint)
1403
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
1404
+ ):
1405
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
1397
1406
 
1398
1407
  else:
1399
1408
  raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1400
1409
 
1401
1410
  if is_accelerate_available():
1402
1411
  unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1403
- if model._keys_to_ignore_on_load_unexpected is not None:
1404
- for pat in model._keys_to_ignore_on_load_unexpected:
1405
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1412
+ else:
1413
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1406
1414
 
1407
- if len(unexpected_keys) > 0:
1408
- logger.warning(
1409
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1410
- )
1415
+ if model._keys_to_ignore_on_load_unexpected is not None:
1416
+ for pat in model._keys_to_ignore_on_load_unexpected:
1417
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1411
1418
 
1412
- else:
1413
- model.load_state_dict(diffusers_format_checkpoint)
1419
+ if len(unexpected_keys) > 0:
1420
+ logger.warning(
1421
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1422
+ )
1414
1423
 
1415
1424
  if torch_dtype is not None:
1416
1425
  model.to(torch_dtype)
@@ -1755,7 +1764,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
1755
1764
  keys = list(checkpoint.keys())
1756
1765
  text_model_dict = {}
1757
1766
 
1758
- remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
1767
+ remove_prefixes = ["text_encoders.t5xxl.transformer."]
1759
1768
 
1760
1769
  for key in keys:
1761
1770
  for prefix in remove_prefixes:
@@ -1799,3 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
1799
1808
 
1800
1809
  else:
1801
1810
  model.load_state_dict(diffusers_format_checkpoint)
1811
+ return model
@@ -33,6 +33,7 @@ if is_torch_available():
33
33
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
34
34
  _import_structure["autoencoders.vq_model"] = ["VQModel"]
35
35
  _import_structure["controlnet"] = ["ControlNetModel"]
36
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
36
37
  _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
37
38
  _import_structure["embeddings"] = ["ImageProjection"]
38
39
  _import_structure["modeling_utils"] = ["ModelMixin"]
@@ -74,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
74
75
  VQModel,
75
76
  )
76
77
  from .controlnet import ControlNetModel
78
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
77
79
  from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
78
80
  from .embeddings import ImageProjection
79
81
  from .modeling_utils import ModelMixin
@@ -0,0 +1,418 @@
1
+ # Copyright 2024 Stability AI, The HuggingFace Team and The InstantX Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Any, Dict, List, Optional, Tuple, Union
18
+
19
+ import torch
20
+ import torch.nn as nn
21
+
22
+ from ..configuration_utils import ConfigMixin, register_to_config
23
+ from ..loaders import FromOriginalModelMixin, PeftAdapterMixin
24
+ from ..models.attention import JointTransformerBlock
25
+ from ..models.attention_processor import Attention, AttentionProcessor
26
+ from ..models.modeling_utils import ModelMixin
27
+ from ..utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
28
+ from .controlnet import BaseOutput, zero_module
29
+ from .embeddings import CombinedTimestepTextProjEmbeddings, PatchEmbed
30
+ from .transformers.transformer_2d import Transformer2DModelOutput
31
+
32
+
33
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
+
35
+
36
+ @dataclass
37
+ class SD3ControlNetOutput(BaseOutput):
38
+ controlnet_block_samples: Tuple[torch.Tensor]
39
+
40
+
41
+ class SD3ControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin):
42
+ _supports_gradient_checkpointing = True
43
+
44
+ @register_to_config
45
+ def __init__(
46
+ self,
47
+ sample_size: int = 128,
48
+ patch_size: int = 2,
49
+ in_channels: int = 16,
50
+ num_layers: int = 18,
51
+ attention_head_dim: int = 64,
52
+ num_attention_heads: int = 18,
53
+ joint_attention_dim: int = 4096,
54
+ caption_projection_dim: int = 1152,
55
+ pooled_projection_dim: int = 2048,
56
+ out_channels: int = 16,
57
+ pos_embed_max_size: int = 96,
58
+ ):
59
+ super().__init__()
60
+ default_out_channels = in_channels
61
+ self.out_channels = out_channels if out_channels is not None else default_out_channels
62
+ self.inner_dim = num_attention_heads * attention_head_dim
63
+
64
+ self.pos_embed = PatchEmbed(
65
+ height=sample_size,
66
+ width=sample_size,
67
+ patch_size=patch_size,
68
+ in_channels=in_channels,
69
+ embed_dim=self.inner_dim,
70
+ pos_embed_max_size=pos_embed_max_size,
71
+ )
72
+ self.time_text_embed = CombinedTimestepTextProjEmbeddings(
73
+ embedding_dim=self.inner_dim, pooled_projection_dim=pooled_projection_dim
74
+ )
75
+ self.context_embedder = nn.Linear(joint_attention_dim, caption_projection_dim)
76
+
77
+ # `attention_head_dim` is doubled to account for the mixing.
78
+ # It needs to crafted when we get the actual checkpoints.
79
+ self.transformer_blocks = nn.ModuleList(
80
+ [
81
+ JointTransformerBlock(
82
+ dim=self.inner_dim,
83
+ num_attention_heads=num_attention_heads,
84
+ attention_head_dim=self.inner_dim,
85
+ context_pre_only=False,
86
+ )
87
+ for i in range(num_layers)
88
+ ]
89
+ )
90
+
91
+ # controlnet_blocks
92
+ self.controlnet_blocks = nn.ModuleList([])
93
+ for _ in range(len(self.transformer_blocks)):
94
+ controlnet_block = nn.Linear(self.inner_dim, self.inner_dim)
95
+ controlnet_block = zero_module(controlnet_block)
96
+ self.controlnet_blocks.append(controlnet_block)
97
+ pos_embed_input = PatchEmbed(
98
+ height=sample_size,
99
+ width=sample_size,
100
+ patch_size=patch_size,
101
+ in_channels=in_channels,
102
+ embed_dim=self.inner_dim,
103
+ pos_embed_type=None,
104
+ )
105
+ self.pos_embed_input = zero_module(pos_embed_input)
106
+
107
+ self.gradient_checkpointing = False
108
+
109
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
110
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
111
+ """
112
+ Sets the attention processor to use [feed forward
113
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
114
+
115
+ Parameters:
116
+ chunk_size (`int`, *optional*):
117
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
118
+ over each tensor of dim=`dim`.
119
+ dim (`int`, *optional*, defaults to `0`):
120
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
121
+ or dim=1 (sequence length).
122
+ """
123
+ if dim not in [0, 1]:
124
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
125
+
126
+ # By default chunk size is 1
127
+ chunk_size = chunk_size or 1
128
+
129
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
130
+ if hasattr(module, "set_chunk_feed_forward"):
131
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
132
+
133
+ for child in module.children():
134
+ fn_recursive_feed_forward(child, chunk_size, dim)
135
+
136
+ for module in self.children():
137
+ fn_recursive_feed_forward(module, chunk_size, dim)
138
+
139
+ @property
140
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
141
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
142
+ r"""
143
+ Returns:
144
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
145
+ indexed by its weight name.
146
+ """
147
+ # set recursively
148
+ processors = {}
149
+
150
+ def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
151
+ if hasattr(module, "get_processor"):
152
+ processors[f"{name}.processor"] = module.get_processor(return_deprecated_lora=True)
153
+
154
+ for sub_name, child in module.named_children():
155
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
156
+
157
+ return processors
158
+
159
+ for name, module in self.named_children():
160
+ fn_recursive_add_processors(name, module, processors)
161
+
162
+ return processors
163
+
164
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
165
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
166
+ r"""
167
+ Sets the attention processor to use to compute attention.
168
+
169
+ Parameters:
170
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
171
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
172
+ for **all** `Attention` layers.
173
+
174
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
175
+ processor. This is strongly recommended when setting trainable attention processors.
176
+
177
+ """
178
+ count = len(self.attn_processors.keys())
179
+
180
+ if isinstance(processor, dict) and len(processor) != count:
181
+ raise ValueError(
182
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
183
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
184
+ )
185
+
186
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
187
+ if hasattr(module, "set_processor"):
188
+ if not isinstance(processor, dict):
189
+ module.set_processor(processor)
190
+ else:
191
+ module.set_processor(processor.pop(f"{name}.processor"))
192
+
193
+ for sub_name, child in module.named_children():
194
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
195
+
196
+ for name, module in self.named_children():
197
+ fn_recursive_attn_processor(name, module, processor)
198
+
199
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
200
+ def fuse_qkv_projections(self):
201
+ """
202
+ Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
203
+ are fused. For cross-attention modules, key and value projection matrices are fused.
204
+
205
+ <Tip warning={true}>
206
+
207
+ This API is 🧪 experimental.
208
+
209
+ </Tip>
210
+ """
211
+ self.original_attn_processors = None
212
+
213
+ for _, attn_processor in self.attn_processors.items():
214
+ if "Added" in str(attn_processor.__class__.__name__):
215
+ raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
216
+
217
+ self.original_attn_processors = self.attn_processors
218
+
219
+ for module in self.modules():
220
+ if isinstance(module, Attention):
221
+ module.fuse_projections(fuse=True)
222
+
223
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
224
+ def unfuse_qkv_projections(self):
225
+ """Disables the fused QKV projection if enabled.
226
+
227
+ <Tip warning={true}>
228
+
229
+ This API is 🧪 experimental.
230
+
231
+ </Tip>
232
+
233
+ """
234
+ if self.original_attn_processors is not None:
235
+ self.set_attn_processor(self.original_attn_processors)
236
+
237
+ def _set_gradient_checkpointing(self, module, value=False):
238
+ if hasattr(module, "gradient_checkpointing"):
239
+ module.gradient_checkpointing = value
240
+
241
+ @classmethod
242
+ def from_transformer(cls, transformer, num_layers=None, load_weights_from_transformer=True):
243
+ config = transformer.config
244
+ config["num_layers"] = num_layers or config.num_layers
245
+ controlnet = cls(**config)
246
+
247
+ if load_weights_from_transformer:
248
+ controlnet.pos_embed.load_state_dict(transformer.pos_embed.state_dict(), strict=False)
249
+ controlnet.time_text_embed.load_state_dict(transformer.time_text_embed.state_dict(), strict=False)
250
+ controlnet.context_embedder.load_state_dict(transformer.context_embedder.state_dict(), strict=False)
251
+ controlnet.transformer_blocks.load_state_dict(transformer.transformer_blocks.state_dict())
252
+
253
+ controlnet.pos_embed_input = zero_module(controlnet.pos_embed_input)
254
+
255
+ return controlnet
256
+
257
+ def forward(
258
+ self,
259
+ hidden_states: torch.FloatTensor,
260
+ controlnet_cond: torch.Tensor,
261
+ conditioning_scale: float = 1.0,
262
+ encoder_hidden_states: torch.FloatTensor = None,
263
+ pooled_projections: torch.FloatTensor = None,
264
+ timestep: torch.LongTensor = None,
265
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
266
+ return_dict: bool = True,
267
+ ) -> Union[torch.FloatTensor, Transformer2DModelOutput]:
268
+ """
269
+ The [`SD3Transformer2DModel`] forward method.
270
+
271
+ Args:
272
+ hidden_states (`torch.FloatTensor` of shape `(batch size, channel, height, width)`):
273
+ Input `hidden_states`.
274
+ controlnet_cond (`torch.Tensor`):
275
+ The conditional input tensor of shape `(batch_size, sequence_length, hidden_size)`.
276
+ conditioning_scale (`float`, defaults to `1.0`):
277
+ The scale factor for ControlNet outputs.
278
+ encoder_hidden_states (`torch.FloatTensor` of shape `(batch size, sequence_len, embed_dims)`):
279
+ Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
280
+ pooled_projections (`torch.FloatTensor` of shape `(batch_size, projection_dim)`): Embeddings projected
281
+ from the embeddings of input conditions.
282
+ timestep ( `torch.LongTensor`):
283
+ Used to indicate denoising step.
284
+ joint_attention_kwargs (`dict`, *optional*):
285
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
286
+ `self.processor` in
287
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
288
+ return_dict (`bool`, *optional*, defaults to `True`):
289
+ Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
290
+ tuple.
291
+
292
+ Returns:
293
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
294
+ `tuple` where the first element is the sample tensor.
295
+ """
296
+ if joint_attention_kwargs is not None:
297
+ joint_attention_kwargs = joint_attention_kwargs.copy()
298
+ lora_scale = joint_attention_kwargs.pop("scale", 1.0)
299
+ else:
300
+ lora_scale = 1.0
301
+
302
+ if USE_PEFT_BACKEND:
303
+ # weight the lora layers by setting `lora_scale` for each PEFT layer
304
+ scale_lora_layers(self, lora_scale)
305
+ else:
306
+ if joint_attention_kwargs is not None and joint_attention_kwargs.get("scale", None) is not None:
307
+ logger.warning(
308
+ "Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
309
+ )
310
+
311
+ height, width = hidden_states.shape[-2:]
312
+
313
+ hidden_states = self.pos_embed(hidden_states) # takes care of adding positional embeddings too.
314
+ temb = self.time_text_embed(timestep, pooled_projections)
315
+ encoder_hidden_states = self.context_embedder(encoder_hidden_states)
316
+
317
+ # add
318
+ hidden_states = hidden_states + self.pos_embed_input(controlnet_cond)
319
+
320
+ block_res_samples = ()
321
+
322
+ for block in self.transformer_blocks:
323
+ if self.training and self.gradient_checkpointing:
324
+
325
+ def create_custom_forward(module, return_dict=None):
326
+ def custom_forward(*inputs):
327
+ if return_dict is not None:
328
+ return module(*inputs, return_dict=return_dict)
329
+ else:
330
+ return module(*inputs)
331
+
332
+ return custom_forward
333
+
334
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
335
+ hidden_states = torch.utils.checkpoint.checkpoint(
336
+ create_custom_forward(block),
337
+ hidden_states,
338
+ encoder_hidden_states,
339
+ temb,
340
+ **ckpt_kwargs,
341
+ )
342
+
343
+ else:
344
+ encoder_hidden_states, hidden_states = block(
345
+ hidden_states=hidden_states, encoder_hidden_states=encoder_hidden_states, temb=temb
346
+ )
347
+
348
+ block_res_samples = block_res_samples + (hidden_states,)
349
+
350
+ controlnet_block_res_samples = ()
351
+ for block_res_sample, controlnet_block in zip(block_res_samples, self.controlnet_blocks):
352
+ block_res_sample = controlnet_block(block_res_sample)
353
+ controlnet_block_res_samples = controlnet_block_res_samples + (block_res_sample,)
354
+
355
+ # 6. scaling
356
+ controlnet_block_res_samples = [sample * conditioning_scale for sample in controlnet_block_res_samples]
357
+
358
+ if USE_PEFT_BACKEND:
359
+ # remove `lora_scale` from each PEFT layer
360
+ unscale_lora_layers(self, lora_scale)
361
+
362
+ if not return_dict:
363
+ return (controlnet_block_res_samples,)
364
+
365
+ return SD3ControlNetOutput(controlnet_block_samples=controlnet_block_res_samples)
366
+
367
+
368
+ class SD3MultiControlNetModel(ModelMixin):
369
+ r"""
370
+ `SD3ControlNetModel` wrapper class for Multi-SD3ControlNet
371
+
372
+ This module is a wrapper for multiple instances of the `SD3ControlNetModel`. The `forward()` API is designed to be
373
+ compatible with `SD3ControlNetModel`.
374
+
375
+ Args:
376
+ controlnets (`List[SD3ControlNetModel]`):
377
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
378
+ `SD3ControlNetModel` as a list.
379
+ """
380
+
381
+ def __init__(self, controlnets):
382
+ super().__init__()
383
+ self.nets = nn.ModuleList(controlnets)
384
+
385
+ def forward(
386
+ self,
387
+ hidden_states: torch.FloatTensor,
388
+ controlnet_cond: List[torch.tensor],
389
+ conditioning_scale: List[float],
390
+ pooled_projections: torch.FloatTensor,
391
+ encoder_hidden_states: torch.FloatTensor = None,
392
+ timestep: torch.LongTensor = None,
393
+ joint_attention_kwargs: Optional[Dict[str, Any]] = None,
394
+ return_dict: bool = True,
395
+ ) -> Union[SD3ControlNetOutput, Tuple]:
396
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
397
+ block_samples = controlnet(
398
+ hidden_states=hidden_states,
399
+ timestep=timestep,
400
+ encoder_hidden_states=encoder_hidden_states,
401
+ pooled_projections=pooled_projections,
402
+ controlnet_cond=image,
403
+ conditioning_scale=scale,
404
+ joint_attention_kwargs=joint_attention_kwargs,
405
+ return_dict=return_dict,
406
+ )
407
+
408
+ # merge samples
409
+ if i == 0:
410
+ control_block_samples = block_samples
411
+ else:
412
+ control_block_samples = [
413
+ control_block_sample + block_sample
414
+ for control_block_sample, block_sample in zip(control_block_samples[0], block_samples[0])
415
+ ]
416
+ control_block_samples = (tuple(control_block_samples),)
417
+
418
+ return control_block_samples