diffusers 0.17.1__py3-none-any.whl → 0.18.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (120) hide show
  1. diffusers/__init__.py +26 -1
  2. diffusers/configuration_utils.py +34 -29
  3. diffusers/dependency_versions_table.py +4 -0
  4. diffusers/image_processor.py +125 -12
  5. diffusers/loaders.py +169 -203
  6. diffusers/models/attention.py +24 -1
  7. diffusers/models/attention_flax.py +10 -5
  8. diffusers/models/attention_processor.py +3 -0
  9. diffusers/models/autoencoder_kl.py +114 -33
  10. diffusers/models/controlnet.py +131 -14
  11. diffusers/models/controlnet_flax.py +37 -26
  12. diffusers/models/cross_attention.py +17 -17
  13. diffusers/models/embeddings.py +67 -0
  14. diffusers/models/modeling_flax_utils.py +64 -56
  15. diffusers/models/modeling_utils.py +193 -104
  16. diffusers/models/prior_transformer.py +207 -37
  17. diffusers/models/resnet.py +26 -26
  18. diffusers/models/transformer_2d.py +36 -41
  19. diffusers/models/transformer_temporal.py +24 -21
  20. diffusers/models/unet_1d.py +31 -25
  21. diffusers/models/unet_2d.py +43 -30
  22. diffusers/models/unet_2d_blocks.py +210 -89
  23. diffusers/models/unet_2d_blocks_flax.py +12 -12
  24. diffusers/models/unet_2d_condition.py +172 -64
  25. diffusers/models/unet_2d_condition_flax.py +38 -24
  26. diffusers/models/unet_3d_blocks.py +34 -31
  27. diffusers/models/unet_3d_condition.py +101 -34
  28. diffusers/models/vae.py +5 -5
  29. diffusers/models/vae_flax.py +37 -34
  30. diffusers/models/vq_model.py +23 -14
  31. diffusers/pipelines/__init__.py +24 -1
  32. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +1 -1
  33. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +5 -3
  34. diffusers/pipelines/consistency_models/__init__.py +1 -0
  35. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +337 -0
  36. diffusers/pipelines/controlnet/multicontrolnet.py +120 -1
  37. diffusers/pipelines/controlnet/pipeline_controlnet.py +59 -17
  38. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +60 -15
  39. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +60 -17
  40. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +1 -1
  41. diffusers/pipelines/kandinsky/__init__.py +1 -1
  42. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +4 -6
  43. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +1 -0
  44. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +1 -0
  45. diffusers/pipelines/kandinsky2_2/__init__.py +7 -0
  46. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +317 -0
  47. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +372 -0
  48. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +434 -0
  49. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +398 -0
  50. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +531 -0
  51. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +541 -0
  52. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +605 -0
  53. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  54. diffusers/pipelines/pipeline_utils.py +124 -146
  55. diffusers/pipelines/shap_e/__init__.py +27 -0
  56. diffusers/pipelines/shap_e/camera.py +147 -0
  57. diffusers/pipelines/shap_e/pipeline_shap_e.py +390 -0
  58. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +349 -0
  59. diffusers/pipelines/shap_e/renderer.py +709 -0
  60. diffusers/pipelines/stable_diffusion/__init__.py +2 -0
  61. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +261 -66
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +3 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +5 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -2
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +1 -1
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +719 -0
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -1
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +832 -0
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +17 -7
  72. diffusers/pipelines/stable_diffusion_xl/__init__.py +26 -0
  73. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +823 -0
  74. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +896 -0
  75. diffusers/pipelines/stable_diffusion_xl/watermark.py +31 -0
  76. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -1
  77. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +5 -1
  78. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +771 -0
  79. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +92 -6
  80. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
  81. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +209 -91
  82. diffusers/schedulers/__init__.py +3 -0
  83. diffusers/schedulers/scheduling_consistency_models.py +380 -0
  84. diffusers/schedulers/scheduling_ddim.py +28 -6
  85. diffusers/schedulers/scheduling_ddim_inverse.py +19 -4
  86. diffusers/schedulers/scheduling_ddim_parallel.py +642 -0
  87. diffusers/schedulers/scheduling_ddpm.py +53 -7
  88. diffusers/schedulers/scheduling_ddpm_parallel.py +604 -0
  89. diffusers/schedulers/scheduling_deis_multistep.py +66 -11
  90. diffusers/schedulers/scheduling_dpmsolver_multistep.py +55 -13
  91. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +19 -4
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +73 -11
  93. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +23 -7
  94. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +58 -9
  95. diffusers/schedulers/scheduling_euler_discrete.py +58 -8
  96. diffusers/schedulers/scheduling_heun_discrete.py +89 -14
  97. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +73 -11
  98. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +73 -11
  99. diffusers/schedulers/scheduling_lms_discrete.py +57 -8
  100. diffusers/schedulers/scheduling_pndm.py +46 -10
  101. diffusers/schedulers/scheduling_repaint.py +19 -4
  102. diffusers/schedulers/scheduling_sde_ve.py +5 -1
  103. diffusers/schedulers/scheduling_unclip.py +43 -4
  104. diffusers/schedulers/scheduling_unipc_multistep.py +48 -7
  105. diffusers/training_utils.py +1 -1
  106. diffusers/utils/__init__.py +2 -1
  107. diffusers/utils/dummy_pt_objects.py +60 -0
  108. diffusers/utils/dummy_torch_and_transformers_and_invisible_watermark_objects.py +32 -0
  109. diffusers/utils/dummy_torch_and_transformers_objects.py +180 -0
  110. diffusers/utils/hub_utils.py +1 -1
  111. diffusers/utils/import_utils.py +20 -3
  112. diffusers/utils/logging.py +15 -18
  113. diffusers/utils/outputs.py +3 -3
  114. diffusers/utils/testing_utils.py +15 -0
  115. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/METADATA +4 -2
  116. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/RECORD +120 -94
  117. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/WHEEL +1 -1
  118. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/LICENSE +0 -0
  119. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/entry_points.txt +0 -0
  120. {diffusers-0.17.1.dist-info → diffusers-0.18.2.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ from functools import partial
22
22
  from typing import Any, Callable, List, Optional, Tuple, Union
23
23
 
24
24
  import torch
25
- from torch import Tensor, device
25
+ from torch import Tensor, device, nn
26
26
 
27
27
  from .. import __version__
28
28
  from ..utils import (
@@ -154,11 +154,10 @@ class ModelMixin(torch.nn.Module):
154
154
  r"""
155
155
  Base class for all models.
156
156
 
157
- [`ModelMixin`] takes care of storing the configuration of the models and handles methods for loading, downloading
158
- and saving models.
157
+ [`ModelMixin`] takes care of storing the model configuration and provides methods for loading, downloading and
158
+ saving models.
159
159
 
160
- - **config_name** ([`str`]) -- A filename under which the model should be stored when calling
161
- [`~models.ModelMixin.save_pretrained`].
160
+ - **config_name** ([`str`]) -- Filename to save a model to when calling [`~models.ModelMixin.save_pretrained`].
162
161
  """
163
162
  config_name = CONFIG_NAME
164
163
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
@@ -190,18 +189,13 @@ class ModelMixin(torch.nn.Module):
190
189
  def is_gradient_checkpointing(self) -> bool:
191
190
  """
192
191
  Whether gradient checkpointing is activated for this model or not.
193
-
194
- Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
195
- activations".
196
192
  """
197
193
  return any(hasattr(m, "gradient_checkpointing") and m.gradient_checkpointing for m in self.modules())
198
194
 
199
195
  def enable_gradient_checkpointing(self):
200
196
  """
201
- Activates gradient checkpointing for the current model.
202
-
203
- Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
204
- activations".
197
+ Activates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
198
+ *checkpoint activations* in other frameworks).
205
199
  """
206
200
  if not self._supports_gradient_checkpointing:
207
201
  raise ValueError(f"{self.__class__.__name__} does not support gradient checkpointing.")
@@ -209,10 +203,8 @@ class ModelMixin(torch.nn.Module):
209
203
 
210
204
  def disable_gradient_checkpointing(self):
211
205
  """
212
- Deactivates gradient checkpointing for the current model.
213
-
214
- Note that in other frameworks this feature can be referred to as "activation checkpointing" or "checkpoint
215
- activations".
206
+ Deactivates gradient checkpointing for the current model (may be referred to as *activation checkpointing* or
207
+ *checkpoint activations* in other frameworks).
216
208
  """
217
209
  if self._supports_gradient_checkpointing:
218
210
  self.apply(partial(self._set_gradient_checkpointing, value=False))
@@ -236,13 +228,17 @@ class ModelMixin(torch.nn.Module):
236
228
 
237
229
  def enable_xformers_memory_efficient_attention(self, attention_op: Optional[Callable] = None):
238
230
  r"""
239
- Enable memory efficient attention as implemented in xformers.
231
+ Enable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
232
+
233
+ When this option is enabled, you should observe lower GPU memory usage and a potential speed up during
234
+ inference. Speed up during training is not guaranteed.
235
+
236
+ <Tip warning={true}>
240
237
 
241
- When this option is enabled, you should observe lower GPU memory usage and a potential speed up at inference
242
- time. Speed up at training time is not guaranteed.
238
+ ⚠️ When memory efficient attention and sliced attention are both enabled, memory efficient attention takes
239
+ precedent.
243
240
 
244
- Warning: When Memory Efficient Attention and Sliced attention are both enabled, the Memory Efficient Attention
245
- is used.
241
+ </Tip>
246
242
 
247
243
  Parameters:
248
244
  attention_op (`Callable`, *optional*):
@@ -268,7 +264,7 @@ class ModelMixin(torch.nn.Module):
268
264
 
269
265
  def disable_xformers_memory_efficient_attention(self):
270
266
  r"""
271
- Disable memory efficient attention as implemented in xformers.
267
+ Disable memory efficient attention from [xFormers](https://facebookresearch.github.io/xformers/).
272
268
  """
273
269
  self.set_use_memory_efficient_attention_xformers(False)
274
270
 
@@ -281,24 +277,24 @@ class ModelMixin(torch.nn.Module):
281
277
  variant: Optional[str] = None,
282
278
  ):
283
279
  """
284
- Save a model and its configuration file to a directory, so that it can be re-loaded using the
285
- `[`~models.ModelMixin.from_pretrained`]` class method.
280
+ Save a model and its configuration file to a directory so that it can be reloaded using the
281
+ [`~models.ModelMixin.from_pretrained`] class method.
286
282
 
287
283
  Arguments:
288
284
  save_directory (`str` or `os.PathLike`):
289
- Directory to which to save. Will be created if it doesn't exist.
285
+ Directory to save a model and its configuration file to. Will be created if it doesn't exist.
290
286
  is_main_process (`bool`, *optional*, defaults to `True`):
291
- Whether the process calling this is the main process or not. Useful when in distributed training like
292
- TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
293
- the main process to avoid race conditions.
287
+ Whether the process calling this is the main process or not. Useful during distributed training and you
288
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
289
+ process to avoid race conditions.
294
290
  save_function (`Callable`):
295
- The function to use to save the state dictionary. Useful on distributed training like TPUs when one
296
- need to replace `torch.save` by another method. Can be configured with the environment variable
291
+ The function to use to save the state dictionary. Useful during distributed training when you need to
292
+ replace `torch.save` with another method. Can be configured with the environment variable
297
293
  `DIFFUSERS_SAVE_MODE`.
298
294
  safe_serialization (`bool`, *optional*, defaults to `False`):
299
- Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
295
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
300
296
  variant (`str`, *optional*):
301
- If specified, weights are saved in the format pytorch_model.<variant>.bin.
297
+ If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
302
298
  """
303
299
  if safe_serialization and not is_safetensors_available():
304
300
  raise ImportError("`safe_serialization` requires the `safetensors library: `pip install safetensors`.")
@@ -335,107 +331,108 @@ class ModelMixin(torch.nn.Module):
335
331
  @classmethod
336
332
  def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
337
333
  r"""
338
- Instantiate a pretrained pytorch model from a pre-trained model configuration.
339
-
340
- The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
341
- the model, you should first set it back in training mode with `model.train()`.
342
-
343
- The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
344
- pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
345
- task.
334
+ Instantiate a pretrained PyTorch model from a pretrained model configuration.
346
335
 
347
- The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
348
- weights are discarded.
336
+ The model is set in evaluation mode - `model.eval()` - by default, and dropout modules are deactivated. To
337
+ train the model, set it back in training mode with `model.train()`.
349
338
 
350
339
  Parameters:
351
340
  pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
352
341
  Can be either:
353
342
 
354
- - A string, the *model id* of a pretrained model hosted inside a model repo on huggingface.co.
355
- Valid model ids should have an organization name, like `google/ddpm-celebahq-256`.
356
- - A path to a *directory* containing model weights saved using [`~ModelMixin.save_config`], e.g.,
357
- `./my_model_directory/`.
343
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
344
+ the Hub.
345
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
346
+ with [`~ModelMixin.save_pretrained`].
358
347
 
359
348
  cache_dir (`Union[str, os.PathLike]`, *optional*):
360
- Path to a directory in which a downloaded pretrained model configuration should be cached if the
361
- standard cache should not be used.
349
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
350
+ is not used.
362
351
  torch_dtype (`str` or `torch.dtype`, *optional*):
363
- Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
364
- will be automatically derived from the model's weights.
352
+ Override the default `torch.dtype` and load the model with another dtype. If `"auto"` is passed, the
353
+ dtype is automatically derived from the model's weights.
365
354
  force_download (`bool`, *optional*, defaults to `False`):
366
355
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
367
356
  cached versions if they exist.
368
357
  resume_download (`bool`, *optional*, defaults to `False`):
369
- Whether or not to delete incompletely received files. Will attempt to resume the download if such a
370
- file exists.
358
+ Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
359
+ incompletely downloaded files are deleted.
371
360
  proxies (`Dict[str, str]`, *optional*):
372
- A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
361
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
373
362
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
374
- output_loading_info(`bool`, *optional*, defaults to `False`):
363
+ output_loading_info (`bool`, *optional*, defaults to `False`):
375
364
  Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
376
365
  local_files_only(`bool`, *optional*, defaults to `False`):
377
- Whether or not to only look at local files (i.e., do not try to download the model).
366
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
367
+ won't be downloaded from the Hub.
378
368
  use_auth_token (`str` or *bool*, *optional*):
379
- The token to use as HTTP bearer authorization for remote files. If `True`, will use the token generated
380
- when running `diffusers-cli login` (stored in `~/.huggingface`).
369
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
370
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
381
371
  revision (`str`, *optional*, defaults to `"main"`):
382
- The specific model version to use. It can be a branch name, a tag name, or a commit id, since we use a
383
- git-based system for storing models and other artifacts on huggingface.co, so `revision` can be any
384
- identifier allowed by git.
372
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
373
+ allowed by Git.
385
374
  from_flax (`bool`, *optional*, defaults to `False`):
386
375
  Load the model weights from a Flax checkpoint save file.
387
376
  subfolder (`str`, *optional*, defaults to `""`):
388
- In case the relevant files are located inside a subfolder of the model repo (either remote in
389
- huggingface.co or downloaded locally), you can specify the folder name here.
390
-
377
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
391
378
  mirror (`str`, *optional*):
392
- Mirror source to accelerate downloads in China. If you are from China and have an accessibility
393
- problem, you can set this option to resolve it. Note that we do not guarantee the timeliness or safety.
394
- Please refer to the mirror site for more information.
379
+ Mirror source to resolve accessibility issues if you're downloading a model in China. We do not
380
+ guarantee the timeliness or safety of the source, and you should refer to the mirror site for more
381
+ information.
395
382
  device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
396
- A map that specifies where each submodule should go. It doesn't need to be refined to each
397
- parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
383
+ A map that specifies where each submodule should go. It doesn't need to be defined for each
384
+ parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
398
385
  same device.
399
386
 
400
- To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
387
+ Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
401
388
  more information about each option see [designing a device
402
389
  map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
403
390
  max_memory (`Dict`, *optional*):
404
- A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
405
- GPU and the available CPU RAM if unset.
391
+ A dictionary device identifier for the maximum memory. Will default to the maximum memory available for
392
+ each GPU and the available CPU RAM if unset.
406
393
  offload_folder (`str` or `os.PathLike`, *optional*):
407
- If the `device_map` contains any value `"disk"`, the folder where we will offload weights.
394
+ The path to offload weights if `device_map` contains the value `"disk"`.
408
395
  offload_state_dict (`bool`, *optional*):
409
- If `True`, will temporarily offload the CPU state dict to the hard drive to avoid getting out of CPU
410
- RAM if the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to
411
- `True` when there is some disk offload.
396
+ If `True`, temporarily offloads the CPU state dict to the hard drive to avoid running out of CPU RAM if
397
+ the weight of the CPU state dict + the biggest shard of the checkpoint does not fit. Defaults to `True`
398
+ when there is some disk offload.
412
399
  low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
413
- Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
414
- also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
415
- model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
416
- setting this argument to `True` will raise an error.
400
+ Speed up model loading only loading the pretrained weights and not initializing the weights. This also
401
+ tries to not use more than 1x model size in CPU memory (including peak memory) while loading the model.
402
+ Only supported for PyTorch >= 1.9.0. If you are using an older version of PyTorch, setting this
403
+ argument to `True` will raise an error.
417
404
  variant (`str`, *optional*):
418
- If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
419
- ignored when using `from_flax`.
405
+ Load weights from a specified `variant` filename such as `"fp16"` or `"ema"`. This is ignored when
406
+ loading `from_flax`.
420
407
  use_safetensors (`bool`, *optional*, defaults to `None`):
421
- If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
422
- `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
423
- `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
408
+ If set to `None`, the `safetensors` weights are downloaded if they're available **and** if the
409
+ `safetensors` library is installed. If set to `True`, the model is forcibly loaded from `safetensors`
410
+ weights. If set to `False`, `safetensors` weights are not loaded.
424
411
 
425
412
  <Tip>
426
413
 
427
- It is required to be logged in (`huggingface-cli login`) when you want to use private or [gated
428
- models](https://huggingface.co/docs/hub/models-gated#gated-models).
414
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
415
+ `huggingface-cli login`. You can also activate the special
416
+ ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
417
+ firewalled environment.
429
418
 
430
419
  </Tip>
431
420
 
432
- <Tip>
421
+ Example:
433
422
 
434
- Activate the special ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use
435
- this method in a firewalled environment.
423
+ ```py
424
+ from diffusers import UNet2DConditionModel
436
425
 
437
- </Tip>
426
+ unet = UNet2DConditionModel.from_pretrained("runwayml/stable-diffusion-v1-5", subfolder="unet")
427
+ ```
438
428
 
429
+ If you get the error message below, you need to finetune the weights for your downstream task:
430
+
431
+ ```bash
432
+ Some weights of UNet2DConditionModel were not initialized from the model checkpoint at runwayml/stable-diffusion-v1-5 and are newly initialized because the shapes did not match:
433
+ - conv_in.weight: found shape torch.Size([320, 4, 3, 3]) in the checkpoint and torch.Size([320, 9, 3, 3]) in the model instantiated
434
+ You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
435
+ ```
439
436
  """
440
437
  cache_dir = kwargs.pop("cache_dir", DIFFUSERS_CACHE)
441
438
  ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
@@ -459,7 +456,7 @@ class ModelMixin(torch.nn.Module):
459
456
 
460
457
  if use_safetensors and not is_safetensors_available():
461
458
  raise ValueError(
462
- "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetenstors"
459
+ "`use_safetensors`=True but safetensors is not installed. Please install safetensors with `pip install safetensors"
463
460
  )
464
461
 
465
462
  allow_pickle = False
@@ -646,15 +643,47 @@ class ModelMixin(torch.nn.Module):
646
643
  else: # else let accelerate handle loading and dispatching.
647
644
  # Load weights and dispatch according to the device_map
648
645
  # by default the device_map is None and the weights are loaded on the CPU
649
- accelerate.load_checkpoint_and_dispatch(
650
- model,
651
- model_file,
652
- device_map,
653
- max_memory=max_memory,
654
- offload_folder=offload_folder,
655
- offload_state_dict=offload_state_dict,
656
- dtype=torch_dtype,
657
- )
646
+ try:
647
+ accelerate.load_checkpoint_and_dispatch(
648
+ model,
649
+ model_file,
650
+ device_map,
651
+ max_memory=max_memory,
652
+ offload_folder=offload_folder,
653
+ offload_state_dict=offload_state_dict,
654
+ dtype=torch_dtype,
655
+ )
656
+ except AttributeError as e:
657
+ # When using accelerate loading, we do not have the ability to load the state
658
+ # dict and rename the weight names manually. Additionally, accelerate skips
659
+ # torch loading conventions and directly writes into `module.{_buffers, _parameters}`
660
+ # (which look like they should be private variables?), so we can't use the standard hooks
661
+ # to rename parameters on load. We need to mimic the original weight names so the correct
662
+ # attributes are available. After we have loaded the weights, we convert the deprecated
663
+ # names to the new non-deprecated names. Then we _greatly encourage_ the user to convert
664
+ # the weights so we don't have to do this again.
665
+
666
+ if "'Attention' object has no attribute" in str(e):
667
+ logger.warn(
668
+ f"Taking `{str(e)}` while using `accelerate.load_checkpoint_and_dispatch` to mean {pretrained_model_name_or_path}"
669
+ " was saved with deprecated attention block weight names. We will load it with the deprecated attention block"
670
+ " names and convert them on the fly to the new attention block format. Please re-save the model after this conversion,"
671
+ " so we don't have to do the on the fly renaming in the future. If the model is from a hub checkpoint,"
672
+ " please also re-upload it or open a PR on the original repository."
673
+ )
674
+ model._temp_convert_self_to_deprecated_attention_blocks()
675
+ accelerate.load_checkpoint_and_dispatch(
676
+ model,
677
+ model_file,
678
+ device_map,
679
+ max_memory=max_memory,
680
+ offload_folder=offload_folder,
681
+ offload_state_dict=offload_state_dict,
682
+ dtype=torch_dtype,
683
+ )
684
+ model._undo_temp_convert_self_to_deprecated_attention_blocks()
685
+ else:
686
+ raise e
658
687
 
659
688
  loading_info = {
660
689
  "missing_keys": [],
@@ -820,17 +849,27 @@ class ModelMixin(torch.nn.Module):
820
849
 
821
850
  def num_parameters(self, only_trainable: bool = False, exclude_embeddings: bool = False) -> int:
822
851
  """
823
- Get number of (optionally, trainable or non-embeddings) parameters in the module.
852
+ Get number of (trainable or non-embedding) parameters in the module.
824
853
 
825
854
  Args:
826
855
  only_trainable (`bool`, *optional*, defaults to `False`):
827
- Whether or not to return only the number of trainable parameters
828
-
856
+ Whether or not to return only the number of trainable parameters.
829
857
  exclude_embeddings (`bool`, *optional*, defaults to `False`):
830
- Whether or not to return only the number of non-embeddings parameters
858
+ Whether or not to return only the number of non-embedding parameters.
831
859
 
832
860
  Returns:
833
861
  `int`: The number of parameters.
862
+
863
+ Example:
864
+
865
+ ```py
866
+ from diffusers import UNet2DConditionModel
867
+
868
+ model_id = "runwayml/stable-diffusion-v1-5"
869
+ unet = UNet2DConditionModel.from_pretrained(model_id, subfolder="unet")
870
+ unet.num_parameters(only_trainable=True)
871
+ 859520964
872
+ ```
834
873
  """
835
874
 
836
875
  if exclude_embeddings:
@@ -889,3 +928,53 @@ class ModelMixin(torch.nn.Module):
889
928
  state_dict[f"{path}.to_out.0.weight"] = state_dict.pop(f"{path}.proj_attn.weight")
890
929
  if f"{path}.proj_attn.bias" in state_dict:
891
930
  state_dict[f"{path}.to_out.0.bias"] = state_dict.pop(f"{path}.proj_attn.bias")
931
+
932
+ def _temp_convert_self_to_deprecated_attention_blocks(self):
933
+ deprecated_attention_block_modules = []
934
+
935
+ def recursive_find_attn_block(module):
936
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
937
+ deprecated_attention_block_modules.append(module)
938
+
939
+ for sub_module in module.children():
940
+ recursive_find_attn_block(sub_module)
941
+
942
+ recursive_find_attn_block(self)
943
+
944
+ for module in deprecated_attention_block_modules:
945
+ module.query = module.to_q
946
+ module.key = module.to_k
947
+ module.value = module.to_v
948
+ module.proj_attn = module.to_out[0]
949
+
950
+ # We don't _have_ to delete the old attributes, but it's helpful to ensure
951
+ # that _all_ the weights are loaded into the new attributes and we're not
952
+ # making an incorrect assumption that this model should be converted when
953
+ # it really shouldn't be.
954
+ del module.to_q
955
+ del module.to_k
956
+ del module.to_v
957
+ del module.to_out
958
+
959
+ def _undo_temp_convert_self_to_deprecated_attention_blocks(self):
960
+ deprecated_attention_block_modules = []
961
+
962
+ def recursive_find_attn_block(module):
963
+ if hasattr(module, "_from_deprecated_attn_block") and module._from_deprecated_attn_block:
964
+ deprecated_attention_block_modules.append(module)
965
+
966
+ for sub_module in module.children():
967
+ recursive_find_attn_block(sub_module)
968
+
969
+ recursive_find_attn_block(self)
970
+
971
+ for module in deprecated_attention_block_modules:
972
+ module.to_q = module.query
973
+ module.to_k = module.key
974
+ module.to_v = module.value
975
+ module.to_out = nn.ModuleList([module.proj_attn, nn.Dropout(module.dropout)])
976
+
977
+ del module.query
978
+ del module.key
979
+ del module.value
980
+ del module.proj_attn