diffusers 0.23.1__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (176) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +0 -1
  4. diffusers/dependency_versions_table.py +4 -5
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +8 -7
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/METADATA +38 -22
  171. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/RECORD +175 -157
  172. diffusers/loaders.py +0 -3336
  173. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  174. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/WHEEL +0 -0
  175. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +0 -0
  176. {diffusers-0.23.1.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,735 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ import os
15
+ from collections import defaultdict
16
+ from contextlib import nullcontext
17
+ from typing import Callable, Dict, List, Optional, Union
18
+
19
+ import safetensors
20
+ import torch
21
+ import torch.nn.functional as F
22
+ from torch import nn
23
+
24
+ from ..models.embeddings import ImageProjection
25
+ from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, load_model_dict_into_meta
26
+ from ..utils import (
27
+ DIFFUSERS_CACHE,
28
+ HF_HUB_OFFLINE,
29
+ USE_PEFT_BACKEND,
30
+ _get_model_file,
31
+ delete_adapter_layers,
32
+ is_accelerate_available,
33
+ logging,
34
+ set_adapter_layers,
35
+ set_weights_and_activate_adapters,
36
+ )
37
+ from .utils import AttnProcsLayers
38
+
39
+
40
+ if is_accelerate_available():
41
+ from accelerate import init_empty_weights
42
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
43
+
44
+ logger = logging.get_logger(__name__)
45
+
46
+
47
+ TEXT_ENCODER_NAME = "text_encoder"
48
+ UNET_NAME = "unet"
49
+
50
+ LORA_WEIGHT_NAME = "pytorch_lora_weights.bin"
51
+ LORA_WEIGHT_NAME_SAFE = "pytorch_lora_weights.safetensors"
52
+
53
+ CUSTOM_DIFFUSION_WEIGHT_NAME = "pytorch_custom_diffusion_weights.bin"
54
+ CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE = "pytorch_custom_diffusion_weights.safetensors"
55
+
56
+
57
+ class UNet2DConditionLoadersMixin:
58
+ """
59
+ Load LoRA layers into a [`UNet2DCondtionModel`].
60
+ """
61
+
62
+ text_encoder_name = TEXT_ENCODER_NAME
63
+ unet_name = UNET_NAME
64
+
65
+ def load_attn_procs(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
66
+ r"""
67
+ Load pretrained attention processor layers into [`UNet2DConditionModel`]. Attention processor layers have to be
68
+ defined in
69
+ [`attention_processor.py`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py)
70
+ and be a `torch.nn.Module` class.
71
+
72
+ Parameters:
73
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
74
+ Can be either:
75
+
76
+ - A string, the model id (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
77
+ the Hub.
78
+ - A path to a directory (for example `./my_model_directory`) containing the model weights saved
79
+ with [`ModelMixin.save_pretrained`].
80
+ - A [torch state
81
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
82
+
83
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
84
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
85
+ is not used.
86
+ force_download (`bool`, *optional*, defaults to `False`):
87
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
88
+ cached versions if they exist.
89
+ resume_download (`bool`, *optional*, defaults to `False`):
90
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
91
+ incompletely downloaded files are deleted.
92
+ proxies (`Dict[str, str]`, *optional*):
93
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
94
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
95
+ local_files_only (`bool`, *optional*, defaults to `False`):
96
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
97
+ won't be downloaded from the Hub.
98
+ use_auth_token (`str` or *bool*, *optional*):
99
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
100
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
101
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
102
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
103
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
104
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
105
+ argument to `True` will raise an error.
106
+ revision (`str`, *optional*, defaults to `"main"`):
107
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
108
+ allowed by Git.
109
+ subfolder (`str`, *optional*, defaults to `""`):
110
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
111
+ mirror (`str`, *optional*):
112
+ Mirror source to resolve accessibility issues if you’re downloading a model in China. We do not
113
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
114
+ information.
115
+
116
+ Example:
117
+
118
+ ```py
119
+ from diffusers import AutoPipelineForText2Image
120
+ import torch
121
+
122
+ pipeline = AutoPipelineForText2Image.from_pretrained(
123
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
124
+ ).to("cuda")
125
+ pipeline.unet.load_attn_procs(
126
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
127
+ )
128
+ ```
129
+ """
130
+ from ..models.attention_processor import CustomDiffusionAttnProcessor
131
+ from ..models.lora import LoRACompatibleConv, LoRACompatibleLinear, LoRAConv2dLayer, LoRALinearLayer
132
+
133
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
134
+ force_download = kwargs.pop("force_download", False)
135
+ resume_download = kwargs.pop("resume_download", False)
136
+ proxies = kwargs.pop("proxies", None)
137
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
138
+ use_auth_token = kwargs.pop("use_auth_token", None)
139
+ revision = kwargs.pop("revision", None)
140
+ subfolder = kwargs.pop("subfolder", None)
141
+ weight_name = kwargs.pop("weight_name", None)
142
+ use_safetensors = kwargs.pop("use_safetensors", None)
143
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
144
+ # This value has the same meaning as the `--network_alpha` option in the kohya-ss trainer script.
145
+ # See https://github.com/darkstorm2150/sd-scripts/blob/main/docs/train_network_README-en.md#execute-learning
146
+ network_alphas = kwargs.pop("network_alphas", None)
147
+
148
+ _pipeline = kwargs.pop("_pipeline", None)
149
+
150
+ is_network_alphas_none = network_alphas is None
151
+
152
+ allow_pickle = False
153
+
154
+ if use_safetensors is None:
155
+ use_safetensors = True
156
+ allow_pickle = True
157
+
158
+ user_agent = {
159
+ "file_type": "attn_procs_weights",
160
+ "framework": "pytorch",
161
+ }
162
+
163
+ if low_cpu_mem_usage and not is_accelerate_available():
164
+ low_cpu_mem_usage = False
165
+ logger.warning(
166
+ "Cannot initialize model with low cpu memory usage because `accelerate` was not found in the"
167
+ " environment. Defaulting to `low_cpu_mem_usage=False`. It is strongly recommended to install"
168
+ " `accelerate` for faster and less memory-intense model loading. You can do so with: \n```\npip"
169
+ " install accelerate\n```\n."
170
+ )
171
+
172
+ model_file = None
173
+ if not isinstance(pretrained_model_name_or_path_or_dict, dict):
174
+ # Let's first try to load .safetensors weights
175
+ if (use_safetensors and weight_name is None) or (
176
+ weight_name is not None and weight_name.endswith(".safetensors")
177
+ ):
178
+ try:
179
+ model_file = _get_model_file(
180
+ pretrained_model_name_or_path_or_dict,
181
+ weights_name=weight_name or LORA_WEIGHT_NAME_SAFE,
182
+ cache_dir=cache_dir,
183
+ force_download=force_download,
184
+ resume_download=resume_download,
185
+ proxies=proxies,
186
+ local_files_only=local_files_only,
187
+ use_auth_token=use_auth_token,
188
+ revision=revision,
189
+ subfolder=subfolder,
190
+ user_agent=user_agent,
191
+ )
192
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
193
+ except IOError as e:
194
+ if not allow_pickle:
195
+ raise e
196
+ # try loading non-safetensors weights
197
+ pass
198
+ if model_file is None:
199
+ model_file = _get_model_file(
200
+ pretrained_model_name_or_path_or_dict,
201
+ weights_name=weight_name or LORA_WEIGHT_NAME,
202
+ cache_dir=cache_dir,
203
+ force_download=force_download,
204
+ resume_download=resume_download,
205
+ proxies=proxies,
206
+ local_files_only=local_files_only,
207
+ use_auth_token=use_auth_token,
208
+ revision=revision,
209
+ subfolder=subfolder,
210
+ user_agent=user_agent,
211
+ )
212
+ state_dict = torch.load(model_file, map_location="cpu")
213
+ else:
214
+ state_dict = pretrained_model_name_or_path_or_dict
215
+
216
+ # fill attn processors
217
+ lora_layers_list = []
218
+
219
+ is_lora = all(("lora" in k or k.endswith(".alpha")) for k in state_dict.keys()) and not USE_PEFT_BACKEND
220
+ is_custom_diffusion = any("custom_diffusion" in k for k in state_dict.keys())
221
+
222
+ if is_lora:
223
+ # correct keys
224
+ state_dict, network_alphas = self.convert_state_dict_legacy_attn_format(state_dict, network_alphas)
225
+
226
+ if network_alphas is not None:
227
+ network_alphas_keys = list(network_alphas.keys())
228
+ used_network_alphas_keys = set()
229
+
230
+ lora_grouped_dict = defaultdict(dict)
231
+ mapped_network_alphas = {}
232
+
233
+ all_keys = list(state_dict.keys())
234
+ for key in all_keys:
235
+ value = state_dict.pop(key)
236
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
237
+ lora_grouped_dict[attn_processor_key][sub_key] = value
238
+
239
+ # Create another `mapped_network_alphas` dictionary so that we can properly map them.
240
+ if network_alphas is not None:
241
+ for k in network_alphas_keys:
242
+ if k.replace(".alpha", "") in key:
243
+ mapped_network_alphas.update({attn_processor_key: network_alphas.get(k)})
244
+ used_network_alphas_keys.add(k)
245
+
246
+ if not is_network_alphas_none:
247
+ if len(set(network_alphas_keys) - used_network_alphas_keys) > 0:
248
+ raise ValueError(
249
+ f"The `network_alphas` has to be empty at this point but has the following keys \n\n {', '.join(network_alphas.keys())}"
250
+ )
251
+
252
+ if len(state_dict) > 0:
253
+ raise ValueError(
254
+ f"The `state_dict` has to be empty at this point but has the following keys \n\n {', '.join(state_dict.keys())}"
255
+ )
256
+
257
+ for key, value_dict in lora_grouped_dict.items():
258
+ attn_processor = self
259
+ for sub_key in key.split("."):
260
+ attn_processor = getattr(attn_processor, sub_key)
261
+
262
+ # Process non-attention layers, which don't have to_{k,v,q,out_proj}_lora layers
263
+ # or add_{k,v,q,out_proj}_proj_lora layers.
264
+ rank = value_dict["lora.down.weight"].shape[0]
265
+
266
+ if isinstance(attn_processor, LoRACompatibleConv):
267
+ in_features = attn_processor.in_channels
268
+ out_features = attn_processor.out_channels
269
+ kernel_size = attn_processor.kernel_size
270
+
271
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
272
+ with ctx():
273
+ lora = LoRAConv2dLayer(
274
+ in_features=in_features,
275
+ out_features=out_features,
276
+ rank=rank,
277
+ kernel_size=kernel_size,
278
+ stride=attn_processor.stride,
279
+ padding=attn_processor.padding,
280
+ network_alpha=mapped_network_alphas.get(key),
281
+ )
282
+ elif isinstance(attn_processor, LoRACompatibleLinear):
283
+ ctx = init_empty_weights if low_cpu_mem_usage else nullcontext
284
+ with ctx():
285
+ lora = LoRALinearLayer(
286
+ attn_processor.in_features,
287
+ attn_processor.out_features,
288
+ rank,
289
+ mapped_network_alphas.get(key),
290
+ )
291
+ else:
292
+ raise ValueError(f"Module {key} is not a LoRACompatibleConv or LoRACompatibleLinear module.")
293
+
294
+ value_dict = {k.replace("lora.", ""): v for k, v in value_dict.items()}
295
+ lora_layers_list.append((attn_processor, lora))
296
+
297
+ if low_cpu_mem_usage:
298
+ device = next(iter(value_dict.values())).device
299
+ dtype = next(iter(value_dict.values())).dtype
300
+ load_model_dict_into_meta(lora, value_dict, device=device, dtype=dtype)
301
+ else:
302
+ lora.load_state_dict(value_dict)
303
+
304
+ elif is_custom_diffusion:
305
+ attn_processors = {}
306
+ custom_diffusion_grouped_dict = defaultdict(dict)
307
+ for key, value in state_dict.items():
308
+ if len(value) == 0:
309
+ custom_diffusion_grouped_dict[key] = {}
310
+ else:
311
+ if "to_out" in key:
312
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-3]), ".".join(key.split(".")[-3:])
313
+ else:
314
+ attn_processor_key, sub_key = ".".join(key.split(".")[:-2]), ".".join(key.split(".")[-2:])
315
+ custom_diffusion_grouped_dict[attn_processor_key][sub_key] = value
316
+
317
+ for key, value_dict in custom_diffusion_grouped_dict.items():
318
+ if len(value_dict) == 0:
319
+ attn_processors[key] = CustomDiffusionAttnProcessor(
320
+ train_kv=False, train_q_out=False, hidden_size=None, cross_attention_dim=None
321
+ )
322
+ else:
323
+ cross_attention_dim = value_dict["to_k_custom_diffusion.weight"].shape[1]
324
+ hidden_size = value_dict["to_k_custom_diffusion.weight"].shape[0]
325
+ train_q_out = True if "to_q_custom_diffusion.weight" in value_dict else False
326
+ attn_processors[key] = CustomDiffusionAttnProcessor(
327
+ train_kv=True,
328
+ train_q_out=train_q_out,
329
+ hidden_size=hidden_size,
330
+ cross_attention_dim=cross_attention_dim,
331
+ )
332
+ attn_processors[key].load_state_dict(value_dict)
333
+ elif USE_PEFT_BACKEND:
334
+ # In that case we have nothing to do as loading the adapter weights is already handled above by `set_peft_model_state_dict`
335
+ # on the Unet
336
+ pass
337
+ else:
338
+ raise ValueError(
339
+ f"{model_file} does not seem to be in the correct format expected by LoRA or Custom Diffusion training."
340
+ )
341
+
342
+ # <Unsafe code
343
+ # We can be sure that the following works as it just sets attention processors, lora layers and puts all in the same dtype
344
+ # Now we remove any existing hooks to
345
+ is_model_cpu_offload = False
346
+ is_sequential_cpu_offload = False
347
+
348
+ # For PEFT backend the Unet is already offloaded at this stage as it is handled inside `lora_lora_weights_into_unet`
349
+ if not USE_PEFT_BACKEND:
350
+ if _pipeline is not None:
351
+ for _, component in _pipeline.components.items():
352
+ if isinstance(component, nn.Module) and hasattr(component, "_hf_hook"):
353
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
354
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
355
+
356
+ logger.info(
357
+ "Accelerate hooks detected. Since you have called `load_lora_weights()`, the previous hooks will be first removed. Then the LoRA parameters will be loaded and the hooks will be applied again."
358
+ )
359
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
360
+
361
+ # only custom diffusion needs to set attn processors
362
+ if is_custom_diffusion:
363
+ self.set_attn_processor(attn_processors)
364
+
365
+ # set lora layers
366
+ for target_module, lora_layer in lora_layers_list:
367
+ target_module.set_lora_layer(lora_layer)
368
+
369
+ self.to(dtype=self.dtype, device=self.device)
370
+
371
+ # Offload back.
372
+ if is_model_cpu_offload:
373
+ _pipeline.enable_model_cpu_offload()
374
+ elif is_sequential_cpu_offload:
375
+ _pipeline.enable_sequential_cpu_offload()
376
+ # Unsafe code />
377
+
378
+ def convert_state_dict_legacy_attn_format(self, state_dict, network_alphas):
379
+ is_new_lora_format = all(
380
+ key.startswith(self.unet_name) or key.startswith(self.text_encoder_name) for key in state_dict.keys()
381
+ )
382
+ if is_new_lora_format:
383
+ # Strip the `"unet"` prefix.
384
+ is_text_encoder_present = any(key.startswith(self.text_encoder_name) for key in state_dict.keys())
385
+ if is_text_encoder_present:
386
+ warn_message = "The state_dict contains LoRA params corresponding to the text encoder which are not being used here. To use both UNet and text encoder related LoRA params, use [`pipe.load_lora_weights()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraLoaderMixin.load_lora_weights)."
387
+ logger.warn(warn_message)
388
+ unet_keys = [k for k in state_dict.keys() if k.startswith(self.unet_name)]
389
+ state_dict = {k.replace(f"{self.unet_name}.", ""): v for k, v in state_dict.items() if k in unet_keys}
390
+
391
+ # change processor format to 'pure' LoRACompatibleLinear format
392
+ if any("processor" in k.split(".") for k in state_dict.keys()):
393
+
394
+ def format_to_lora_compatible(key):
395
+ if "processor" not in key.split("."):
396
+ return key
397
+ return key.replace(".processor", "").replace("to_out_lora", "to_out.0.lora").replace("_lora", ".lora")
398
+
399
+ state_dict = {format_to_lora_compatible(k): v for k, v in state_dict.items()}
400
+
401
+ if network_alphas is not None:
402
+ network_alphas = {format_to_lora_compatible(k): v for k, v in network_alphas.items()}
403
+ return state_dict, network_alphas
404
+
405
+ def save_attn_procs(
406
+ self,
407
+ save_directory: Union[str, os.PathLike],
408
+ is_main_process: bool = True,
409
+ weight_name: str = None,
410
+ save_function: Callable = None,
411
+ safe_serialization: bool = True,
412
+ **kwargs,
413
+ ):
414
+ r"""
415
+ Save attention processor layers to a directory so that it can be reloaded with the
416
+ [`~loaders.UNet2DConditionLoadersMixin.load_attn_procs`] method.
417
+
418
+ Arguments:
419
+ save_directory (`str` or `os.PathLike`):
420
+ Directory to save an attention processor to (will be created if it doesn't exist).
421
+ is_main_process (`bool`, *optional*, defaults to `True`):
422
+ Whether the process calling this is the main process or not. Useful during distributed training and you
423
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
424
+ process to avoid race conditions.
425
+ save_function (`Callable`):
426
+ The function to use to save the state dictionary. Useful during distributed training when you need to
427
+ replace `torch.save` with another method. Can be configured with the environment variable
428
+ `DIFFUSERS_SAVE_MODE`.
429
+ safe_serialization (`bool`, *optional*, defaults to `True`):
430
+ Whether to save the model using `safetensors` or with `pickle`.
431
+
432
+ Example:
433
+
434
+ ```py
435
+ import torch
436
+ from diffusers import DiffusionPipeline
437
+
438
+ pipeline = DiffusionPipeline.from_pretrained(
439
+ "CompVis/stable-diffusion-v1-4",
440
+ torch_dtype=torch.float16,
441
+ ).to("cuda")
442
+ pipeline.unet.load_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
443
+ pipeline.unet.save_attn_procs("path-to-save-model", weight_name="pytorch_custom_diffusion_weights.bin")
444
+ ```
445
+ """
446
+ from ..models.attention_processor import (
447
+ CustomDiffusionAttnProcessor,
448
+ CustomDiffusionAttnProcessor2_0,
449
+ CustomDiffusionXFormersAttnProcessor,
450
+ )
451
+
452
+ if os.path.isfile(save_directory):
453
+ logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
454
+ return
455
+
456
+ if save_function is None:
457
+ if safe_serialization:
458
+
459
+ def save_function(weights, filename):
460
+ return safetensors.torch.save_file(weights, filename, metadata={"format": "pt"})
461
+
462
+ else:
463
+ save_function = torch.save
464
+
465
+ os.makedirs(save_directory, exist_ok=True)
466
+
467
+ is_custom_diffusion = any(
468
+ isinstance(
469
+ x,
470
+ (CustomDiffusionAttnProcessor, CustomDiffusionAttnProcessor2_0, CustomDiffusionXFormersAttnProcessor),
471
+ )
472
+ for (_, x) in self.attn_processors.items()
473
+ )
474
+ if is_custom_diffusion:
475
+ model_to_save = AttnProcsLayers(
476
+ {
477
+ y: x
478
+ for (y, x) in self.attn_processors.items()
479
+ if isinstance(
480
+ x,
481
+ (
482
+ CustomDiffusionAttnProcessor,
483
+ CustomDiffusionAttnProcessor2_0,
484
+ CustomDiffusionXFormersAttnProcessor,
485
+ ),
486
+ )
487
+ }
488
+ )
489
+ state_dict = model_to_save.state_dict()
490
+ for name, attn in self.attn_processors.items():
491
+ if len(attn.state_dict()) == 0:
492
+ state_dict[name] = {}
493
+ else:
494
+ model_to_save = AttnProcsLayers(self.attn_processors)
495
+ state_dict = model_to_save.state_dict()
496
+
497
+ if weight_name is None:
498
+ if safe_serialization:
499
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME_SAFE if is_custom_diffusion else LORA_WEIGHT_NAME_SAFE
500
+ else:
501
+ weight_name = CUSTOM_DIFFUSION_WEIGHT_NAME if is_custom_diffusion else LORA_WEIGHT_NAME
502
+
503
+ # Save the model
504
+ save_function(state_dict, os.path.join(save_directory, weight_name))
505
+ logger.info(f"Model weights saved in {os.path.join(save_directory, weight_name)}")
506
+
507
+ def fuse_lora(self, lora_scale=1.0, safe_fusing=False):
508
+ self.lora_scale = lora_scale
509
+ self._safe_fusing = safe_fusing
510
+ self.apply(self._fuse_lora_apply)
511
+
512
+ def _fuse_lora_apply(self, module):
513
+ if not USE_PEFT_BACKEND:
514
+ if hasattr(module, "_fuse_lora"):
515
+ module._fuse_lora(self.lora_scale, self._safe_fusing)
516
+ else:
517
+ from peft.tuners.tuners_utils import BaseTunerLayer
518
+
519
+ if isinstance(module, BaseTunerLayer):
520
+ if self.lora_scale != 1.0:
521
+ module.scale_layer(self.lora_scale)
522
+ module.merge(safe_merge=self._safe_fusing)
523
+
524
+ def unfuse_lora(self):
525
+ self.apply(self._unfuse_lora_apply)
526
+
527
+ def _unfuse_lora_apply(self, module):
528
+ if not USE_PEFT_BACKEND:
529
+ if hasattr(module, "_unfuse_lora"):
530
+ module._unfuse_lora()
531
+ else:
532
+ from peft.tuners.tuners_utils import BaseTunerLayer
533
+
534
+ if isinstance(module, BaseTunerLayer):
535
+ module.unmerge()
536
+
537
+ def set_adapters(
538
+ self,
539
+ adapter_names: Union[List[str], str],
540
+ weights: Optional[Union[List[float], float]] = None,
541
+ ):
542
+ """
543
+ Set the currently active adapters for use in the UNet.
544
+
545
+ Args:
546
+ adapter_names (`List[str]` or `str`):
547
+ The names of the adapters to use.
548
+ adapter_weights (`Union[List[float], float]`, *optional*):
549
+ The adapter(s) weights to use with the UNet. If `None`, the weights are set to `1.0` for all the
550
+ adapters.
551
+
552
+ Example:
553
+
554
+ ```py
555
+ from diffusers import AutoPipelineForText2Image
556
+ import torch
557
+
558
+ pipeline = AutoPipelineForText2Image.from_pretrained(
559
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
560
+ ).to("cuda")
561
+ pipeline.load_lora_weights(
562
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
563
+ )
564
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
565
+ pipeline.set_adapters(["cinematic", "pixel"], adapter_weights=[0.5, 0.5])
566
+ ```
567
+ """
568
+ if not USE_PEFT_BACKEND:
569
+ raise ValueError("PEFT backend is required for `set_adapters()`.")
570
+
571
+ adapter_names = [adapter_names] if isinstance(adapter_names, str) else adapter_names
572
+
573
+ if weights is None:
574
+ weights = [1.0] * len(adapter_names)
575
+ elif isinstance(weights, float):
576
+ weights = [weights] * len(adapter_names)
577
+
578
+ if len(adapter_names) != len(weights):
579
+ raise ValueError(
580
+ f"Length of adapter names {len(adapter_names)} is not equal to the length of their weights {len(weights)}."
581
+ )
582
+
583
+ set_weights_and_activate_adapters(self, adapter_names, weights)
584
+
585
+ def disable_lora(self):
586
+ """
587
+ Disable the UNet's active LoRA layers.
588
+
589
+ Example:
590
+
591
+ ```py
592
+ from diffusers import AutoPipelineForText2Image
593
+ import torch
594
+
595
+ pipeline = AutoPipelineForText2Image.from_pretrained(
596
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
597
+ ).to("cuda")
598
+ pipeline.load_lora_weights(
599
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
600
+ )
601
+ pipeline.disable_lora()
602
+ ```
603
+ """
604
+ if not USE_PEFT_BACKEND:
605
+ raise ValueError("PEFT backend is required for this method.")
606
+ set_adapter_layers(self, enabled=False)
607
+
608
+ def enable_lora(self):
609
+ """
610
+ Enable the UNet's active LoRA layers.
611
+
612
+ Example:
613
+
614
+ ```py
615
+ from diffusers import AutoPipelineForText2Image
616
+ import torch
617
+
618
+ pipeline = AutoPipelineForText2Image.from_pretrained(
619
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
620
+ ).to("cuda")
621
+ pipeline.load_lora_weights(
622
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_name="cinematic"
623
+ )
624
+ pipeline.enable_lora()
625
+ ```
626
+ """
627
+ if not USE_PEFT_BACKEND:
628
+ raise ValueError("PEFT backend is required for this method.")
629
+ set_adapter_layers(self, enabled=True)
630
+
631
+ def delete_adapters(self, adapter_names: Union[List[str], str]):
632
+ """
633
+ Delete an adapter's LoRA layers from the UNet.
634
+
635
+ Args:
636
+ adapter_names (`Union[List[str], str]`):
637
+ The names (single string or list of strings) of the adapter to delete.
638
+
639
+ Example:
640
+
641
+ ```py
642
+ from diffusers import AutoPipelineForText2Image
643
+ import torch
644
+
645
+ pipeline = AutoPipelineForText2Image.from_pretrained(
646
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
647
+ ).to("cuda")
648
+ pipeline.load_lora_weights(
649
+ "jbilcke-hf/sdxl-cinematic-1", weight_name="pytorch_lora_weights.safetensors", adapter_names="cinematic"
650
+ )
651
+ pipeline.delete_adapters("cinematic")
652
+ ```
653
+ """
654
+ if not USE_PEFT_BACKEND:
655
+ raise ValueError("PEFT backend is required for this method.")
656
+
657
+ if isinstance(adapter_names, str):
658
+ adapter_names = [adapter_names]
659
+
660
+ for adapter_name in adapter_names:
661
+ delete_adapter_layers(self, adapter_name)
662
+
663
+ # Pop also the corresponding adapter from the config
664
+ if hasattr(self, "peft_config"):
665
+ self.peft_config.pop(adapter_name, None)
666
+
667
+ def _load_ip_adapter_weights(self, state_dict):
668
+ from ..models.attention_processor import (
669
+ AttnProcessor,
670
+ AttnProcessor2_0,
671
+ IPAdapterAttnProcessor,
672
+ IPAdapterAttnProcessor2_0,
673
+ )
674
+
675
+ # set ip-adapter cross-attention processors & load state_dict
676
+ attn_procs = {}
677
+ key_id = 1
678
+ for name in self.attn_processors.keys():
679
+ cross_attention_dim = None if name.endswith("attn1.processor") else self.config.cross_attention_dim
680
+ if name.startswith("mid_block"):
681
+ hidden_size = self.config.block_out_channels[-1]
682
+ elif name.startswith("up_blocks"):
683
+ block_id = int(name[len("up_blocks.")])
684
+ hidden_size = list(reversed(self.config.block_out_channels))[block_id]
685
+ elif name.startswith("down_blocks"):
686
+ block_id = int(name[len("down_blocks.")])
687
+ hidden_size = self.config.block_out_channels[block_id]
688
+ if cross_attention_dim is None or "motion_modules" in name:
689
+ attn_processor_class = (
690
+ AttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else AttnProcessor
691
+ )
692
+ attn_procs[name] = attn_processor_class()
693
+ else:
694
+ attn_processor_class = (
695
+ IPAdapterAttnProcessor2_0 if hasattr(F, "scaled_dot_product_attention") else IPAdapterAttnProcessor
696
+ )
697
+ attn_procs[name] = attn_processor_class(
698
+ hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, scale=1.0
699
+ ).to(dtype=self.dtype, device=self.device)
700
+
701
+ value_dict = {}
702
+ for k, w in attn_procs[name].state_dict().items():
703
+ value_dict.update({f"{k}": state_dict["ip_adapter"][f"{key_id}.{k}"]})
704
+
705
+ attn_procs[name].load_state_dict(value_dict)
706
+ key_id += 2
707
+
708
+ self.set_attn_processor(attn_procs)
709
+
710
+ # create image projection layers.
711
+ clip_embeddings_dim = state_dict["image_proj"]["proj.weight"].shape[-1]
712
+ cross_attention_dim = state_dict["image_proj"]["proj.weight"].shape[0] // 4
713
+
714
+ image_projection = ImageProjection(
715
+ cross_attention_dim=cross_attention_dim, image_embed_dim=clip_embeddings_dim, num_image_text_embeds=4
716
+ )
717
+ image_projection.to(dtype=self.dtype, device=self.device)
718
+
719
+ # load image projection layer weights
720
+ image_proj_state_dict = {}
721
+ image_proj_state_dict.update(
722
+ {
723
+ "image_embeds.weight": state_dict["image_proj"]["proj.weight"],
724
+ "image_embeds.bias": state_dict["image_proj"]["proj.bias"],
725
+ "norm.weight": state_dict["image_proj"]["norm.weight"],
726
+ "norm.bias": state_dict["image_proj"]["norm.bias"],
727
+ }
728
+ )
729
+
730
+ image_projection.load_state_dict(image_proj_state_dict)
731
+
732
+ self.encoder_hid_proj = image_projection.to(device=self.device, dtype=self.dtype)
733
+ self.config.encoder_hid_dim_type = "ip_image_proj"
734
+
735
+ delete_adapter_layers