diffusers 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +1 -14
  4. diffusers/dependency_versions_table.py +5 -4
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +11 -6
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. diffusers/utils/versions.py +117 -0
  171. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/METADATA +83 -64
  172. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/RECORD +176 -157
  173. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/WHEEL +1 -1
  174. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +1 -0
  175. diffusers/loaders.py +0 -3336
  176. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  177. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,459 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, List, Optional, Union
15
+
16
+ import safetensors
17
+ import torch
18
+ from torch import nn
19
+
20
+ from ..utils import (
21
+ DIFFUSERS_CACHE,
22
+ HF_HUB_OFFLINE,
23
+ _get_model_file,
24
+ is_accelerate_available,
25
+ is_transformers_available,
26
+ logging,
27
+ )
28
+
29
+
30
+ if is_transformers_available():
31
+ from transformers import PreTrainedModel, PreTrainedTokenizer
32
+
33
+ if is_accelerate_available():
34
+ from accelerate.hooks import AlignDevicesHook, CpuOffload, remove_hook_from_module
35
+
36
+ logger = logging.get_logger(__name__)
37
+
38
+ TEXT_INVERSION_NAME = "learned_embeds.bin"
39
+ TEXT_INVERSION_NAME_SAFE = "learned_embeds.safetensors"
40
+
41
+
42
+ def load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs):
43
+ cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
44
+ force_download = kwargs.pop("force_download", False)
45
+ resume_download = kwargs.pop("resume_download", False)
46
+ proxies = kwargs.pop("proxies", None)
47
+ local_files_only = kwargs.pop("local_files_only", HF_HUB_OFFLINE)
48
+ use_auth_token = kwargs.pop("use_auth_token", None)
49
+ revision = kwargs.pop("revision", None)
50
+ subfolder = kwargs.pop("subfolder", None)
51
+ weight_name = kwargs.pop("weight_name", None)
52
+ use_safetensors = kwargs.pop("use_safetensors", None)
53
+
54
+ allow_pickle = False
55
+ if use_safetensors is None:
56
+ use_safetensors = True
57
+ allow_pickle = True
58
+
59
+ user_agent = {
60
+ "file_type": "text_inversion",
61
+ "framework": "pytorch",
62
+ }
63
+ state_dicts = []
64
+ for pretrained_model_name_or_path in pretrained_model_name_or_paths:
65
+ if not isinstance(pretrained_model_name_or_path, (dict, torch.Tensor)):
66
+ # 3.1. Load textual inversion file
67
+ model_file = None
68
+
69
+ # Let's first try to load .safetensors weights
70
+ if (use_safetensors and weight_name is None) or (
71
+ weight_name is not None and weight_name.endswith(".safetensors")
72
+ ):
73
+ try:
74
+ model_file = _get_model_file(
75
+ pretrained_model_name_or_path,
76
+ weights_name=weight_name or TEXT_INVERSION_NAME_SAFE,
77
+ cache_dir=cache_dir,
78
+ force_download=force_download,
79
+ resume_download=resume_download,
80
+ proxies=proxies,
81
+ local_files_only=local_files_only,
82
+ use_auth_token=use_auth_token,
83
+ revision=revision,
84
+ subfolder=subfolder,
85
+ user_agent=user_agent,
86
+ )
87
+ state_dict = safetensors.torch.load_file(model_file, device="cpu")
88
+ except Exception as e:
89
+ if not allow_pickle:
90
+ raise e
91
+
92
+ model_file = None
93
+
94
+ if model_file is None:
95
+ model_file = _get_model_file(
96
+ pretrained_model_name_or_path,
97
+ weights_name=weight_name or TEXT_INVERSION_NAME,
98
+ cache_dir=cache_dir,
99
+ force_download=force_download,
100
+ resume_download=resume_download,
101
+ proxies=proxies,
102
+ local_files_only=local_files_only,
103
+ use_auth_token=use_auth_token,
104
+ revision=revision,
105
+ subfolder=subfolder,
106
+ user_agent=user_agent,
107
+ )
108
+ state_dict = torch.load(model_file, map_location="cpu")
109
+ else:
110
+ state_dict = pretrained_model_name_or_path
111
+
112
+ state_dicts.append(state_dict)
113
+
114
+ return state_dicts
115
+
116
+
117
+ class TextualInversionLoaderMixin:
118
+ r"""
119
+ Load Textual Inversion tokens and embeddings to the tokenizer and text encoder.
120
+ """
121
+
122
+ def maybe_convert_prompt(self, prompt: Union[str, List[str]], tokenizer: "PreTrainedTokenizer"): # noqa: F821
123
+ r"""
124
+ Processes prompts that include a special token corresponding to a multi-vector textual inversion embedding to
125
+ be replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
126
+ inversion token or if the textual inversion token is a single vector, the input prompt is returned.
127
+
128
+ Parameters:
129
+ prompt (`str` or list of `str`):
130
+ The prompt or prompts to guide the image generation.
131
+ tokenizer (`PreTrainedTokenizer`):
132
+ The tokenizer responsible for encoding the prompt into input tokens.
133
+
134
+ Returns:
135
+ `str` or list of `str`: The converted prompt
136
+ """
137
+ if not isinstance(prompt, List):
138
+ prompts = [prompt]
139
+ else:
140
+ prompts = prompt
141
+
142
+ prompts = [self._maybe_convert_prompt(p, tokenizer) for p in prompts]
143
+
144
+ if not isinstance(prompt, List):
145
+ return prompts[0]
146
+
147
+ return prompts
148
+
149
+ def _maybe_convert_prompt(self, prompt: str, tokenizer: "PreTrainedTokenizer"): # noqa: F821
150
+ r"""
151
+ Maybe convert a prompt into a "multi vector"-compatible prompt. If the prompt includes a token that corresponds
152
+ to a multi-vector textual inversion embedding, this function will process the prompt so that the special token
153
+ is replaced with multiple special tokens each corresponding to one of the vectors. If the prompt has no textual
154
+ inversion token or a textual inversion token that is a single vector, the input prompt is simply returned.
155
+
156
+ Parameters:
157
+ prompt (`str`):
158
+ The prompt to guide the image generation.
159
+ tokenizer (`PreTrainedTokenizer`):
160
+ The tokenizer responsible for encoding the prompt into input tokens.
161
+
162
+ Returns:
163
+ `str`: The converted prompt
164
+ """
165
+ tokens = tokenizer.tokenize(prompt)
166
+ unique_tokens = set(tokens)
167
+ for token in unique_tokens:
168
+ if token in tokenizer.added_tokens_encoder:
169
+ replacement = token
170
+ i = 1
171
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
172
+ replacement += f" {token}_{i}"
173
+ i += 1
174
+
175
+ prompt = prompt.replace(token, replacement)
176
+
177
+ return prompt
178
+
179
+ def _check_text_inv_inputs(self, tokenizer, text_encoder, pretrained_model_name_or_paths, tokens):
180
+ if tokenizer is None:
181
+ raise ValueError(
182
+ f"{self.__class__.__name__} requires `self.tokenizer` or passing a `tokenizer` of type `PreTrainedTokenizer` for calling"
183
+ f" `{self.load_textual_inversion.__name__}`"
184
+ )
185
+
186
+ if text_encoder is None:
187
+ raise ValueError(
188
+ f"{self.__class__.__name__} requires `self.text_encoder` or passing a `text_encoder` of type `PreTrainedModel` for calling"
189
+ f" `{self.load_textual_inversion.__name__}`"
190
+ )
191
+
192
+ if len(pretrained_model_name_or_paths) > 1 and len(pretrained_model_name_or_paths) != len(tokens):
193
+ raise ValueError(
194
+ f"You have passed a list of models of length {len(pretrained_model_name_or_paths)}, and list of tokens of length {len(tokens)} "
195
+ f"Make sure both lists have the same length."
196
+ )
197
+
198
+ valid_tokens = [t for t in tokens if t is not None]
199
+ if len(set(valid_tokens)) < len(valid_tokens):
200
+ raise ValueError(f"You have passed a list of tokens that contains duplicates: {tokens}")
201
+
202
+ @staticmethod
203
+ def _retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer):
204
+ all_tokens = []
205
+ all_embeddings = []
206
+ for state_dict, token in zip(state_dicts, tokens):
207
+ if isinstance(state_dict, torch.Tensor):
208
+ if token is None:
209
+ raise ValueError(
210
+ "You are trying to load a textual inversion embedding that has been saved as a PyTorch tensor. Make sure to pass the name of the corresponding token in this case: `token=...`."
211
+ )
212
+ loaded_token = token
213
+ embedding = state_dict
214
+ elif len(state_dict) == 1:
215
+ # diffusers
216
+ loaded_token, embedding = next(iter(state_dict.items()))
217
+ elif "string_to_param" in state_dict:
218
+ # A1111
219
+ loaded_token = state_dict["name"]
220
+ embedding = state_dict["string_to_param"]["*"]
221
+ else:
222
+ raise ValueError(
223
+ f"Loaded state dictonary is incorrect: {state_dict}. \n\n"
224
+ "Please verify that the loaded state dictionary of the textual embedding either only has a single key or includes the `string_to_param`"
225
+ " input key."
226
+ )
227
+
228
+ if token is not None and loaded_token != token:
229
+ logger.info(f"The loaded token: {loaded_token} is overwritten by the passed token {token}.")
230
+ else:
231
+ token = loaded_token
232
+
233
+ if token in tokenizer.get_vocab():
234
+ raise ValueError(
235
+ f"Token {token} already in tokenizer vocabulary. Please choose a different token name or remove {token} and embedding from the tokenizer and text encoder."
236
+ )
237
+
238
+ all_tokens.append(token)
239
+ all_embeddings.append(embedding)
240
+
241
+ return all_tokens, all_embeddings
242
+
243
+ @staticmethod
244
+ def _extend_tokens_and_embeddings(tokens, embeddings, tokenizer):
245
+ all_tokens = []
246
+ all_embeddings = []
247
+
248
+ for embedding, token in zip(embeddings, tokens):
249
+ if f"{token}_1" in tokenizer.get_vocab():
250
+ multi_vector_tokens = [token]
251
+ i = 1
252
+ while f"{token}_{i}" in tokenizer.added_tokens_encoder:
253
+ multi_vector_tokens.append(f"{token}_{i}")
254
+ i += 1
255
+
256
+ raise ValueError(
257
+ f"Multi-vector Token {multi_vector_tokens} already in tokenizer vocabulary. Please choose a different token name or remove the {multi_vector_tokens} and embedding from the tokenizer and text encoder."
258
+ )
259
+
260
+ is_multi_vector = len(embedding.shape) > 1 and embedding.shape[0] > 1
261
+ if is_multi_vector:
262
+ all_tokens += [token] + [f"{token}_{i}" for i in range(1, embedding.shape[0])]
263
+ all_embeddings += [e for e in embedding] # noqa: C416
264
+ else:
265
+ all_tokens += [token]
266
+ all_embeddings += [embedding[0]] if len(embedding.shape) > 1 else [embedding]
267
+
268
+ return all_tokens, all_embeddings
269
+
270
+ def load_textual_inversion(
271
+ self,
272
+ pretrained_model_name_or_path: Union[str, List[str], Dict[str, torch.Tensor], List[Dict[str, torch.Tensor]]],
273
+ token: Optional[Union[str, List[str]]] = None,
274
+ tokenizer: Optional["PreTrainedTokenizer"] = None, # noqa: F821
275
+ text_encoder: Optional["PreTrainedModel"] = None, # noqa: F821
276
+ **kwargs,
277
+ ):
278
+ r"""
279
+ Load Textual Inversion embeddings into the text encoder of [`StableDiffusionPipeline`] (both 🤗 Diffusers and
280
+ Automatic1111 formats are supported).
281
+
282
+ Parameters:
283
+ pretrained_model_name_or_path (`str` or `os.PathLike` or `List[str or os.PathLike]` or `Dict` or `List[Dict]`):
284
+ Can be either one of the following or a list of them:
285
+
286
+ - A string, the *model id* (for example `sd-concepts-library/low-poly-hd-logos-icons`) of a
287
+ pretrained model hosted on the Hub.
288
+ - A path to a *directory* (for example `./my_text_inversion_directory/`) containing the textual
289
+ inversion weights.
290
+ - A path to a *file* (for example `./my_text_inversions.pt`) containing textual inversion weights.
291
+ - A [torch state
292
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
293
+
294
+ token (`str` or `List[str]`, *optional*):
295
+ Override the token to use for the textual inversion weights. If `pretrained_model_name_or_path` is a
296
+ list, then `token` must also be a list of equal length.
297
+ text_encoder ([`~transformers.CLIPTextModel`], *optional*):
298
+ Frozen text-encoder ([clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14)).
299
+ If not specified, function will take self.tokenizer.
300
+ tokenizer ([`~transformers.CLIPTokenizer`], *optional*):
301
+ A `CLIPTokenizer` to tokenize text. If not specified, function will take self.tokenizer.
302
+ weight_name (`str`, *optional*):
303
+ Name of a custom weight file. This should be used when:
304
+
305
+ - The saved textual inversion file is in 🤗 Diffusers format, but was saved under a specific weight
306
+ name such as `text_inv.bin`.
307
+ - The saved textual inversion file is in the Automatic1111 format.
308
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
309
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
310
+ is not used.
311
+ force_download (`bool`, *optional*, defaults to `False`):
312
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
313
+ cached versions if they exist.
314
+ resume_download (`bool`, *optional*, defaults to `False`):
315
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
316
+ incompletely downloaded files are deleted.
317
+ proxies (`Dict[str, str]`, *optional*):
318
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
319
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
320
+ local_files_only (`bool`, *optional*, defaults to `False`):
321
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
322
+ won't be downloaded from the Hub.
323
+ use_auth_token (`str` or *bool*, *optional*):
324
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
325
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
326
+ revision (`str`, *optional*, defaults to `"main"`):
327
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
328
+ allowed by Git.
329
+ subfolder (`str`, *optional*, defaults to `""`):
330
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
331
+ mirror (`str`, *optional*):
332
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
333
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
334
+ information.
335
+
336
+ Example:
337
+
338
+ To load a Textual Inversion embedding vector in 🤗 Diffusers format:
339
+
340
+ ```py
341
+ from diffusers import StableDiffusionPipeline
342
+ import torch
343
+
344
+ model_id = "runwayml/stable-diffusion-v1-5"
345
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
346
+
347
+ pipe.load_textual_inversion("sd-concepts-library/cat-toy")
348
+
349
+ prompt = "A <cat-toy> backpack"
350
+
351
+ image = pipe(prompt, num_inference_steps=50).images[0]
352
+ image.save("cat-backpack.png")
353
+ ```
354
+
355
+ To load a Textual Inversion embedding vector in Automatic1111 format, make sure to download the vector first
356
+ (for example from [civitAI](https://civitai.com/models/3036?modelVersionId=9857)) and then load the vector
357
+ locally:
358
+
359
+ ```py
360
+ from diffusers import StableDiffusionPipeline
361
+ import torch
362
+
363
+ model_id = "runwayml/stable-diffusion-v1-5"
364
+ pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16).to("cuda")
365
+
366
+ pipe.load_textual_inversion("./charturnerv2.pt", token="charturnerv2")
367
+
368
+ prompt = "charturnerv2, multiple views of the same character in the same outfit, a character turnaround of a woman wearing a black jacket and red shirt, best quality, intricate details."
369
+
370
+ image = pipe(prompt, num_inference_steps=50).images[0]
371
+ image.save("character.png")
372
+ ```
373
+
374
+ """
375
+ # 1. Set correct tokenizer and text encoder
376
+ tokenizer = tokenizer or getattr(self, "tokenizer", None)
377
+ text_encoder = text_encoder or getattr(self, "text_encoder", None)
378
+
379
+ # 2. Normalize inputs
380
+ pretrained_model_name_or_paths = (
381
+ [pretrained_model_name_or_path]
382
+ if not isinstance(pretrained_model_name_or_path, list)
383
+ else pretrained_model_name_or_path
384
+ )
385
+ tokens = [token] if not isinstance(token, list) else token
386
+ if tokens[0] is None:
387
+ tokens = tokens * len(pretrained_model_name_or_paths)
388
+
389
+ # 3. Check inputs
390
+ self._check_text_inv_inputs(tokenizer, text_encoder, pretrained_model_name_or_paths, tokens)
391
+
392
+ # 4. Load state dicts of textual embeddings
393
+ state_dicts = load_textual_inversion_state_dicts(pretrained_model_name_or_paths, **kwargs)
394
+
395
+ # 4.1 Handle the special case when state_dict is a tensor that contains n embeddings for n tokens
396
+ if len(tokens) > 1 and len(state_dicts) == 1:
397
+ if isinstance(state_dicts[0], torch.Tensor):
398
+ state_dicts = list(state_dicts[0])
399
+ if len(tokens) != len(state_dicts):
400
+ raise ValueError(
401
+ f"You have passed a state_dict contains {len(state_dicts)} embeddings, and list of tokens of length {len(tokens)} "
402
+ f"Make sure both have the same length."
403
+ )
404
+
405
+ # 4. Retrieve tokens and embeddings
406
+ tokens, embeddings = self._retrieve_tokens_and_embeddings(tokens, state_dicts, tokenizer)
407
+
408
+ # 5. Extend tokens and embeddings for multi vector
409
+ tokens, embeddings = self._extend_tokens_and_embeddings(tokens, embeddings, tokenizer)
410
+
411
+ # 6. Make sure all embeddings have the correct size
412
+ expected_emb_dim = text_encoder.get_input_embeddings().weight.shape[-1]
413
+ if any(expected_emb_dim != emb.shape[-1] for emb in embeddings):
414
+ raise ValueError(
415
+ "Loaded embeddings are of incorrect shape. Expected each textual inversion embedding "
416
+ "to be of shape {input_embeddings.shape[-1]}, but are {embeddings.shape[-1]} "
417
+ )
418
+
419
+ # 7. Now we can be sure that loading the embedding matrix works
420
+ # < Unsafe code:
421
+
422
+ # 7.1 Offload all hooks in case the pipeline was cpu offloaded before make sure, we offload and onload again
423
+ is_model_cpu_offload = False
424
+ is_sequential_cpu_offload = False
425
+ for _, component in self.components.items():
426
+ if isinstance(component, nn.Module):
427
+ if hasattr(component, "_hf_hook"):
428
+ is_model_cpu_offload = isinstance(getattr(component, "_hf_hook"), CpuOffload)
429
+ is_sequential_cpu_offload = isinstance(getattr(component, "_hf_hook"), AlignDevicesHook)
430
+ logger.info(
431
+ "Accelerate hooks detected. Since you have called `load_textual_inversion()`, the previous hooks will be first removed. Then the textual inversion parameters will be loaded and the hooks will be applied again."
432
+ )
433
+ remove_hook_from_module(component, recurse=is_sequential_cpu_offload)
434
+
435
+ # 7.2 save expected device and dtype
436
+ device = text_encoder.device
437
+ dtype = text_encoder.dtype
438
+
439
+ # 7.3 Increase token embedding matrix
440
+ text_encoder.resize_token_embeddings(len(tokenizer) + len(tokens))
441
+ input_embeddings = text_encoder.get_input_embeddings().weight
442
+
443
+ # 7.4 Load token and embedding
444
+ for token, embedding in zip(tokens, embeddings):
445
+ # add tokens and get ids
446
+ tokenizer.add_tokens(token)
447
+ token_id = tokenizer.convert_tokens_to_ids(token)
448
+ input_embeddings.data[token_id] = embedding
449
+ logger.info(f"Loaded textual inversion embedding for {token}.")
450
+
451
+ input_embeddings.to(dtype=dtype, device=device)
452
+
453
+ # 7.5 Offload the model again
454
+ if is_model_cpu_offload:
455
+ self.enable_model_cpu_offload()
456
+ elif is_sequential_cpu_offload:
457
+ self.enable_sequential_cpu_offload()
458
+
459
+ # / Unsafe Code >