diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -19,10 +19,10 @@ import torch
19
19
  import torch.utils.checkpoint
20
20
  from torch import Tensor, nn
21
21
 
22
- from ..configuration_utils import ConfigMixin, register_to_config
23
- from ..utils import BaseOutput, is_torch_version, logging
24
- from ..utils.torch_utils import apply_freeu
25
- from .attention_processor import (
22
+ from ...configuration_utils import ConfigMixin, register_to_config
23
+ from ...utils import BaseOutput, is_torch_version, logging
24
+ from ...utils.torch_utils import apply_freeu
25
+ from ..attention_processor import (
26
26
  ADDED_KV_ATTENTION_PROCESSORS,
27
27
  CROSS_ATTENTION_PROCESSORS,
28
28
  Attention,
@@ -31,10 +31,9 @@ from .attention_processor import (
31
31
  AttnProcessor,
32
32
  FusedAttnProcessor2_0,
33
33
  )
34
- from .controlnet import ControlNetConditioningEmbedding
35
- from .embeddings import TimestepEmbedding, Timesteps
36
- from .modeling_utils import ModelMixin
37
- from .unets.unet_2d_blocks import (
34
+ from ..embeddings import TimestepEmbedding, Timesteps
35
+ from ..modeling_utils import ModelMixin
36
+ from ..unets.unet_2d_blocks import (
38
37
  CrossAttnDownBlock2D,
39
38
  CrossAttnUpBlock2D,
40
39
  Downsample2D,
@@ -43,7 +42,8 @@ from .unets.unet_2d_blocks import (
43
42
  UNetMidBlock2DCrossAttn,
44
43
  Upsample2D,
45
44
  )
46
- from .unets.unet_2d_condition import UNet2DConditionModel
45
+ from ..unets.unet_2d_condition import UNet2DConditionModel
46
+ from .controlnet import ControlNetConditioningEmbedding
47
47
 
48
48
 
49
49
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -1062,7 +1062,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
1062
1062
  added_cond_kwargs (`dict`):
1063
1063
  Additional conditions for the Stable Diffusion XL UNet.
1064
1064
  return_dict (`bool`, defaults to `True`):
1065
- Whether or not to return a [`~models.controlnet.ControlNetOutput`] instead of a plain tuple.
1065
+ Whether or not to return a [`~models.controlnets.controlnet.ControlNetOutput`] instead of a plain
1066
+ tuple.
1066
1067
  apply_control (`bool`, defaults to `True`):
1067
1068
  If `False`, the input is run only through the base model.
1068
1069
 
@@ -1465,7 +1466,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1465
1466
  h_ctrl = torch.cat([h_ctrl, b2c(h_base)], dim=1)
1466
1467
 
1467
1468
  # apply base subblock
1468
- if self.training and self.gradient_checkpointing:
1469
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1469
1470
  ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1470
1471
  h_base = torch.utils.checkpoint.checkpoint(
1471
1472
  create_custom_forward(b_res),
@@ -1488,7 +1489,7 @@ class ControlNetXSCrossAttnDownBlock2D(nn.Module):
1488
1489
 
1489
1490
  # apply ctrl subblock
1490
1491
  if apply_control:
1491
- if self.training and self.gradient_checkpointing:
1492
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1492
1493
  ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1493
1494
  h_ctrl = torch.utils.checkpoint.checkpoint(
1494
1495
  create_custom_forward(c_res),
@@ -1897,7 +1898,7 @@ class ControlNetXSCrossAttnUpBlock2D(nn.Module):
1897
1898
  hidden_states, res_h_base = maybe_apply_freeu_to_subblock(hidden_states, res_h_base)
1898
1899
  hidden_states = torch.cat([hidden_states, res_h_base], dim=1)
1899
1900
 
1900
- if self.training and self.gradient_checkpointing:
1901
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1901
1902
  ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1902
1903
  hidden_states = torch.utils.checkpoint.checkpoint(
1903
1904
  create_custom_forward(resnet),
@@ -0,0 +1,183 @@
1
+ import os
2
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
3
+
4
+ import torch
5
+ from torch import nn
6
+
7
+ from ...models.controlnets.controlnet import ControlNetModel, ControlNetOutput
8
+ from ...models.modeling_utils import ModelMixin
9
+ from ...utils import logging
10
+
11
+
12
+ logger = logging.get_logger(__name__)
13
+
14
+
15
+ class MultiControlNetModel(ModelMixin):
16
+ r"""
17
+ Multiple `ControlNetModel` wrapper class for Multi-ControlNet
18
+
19
+ This module is a wrapper for multiple instances of the `ControlNetModel`. The `forward()` API is designed to be
20
+ compatible with `ControlNetModel`.
21
+
22
+ Args:
23
+ controlnets (`List[ControlNetModel]`):
24
+ Provides additional conditioning to the unet during the denoising process. You must set multiple
25
+ `ControlNetModel` as a list.
26
+ """
27
+
28
+ def __init__(self, controlnets: Union[List[ControlNetModel], Tuple[ControlNetModel]]):
29
+ super().__init__()
30
+ self.nets = nn.ModuleList(controlnets)
31
+
32
+ def forward(
33
+ self,
34
+ sample: torch.Tensor,
35
+ timestep: Union[torch.Tensor, float, int],
36
+ encoder_hidden_states: torch.Tensor,
37
+ controlnet_cond: List[torch.tensor],
38
+ conditioning_scale: List[float],
39
+ class_labels: Optional[torch.Tensor] = None,
40
+ timestep_cond: Optional[torch.Tensor] = None,
41
+ attention_mask: Optional[torch.Tensor] = None,
42
+ added_cond_kwargs: Optional[Dict[str, torch.Tensor]] = None,
43
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
44
+ guess_mode: bool = False,
45
+ return_dict: bool = True,
46
+ ) -> Union[ControlNetOutput, Tuple]:
47
+ for i, (image, scale, controlnet) in enumerate(zip(controlnet_cond, conditioning_scale, self.nets)):
48
+ down_samples, mid_sample = controlnet(
49
+ sample=sample,
50
+ timestep=timestep,
51
+ encoder_hidden_states=encoder_hidden_states,
52
+ controlnet_cond=image,
53
+ conditioning_scale=scale,
54
+ class_labels=class_labels,
55
+ timestep_cond=timestep_cond,
56
+ attention_mask=attention_mask,
57
+ added_cond_kwargs=added_cond_kwargs,
58
+ cross_attention_kwargs=cross_attention_kwargs,
59
+ guess_mode=guess_mode,
60
+ return_dict=return_dict,
61
+ )
62
+
63
+ # merge samples
64
+ if i == 0:
65
+ down_block_res_samples, mid_block_res_sample = down_samples, mid_sample
66
+ else:
67
+ down_block_res_samples = [
68
+ samples_prev + samples_curr
69
+ for samples_prev, samples_curr in zip(down_block_res_samples, down_samples)
70
+ ]
71
+ mid_block_res_sample += mid_sample
72
+
73
+ return down_block_res_samples, mid_block_res_sample
74
+
75
+ def save_pretrained(
76
+ self,
77
+ save_directory: Union[str, os.PathLike],
78
+ is_main_process: bool = True,
79
+ save_function: Callable = None,
80
+ safe_serialization: bool = True,
81
+ variant: Optional[str] = None,
82
+ ):
83
+ """
84
+ Save a model and its configuration file to a directory, so that it can be re-loaded using the
85
+ `[`~models.controlnets.multicontrolnet.MultiControlNetModel.from_pretrained`]` class method.
86
+
87
+ Arguments:
88
+ save_directory (`str` or `os.PathLike`):
89
+ Directory to which to save. Will be created if it doesn't exist.
90
+ is_main_process (`bool`, *optional*, defaults to `True`):
91
+ Whether the process calling this is the main process or not. Useful when in distributed training like
92
+ TPUs and need to call this function on all processes. In this case, set `is_main_process=True` only on
93
+ the main process to avoid race conditions.
94
+ save_function (`Callable`):
95
+ The function to use to save the state dictionary. Useful on distributed training like TPUs when one
96
+ need to replace `torch.save` by another method. Can be configured with the environment variable
97
+ `DIFFUSERS_SAVE_MODE`.
98
+ safe_serialization (`bool`, *optional*, defaults to `True`):
99
+ Whether to save the model using `safetensors` or the traditional PyTorch way (that uses `pickle`).
100
+ variant (`str`, *optional*):
101
+ If specified, weights are saved in the format pytorch_model.<variant>.bin.
102
+ """
103
+ for idx, controlnet in enumerate(self.nets):
104
+ suffix = "" if idx == 0 else f"_{idx}"
105
+ controlnet.save_pretrained(
106
+ save_directory + suffix,
107
+ is_main_process=is_main_process,
108
+ save_function=save_function,
109
+ safe_serialization=safe_serialization,
110
+ variant=variant,
111
+ )
112
+
113
+ @classmethod
114
+ def from_pretrained(cls, pretrained_model_path: Optional[Union[str, os.PathLike]], **kwargs):
115
+ r"""
116
+ Instantiate a pretrained MultiControlNet model from multiple pre-trained controlnet models.
117
+
118
+ The model is set in evaluation mode by default using `model.eval()` (Dropout modules are deactivated). To train
119
+ the model, you should first set it back in training mode with `model.train()`.
120
+
121
+ The warning *Weights from XXX not initialized from pretrained model* means that the weights of XXX do not come
122
+ pretrained with the rest of the model. It is up to you to train those weights with a downstream fine-tuning
123
+ task.
124
+
125
+ The warning *Weights from XXX not used in YYY* means that the layer XXX is not used by YYY, therefore those
126
+ weights are discarded.
127
+
128
+ Parameters:
129
+ pretrained_model_path (`os.PathLike`):
130
+ A path to a *directory* containing model weights saved using
131
+ [`~models.controlnets.multicontrolnet.MultiControlNetModel.save_pretrained`], e.g.,
132
+ `./my_model_directory/controlnet`.
133
+ torch_dtype (`str` or `torch.dtype`, *optional*):
134
+ Override the default `torch.dtype` and load the model under this dtype. If `"auto"` is passed the dtype
135
+ will be automatically derived from the model's weights.
136
+ output_loading_info(`bool`, *optional*, defaults to `False`):
137
+ Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
138
+ device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
139
+ A map that specifies where each submodule should go. It doesn't need to be refined to each
140
+ parameter/buffer name, once a given module name is inside, every submodule of it will be sent to the
141
+ same device.
142
+
143
+ To have Accelerate compute the most optimized `device_map` automatically, set `device_map="auto"`. For
144
+ more information about each option see [designing a device
145
+ map](https://hf.co/docs/accelerate/main/en/usage_guides/big_modeling#designing-a-device-map).
146
+ max_memory (`Dict`, *optional*):
147
+ A dictionary device identifier to maximum memory. Will default to the maximum memory available for each
148
+ GPU and the available CPU RAM if unset.
149
+ low_cpu_mem_usage (`bool`, *optional*, defaults to `True` if torch version >= 1.9.0 else `False`):
150
+ Speed up model loading by not initializing the weights and only loading the pre-trained weights. This
151
+ also tries to not use more than 1x model size in CPU memory (including peak memory) while loading the
152
+ model. This is only supported when torch version >= 1.9.0. If you are using an older version of torch,
153
+ setting this argument to `True` will raise an error.
154
+ variant (`str`, *optional*):
155
+ If specified load weights from `variant` filename, *e.g.* pytorch_model.<variant>.bin. `variant` is
156
+ ignored when using `from_flax`.
157
+ use_safetensors (`bool`, *optional*, defaults to `None`):
158
+ If set to `None`, the `safetensors` weights will be downloaded if they're available **and** if the
159
+ `safetensors` library is installed. If set to `True`, the model will be forcibly loaded from
160
+ `safetensors` weights. If set to `False`, loading will *not* use `safetensors`.
161
+ """
162
+ idx = 0
163
+ controlnets = []
164
+
165
+ # load controlnet and append to list until no controlnet directory exists anymore
166
+ # first controlnet has to be saved under `./mydirectory/controlnet` to be compliant with `DiffusionPipeline.from_prertained`
167
+ # second, third, ... controlnets have to be saved under `./mydirectory/controlnet_1`, `./mydirectory/controlnet_2`, ...
168
+ model_path_to_load = pretrained_model_path
169
+ while os.path.isdir(model_path_to_load):
170
+ controlnet = ControlNetModel.from_pretrained(model_path_to_load, **kwargs)
171
+ controlnets.append(controlnet)
172
+
173
+ idx += 1
174
+ model_path_to_load = pretrained_model_path + f"_{idx}"
175
+
176
+ logger.info(f"{len(controlnets)} controlnets loaded from {pretrained_model_path}.")
177
+
178
+ if len(controlnets) == 0:
179
+ raise ValueError(
180
+ f"No ControlNets found under {os.path.dirname(pretrained_model_path)}. Expected at least {pretrained_model_path + '_0'}."
181
+ )
182
+
183
+ return cls(controlnets)