diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -150,7 +150,9 @@ def unscale_lora_layers(model, weight: Optional[float] = None):
150
150
  module.set_scale(adapter_name, 1.0)
151
151
 
152
152
 
153
- def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True):
153
+ def get_peft_kwargs(
154
+ rank_dict, network_alpha_dict, peft_state_dict, is_unet=True, model_state_dict=None, adapter_name=None
155
+ ):
154
156
  rank_pattern = {}
155
157
  alpha_pattern = {}
156
158
  r = lora_alpha = list(rank_dict.values())[0]
@@ -180,7 +182,6 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
180
182
  else:
181
183
  lora_alpha = set(network_alpha_dict.values()).pop()
182
184
 
183
- # layer names without the Diffusers specific
184
185
  target_modules = list({name.split(".lora")[0] for name in peft_state_dict.keys()})
185
186
  use_dora = any("lora_magnitude_vector" in k for k in peft_state_dict)
186
187
  # for now we know that the "bias" keys are only associated with `lora_B`.
@@ -195,6 +196,7 @@ def get_peft_kwargs(rank_dict, network_alpha_dict, peft_state_dict, is_unet=True
195
196
  "use_dora": use_dora,
196
197
  "lora_bias": lora_bias,
197
198
  }
199
+
198
200
  return lora_config_kwargs
199
201
 
200
202
 
@@ -294,11 +296,7 @@ def check_peft_version(min_version: str) -> None:
294
296
 
295
297
 
296
298
  def _create_lora_config(
297
- state_dict,
298
- network_alphas,
299
- metadata,
300
- rank_pattern_dict,
301
- is_unet: bool = True,
299
+ state_dict, network_alphas, metadata, rank_pattern_dict, is_unet=True, model_state_dict=None, adapter_name=None
302
300
  ):
303
301
  from peft import LoraConfig
304
302
 
@@ -306,7 +304,12 @@ def _create_lora_config(
306
304
  lora_config_kwargs = metadata
307
305
  else:
308
306
  lora_config_kwargs = get_peft_kwargs(
309
- rank_pattern_dict, network_alpha_dict=network_alphas, peft_state_dict=state_dict, is_unet=is_unet
307
+ rank_pattern_dict,
308
+ network_alpha_dict=network_alphas,
309
+ peft_state_dict=state_dict,
310
+ is_unet=is_unet,
311
+ model_state_dict=model_state_dict,
312
+ adapter_name=adapter_name,
310
313
  )
311
314
 
312
315
  _maybe_raise_error_for_ambiguous_keys(lora_config_kwargs)
@@ -1,4 +1,5 @@
1
1
  import functools
2
+ import glob
2
3
  import importlib
3
4
  import importlib.metadata
4
5
  import inspect
@@ -18,7 +19,7 @@ from collections import UserDict
18
19
  from contextlib import contextmanager
19
20
  from io import BytesIO, StringIO
20
21
  from pathlib import Path
21
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
22
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Set, Tuple, Union
22
23
 
23
24
  import numpy as np
24
25
  import PIL.Image
@@ -35,6 +36,7 @@ from .import_utils import (
35
36
  is_compel_available,
36
37
  is_flax_available,
37
38
  is_gguf_available,
39
+ is_kernels_available,
38
40
  is_note_seq_available,
39
41
  is_onnx_available,
40
42
  is_opencv_available,
@@ -421,6 +423,10 @@ def require_big_accelerator(test_case):
421
423
  Decorator marking a test that requires a bigger hardware accelerator (24GB) for execution. Some example pipelines:
422
424
  Flux, SD3, Cog, etc.
423
425
  """
426
+ import pytest
427
+
428
+ test_case = pytest.mark.big_accelerator(test_case)
429
+
424
430
  if not is_torch_available():
425
431
  return unittest.skip("test requires PyTorch")(test_case)
426
432
 
@@ -629,6 +635,18 @@ def require_torchao_version_greater_or_equal(torchao_version):
629
635
  return decorator
630
636
 
631
637
 
638
+ def require_kernels_version_greater_or_equal(kernels_version):
639
+ def decorator(test_case):
640
+ correct_kernels_version = is_kernels_available() and version.parse(
641
+ version.parse(importlib.metadata.version("kernels")).base_version
642
+ ) >= version.parse(kernels_version)
643
+ return unittest.skipUnless(
644
+ correct_kernels_version, f"Test requires kernels with version greater than {kernels_version}."
645
+ )(test_case)
646
+
647
+ return decorator
648
+
649
+
632
650
  def deprecate_after_peft_backend(test_case):
633
651
  """
634
652
  Decorator marking a test that will be skipped after PEFT backend
@@ -990,10 +1008,10 @@ def pytest_terminal_summary_main(tr, id):
990
1008
  config.option.tbstyle = orig_tbstyle
991
1009
 
992
1010
 
993
- # Copied from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
1011
+ # Adapted from https://github.com/huggingface/transformers/blob/000e52aec8850d3fe2f360adc6fd256e5b47fe4c/src/transformers/testing_utils.py#L1905
994
1012
  def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, description: Optional[str] = None):
995
1013
  """
996
- To decorate flaky tests. They will be retried on failures.
1014
+ To decorate flaky tests (methods or entire classes). They will be retried on failures.
997
1015
 
998
1016
  Args:
999
1017
  max_attempts (`int`, *optional*, defaults to 5):
@@ -1005,22 +1023,33 @@ def is_flaky(max_attempts: int = 5, wait_before_retry: Optional[float] = None, d
1005
1023
  etc.)
1006
1024
  """
1007
1025
 
1008
- def decorator(test_func_ref):
1009
- @functools.wraps(test_func_ref)
1026
+ def decorator(obj):
1027
+ # If decorating a class, wrap each test method on it
1028
+ if inspect.isclass(obj):
1029
+ for attr_name, attr_value in list(obj.__dict__.items()):
1030
+ if callable(attr_value) and attr_name.startswith("test"):
1031
+ # recursively decorate the method
1032
+ setattr(obj, attr_name, decorator(attr_value))
1033
+ return obj
1034
+
1035
+ # Otherwise we're decorating a single test function / method
1036
+ @functools.wraps(obj)
1010
1037
  def wrapper(*args, **kwargs):
1011
1038
  retry_count = 1
1012
-
1013
1039
  while retry_count < max_attempts:
1014
1040
  try:
1015
- return test_func_ref(*args, **kwargs)
1016
-
1041
+ return obj(*args, **kwargs)
1017
1042
  except Exception as err:
1018
- print(f"Test failed with {err} at try {retry_count}/{max_attempts}.", file=sys.stderr)
1043
+ msg = (
1044
+ f"[FLAKY] {description or obj.__name__!r} "
1045
+ f"failed on attempt {retry_count}/{max_attempts}: {err}"
1046
+ )
1047
+ print(msg, file=sys.stderr)
1019
1048
  if wait_before_retry is not None:
1020
1049
  time.sleep(wait_before_retry)
1021
1050
  retry_count += 1
1022
1051
 
1023
- return test_func_ref(*args, **kwargs)
1052
+ return obj(*args, **kwargs)
1024
1053
 
1025
1054
  return wrapper
1026
1055
 
@@ -1377,6 +1406,103 @@ if TYPE_CHECKING:
1377
1406
  else:
1378
1407
  DevicePropertiesUserDict = UserDict
1379
1408
 
1409
+ if is_torch_available():
1410
+ from diffusers.hooks._common import _GO_LC_SUPPORTED_PYTORCH_LAYERS
1411
+ from diffusers.hooks.group_offloading import (
1412
+ _GROUP_ID_LAZY_LEAF,
1413
+ _compute_group_hash,
1414
+ _find_parent_module_in_module_dict,
1415
+ _gather_buffers_with_no_group_offloading_parent,
1416
+ _gather_parameters_with_no_group_offloading_parent,
1417
+ )
1418
+
1419
+ def _get_expected_safetensors_files(
1420
+ module: torch.nn.Module,
1421
+ offload_to_disk_path: str,
1422
+ offload_type: str,
1423
+ num_blocks_per_group: Optional[int] = None,
1424
+ ) -> Set[str]:
1425
+ expected_files = set()
1426
+
1427
+ def get_hashed_filename(group_id: str) -> str:
1428
+ short_hash = _compute_group_hash(group_id)
1429
+ return os.path.join(offload_to_disk_path, f"group_{short_hash}.safetensors")
1430
+
1431
+ if offload_type == "block_level":
1432
+ if num_blocks_per_group is None:
1433
+ raise ValueError("num_blocks_per_group must be provided for 'block_level' offloading.")
1434
+
1435
+ # Handle groups of ModuleList and Sequential blocks
1436
+ unmatched_modules = []
1437
+ for name, submodule in module.named_children():
1438
+ if not isinstance(submodule, (torch.nn.ModuleList, torch.nn.Sequential)):
1439
+ unmatched_modules.append(module)
1440
+ continue
1441
+
1442
+ for i in range(0, len(submodule), num_blocks_per_group):
1443
+ current_modules = submodule[i : i + num_blocks_per_group]
1444
+ if not current_modules:
1445
+ continue
1446
+ group_id = f"{name}_{i}_{i + len(current_modules) - 1}"
1447
+ expected_files.add(get_hashed_filename(group_id))
1448
+
1449
+ # Handle the group for unmatched top-level modules and parameters
1450
+ for module in unmatched_modules:
1451
+ expected_files.add(get_hashed_filename(f"{module.__class__.__name__}_unmatched_group"))
1452
+
1453
+ elif offload_type == "leaf_level":
1454
+ # Handle leaf-level module groups
1455
+ for name, submodule in module.named_modules():
1456
+ if isinstance(submodule, _GO_LC_SUPPORTED_PYTORCH_LAYERS):
1457
+ # These groups will always have parameters, so a file is expected
1458
+ expected_files.add(get_hashed_filename(name))
1459
+
1460
+ # Handle groups for non-leaf parameters/buffers
1461
+ modules_with_group_offloading = {
1462
+ name for name, sm in module.named_modules() if isinstance(sm, _GO_LC_SUPPORTED_PYTORCH_LAYERS)
1463
+ }
1464
+ parameters = _gather_parameters_with_no_group_offloading_parent(module, modules_with_group_offloading)
1465
+ buffers = _gather_buffers_with_no_group_offloading_parent(module, modules_with_group_offloading)
1466
+
1467
+ all_orphans = parameters + buffers
1468
+ if all_orphans:
1469
+ parent_to_tensors = {}
1470
+ module_dict = dict(module.named_modules())
1471
+ for tensor_name, _ in all_orphans:
1472
+ parent_name = _find_parent_module_in_module_dict(tensor_name, module_dict)
1473
+ if parent_name not in parent_to_tensors:
1474
+ parent_to_tensors[parent_name] = []
1475
+ parent_to_tensors[parent_name].append(tensor_name)
1476
+
1477
+ for parent_name in parent_to_tensors:
1478
+ # A file is expected for each parent that gathers orphaned tensors
1479
+ expected_files.add(get_hashed_filename(parent_name))
1480
+ expected_files.add(get_hashed_filename(_GROUP_ID_LAZY_LEAF))
1481
+
1482
+ else:
1483
+ raise ValueError(f"Unsupported offload_type: {offload_type}")
1484
+
1485
+ return expected_files
1486
+
1487
+ def _check_safetensors_serialization(
1488
+ module: torch.nn.Module,
1489
+ offload_to_disk_path: str,
1490
+ offload_type: str,
1491
+ num_blocks_per_group: Optional[int] = None,
1492
+ ) -> bool:
1493
+ if not os.path.isdir(offload_to_disk_path):
1494
+ return False, None, None
1495
+
1496
+ expected_files = _get_expected_safetensors_files(
1497
+ module, offload_to_disk_path, offload_type, num_blocks_per_group
1498
+ )
1499
+ actual_files = set(glob.glob(os.path.join(offload_to_disk_path, "*.safetensors")))
1500
+ missing_files = expected_files - actual_files
1501
+ extra_files = actual_files - expected_files
1502
+
1503
+ is_correct = not missing_files and not extra_files
1504
+ return is_correct, extra_files, missing_files
1505
+
1380
1506
 
1381
1507
  class Expectations(DevicePropertiesUserDict):
1382
1508
  def get_expectation(self) -> Any:
@@ -15,6 +15,7 @@
15
15
  PyTorch utilities: Utilities related to PyTorch
16
16
  """
17
17
 
18
+ import functools
18
19
  from typing import List, Optional, Tuple, Union
19
20
 
20
21
  from . import logging
@@ -92,6 +93,11 @@ def is_compiled_module(module) -> bool:
92
93
  return isinstance(module, torch._dynamo.eval_frame.OptimizedModule)
93
94
 
94
95
 
96
+ def unwrap_module(module):
97
+ """Unwraps a module if it was compiled with torch.compile()"""
98
+ return module._orig_mod if is_compiled_module(module) else module
99
+
100
+
95
101
  def fourier_filter(x_in: "torch.Tensor", threshold: int, scale: int) -> "torch.Tensor":
96
102
  """Fourier filter as introduced in FreeU (https://huggingface.co/papers/2309.11497).
97
103
 
@@ -163,6 +169,7 @@ def get_torch_cuda_device_capability():
163
169
  return None
164
170
 
165
171
 
172
+ @functools.lru_cache
166
173
  def get_device():
167
174
  if torch.cuda.is_available():
168
175
  return "cuda"
@@ -170,6 +177,8 @@ def get_device():
170
177
  return "npu"
171
178
  elif hasattr(torch, "xpu") and torch.xpu.is_available():
172
179
  return "xpu"
180
+ elif torch.backends.mps.is_available():
181
+ return "mps"
173
182
  else:
174
183
  return "cpu"
175
184
 
@@ -177,5 +186,14 @@ def get_device():
177
186
  def empty_device_cache(device_type: Optional[str] = None):
178
187
  if device_type is None:
179
188
  device_type = get_device()
189
+ if device_type in ["cpu"]:
190
+ return
180
191
  device_mod = getattr(torch, device_type, torch.cuda)
181
192
  device_mod.empty_cache()
193
+
194
+
195
+ def device_synchronize(device_type: Optional[str] = None):
196
+ if device_type is None:
197
+ device_type = get_device()
198
+ device_mod = getattr(torch, device_type, torch.cuda)
199
+ device_mod.synchronize()
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.2
2
2
  Name: diffusers
3
- Version: 0.34.0
3
+ Version: 0.35.1
4
4
  Summary: State-of-the-art diffusion in PyTorch and JAX.
5
5
  Home-page: https://github.com/huggingface/diffusers
6
6
  Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/diffusers/graphs/contributors)
@@ -23,7 +23,7 @@ Description-Content-Type: text/markdown
23
23
  License-File: LICENSE
24
24
  Requires-Dist: importlib_metadata
25
25
  Requires-Dist: filelock
26
- Requires-Dist: huggingface-hub>=0.27.0
26
+ Requires-Dist: huggingface-hub>=0.34.0
27
27
  Requires-Dist: numpy
28
28
  Requires-Dist: regex!=2019.12.17
29
29
  Requires-Dist: requests
@@ -42,14 +42,14 @@ Requires-Dist: datasets; extra == "training"
42
42
  Requires-Dist: protobuf<4,>=3.20.3; extra == "training"
43
43
  Requires-Dist: tensorboard; extra == "training"
44
44
  Requires-Dist: Jinja2; extra == "training"
45
- Requires-Dist: peft>=0.15.0; extra == "training"
45
+ Requires-Dist: peft>=0.17.0; extra == "training"
46
46
  Provides-Extra: test
47
47
  Requires-Dist: compel==0.1.8; extra == "test"
48
48
  Requires-Dist: GitPython<3.1.19; extra == "test"
49
49
  Requires-Dist: datasets; extra == "test"
50
50
  Requires-Dist: Jinja2; extra == "test"
51
51
  Requires-Dist: invisible-watermark>=0.2.0; extra == "test"
52
- Requires-Dist: k-diffusion>=0.0.12; extra == "test"
52
+ Requires-Dist: k-diffusion==0.0.12; extra == "test"
53
53
  Requires-Dist: librosa; extra == "test"
54
54
  Requires-Dist: parameterized; extra == "test"
55
55
  Requires-Dist: pytest; extra == "test"
@@ -92,7 +92,7 @@ Requires-Dist: GitPython<3.1.19; extra == "dev"
92
92
  Requires-Dist: datasets; extra == "dev"
93
93
  Requires-Dist: Jinja2; extra == "dev"
94
94
  Requires-Dist: invisible-watermark>=0.2.0; extra == "dev"
95
- Requires-Dist: k-diffusion>=0.0.12; extra == "dev"
95
+ Requires-Dist: k-diffusion==0.0.12; extra == "dev"
96
96
  Requires-Dist: librosa; extra == "dev"
97
97
  Requires-Dist: parameterized; extra == "dev"
98
98
  Requires-Dist: pytest; extra == "dev"
@@ -111,7 +111,7 @@ Requires-Dist: datasets; extra == "dev"
111
111
  Requires-Dist: protobuf<4,>=3.20.3; extra == "dev"
112
112
  Requires-Dist: tensorboard; extra == "dev"
113
113
  Requires-Dist: Jinja2; extra == "dev"
114
- Requires-Dist: peft>=0.15.0; extra == "dev"
114
+ Requires-Dist: peft>=0.17.0; extra == "dev"
115
115
  Requires-Dist: hf-doc-builder>=0.3.0; extra == "dev"
116
116
  Requires-Dist: torch>=1.4; extra == "dev"
117
117
  Requires-Dist: accelerate>=0.31.0; extra == "dev"