diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +19 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +20 -26
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +42 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
  229. diffusers/schedulers/scheduling_edm_euler.py +50 -31
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
  231. diffusers/schedulers/scheduling_euler_discrete.py +160 -68
  232. diffusers/schedulers/scheduling_heun_discrete.py +57 -39
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +24 -26
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
  267. diffusers-0.27.2.dist-info/RECORD +0 -399
  268. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  269. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -20,6 +20,7 @@ import os
20
20
  import re
21
21
  from collections import OrderedDict
22
22
  from functools import partial
23
+ from pathlib import Path
23
24
  from typing import Any, Callable, List, Optional, Tuple, Union
24
25
 
25
26
  import safetensors
@@ -32,7 +33,6 @@ from .. import __version__
32
33
  from ..utils import (
33
34
  CONFIG_NAME,
34
35
  FLAX_WEIGHTS_NAME,
35
- SAFETENSORS_FILE_EXTENSION,
36
36
  SAFETENSORS_WEIGHTS_NAME,
37
37
  WEIGHTS_NAME,
38
38
  _add_variant,
@@ -43,6 +43,12 @@ from ..utils import (
43
43
  logging,
44
44
  )
45
45
  from ..utils.hub_utils import PushToHubMixin, load_or_create_model_card, populate_model_card
46
+ from .model_loading_utils import (
47
+ _determine_device_map,
48
+ _load_state_dict_into_model,
49
+ load_model_dict_into_meta,
50
+ load_state_dict,
51
+ )
46
52
 
47
53
 
48
54
  logger = logging.get_logger(__name__)
@@ -56,8 +62,6 @@ else:
56
62
 
57
63
  if is_accelerate_available():
58
64
  import accelerate
59
- from accelerate.utils import set_module_tensor_to_device
60
- from accelerate.utils.versions import is_torch_version
61
65
 
62
66
 
63
67
  def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
@@ -98,89 +102,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
98
102
  return first_tuple[1].dtype
99
103
 
100
104
 
101
- def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
102
- """
103
- Reads a checkpoint file, returning properly formatted errors if they arise.
104
- """
105
- try:
106
- file_extension = os.path.basename(checkpoint_file).split(".")[-1]
107
- if file_extension == SAFETENSORS_FILE_EXTENSION:
108
- return safetensors.torch.load_file(checkpoint_file, device="cpu")
109
- else:
110
- return torch.load(checkpoint_file, map_location="cpu")
111
- except Exception as e:
112
- try:
113
- with open(checkpoint_file) as f:
114
- if f.read().startswith("version"):
115
- raise OSError(
116
- "You seem to have cloned a repository without having git-lfs installed. Please install "
117
- "git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
118
- "you cloned."
119
- )
120
- else:
121
- raise ValueError(
122
- f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
123
- "model. Make sure you have saved the model properly."
124
- ) from e
125
- except (UnicodeDecodeError, ValueError):
126
- raise OSError(
127
- f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
128
- )
129
-
130
-
131
- def load_model_dict_into_meta(
132
- model,
133
- state_dict: OrderedDict,
134
- device: Optional[Union[str, torch.device]] = None,
135
- dtype: Optional[Union[str, torch.dtype]] = None,
136
- model_name_or_path: Optional[str] = None,
137
- ) -> List[str]:
138
- device = device or torch.device("cpu")
139
- dtype = dtype or torch.float32
140
-
141
- accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
142
-
143
- unexpected_keys = []
144
- empty_state_dict = model.state_dict()
145
- for param_name, param in state_dict.items():
146
- if param_name not in empty_state_dict:
147
- unexpected_keys.append(param_name)
148
- continue
149
-
150
- if empty_state_dict[param_name].shape != param.shape:
151
- model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
152
- raise ValueError(
153
- f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
154
- )
155
-
156
- if accepts_dtype:
157
- set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
158
- else:
159
- set_module_tensor_to_device(model, param_name, device, value=param)
160
- return unexpected_keys
161
-
162
-
163
- def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
164
- # Convert old format to new format if needed from a PyTorch state_dict
165
- # copy state_dict so _load_from_state_dict can modify it
166
- state_dict = state_dict.copy()
167
- error_msgs = []
168
-
169
- # PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
170
- # so we need to apply the function recursively.
171
- def load(module: torch.nn.Module, prefix: str = ""):
172
- args = (state_dict, prefix, {}, True, [], [], error_msgs)
173
- module._load_from_state_dict(*args)
174
-
175
- for name, child in module._modules.items():
176
- if child is not None:
177
- load(child, prefix + name + ".")
178
-
179
- load(model_to_load)
180
-
181
- return error_msgs
182
-
183
-
184
105
  class ModelMixin(torch.nn.Module, PushToHubMixin):
185
106
  r"""
186
107
  Base class for all models.
@@ -195,6 +116,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
195
116
  _automatically_saved_args = ["_diffusers_version", "_class_name", "_name_or_path"]
196
117
  _supports_gradient_checkpointing = False
197
118
  _keys_to_ignore_on_load_unexpected = None
119
+ _no_split_modules = None
198
120
 
199
121
  def __init__(self):
200
122
  super().__init__()
@@ -241,6 +163,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
241
163
  if self._supports_gradient_checkpointing:
242
164
  self.apply(partial(self._set_gradient_checkpointing, value=False))
243
165
 
166
+ def set_use_npu_flash_attention(self, valid: bool) -> None:
167
+ r"""
168
+ Set the switch for the npu flash attention.
169
+ """
170
+
171
+ def fn_recursive_set_npu_flash_attention(module: torch.nn.Module):
172
+ if hasattr(module, "set_use_npu_flash_attention"):
173
+ module.set_use_npu_flash_attention(valid)
174
+
175
+ for child in module.children():
176
+ fn_recursive_set_npu_flash_attention(child)
177
+
178
+ for module in self.children():
179
+ if isinstance(module, torch.nn.Module):
180
+ fn_recursive_set_npu_flash_attention(module)
181
+
182
+ def enable_npu_flash_attention(self) -> None:
183
+ r"""
184
+ Enable npu flash attention from torch_npu
185
+
186
+ """
187
+ self.set_use_npu_flash_attention(True)
188
+
189
+ def disable_npu_flash_attention(self) -> None:
190
+ r"""
191
+ disable npu flash attention from torch_npu
192
+
193
+ """
194
+ self.set_use_npu_flash_attention(False)
195
+
244
196
  def set_use_memory_efficient_attention_xformers(
245
197
  self, valid: bool, attention_op: Optional[Callable] = None
246
198
  ) -> None:
@@ -367,18 +319,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
367
319
  # Save the model
368
320
  if safe_serialization:
369
321
  safetensors.torch.save_file(
370
- state_dict, os.path.join(save_directory, weights_name), metadata={"format": "pt"}
322
+ state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
371
323
  )
372
324
  else:
373
- torch.save(state_dict, os.path.join(save_directory, weights_name))
325
+ torch.save(state_dict, Path(save_directory, weights_name).as_posix())
374
326
 
375
- logger.info(f"Model weights saved in {os.path.join(save_directory, weights_name)}")
327
+ logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
376
328
 
377
329
  if push_to_hub:
378
330
  # Create a new empty model card and eventually tag it
379
331
  model_card = load_or_create_model_card(repo_id, token=token)
380
332
  model_card = populate_model_card(model_card)
381
- model_card.save(os.path.join(save_directory, "README.md"))
333
+ model_card.save(Path(save_directory, "README.md").as_posix())
382
334
 
383
335
  self._upload_folder(
384
336
  save_directory,
@@ -415,9 +367,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
415
367
  force_download (`bool`, *optional*, defaults to `False`):
416
368
  Whether or not to force the (re-)download of the model weights and configuration files, overriding the
417
369
  cached versions if they exist.
418
- resume_download (`bool`, *optional*, defaults to `False`):
419
- Whether or not to resume downloading the model weights and configuration files. If set to `False`, any
420
- incompletely downloaded files are deleted.
370
+ resume_download:
371
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
372
+ of Diffusers.
421
373
  proxies (`Dict[str, str]`, *optional*):
422
374
  A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
423
375
  'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
@@ -499,7 +451,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
499
451
  ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
500
452
  force_download = kwargs.pop("force_download", False)
501
453
  from_flax = kwargs.pop("from_flax", False)
502
- resume_download = kwargs.pop("resume_download", False)
454
+ resume_download = kwargs.pop("resume_download", None)
503
455
  proxies = kwargs.pop("proxies", None)
504
456
  output_loading_info = kwargs.pop("output_loading_info", False)
505
457
  local_files_only = kwargs.pop("local_files_only", None)
@@ -554,6 +506,36 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
554
506
  " dispatching. Please make sure to set `low_cpu_mem_usage=True`."
555
507
  )
556
508
 
509
+ # change device_map into a map if we passed an int, a str or a torch.device
510
+ if isinstance(device_map, torch.device):
511
+ device_map = {"": device_map}
512
+ elif isinstance(device_map, str) and device_map not in ["auto", "balanced", "balanced_low_0", "sequential"]:
513
+ try:
514
+ device_map = {"": torch.device(device_map)}
515
+ except RuntimeError:
516
+ raise ValueError(
517
+ "When passing device_map as a string, the value needs to be a device name (e.g. cpu, cuda:0) or "
518
+ f"'auto', 'balanced', 'balanced_low_0', 'sequential' but found {device_map}."
519
+ )
520
+ elif isinstance(device_map, int):
521
+ if device_map < 0:
522
+ raise ValueError(
523
+ "You can't pass device_map as a negative int. If you want to put the model on the cpu, pass device_map = 'cpu' "
524
+ )
525
+ else:
526
+ device_map = {"": device_map}
527
+
528
+ if device_map is not None:
529
+ if low_cpu_mem_usage is None:
530
+ low_cpu_mem_usage = True
531
+ elif not low_cpu_mem_usage:
532
+ raise ValueError("Passing along a `device_map` requires `low_cpu_mem_usage=True`")
533
+
534
+ if low_cpu_mem_usage:
535
+ if device_map is not None and not is_torch_version(">=", "1.10"):
536
+ # The max memory utils require PyTorch >= 1.10 to have torch.cuda.mem_get_info.
537
+ raise ValueError("`low_cpu_mem_usage` and `device_map` require PyTorch >= 1.10.")
538
+
557
539
  # Load config if we don't provide a configuration
558
540
  config_path = pretrained_model_name_or_path
559
541
 
@@ -576,10 +558,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
576
558
  token=token,
577
559
  revision=revision,
578
560
  subfolder=subfolder,
579
- device_map=device_map,
580
- max_memory=max_memory,
581
- offload_folder=offload_folder,
582
- offload_state_dict=offload_state_dict,
583
561
  user_agent=user_agent,
584
562
  **kwargs,
585
563
  )
@@ -684,6 +662,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
684
662
  else: # else let accelerate handle loading and dispatching.
685
663
  # Load weights and dispatch according to the device_map
686
664
  # by default the device_map is None and the weights are loaded on the CPU
665
+ device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
687
666
  try:
688
667
  accelerate.load_checkpoint_and_dispatch(
689
668
  model,
@@ -693,6 +672,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
693
672
  offload_folder=offload_folder,
694
673
  offload_state_dict=offload_state_dict,
695
674
  dtype=torch_dtype,
675
+ force_hooks=True,
676
+ strict=True,
696
677
  )
697
678
  except AttributeError as e:
698
679
  # When using accelerate loading, we do not have the ability to load the state
@@ -873,6 +854,45 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
873
854
 
874
855
  return model, missing_keys, unexpected_keys, mismatched_keys, error_msgs
875
856
 
857
+ @classmethod
858
+ def _get_signature_keys(cls, obj):
859
+ parameters = inspect.signature(obj.__init__).parameters
860
+ required_parameters = {k: v for k, v in parameters.items() if v.default == inspect._empty}
861
+ optional_parameters = set({k for k, v in parameters.items() if v.default != inspect._empty})
862
+ expected_modules = set(required_parameters.keys()) - {"self"}
863
+
864
+ return expected_modules, optional_parameters
865
+
866
+ # Adapted from `transformers` modeling_utils.py
867
+ def _get_no_split_modules(self, device_map: str):
868
+ """
869
+ Get the modules of the model that should not be spit when using device_map. We iterate through the modules to
870
+ get the underlying `_no_split_modules`.
871
+
872
+ Args:
873
+ device_map (`str`):
874
+ The device map value. Options are ["auto", "balanced", "balanced_low_0", "sequential"]
875
+
876
+ Returns:
877
+ `List[str]`: List of modules that should not be split
878
+ """
879
+ _no_split_modules = set()
880
+ modules_to_check = [self]
881
+ while len(modules_to_check) > 0:
882
+ module = modules_to_check.pop(-1)
883
+ # if the module does not appear in _no_split_modules, we also check the children
884
+ if module.__class__.__name__ not in _no_split_modules:
885
+ if isinstance(module, ModelMixin):
886
+ if module._no_split_modules is None:
887
+ raise ValueError(
888
+ f"{module.__class__.__name__} does not support `device_map='{device_map}'`. To implement support, the model "
889
+ "class needs to implement the `_no_split_modules` attribute."
890
+ )
891
+ else:
892
+ _no_split_modules = _no_split_modules | set(module._no_split_modules)
893
+ modules_to_check += list(module.children())
894
+ return list(_no_split_modules)
895
+
876
896
  @property
877
897
  def device(self) -> torch.device:
878
898
  """
@@ -58,7 +58,7 @@ class ResnetBlockCondNorm2D(nn.Module):
58
58
  non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
59
59
  time_embedding_norm (`str`, *optional*, default to `"ada_group"` ):
60
60
  The normalization layer for time embedding `temb`. Currently only support "ada_group" or "spatial".
61
- kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
61
+ kernel (`torch.Tensor`, optional, default to None): FIR filter, see
62
62
  [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
63
63
  output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
64
64
  use_in_shortcut (`bool`, *optional*, default to `True`):
@@ -101,8 +101,6 @@ class ResnetBlockCondNorm2D(nn.Module):
101
101
  self.output_scale_factor = output_scale_factor
102
102
  self.time_embedding_norm = time_embedding_norm
103
103
 
104
- conv_cls = nn.Conv2d
105
-
106
104
  if groups_out is None:
107
105
  groups_out = groups
108
106
 
@@ -113,7 +111,7 @@ class ResnetBlockCondNorm2D(nn.Module):
113
111
  else:
114
112
  raise ValueError(f" unsupported time_embedding_norm: {self.time_embedding_norm}")
115
113
 
116
- self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
114
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
117
115
 
118
116
  if self.time_embedding_norm == "ada_group": # ada_group
119
117
  self.norm2 = AdaGroupNorm(temb_channels, out_channels, groups_out, eps=eps)
@@ -125,7 +123,7 @@ class ResnetBlockCondNorm2D(nn.Module):
125
123
  self.dropout = torch.nn.Dropout(dropout)
126
124
 
127
125
  conv_2d_out_channels = conv_2d_out_channels or out_channels
128
- self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
126
+ self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
129
127
 
130
128
  self.nonlinearity = get_activation(non_linearity)
131
129
 
@@ -139,7 +137,7 @@ class ResnetBlockCondNorm2D(nn.Module):
139
137
 
140
138
  self.conv_shortcut = None
141
139
  if self.use_in_shortcut:
142
- self.conv_shortcut = conv_cls(
140
+ self.conv_shortcut = nn.Conv2d(
143
141
  in_channels,
144
142
  conv_2d_out_channels,
145
143
  kernel_size=1,
@@ -148,7 +146,7 @@ class ResnetBlockCondNorm2D(nn.Module):
148
146
  bias=conv_shortcut_bias,
149
147
  )
150
148
 
151
- def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
149
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
152
150
  if len(args) > 0 or kwargs.get("scale", None) is not None:
153
151
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
154
152
  deprecate("scale", "1.0.0", deprecation_message)
@@ -204,9 +202,9 @@ class ResnetBlock2D(nn.Module):
204
202
  eps (`float`, *optional*, defaults to `1e-6`): The epsilon to use for the normalization.
205
203
  non_linearity (`str`, *optional*, default to `"swish"`): the activation function to use.
206
204
  time_embedding_norm (`str`, *optional*, default to `"default"` ): Time scale shift config.
207
- By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift"
208
- for a stronger conditioning with scale and shift.
209
- kernel (`torch.FloatTensor`, optional, default to None): FIR filter, see
205
+ By default, apply timestep embedding conditioning with a simple shift mechanism. Choose "scale_shift" for a
206
+ stronger conditioning with scale and shift.
207
+ kernel (`torch.Tensor`, optional, default to None): FIR filter, see
210
208
  [`~models.resnet.FirUpsample2D`] and [`~models.resnet.FirDownsample2D`].
211
209
  output_scale_factor (`float`, *optional*, default to be `1.0`): the scale factor to use for the output.
212
210
  use_in_shortcut (`bool`, *optional*, default to `True`):
@@ -234,7 +232,7 @@ class ResnetBlock2D(nn.Module):
234
232
  non_linearity: str = "swish",
235
233
  skip_time_act: bool = False,
236
234
  time_embedding_norm: str = "default", # default, scale_shift,
237
- kernel: Optional[torch.FloatTensor] = None,
235
+ kernel: Optional[torch.Tensor] = None,
238
236
  output_scale_factor: float = 1.0,
239
237
  use_in_shortcut: Optional[bool] = None,
240
238
  up: bool = False,
@@ -263,21 +261,18 @@ class ResnetBlock2D(nn.Module):
263
261
  self.time_embedding_norm = time_embedding_norm
264
262
  self.skip_time_act = skip_time_act
265
263
 
266
- linear_cls = nn.Linear
267
- conv_cls = nn.Conv2d
268
-
269
264
  if groups_out is None:
270
265
  groups_out = groups
271
266
 
272
267
  self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
273
268
 
274
- self.conv1 = conv_cls(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
269
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
275
270
 
276
271
  if temb_channels is not None:
277
272
  if self.time_embedding_norm == "default":
278
- self.time_emb_proj = linear_cls(temb_channels, out_channels)
273
+ self.time_emb_proj = nn.Linear(temb_channels, out_channels)
279
274
  elif self.time_embedding_norm == "scale_shift":
280
- self.time_emb_proj = linear_cls(temb_channels, 2 * out_channels)
275
+ self.time_emb_proj = nn.Linear(temb_channels, 2 * out_channels)
281
276
  else:
282
277
  raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
283
278
  else:
@@ -287,7 +282,7 @@ class ResnetBlock2D(nn.Module):
287
282
 
288
283
  self.dropout = torch.nn.Dropout(dropout)
289
284
  conv_2d_out_channels = conv_2d_out_channels or out_channels
290
- self.conv2 = conv_cls(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
285
+ self.conv2 = nn.Conv2d(out_channels, conv_2d_out_channels, kernel_size=3, stride=1, padding=1)
291
286
 
292
287
  self.nonlinearity = get_activation(non_linearity)
293
288
 
@@ -313,7 +308,7 @@ class ResnetBlock2D(nn.Module):
313
308
 
314
309
  self.conv_shortcut = None
315
310
  if self.use_in_shortcut:
316
- self.conv_shortcut = conv_cls(
311
+ self.conv_shortcut = nn.Conv2d(
317
312
  in_channels,
318
313
  conv_2d_out_channels,
319
314
  kernel_size=1,
@@ -322,7 +317,7 @@ class ResnetBlock2D(nn.Module):
322
317
  bias=conv_shortcut_bias,
323
318
  )
324
319
 
325
- def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor, *args, **kwargs) -> torch.FloatTensor:
320
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor, *args, **kwargs) -> torch.Tensor:
326
321
  if len(args) > 0 or kwargs.get("scale", None) is not None:
327
322
  deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
328
323
  deprecate("scale", "1.0.0", deprecation_message)
@@ -610,7 +605,7 @@ class TemporalResnetBlock(nn.Module):
610
605
  padding=0,
611
606
  )
612
607
 
613
- def forward(self, input_tensor: torch.FloatTensor, temb: torch.FloatTensor) -> torch.FloatTensor:
608
+ def forward(self, input_tensor: torch.Tensor, temb: torch.Tensor) -> torch.Tensor:
614
609
  hidden_states = input_tensor
615
610
 
616
611
  hidden_states = self.norm1(hidden_states)
@@ -690,8 +685,8 @@ class SpatioTemporalResBlock(nn.Module):
690
685
 
691
686
  def forward(
692
687
  self,
693
- hidden_states: torch.FloatTensor,
694
- temb: Optional[torch.FloatTensor] = None,
688
+ hidden_states: torch.Tensor,
689
+ temb: Optional[torch.Tensor] = None,
695
690
  image_only_indicator: Optional[torch.Tensor] = None,
696
691
  ):
697
692
  num_frames = image_only_indicator.shape[-1]
@@ -20,15 +20,15 @@ from .transformers.transformer_temporal import (
20
20
 
21
21
 
22
22
  class TransformerTemporalModelOutput(TransformerTemporalModelOutput):
23
- deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModelOutput`, instead."
23
+ deprecation_message = "Importing `TransformerTemporalModelOutput` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_temporal import TransformerTemporalModelOutput`, instead."
24
24
  deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
25
25
 
26
26
 
27
27
  class TransformerTemporalModel(TransformerTemporalModel):
28
- deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerTemporalModel`, instead."
28
+ deprecation_message = "Importing `TransformerTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_temporal import TransformerTemporalModel`, instead."
29
29
  deprecate("TransformerTemporalModel", "0.29", deprecation_message)
30
30
 
31
31
 
32
32
  class TransformerSpatioTemporalModel(TransformerSpatioTemporalModel):
33
- deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.tranformer_temporal import TransformerSpatioTemporalModel`, instead."
33
+ deprecation_message = "Importing `TransformerSpatioTemporalModel` from `diffusers.models.transformer_temporal` is deprecated and this will be removed in a future version. Please use `from diffusers.models.transformers.transformer_temporal import TransformerSpatioTemporalModel`, instead."
34
34
  deprecate("TransformerTemporalModelOutput", "0.29", deprecation_message)
@@ -106,21 +106,21 @@ class DualTransformer2DModel(nn.Module):
106
106
  """
107
107
  Args:
108
108
  hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
109
- When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
110
- hidden_states.
109
+ When continuous, `torch.Tensor` of shape `(batch size, channel, height, width)`): Input hidden_states.
111
110
  encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
112
111
  Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
113
112
  self-attention.
114
113
  timestep ( `torch.long`, *optional*):
115
114
  Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
116
- attention_mask (`torch.FloatTensor`, *optional*):
115
+ attention_mask (`torch.Tensor`, *optional*):
117
116
  Optional attention mask to be applied in Attention.
118
117
  cross_attention_kwargs (`dict`, *optional*):
119
118
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
120
119
  `self.processor` in
121
120
  [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
122
121
  return_dict (`bool`, *optional*, defaults to `True`):
123
- Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
122
+ Whether or not to return a [`models.unets.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
123
+ tuple.
124
124
 
125
125
  Returns:
126
126
  [`~models.transformer_2d.Transformer2DModelOutput`] or `tuple`:
@@ -26,11 +26,11 @@ class PriorTransformerOutput(BaseOutput):
26
26
  The output of [`PriorTransformer`].
27
27
 
28
28
  Args:
29
- predicted_image_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
29
+ predicted_image_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
30
30
  The predicted CLIP image embedding conditioned on the CLIP text embedding input.
31
31
  """
32
32
 
33
- predicted_image_embedding: torch.FloatTensor
33
+ predicted_image_embedding: torch.Tensor
34
34
 
35
35
 
36
36
  class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, PeftAdapterMixin):
@@ -246,8 +246,8 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
246
246
  self,
247
247
  hidden_states,
248
248
  timestep: Union[torch.Tensor, float, int],
249
- proj_embedding: torch.FloatTensor,
250
- encoder_hidden_states: Optional[torch.FloatTensor] = None,
249
+ proj_embedding: torch.Tensor,
250
+ encoder_hidden_states: Optional[torch.Tensor] = None,
251
251
  attention_mask: Optional[torch.BoolTensor] = None,
252
252
  return_dict: bool = True,
253
253
  ):
@@ -255,13 +255,13 @@ class PriorTransformer(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin, Pef
255
255
  The [`PriorTransformer`] forward method.
256
256
 
257
257
  Args:
258
- hidden_states (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
258
+ hidden_states (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
259
259
  The currently predicted image embeddings.
260
260
  timestep (`torch.LongTensor`):
261
261
  Current denoising step.
262
- proj_embedding (`torch.FloatTensor` of shape `(batch_size, embedding_dim)`):
262
+ proj_embedding (`torch.Tensor` of shape `(batch_size, embedding_dim)`):
263
263
  Projected embedding vector the denoising process is conditioned on.
264
- encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
264
+ encoder_hidden_states (`torch.Tensor` of shape `(batch_size, num_embeddings, embedding_dim)`):
265
265
  Hidden states of the text embeddings the denoising process is conditioned on.
266
266
  attention_mask (`torch.BoolTensor` of shape `(batch_size, num_embeddings)`):
267
267
  Text mask for the text embeddings.
@@ -86,7 +86,7 @@ class T5FilmDecoder(ModelMixin, ConfigMixin):
86
86
  self.post_dropout = nn.Dropout(p=dropout_rate)
87
87
  self.spec_out = nn.Linear(d_model, input_dims, bias=False)
88
88
 
89
- def encoder_decoder_mask(self, query_input: torch.FloatTensor, key_input: torch.FloatTensor) -> torch.FloatTensor:
89
+ def encoder_decoder_mask(self, query_input: torch.Tensor, key_input: torch.Tensor) -> torch.Tensor:
90
90
  mask = torch.mul(query_input.unsqueeze(-1), key_input.unsqueeze(-2))
91
91
  return mask.unsqueeze(-3)
92
92
 
@@ -195,13 +195,13 @@ class DecoderLayer(nn.Module):
195
195
 
196
196
  def forward(
197
197
  self,
198
- hidden_states: torch.FloatTensor,
199
- conditioning_emb: Optional[torch.FloatTensor] = None,
200
- attention_mask: Optional[torch.FloatTensor] = None,
198
+ hidden_states: torch.Tensor,
199
+ conditioning_emb: Optional[torch.Tensor] = None,
200
+ attention_mask: Optional[torch.Tensor] = None,
201
201
  encoder_hidden_states: Optional[torch.Tensor] = None,
202
202
  encoder_attention_mask: Optional[torch.Tensor] = None,
203
203
  encoder_decoder_position_bias=None,
204
- ) -> Tuple[torch.FloatTensor]:
204
+ ) -> Tuple[torch.Tensor]:
205
205
  hidden_states = self.layer[0](
206
206
  hidden_states,
207
207
  conditioning_emb=conditioning_emb,
@@ -249,10 +249,10 @@ class T5LayerSelfAttentionCond(nn.Module):
249
249
 
250
250
  def forward(
251
251
  self,
252
- hidden_states: torch.FloatTensor,
253
- conditioning_emb: Optional[torch.FloatTensor] = None,
254
- attention_mask: Optional[torch.FloatTensor] = None,
255
- ) -> torch.FloatTensor:
252
+ hidden_states: torch.Tensor,
253
+ conditioning_emb: Optional[torch.Tensor] = None,
254
+ attention_mask: Optional[torch.Tensor] = None,
255
+ ) -> torch.Tensor:
256
256
  # pre_self_attention_layer_norm
257
257
  normed_hidden_states = self.layer_norm(hidden_states)
258
258
 
@@ -292,10 +292,10 @@ class T5LayerCrossAttention(nn.Module):
292
292
 
293
293
  def forward(
294
294
  self,
295
- hidden_states: torch.FloatTensor,
296
- key_value_states: Optional[torch.FloatTensor] = None,
297
- attention_mask: Optional[torch.FloatTensor] = None,
298
- ) -> torch.FloatTensor:
295
+ hidden_states: torch.Tensor,
296
+ key_value_states: Optional[torch.Tensor] = None,
297
+ attention_mask: Optional[torch.Tensor] = None,
298
+ ) -> torch.Tensor:
299
299
  normed_hidden_states = self.layer_norm(hidden_states)
300
300
  attention_output = self.attention(
301
301
  normed_hidden_states,
@@ -328,9 +328,7 @@ class T5LayerFFCond(nn.Module):
328
328
  self.layer_norm = T5LayerNorm(d_model, eps=layer_norm_epsilon)
329
329
  self.dropout = nn.Dropout(dropout_rate)
330
330
 
331
- def forward(
332
- self, hidden_states: torch.FloatTensor, conditioning_emb: Optional[torch.FloatTensor] = None
333
- ) -> torch.FloatTensor:
331
+ def forward(self, hidden_states: torch.Tensor, conditioning_emb: Optional[torch.Tensor] = None) -> torch.Tensor:
334
332
  forwarded_states = self.layer_norm(hidden_states)
335
333
  if conditioning_emb is not None:
336
334
  forwarded_states = self.film(forwarded_states, conditioning_emb)
@@ -361,7 +359,7 @@ class T5DenseGatedActDense(nn.Module):
361
359
  self.dropout = nn.Dropout(dropout_rate)
362
360
  self.act = NewGELUActivation()
363
361
 
364
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
362
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
365
363
  hidden_gelu = self.act(self.wi_0(hidden_states))
366
364
  hidden_linear = self.wi_1(hidden_states)
367
365
  hidden_states = hidden_gelu * hidden_linear
@@ -390,7 +388,7 @@ class T5LayerNorm(nn.Module):
390
388
  self.weight = nn.Parameter(torch.ones(hidden_size))
391
389
  self.variance_epsilon = eps
392
390
 
393
- def forward(self, hidden_states: torch.FloatTensor) -> torch.FloatTensor:
391
+ def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
394
392
  # T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
395
393
  # Square Layer Normalization https://arxiv.org/abs/1910.07467 thus variance is calculated
396
394
  # w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
@@ -431,7 +429,7 @@ class T5FiLMLayer(nn.Module):
431
429
  super().__init__()
432
430
  self.scale_bias = nn.Linear(in_features, out_features * 2, bias=False)
433
431
 
434
- def forward(self, x: torch.FloatTensor, conditioning_emb: torch.FloatTensor) -> torch.FloatTensor:
432
+ def forward(self, x: torch.Tensor, conditioning_emb: torch.Tensor) -> torch.Tensor:
435
433
  emb = self.scale_bias(conditioning_emb)
436
434
  scale, shift = torch.chunk(emb, 2, -1)
437
435
  x = x * (1 + scale) + shift