diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,96 @@
2
2
  from ..utils import DummyObject, requires_backends
3
3
 
4
4
 
5
+ class FluxAutoBlocks(metaclass=DummyObject):
6
+ _backends = ["torch", "transformers"]
7
+
8
+ def __init__(self, *args, **kwargs):
9
+ requires_backends(self, ["torch", "transformers"])
10
+
11
+ @classmethod
12
+ def from_config(cls, *args, **kwargs):
13
+ requires_backends(cls, ["torch", "transformers"])
14
+
15
+ @classmethod
16
+ def from_pretrained(cls, *args, **kwargs):
17
+ requires_backends(cls, ["torch", "transformers"])
18
+
19
+
20
+ class FluxModularPipeline(metaclass=DummyObject):
21
+ _backends = ["torch", "transformers"]
22
+
23
+ def __init__(self, *args, **kwargs):
24
+ requires_backends(self, ["torch", "transformers"])
25
+
26
+ @classmethod
27
+ def from_config(cls, *args, **kwargs):
28
+ requires_backends(cls, ["torch", "transformers"])
29
+
30
+ @classmethod
31
+ def from_pretrained(cls, *args, **kwargs):
32
+ requires_backends(cls, ["torch", "transformers"])
33
+
34
+
35
+ class StableDiffusionXLAutoBlocks(metaclass=DummyObject):
36
+ _backends = ["torch", "transformers"]
37
+
38
+ def __init__(self, *args, **kwargs):
39
+ requires_backends(self, ["torch", "transformers"])
40
+
41
+ @classmethod
42
+ def from_config(cls, *args, **kwargs):
43
+ requires_backends(cls, ["torch", "transformers"])
44
+
45
+ @classmethod
46
+ def from_pretrained(cls, *args, **kwargs):
47
+ requires_backends(cls, ["torch", "transformers"])
48
+
49
+
50
+ class StableDiffusionXLModularPipeline(metaclass=DummyObject):
51
+ _backends = ["torch", "transformers"]
52
+
53
+ def __init__(self, *args, **kwargs):
54
+ requires_backends(self, ["torch", "transformers"])
55
+
56
+ @classmethod
57
+ def from_config(cls, *args, **kwargs):
58
+ requires_backends(cls, ["torch", "transformers"])
59
+
60
+ @classmethod
61
+ def from_pretrained(cls, *args, **kwargs):
62
+ requires_backends(cls, ["torch", "transformers"])
63
+
64
+
65
+ class WanAutoBlocks(metaclass=DummyObject):
66
+ _backends = ["torch", "transformers"]
67
+
68
+ def __init__(self, *args, **kwargs):
69
+ requires_backends(self, ["torch", "transformers"])
70
+
71
+ @classmethod
72
+ def from_config(cls, *args, **kwargs):
73
+ requires_backends(cls, ["torch", "transformers"])
74
+
75
+ @classmethod
76
+ def from_pretrained(cls, *args, **kwargs):
77
+ requires_backends(cls, ["torch", "transformers"])
78
+
79
+
80
+ class WanModularPipeline(metaclass=DummyObject):
81
+ _backends = ["torch", "transformers"]
82
+
83
+ def __init__(self, *args, **kwargs):
84
+ requires_backends(self, ["torch", "transformers"])
85
+
86
+ @classmethod
87
+ def from_config(cls, *args, **kwargs):
88
+ requires_backends(cls, ["torch", "transformers"])
89
+
90
+ @classmethod
91
+ def from_pretrained(cls, *args, **kwargs):
92
+ requires_backends(cls, ["torch", "transformers"])
93
+
94
+
5
95
  class AllegroPipeline(metaclass=DummyObject):
6
96
  _backends = ["torch", "transformers"]
7
97
 
@@ -692,6 +782,36 @@ class FluxInpaintPipeline(metaclass=DummyObject):
692
782
  requires_backends(cls, ["torch", "transformers"])
693
783
 
694
784
 
785
+ class FluxKontextInpaintPipeline(metaclass=DummyObject):
786
+ _backends = ["torch", "transformers"]
787
+
788
+ def __init__(self, *args, **kwargs):
789
+ requires_backends(self, ["torch", "transformers"])
790
+
791
+ @classmethod
792
+ def from_config(cls, *args, **kwargs):
793
+ requires_backends(cls, ["torch", "transformers"])
794
+
795
+ @classmethod
796
+ def from_pretrained(cls, *args, **kwargs):
797
+ requires_backends(cls, ["torch", "transformers"])
798
+
799
+
800
+ class FluxKontextPipeline(metaclass=DummyObject):
801
+ _backends = ["torch", "transformers"]
802
+
803
+ def __init__(self, *args, **kwargs):
804
+ requires_backends(self, ["torch", "transformers"])
805
+
806
+ @classmethod
807
+ def from_config(cls, *args, **kwargs):
808
+ requires_backends(cls, ["torch", "transformers"])
809
+
810
+ @classmethod
811
+ def from_pretrained(cls, *args, **kwargs):
812
+ requires_backends(cls, ["torch", "transformers"])
813
+
814
+
695
815
  class FluxPipeline(metaclass=DummyObject):
696
816
  _backends = ["torch", "transformers"]
697
817
 
@@ -1622,6 +1742,66 @@ class PixArtSigmaPipeline(metaclass=DummyObject):
1622
1742
  requires_backends(cls, ["torch", "transformers"])
1623
1743
 
1624
1744
 
1745
+ class QwenImageEditPipeline(metaclass=DummyObject):
1746
+ _backends = ["torch", "transformers"]
1747
+
1748
+ def __init__(self, *args, **kwargs):
1749
+ requires_backends(self, ["torch", "transformers"])
1750
+
1751
+ @classmethod
1752
+ def from_config(cls, *args, **kwargs):
1753
+ requires_backends(cls, ["torch", "transformers"])
1754
+
1755
+ @classmethod
1756
+ def from_pretrained(cls, *args, **kwargs):
1757
+ requires_backends(cls, ["torch", "transformers"])
1758
+
1759
+
1760
+ class QwenImageImg2ImgPipeline(metaclass=DummyObject):
1761
+ _backends = ["torch", "transformers"]
1762
+
1763
+ def __init__(self, *args, **kwargs):
1764
+ requires_backends(self, ["torch", "transformers"])
1765
+
1766
+ @classmethod
1767
+ def from_config(cls, *args, **kwargs):
1768
+ requires_backends(cls, ["torch", "transformers"])
1769
+
1770
+ @classmethod
1771
+ def from_pretrained(cls, *args, **kwargs):
1772
+ requires_backends(cls, ["torch", "transformers"])
1773
+
1774
+
1775
+ class QwenImageInpaintPipeline(metaclass=DummyObject):
1776
+ _backends = ["torch", "transformers"]
1777
+
1778
+ def __init__(self, *args, **kwargs):
1779
+ requires_backends(self, ["torch", "transformers"])
1780
+
1781
+ @classmethod
1782
+ def from_config(cls, *args, **kwargs):
1783
+ requires_backends(cls, ["torch", "transformers"])
1784
+
1785
+ @classmethod
1786
+ def from_pretrained(cls, *args, **kwargs):
1787
+ requires_backends(cls, ["torch", "transformers"])
1788
+
1789
+
1790
+ class QwenImagePipeline(metaclass=DummyObject):
1791
+ _backends = ["torch", "transformers"]
1792
+
1793
+ def __init__(self, *args, **kwargs):
1794
+ requires_backends(self, ["torch", "transformers"])
1795
+
1796
+ @classmethod
1797
+ def from_config(cls, *args, **kwargs):
1798
+ requires_backends(cls, ["torch", "transformers"])
1799
+
1800
+ @classmethod
1801
+ def from_pretrained(cls, *args, **kwargs):
1802
+ requires_backends(cls, ["torch", "transformers"])
1803
+
1804
+
1625
1805
  class ReduxImageEncoder(metaclass=DummyObject):
1626
1806
  _backends = ["torch", "transformers"]
1627
1807
 
@@ -1757,6 +1937,81 @@ class ShapEPipeline(metaclass=DummyObject):
1757
1937
  requires_backends(cls, ["torch", "transformers"])
1758
1938
 
1759
1939
 
1940
+ class SkyReelsV2DiffusionForcingImageToVideoPipeline(metaclass=DummyObject):
1941
+ _backends = ["torch", "transformers"]
1942
+
1943
+ def __init__(self, *args, **kwargs):
1944
+ requires_backends(self, ["torch", "transformers"])
1945
+
1946
+ @classmethod
1947
+ def from_config(cls, *args, **kwargs):
1948
+ requires_backends(cls, ["torch", "transformers"])
1949
+
1950
+ @classmethod
1951
+ def from_pretrained(cls, *args, **kwargs):
1952
+ requires_backends(cls, ["torch", "transformers"])
1953
+
1954
+
1955
+ class SkyReelsV2DiffusionForcingPipeline(metaclass=DummyObject):
1956
+ _backends = ["torch", "transformers"]
1957
+
1958
+ def __init__(self, *args, **kwargs):
1959
+ requires_backends(self, ["torch", "transformers"])
1960
+
1961
+ @classmethod
1962
+ def from_config(cls, *args, **kwargs):
1963
+ requires_backends(cls, ["torch", "transformers"])
1964
+
1965
+ @classmethod
1966
+ def from_pretrained(cls, *args, **kwargs):
1967
+ requires_backends(cls, ["torch", "transformers"])
1968
+
1969
+
1970
+ class SkyReelsV2DiffusionForcingVideoToVideoPipeline(metaclass=DummyObject):
1971
+ _backends = ["torch", "transformers"]
1972
+
1973
+ def __init__(self, *args, **kwargs):
1974
+ requires_backends(self, ["torch", "transformers"])
1975
+
1976
+ @classmethod
1977
+ def from_config(cls, *args, **kwargs):
1978
+ requires_backends(cls, ["torch", "transformers"])
1979
+
1980
+ @classmethod
1981
+ def from_pretrained(cls, *args, **kwargs):
1982
+ requires_backends(cls, ["torch", "transformers"])
1983
+
1984
+
1985
+ class SkyReelsV2ImageToVideoPipeline(metaclass=DummyObject):
1986
+ _backends = ["torch", "transformers"]
1987
+
1988
+ def __init__(self, *args, **kwargs):
1989
+ requires_backends(self, ["torch", "transformers"])
1990
+
1991
+ @classmethod
1992
+ def from_config(cls, *args, **kwargs):
1993
+ requires_backends(cls, ["torch", "transformers"])
1994
+
1995
+ @classmethod
1996
+ def from_pretrained(cls, *args, **kwargs):
1997
+ requires_backends(cls, ["torch", "transformers"])
1998
+
1999
+
2000
+ class SkyReelsV2Pipeline(metaclass=DummyObject):
2001
+ _backends = ["torch", "transformers"]
2002
+
2003
+ def __init__(self, *args, **kwargs):
2004
+ requires_backends(self, ["torch", "transformers"])
2005
+
2006
+ @classmethod
2007
+ def from_config(cls, *args, **kwargs):
2008
+ requires_backends(cls, ["torch", "transformers"])
2009
+
2010
+ @classmethod
2011
+ def from_pretrained(cls, *args, **kwargs):
2012
+ requires_backends(cls, ["torch", "transformers"])
2013
+
2014
+
1760
2015
  class StableAudioPipeline(metaclass=DummyObject):
1761
2016
  _backends = ["torch", "transformers"]
1762
2017
 
@@ -20,8 +20,11 @@ import json
20
20
  import os
21
21
  import re
22
22
  import shutil
23
+ import signal
23
24
  import sys
25
+ import threading
24
26
  from pathlib import Path
27
+ from types import ModuleType
25
28
  from typing import Dict, Optional, Union
26
29
  from urllib import request
27
30
 
@@ -37,6 +40,8 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
37
40
 
38
41
  # See https://huggingface.co/datasets/diffusers/community-pipelines-mirror
39
42
  COMMUNITY_PIPELINES_MIRROR_ID = "diffusers/community-pipelines-mirror"
43
+ TIME_OUT_REMOTE_CODE = int(os.getenv("DIFFUSERS_TIMEOUT_REMOTE_CODE", 15))
44
+ _HF_REMOTE_CODE_LOCK = threading.Lock()
40
45
 
41
46
 
42
47
  def get_diffusers_versions():
@@ -154,33 +159,87 @@ def check_imports(filename):
154
159
  return get_relative_imports(filename)
155
160
 
156
161
 
157
- def get_class_in_module(class_name, module_path, pretrained_model_name_or_path=None):
162
+ def _raise_timeout_error(signum, frame):
163
+ raise ValueError(
164
+ "Loading this model requires you to execute custom code contained in the model repository on your local "
165
+ "machine. Please set the option `trust_remote_code=True` to permit loading of this model."
166
+ )
167
+
168
+
169
+ def resolve_trust_remote_code(trust_remote_code, model_name, has_remote_code):
170
+ if trust_remote_code is None:
171
+ if has_remote_code and TIME_OUT_REMOTE_CODE > 0:
172
+ prev_sig_handler = None
173
+ try:
174
+ prev_sig_handler = signal.signal(signal.SIGALRM, _raise_timeout_error)
175
+ signal.alarm(TIME_OUT_REMOTE_CODE)
176
+ while trust_remote_code is None:
177
+ answer = input(
178
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
179
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
180
+ f"You can avoid this prompt in future by passing the argument `trust_remote_code=True`.\n\n"
181
+ f"Do you wish to run the custom code? [y/N] "
182
+ )
183
+ if answer.lower() in ["yes", "y", "1"]:
184
+ trust_remote_code = True
185
+ elif answer.lower() in ["no", "n", "0", ""]:
186
+ trust_remote_code = False
187
+ signal.alarm(0)
188
+ except Exception:
189
+ # OS which does not support signal.SIGALRM
190
+ raise ValueError(
191
+ f"The repository for {model_name} contains custom code which must be executed to correctly "
192
+ f"load the model. You can inspect the repository content at https://hf.co/{model_name}.\n"
193
+ f"Please pass the argument `trust_remote_code=True` to allow custom code to be run."
194
+ )
195
+ finally:
196
+ if prev_sig_handler is not None:
197
+ signal.signal(signal.SIGALRM, prev_sig_handler)
198
+ signal.alarm(0)
199
+ elif has_remote_code:
200
+ # For the CI which puts the timeout at 0
201
+ _raise_timeout_error(None, None)
202
+
203
+ if has_remote_code and not trust_remote_code:
204
+ raise ValueError(
205
+ f"Loading {model_name} requires you to execute the configuration file in that"
206
+ " repo on your local machine. Make sure you have read the code there to avoid malicious use, then"
207
+ " set the option `trust_remote_code=True` to remove this error."
208
+ )
209
+
210
+ return trust_remote_code
211
+
212
+
213
+ def get_class_in_module(class_name, module_path, force_reload=False):
158
214
  """
159
215
  Import a module on the cache directory for modules and extract a class from it.
160
216
  """
161
- module_path = module_path.replace(os.path.sep, ".")
162
- try:
163
- module = importlib.import_module(module_path)
164
- except ModuleNotFoundError as e:
165
- # This can happen when the repo id contains ".", which Python's import machinery interprets as a directory
166
- # separator. We do a bit of monkey patching to detect and fix this case.
167
- if not (
168
- pretrained_model_name_or_path is not None
169
- and "." in pretrained_model_name_or_path
170
- and module_path.startswith("diffusers_modules")
171
- and pretrained_model_name_or_path.replace("/", "--") in module_path
172
- ):
173
- raise e # We can't figure this one out, just reraise the original error
217
+ name = os.path.normpath(module_path)
218
+ if name.endswith(".py"):
219
+ name = name[:-3]
220
+ name = name.replace(os.path.sep, ".")
221
+ module_file: Path = Path(HF_MODULES_CACHE) / module_path
222
+
223
+ with _HF_REMOTE_CODE_LOCK:
224
+ if force_reload:
225
+ sys.modules.pop(name, None)
226
+ importlib.invalidate_caches()
227
+ cached_module: Optional[ModuleType] = sys.modules.get(name)
228
+ module_spec = importlib.util.spec_from_file_location(name, location=module_file)
229
+
230
+ module: ModuleType
231
+ if cached_module is None:
232
+ module = importlib.util.module_from_spec(module_spec)
233
+ # insert it into sys.modules before any loading begins
234
+ sys.modules[name] = module
235
+ else:
236
+ module = cached_module
174
237
 
175
- corrected_path = os.path.join(HF_MODULES_CACHE, module_path.replace(".", "/")) + ".py"
176
- corrected_path = corrected_path.replace(
177
- pretrained_model_name_or_path.replace("/", "--").replace(".", "/"),
178
- pretrained_model_name_or_path.replace("/", "--"),
179
- )
180
- module = importlib.machinery.SourceFileLoader(module_path, corrected_path).load_module()
238
+ module_spec.loader.exec_module(module)
181
239
 
182
240
  if class_name is None:
183
241
  return find_pipeline_class(module)
242
+
184
243
  return getattr(module, class_name)
185
244
 
186
245
 
@@ -259,8 +318,8 @@ def get_cached_module_file(
259
318
 
260
319
  <Tip>
261
320
 
262
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
263
- [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
321
+ You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
322
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
264
323
 
265
324
  </Tip>
266
325
 
@@ -446,8 +505,8 @@ def get_class_from_dynamic_module(
446
505
 
447
506
  <Tip>
448
507
 
449
- You may pass a token in `token` if you are not logged in (`huggingface-cli login`) and want to use private or
450
- [gated models](https://huggingface.co/docs/hub/models-gated#gated-models).
508
+ You may pass a token in `token` if you are not logged in (`hf auth login`) and want to use private or [gated
509
+ models](https://huggingface.co/docs/hub/models-gated#gated-models).
451
510
 
452
511
  </Tip>
453
512
 
@@ -472,4 +531,4 @@ def get_class_from_dynamic_module(
472
531
  revision=revision,
473
532
  local_files_only=local_files_only,
474
533
  )
475
- return get_class_in_module(class_name, final_module.replace(".py", ""), pretrained_model_name_or_path)
534
+ return get_class_in_module(class_name, final_module)
@@ -304,8 +304,7 @@ def _get_model_file(
304
304
  raise EnvironmentError(
305
305
  f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier "
306
306
  "listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a "
307
- "token having permission to this repo with `token` or log in with `huggingface-cli "
308
- "login`."
307
+ "token having permission to this repo with `token` or log in with `hf auth login`."
309
308
  ) from e
310
309
  except RevisionNotFoundError as e:
311
310
  raise EnvironmentError(
@@ -403,15 +402,17 @@ def _get_checkpoint_shard_files(
403
402
  allow_patterns = [os.path.join(subfolder, p) for p in allow_patterns]
404
403
 
405
404
  ignore_patterns = ["*.json", "*.md"]
406
- # `model_info` call must guarded with the above condition.
407
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
408
- for shard_file in original_shard_filenames:
409
- shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
410
- if not shard_file_present:
411
- raise EnvironmentError(
412
- f"{shards_path} does not appear to have a file named {shard_file} which is "
413
- "required according to the checkpoint index."
414
- )
405
+
406
+ # If the repo doesn't have the required shards, error out early even before downloading anything.
407
+ if not local_files_only:
408
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
409
+ for shard_file in original_shard_filenames:
410
+ shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
411
+ if not shard_file_present:
412
+ raise EnvironmentError(
413
+ f"{shards_path} does not appear to have a file named {shard_file} which is "
414
+ "required according to the checkpoint index."
415
+ )
415
416
 
416
417
  try:
417
418
  # Load from URL
@@ -438,6 +439,11 @@ def _get_checkpoint_shard_files(
438
439
  ) from e
439
440
 
440
441
  cached_filenames = [os.path.join(cached_folder, f) for f in original_shard_filenames]
442
+ for cached_file in cached_filenames:
443
+ if not os.path.isfile(cached_file):
444
+ raise EnvironmentError(
445
+ f"{cached_folder} does not have a file named {cached_file} which is required according to the checkpoint index."
446
+ )
441
447
 
442
448
  return cached_filenames, sharded_metadata
443
449
 
@@ -467,6 +473,7 @@ class PushToHubMixin:
467
473
  token: Optional[str] = None,
468
474
  commit_message: Optional[str] = None,
469
475
  create_pr: bool = False,
476
+ subfolder: Optional[str] = None,
470
477
  ):
471
478
  """
472
479
  Uploads all files in `working_dir` to `repo_id`.
@@ -481,7 +488,12 @@ class PushToHubMixin:
481
488
 
482
489
  logger.info(f"Uploading the files of {working_dir} to {repo_id}.")
483
490
  return upload_folder(
484
- repo_id=repo_id, folder_path=working_dir, token=token, commit_message=commit_message, create_pr=create_pr
491
+ repo_id=repo_id,
492
+ folder_path=working_dir,
493
+ token=token,
494
+ commit_message=commit_message,
495
+ create_pr=create_pr,
496
+ path_in_repo=subfolder,
485
497
  )
486
498
 
487
499
  def push_to_hub(
@@ -493,6 +505,7 @@ class PushToHubMixin:
493
505
  create_pr: bool = False,
494
506
  safe_serialization: bool = True,
495
507
  variant: Optional[str] = None,
508
+ subfolder: Optional[str] = None,
496
509
  ) -> str:
497
510
  """
498
511
  Upload model, scheduler, or pipeline files to the 🤗 Hugging Face Hub.
@@ -508,8 +521,8 @@ class PushToHubMixin:
508
521
  Whether to make the repo private. If `None` (default), the repo will be public unless the
509
522
  organization's default is private. This value is ignored if the repo already exists.
510
523
  token (`str`, *optional*):
511
- The token to use as HTTP bearer authorization for remote files. The token generated when running
512
- `huggingface-cli login` (stored in `~/.huggingface`).
524
+ The token to use as HTTP bearer authorization for remote files. The token generated when running `hf
525
+ auth login` (stored in `~/.huggingface`).
513
526
  create_pr (`bool`, *optional*, defaults to `False`):
514
527
  Whether or not to create a PR with the uploaded files or directly commit.
515
528
  safe_serialization (`bool`, *optional*, defaults to `True`):
@@ -534,8 +547,9 @@ class PushToHubMixin:
534
547
  repo_id = create_repo(repo_id, private=private, token=token, exist_ok=True).repo_id
535
548
 
536
549
  # Create a new empty model card and eventually tag it
537
- model_card = load_or_create_model_card(repo_id, token=token)
538
- model_card = populate_model_card(model_card)
550
+ if not subfolder:
551
+ model_card = load_or_create_model_card(repo_id, token=token)
552
+ model_card = populate_model_card(model_card)
539
553
 
540
554
  # Save all files.
541
555
  save_kwargs = {"safe_serialization": safe_serialization}
@@ -546,7 +560,8 @@ class PushToHubMixin:
546
560
  self.save_pretrained(tmpdir, **save_kwargs)
547
561
 
548
562
  # Update model card if needed:
549
- model_card.save(os.path.join(tmpdir, "README.md"))
563
+ if not subfolder:
564
+ model_card.save(os.path.join(tmpdir, "README.md"))
550
565
 
551
566
  return self._upload_folder(
552
567
  tmpdir,
@@ -554,4 +569,5 @@ class PushToHubMixin:
554
569
  token=token,
555
570
  commit_message=commit_message,
556
571
  create_pr=create_pr,
572
+ subfolder=subfolder,
557
573
  )
@@ -192,6 +192,7 @@ _torch_xla_available, _torch_xla_version = _is_package_available("torch_xla")
192
192
  _torch_npu_available, _torch_npu_version = _is_package_available("torch_npu")
193
193
  _transformers_available, _transformers_version = _is_package_available("transformers")
194
194
  _hf_hub_available, _hf_hub_version = _is_package_available("huggingface_hub")
195
+ _kernels_available, _kernels_version = _is_package_available("kernels")
195
196
  _inflect_available, _inflect_version = _is_package_available("inflect")
196
197
  _unidecode_available, _unidecode_version = _is_package_available("unidecode")
197
198
  _k_diffusion_available, _k_diffusion_version = _is_package_available("k_diffusion")
@@ -220,6 +221,10 @@ _pytorch_retinaface_available, _pytorch_retinaface_version = _is_package_availab
220
221
  _better_profanity_available, _better_profanity_version = _is_package_available("better_profanity")
221
222
  _nltk_available, _nltk_version = _is_package_available("nltk")
222
223
  _cosmos_guardrail_available, _cosmos_guardrail_version = _is_package_available("cosmos_guardrail")
224
+ _sageattention_available, _sageattention_version = _is_package_available("sageattention")
225
+ _flash_attn_available, _flash_attn_version = _is_package_available("flash_attn")
226
+ _flash_attn_3_available, _flash_attn_3_version = _is_package_available("flash_attn_3")
227
+ _kornia_available, _kornia_version = _is_package_available("kornia")
223
228
 
224
229
 
225
230
  def is_torch_available():
@@ -274,6 +279,10 @@ def is_accelerate_available():
274
279
  return _accelerate_available
275
280
 
276
281
 
282
+ def is_kernels_available():
283
+ return _kernels_available
284
+
285
+
277
286
  def is_k_diffusion_available():
278
287
  return _k_diffusion_available
279
288
 
@@ -378,6 +387,22 @@ def is_hpu_available():
378
387
  return all(importlib.util.find_spec(lib) for lib in ("habana_frameworks", "habana_frameworks.torch"))
379
388
 
380
389
 
390
+ def is_sageattention_available():
391
+ return _sageattention_available
392
+
393
+
394
+ def is_flash_attn_available():
395
+ return _flash_attn_available
396
+
397
+
398
+ def is_flash_attn_3_available():
399
+ return _flash_attn_3_available
400
+
401
+
402
+ def is_kornia_available():
403
+ return _kornia_available
404
+
405
+
381
406
  # docstyle-ignore
382
407
  FLAX_IMPORT_ERROR = """
383
408
  {0} requires the FLAX library but it was not found in your environment. Checkout the instructions on the
@@ -804,6 +829,51 @@ def is_optimum_quanto_version(operation: str, version: str):
804
829
  return compare_versions(parse(_optimum_quanto_version), operation, version)
805
830
 
806
831
 
832
+ def is_xformers_version(operation: str, version: str):
833
+ """
834
+ Compares the current xformers version to a given reference with an operation.
835
+
836
+ Args:
837
+ operation (`str`):
838
+ A string representation of an operator, such as `">"` or `"<="`
839
+ version (`str`):
840
+ A version string
841
+ """
842
+ if not _xformers_available:
843
+ return False
844
+ return compare_versions(parse(_xformers_version), operation, version)
845
+
846
+
847
+ def is_sageattention_version(operation: str, version: str):
848
+ """
849
+ Compares the current sageattention version to a given reference with an operation.
850
+
851
+ Args:
852
+ operation (`str`):
853
+ A string representation of an operator, such as `">"` or `"<="`
854
+ version (`str`):
855
+ A version string
856
+ """
857
+ if not _sageattention_available:
858
+ return False
859
+ return compare_versions(parse(_sageattention_version), operation, version)
860
+
861
+
862
+ def is_flash_attn_version(operation: str, version: str):
863
+ """
864
+ Compares the current flash-attention version to a given reference with an operation.
865
+
866
+ Args:
867
+ operation (`str`):
868
+ A string representation of an operator, such as `">"` or `"<="`
869
+ version (`str`):
870
+ A version string
871
+ """
872
+ if not _flash_attn_available:
873
+ return False
874
+ return compare_versions(parse(_flash_attn_version), operation, version)
875
+
876
+
807
877
  def get_objects_from_module(module):
808
878
  """
809
879
  Returns a dict of object names and values in a module, while skipping private/internal objects