diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (173) hide show
  1. diffusers/__init__.py +38 -2
  2. diffusers/configuration_utils.py +12 -0
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +257 -54
  5. diffusers/loaders/__init__.py +2 -0
  6. diffusers/loaders/ip_adapter.py +5 -1
  7. diffusers/loaders/lora_base.py +14 -7
  8. diffusers/loaders/lora_conversion_utils.py +332 -0
  9. diffusers/loaders/lora_pipeline.py +707 -41
  10. diffusers/loaders/peft.py +1 -0
  11. diffusers/loaders/single_file_utils.py +81 -4
  12. diffusers/loaders/textual_inversion.py +2 -0
  13. diffusers/loaders/unet.py +39 -8
  14. diffusers/models/__init__.py +4 -0
  15. diffusers/models/adapter.py +53 -53
  16. diffusers/models/attention.py +86 -10
  17. diffusers/models/attention_processor.py +169 -133
  18. diffusers/models/autoencoders/autoencoder_kl.py +71 -11
  19. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
  20. diffusers/models/controlnet_flux.py +536 -0
  21. diffusers/models/controlnet_sd3.py +7 -3
  22. diffusers/models/controlnet_sparsectrl.py +0 -1
  23. diffusers/models/embeddings.py +238 -61
  24. diffusers/models/embeddings_flax.py +23 -9
  25. diffusers/models/model_loading_utils.py +182 -14
  26. diffusers/models/modeling_utils.py +283 -46
  27. diffusers/models/normalization.py +79 -0
  28. diffusers/models/transformers/__init__.py +1 -0
  29. diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
  30. diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
  31. diffusers/models/transformers/pixart_transformer_2d.py +9 -1
  32. diffusers/models/transformers/transformer_cogview3plus.py +386 -0
  33. diffusers/models/transformers/transformer_flux.py +161 -44
  34. diffusers/models/transformers/transformer_sd3.py +7 -1
  35. diffusers/models/unets/unet_2d_condition.py +8 -8
  36. diffusers/models/unets/unet_motion_model.py +41 -63
  37. diffusers/models/upsampling.py +6 -6
  38. diffusers/pipelines/__init__.py +40 -7
  39. diffusers/pipelines/animatediff/__init__.py +2 -0
  40. diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
  41. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
  42. diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
  43. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
  45. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
  46. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
  47. diffusers/pipelines/auto_pipeline.py +39 -8
  48. diffusers/pipelines/cogvideo/__init__.py +6 -0
  49. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
  50. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
  51. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
  52. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
  53. diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
  54. diffusers/pipelines/cogview3/__init__.py +47 -0
  55. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
  56. diffusers/pipelines/cogview3/pipeline_output.py +21 -0
  57. diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
  58. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
  59. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
  60. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
  61. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
  62. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
  63. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
  64. diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
  65. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
  66. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
  67. diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
  68. diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
  69. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
  70. diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
  71. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
  72. diffusers/pipelines/flux/__init__.py +10 -0
  73. diffusers/pipelines/flux/pipeline_flux.py +53 -20
  74. diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
  75. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
  76. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
  77. diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
  78. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
  79. diffusers/pipelines/free_noise_utils.py +365 -5
  80. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
  81. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
  82. diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
  83. diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
  84. diffusers/pipelines/kolors/tokenizer.py +4 -0
  85. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
  86. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
  87. diffusers/pipelines/latte/pipeline_latte.py +2 -2
  88. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
  89. diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
  90. diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
  91. diffusers/pipelines/pag/__init__.py +6 -0
  92. diffusers/pipelines/pag/pag_utils.py +8 -2
  93. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
  94. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
  95. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
  96. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
  97. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
  98. diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
  99. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
  100. diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
  101. diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
  102. diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
  103. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
  104. diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
  105. diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
  106. diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
  107. diffusers/pipelines/pia/pipeline_pia.py +2 -0
  108. diffusers/pipelines/pipeline_loading_utils.py +225 -27
  109. diffusers/pipelines/pipeline_utils.py +123 -180
  110. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  111. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  112. diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
  113. diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
  114. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
  115. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
  116. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
  117. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
  118. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
  119. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
  120. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
  121. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
  122. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
  123. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
  124. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
  125. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
  126. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
  127. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
  128. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
  129. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
  130. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
  131. diffusers/quantizers/__init__.py +16 -0
  132. diffusers/quantizers/auto.py +126 -0
  133. diffusers/quantizers/base.py +233 -0
  134. diffusers/quantizers/bitsandbytes/__init__.py +2 -0
  135. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
  136. diffusers/quantizers/bitsandbytes/utils.py +306 -0
  137. diffusers/quantizers/quantization_config.py +391 -0
  138. diffusers/schedulers/scheduling_ddim.py +4 -1
  139. diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
  140. diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
  141. diffusers/schedulers/scheduling_ddpm.py +4 -1
  142. diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
  143. diffusers/schedulers/scheduling_deis_multistep.py +78 -1
  144. diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
  145. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
  146. diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
  147. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
  148. diffusers/schedulers/scheduling_edm_euler.py +8 -6
  149. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
  150. diffusers/schedulers/scheduling_euler_discrete.py +92 -7
  151. diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
  152. diffusers/schedulers/scheduling_heun_discrete.py +114 -8
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
  155. diffusers/schedulers/scheduling_lms_discrete.py +76 -1
  156. diffusers/schedulers/scheduling_sasolver.py +78 -1
  157. diffusers/schedulers/scheduling_unclip.py +4 -1
  158. diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
  159. diffusers/training_utils.py +48 -18
  160. diffusers/utils/__init__.py +2 -1
  161. diffusers/utils/dummy_pt_objects.py +60 -0
  162. diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
  163. diffusers/utils/hub_utils.py +16 -4
  164. diffusers/utils/import_utils.py +31 -8
  165. diffusers/utils/loading_utils.py +28 -4
  166. diffusers/utils/peft_utils.py +3 -3
  167. diffusers/utils/testing_utils.py +59 -0
  168. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
  169. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
  170. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
  171. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
  172. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
  173. {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  import contextlib
2
2
  import copy
3
+ import gc
3
4
  import math
4
5
  import random
5
6
  from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
@@ -23,6 +24,9 @@ from .utils import (
23
24
  if is_transformers_available():
24
25
  import transformers
25
26
 
27
+ if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
28
+ import deepspeed
29
+
26
30
  if is_peft_available():
27
31
  from peft import set_peft_model_state_dict
28
32
 
@@ -35,8 +39,9 @@ if is_torch_npu_available():
35
39
 
36
40
  def set_seed(seed: int):
37
41
  """
38
- Args:
39
42
  Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
43
+
44
+ Args:
40
45
  seed (`int`): The seed to set.
41
46
  """
42
47
  random.seed(seed)
@@ -193,6 +198,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
193
198
 
194
199
 
195
200
  def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
201
+ """
202
+ Casts the training parameters of the model to the specified data type.
203
+
204
+ Args:
205
+ model: The PyTorch model whose parameters will be cast.
206
+ dtype: The data type to which the model parameters will be cast.
207
+ """
196
208
  if not isinstance(model, list):
197
209
  model = [model]
198
210
  for m in model:
@@ -224,7 +236,8 @@ def _set_state_dict_into_text_encoder(
224
236
  def compute_density_for_timestep_sampling(
225
237
  weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
226
238
  ):
227
- """Compute the density for sampling the timesteps when doing SD3 training.
239
+ """
240
+ Compute the density for sampling the timesteps when doing SD3 training.
228
241
 
229
242
  Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
230
243
 
@@ -243,7 +256,8 @@ def compute_density_for_timestep_sampling(
243
256
 
244
257
 
245
258
  def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
246
- """Computes loss weighting scheme for SD3 training.
259
+ """
260
+ Computes loss weighting scheme for SD3 training.
247
261
 
248
262
  Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
249
263
 
@@ -259,6 +273,20 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
259
273
  return weighting
260
274
 
261
275
 
276
+ def free_memory():
277
+ """
278
+ Runs garbage collection. Then clears the cache of the available accelerator.
279
+ """
280
+ gc.collect()
281
+
282
+ if torch.cuda.is_available():
283
+ torch.cuda.empty_cache()
284
+ elif torch.backends.mps.is_available():
285
+ torch.mps.empty_cache()
286
+ elif is_torch_npu_available():
287
+ torch_npu.empty_cache()
288
+
289
+
262
290
  # Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
263
291
  class EMAModel:
264
292
  """
@@ -417,15 +445,13 @@ class EMAModel:
417
445
  self.cur_decay_value = decay
418
446
  one_minus_decay = 1 - decay
419
447
 
420
- context_manager = contextlib.nullcontext
421
- if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
422
- import deepspeed
448
+ context_manager = contextlib.nullcontext()
423
449
 
424
450
  if self.foreach:
425
- if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
451
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
426
452
  context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
427
453
 
428
- with context_manager():
454
+ with context_manager:
429
455
  params_grad = [param for param in parameters if param.requires_grad]
430
456
  s_params_grad = [
431
457
  s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
@@ -444,10 +470,10 @@ class EMAModel:
444
470
 
445
471
  else:
446
472
  for s_param, param in zip(self.shadow_params, parameters):
447
- if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
473
+ if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
448
474
  context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
449
475
 
450
- with context_manager():
476
+ with context_manager:
451
477
  if param.requires_grad:
452
478
  s_param.sub_(one_minus_decay * (s_param - param))
453
479
  else:
@@ -481,7 +507,8 @@ class EMAModel:
481
507
  self.shadow_params = [p.pin_memory() for p in self.shadow_params]
482
508
 
483
509
  def to(self, device=None, dtype=None, non_blocking=False) -> None:
484
- r"""Move internal buffers of the ExponentialMovingAverage to `device`.
510
+ r"""
511
+ Move internal buffers of the ExponentialMovingAverage to `device`.
485
512
 
486
513
  Args:
487
514
  device: like `device` argument to `torch.Tensor.to`
@@ -515,23 +542,25 @@ class EMAModel:
515
542
 
516
543
  def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
517
544
  r"""
545
+ Saves the current parameters for restoring later.
546
+
518
547
  Args:
519
- Save the current parameters for restoring later.
520
- parameters: Iterable of `torch.nn.Parameter`; the parameters to be
521
- temporarily stored.
548
+ parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
522
549
  """
523
550
  self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
524
551
 
525
552
  def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
526
553
  r"""
527
- Args:
528
- Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters without:
529
- affecting the original optimization process. Store the parameters before the `copy_to()` method. After
554
+ Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
555
+ without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
530
556
  validation (or model saving), use this to restore the former parameters.
557
+
558
+ Args:
531
559
  parameters: Iterable of `torch.nn.Parameter`; the parameters to be
532
560
  updated with the stored parameters. If `None`, the parameters with which this
533
561
  `ExponentialMovingAverage` was initialized will be used.
534
562
  """
563
+
535
564
  if self.temp_stored_params is None:
536
565
  raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
537
566
  if self.foreach:
@@ -547,9 +576,10 @@ class EMAModel:
547
576
 
548
577
  def load_state_dict(self, state_dict: dict) -> None:
549
578
  r"""
550
- Args:
551
579
  Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
552
580
  ema state dict.
581
+
582
+ Args:
553
583
  state_dict (dict): EMA state. Should be an object returned
554
584
  from a call to :meth:`state_dict`.
555
585
  """
@@ -62,6 +62,7 @@ from .import_utils import (
62
62
  is_accelerate_available,
63
63
  is_accelerate_version,
64
64
  is_bitsandbytes_available,
65
+ is_bitsandbytes_version,
65
66
  is_bs4_available,
66
67
  is_flax_available,
67
68
  is_ftfy_available,
@@ -94,7 +95,7 @@ from .import_utils import (
94
95
  is_xformers_available,
95
96
  requires_backends,
96
97
  )
97
- from .loading_utils import load_image, load_video
98
+ from .loading_utils import get_module_from_name, load_image, load_video
98
99
  from .logging import get_logger
99
100
  from .outputs import BaseOutput
100
101
  from .peft_utils import (
@@ -122,6 +122,21 @@ class CogVideoXTransformer3DModel(metaclass=DummyObject):
122
122
  requires_backends(cls, ["torch"])
123
123
 
124
124
 
125
+ class CogView3PlusTransformer2DModel(metaclass=DummyObject):
126
+ _backends = ["torch"]
127
+
128
+ def __init__(self, *args, **kwargs):
129
+ requires_backends(self, ["torch"])
130
+
131
+ @classmethod
132
+ def from_config(cls, *args, **kwargs):
133
+ requires_backends(cls, ["torch"])
134
+
135
+ @classmethod
136
+ def from_pretrained(cls, *args, **kwargs):
137
+ requires_backends(cls, ["torch"])
138
+
139
+
125
140
  class ConsistencyDecoderVAE(metaclass=DummyObject):
126
141
  _backends = ["torch"]
127
142
 
@@ -182,6 +197,36 @@ class DiTTransformer2DModel(metaclass=DummyObject):
182
197
  requires_backends(cls, ["torch"])
183
198
 
184
199
 
200
+ class FluxControlNetModel(metaclass=DummyObject):
201
+ _backends = ["torch"]
202
+
203
+ def __init__(self, *args, **kwargs):
204
+ requires_backends(self, ["torch"])
205
+
206
+ @classmethod
207
+ def from_config(cls, *args, **kwargs):
208
+ requires_backends(cls, ["torch"])
209
+
210
+ @classmethod
211
+ def from_pretrained(cls, *args, **kwargs):
212
+ requires_backends(cls, ["torch"])
213
+
214
+
215
+ class FluxMultiControlNetModel(metaclass=DummyObject):
216
+ _backends = ["torch"]
217
+
218
+ def __init__(self, *args, **kwargs):
219
+ requires_backends(self, ["torch"])
220
+
221
+ @classmethod
222
+ def from_config(cls, *args, **kwargs):
223
+ requires_backends(cls, ["torch"])
224
+
225
+ @classmethod
226
+ def from_pretrained(cls, *args, **kwargs):
227
+ requires_backends(cls, ["torch"])
228
+
229
+
185
230
  class FluxTransformer2DModel(metaclass=DummyObject):
186
231
  _backends = ["torch"]
187
232
 
@@ -975,6 +1020,21 @@ class StableDiffusionMixin(metaclass=DummyObject):
975
1020
  requires_backends(cls, ["torch"])
976
1021
 
977
1022
 
1023
+ class DiffusersQuantizer(metaclass=DummyObject):
1024
+ _backends = ["torch"]
1025
+
1026
+ def __init__(self, *args, **kwargs):
1027
+ requires_backends(self, ["torch"])
1028
+
1029
+ @classmethod
1030
+ def from_config(cls, *args, **kwargs):
1031
+ requires_backends(cls, ["torch"])
1032
+
1033
+ @classmethod
1034
+ def from_pretrained(cls, *args, **kwargs):
1035
+ requires_backends(cls, ["torch"])
1036
+
1037
+
978
1038
  class AmusedScheduler(metaclass=DummyObject):
979
1039
  _backends = ["torch"]
980
1040
 
@@ -152,6 +152,21 @@ class AnimateDiffSparseControlNetPipeline(metaclass=DummyObject):
152
152
  requires_backends(cls, ["torch", "transformers"])
153
153
 
154
154
 
155
+ class AnimateDiffVideoToVideoControlNetPipeline(metaclass=DummyObject):
156
+ _backends = ["torch", "transformers"]
157
+
158
+ def __init__(self, *args, **kwargs):
159
+ requires_backends(self, ["torch", "transformers"])
160
+
161
+ @classmethod
162
+ def from_config(cls, *args, **kwargs):
163
+ requires_backends(cls, ["torch", "transformers"])
164
+
165
+ @classmethod
166
+ def from_pretrained(cls, *args, **kwargs):
167
+ requires_backends(cls, ["torch", "transformers"])
168
+
169
+
155
170
  class AnimateDiffVideoToVideoPipeline(metaclass=DummyObject):
156
171
  _backends = ["torch", "transformers"]
157
172
 
@@ -257,6 +272,36 @@ class CLIPImageProjection(metaclass=DummyObject):
257
272
  requires_backends(cls, ["torch", "transformers"])
258
273
 
259
274
 
275
+ class CogVideoXFunControlPipeline(metaclass=DummyObject):
276
+ _backends = ["torch", "transformers"]
277
+
278
+ def __init__(self, *args, **kwargs):
279
+ requires_backends(self, ["torch", "transformers"])
280
+
281
+ @classmethod
282
+ def from_config(cls, *args, **kwargs):
283
+ requires_backends(cls, ["torch", "transformers"])
284
+
285
+ @classmethod
286
+ def from_pretrained(cls, *args, **kwargs):
287
+ requires_backends(cls, ["torch", "transformers"])
288
+
289
+
290
+ class CogVideoXImageToVideoPipeline(metaclass=DummyObject):
291
+ _backends = ["torch", "transformers"]
292
+
293
+ def __init__(self, *args, **kwargs):
294
+ requires_backends(self, ["torch", "transformers"])
295
+
296
+ @classmethod
297
+ def from_config(cls, *args, **kwargs):
298
+ requires_backends(cls, ["torch", "transformers"])
299
+
300
+ @classmethod
301
+ def from_pretrained(cls, *args, **kwargs):
302
+ requires_backends(cls, ["torch", "transformers"])
303
+
304
+
260
305
  class CogVideoXPipeline(metaclass=DummyObject):
261
306
  _backends = ["torch", "transformers"]
262
307
 
@@ -272,6 +317,36 @@ class CogVideoXPipeline(metaclass=DummyObject):
272
317
  requires_backends(cls, ["torch", "transformers"])
273
318
 
274
319
 
320
+ class CogVideoXVideoToVideoPipeline(metaclass=DummyObject):
321
+ _backends = ["torch", "transformers"]
322
+
323
+ def __init__(self, *args, **kwargs):
324
+ requires_backends(self, ["torch", "transformers"])
325
+
326
+ @classmethod
327
+ def from_config(cls, *args, **kwargs):
328
+ requires_backends(cls, ["torch", "transformers"])
329
+
330
+ @classmethod
331
+ def from_pretrained(cls, *args, **kwargs):
332
+ requires_backends(cls, ["torch", "transformers"])
333
+
334
+
335
+ class CogView3PlusPipeline(metaclass=DummyObject):
336
+ _backends = ["torch", "transformers"]
337
+
338
+ def __init__(self, *args, **kwargs):
339
+ requires_backends(self, ["torch", "transformers"])
340
+
341
+ @classmethod
342
+ def from_config(cls, *args, **kwargs):
343
+ requires_backends(cls, ["torch", "transformers"])
344
+
345
+ @classmethod
346
+ def from_pretrained(cls, *args, **kwargs):
347
+ requires_backends(cls, ["torch", "transformers"])
348
+
349
+
275
350
  class CycleDiffusionPipeline(metaclass=DummyObject):
276
351
  _backends = ["torch", "transformers"]
277
352
 
@@ -287,6 +362,81 @@ class CycleDiffusionPipeline(metaclass=DummyObject):
287
362
  requires_backends(cls, ["torch", "transformers"])
288
363
 
289
364
 
365
+ class FluxControlNetImg2ImgPipeline(metaclass=DummyObject):
366
+ _backends = ["torch", "transformers"]
367
+
368
+ def __init__(self, *args, **kwargs):
369
+ requires_backends(self, ["torch", "transformers"])
370
+
371
+ @classmethod
372
+ def from_config(cls, *args, **kwargs):
373
+ requires_backends(cls, ["torch", "transformers"])
374
+
375
+ @classmethod
376
+ def from_pretrained(cls, *args, **kwargs):
377
+ requires_backends(cls, ["torch", "transformers"])
378
+
379
+
380
+ class FluxControlNetInpaintPipeline(metaclass=DummyObject):
381
+ _backends = ["torch", "transformers"]
382
+
383
+ def __init__(self, *args, **kwargs):
384
+ requires_backends(self, ["torch", "transformers"])
385
+
386
+ @classmethod
387
+ def from_config(cls, *args, **kwargs):
388
+ requires_backends(cls, ["torch", "transformers"])
389
+
390
+ @classmethod
391
+ def from_pretrained(cls, *args, **kwargs):
392
+ requires_backends(cls, ["torch", "transformers"])
393
+
394
+
395
+ class FluxControlNetPipeline(metaclass=DummyObject):
396
+ _backends = ["torch", "transformers"]
397
+
398
+ def __init__(self, *args, **kwargs):
399
+ requires_backends(self, ["torch", "transformers"])
400
+
401
+ @classmethod
402
+ def from_config(cls, *args, **kwargs):
403
+ requires_backends(cls, ["torch", "transformers"])
404
+
405
+ @classmethod
406
+ def from_pretrained(cls, *args, **kwargs):
407
+ requires_backends(cls, ["torch", "transformers"])
408
+
409
+
410
+ class FluxImg2ImgPipeline(metaclass=DummyObject):
411
+ _backends = ["torch", "transformers"]
412
+
413
+ def __init__(self, *args, **kwargs):
414
+ requires_backends(self, ["torch", "transformers"])
415
+
416
+ @classmethod
417
+ def from_config(cls, *args, **kwargs):
418
+ requires_backends(cls, ["torch", "transformers"])
419
+
420
+ @classmethod
421
+ def from_pretrained(cls, *args, **kwargs):
422
+ requires_backends(cls, ["torch", "transformers"])
423
+
424
+
425
+ class FluxInpaintPipeline(metaclass=DummyObject):
426
+ _backends = ["torch", "transformers"]
427
+
428
+ def __init__(self, *args, **kwargs):
429
+ requires_backends(self, ["torch", "transformers"])
430
+
431
+ @classmethod
432
+ def from_config(cls, *args, **kwargs):
433
+ requires_backends(cls, ["torch", "transformers"])
434
+
435
+ @classmethod
436
+ def from_pretrained(cls, *args, **kwargs):
437
+ requires_backends(cls, ["torch", "transformers"])
438
+
439
+
290
440
  class FluxPipeline(metaclass=DummyObject):
291
441
  _backends = ["torch", "transformers"]
292
442
 
@@ -1232,6 +1382,21 @@ class StableDiffusionControlNetInpaintPipeline(metaclass=DummyObject):
1232
1382
  requires_backends(cls, ["torch", "transformers"])
1233
1383
 
1234
1384
 
1385
+ class StableDiffusionControlNetPAGInpaintPipeline(metaclass=DummyObject):
1386
+ _backends = ["torch", "transformers"]
1387
+
1388
+ def __init__(self, *args, **kwargs):
1389
+ requires_backends(self, ["torch", "transformers"])
1390
+
1391
+ @classmethod
1392
+ def from_config(cls, *args, **kwargs):
1393
+ requires_backends(cls, ["torch", "transformers"])
1394
+
1395
+ @classmethod
1396
+ def from_pretrained(cls, *args, **kwargs):
1397
+ requires_backends(cls, ["torch", "transformers"])
1398
+
1399
+
1235
1400
  class StableDiffusionControlNetPAGPipeline(metaclass=DummyObject):
1236
1401
  _backends = ["torch", "transformers"]
1237
1402
 
@@ -1457,6 +1622,21 @@ class StableDiffusionModelEditingPipeline(metaclass=DummyObject):
1457
1622
  requires_backends(cls, ["torch", "transformers"])
1458
1623
 
1459
1624
 
1625
+ class StableDiffusionPAGImg2ImgPipeline(metaclass=DummyObject):
1626
+ _backends = ["torch", "transformers"]
1627
+
1628
+ def __init__(self, *args, **kwargs):
1629
+ requires_backends(self, ["torch", "transformers"])
1630
+
1631
+ @classmethod
1632
+ def from_config(cls, *args, **kwargs):
1633
+ requires_backends(cls, ["torch", "transformers"])
1634
+
1635
+ @classmethod
1636
+ def from_pretrained(cls, *args, **kwargs):
1637
+ requires_backends(cls, ["torch", "transformers"])
1638
+
1639
+
1460
1640
  class StableDiffusionPAGPipeline(metaclass=DummyObject):
1461
1641
  _backends = ["torch", "transformers"]
1462
1642
 
@@ -1622,6 +1802,21 @@ class StableDiffusionXLControlNetInpaintPipeline(metaclass=DummyObject):
1622
1802
  requires_backends(cls, ["torch", "transformers"])
1623
1803
 
1624
1804
 
1805
+ class StableDiffusionXLControlNetPAGImg2ImgPipeline(metaclass=DummyObject):
1806
+ _backends = ["torch", "transformers"]
1807
+
1808
+ def __init__(self, *args, **kwargs):
1809
+ requires_backends(self, ["torch", "transformers"])
1810
+
1811
+ @classmethod
1812
+ def from_config(cls, *args, **kwargs):
1813
+ requires_backends(cls, ["torch", "transformers"])
1814
+
1815
+ @classmethod
1816
+ def from_pretrained(cls, *args, **kwargs):
1817
+ requires_backends(cls, ["torch", "transformers"])
1818
+
1819
+
1625
1820
  class StableDiffusionXLControlNetPAGPipeline(metaclass=DummyObject):
1626
1821
  _backends = ["torch", "transformers"]
1627
1822
 
@@ -271,8 +271,7 @@ if cache_version < 1:
271
271
  def _add_variant(weights_name: str, variant: Optional[str] = None) -> str:
272
272
  if variant is not None:
273
273
  splits = weights_name.split(".")
274
- split_index = -2 if weights_name.endswith(".index.json") else -1
275
- splits = splits[:-split_index] + [variant] + splits[-split_index:]
274
+ splits = splits[:-1] + [variant] + splits[-1:]
276
275
  weights_name = ".".join(splits)
277
276
 
278
277
  return weights_name
@@ -458,7 +457,7 @@ def _get_checkpoint_shard_files(
458
457
  ignore_patterns = ["*.json", "*.md"]
459
458
  if not local_files_only:
460
459
  # `model_info` call must guarded with the above condition.
461
- model_files_info = model_info(pretrained_model_name_or_path, revision=revision)
460
+ model_files_info = model_info(pretrained_model_name_or_path, revision=revision, token=token)
462
461
  for shard_file in original_shard_filenames:
463
462
  shard_file_present = any(shard_file in k.rfilename for k in model_files_info.siblings)
464
463
  if not shard_file_present:
@@ -497,11 +496,24 @@ def _get_checkpoint_shard_files(
497
496
  local_dir=cache_dir, subfolder=subfolder, original_shard_filenames=original_shard_filenames
498
497
  )
499
498
  if subfolder is not None:
500
- cached_folder = os.path.join(cached_folder, subfolder)
499
+ cached_folder = os.path.join(cache_dir, subfolder)
501
500
 
502
501
  return cached_folder, sharded_metadata
503
502
 
504
503
 
504
+ def _check_legacy_sharding_variant_format(folder: str = None, filenames: List[str] = None, variant: str = None):
505
+ if filenames and folder:
506
+ raise ValueError("Both `filenames` and `folder` cannot be provided.")
507
+ if not filenames:
508
+ filenames = []
509
+ for _, _, files in os.walk(folder):
510
+ for file in files:
511
+ filenames.append(os.path.basename(file))
512
+ transformers_index_format = r"\d{5}-of-\d{5}"
513
+ variant_file_re = re.compile(rf".*-{transformers_index_format}\.{variant}\.[a-z]+$")
514
+ return any(variant_file_re.match(f) is not None for f in filenames)
515
+
516
+
505
517
  class PushToHubMixin:
506
518
  """
507
519
  A Mixin to push a model, scheduler, or pipeline to the Hugging Face Hub.
@@ -668,8 +668,9 @@ class DummyObject(type):
668
668
  # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L319
669
669
  def compare_versions(library_or_version: Union[str, Version], operation: str, requirement_version: str):
670
670
  """
671
- Args:
672
671
  Compares a library version to some requirement using a given operation.
672
+
673
+ Args:
673
674
  library_or_version (`str` or `packaging.version.Version`):
674
675
  A library name or a version to check.
675
676
  operation (`str`):
@@ -688,8 +689,9 @@ def compare_versions(library_or_version: Union[str, Version], operation: str, re
688
689
  # This function was copied from: https://github.com/huggingface/accelerate/blob/874c4967d94badd24f893064cc3bef45f57cadf7/src/accelerate/utils/versions.py#L338
689
690
  def is_torch_version(operation: str, version: str):
690
691
  """
691
- Args:
692
692
  Compares the current PyTorch version to a given reference with an operation.
693
+
694
+ Args:
693
695
  operation (`str`):
694
696
  A string representation of an operator, such as `">"` or `"<="`
695
697
  version (`str`):
@@ -700,8 +702,9 @@ def is_torch_version(operation: str, version: str):
700
702
 
701
703
  def is_transformers_version(operation: str, version: str):
702
704
  """
703
- Args:
704
705
  Compares the current Transformers version to a given reference with an operation.
706
+
707
+ Args:
705
708
  operation (`str`):
706
709
  A string representation of an operator, such as `">"` or `"<="`
707
710
  version (`str`):
@@ -714,8 +717,9 @@ def is_transformers_version(operation: str, version: str):
714
717
 
715
718
  def is_accelerate_version(operation: str, version: str):
716
719
  """
717
- Args:
718
720
  Compares the current Accelerate version to a given reference with an operation.
721
+
722
+ Args:
719
723
  operation (`str`):
720
724
  A string representation of an operator, such as `">"` or `"<="`
721
725
  version (`str`):
@@ -728,8 +732,9 @@ def is_accelerate_version(operation: str, version: str):
728
732
 
729
733
  def is_peft_version(operation: str, version: str):
730
734
  """
731
- Args:
732
735
  Compares the current PEFT version to a given reference with an operation.
736
+
737
+ Args:
733
738
  operation (`str`):
734
739
  A string representation of an operator, such as `">"` or `"<="`
735
740
  version (`str`):
@@ -740,10 +745,25 @@ def is_peft_version(operation: str, version: str):
740
745
  return compare_versions(parse(_peft_version), operation, version)
741
746
 
742
747
 
743
- def is_k_diffusion_version(operation: str, version: str):
748
+ def is_bitsandbytes_version(operation: str, version: str):
744
749
  """
745
750
  Args:
751
+ Compares the current bitsandbytes version to a given reference with an operation.
752
+ operation (`str`):
753
+ A string representation of an operator, such as `">"` or `"<="`
754
+ version (`str`):
755
+ A version string
756
+ """
757
+ if not _bitsandbytes_version:
758
+ return False
759
+ return compare_versions(parse(_bitsandbytes_version), operation, version)
760
+
761
+
762
+ def is_k_diffusion_version(operation: str, version: str):
763
+ """
746
764
  Compares the current k-diffusion version to a given reference with an operation.
765
+
766
+ Args:
747
767
  operation (`str`):
748
768
  A string representation of an operator, such as `">"` or `"<="`
749
769
  version (`str`):
@@ -756,8 +776,9 @@ def is_k_diffusion_version(operation: str, version: str):
756
776
 
757
777
  def get_objects_from_module(module):
758
778
  """
759
- Args:
760
779
  Returns a dict of object names and values in a module, while skipping private/internal objects
780
+
781
+ Args:
761
782
  module (ModuleType):
762
783
  Module to extract the objects from.
763
784
 
@@ -775,7 +796,9 @@ def get_objects_from_module(module):
775
796
 
776
797
 
777
798
  class OptionalDependencyNotAvailable(BaseException):
778
- """An error indicating that an optional dependency of Diffusers was not found in the environment."""
799
+ """
800
+ An error indicating that an optional dependency of Diffusers was not found in the environment.
801
+ """
779
802
 
780
803
 
781
804
  class _LazyModule(ModuleType):