diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +26 -1
- diffusers/configuration_utils.py +34 -29
- diffusers/dependency_versions_table.py +4 -0
- diffusers/image_processor.py +125 -12
- diffusers/loaders.py +169 -203
- diffusers/models/attention.py +24 -1
- diffusers/models/attention_flax.py +10 -5
- diffusers/models/attention_processor.py +3 -0
- diffusers/models/autoencoder_kl.py +114 -33
- diffusers/models/controlnet.py +131 -14
- diffusers/models/controlnet_flax.py +37 -26
- diffusers/models/cross_attention.py +17 -17
- diffusers/models/embeddings.py +67 -0
- diffusers/models/modeling_flax_utils.py +64 -56
- diffusers/models/modeling_utils.py +193 -104
- diffusers/models/prior_transformer.py +207 -37
- diffusers/models/resnet.py +26 -26
- diffusers/models/transformer_2d.py +36 -41
- diffusers/models/transformer_temporal.py +24 -21
- diffusers/models/unet_1d.py +31 -25
- diffusers/models/unet_2d.py +43 -30
- diffusers/models/unet_2d_blocks.py +210 -89
- diffusers/models/unet_2d_blocks_flax.py +12 -12
- diffusers/models/unet_2d_condition.py +172 -64
- diffusers/models/unet_2d_condition_flax.py +38 -24
- diffusers/models/unet_3d_blocks.py +34 -31
- diffusers/models/unet_3d_condition.py +101 -34
- diffusers/models/vae.py +5 -5
- diffusers/models/vae_flax.py +37 -34
- diffusers/models/vq_model.py +23 -14
- diffusers/pipelines/__init__.py +24 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
- diffusers/pipelines/consistency_models/__init__.py +1 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
- diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
- diffusers/pipelines/kandinsky/__init__.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
- diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_utils.py +124 -146
- diffusers/pipelines/shap_e/__init__.py +27 -0
- diffusers/pipelines/shap_e/camera.py +147 -0
- diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
- diffusers/pipelines/shap_e/renderer.py +709 -0
- diffusers/pipelines/stable_diffusion/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
- diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
- diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
- diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
- diffusers/schedulers/__init__.py +3 -0
- diffusers/schedulers/scheduling_consistency_models.py +380 -0
- diffusers/schedulers/scheduling_ddim.py +28 -6
- diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
- diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
- diffusers/schedulers/scheduling_ddpm.py +53 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
- diffusers/schedulers/scheduling_deis_multistep.py +66 -11
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
- diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
- diffusers/schedulers/scheduling_euler_discrete.py +58 -8
- diffusers/schedulers/scheduling_heun_discrete.py +89 -14
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
- diffusers/schedulers/scheduling_lms_discrete.py +57 -8
- diffusers/schedulers/scheduling_pndm.py +46 -10
- diffusers/schedulers/scheduling_repaint.py +19 -4
- diffusers/schedulers/scheduling_sde_ve.py +5 -1
- diffusers/schedulers/scheduling_unclip.py +43 -4
- diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
- diffusers/training_utils.py +1 -1
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
- diffusers/utils/hub_utils.py +1 -1
- diffusers/utils/import_utils.py +20 -3
- diffusers/utils/logging.py +15 -18
- diffusers/utils/outputs.py +3 -3
- diffusers/utils/testing_utils.py +15 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
- {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from functools import partial
|
|
22
22
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
23
23
|
|
24
24
|
import torch
|
25
|
-
from torch import Tensor, device
|
25
|
+
from torch import Tensor, device, nn
|
26
26
|
|
27
27
|
from .. import __version__
|
28
28
|
from ..utils import (
|
@@ -154,11 +154,10 @@ class ModelMixin(torch.nn.Module):
|
|
154
154
|
r"""
|
155
155
|
Base class for all models.
|
156
156
|
|
157
|
-
[`ModelMixin`] takes care of storing the configuration
|
158
|
-
|
157
|
+
[`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
|
158
|
+
saving models.
|
159
159
|
|
160
|
-
- **config_name** ([`str`]) --
|
161
|
-
[`~models.ModelMixin.save_pretrained`].
|
160
|
+
- **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
|
162
161
|
"""
|
163
162
|
config_name = CONFIG_NAME
|
164
163
|
_automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
|
@@ -190,18 +189,13 @@ class ModelMixin(torch.nn.Module):
|
|
190
189
|
def is_gradient_checkpointing(self) -> bool:
|
191
190
|
"""
|
192
191
|
Whether gradient checkpointing is activated for this model or not.
|
193
|
-
|
194
|
-
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
195
|
-
activations".
|
196
192
|
"""
|
197
193
|
return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
|
198
194
|
|
199
195
|
def enable_gradient_checkpointing(self):
|
200
196
|
"""
|
201
|
-
Activates gradient checkpointing for the current model
|
202
|
-
|
203
|
-
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
204
|
-
activations".
|
197
|
+
Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
198
|
+
*checkpoint activations* in other frameworks).
|
205
199
|
"""
|
206
200
|
if not self._supports_gradient_checkpointing:
|
207
201
|
raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
|
@@ -209,10 +203,8 @@ class ModelMixin(torch.nn.Module):
|
|
209
203
|
|
210
204
|
def disable_gradient_checkpointing(self):
|
211
205
|
"""
|
212
|
-
Deactivates gradient checkpointing for the current model
|
213
|
-
|
214
|
-
Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
|
215
|
-
activations".
|
206
|
+
Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
|
207
|
+
*checkpoint activations* in other frameworks).
|
216
208
|
"""
|
217
209
|
if self._supports_gradient_checkpointing:
|
218
210
|
self.apply(partial(self._set_gradient_checkpointing, value=False))
|
@@ -236,13 +228,17 @@ class ModelMixin(torch.nn.Module):
|
|
236
228
|
|
237
229
|
def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
|
238
230
|
r"""
|
239
|
-
Enable memory efficient attention
|
231
|
+
Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
232
|
+
|
233
|
+
When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
|
234
|
+
inference. Speed up during training is not guaranteed.
|
235
|
+
|
236
|
+
<Tip warning={true}>
|
240
237
|
|
241
|
-
When
|
242
|
-
|
238
|
+
⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
|
239
|
+
precedent.
|
243
240
|
|
244
|
-
|
245
|
-
is used.
|
241
|
+
</Tip>
|
246
242
|
|
247
243
|
Parameters:
|
248
244
|
attention_op (`Callable`, *optional*):
|
@@ -268,7 +264,7 @@ class ModelMixin(torch.nn.Module):
|
|
268
264
|
|
269
265
|
def disable_xformers_memory_efficient_attention(self):
|
270
266
|
r"""
|
271
|
-
Disable memory efficient attention
|
267
|
+
Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
|
272
268
|
"""
|
273
269
|
self.set_use_memory_efficient_attention_xformers(False)
|
274
270
|
|
@@ -281,24 +277,24 @@ class ModelMixin(torch.nn.Module):
|
|
281
277
|
variant: Optional[str] = None,
|
282
278
|
):
|
283
279
|
"""
|
284
|
-
Save a model and its configuration file to a directory
|
285
|
-
|
280
|
+
Save a model and its configuration file to a directory so that it can be reloaded using the
|
281
|
+
[`~models.ModelMixin.from_pretrained`] class method.
|
286
282
|
|
287
283
|
Arguments:
|
288
284
|
save_directory (`str` or `os.PathLike`):
|
289
|
-
Directory to
|
285
|
+
Directory to save a model and its configuration file to. Will be created if it doesn't exist.
|
290
286
|
is_main_process (`bool`, *optional*, defaults to `True`):
|
291
|
-
Whether the process calling this is the main process or not. Useful
|
292
|
-
|
293
|
-
|
287
|
+
Whether the process calling this is the main process or not. Useful during distributed training and you
|
288
|
+
need to call this function on all processes. In this case, set `is_main_process=True` only on the main
|
289
|
+
process to avoid race conditions.
|
294
290
|
save_function (`Callable`):
|
295
|
-
The function to use to save the state dictionary. Useful
|
296
|
-
|
291
|
+
The function to use to save the state dictionary. Useful during distributed training when you need to
|
292
|
+
replace `torch.save` with another method. Can be configured with the environment variable
|
297
293
|
`DIFFUSERS_SAVE_MODE`.
|
298
294
|
safe_serialization (`bool`, *optional*, defaults to `False`):
|
299
|
-
Whether to save the model using `safetensors` or the traditional PyTorch way
|
295
|
+
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
300
296
|
variant (`str`, *optional*):
|
301
|
-
If specified, weights are saved in the format pytorch_model.<variant>.bin
|
297
|
+
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
302
298
|
"""
|
303
299
|
if safe_serialization and not is_safetensors_available():
|
304
300
|
raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
|
@@ -335,107 +331,108 @@ class ModelMixin(torch.nn.Module):
|
|
335
331
|
@classmethod
|
336
332
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
337
333
|
r"""
|
338
|
-
Instantiate a pretrained
|
339
|
-
|
340
|
-
The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
|
341
|
-
the model, you should first set it back in training mode with `model.train()`.
|
342
|
-
|
343
|
-
The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
|
344
|
-
pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
|
345
|
-
task.
|
334
|
+
Instantiate a pretrained PyTorch model from a pretrained model configuration.
|
346
335
|
|
347
|
-
The
|
348
|
-
|
336
|
+
The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
|
337
|
+
train the model, set it back in training mode with `model.train()`.
|
349
338
|
|
350
339
|
Parameters:
|
351
340
|
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
352
341
|
Can be either:
|
353
342
|
|
354
|
-
- A string, the *model id* of a pretrained model hosted
|
355
|
-
|
356
|
-
- A path to a *directory* containing model weights saved
|
357
|
-
|
343
|
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
344
|
+
the Hub.
|
345
|
+
- A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
|
346
|
+
with [`~ModelMixin.save_pretrained`].
|
358
347
|
|
359
348
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
360
|
-
Path to a directory
|
361
|
-
|
349
|
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
350
|
+
is not used.
|
362
351
|
torch_dtype (`str` or `torch.dtype`, *optional*):
|
363
|
-
Override the default `torch.dtype` and load the model
|
364
|
-
|
352
|
+
Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
|
353
|
+
dtype is automatically derived from the model's weights.
|
365
354
|
force_download (`bool`, *optional*, defaults to `False`):
|
366
355
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
367
356
|
cached versions if they exist.
|
368
357
|
resume_download (`bool`, *optional*, defaults to `False`):
|
369
|
-
Whether or not to
|
370
|
-
|
358
|
+
Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
|
359
|
+
incompletely downloaded files are deleted.
|
371
360
|
proxies (`Dict[str, str]`, *optional*):
|
372
|
-
A dictionary of proxy servers to use by protocol or endpoint,
|
361
|
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
373
362
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
374
|
-
output_loading_info(`bool`, *optional*, defaults to `False`):
|
363
|
+
output_loading_info (`bool`, *optional*, defaults to `False`):
|
375
364
|
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
376
365
|
local_files_only(`bool`, *optional*, defaults to `False`):
|
377
|
-
Whether
|
366
|
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
367
|
+
won't be downloaded from the Hub.
|
378
368
|
use_auth_token (`str` or *bool*, *optional*):
|
379
|
-
The token to use as HTTP bearer authorization for remote files. If `True`,
|
380
|
-
|
369
|
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
370
|
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
381
371
|
revision (`str`, *optional*, defaults to `"main"`):
|
382
|
-
The specific model version to use. It can be a branch name, a tag name,
|
383
|
-
|
384
|
-
identifier allowed by git.
|
372
|
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
373
|
+
allowed by Git.
|
385
374
|
from_flax (`bool`, *optional*, defaults to `False`):
|
386
375
|
Load the model weights from a Flax checkpoint save file.
|
387
376
|
subfolder (`str`, *optional*, defaults to `""`):
|
388
|
-
|
389
|
-
huggingface.co or downloaded locally), you can specify the folder name here.
|
390
|
-
|
377
|
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
391
378
|
mirror (`str`, *optional*):
|
392
|
-
Mirror source to
|
393
|
-
|
394
|
-
|
379
|
+
Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
|
380
|
+
guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
|
381
|
+
information.
|
395
382
|
device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
|
396
|
-
A map that specifies where each submodule should go. It doesn't need to be
|
397
|
-
parameter/buffer name
|
383
|
+
A map that specifies where each submodule should go. It doesn't need to be defined for each
|
384
|
+
parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
|
398
385
|
same device.
|
399
386
|
|
400
|
-
|
387
|
+
Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
|
401
388
|
more information about each option see [designing a device
|
402
389
|
map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
|
403
390
|
max_memory (`Dict`, *optional*):
|
404
|
-
A dictionary device identifier
|
405
|
-
GPU and the available CPU RAM if unset.
|
391
|
+
A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
|
392
|
+
each GPU and the available CPU RAM if unset.
|
406
393
|
offload_folder (`str` or `os.PathLike`, *optional*):
|
407
|
-
|
394
|
+
The path to offload weights if `device_map` contains the value `"disk"`.
|
408
395
|
offload_state_dict (`bool`, *optional*):
|
409
|
-
If `True`,
|
410
|
-
|
411
|
-
|
396
|
+
If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
|
397
|
+
the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
|
398
|
+
when there is some disk offload.
|
412
399
|
low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
|
413
|
-
Speed up model loading
|
414
|
-
|
415
|
-
|
416
|
-
|
400
|
+
Speed up model loading only loading the pretrained weights and not initializing the weights. This also
|
401
|
+
tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
|
402
|
+
Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
|
403
|
+
argument to `True` will raise an error.
|
417
404
|
variant (`str`, *optional*):
|
418
|
-
|
419
|
-
|
405
|
+
Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
|
406
|
+
loading `from_flax`.
|
420
407
|
use_safetensors (`bool`, *optional*, defaults to `None`):
|
421
|
-
If set to `None`, the `safetensors` weights
|
422
|
-
`safetensors` library is installed. If set to `True`, the model
|
423
|
-
|
408
|
+
If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
|
409
|
+
`safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
|
410
|
+
weights. If set to `False`, `safetensors` weights are not loaded.
|
424
411
|
|
425
412
|
<Tip>
|
426
413
|
|
427
|
-
|
428
|
-
|
414
|
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
415
|
+
`huggingface-cli login`. You can also activate the special
|
416
|
+
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
417
|
+
firewalled environment.
|
429
418
|
|
430
419
|
</Tip>
|
431
420
|
|
432
|
-
|
421
|
+
Example:
|
433
422
|
|
434
|
-
|
435
|
-
|
423
|
+
```py
|
424
|
+
from diffusers import UNet2DConditionModel
|
436
425
|
|
437
|
-
|
426
|
+
unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
|
427
|
+
```
|
438
428
|
|
429
|
+
If you get the error message below, you need to finetune the weights for your downstream task:
|
430
|
+
|
431
|
+
```bash
|
432
|
+
Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
|
433
|
+
- conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
|
434
|
+
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
|
435
|
+
```
|
439
436
|
"""
|
440
437
|
cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
|
441
438
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
@@ -459,7 +456,7 @@ class ModelMixin(torch.nn.Module):
|
|
459
456
|
|
460
457
|
if use_safetensors and not is_safetensors_available():
|
461
458
|
raise ValueError(
|
462
|
-
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install
|
459
|
+
"`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
|
463
460
|
)
|
464
461
|
|
465
462
|
allow_pickle = False
|
@@ -646,15 +643,47 @@ class ModelMixin(torch.nn.Module):
|
|
646
643
|
else: # else let accelerate handle loading and dispatching.
|
647
644
|
# Load weights and dispatch according to the device_map
|
648
645
|
# by default the device_map is None and the weights are loaded on the CPU
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
653
|
-
|
654
|
-
|
655
|
-
|
656
|
-
|
657
|
-
|
646
|
+
try:
|
647
|
+
accelerate.load_checkpoint_and_dispatch(
|
648
|
+
model,
|
649
|
+
model_file,
|
650
|
+
device_map,
|
651
|
+
max_memory=max_memory,
|
652
|
+
offload_folder=offload_folder,
|
653
|
+
offload_state_dict=offload_state_dict,
|
654
|
+
dtype=torch_dtype,
|
655
|
+
)
|
656
|
+
except AttributeError as e:
|
657
|
+
# When using accelerate loading, we do not have the ability to load the state
|
658
|
+
# dict and rename the weight names manually. Additionally, accelerate skips
|
659
|
+
# torch loading conventions and directly writes into `module.{_buffers, _parameters}`
|
660
|
+
# (which look like they should be private variables?), so we can't use the standard hooks
|
661
|
+
# to rename parameters on load. We need to mimic the original weight names so the correct
|
662
|
+
# attributes are available. After we have loaded the weights, we convert the deprecated
|
663
|
+
# names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
|
664
|
+
# the weights so we don't have to do this again.
|
665
|
+
|
666
|
+
if "'Attention' object has no attribute" in str(e):
|
667
|
+
logger.warn(
|
668
|
+
f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
|
669
|
+
" was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
|
670
|
+
" names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
|
671
|
+
" so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
|
672
|
+
" please also re-upload it or open a PR on the original repository."
|
673
|
+
)
|
674
|
+
model._temp_convert_self_to_deprecated_attention_blocks()
|
675
|
+
accelerate.load_checkpoint_and_dispatch(
|
676
|
+
model,
|
677
|
+
model_file,
|
678
|
+
device_map,
|
679
|
+
max_memory=max_memory,
|
680
|
+
offload_folder=offload_folder,
|
681
|
+
offload_state_dict=offload_state_dict,
|
682
|
+
dtype=torch_dtype,
|
683
|
+
)
|
684
|
+
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
685
|
+
else:
|
686
|
+
raise e
|
658
687
|
|
659
688
|
loading_info = {
|
660
689
|
"missing_keys": [],
|
@@ -820,17 +849,27 @@ class ModelMixin(torch.nn.Module):
|
|
820
849
|
|
821
850
|
def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
|
822
851
|
"""
|
823
|
-
Get number of (
|
852
|
+
Get number of (trainable or non-embedding) parameters in the module.
|
824
853
|
|
825
854
|
Args:
|
826
855
|
only_trainable (`bool`, *optional*, defaults to `False`):
|
827
|
-
Whether or not to return only the number of trainable parameters
|
828
|
-
|
856
|
+
Whether or not to return only the number of trainable parameters.
|
829
857
|
exclude_embeddings (`bool`, *optional*, defaults to `False`):
|
830
|
-
Whether or not to return only the number of non-
|
858
|
+
Whether or not to return only the number of non-embedding parameters.
|
831
859
|
|
832
860
|
Returns:
|
833
861
|
`int`: The number of parameters.
|
862
|
+
|
863
|
+
Example:
|
864
|
+
|
865
|
+
```py
|
866
|
+
from diffusers import UNet2DConditionModel
|
867
|
+
|
868
|
+
model_id = "runwayml/stable-diffusion-v1-5"
|
869
|
+
unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
|
870
|
+
unet.num_parameters(only_trainable=True)
|
871
|
+
859520964
|
872
|
+
```
|
834
873
|
"""
|
835
874
|
|
836
875
|
if exclude_embeddings:
|
@@ -889,3 +928,53 @@ class ModelMixin(torch.nn.Module):
|
|
889
928
|
state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
|
890
929
|
if f"{path}.proj_attn.bias" in state_dict:
|
891
930
|
state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
|
931
|
+
|
932
|
+
def _temp_convert_self_to_deprecated_attention_blocks(self):
|
933
|
+
deprecated_attention_block_modules = []
|
934
|
+
|
935
|
+
def recursive_find_attn_block(module):
|
936
|
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
937
|
+
deprecated_attention_block_modules.append(module)
|
938
|
+
|
939
|
+
for sub_module in module.children():
|
940
|
+
recursive_find_attn_block(sub_module)
|
941
|
+
|
942
|
+
recursive_find_attn_block(self)
|
943
|
+
|
944
|
+
for module in deprecated_attention_block_modules:
|
945
|
+
module.query = module.to_q
|
946
|
+
module.key = module.to_k
|
947
|
+
module.value = module.to_v
|
948
|
+
module.proj_attn = module.to_out[0]
|
949
|
+
|
950
|
+
# We don't _have_ to delete the old attributes, but it's helpful to ensure
|
951
|
+
# that _all_ the weights are loaded into the new attributes and we're not
|
952
|
+
# making an incorrect assumption that this model should be converted when
|
953
|
+
# it really shouldn't be.
|
954
|
+
del module.to_q
|
955
|
+
del module.to_k
|
956
|
+
del module.to_v
|
957
|
+
del module.to_out
|
958
|
+
|
959
|
+
def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
|
960
|
+
deprecated_attention_block_modules = []
|
961
|
+
|
962
|
+
def recursive_find_attn_block(module):
|
963
|
+
if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
|
964
|
+
deprecated_attention_block_modules.append(module)
|
965
|
+
|
966
|
+
for sub_module in module.children():
|
967
|
+
recursive_find_attn_block(sub_module)
|
968
|
+
|
969
|
+
recursive_find_attn_block(self)
|
970
|
+
|
971
|
+
for module in deprecated_attention_block_modules:
|
972
|
+
module.to_q = module.query
|
973
|
+
module.to_k = module.key
|
974
|
+
module.to_v = module.value
|
975
|
+
module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
|
976
|
+
|
977
|
+
del module.query
|
978
|
+
del module.key
|
979
|
+
del module.value
|
980
|
+
del module.proj_attn
|