diffusers 0.29.0__py3-none-any.whl → 0.29.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
diffusers/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.29.0"
1
+ __version__ = "0.29.2"
2
2
 
3
3
  from typing import TYPE_CHECKING
4
4
 
@@ -91,6 +91,8 @@ else:
91
91
  "MultiAdapter",
92
92
  "PixArtTransformer2DModel",
93
93
  "PriorTransformer",
94
+ "SD3ControlNetModel",
95
+ "SD3MultiControlNetModel",
94
96
  "SD3Transformer2DModel",
95
97
  "StableCascadeUNet",
96
98
  "T2IAdapter",
@@ -278,6 +280,7 @@ else:
278
280
  "StableCascadeCombinedPipeline",
279
281
  "StableCascadeDecoderPipeline",
280
282
  "StableCascadePriorPipeline",
283
+ "StableDiffusion3ControlNetPipeline",
281
284
  "StableDiffusion3Img2ImgPipeline",
282
285
  "StableDiffusion3Pipeline",
283
286
  "StableDiffusionAdapterPipeline",
@@ -501,6 +504,8 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
501
504
  MultiAdapter,
502
505
  PixArtTransformer2DModel,
503
506
  PriorTransformer,
507
+ SD3ControlNetModel,
508
+ SD3MultiControlNetModel,
504
509
  SD3Transformer2DModel,
505
510
  T2IAdapter,
506
511
  T5FilmDecoder,
@@ -666,6 +671,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
666
671
  StableCascadeCombinedPipeline,
667
672
  StableCascadeDecoderPipeline,
668
673
  StableCascadePriorPipeline,
674
+ StableDiffusion3ControlNetPipeline,
669
675
  StableDiffusion3Img2ImgPipeline,
670
676
  StableDiffusion3Pipeline,
671
677
  StableDiffusionAdapterPipeline,
diffusers/loaders/lora.py CHANGED
@@ -42,7 +42,7 @@ from ..utils import (
42
42
  set_adapter_layers,
43
43
  set_weights_and_activate_adapters,
44
44
  )
45
- from .lora_conversion_utils import _convert_kohya_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
45
+ from .lora_conversion_utils import _convert_non_diffusers_lora_to_diffusers, _maybe_map_sgm_blocks_to_diffusers
46
46
 
47
47
 
48
48
  if is_transformers_available():
@@ -287,7 +287,7 @@ class LoraLoaderMixin:
287
287
  if unet_config is not None:
288
288
  # use unet config to remap block numbers
289
289
  state_dict = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
290
- state_dict, network_alphas = _convert_kohya_lora_to_diffusers(state_dict)
290
+ state_dict, network_alphas = _convert_non_diffusers_lora_to_diffusers(state_dict)
291
291
 
292
292
  return state_dict, network_alphas
293
293
 
@@ -395,8 +395,7 @@ class LoraLoaderMixin:
395
395
  # their prefixes.
396
396
  keys = list(state_dict.keys())
397
397
  only_text_encoder = all(key.startswith(cls.text_encoder_name) for key in keys)
398
-
399
- if any(key.startswith(cls.unet_name) for key in keys) and not only_text_encoder:
398
+ if not only_text_encoder:
400
399
  # Load the layers corresponding to UNet.
401
400
  logger.info(f"Loading {cls.unet_name}.")
402
401
  unet.load_attn_procs(
@@ -123,134 +123,76 @@ def _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config, delimiter="_", b
123
123
  return new_state_dict
124
124
 
125
125
 
126
- def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
126
+ def _convert_non_diffusers_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_name="text_encoder"):
127
+ """
128
+ Converts a non-Diffusers LoRA state dict to a Diffusers compatible state dict.
129
+
130
+ Args:
131
+ state_dict (`dict`): The state dict to convert.
132
+ unet_name (`str`, optional): The name of the U-Net module in the Diffusers model. Defaults to "unet".
133
+ text_encoder_name (`str`, optional): The name of the text encoder module in the Diffusers model. Defaults to
134
+ "text_encoder".
135
+
136
+ Returns:
137
+ `tuple`: A tuple containing the converted state dict and a dictionary of alphas.
138
+ """
127
139
  unet_state_dict = {}
128
140
  te_state_dict = {}
129
141
  te2_state_dict = {}
130
142
  network_alphas = {}
131
- is_unet_dora_lora = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
132
- is_te_dora_lora = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
133
- is_te2_dora_lora = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
134
143
 
135
- if is_unet_dora_lora or is_te_dora_lora or is_te2_dora_lora:
144
+ # Check for DoRA-enabled LoRAs.
145
+ dora_present_in_unet = any("dora_scale" in k and "lora_unet_" in k for k in state_dict)
146
+ dora_present_in_te = any("dora_scale" in k and ("lora_te_" in k or "lora_te1_" in k) for k in state_dict)
147
+ dora_present_in_te2 = any("dora_scale" in k and "lora_te2_" in k for k in state_dict)
148
+ if dora_present_in_unet or dora_present_in_te or dora_present_in_te2:
136
149
  if is_peft_version("<", "0.9.0"):
137
150
  raise ValueError(
138
151
  "You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
139
152
  )
140
153
 
141
- # every down weight has a corresponding up weight and potentially an alpha weight
142
- lora_keys = [k for k in state_dict.keys() if k.endswith("lora_down.weight")]
143
- for key in lora_keys:
154
+ # Iterate over all LoRA weights.
155
+ all_lora_keys = list(state_dict.keys())
156
+ for key in all_lora_keys:
157
+ if not key.endswith("lora_down.weight"):
158
+ continue
159
+
160
+ # Extract LoRA name.
144
161
  lora_name = key.split(".")[0]
162
+
163
+ # Find corresponding up weight and alpha.
145
164
  lora_name_up = lora_name + ".lora_up.weight"
146
165
  lora_name_alpha = lora_name + ".alpha"
147
166
 
167
+ # Handle U-Net LoRAs.
148
168
  if lora_name.startswith("lora_unet_"):
149
- diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
150
-
151
- if "input.blocks" in diffusers_name:
152
- diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
153
- else:
154
- diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
169
+ diffusers_name = _convert_unet_lora_key(key)
155
170
 
156
- if "middle.block" in diffusers_name:
157
- diffusers_name = diffusers_name.replace("middle.block", "mid_block")
158
- else:
159
- diffusers_name = diffusers_name.replace("mid.block", "mid_block")
160
- if "output.blocks" in diffusers_name:
161
- diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
162
- else:
163
- diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
164
-
165
- diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
166
- diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
167
- diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
168
- diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
169
- diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
170
- diffusers_name = diffusers_name.replace("proj.in", "proj_in")
171
- diffusers_name = diffusers_name.replace("proj.out", "proj_out")
172
- diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
173
-
174
- # SDXL specificity.
175
- if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
176
- pattern = r"\.\d+(?=\D*$)"
177
- diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
178
- if ".in." in diffusers_name:
179
- diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
180
- if ".out." in diffusers_name:
181
- diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
182
- if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
183
- diffusers_name = diffusers_name.replace("op", "conv")
184
- if "skip" in diffusers_name:
185
- diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
186
-
187
- # LyCORIS specificity.
188
- if "time.emb.proj" in diffusers_name:
189
- diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
190
- if "conv.shortcut" in diffusers_name:
191
- diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
192
-
193
- # General coverage.
194
- if "transformer_blocks" in diffusers_name:
195
- if "attn1" in diffusers_name or "attn2" in diffusers_name:
196
- diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
197
- diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
198
- unet_state_dict[diffusers_name] = state_dict.pop(key)
199
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
200
- elif "ff" in diffusers_name:
201
- unet_state_dict[diffusers_name] = state_dict.pop(key)
202
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
203
- elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
204
- unet_state_dict[diffusers_name] = state_dict.pop(key)
205
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
206
- else:
207
- unet_state_dict[diffusers_name] = state_dict.pop(key)
208
- unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
171
+ # Store down and up weights.
172
+ unet_state_dict[diffusers_name] = state_dict.pop(key)
173
+ unet_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
209
174
 
210
- if is_unet_dora_lora:
175
+ # Store DoRA scale if present.
176
+ if dora_present_in_unet:
211
177
  dora_scale_key_to_replace = "_lora.down." if "_lora.down." in diffusers_name else ".lora.down."
212
178
  unet_state_dict[
213
179
  diffusers_name.replace(dora_scale_key_to_replace, ".lora_magnitude_vector.")
214
180
  ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
215
181
 
182
+ # Handle text encoder LoRAs.
216
183
  elif lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
184
+ diffusers_name = _convert_text_encoder_lora_key(key, lora_name)
185
+
186
+ # Store down and up weights for te or te2.
217
187
  if lora_name.startswith(("lora_te_", "lora_te1_")):
218
- key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
188
+ te_state_dict[diffusers_name] = state_dict.pop(key)
189
+ te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
219
190
  else:
220
- key_to_replace = "lora_te2_"
221
-
222
- diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
223
- diffusers_name = diffusers_name.replace("text.model", "text_model")
224
- diffusers_name = diffusers_name.replace("self.attn", "self_attn")
225
- diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
226
- diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
227
- diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
228
- diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
229
- diffusers_name = diffusers_name.replace("text.projection", "text_projection")
230
-
231
- if "self_attn" in diffusers_name:
232
- if lora_name.startswith(("lora_te_", "lora_te1_")):
233
- te_state_dict[diffusers_name] = state_dict.pop(key)
234
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
235
- else:
236
- te2_state_dict[diffusers_name] = state_dict.pop(key)
237
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
238
- elif "mlp" in diffusers_name:
239
- # Be aware that this is the new diffusers convention and the rest of the code might
240
- # not utilize it yet.
241
- diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
242
- if lora_name.startswith(("lora_te_", "lora_te1_")):
243
- te_state_dict[diffusers_name] = state_dict.pop(key)
244
- te_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
245
- else:
246
- te2_state_dict[diffusers_name] = state_dict.pop(key)
247
- te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
248
- # OneTrainer specificity
249
- elif "text_projection" in diffusers_name and lora_name.startswith("lora_te2_"):
250
191
  te2_state_dict[diffusers_name] = state_dict.pop(key)
251
192
  te2_state_dict[diffusers_name.replace(".down.", ".up.")] = state_dict.pop(lora_name_up)
252
193
 
253
- if (is_te_dora_lora or is_te2_dora_lora) and lora_name.startswith(("lora_te_", "lora_te1_", "lora_te2_")):
194
+ # Store DoRA scale if present.
195
+ if dora_present_in_te or dora_present_in_te2:
254
196
  dora_scale_key_to_replace_te = (
255
197
  "_lora.down." if "_lora.down." in diffusers_name else ".lora_linear_layer."
256
198
  )
@@ -263,22 +205,18 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
263
205
  diffusers_name.replace(dora_scale_key_to_replace_te, ".lora_magnitude_vector.")
264
206
  ] = state_dict.pop(key.replace("lora_down.weight", "dora_scale"))
265
207
 
266
- # Rename the alphas so that they can be mapped appropriately.
208
+ # Store alpha if present.
267
209
  if lora_name_alpha in state_dict:
268
210
  alpha = state_dict.pop(lora_name_alpha).item()
269
- if lora_name_alpha.startswith("lora_unet_"):
270
- prefix = "unet."
271
- elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
272
- prefix = "text_encoder."
273
- else:
274
- prefix = "text_encoder_2."
275
- new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
276
- network_alphas.update({new_name: alpha})
211
+ network_alphas.update(_get_alpha_name(lora_name_alpha, diffusers_name, alpha))
277
212
 
213
+ # Check if any keys remain.
278
214
  if len(state_dict) > 0:
279
215
  raise ValueError(f"The following keys have not been correctly renamed: \n\n {', '.join(state_dict.keys())}")
280
216
 
281
- logger.info("Kohya-style checkpoint detected.")
217
+ logger.info("Non-diffusers checkpoint detected.")
218
+
219
+ # Construct final state dict.
282
220
  unet_state_dict = {f"{unet_name}.{module_name}": params for module_name, params in unet_state_dict.items()}
283
221
  te_state_dict = {f"{text_encoder_name}.{module_name}": params for module_name, params in te_state_dict.items()}
284
222
  te2_state_dict = (
@@ -291,3 +229,100 @@ def _convert_kohya_lora_to_diffusers(state_dict, unet_name="unet", text_encoder_
291
229
 
292
230
  new_state_dict = {**unet_state_dict, **te_state_dict}
293
231
  return new_state_dict, network_alphas
232
+
233
+
234
+ def _convert_unet_lora_key(key):
235
+ """
236
+ Converts a U-Net LoRA key to a Diffusers compatible key.
237
+ """
238
+ diffusers_name = key.replace("lora_unet_", "").replace("_", ".")
239
+
240
+ # Replace common U-Net naming patterns.
241
+ diffusers_name = diffusers_name.replace("input.blocks", "down_blocks")
242
+ diffusers_name = diffusers_name.replace("down.blocks", "down_blocks")
243
+ diffusers_name = diffusers_name.replace("middle.block", "mid_block")
244
+ diffusers_name = diffusers_name.replace("mid.block", "mid_block")
245
+ diffusers_name = diffusers_name.replace("output.blocks", "up_blocks")
246
+ diffusers_name = diffusers_name.replace("up.blocks", "up_blocks")
247
+ diffusers_name = diffusers_name.replace("transformer.blocks", "transformer_blocks")
248
+ diffusers_name = diffusers_name.replace("to.q.lora", "to_q_lora")
249
+ diffusers_name = diffusers_name.replace("to.k.lora", "to_k_lora")
250
+ diffusers_name = diffusers_name.replace("to.v.lora", "to_v_lora")
251
+ diffusers_name = diffusers_name.replace("to.out.0.lora", "to_out_lora")
252
+ diffusers_name = diffusers_name.replace("proj.in", "proj_in")
253
+ diffusers_name = diffusers_name.replace("proj.out", "proj_out")
254
+ diffusers_name = diffusers_name.replace("emb.layers", "time_emb_proj")
255
+
256
+ # SDXL specific conversions.
257
+ if "emb" in diffusers_name and "time.emb.proj" not in diffusers_name:
258
+ pattern = r"\.\d+(?=\D*$)"
259
+ diffusers_name = re.sub(pattern, "", diffusers_name, count=1)
260
+ if ".in." in diffusers_name:
261
+ diffusers_name = diffusers_name.replace("in.layers.2", "conv1")
262
+ if ".out." in diffusers_name:
263
+ diffusers_name = diffusers_name.replace("out.layers.3", "conv2")
264
+ if "downsamplers" in diffusers_name or "upsamplers" in diffusers_name:
265
+ diffusers_name = diffusers_name.replace("op", "conv")
266
+ if "skip" in diffusers_name:
267
+ diffusers_name = diffusers_name.replace("skip.connection", "conv_shortcut")
268
+
269
+ # LyCORIS specific conversions.
270
+ if "time.emb.proj" in diffusers_name:
271
+ diffusers_name = diffusers_name.replace("time.emb.proj", "time_emb_proj")
272
+ if "conv.shortcut" in diffusers_name:
273
+ diffusers_name = diffusers_name.replace("conv.shortcut", "conv_shortcut")
274
+
275
+ # General conversions.
276
+ if "transformer_blocks" in diffusers_name:
277
+ if "attn1" in diffusers_name or "attn2" in diffusers_name:
278
+ diffusers_name = diffusers_name.replace("attn1", "attn1.processor")
279
+ diffusers_name = diffusers_name.replace("attn2", "attn2.processor")
280
+ elif "ff" in diffusers_name:
281
+ pass
282
+ elif any(key in diffusers_name for key in ("proj_in", "proj_out")):
283
+ pass
284
+ else:
285
+ pass
286
+
287
+ return diffusers_name
288
+
289
+
290
+ def _convert_text_encoder_lora_key(key, lora_name):
291
+ """
292
+ Converts a text encoder LoRA key to a Diffusers compatible key.
293
+ """
294
+ if lora_name.startswith(("lora_te_", "lora_te1_")):
295
+ key_to_replace = "lora_te_" if lora_name.startswith("lora_te_") else "lora_te1_"
296
+ else:
297
+ key_to_replace = "lora_te2_"
298
+
299
+ diffusers_name = key.replace(key_to_replace, "").replace("_", ".")
300
+ diffusers_name = diffusers_name.replace("text.model", "text_model")
301
+ diffusers_name = diffusers_name.replace("self.attn", "self_attn")
302
+ diffusers_name = diffusers_name.replace("q.proj.lora", "to_q_lora")
303
+ diffusers_name = diffusers_name.replace("k.proj.lora", "to_k_lora")
304
+ diffusers_name = diffusers_name.replace("v.proj.lora", "to_v_lora")
305
+ diffusers_name = diffusers_name.replace("out.proj.lora", "to_out_lora")
306
+ diffusers_name = diffusers_name.replace("text.projection", "text_projection")
307
+
308
+ if "self_attn" in diffusers_name or "text_projection" in diffusers_name:
309
+ pass
310
+ elif "mlp" in diffusers_name:
311
+ # Be aware that this is the new diffusers convention and the rest of the code might
312
+ # not utilize it yet.
313
+ diffusers_name = diffusers_name.replace(".lora.", ".lora_linear_layer.")
314
+ return diffusers_name
315
+
316
+
317
+ def _get_alpha_name(lora_name_alpha, diffusers_name, alpha):
318
+ """
319
+ Gets the correct alpha name for the Diffusers model.
320
+ """
321
+ if lora_name_alpha.startswith("lora_unet_"):
322
+ prefix = "unet."
323
+ elif lora_name_alpha.startswith(("lora_te_", "lora_te1_")):
324
+ prefix = "text_encoder."
325
+ else:
326
+ prefix = "text_encoder_2."
327
+ new_name = prefix + diffusers_name.split(".lora.")[0] + ".alpha"
328
+ return {new_name: alpha}
@@ -28,9 +28,11 @@ from .single_file_utils import (
28
28
  _legacy_load_safety_checker,
29
29
  _legacy_load_scheduler,
30
30
  create_diffusers_clip_model_from_ldm,
31
+ create_diffusers_t5_model_from_checkpoint,
31
32
  fetch_diffusers_config,
32
33
  fetch_original_config,
33
34
  is_clip_model_in_single_file,
35
+ is_t5_in_single_file,
34
36
  load_single_file_checkpoint,
35
37
  )
36
38
 
@@ -118,6 +120,16 @@ def load_single_file_sub_model(
118
120
  is_legacy_loading=is_legacy_loading,
119
121
  )
120
122
 
123
+ elif is_transformers_model and is_t5_in_single_file(checkpoint):
124
+ loaded_sub_model = create_diffusers_t5_model_from_checkpoint(
125
+ class_obj,
126
+ checkpoint=checkpoint,
127
+ config=cached_model_config_path,
128
+ subfolder=name,
129
+ torch_dtype=torch_dtype,
130
+ local_files_only=local_files_only,
131
+ )
132
+
121
133
  elif is_tokenizer and is_legacy_loading:
122
134
  loaded_sub_model = _legacy_load_clip_tokenizer(
123
135
  class_obj, checkpoint=checkpoint, config=cached_model_config_path, local_files_only=local_files_only
@@ -276,16 +276,18 @@ class FromOriginalModelMixin:
276
276
 
277
277
  if is_accelerate_available():
278
278
  unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
279
- if model._keys_to_ignore_on_load_unexpected is not None:
280
- for pat in model._keys_to_ignore_on_load_unexpected:
281
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
282
279
 
283
- if len(unexpected_keys) > 0:
284
- logger.warning(
285
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
286
- )
287
280
  else:
288
- model.load_state_dict(diffusers_format_checkpoint)
281
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
282
+
283
+ if model._keys_to_ignore_on_load_unexpected is not None:
284
+ for pat in model._keys_to_ignore_on_load_unexpected:
285
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
286
+
287
+ if len(unexpected_keys) > 0:
288
+ logger.warning(
289
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
290
+ )
289
291
 
290
292
  if torch_dtype is not None:
291
293
  model.to(torch_dtype)
@@ -252,7 +252,6 @@ LDM_CONTROLNET_KEY = "control_model."
252
252
  LDM_CLIP_PREFIX_TO_REMOVE = [
253
253
  "cond_stage_model.transformer.",
254
254
  "conditioner.embedders.0.transformer.",
255
- "text_encoders.clip_l.transformer.",
256
255
  ]
257
256
  OPEN_CLIP_PREFIX = "conditioner.embedders.0.model."
258
257
  LDM_OPEN_CLIP_TEXT_PROJECTION_DIM = 1024
@@ -399,11 +398,14 @@ def is_open_clip_sdxl_model(checkpoint):
399
398
 
400
399
 
401
400
  def is_open_clip_sd3_model(checkpoint):
402
- is_open_clip_sdxl_refiner_model(checkpoint)
401
+ if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
402
+ return True
403
+
404
+ return False
403
405
 
404
406
 
405
407
  def is_open_clip_sdxl_refiner_model(checkpoint):
406
- if CHECKPOINT_KEY_NAMES["open_clip_sd3"] in checkpoint:
408
+ if CHECKPOINT_KEY_NAMES["open_clip_sdxl_refiner"] in checkpoint:
407
409
  return True
408
410
 
409
411
  return False
@@ -1233,11 +1235,14 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
1233
1235
  return new_checkpoint
1234
1236
 
1235
1237
 
1236
- def convert_ldm_clip_checkpoint(checkpoint):
1238
+ def convert_ldm_clip_checkpoint(checkpoint, remove_prefix=None):
1237
1239
  keys = list(checkpoint.keys())
1238
1240
  text_model_dict = {}
1239
1241
 
1240
- remove_prefixes = LDM_CLIP_PREFIX_TO_REMOVE
1242
+ remove_prefixes = []
1243
+ remove_prefixes.extend(LDM_CLIP_PREFIX_TO_REMOVE)
1244
+ if remove_prefix:
1245
+ remove_prefixes.append(remove_prefix)
1241
1246
 
1242
1247
  for key in keys:
1243
1248
  for prefix in remove_prefixes:
@@ -1263,8 +1268,6 @@ def convert_open_clip_checkpoint(
1263
1268
  else:
1264
1269
  text_proj_dim = LDM_OPEN_CLIP_TEXT_PROJECTION_DIM
1265
1270
 
1266
- text_model_dict["text_model.embeddings.position_ids"] = text_model.text_model.embeddings.get_buffer("position_ids")
1267
-
1268
1271
  keys = list(checkpoint.keys())
1269
1272
  keys_to_ignore = SD_2_TEXT_ENCODER_KEYS_TO_IGNORE
1270
1273
 
@@ -1313,9 +1316,6 @@ def convert_open_clip_checkpoint(
1313
1316
  else:
1314
1317
  text_model_dict[diffusers_key] = checkpoint.get(key)
1315
1318
 
1316
- if not (hasattr(text_model, "embeddings") and hasattr(text_model.embeddings.position_ids)):
1317
- text_model_dict.pop("text_model.embeddings.position_ids", None)
1318
-
1319
1319
  return text_model_dict
1320
1320
 
1321
1321
 
@@ -1376,6 +1376,13 @@ def create_diffusers_clip_model_from_ldm(
1376
1376
  ):
1377
1377
  diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint)
1378
1378
 
1379
+ elif (
1380
+ is_clip_sd3_model(checkpoint)
1381
+ and checkpoint[CHECKPOINT_KEY_NAMES["clip_sd3"]].shape[-1] == position_embedding_dim
1382
+ ):
1383
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_l.transformer.")
1384
+ diffusers_format_checkpoint["text_projection.weight"] = torch.eye(position_embedding_dim)
1385
+
1379
1386
  elif is_open_clip_model(checkpoint):
1380
1387
  prefix = "cond_stage_model.model."
1381
1388
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
@@ -1391,26 +1398,28 @@ def create_diffusers_clip_model_from_ldm(
1391
1398
  prefix = "conditioner.embedders.0.model."
1392
1399
  diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1393
1400
 
1394
- elif is_open_clip_sd3_model(checkpoint):
1395
- prefix = "text_encoders.clip_g.transformer."
1396
- diffusers_format_checkpoint = convert_open_clip_checkpoint(model, checkpoint, prefix=prefix)
1401
+ elif (
1402
+ is_open_clip_sd3_model(checkpoint)
1403
+ and checkpoint[CHECKPOINT_KEY_NAMES["open_clip_sd3"]].shape[-1] == position_embedding_dim
1404
+ ):
1405
+ diffusers_format_checkpoint = convert_ldm_clip_checkpoint(checkpoint, "text_encoders.clip_g.transformer.")
1397
1406
 
1398
1407
  else:
1399
1408
  raise ValueError("The provided checkpoint does not seem to contain a valid CLIP model.")
1400
1409
 
1401
1410
  if is_accelerate_available():
1402
1411
  unexpected_keys = load_model_dict_into_meta(model, diffusers_format_checkpoint, dtype=torch_dtype)
1403
- if model._keys_to_ignore_on_load_unexpected is not None:
1404
- for pat in model._keys_to_ignore_on_load_unexpected:
1405
- unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1412
+ else:
1413
+ _, unexpected_keys = model.load_state_dict(diffusers_format_checkpoint, strict=False)
1406
1414
 
1407
- if len(unexpected_keys) > 0:
1408
- logger.warning(
1409
- f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1410
- )
1415
+ if model._keys_to_ignore_on_load_unexpected is not None:
1416
+ for pat in model._keys_to_ignore_on_load_unexpected:
1417
+ unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1411
1418
 
1412
- else:
1413
- model.load_state_dict(diffusers_format_checkpoint)
1419
+ if len(unexpected_keys) > 0:
1420
+ logger.warning(
1421
+ f"Some weights of the model checkpoint were not used when initializing {cls.__name__}: \n {[', '.join(unexpected_keys)]}"
1422
+ )
1414
1423
 
1415
1424
  if torch_dtype is not None:
1416
1425
  model.to(torch_dtype)
@@ -1755,7 +1764,7 @@ def convert_sd3_t5_checkpoint_to_diffusers(checkpoint):
1755
1764
  keys = list(checkpoint.keys())
1756
1765
  text_model_dict = {}
1757
1766
 
1758
- remove_prefixes = ["text_encoders.t5xxl.transformer.encoder."]
1767
+ remove_prefixes = ["text_encoders.t5xxl.transformer."]
1759
1768
 
1760
1769
  for key in keys:
1761
1770
  for prefix in remove_prefixes:
@@ -1799,3 +1808,4 @@ def create_diffusers_t5_model_from_checkpoint(
1799
1808
 
1800
1809
  else:
1801
1810
  model.load_state_dict(diffusers_format_checkpoint)
1811
+ return model
@@ -33,6 +33,7 @@ if is_torch_available():
33
33
  _import_structure["autoencoders.consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
34
34
  _import_structure["autoencoders.vq_model"] = ["VQModel"]
35
35
  _import_structure["controlnet"] = ["ControlNetModel"]
36
+ _import_structure["controlnet_sd3"] = ["SD3ControlNetModel", "SD3MultiControlNetModel"]
36
37
  _import_structure["controlnet_xs"] = ["ControlNetXSAdapter", "UNetControlNetXSModel"]
37
38
  _import_structure["embeddings"] = ["ImageProjection"]
38
39
  _import_structure["modeling_utils"] = ["ModelMixin"]
@@ -74,6 +75,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
74
75
  VQModel,
75
76
  )
76
77
  from .controlnet import ControlNetModel
78
+ from .controlnet_sd3 import SD3ControlNetModel, SD3MultiControlNetModel
77
79
  from .controlnet_xs import ControlNetXSAdapter, UNetControlNetXSModel
78
80
  from .embeddings import ImageProjection
79
81
  from .modeling_utils import ModelMixin