diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (270) hide show
  1. diffusers/__init__.py +18 -1
  2. diffusers/callbacks.py +156 -0
  3. diffusers/commands/env.py +110 -6
  4. diffusers/configuration_utils.py +16 -11
  5. diffusers/dependency_versions_table.py +2 -1
  6. diffusers/image_processor.py +158 -45
  7. diffusers/loaders/__init__.py +2 -5
  8. diffusers/loaders/autoencoder.py +4 -4
  9. diffusers/loaders/controlnet.py +4 -4
  10. diffusers/loaders/ip_adapter.py +80 -22
  11. diffusers/loaders/lora.py +134 -20
  12. diffusers/loaders/lora_conversion_utils.py +46 -43
  13. diffusers/loaders/peft.py +4 -3
  14. diffusers/loaders/single_file.py +401 -170
  15. diffusers/loaders/single_file_model.py +290 -0
  16. diffusers/loaders/single_file_utils.py +616 -672
  17. diffusers/loaders/textual_inversion.py +41 -20
  18. diffusers/loaders/unet.py +168 -115
  19. diffusers/loaders/unet_loader_utils.py +163 -0
  20. diffusers/models/__init__.py +2 -0
  21. diffusers/models/activations.py +11 -3
  22. diffusers/models/attention.py +10 -11
  23. diffusers/models/attention_processor.py +367 -148
  24. diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
  25. diffusers/models/autoencoders/autoencoder_kl.py +18 -19
  26. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
  27. diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
  28. diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
  29. diffusers/models/autoencoders/vae.py +23 -24
  30. diffusers/models/controlnet.py +12 -9
  31. diffusers/models/controlnet_flax.py +4 -4
  32. diffusers/models/controlnet_xs.py +1915 -0
  33. diffusers/models/downsampling.py +17 -18
  34. diffusers/models/embeddings.py +147 -24
  35. diffusers/models/model_loading_utils.py +149 -0
  36. diffusers/models/modeling_flax_pytorch_utils.py +2 -1
  37. diffusers/models/modeling_flax_utils.py +4 -4
  38. diffusers/models/modeling_pytorch_flax_utils.py +1 -1
  39. diffusers/models/modeling_utils.py +118 -98
  40. diffusers/models/resnet.py +18 -23
  41. diffusers/models/transformer_temporal.py +3 -3
  42. diffusers/models/transformers/dual_transformer_2d.py +4 -4
  43. diffusers/models/transformers/prior_transformer.py +7 -7
  44. diffusers/models/transformers/t5_film_transformer.py +17 -19
  45. diffusers/models/transformers/transformer_2d.py +272 -156
  46. diffusers/models/transformers/transformer_temporal.py +10 -10
  47. diffusers/models/unets/unet_1d.py +5 -5
  48. diffusers/models/unets/unet_1d_blocks.py +29 -29
  49. diffusers/models/unets/unet_2d.py +6 -6
  50. diffusers/models/unets/unet_2d_blocks.py +137 -128
  51. diffusers/models/unets/unet_2d_condition.py +20 -15
  52. diffusers/models/unets/unet_2d_condition_flax.py +6 -5
  53. diffusers/models/unets/unet_3d_blocks.py +79 -77
  54. diffusers/models/unets/unet_3d_condition.py +13 -9
  55. diffusers/models/unets/unet_i2vgen_xl.py +14 -13
  56. diffusers/models/unets/unet_kandinsky3.py +1 -1
  57. diffusers/models/unets/unet_motion_model.py +114 -14
  58. diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
  59. diffusers/models/unets/unet_stable_cascade.py +16 -13
  60. diffusers/models/upsampling.py +17 -20
  61. diffusers/models/vq_model.py +16 -15
  62. diffusers/pipelines/__init__.py +25 -3
  63. diffusers/pipelines/amused/pipeline_amused.py +12 -12
  64. diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
  65. diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
  66. diffusers/pipelines/animatediff/__init__.py +2 -0
  67. diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
  68. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
  69. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
  70. diffusers/pipelines/animatediff/pipeline_output.py +3 -2
  71. diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
  72. diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
  73. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
  74. diffusers/pipelines/auto_pipeline.py +21 -17
  75. diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
  76. diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
  77. diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
  78. diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
  79. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
  80. diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
  81. diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
  82. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
  83. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
  84. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
  85. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
  86. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
  87. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
  88. diffusers/pipelines/controlnet_xs/__init__.py +68 -0
  89. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
  90. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
  91. diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
  92. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
  93. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
  94. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
  95. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
  96. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
  97. diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
  98. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
  99. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
  100. diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
  101. diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
  102. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
  103. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
  104. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
  105. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
  106. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
  107. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
  108. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
  109. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
  110. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
  111. diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
  112. diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
  113. diffusers/pipelines/dit/pipeline_dit.py +3 -0
  114. diffusers/pipelines/free_init_utils.py +39 -38
  115. diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
  116. diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
  117. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
  118. diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
  119. diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
  120. diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
  121. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
  122. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
  123. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
  124. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
  125. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
  126. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
  127. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
  128. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
  129. diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
  130. diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
  131. diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
  132. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
  133. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
  134. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
  135. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
  136. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
  137. diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
  138. diffusers/pipelines/marigold/__init__.py +50 -0
  139. diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
  140. diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
  141. diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
  142. diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
  143. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
  144. diffusers/pipelines/pia/pipeline_pia.py +39 -125
  145. diffusers/pipelines/pipeline_flax_utils.py +4 -4
  146. diffusers/pipelines/pipeline_loading_utils.py +268 -23
  147. diffusers/pipelines/pipeline_utils.py +266 -37
  148. diffusers/pipelines/pixart_alpha/__init__.py +8 -1
  149. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
  150. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
  151. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
  152. diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
  153. diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
  154. diffusers/pipelines/shap_e/renderer.py +1 -1
  155. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
  156. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
  157. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
  158. diffusers/pipelines/stable_diffusion/__init__.py +0 -1
  159. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
  160. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
  161. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
  163. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
  164. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
  165. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
  166. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
  167. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
  168. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
  169. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
  170. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
  171. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
  172. diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
  173. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
  174. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
  175. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
  176. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
  177. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
  178. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
  179. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
  180. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
  181. diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
  182. diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
  183. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
  184. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
  185. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
  186. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
  187. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
  188. diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
  189. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
  190. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
  191. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
  192. diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
  193. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
  194. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
  195. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
  196. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
  197. diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
  198. diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
  199. diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
  200. diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
  201. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
  202. diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
  203. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
  204. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
  205. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
  206. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
  207. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
  208. diffusers/schedulers/__init__.py +2 -2
  209. diffusers/schedulers/deprecated/__init__.py +1 -1
  210. diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
  211. diffusers/schedulers/scheduling_amused.py +5 -5
  212. diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
  213. diffusers/schedulers/scheduling_consistency_models.py +23 -25
  214. diffusers/schedulers/scheduling_ddim.py +22 -24
  215. diffusers/schedulers/scheduling_ddim_flax.py +2 -1
  216. diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
  217. diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
  218. diffusers/schedulers/scheduling_ddpm.py +20 -22
  219. diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
  220. diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
  221. diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
  222. diffusers/schedulers/scheduling_deis_multistep.py +46 -42
  223. diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
  224. diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
  225. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
  226. diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
  227. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
  228. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
  229. diffusers/schedulers/scheduling_edm_euler.py +53 -30
  230. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
  231. diffusers/schedulers/scheduling_euler_discrete.py +163 -67
  232. diffusers/schedulers/scheduling_heun_discrete.py +60 -38
  233. diffusers/schedulers/scheduling_ipndm.py +8 -8
  234. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
  235. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
  236. diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
  237. diffusers/schedulers/scheduling_lcm.py +21 -23
  238. diffusers/schedulers/scheduling_lms_discrete.py +27 -25
  239. diffusers/schedulers/scheduling_pndm.py +20 -20
  240. diffusers/schedulers/scheduling_repaint.py +20 -20
  241. diffusers/schedulers/scheduling_sasolver.py +55 -54
  242. diffusers/schedulers/scheduling_sde_ve.py +19 -19
  243. diffusers/schedulers/scheduling_tcd.py +39 -30
  244. diffusers/schedulers/scheduling_unclip.py +15 -15
  245. diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
  246. diffusers/schedulers/scheduling_utils.py +14 -5
  247. diffusers/schedulers/scheduling_utils_flax.py +3 -3
  248. diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
  249. diffusers/training_utils.py +56 -1
  250. diffusers/utils/__init__.py +7 -0
  251. diffusers/utils/doc_utils.py +1 -0
  252. diffusers/utils/dummy_pt_objects.py +30 -0
  253. diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
  254. diffusers/utils/dynamic_modules_utils.py +24 -11
  255. diffusers/utils/hub_utils.py +3 -2
  256. diffusers/utils/import_utils.py +91 -0
  257. diffusers/utils/loading_utils.py +2 -2
  258. diffusers/utils/logging.py +1 -1
  259. diffusers/utils/peft_utils.py +32 -5
  260. diffusers/utils/state_dict_utils.py +11 -2
  261. diffusers/utils/testing_utils.py +71 -6
  262. diffusers/utils/torch_utils.py +1 -0
  263. diffusers/video_processor.py +113 -0
  264. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
  265. diffusers-0.28.0.dist-info/RECORD +414 -0
  266. diffusers-0.27.1.dist-info/RECORD +0 -399
  267. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
  268. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
  269. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
  270. {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
@@ -92,6 +92,21 @@ class AnimateDiffPipeline(metaclass=DummyObject):
92
92
  requires_backends(cls, ["torch", "transformers"])
93
93
 
94
94
 
95
+ class AnimateDiffSDXLPipeline(metaclass=DummyObject):
96
+ _backends = ["torch", "transformers"]
97
+
98
+ def __init__(self, *args, **kwargs):
99
+ requires_backends(self, ["torch", "transformers"])
100
+
101
+ @classmethod
102
+ def from_config(cls, *args, **kwargs):
103
+ requires_backends(cls, ["torch", "transformers"])
104
+
105
+ @classmethod
106
+ def from_pretrained(cls, *args, **kwargs):
107
+ requires_backends(cls, ["torch", "transformers"])
108
+
109
+
95
110
  class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
96
111
  _backends = ["torch", "transformers"]
97
112
 
@@ -677,6 +692,36 @@ class LEditsPPPipelineStableDiffusionXL(metaclass=DummyObject):
677
692
  requires_backends(cls, ["torch", "transformers"])
678
693
 
679
694
 
695
+ class MarigoldDepthPipeline(metaclass=DummyObject):
696
+ _backends = ["torch", "transformers"]
697
+
698
+ def __init__(self, *args, **kwargs):
699
+ requires_backends(self, ["torch", "transformers"])
700
+
701
+ @classmethod
702
+ def from_config(cls, *args, **kwargs):
703
+ requires_backends(cls, ["torch", "transformers"])
704
+
705
+ @classmethod
706
+ def from_pretrained(cls, *args, **kwargs):
707
+ requires_backends(cls, ["torch", "transformers"])
708
+
709
+
710
+ class MarigoldNormalsPipeline(metaclass=DummyObject):
711
+ _backends = ["torch", "transformers"]
712
+
713
+ def __init__(self, *args, **kwargs):
714
+ requires_backends(self, ["torch", "transformers"])
715
+
716
+ @classmethod
717
+ def from_config(cls, *args, **kwargs):
718
+ requires_backends(cls, ["torch", "transformers"])
719
+
720
+ @classmethod
721
+ def from_pretrained(cls, *args, **kwargs):
722
+ requires_backends(cls, ["torch", "transformers"])
723
+
724
+
680
725
  class MusicLDMPipeline(metaclass=DummyObject):
681
726
  _backends = ["torch", "transformers"]
682
727
 
@@ -737,6 +782,21 @@ class PixArtAlphaPipeline(metaclass=DummyObject):
737
782
  requires_backends(cls, ["torch", "transformers"])
738
783
 
739
784
 
785
+ class PixArtSigmaPipeline(metaclass=DummyObject):
786
+ _backends = ["torch", "transformers"]
787
+
788
+ def __init__(self, *args, **kwargs):
789
+ requires_backends(self, ["torch", "transformers"])
790
+
791
+ @classmethod
792
+ def from_config(cls, *args, **kwargs):
793
+ requires_backends(cls, ["torch", "transformers"])
794
+
795
+ @classmethod
796
+ def from_pretrained(cls, *args, **kwargs):
797
+ requires_backends(cls, ["torch", "transformers"])
798
+
799
+
740
800
  class SemanticStableDiffusionPipeline(metaclass=DummyObject):
741
801
  _backends = ["torch", "transformers"]
742
802
 
@@ -902,6 +962,21 @@ class StableDiffusionControlNetPipeline(metaclass=DummyObject):
902
962
  requires_backends(cls, ["torch", "transformers"])
903
963
 
904
964
 
965
+ class StableDiffusionControlNetXSPipeline(metaclass=DummyObject):
966
+ _backends = ["torch", "transformers"]
967
+
968
+ def __init__(self, *args, **kwargs):
969
+ requires_backends(self, ["torch", "transformers"])
970
+
971
+ @classmethod
972
+ def from_config(cls, *args, **kwargs):
973
+ requires_backends(cls, ["torch", "transformers"])
974
+
975
+ @classmethod
976
+ def from_pretrained(cls, *args, **kwargs):
977
+ requires_backends(cls, ["torch", "transformers"])
978
+
979
+
905
980
  class StableDiffusionDepth2ImgPipeline(metaclass=DummyObject):
906
981
  _backends = ["torch", "transformers"]
907
982
 
@@ -1247,6 +1322,21 @@ class StableDiffusionXLControlNetPipeline(metaclass=DummyObject):
1247
1322
  requires_backends(cls, ["torch", "transformers"])
1248
1323
 
1249
1324
 
1325
+ class StableDiffusionXLControlNetXSPipeline(metaclass=DummyObject):
1326
+ _backends = ["torch", "transformers"]
1327
+
1328
+ def __init__(self, *args, **kwargs):
1329
+ requires_backends(self, ["torch", "transformers"])
1330
+
1331
+ @classmethod
1332
+ def from_config(cls, *args, **kwargs):
1333
+ requires_backends(cls, ["torch", "transformers"])
1334
+
1335
+ @classmethod
1336
+ def from_pretrained(cls, *args, **kwargs):
1337
+ requires_backends(cls, ["torch", "transformers"])
1338
+
1339
+
1250
1340
  class StableDiffusionXLImg2ImgPipeline(metaclass=DummyObject):
1251
1341
  _backends = ["torch", "transformers"]
1252
1342
 
@@ -201,7 +201,7 @@ def get_cached_module_file(
201
201
  module_file: str,
202
202
  cache_dir: Optional[Union[str, os.PathLike]] = None,
203
203
  force_download: bool = False,
204
- resume_download: bool = False,
204
+ resume_download: Optional[bool] = None,
205
205
  proxies: Optional[Dict[str, str]] = None,
206
206
  token: Optional[Union[bool, str]] = None,
207
207
  revision: Optional[str] = None,
@@ -228,9 +228,9 @@ def get_cached_module_file(
228
228
  cache should not be used.
229
229
  force_download (`bool`, *optional*, defaults to `False`):
230
230
  Whether or not to force to (re-)download the configuration files and override the cached versions if they
231
- exist.
232
- resume_download (`bool`, *optional*, defaults to `False`):
233
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
231
+ exist. resume_download:
232
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
233
+ of Diffusers.
234
234
  proxies (`Dict[str, str]`, *optional*):
235
235
  A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
236
236
  'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
@@ -246,8 +246,8 @@ def get_cached_module_file(
246
246
 
247
247
  <Tip>
248
248
 
249
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
250
- or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
249
+ You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
250
+ [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
251
251
 
252
252
  </Tip>
253
253
 
@@ -329,6 +329,11 @@ def get_cached_module_file(
329
329
  # The only reason we do the copy is to avoid putting too many folders in sys.path.
330
330
  shutil.copy(resolved_module_file, submodule_path / module_file)
331
331
  for module_needed in modules_needed:
332
+ if len(module_needed.split(".")) == 2:
333
+ module_needed = "/".join(module_needed.split("."))
334
+ module_folder = module_needed.split("/")[0]
335
+ if not os.path.exists(submodule_path / module_folder):
336
+ os.makedirs(submodule_path / module_folder)
332
337
  module_needed = f"{module_needed}.py"
333
338
  shutil.copy(os.path.join(pretrained_model_name_or_path, module_needed), submodule_path / module_needed)
334
339
  else:
@@ -343,9 +348,16 @@ def get_cached_module_file(
343
348
  create_dynamic_module(full_submodule)
344
349
 
345
350
  if not (submodule_path / module_file).exists():
351
+ if len(module_file.split("/")) == 2:
352
+ module_folder = module_file.split("/")[0]
353
+ if not os.path.exists(submodule_path / module_folder):
354
+ os.makedirs(submodule_path / module_folder)
346
355
  shutil.copy(resolved_module_file, submodule_path / module_file)
356
+
347
357
  # Make sure we also have every file with relative
348
358
  for module_needed in modules_needed:
359
+ if len(module_needed.split(".")) == 2:
360
+ module_needed = "/".join(module_needed.split("."))
349
361
  if not (submodule_path / module_needed).exists():
350
362
  get_cached_module_file(
351
363
  pretrained_model_name_or_path,
@@ -368,7 +380,7 @@ def get_class_from_dynamic_module(
368
380
  class_name: Optional[str] = None,
369
381
  cache_dir: Optional[Union[str, os.PathLike]] = None,
370
382
  force_download: bool = False,
371
- resume_download: bool = False,
383
+ resume_download: Optional[bool] = None,
372
384
  proxies: Optional[Dict[str, str]] = None,
373
385
  token: Optional[Union[bool, str]] = None,
374
386
  revision: Optional[str] = None,
@@ -405,8 +417,9 @@ def get_class_from_dynamic_module(
405
417
  force_download (`bool`, *optional*, defaults to `False`):
406
418
  Whether or not to force to (re-)download the configuration files and override the cached versions if they
407
419
  exist.
408
- resume_download (`bool`, *optional*, defaults to `False`):
409
- Whether or not to delete incompletely received file. Attempts to resume the download if such a file exists.
420
+ resume_download:
421
+ Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1 of
422
+ Diffusers.
410
423
  proxies (`Dict[str, str]`, *optional*):
411
424
  A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
412
425
  'http://hostname': 'foo.bar:4012'}.` The proxies are used on each request.
@@ -422,8 +435,8 @@ def get_class_from_dynamic_module(
422
435
 
423
436
  <Tip>
424
437
 
425
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private
426
- or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
438
+ You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
439
+ [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
427
440
 
428
441
  </Tip>
429
442
 
@@ -112,7 +112,8 @@ def load_or_create_model_card(
112
112
  repo_id_or_path (`str`):
113
113
  The repo id (e.g., "runwayml/stable-diffusion-v1-5") or local path where to look for the model card.
114
114
  token (`str`, *optional*):
115
- Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more details.
115
+ Authentication token. Will default to the stored token. See https://huggingface.co/settings/token for more
116
+ details.
116
117
  is_pipeline (`bool`):
117
118
  Boolean to indicate if we're adding tag to a [`DiffusionPipeline`].
118
119
  from_training: (`bool`): Boolean flag to denote if the model card is being created from a training script.
@@ -282,7 +283,7 @@ def _get_model_file(
282
283
  cache_dir: Optional[str] = None,
283
284
  force_download: bool = False,
284
285
  proxies: Optional[Dict] = None,
285
- resume_download: bool = False,
286
+ resume_download: Optional[bool] = None,
286
287
  local_files_only: bool = False,
287
288
  token: Optional[str] = None,
288
289
  user_agent: Optional[Union[Dict, str]] = None,
@@ -295,6 +295,46 @@ try:
295
295
  except importlib_metadata.PackageNotFoundError:
296
296
  _torchvision_available = False
297
297
 
298
+ _matplotlib_available = importlib.util.find_spec("matplotlib") is not None
299
+ try:
300
+ _matplotlib_version = importlib_metadata.version("matplotlib")
301
+ logger.debug(f"Successfully imported matplotlib version {_matplotlib_version}")
302
+ except importlib_metadata.PackageNotFoundError:
303
+ _matplotlib_available = False
304
+
305
+ _timm_available = importlib.util.find_spec("timm") is not None
306
+ if _timm_available:
307
+ try:
308
+ _timm_version = importlib_metadata.version("timm")
309
+ logger.info(f"Timm version {_timm_version} available.")
310
+ except importlib_metadata.PackageNotFoundError:
311
+ _timm_available = False
312
+
313
+
314
+ def is_timm_available():
315
+ return _timm_available
316
+
317
+
318
+ _bitsandbytes_available = importlib.util.find_spec("bitsandbytes") is not None
319
+ try:
320
+ _bitsandbytes_version = importlib_metadata.version("bitsandbytes")
321
+ logger.debug(f"Successfully imported bitsandbytes version {_bitsandbytes_version}")
322
+ except importlib_metadata.PackageNotFoundError:
323
+ _bitsandbytes_available = False
324
+
325
+ # Taken from `huggingface_hub`.
326
+ _is_notebook = False
327
+ try:
328
+ shell_class = get_ipython().__class__ # type: ignore # noqa: F821
329
+ for parent_class in shell_class.__mro__: # e.g. "is subclass of"
330
+ if parent_class.__name__ == "ZMQInteractiveShell":
331
+ _is_notebook = True # Jupyter notebook, Google colab or qtconsole
332
+ break
333
+ except NameError:
334
+ pass # Probably standard Python interpreter
335
+
336
+ _is_google_colab = "google.colab" in sys.modules
337
+
298
338
 
299
339
  def is_torch_available():
300
340
  return _torch_available
@@ -392,6 +432,26 @@ def is_torchvision_available():
392
432
  return _torchvision_available
393
433
 
394
434
 
435
+ def is_matplotlib_available():
436
+ return _matplotlib_available
437
+
438
+
439
+ def is_safetensors_available():
440
+ return _safetensors_available
441
+
442
+
443
+ def is_bitsandbytes_available():
444
+ return _bitsandbytes_available
445
+
446
+
447
+ def is_notebook():
448
+ return _is_notebook
449
+
450
+
451
+ def is_google_colab():
452
+ return _is_google_colab
453
+
454
+
395
455
  # docstyle-ignore
396
456
  FLAX_IMPORT_ERROR = """
397
457
  {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -499,6 +559,20 @@ INVISIBLE_WATERMARK_IMPORT_ERROR = """
499
559
  {0} requires the invisible-watermark library but it was not found in your environment. You can install it with pip: `pip install invisible-watermark>=0.2.0`
500
560
  """
501
561
 
562
+ # docstyle-ignore
563
+ PEFT_IMPORT_ERROR = """
564
+ {0} requires the peft library but it was not found in your environment. You can install it with pip: `pip install peft`
565
+ """
566
+
567
+ # docstyle-ignore
568
+ SAFETENSORS_IMPORT_ERROR = """
569
+ {0} requires the safetensors library but it was not found in your environment. You can install it with pip: `pip install safetensors`
570
+ """
571
+
572
+ # docstyle-ignore
573
+ BITSANDBYTES_IMPORT_ERROR = """
574
+ {0} requires the bitsandbytes library but it was not found in your environment. You can install it with pip: `pip install bitsandbytes`
575
+ """
502
576
 
503
577
  BACKENDS_MAPPING = OrderedDict(
504
578
  [
@@ -520,6 +594,9 @@ BACKENDS_MAPPING = OrderedDict(
520
594
  ("ftfy", (is_ftfy_available, FTFY_IMPORT_ERROR)),
521
595
  ("torchsde", (is_torchsde_available, TORCHSDE_IMPORT_ERROR)),
522
596
  ("invisible_watermark", (is_invisible_watermark_available, INVISIBLE_WATERMARK_IMPORT_ERROR)),
597
+ ("peft", (is_peft_available, PEFT_IMPORT_ERROR)),
598
+ ("safetensors", (is_safetensors_available, SAFETENSORS_IMPORT_ERROR)),
599
+ ("bitsandbytes", (is_bitsandbytes_available, BITSANDBYTES_IMPORT_ERROR)),
523
600
  ]
524
601
  )
525
602
 
@@ -628,6 +705,20 @@ def is_accelerate_version(operation: str, version: str):
628
705
  return compare_versions(parse(_accelerate_version), operation, version)
629
706
 
630
707
 
708
+ def is_peft_version(operation: str, version: str):
709
+ """
710
+ Args:
711
+ Compares the current PEFT version to a given reference with an operation.
712
+ operation (`str`):
713
+ A string representation of an operator, such as `">"` or `"<="`
714
+ version (`str`):
715
+ A version string
716
+ """
717
+ if not _peft_version:
718
+ return False
719
+ return compare_versions(parse(_peft_version), operation, version)
720
+
721
+
631
722
  def is_k_diffusion_version(operation: str, version: str):
632
723
  """
633
724
  Args:
@@ -16,8 +16,8 @@ def load_image(
16
16
  image (`str` or `PIL.Image.Image`):
17
17
  The image to convert to the PIL Image format.
18
18
  convert_method (Callable[[PIL.Image.Image], PIL.Image.Image], optional):
19
- A conversion method to apply to the image after loading it.
20
- When set to `None` the image will be converted "RGB".
19
+ A conversion method to apply to the image after loading it. When set to `None` the image will be converted
20
+ "RGB".
21
21
 
22
22
  Returns:
23
23
  `PIL.Image.Image`:
@@ -12,7 +12,7 @@
12
12
  # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
13
  # See the License for the specific language governing permissions and
14
14
  # limitations under the License.
15
- """ Logging utilities."""
15
+ """Logging utilities."""
16
16
 
17
17
  import logging
18
18
  import os
@@ -14,6 +14,7 @@
14
14
  """
15
15
  PEFT utilities: Utilities related to peft library
16
16
  """
17
+
17
18
  import collections
18
19
  import importlib
19
20
  from typing import Optional
@@ -63,9 +64,11 @@ def recurse_remove_peft_layers(model):
63
64
  module_replaced = False
64
65
 
65
66
  if isinstance(module, LoraLayer) and isinstance(module, torch.nn.Linear):
66
- new_module = torch.nn.Linear(module.in_features, module.out_features, bias=module.bias is not None).to(
67
- module.weight.device
68
- )
67
+ new_module = torch.nn.Linear(
68
+ module.in_features,
69
+ module.out_features,
70
+ bias=module.bias is not None,
71
+ ).to(module.weight.device)
69
72
  new_module.weight = module.weight
70
73
  if module.bias is not None:
71
74
  new_module.bias = module.bias
@@ -109,6 +112,9 @@ def scale_lora_layers(model, weight):
109
112
  """
110
113
  from peft.tuners.tuners_utils import BaseTunerLayer
111
114
 
115
+ if weight == 1.0:
116
+ return
117
+
112
118
  for module in model.modules():
113
119
  if isinstance(module, BaseTunerLayer):
114
120
  module.scale_layer(weight)
@@ -128,6 +134,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
128
134
  """
129
135
  from peft.tuners.tuners_utils import BaseTunerLayer
130
136
 
137
+ if weight == 1.0:
138
+ return
139
+
131
140
  for module in model.modules():
132
141
  if isinstance(module, BaseTunerLayer):
133
142
  if weight is not None and weight != 0:
@@ -170,6 +179,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
170
179
 
171
180
  # layer names without the Diffusers specific
172
181
  target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
182
+ use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
173
183
 
174
184
  lora_config_kwargs = {
175
185
  "r": r,
@@ -177,6 +187,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
177
187
  "rank_pattern": rank_pattern,
178
188
  "alpha_pattern": alpha_pattern,
179
189
  "target_modules": target_modules,
190
+ "use_dora": use_dora,
180
191
  }
181
192
  return lora_config_kwargs
182
193
 
@@ -227,16 +238,32 @@ def delete_adapter_layers(model, adapter_name):
227
238
  def set_weights_and_activate_adapters(model, adapter_names, weights):
228
239
  from peft.tuners.tuners_utils import BaseTunerLayer
229
240
 
241
+ def get_module_weight(weight_for_adapter, module_name):
242
+ if not isinstance(weight_for_adapter, dict):
243
+ # If weight_for_adapter is a single number, always return it.
244
+ return weight_for_adapter
245
+
246
+ for layer_name, weight_ in weight_for_adapter.items():
247
+ if layer_name in module_name:
248
+ return weight_
249
+
250
+ parts = module_name.split(".")
251
+ # e.g. key = "down_blocks.1.attentions.0"
252
+ key = f"{parts[0]}.{parts[1]}.attentions.{parts[3]}"
253
+ block_weight = weight_for_adapter.get(key, 1.0)
254
+
255
+ return block_weight
256
+
230
257
  # iterate over each adapter, make it active and set the corresponding scaling weight
231
258
  for adapter_name, weight in zip(adapter_names, weights):
232
- for module in model.modules():
259
+ for module_name, module in model.named_modules():
233
260
  if isinstance(module, BaseTunerLayer):
234
261
  # For backward compatbility with previous PEFT versions
235
262
  if hasattr(module, "set_adapter"):
236
263
  module.set_adapter(adapter_name)
237
264
  else:
238
265
  module.active_adapter = adapter_name
239
- module.set_scale(adapter_name, weight)
266
+ module.set_scale(adapter_name, get_module_weight(weight, module_name))
240
267
 
241
268
  # set multiple active adapters
242
269
  for module in model.modules():
@@ -14,6 +14,7 @@
14
14
  """
15
15
  State dict utilities: utility methods for converting state dicts easily
16
16
  """
17
+
17
18
  import enum
18
19
 
19
20
  from .logging import get_logger
@@ -46,6 +47,7 @@ UNET_TO_DIFFUSERS = {
46
47
  ".to_v_lora.up": ".to_v.lora_B",
47
48
  ".lora.up": ".lora_B",
48
49
  ".lora.down": ".lora_A",
50
+ ".to_out.lora_magnitude_vector": ".to_out.0.lora_magnitude_vector",
49
51
  }
50
52
 
51
53
 
@@ -103,6 +105,10 @@ DIFFUSERS_OLD_TO_DIFFUSERS = {
103
105
  ".to_v_lora.down": ".v_proj.lora_linear_layer.down",
104
106
  ".to_out_lora.up": ".out_proj.lora_linear_layer.up",
105
107
  ".to_out_lora.down": ".out_proj.lora_linear_layer.down",
108
+ ".to_k.lora_magnitude_vector": ".k_proj.lora_magnitude_vector",
109
+ ".to_v.lora_magnitude_vector": ".v_proj.lora_magnitude_vector",
110
+ ".to_q.lora_magnitude_vector": ".q_proj.lora_magnitude_vector",
111
+ ".to_out.lora_magnitude_vector": ".out_proj.lora_magnitude_vector",
106
112
  }
107
113
 
108
114
  PEFT_TO_KOHYA_SS = {
@@ -247,8 +253,8 @@ def convert_unet_state_dict_to_peft(state_dict):
247
253
 
248
254
  def convert_all_state_dict_to_peft(state_dict):
249
255
  r"""
250
- Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer`
251
- for a valid `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
256
+ Attempts to first `convert_state_dict_to_peft`, and if it doesn't detect `lora_linear_layer` for a valid
257
+ `DIFFUSERS` LoRA for example, attempts to exclusively convert the Unet `convert_unet_state_dict_to_peft`
252
258
  """
253
259
  try:
254
260
  peft_dict = convert_state_dict_to_peft(state_dict)
@@ -314,6 +320,9 @@ def convert_state_dict_to_kohya(state_dict, original_type=None, **kwargs):
314
320
  kohya_key = kohya_key.replace("text_encoder.", "lora_te1.")
315
321
  elif "unet" in kohya_key:
316
322
  kohya_key = kohya_key.replace("unet", "lora_unet")
323
+ elif "lora_magnitude_vector" in kohya_key:
324
+ kohya_key = kohya_key.replace("lora_magnitude_vector", "dora_scale")
325
+
317
326
  kohya_key = kohya_key.replace(".", "_", kohya_key.count(".") - 2)
318
327
  kohya_key = kohya_key.replace(peft_adapter_name, "") # Kohya doesn't take names
319
328
  kohya_ss_state_dict[kohya_key] = weight
@@ -14,7 +14,6 @@ import time
14
14
  import unittest
15
15
  import urllib.parse
16
16
  from contextlib import contextmanager
17
- from distutils.util import strtobool
18
17
  from io import BytesIO, StringIO
19
18
  from pathlib import Path
20
19
  from typing import Callable, Dict, List, Optional, Union
@@ -34,6 +33,7 @@ from .import_utils import (
34
33
  is_onnx_available,
35
34
  is_opencv_available,
36
35
  is_peft_available,
36
+ is_timm_available,
37
37
  is_torch_available,
38
38
  is_torch_version,
39
39
  is_torchsde_available,
@@ -106,10 +106,21 @@ def numpy_cosine_similarity_distance(a, b):
106
106
  return distance
107
107
 
108
108
 
109
- def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_name="expected_slice"):
109
+ def print_tensor_test(
110
+ tensor,
111
+ limit_to_slices=None,
112
+ max_torch_print=None,
113
+ filename="test_corrections.txt",
114
+ expected_tensor_name="expected_slice",
115
+ ):
116
+ if max_torch_print:
117
+ torch.set_printoptions(threshold=10_000)
118
+
110
119
  test_name = os.environ.get("PYTEST_CURRENT_TEST")
111
120
  if not torch.is_tensor(tensor):
112
121
  tensor = torch.from_numpy(tensor)
122
+ if limit_to_slices:
123
+ tensor = tensor[0, -3:, -3:, -1]
113
124
 
114
125
  tensor_str = str(tensor.detach().cpu().flatten().to(torch.float32)).replace("\n", "")
115
126
  # format is usually:
@@ -118,7 +129,7 @@ def print_tensor_test(tensor, filename="test_corrections.txt", expected_tensor_n
118
129
  test_file, test_class, test_fn = test_name.split("::")
119
130
  test_fn = test_fn.split()[0]
120
131
  with open(filename, "a") as f:
121
- print(";".join([test_file, test_class, test_fn, output_str]), file=f)
132
+ print("::".join([test_file, test_class, test_fn, output_str]), file=f)
122
133
 
123
134
 
124
135
  def get_tests_dir(append_path=None):
@@ -142,6 +153,22 @@ def get_tests_dir(append_path=None):
142
153
  return tests_dir
143
154
 
144
155
 
156
+ # Taken from the following PR:
157
+ # https://github.com/huggingface/accelerate/pull/1964
158
+ def str_to_bool(value) -> int:
159
+ """
160
+ Converts a string representation of truth to `True` (1) or `False` (0). True values are `y`, `yes`, `t`, `true`,
161
+ `on`, and `1`; False value are `n`, `no`, `f`, `false`, `off`, and `0`;
162
+ """
163
+ value = value.lower()
164
+ if value in ("y", "yes", "t", "true", "on", "1"):
165
+ return 1
166
+ elif value in ("n", "no", "f", "false", "off", "0"):
167
+ return 0
168
+ else:
169
+ raise ValueError(f"invalid truth value {value}")
170
+
171
+
145
172
  def parse_flag_from_env(key, default=False):
146
173
  try:
147
174
  value = os.environ[key]
@@ -151,7 +178,7 @@ def parse_flag_from_env(key, default=False):
151
178
  else:
152
179
  # KEY is set, convert it to True or False.
153
180
  try:
154
- _value = strtobool(value)
181
+ _value = str_to_bool(value)
155
182
  except ValueError:
156
183
  # More values are supported, but let's keep the message simple.
157
184
  raise ValueError(f"If set, {key} must be yes or no.")
@@ -229,6 +256,20 @@ def require_torch_accelerator(test_case):
229
256
  )
230
257
 
231
258
 
259
+ def require_torch_multi_gpu(test_case):
260
+ """
261
+ Decorator marking a test that requires a multi-GPU setup (in PyTorch). These tests are skipped on a machine without
262
+ multiple GPUs. To run *only* the multi_gpu tests, assuming all test names contain multi_gpu: $ pytest -sv ./tests
263
+ -k "multi_gpu"
264
+ """
265
+ if not is_torch_available():
266
+ return unittest.skip("test requires PyTorch")(test_case)
267
+
268
+ import torch
269
+
270
+ return unittest.skipUnless(torch.cuda.device_count() > 1, "test requires multiple GPUs")(test_case)
271
+
272
+
232
273
  def require_torch_accelerator_with_fp16(test_case):
233
274
  """Decorator marking a test that requires an accelerator with support for the FP16 data type."""
234
275
  return unittest.skipUnless(_is_torch_fp16_available(torch_device), "test requires accelerator with fp16 support")(
@@ -300,6 +341,13 @@ def require_peft_backend(test_case):
300
341
  return unittest.skipUnless(USE_PEFT_BACKEND, "test requires PEFT backend")(test_case)
301
342
 
302
343
 
344
+ def require_timm(test_case):
345
+ """
346
+ Decorator marking a test that requires timm. These tests are skipped when timm isn't installed.
347
+ """
348
+ return unittest.skipUnless(is_timm_available(), "test requires timm")(test_case)
349
+
350
+
303
351
  def require_peft_version_greater(peft_version):
304
352
  """
305
353
  Decorator marking a test that requires PEFT backend with a specific version, this would require some specific
@@ -317,6 +365,18 @@ def require_peft_version_greater(peft_version):
317
365
  return decorator
318
366
 
319
367
 
368
+ def require_accelerate_version_greater(accelerate_version):
369
+ def decorator(test_case):
370
+ correct_accelerate_version = is_peft_available() and version.parse(
371
+ version.parse(importlib.metadata.version("accelerate")).base_version
372
+ ) > version.parse(accelerate_version)
373
+ return unittest.skipUnless(
374
+ correct_accelerate_version, f"Test requires accelerate with the version greater than {accelerate_version}."
375
+ )(test_case)
376
+
377
+ return decorator
378
+
379
+
320
380
  def deprecate_after_peft_backend(test_case):
321
381
  """
322
382
  Decorator marking a test that will be skipped after PEFT backend
@@ -324,10 +384,15 @@ def deprecate_after_peft_backend(test_case):
324
384
  return unittest.skipUnless(not USE_PEFT_BACKEND, "test skipped in favor of PEFT backend")(test_case)
325
385
 
326
386
 
387
+ def get_python_version():
388
+ sys_info = sys.version_info
389
+ major, minor = sys_info.major, sys_info.minor
390
+ return major, minor
391
+
392
+
327
393
  def require_python39_or_higher(test_case):
328
394
  def python39_available():
329
- sys_info = sys.version_info
330
- major, minor = sys_info.major, sys_info.minor
395
+ major, minor = get_python_version()
331
396
  return major == 3 and minor >= 9
332
397
 
333
398
  return unittest.skipUnless(python39_available(), "test requires Python 3.9 or higher")(test_case)
@@ -14,6 +14,7 @@
14
14
  """
15
15
  PyTorch utilities: Utilities related to PyTorch
16
16
  """
17
+
17
18
  from typing import List, Optional, Tuple, Union
18
19
 
19
20
  from . import logging