diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +15 -1
- diffusers/commands/env.py +1 -5
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +2 -1
- diffusers/loaders/__init__.py +2 -2
- diffusers/loaders/lora.py +406 -140
- diffusers/loaders/lora_conversion_utils.py +7 -1
- diffusers/loaders/single_file.py +13 -1
- diffusers/loaders/single_file_model.py +15 -8
- diffusers/loaders/single_file_utils.py +267 -17
- diffusers/loaders/unet.py +307 -272
- diffusers/models/__init__.py +7 -3
- diffusers/models/attention.py +125 -1
- diffusers/models/attention_processor.py +169 -1
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl.py +17 -6
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
- diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
- diffusers/models/autoencoders/vq_model.py +182 -0
- diffusers/models/controlnet_sd3.py +418 -0
- diffusers/models/controlnet_xs.py +6 -6
- diffusers/models/embeddings.py +112 -84
- diffusers/models/model_loading_utils.py +55 -0
- diffusers/models/modeling_utils.py +138 -20
- diffusers/models/normalization.py +11 -6
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/dual_transformer_2d.py +5 -4
- diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
- diffusers/models/transformers/prior_transformer.py +5 -5
- diffusers/models/transformers/transformer_2d.py +2 -2
- diffusers/models/transformers/transformer_sd3.py +353 -0
- diffusers/models/transformers/transformer_temporal.py +12 -10
- diffusers/models/unets/unet_1d.py +3 -3
- diffusers/models/unets/unet_2d.py +3 -3
- diffusers/models/unets/unet_2d_condition.py +4 -15
- diffusers/models/unets/unet_3d_condition.py +5 -17
- diffusers/models/unets/unet_i2vgen_xl.py +4 -4
- diffusers/models/unets/unet_motion_model.py +4 -4
- diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
- diffusers/models/vq_model.py +8 -165
- diffusers/pipelines/__init__.py +11 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
- diffusers/pipelines/auto_pipeline.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
- diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
- diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
- diffusers/pipelines/pia/pipeline_pia.py +4 -3
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
- diffusers/schedulers/__init__.py +2 -0
- diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
- diffusers/schedulers/scheduling_edm_euler.py +2 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
- diffusers/schedulers/scheduling_lms_discrete.py +2 -2
- diffusers/training_utils.py +4 -4
- diffusers/utils/__init__.py +3 -0
- diffusers/utils/constants.py +2 -0
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
- diffusers/utils/dynamic_modules_utils.py +15 -13
- diffusers/utils/hub_utils.py +106 -0
- diffusers/utils/import_utils.py +0 -1
- diffusers/utils/logging.py +3 -1
- diffusers/utils/state_dict_utils.py +2 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
- diffusers/models/dual_transformer_2d.py +0 -20
- diffusers/models/prior_transformer.py +0 -12
- diffusers/models/t5_film_transformer.py +0 -70
- diffusers/models/transformer_2d.py +0 -25
- diffusers/models/transformer_temporal.py +0 -34
- diffusers/models/unet_1d.py +0 -26
- diffusers/models/unet_1d_blocks.py +0 -203
- diffusers/models/unet_2d.py +0 -27
- diffusers/models/unet_2d_blocks.py +0 -375
- diffusers/models/unet_2d_condition.py +0 -25
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
diffusers/loaders/unet.py
CHANGED
@@ -33,34 +33,32 @@ from ..models.embeddings import (
|
|
33
33
|
IPAdapterPlusImageProjection,
|
34
34
|
MultiIPAdapterImageProjection,
|
35
35
|
)
|
36
|
-
from ..models.modeling_utils import
|
36
|
+
from ..models.modeling_utils import load_model_dict_into_meta, load_state_dict
|
37
37
|
from ..utils import (
|
38
38
|
USE_PEFT_BACKEND,
|
39
39
|
_get_model_file,
|
40
|
+
convert_unet_state_dict_to_peft,
|
40
41
|
delete_adapter_layers,
|
42
|
+
get_adapter_name,
|
43
|
+
get_peft_kwargs,
|
41
44
|
is_accelerate_available,
|
45
|
+
is_peft_version,
|
42
46
|
is_torch_version,
|
43
47
|
logging,
|
44
48
|
set_adapter_layers,
|
45
49
|
set_weights_and_activate_adapters,
|
46
50
|
)
|
51
|
+
from .lora import LORA_WEIGHT_NAME, LORA_WEIGHT_NAME_SAFE, TEXT_ENCODER_NAME, UNET_NAME
|
47
52
|
from .unet_loader_utils import _maybe_expand_lora_scales
|
48
53
|
from .utils import AttnProcsLayers
|
49
54
|
|
50
55
|
|
51
56
|
if is_accelerate_available():
|
52
|
-
from accelerate import init_empty_weights
|
53
57
|
from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
|
54
58
|
|
55
59
|
logger = logging.get_logger(__name__)
|
56
60
|
|
57
61
|
|
58
|
-
TEXT_ENCODER_NAME = "text_encoder"
|
59
|
-
UNET_NAME = "unet"
|
60
|
-
|
61
|
-
LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
|
62
|
-
LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
|
63
|
-
|
64
62
|
CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
|
65
63
|
CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
|
66
64
|
|
@@ -79,7 +77,8 @@ class UNet2DConditionLoadersMixin:
|
|
79
77
|
Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
|
80
78
|
defined in
|
81
79
|
[`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
|
82
|
-
and be a `torch.nn.Module` class.
|
80
|
+
and be a `torch.nn.Module` class. Currently supported: LoRA, Custom Diffusion. For LoRA, one must install
|
81
|
+
`peft`: `pip install -U peft`.
|
83
82
|
|
84
83
|
Parameters:
|
85
84
|
pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
|
@@ -110,20 +109,20 @@ class UNet2DConditionLoadersMixin:
|
|
110
109
|
token (`str` or *bool*, *optional*):
|
111
110
|
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
112
111
|
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
113
|
-
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
114
|
-
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
115
|
-
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
116
|
-
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
117
|
-
argument to `True` will raise an error.
|
118
112
|
revision (`str`, *optional*, defaults to `"main"`):
|
119
113
|
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
120
114
|
allowed by Git.
|
121
115
|
subfolder (`str`, *optional*, defaults to `""`):
|
122
116
|
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
123
|
-
|
124
|
-
|
125
|
-
|
126
|
-
|
117
|
+
network_alphas (`Dict[str, float]`):
|
118
|
+
The value of the network alpha used for stable learning and preventing underflow. This value has the
|
119
|
+
same meaning as the `--network_alpha` option in the kohya-ss trainer script. Refer to [this
|
120
|
+
link](https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning).
|
121
|
+
adapter_name (`str`, *optional*, defaults to None):
|
122
|
+
Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
|
123
|
+
`default_{i}` where i is the total number of adapters being loaded.
|
124
|
+
weight_name (`str`, *optional*, defaults to None):
|
125
|
+
Name of the serialized state dict file.
|
127
126
|
|
128
127
|
Example:
|
129
128
|
|
@@ -139,9 +138,6 @@ class UNet2DConditionLoadersMixin:
|
|
139
138
|
)
|
140
139
|
```
|
141
140
|
"""
|
142
|
-
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
143
|
-
from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
|
144
|
-
|
145
141
|
cache_dir = kwargs.pop("cache_dir", None)
|
146
142
|
force_download = kwargs.pop("force_download", False)
|
147
143
|
resume_download = kwargs.pop("resume_download", None)
|
@@ -152,15 +148,9 @@ class UNet2DConditionLoadersMixin:
|
|
152
148
|
subfolder = kwargs.pop("subfolder", None)
|
153
149
|
weight_name = kwargs.pop("weight_name", None)
|
154
150
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
155
|
-
|
156
|
-
# This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
|
157
|
-
# See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
|
158
|
-
network_alphas = kwargs.pop("network_alphas", None)
|
159
|
-
|
151
|
+
adapter_name = kwargs.pop("adapter_name", None)
|
160
152
|
_pipeline = kwargs.pop("_pipeline", None)
|
161
|
-
|
162
|
-
is_network_alphas_none = network_alphas is None
|
163
|
-
|
153
|
+
network_alphas = kwargs.pop("network_alphas", None)
|
164
154
|
allow_pickle = False
|
165
155
|
|
166
156
|
if use_safetensors is None:
|
@@ -216,198 +206,196 @@ class UNet2DConditionLoadersMixin:
|
|
216
206
|
else:
|
217
207
|
state_dict = pretrained_model_name_or_path_or_dict
|
218
208
|
|
219
|
-
# fill attn processors
|
220
|
-
lora_layers_list = []
|
221
|
-
|
222
|
-
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
|
223
209
|
is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
|
210
|
+
is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys())
|
211
|
+
is_model_cpu_offload = False
|
212
|
+
is_sequential_cpu_offload = False
|
224
213
|
|
225
|
-
if
|
226
|
-
|
227
|
-
|
214
|
+
if is_custom_diffusion:
|
215
|
+
attn_processors = self._process_custom_diffusion(state_dict=state_dict)
|
216
|
+
elif is_lora:
|
217
|
+
is_model_cpu_offload, is_sequential_cpu_offload = self._process_lora(
|
218
|
+
state_dict=state_dict,
|
219
|
+
unet_identifier_key=self.unet_name,
|
220
|
+
network_alphas=network_alphas,
|
221
|
+
adapter_name=adapter_name,
|
222
|
+
_pipeline=_pipeline,
|
223
|
+
)
|
224
|
+
else:
|
225
|
+
raise ValueError(
|
226
|
+
f"{model_file} does not seem to be in the correct format expected by Custom Diffusion training."
|
227
|
+
)
|
228
228
|
|
229
|
-
|
230
|
-
|
231
|
-
|
232
|
-
|
233
|
-
lora_grouped_dict = defaultdict(dict)
|
234
|
-
mapped_network_alphas = {}
|
235
|
-
|
236
|
-
all_keys = list(state_dict.keys())
|
237
|
-
for key in all_keys:
|
238
|
-
value = state_dict.pop(key)
|
239
|
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
240
|
-
lora_grouped_dict[attn_processor_key][sub_key] = value
|
241
|
-
|
242
|
-
# Create another `mapped_network_alphas` dictionary so that we can properly map them.
|
243
|
-
if network_alphas is not None:
|
244
|
-
for k in network_alphas_keys:
|
245
|
-
if k.replace(".alpha", "") in key:
|
246
|
-
mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
|
247
|
-
used_network_alphas_keys.add(k)
|
248
|
-
|
249
|
-
if not is_network_alphas_none:
|
250
|
-
if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
|
251
|
-
raise ValueError(
|
252
|
-
f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
|
253
|
-
)
|
229
|
+
# <Unsafe code
|
230
|
+
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
231
|
+
# Now we remove any existing hooks to `_pipeline`.
|
254
232
|
|
255
|
-
|
256
|
-
|
257
|
-
|
258
|
-
)
|
233
|
+
# For LoRA, the UNet is already offloaded at this stage as it is handled inside `_process_lora`.
|
234
|
+
if is_custom_diffusion and _pipeline is not None:
|
235
|
+
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline=_pipeline)
|
259
236
|
|
260
|
-
|
261
|
-
|
262
|
-
|
263
|
-
attn_processor = getattr(attn_processor, sub_key)
|
264
|
-
|
265
|
-
# Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
|
266
|
-
# or add_{k,v,q,out_proj}_proj_lora layers.
|
267
|
-
rank = value_dict["lora.down.weight"].shape[0]
|
268
|
-
|
269
|
-
if isinstance(attn_processor, LoRACompatibleConv):
|
270
|
-
in_features = attn_processor.in_channels
|
271
|
-
out_features = attn_processor.out_channels
|
272
|
-
kernel_size = attn_processor.kernel_size
|
273
|
-
|
274
|
-
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
275
|
-
with ctx():
|
276
|
-
lora = LoRAConv2dLayer(
|
277
|
-
in_features=in_features,
|
278
|
-
out_features=out_features,
|
279
|
-
rank=rank,
|
280
|
-
kernel_size=kernel_size,
|
281
|
-
stride=attn_processor.stride,
|
282
|
-
padding=attn_processor.padding,
|
283
|
-
network_alpha=mapped_network_alphas.get(key),
|
284
|
-
)
|
285
|
-
elif isinstance(attn_processor, LoRACompatibleLinear):
|
286
|
-
ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
|
287
|
-
with ctx():
|
288
|
-
lora = LoRALinearLayer(
|
289
|
-
attn_processor.in_features,
|
290
|
-
attn_processor.out_features,
|
291
|
-
rank,
|
292
|
-
mapped_network_alphas.get(key),
|
293
|
-
)
|
294
|
-
else:
|
295
|
-
raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
|
237
|
+
# only custom diffusion needs to set attn processors
|
238
|
+
self.set_attn_processor(attn_processors)
|
239
|
+
self.to(dtype=self.dtype, device=self.device)
|
296
240
|
|
297
|
-
|
298
|
-
|
241
|
+
# Offload back.
|
242
|
+
if is_model_cpu_offload:
|
243
|
+
_pipeline.enable_model_cpu_offload()
|
244
|
+
elif is_sequential_cpu_offload:
|
245
|
+
_pipeline.enable_sequential_cpu_offload()
|
246
|
+
# Unsafe code />
|
299
247
|
|
300
|
-
|
301
|
-
|
302
|
-
dtype = next(iter(value_dict.values())).dtype
|
303
|
-
load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
|
304
|
-
else:
|
305
|
-
lora.load_state_dict(value_dict)
|
248
|
+
def _process_custom_diffusion(self, state_dict):
|
249
|
+
from ..models.attention_processor import CustomDiffusionAttnProcessor
|
306
250
|
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
|
251
|
+
attn_processors = {}
|
252
|
+
custom_diffusion_grouped_dict = defaultdict(dict)
|
253
|
+
for key, value in state_dict.items():
|
254
|
+
if len(value) == 0:
|
255
|
+
custom_diffusion_grouped_dict[key] = {}
|
256
|
+
else:
|
257
|
+
if "to_out" in key:
|
258
|
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
|
313
259
|
else:
|
314
|
-
|
315
|
-
|
316
|
-
else:
|
317
|
-
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
318
|
-
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
260
|
+
attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
|
261
|
+
custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
|
319
262
|
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
|
341
|
-
|
342
|
-
|
343
|
-
|
263
|
+
for key, value_dict in custom_diffusion_grouped_dict.items():
|
264
|
+
if len(value_dict) == 0:
|
265
|
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
266
|
+
train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
|
267
|
+
)
|
268
|
+
else:
|
269
|
+
cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
|
270
|
+
hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
|
271
|
+
train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
|
272
|
+
attn_processors[key] = CustomDiffusionAttnProcessor(
|
273
|
+
train_kv=True,
|
274
|
+
train_q_out=train_q_out,
|
275
|
+
hidden_size=hidden_size,
|
276
|
+
cross_attention_dim=cross_attention_dim,
|
277
|
+
)
|
278
|
+
attn_processors[key].load_state_dict(value_dict)
|
279
|
+
|
280
|
+
return attn_processors
|
281
|
+
|
282
|
+
def _process_lora(self, state_dict, unet_identifier_key, network_alphas, adapter_name, _pipeline):
|
283
|
+
# This method does the following things:
|
284
|
+
# 1. Filters the `state_dict` with keys matching `unet_identifier_key` when using the non-legacy
|
285
|
+
# format. For legacy format no filtering is applied.
|
286
|
+
# 2. Converts the `state_dict` to the `peft` compatible format.
|
287
|
+
# 3. Creates a `LoraConfig` and then injects the converted `state_dict` into the UNet per the
|
288
|
+
# `LoraConfig` specs.
|
289
|
+
# 4. It also reports if the underlying `_pipeline` has any kind of offloading inside of it.
|
290
|
+
if not USE_PEFT_BACKEND:
|
291
|
+
raise ValueError("PEFT backend is required for this method.")
|
292
|
+
|
293
|
+
from peft import LoraConfig, inject_adapter_in_model, set_peft_model_state_dict
|
294
|
+
|
295
|
+
keys = list(state_dict.keys())
|
296
|
+
|
297
|
+
unet_keys = [k for k in keys if k.startswith(unet_identifier_key)]
|
298
|
+
unet_state_dict = {
|
299
|
+
k.replace(f"{unet_identifier_key}.", ""): v for k, v in state_dict.items() if k in unet_keys
|
300
|
+
}
|
301
|
+
|
302
|
+
if network_alphas is not None:
|
303
|
+
alpha_keys = [k for k in network_alphas.keys() if k.startswith(unet_identifier_key)]
|
304
|
+
network_alphas = {
|
305
|
+
k.replace(f"{unet_identifier_key}.", ""): v for k, v in network_alphas.items() if k in alpha_keys
|
306
|
+
}
|
344
307
|
|
345
|
-
# <Unsafe code
|
346
|
-
# We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
|
347
|
-
# Now we remove any existing hooks to
|
348
308
|
is_model_cpu_offload = False
|
349
309
|
is_sequential_cpu_offload = False
|
310
|
+
state_dict_to_be_used = unet_state_dict if len(unet_state_dict) > 0 else state_dict
|
350
311
|
|
351
|
-
|
352
|
-
|
353
|
-
|
354
|
-
|
355
|
-
|
356
|
-
is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
|
357
|
-
is_sequential_cpu_offload = (
|
358
|
-
isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
|
359
|
-
or hasattr(component._hf_hook, "hooks")
|
360
|
-
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
361
|
-
)
|
312
|
+
if len(state_dict_to_be_used) > 0:
|
313
|
+
if adapter_name in getattr(self, "peft_config", {}):
|
314
|
+
raise ValueError(
|
315
|
+
f"Adapter name {adapter_name} already in use in the Unet - please select a new adapter name."
|
316
|
+
)
|
362
317
|
|
363
|
-
|
364
|
-
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
365
|
-
)
|
366
|
-
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
318
|
+
state_dict = convert_unet_state_dict_to_peft(state_dict_to_be_used)
|
367
319
|
|
368
|
-
|
369
|
-
|
370
|
-
|
320
|
+
if network_alphas is not None:
|
321
|
+
# The alphas state dict have the same structure as Unet, thus we convert it to peft format using
|
322
|
+
# `convert_unet_state_dict_to_peft` method.
|
323
|
+
network_alphas = convert_unet_state_dict_to_peft(network_alphas)
|
324
|
+
|
325
|
+
rank = {}
|
326
|
+
for key, val in state_dict.items():
|
327
|
+
if "lora_B" in key:
|
328
|
+
rank[key] = val.shape[1]
|
329
|
+
|
330
|
+
lora_config_kwargs = get_peft_kwargs(rank, network_alphas, state_dict, is_unet=True)
|
331
|
+
if "use_dora" in lora_config_kwargs:
|
332
|
+
if lora_config_kwargs["use_dora"]:
|
333
|
+
if is_peft_version("<", "0.9.0"):
|
334
|
+
raise ValueError(
|
335
|
+
"You need `peft` 0.9.0 at least to use DoRA-enabled LoRAs. Please upgrade your installation of `peft`."
|
336
|
+
)
|
337
|
+
else:
|
338
|
+
if is_peft_version("<", "0.9.0"):
|
339
|
+
lora_config_kwargs.pop("use_dora")
|
340
|
+
lora_config = LoraConfig(**lora_config_kwargs)
|
341
|
+
|
342
|
+
# adapter_name
|
343
|
+
if adapter_name is None:
|
344
|
+
adapter_name = get_adapter_name(self)
|
345
|
+
|
346
|
+
# In case the pipeline has been already offloaded to CPU - temporarily remove the hooks
|
347
|
+
# otherwise loading LoRA weights will lead to an error
|
348
|
+
is_model_cpu_offload, is_sequential_cpu_offload = self._optionally_disable_offloading(_pipeline)
|
349
|
+
|
350
|
+
inject_adapter_in_model(lora_config, self, adapter_name=adapter_name)
|
351
|
+
incompatible_keys = set_peft_model_state_dict(self, state_dict, adapter_name)
|
352
|
+
|
353
|
+
if incompatible_keys is not None:
|
354
|
+
# check only for unexpected keys
|
355
|
+
unexpected_keys = getattr(incompatible_keys, "unexpected_keys", None)
|
356
|
+
if unexpected_keys:
|
357
|
+
logger.warning(
|
358
|
+
f"Loading adapter weights from state_dict led to unexpected keys not found in the model: "
|
359
|
+
f" {unexpected_keys}. "
|
360
|
+
)
|
371
361
|
|
372
|
-
|
373
|
-
for target_module, lora_layer in lora_layers_list:
|
374
|
-
target_module.set_lora_layer(lora_layer)
|
362
|
+
return is_model_cpu_offload, is_sequential_cpu_offload
|
375
363
|
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
elif is_sequential_cpu_offload:
|
382
|
-
_pipeline.enable_sequential_cpu_offload()
|
383
|
-
# Unsafe code />
|
364
|
+
@classmethod
|
365
|
+
# Copied from diffusers.loaders.lora.LoraLoaderMixin._optionally_disable_offloading
|
366
|
+
def _optionally_disable_offloading(cls, _pipeline):
|
367
|
+
"""
|
368
|
+
Optionally removes offloading in case the pipeline has been already sequentially offloaded to CPU.
|
384
369
|
|
385
|
-
|
386
|
-
|
387
|
-
|
388
|
-
)
|
389
|
-
if is_new_lora_format:
|
390
|
-
# Strip the `"unet"` prefix.
|
391
|
-
is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
|
392
|
-
if is_text_encoder_present:
|
393
|
-
warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
|
394
|
-
logger.warning(warn_message)
|
395
|
-
unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
|
396
|
-
state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
|
370
|
+
Args:
|
371
|
+
_pipeline (`DiffusionPipeline`):
|
372
|
+
The pipeline to disable offloading for.
|
397
373
|
|
398
|
-
|
399
|
-
|
374
|
+
Returns:
|
375
|
+
tuple:
|
376
|
+
A tuple indicating if `is_model_cpu_offload` or `is_sequential_cpu_offload` is True.
|
377
|
+
"""
|
378
|
+
is_model_cpu_offload = False
|
379
|
+
is_sequential_cpu_offload = False
|
400
380
|
|
401
|
-
|
402
|
-
|
403
|
-
|
404
|
-
|
381
|
+
if _pipeline is not None and _pipeline.hf_device_map is None:
|
382
|
+
for _, component in _pipeline.components.items():
|
383
|
+
if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
|
384
|
+
if not is_model_cpu_offload:
|
385
|
+
is_model_cpu_offload = isinstance(component._hf_hook, CpuOffload)
|
386
|
+
if not is_sequential_cpu_offload:
|
387
|
+
is_sequential_cpu_offload = (
|
388
|
+
isinstance(component._hf_hook, AlignDevicesHook)
|
389
|
+
or hasattr(component._hf_hook, "hooks")
|
390
|
+
and isinstance(component._hf_hook.hooks[0], AlignDevicesHook)
|
391
|
+
)
|
405
392
|
|
406
|
-
|
393
|
+
logger.info(
|
394
|
+
"Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
|
395
|
+
)
|
396
|
+
remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
|
407
397
|
|
408
|
-
|
409
|
-
network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
|
410
|
-
return state_dict, network_alphas
|
398
|
+
return (is_model_cpu_offload, is_sequential_cpu_offload)
|
411
399
|
|
412
400
|
def save_attn_procs(
|
413
401
|
self,
|
@@ -460,6 +448,23 @@ class UNet2DConditionLoadersMixin:
|
|
460
448
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
461
449
|
return
|
462
450
|
|
451
|
+
is_custom_diffusion = any(
|
452
|
+
isinstance(
|
453
|
+
x,
|
454
|
+
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
455
|
+
)
|
456
|
+
for (_, x) in self.attn_processors.items()
|
457
|
+
)
|
458
|
+
if is_custom_diffusion:
|
459
|
+
state_dict = self._get_custom_diffusion_state_dict()
|
460
|
+
else:
|
461
|
+
if not USE_PEFT_BACKEND:
|
462
|
+
raise ValueError("PEFT backend is required for saving LoRAs using the `save_attn_procs()` method.")
|
463
|
+
|
464
|
+
from peft.utils import get_peft_model_state_dict
|
465
|
+
|
466
|
+
state_dict = get_peft_model_state_dict(self)
|
467
|
+
|
463
468
|
if save_function is None:
|
464
469
|
if safe_serialization:
|
465
470
|
|
@@ -471,36 +476,6 @@ class UNet2DConditionLoadersMixin:
|
|
471
476
|
|
472
477
|
os.makedirs(save_directory, exist_ok=True)
|
473
478
|
|
474
|
-
is_custom_diffusion = any(
|
475
|
-
isinstance(
|
476
|
-
x,
|
477
|
-
(CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
|
478
|
-
)
|
479
|
-
for (_, x) in self.attn_processors.items()
|
480
|
-
)
|
481
|
-
if is_custom_diffusion:
|
482
|
-
model_to_save = AttnProcsLayers(
|
483
|
-
{
|
484
|
-
y: x
|
485
|
-
for (y, x) in self.attn_processors.items()
|
486
|
-
if isinstance(
|
487
|
-
x,
|
488
|
-
(
|
489
|
-
CustomDiffusionAttnProcessor,
|
490
|
-
CustomDiffusionAttnProcessor2_0,
|
491
|
-
CustomDiffusionXFormersAttnProcessor,
|
492
|
-
),
|
493
|
-
)
|
494
|
-
}
|
495
|
-
)
|
496
|
-
state_dict = model_to_save.state_dict()
|
497
|
-
for name, attn in self.attn_processors.items():
|
498
|
-
if len(attn.state_dict()) == 0:
|
499
|
-
state_dict[name] = {}
|
500
|
-
else:
|
501
|
-
model_to_save = AttnProcsLayers(self.attn_processors)
|
502
|
-
state_dict = model_to_save.state_dict()
|
503
|
-
|
504
479
|
if weight_name is None:
|
505
480
|
if safe_serialization:
|
506
481
|
weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
|
@@ -512,56 +487,84 @@ class UNet2DConditionLoadersMixin:
|
|
512
487
|
save_function(state_dict, save_path)
|
513
488
|
logger.info(f"Model weights saved in {save_path}")
|
514
489
|
|
490
|
+
def _get_custom_diffusion_state_dict(self):
|
491
|
+
from ..models.attention_processor import (
|
492
|
+
CustomDiffusionAttnProcessor,
|
493
|
+
CustomDiffusionAttnProcessor2_0,
|
494
|
+
CustomDiffusionXFormersAttnProcessor,
|
495
|
+
)
|
496
|
+
|
497
|
+
model_to_save = AttnProcsLayers(
|
498
|
+
{
|
499
|
+
y: x
|
500
|
+
for (y, x) in self.attn_processors.items()
|
501
|
+
if isinstance(
|
502
|
+
x,
|
503
|
+
(
|
504
|
+
CustomDiffusionAttnProcessor,
|
505
|
+
CustomDiffusionAttnProcessor2_0,
|
506
|
+
CustomDiffusionXFormersAttnProcessor,
|
507
|
+
),
|
508
|
+
)
|
509
|
+
}
|
510
|
+
)
|
511
|
+
state_dict = model_to_save.state_dict()
|
512
|
+
for name, attn in self.attn_processors.items():
|
513
|
+
if len(attn.state_dict()) == 0:
|
514
|
+
state_dict[name] = {}
|
515
|
+
|
516
|
+
return state_dict
|
517
|
+
|
515
518
|
def fuse_lora(self, lora_scale=1.0, safe_fusing=False, adapter_names=None):
|
519
|
+
if not USE_PEFT_BACKEND:
|
520
|
+
raise ValueError("PEFT backend is required for `fuse_lora()`.")
|
521
|
+
|
516
522
|
self.lora_scale = lora_scale
|
517
523
|
self._safe_fusing = safe_fusing
|
518
524
|
self.apply(partial(self._fuse_lora_apply, adapter_names=adapter_names))
|
519
525
|
|
520
526
|
def _fuse_lora_apply(self, module, adapter_names=None):
|
521
|
-
|
522
|
-
|
523
|
-
|
527
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
528
|
+
|
529
|
+
merge_kwargs = {"safe_merge": self._safe_fusing}
|
530
|
+
|
531
|
+
if isinstance(module, BaseTunerLayer):
|
532
|
+
if self.lora_scale != 1.0:
|
533
|
+
module.scale_layer(self.lora_scale)
|
524
534
|
|
525
|
-
|
535
|
+
# For BC with prevous PEFT versions, we need to check the signature
|
536
|
+
# of the `merge` method to see if it supports the `adapter_names` argument.
|
537
|
+
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
538
|
+
if "adapter_names" in supported_merge_kwargs:
|
539
|
+
merge_kwargs["adapter_names"] = adapter_names
|
540
|
+
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
526
541
|
raise ValueError(
|
527
|
-
"The `adapter_names` argument is not supported
|
528
|
-
" to
|
529
|
-
" `pip install -U peft transformers`"
|
542
|
+
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
543
|
+
" to the latest version of PEFT. `pip install -U peft`"
|
530
544
|
)
|
531
|
-
else:
|
532
|
-
from peft.tuners.tuners_utils import BaseTunerLayer
|
533
|
-
|
534
|
-
merge_kwargs = {"safe_merge": self._safe_fusing}
|
535
|
-
|
536
|
-
if isinstance(module, BaseTunerLayer):
|
537
|
-
if self.lora_scale != 1.0:
|
538
|
-
module.scale_layer(self.lora_scale)
|
539
|
-
|
540
|
-
# For BC with prevous PEFT versions, we need to check the signature
|
541
|
-
# of the `merge` method to see if it supports the `adapter_names` argument.
|
542
|
-
supported_merge_kwargs = list(inspect.signature(module.merge).parameters)
|
543
|
-
if "adapter_names" in supported_merge_kwargs:
|
544
|
-
merge_kwargs["adapter_names"] = adapter_names
|
545
|
-
elif "adapter_names" not in supported_merge_kwargs and adapter_names is not None:
|
546
|
-
raise ValueError(
|
547
|
-
"The `adapter_names` argument is not supported with your PEFT version. Please upgrade"
|
548
|
-
" to the latest version of PEFT. `pip install -U peft`"
|
549
|
-
)
|
550
545
|
|
551
|
-
|
546
|
+
module.merge(**merge_kwargs)
|
552
547
|
|
553
548
|
def unfuse_lora(self):
|
549
|
+
if not USE_PEFT_BACKEND:
|
550
|
+
raise ValueError("PEFT backend is required for `unfuse_lora()`.")
|
554
551
|
self.apply(self._unfuse_lora_apply)
|
555
552
|
|
556
553
|
def _unfuse_lora_apply(self, module):
|
554
|
+
from peft.tuners.tuners_utils import BaseTunerLayer
|
555
|
+
|
556
|
+
if isinstance(module, BaseTunerLayer):
|
557
|
+
module.unmerge()
|
558
|
+
|
559
|
+
def unload_lora(self):
|
557
560
|
if not USE_PEFT_BACKEND:
|
558
|
-
|
559
|
-
|
560
|
-
|
561
|
-
from peft.tuners.tuners_utils import BaseTunerLayer
|
561
|
+
raise ValueError("PEFT backend is required for `unload_lora()`.")
|
562
|
+
|
563
|
+
from ..utils import recurse_remove_peft_layers
|
562
564
|
|
563
|
-
|
564
|
-
|
565
|
+
recurse_remove_peft_layers(self)
|
566
|
+
if hasattr(self, "peft_config"):
|
567
|
+
del self.peft_config
|
565
568
|
|
566
569
|
def set_adapters(
|
567
570
|
self,
|
@@ -847,7 +850,12 @@ class UNet2DConditionLoadersMixin:
|
|
847
850
|
embed_dims = state_dict["proj_in.weight"].shape[1]
|
848
851
|
output_dims = state_dict["proj_out.weight"].shape[0]
|
849
852
|
hidden_dims = state_dict["latents"].shape[2]
|
850
|
-
|
853
|
+
attn_key_present = any("attn" in k for k in state_dict)
|
854
|
+
heads = (
|
855
|
+
state_dict["layers.0.attn.to_q.weight"].shape[0] // 64
|
856
|
+
if attn_key_present
|
857
|
+
else state_dict["layers.0.0.to_q.weight"].shape[0] // 64
|
858
|
+
)
|
851
859
|
|
852
860
|
with init_context():
|
853
861
|
image_projection = IPAdapterPlusImageProjection(
|
@@ -860,26 +868,53 @@ class UNet2DConditionLoadersMixin:
|
|
860
868
|
|
861
869
|
for key, value in state_dict.items():
|
862
870
|
diffusers_name = key.replace("0.to", "2.to")
|
863
|
-
diffusers_name = diffusers_name.replace("1.0.weight", "3.0.weight")
|
864
|
-
diffusers_name = diffusers_name.replace("1.0.bias", "3.0.bias")
|
865
|
-
diffusers_name = diffusers_name.replace("1.1.weight", "3.1.net.0.proj.weight")
|
866
|
-
diffusers_name = diffusers_name.replace("1.3.weight", "3.1.net.2.weight")
|
867
871
|
|
868
|
-
|
869
|
-
|
870
|
-
|
871
|
-
|
872
|
-
|
872
|
+
diffusers_name = diffusers_name.replace("0.0.norm1", "0.ln0")
|
873
|
+
diffusers_name = diffusers_name.replace("0.0.norm2", "0.ln1")
|
874
|
+
diffusers_name = diffusers_name.replace("1.0.norm1", "1.ln0")
|
875
|
+
diffusers_name = diffusers_name.replace("1.0.norm2", "1.ln1")
|
876
|
+
diffusers_name = diffusers_name.replace("2.0.norm1", "2.ln0")
|
877
|
+
diffusers_name = diffusers_name.replace("2.0.norm2", "2.ln1")
|
878
|
+
diffusers_name = diffusers_name.replace("3.0.norm1", "3.ln0")
|
879
|
+
diffusers_name = diffusers_name.replace("3.0.norm2", "3.ln1")
|
880
|
+
|
881
|
+
if "to_kv" in diffusers_name:
|
882
|
+
parts = diffusers_name.split(".")
|
883
|
+
parts[2] = "attn"
|
884
|
+
diffusers_name = ".".join(parts)
|
873
885
|
v_chunk = value.chunk(2, dim=0)
|
874
886
|
updated_state_dict[diffusers_name.replace("to_kv", "to_k")] = v_chunk[0]
|
875
887
|
updated_state_dict[diffusers_name.replace("to_kv", "to_v")] = v_chunk[1]
|
888
|
+
elif "to_q" in diffusers_name:
|
889
|
+
parts = diffusers_name.split(".")
|
890
|
+
parts[2] = "attn"
|
891
|
+
diffusers_name = ".".join(parts)
|
892
|
+
updated_state_dict[diffusers_name] = value
|
876
893
|
elif "to_out" in diffusers_name:
|
894
|
+
parts = diffusers_name.split(".")
|
895
|
+
parts[2] = "attn"
|
896
|
+
diffusers_name = ".".join(parts)
|
877
897
|
updated_state_dict[diffusers_name.replace("to_out", "to_out.0")] = value
|
878
898
|
else:
|
899
|
+
diffusers_name = diffusers_name.replace("0.1.0", "0.ff.0")
|
900
|
+
diffusers_name = diffusers_name.replace("0.1.1", "0.ff.1.net.0.proj")
|
901
|
+
diffusers_name = diffusers_name.replace("0.1.3", "0.ff.1.net.2")
|
902
|
+
|
903
|
+
diffusers_name = diffusers_name.replace("1.1.0", "1.ff.0")
|
904
|
+
diffusers_name = diffusers_name.replace("1.1.1", "1.ff.1.net.0.proj")
|
905
|
+
diffusers_name = diffusers_name.replace("1.1.3", "1.ff.1.net.2")
|
906
|
+
|
907
|
+
diffusers_name = diffusers_name.replace("2.1.0", "2.ff.0")
|
908
|
+
diffusers_name = diffusers_name.replace("2.1.1", "2.ff.1.net.0.proj")
|
909
|
+
diffusers_name = diffusers_name.replace("2.1.3", "2.ff.1.net.2")
|
910
|
+
|
911
|
+
diffusers_name = diffusers_name.replace("3.1.0", "3.ff.0")
|
912
|
+
diffusers_name = diffusers_name.replace("3.1.1", "3.ff.1.net.0.proj")
|
913
|
+
diffusers_name = diffusers_name.replace("3.1.3", "3.ff.1.net.2")
|
879
914
|
updated_state_dict[diffusers_name] = value
|
880
915
|
|
881
916
|
if not low_cpu_mem_usage:
|
882
|
-
image_projection.load_state_dict(updated_state_dict)
|
917
|
+
image_projection.load_state_dict(updated_state_dict, strict=True)
|
883
918
|
else:
|
884
919
|
load_model_dict_into_meta(image_projection, updated_state_dict, device=self.device, dtype=self.dtype)
|
885
920
|
|