diffusers 0.31.0__py3-none-any.whl → 0.32.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (214) hide show
  1. diffusers/__init__.py +66 -5
  2. diffusers/callbacks.py +56 -3
  3. diffusers/configuration_utils.py +1 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/image_processor.py +25 -17
  6. diffusers/loaders/__init__.py +22 -3
  7. diffusers/loaders/ip_adapter.py +538 -15
  8. diffusers/loaders/lora_base.py +124 -118
  9. diffusers/loaders/lora_conversion_utils.py +318 -3
  10. diffusers/loaders/lora_pipeline.py +1688 -368
  11. diffusers/loaders/peft.py +379 -0
  12. diffusers/loaders/single_file_model.py +71 -4
  13. diffusers/loaders/single_file_utils.py +519 -9
  14. diffusers/loaders/textual_inversion.py +3 -3
  15. diffusers/loaders/transformer_flux.py +181 -0
  16. diffusers/loaders/transformer_sd3.py +89 -0
  17. diffusers/loaders/unet.py +17 -4
  18. diffusers/models/__init__.py +47 -14
  19. diffusers/models/activations.py +22 -9
  20. diffusers/models/attention.py +13 -4
  21. diffusers/models/attention_flax.py +1 -1
  22. diffusers/models/attention_processor.py +2059 -281
  23. diffusers/models/autoencoders/__init__.py +5 -0
  24. diffusers/models/autoencoders/autoencoder_dc.py +620 -0
  25. diffusers/models/autoencoders/autoencoder_kl.py +2 -1
  26. diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
  27. diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
  28. diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
  29. diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
  30. diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
  31. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
  32. diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
  33. diffusers/models/autoencoders/vae.py +18 -5
  34. diffusers/models/controlnet.py +47 -802
  35. diffusers/models/controlnet_flux.py +29 -495
  36. diffusers/models/controlnet_sd3.py +25 -379
  37. diffusers/models/controlnet_sparsectrl.py +46 -718
  38. diffusers/models/controlnets/__init__.py +23 -0
  39. diffusers/models/controlnets/controlnet.py +872 -0
  40. diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
  41. diffusers/models/controlnets/controlnet_flux.py +536 -0
  42. diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
  43. diffusers/models/controlnets/controlnet_sd3.py +489 -0
  44. diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
  45. diffusers/models/controlnets/controlnet_union.py +832 -0
  46. diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
  47. diffusers/models/controlnets/multicontrolnet.py +183 -0
  48. diffusers/models/embeddings.py +838 -43
  49. diffusers/models/model_loading_utils.py +88 -6
  50. diffusers/models/modeling_flax_utils.py +1 -1
  51. diffusers/models/modeling_utils.py +72 -26
  52. diffusers/models/normalization.py +78 -13
  53. diffusers/models/transformers/__init__.py +5 -0
  54. diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
  55. diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
  56. diffusers/models/transformers/dit_transformer_2d.py +1 -1
  57. diffusers/models/transformers/latte_transformer_3d.py +4 -4
  58. diffusers/models/transformers/pixart_transformer_2d.py +1 -1
  59. diffusers/models/transformers/sana_transformer.py +488 -0
  60. diffusers/models/transformers/stable_audio_transformer.py +1 -1
  61. diffusers/models/transformers/transformer_2d.py +1 -1
  62. diffusers/models/transformers/transformer_allegro.py +422 -0
  63. diffusers/models/transformers/transformer_cogview3plus.py +1 -1
  64. diffusers/models/transformers/transformer_flux.py +30 -9
  65. diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
  66. diffusers/models/transformers/transformer_ltx.py +469 -0
  67. diffusers/models/transformers/transformer_mochi.py +499 -0
  68. diffusers/models/transformers/transformer_sd3.py +105 -17
  69. diffusers/models/transformers/transformer_temporal.py +1 -1
  70. diffusers/models/unets/unet_1d_blocks.py +1 -1
  71. diffusers/models/unets/unet_2d.py +8 -1
  72. diffusers/models/unets/unet_2d_blocks.py +88 -21
  73. diffusers/models/unets/unet_2d_condition.py +1 -1
  74. diffusers/models/unets/unet_3d_blocks.py +9 -7
  75. diffusers/models/unets/unet_motion_model.py +5 -5
  76. diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
  77. diffusers/models/unets/unet_stable_cascade.py +2 -2
  78. diffusers/models/unets/uvit_2d.py +1 -1
  79. diffusers/models/upsampling.py +8 -0
  80. diffusers/pipelines/__init__.py +34 -0
  81. diffusers/pipelines/allegro/__init__.py +48 -0
  82. diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
  83. diffusers/pipelines/allegro/pipeline_output.py +23 -0
  84. diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
  85. diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
  86. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
  87. diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
  88. diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
  89. diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
  90. diffusers/pipelines/auto_pipeline.py +53 -6
  91. diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
  92. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
  93. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
  94. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
  95. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
  96. diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
  97. diffusers/pipelines/controlnet/__init__.py +86 -80
  98. diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
  99. diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
  100. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
  101. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
  102. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
  103. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
  104. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
  105. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
  106. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
  107. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
  108. diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
  109. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
  110. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
  111. diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
  112. diffusers/pipelines/flux/__init__.py +13 -1
  113. diffusers/pipelines/flux/modeling_flux.py +47 -0
  114. diffusers/pipelines/flux/pipeline_flux.py +204 -29
  115. diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
  116. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
  117. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
  118. diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
  119. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
  120. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
  121. diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
  122. diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
  123. diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
  124. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
  125. diffusers/pipelines/flux/pipeline_output.py +16 -0
  126. diffusers/pipelines/hunyuan_video/__init__.py +48 -0
  127. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
  128. diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
  129. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
  130. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
  131. diffusers/pipelines/kolors/text_encoder.py +2 -2
  132. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
  133. diffusers/pipelines/ltx/__init__.py +50 -0
  134. diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
  135. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
  136. diffusers/pipelines/ltx/pipeline_output.py +20 -0
  137. diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
  138. diffusers/pipelines/mochi/__init__.py +48 -0
  139. diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
  140. diffusers/pipelines/mochi/pipeline_output.py +20 -0
  141. diffusers/pipelines/pag/__init__.py +7 -0
  142. diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
  143. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
  144. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
  145. diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
  146. diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
  147. diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
  148. diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
  149. diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
  150. diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
  151. diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
  152. diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
  153. diffusers/pipelines/pipeline_flax_utils.py +1 -1
  154. diffusers/pipelines/pipeline_loading_utils.py +25 -4
  155. diffusers/pipelines/pipeline_utils.py +35 -6
  156. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
  157. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
  158. diffusers/pipelines/sana/__init__.py +47 -0
  159. diffusers/pipelines/sana/pipeline_output.py +21 -0
  160. diffusers/pipelines/sana/pipeline_sana.py +884 -0
  161. diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
  162. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
  163. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
  164. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
  165. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
  166. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
  167. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
  168. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
  169. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
  170. diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
  171. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
  172. diffusers/quantizers/auto.py +14 -1
  173. diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
  174. diffusers/quantizers/gguf/__init__.py +1 -0
  175. diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
  176. diffusers/quantizers/gguf/utils.py +456 -0
  177. diffusers/quantizers/quantization_config.py +280 -2
  178. diffusers/quantizers/torchao/__init__.py +15 -0
  179. diffusers/quantizers/torchao/torchao_quantizer.py +292 -0
  180. diffusers/schedulers/scheduling_ddpm.py +2 -6
  181. diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
  182. diffusers/schedulers/scheduling_deis_multistep.py +28 -9
  183. diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
  184. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
  185. diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
  186. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
  187. diffusers/schedulers/scheduling_euler_discrete.py +4 -4
  188. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
  189. diffusers/schedulers/scheduling_heun_discrete.py +4 -4
  190. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
  191. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
  192. diffusers/schedulers/scheduling_lcm.py +2 -6
  193. diffusers/schedulers/scheduling_lms_discrete.py +4 -4
  194. diffusers/schedulers/scheduling_repaint.py +1 -1
  195. diffusers/schedulers/scheduling_sasolver.py +28 -9
  196. diffusers/schedulers/scheduling_tcd.py +2 -6
  197. diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
  198. diffusers/training_utils.py +16 -2
  199. diffusers/utils/__init__.py +5 -0
  200. diffusers/utils/constants.py +1 -0
  201. diffusers/utils/dummy_pt_objects.py +180 -0
  202. diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
  203. diffusers/utils/dynamic_modules_utils.py +3 -3
  204. diffusers/utils/hub_utils.py +31 -39
  205. diffusers/utils/import_utils.py +67 -0
  206. diffusers/utils/peft_utils.py +3 -0
  207. diffusers/utils/testing_utils.py +56 -1
  208. diffusers/utils/torch_utils.py +3 -0
  209. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/METADATA +6 -6
  210. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/RECORD +214 -162
  211. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/WHEEL +1 -1
  212. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/LICENSE +0 -0
  213. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/entry_points.txt +0 -0
  214. {diffusers-0.31.0.dist-info → diffusers-0.32.1.dist-info}/top_level.txt +0 -0
@@ -217,7 +217,7 @@ class MidResTemporalBlock1D(nn.Module):
217
217
  if self.upsample:
218
218
  hidden_states = self.upsample(hidden_states)
219
219
  if self.downsample:
220
- self.downsample = self.downsample(hidden_states)
220
+ hidden_states = self.downsample(hidden_states)
221
221
 
222
222
  return hidden_states
223
223
 
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
89
89
  conditioning with `class_embed_type` equal to `None`.
90
90
  """
91
91
 
92
+ _supports_gradient_checkpointing = True
93
+
92
94
  @register_to_config
93
95
  def __init__(
94
96
  self,
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
97
99
  out_channels: int = 3,
98
100
  center_input_sample: bool = False,
99
101
  time_embedding_type: str = "positional",
102
+ time_embedding_dim: Optional[int] = None,
100
103
  freq_shift: int = 0,
101
104
  flip_sin_to_cos: bool = True,
102
105
  down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
122
125
  super().__init__()
123
126
 
124
127
  self.sample_size = sample_size
125
- time_embed_dim = block_out_channels[0] * 4
128
+ time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
126
129
 
127
130
  # Check inputs
128
131
  if len(down_block_types) != len(up_block_types):
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
240
243
  self.conv_act = nn.SiLU()
241
244
  self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
242
245
 
246
+ def _set_gradient_checkpointing(self, module, value=False):
247
+ if hasattr(module, "gradient_checkpointing"):
248
+ module.gradient_checkpointing = value
249
+
243
250
  def forward(
244
251
  self,
245
252
  sample: torch.Tensor,
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
731
731
  self.attentions = nn.ModuleList(attentions)
732
732
  self.resnets = nn.ModuleList(resnets)
733
733
 
734
+ self.gradient_checkpointing = False
735
+
734
736
  def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
735
737
  hidden_states = self.resnets[0](hidden_states, temb)
736
738
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
737
- if attn is not None:
738
- hidden_states = attn(hidden_states, temb=temb)
739
- hidden_states = resnet(hidden_states, temb)
739
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
740
+
741
+ def create_custom_forward(module, return_dict=None):
742
+ def custom_forward(*inputs):
743
+ if return_dict is not None:
744
+ return module(*inputs, return_dict=return_dict)
745
+ else:
746
+ return module(*inputs)
747
+
748
+ return custom_forward
749
+
750
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
751
+ if attn is not None:
752
+ hidden_states = attn(hidden_states, temb=temb)
753
+ hidden_states = torch.utils.checkpoint.checkpoint(
754
+ create_custom_forward(resnet),
755
+ hidden_states,
756
+ temb,
757
+ **ckpt_kwargs,
758
+ )
759
+ else:
760
+ if attn is not None:
761
+ hidden_states = attn(hidden_states, temb=temb)
762
+ hidden_states = resnet(hidden_states, temb)
740
763
 
741
764
  return hidden_states
742
765
 
@@ -859,7 +882,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
859
882
 
860
883
  hidden_states = self.resnets[0](hidden_states, temb)
861
884
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
862
- if self.training and self.gradient_checkpointing:
885
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
863
886
 
864
887
  def create_custom_forward(module, return_dict=None):
865
888
  def custom_forward(*inputs):
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
1116
1139
  else:
1117
1140
  self.downsamplers = None
1118
1141
 
1142
+ self.gradient_checkpointing = False
1143
+
1119
1144
  def forward(
1120
1145
  self,
1121
1146
  hidden_states: torch.Tensor,
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
1130
1155
  output_states = ()
1131
1156
 
1132
1157
  for resnet, attn in zip(self.resnets, self.attentions):
1133
- hidden_states = resnet(hidden_states, temb)
1134
- hidden_states = attn(hidden_states, **cross_attention_kwargs)
1135
- output_states = output_states + (hidden_states,)
1158
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1159
+
1160
+ def create_custom_forward(module, return_dict=None):
1161
+ def custom_forward(*inputs):
1162
+ if return_dict is not None:
1163
+ return module(*inputs, return_dict=return_dict)
1164
+ else:
1165
+ return module(*inputs)
1166
+
1167
+ return custom_forward
1168
+
1169
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1170
+ hidden_states = torch.utils.checkpoint.checkpoint(
1171
+ create_custom_forward(resnet),
1172
+ hidden_states,
1173
+ temb,
1174
+ **ckpt_kwargs,
1175
+ )
1176
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
1177
+ output_states = output_states + (hidden_states,)
1178
+ else:
1179
+ hidden_states = resnet(hidden_states, temb)
1180
+ hidden_states = attn(hidden_states, **cross_attention_kwargs)
1181
+ output_states = output_states + (hidden_states,)
1136
1182
 
1137
1183
  if self.downsamplers is not None:
1138
1184
  for downsampler in self.downsamplers:
@@ -1257,7 +1303,7 @@ class CrossAttnDownBlock2D(nn.Module):
1257
1303
  blocks = list(zip(self.resnets, self.attentions))
1258
1304
 
1259
1305
  for i, (resnet, attn) in enumerate(blocks):
1260
- if self.training and self.gradient_checkpointing:
1306
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1261
1307
 
1262
1308
  def create_custom_forward(module, return_dict=None):
1263
1309
  def custom_forward(*inputs):
@@ -1371,7 +1417,7 @@ class DownBlock2D(nn.Module):
1371
1417
  output_states = ()
1372
1418
 
1373
1419
  for resnet in self.resnets:
1374
- if self.training and self.gradient_checkpointing:
1420
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1375
1421
 
1376
1422
  def create_custom_forward(module):
1377
1423
  def custom_forward(*inputs):
@@ -1859,7 +1905,7 @@ class ResnetDownsampleBlock2D(nn.Module):
1859
1905
  output_states = ()
1860
1906
 
1861
1907
  for resnet in self.resnets:
1862
- if self.training and self.gradient_checkpointing:
1908
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1863
1909
 
1864
1910
  def create_custom_forward(module):
1865
1911
  def custom_forward(*inputs):
@@ -2011,7 +2057,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
2011
2057
  mask = attention_mask
2012
2058
 
2013
2059
  for resnet, attn in zip(self.resnets, self.attentions):
2014
- if self.training and self.gradient_checkpointing:
2060
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2015
2061
 
2016
2062
  def create_custom_forward(module, return_dict=None):
2017
2063
  def custom_forward(*inputs):
@@ -2106,7 +2152,7 @@ class KDownBlock2D(nn.Module):
2106
2152
  output_states = ()
2107
2153
 
2108
2154
  for resnet in self.resnets:
2109
- if self.training and self.gradient_checkpointing:
2155
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2110
2156
 
2111
2157
  def create_custom_forward(module):
2112
2158
  def custom_forward(*inputs):
@@ -2215,7 +2261,7 @@ class KCrossAttnDownBlock2D(nn.Module):
2215
2261
  output_states = ()
2216
2262
 
2217
2263
  for resnet, attn in zip(self.resnets, self.attentions):
2218
- if self.training and self.gradient_checkpointing:
2264
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2219
2265
 
2220
2266
  def create_custom_forward(module, return_dict=None):
2221
2267
  def custom_forward(*inputs):
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
2354
2400
  else:
2355
2401
  self.upsamplers = None
2356
2402
 
2403
+ self.gradient_checkpointing = False
2357
2404
  self.resolution_idx = resolution_idx
2358
2405
 
2359
2406
  def forward(
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
2375
2422
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
2376
2423
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2377
2424
 
2378
- hidden_states = resnet(hidden_states, temb)
2379
- hidden_states = attn(hidden_states)
2425
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2426
+
2427
+ def create_custom_forward(module, return_dict=None):
2428
+ def custom_forward(*inputs):
2429
+ if return_dict is not None:
2430
+ return module(*inputs, return_dict=return_dict)
2431
+ else:
2432
+ return module(*inputs)
2433
+
2434
+ return custom_forward
2435
+
2436
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
2437
+ hidden_states = torch.utils.checkpoint.checkpoint(
2438
+ create_custom_forward(resnet),
2439
+ hidden_states,
2440
+ temb,
2441
+ **ckpt_kwargs,
2442
+ )
2443
+ hidden_states = attn(hidden_states)
2444
+ else:
2445
+ hidden_states = resnet(hidden_states, temb)
2446
+ hidden_states = attn(hidden_states)
2380
2447
 
2381
2448
  if self.upsamplers is not None:
2382
2449
  for upsampler in self.upsamplers:
@@ -2520,7 +2587,7 @@ class CrossAttnUpBlock2D(nn.Module):
2520
2587
 
2521
2588
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2522
2589
 
2523
- if self.training and self.gradient_checkpointing:
2590
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2524
2591
 
2525
2592
  def create_custom_forward(module, return_dict=None):
2526
2593
  def custom_forward(*inputs):
@@ -2653,7 +2720,7 @@ class UpBlock2D(nn.Module):
2653
2720
 
2654
2721
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
2655
2722
 
2656
- if self.training and self.gradient_checkpointing:
2723
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
2657
2724
 
2658
2725
  def create_custom_forward(module):
2659
2726
  def custom_forward(*inputs):
@@ -3183,7 +3250,7 @@ class ResnetUpsampleBlock2D(nn.Module):
3183
3250
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
3184
3251
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3185
3252
 
3186
- if self.training and self.gradient_checkpointing:
3253
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3187
3254
 
3188
3255
  def create_custom_forward(module):
3189
3256
  def custom_forward(*inputs):
@@ -3341,7 +3408,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
3341
3408
  res_hidden_states_tuple = res_hidden_states_tuple[:-1]
3342
3409
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
3343
3410
 
3344
- if self.training and self.gradient_checkpointing:
3411
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3345
3412
 
3346
3413
  def create_custom_forward(module, return_dict=None):
3347
3414
  def custom_forward(*inputs):
@@ -3444,7 +3511,7 @@ class KUpBlock2D(nn.Module):
3444
3511
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3445
3512
 
3446
3513
  for resnet in self.resnets:
3447
- if self.training and self.gradient_checkpointing:
3514
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3448
3515
 
3449
3516
  def create_custom_forward(module):
3450
3517
  def custom_forward(*inputs):
@@ -3572,7 +3639,7 @@ class KCrossAttnUpBlock2D(nn.Module):
3572
3639
  hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
3573
3640
 
3574
3641
  for resnet, attn in zip(self.resnets, self.attentions):
3575
- if self.training and self.gradient_checkpointing:
3642
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
3576
3643
 
3577
3644
  def create_custom_forward(module, return_dict=None):
3578
3645
  def custom_forward(*inputs):
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
170
170
  @register_to_config
171
171
  def __init__(
172
172
  self,
173
- sample_size: Optional[int] = None,
173
+ sample_size: Optional[Union[int, Tuple[int, int]]] = None,
174
174
  in_channels: int = 4,
175
175
  out_channels: int = 4,
176
176
  center_input_sample: bool = False,
@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
1078
1078
  )
1079
1079
 
1080
1080
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
1081
- if self.training and self.gradient_checkpointing: # TODO
1081
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1082
1082
 
1083
1083
  def create_custom_forward(module, return_dict=None):
1084
1084
  def custom_forward(*inputs):
@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
1168
1168
  ) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
1169
1169
  output_states = ()
1170
1170
  for resnet in self.resnets:
1171
- if self.training and self.gradient_checkpointing:
1171
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1172
1172
 
1173
1173
  def create_custom_forward(module):
1174
1174
  def custom_forward(*inputs):
@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
1281
1281
 
1282
1282
  blocks = list(zip(self.resnets, self.attentions))
1283
1283
  for resnet, attn in blocks:
1284
- if self.training and self.gradient_checkpointing: # TODO
1284
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1285
1285
 
1286
1286
  def create_custom_forward(module, return_dict=None):
1287
1287
  def custom_forward(*inputs):
@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
1375
1375
  res_hidden_states_tuple: Tuple[torch.Tensor, ...],
1376
1376
  temb: Optional[torch.Tensor] = None,
1377
1377
  image_only_indicator: Optional[torch.Tensor] = None,
1378
+ upsample_size: Optional[int] = None,
1378
1379
  ) -> torch.Tensor:
1379
1380
  for resnet in self.resnets:
1380
1381
  # pop res hidden states
@@ -1383,7 +1384,7 @@ class UpBlockSpatioTemporal(nn.Module):
1383
1384
 
1384
1385
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1385
1386
 
1386
- if self.training and self.gradient_checkpointing:
1387
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1387
1388
 
1388
1389
  def create_custom_forward(module):
1389
1390
  def custom_forward(*inputs):
@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
1415
1416
 
1416
1417
  if self.upsamplers is not None:
1417
1418
  for upsampler in self.upsamplers:
1418
- hidden_states = upsampler(hidden_states)
1419
+ hidden_states = upsampler(hidden_states, upsample_size)
1419
1420
 
1420
1421
  return hidden_states
1421
1422
 
@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1485
1486
  temb: Optional[torch.Tensor] = None,
1486
1487
  encoder_hidden_states: Optional[torch.Tensor] = None,
1487
1488
  image_only_indicator: Optional[torch.Tensor] = None,
1489
+ upsample_size: Optional[int] = None,
1488
1490
  ) -> torch.Tensor:
1489
1491
  for resnet, attn in zip(self.resnets, self.attentions):
1490
1492
  # pop res hidden states
@@ -1493,7 +1495,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1493
1495
 
1494
1496
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
1495
1497
 
1496
- if self.training and self.gradient_checkpointing: # TODO
1498
+ if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
1497
1499
 
1498
1500
  def create_custom_forward(module, return_dict=None):
1499
1501
  def custom_forward(*inputs):
@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
1533
1535
 
1534
1536
  if self.upsamplers is not None:
1535
1537
  for upsampler in self.upsamplers:
1536
- hidden_states = upsampler(hidden_states)
1538
+ hidden_states = upsampler(hidden_states, upsample_size)
1537
1539
 
1538
1540
  return hidden_states
@@ -323,7 +323,7 @@ class DownBlockMotion(nn.Module):
323
323
 
324
324
  blocks = zip(self.resnets, self.motion_modules)
325
325
  for resnet, motion_module in blocks:
326
- if self.training and self.gradient_checkpointing:
326
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
327
327
 
328
328
  def create_custom_forward(module):
329
329
  def custom_forward(*inputs):
@@ -513,7 +513,7 @@ class CrossAttnDownBlockMotion(nn.Module):
513
513
 
514
514
  blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
515
515
  for i, (resnet, attn, motion_module) in enumerate(blocks):
516
- if self.training and self.gradient_checkpointing:
516
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
517
517
 
518
518
  def create_custom_forward(module, return_dict=None):
519
519
  def custom_forward(*inputs):
@@ -732,7 +732,7 @@ class CrossAttnUpBlockMotion(nn.Module):
732
732
 
733
733
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
734
734
 
735
- if self.training and self.gradient_checkpointing:
735
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
736
736
 
737
737
  def create_custom_forward(module, return_dict=None):
738
738
  def custom_forward(*inputs):
@@ -895,7 +895,7 @@ class UpBlockMotion(nn.Module):
895
895
 
896
896
  hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
897
897
 
898
- if self.training and self.gradient_checkpointing:
898
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
899
899
 
900
900
  def create_custom_forward(module):
901
901
  def custom_forward(*inputs):
@@ -1079,7 +1079,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
1079
1079
  return_dict=False,
1080
1080
  )[0]
1081
1081
 
1082
- if self.training and self.gradient_checkpointing:
1082
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
1083
1083
 
1084
1084
  def create_custom_forward(module, return_dict=None):
1085
1085
  def custom_forward(*inputs):
@@ -382,6 +382,20 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
382
382
  If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
383
383
  returned, otherwise a `tuple` is returned where the first element is the sample tensor.
384
384
  """
385
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
386
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
387
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
388
+ # on the fly if necessary.
389
+ default_overall_up_factor = 2**self.num_upsamplers
390
+
391
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
392
+ forward_upsample_size = False
393
+ upsample_size = None
394
+
395
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
396
+ logger.info("Forward upsample size to force interpolation output size.")
397
+ forward_upsample_size = True
398
+
385
399
  # 1. time
386
400
  timesteps = timestep
387
401
  if not torch.is_tensor(timesteps):
@@ -457,15 +471,23 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
457
471
 
458
472
  # 5. up
459
473
  for i, upsample_block in enumerate(self.up_blocks):
474
+ is_final_block = i == len(self.up_blocks) - 1
475
+
460
476
  res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
461
477
  down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
462
478
 
479
+ # if we have not reached the final block and need to forward the
480
+ # upsample size, we do it here
481
+ if not is_final_block and forward_upsample_size:
482
+ upsample_size = down_block_res_samples[-1].shape[2:]
483
+
463
484
  if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
464
485
  sample = upsample_block(
465
486
  hidden_states=sample,
466
487
  temb=emb,
467
488
  res_hidden_states_tuple=res_samples,
468
489
  encoder_hidden_states=encoder_hidden_states,
490
+ upsample_size=upsample_size,
469
491
  image_only_indicator=image_only_indicator,
470
492
  )
471
493
  else:
@@ -473,6 +495,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
473
495
  hidden_states=sample,
474
496
  temb=emb,
475
497
  res_hidden_states_tuple=res_samples,
498
+ upsample_size=upsample_size,
476
499
  image_only_indicator=image_only_indicator,
477
500
  )
478
501
 
@@ -455,7 +455,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
455
455
  level_outputs = []
456
456
  block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
457
457
 
458
- if self.training and self.gradient_checkpointing:
458
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
459
459
 
460
460
  def create_custom_forward(module):
461
461
  def custom_forward(*inputs):
@@ -504,7 +504,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
504
504
  x = level_outputs[0]
505
505
  block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
506
506
 
507
- if self.training and self.gradient_checkpointing:
507
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
508
508
 
509
509
  def create_custom_forward(module):
510
510
  def custom_forward(*inputs):
@@ -181,7 +181,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
181
181
  hidden_states = self.project_to_hidden(hidden_states)
182
182
 
183
183
  for layer in self.transformer_layers:
184
- if self.training and self.gradient_checkpointing:
184
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
185
185
 
186
186
  def layer_(*args):
187
187
  return checkpoint(layer, *args)
@@ -165,6 +165,14 @@ class Upsample2D(nn.Module):
165
165
  # if `output_size` is passed we force the interpolation output
166
166
  # size and do not make use of `scale_factor=2`
167
167
  if self.interpolate:
168
+ # upsample_nearest_nhwc also fails when the number of output elements is large
169
+ # https://github.com/pytorch/pytorch/issues/141831
170
+ scale_factor = (
171
+ 2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])])
172
+ )
173
+ if hidden_states.numel() * scale_factor > pow(2, 31):
174
+ hidden_states = hidden_states.contiguous()
175
+
168
176
  if output_size is None:
169
177
  hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
170
178
  else:
@@ -116,6 +116,7 @@ else:
116
116
  "VersatileDiffusionTextToImagePipeline",
117
117
  ]
118
118
  )
119
+ _import_structure["allegro"] = ["AllegroPipeline"]
119
120
  _import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
120
121
  _import_structure["animatediff"] = [
121
122
  "AnimateDiffPipeline",
@@ -126,12 +127,18 @@ else:
126
127
  "AnimateDiffVideoToVideoControlNetPipeline",
127
128
  ]
128
129
  _import_structure["flux"] = [
130
+ "FluxControlPipeline",
131
+ "FluxControlInpaintPipeline",
132
+ "FluxControlImg2ImgPipeline",
129
133
  "FluxControlNetPipeline",
130
134
  "FluxControlNetImg2ImgPipeline",
131
135
  "FluxControlNetInpaintPipeline",
132
136
  "FluxImg2ImgPipeline",
133
137
  "FluxInpaintPipeline",
134
138
  "FluxPipeline",
139
+ "FluxFillPipeline",
140
+ "FluxPriorReduxPipeline",
141
+ "ReduxImageEncoder",
135
142
  ]
136
143
  _import_structure["audioldm"] = ["AudioLDMPipeline"]
137
144
  _import_structure["audioldm2"] = [
@@ -156,6 +163,9 @@ else:
156
163
  "StableDiffusionXLControlNetImg2ImgPipeline",
157
164
  "StableDiffusionXLControlNetInpaintPipeline",
158
165
  "StableDiffusionXLControlNetPipeline",
166
+ "StableDiffusionXLControlNetUnionPipeline",
167
+ "StableDiffusionXLControlNetUnionInpaintPipeline",
168
+ "StableDiffusionXLControlNetUnionImg2ImgPipeline",
159
169
  ]
160
170
  )
161
171
  _import_structure["pag"].extend(
@@ -165,8 +175,10 @@ else:
165
175
  "KolorsPAGPipeline",
166
176
  "HunyuanDiTPAGPipeline",
167
177
  "StableDiffusion3PAGPipeline",
178
+ "StableDiffusion3PAGImg2ImgPipeline",
168
179
  "StableDiffusionPAGPipeline",
169
180
  "StableDiffusionPAGImg2ImgPipeline",
181
+ "StableDiffusionPAGInpaintPipeline",
170
182
  "StableDiffusionControlNetPAGPipeline",
171
183
  "StableDiffusionXLPAGPipeline",
172
184
  "StableDiffusionXLPAGInpaintPipeline",
@@ -174,6 +186,7 @@ else:
174
186
  "StableDiffusionXLControlNetPAGPipeline",
175
187
  "StableDiffusionXLPAGImg2ImgPipeline",
176
188
  "PixArtSigmaPAGPipeline",
189
+ "SanaPAGPipeline",
177
190
  ]
178
191
  )
179
192
  _import_structure["controlnet_xs"].extend(
@@ -202,6 +215,7 @@ else:
202
215
  "IFSuperResolutionPipeline",
203
216
  ]
204
217
  _import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
218
+ _import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"]
205
219
  _import_structure["kandinsky"] = [
206
220
  "KandinskyCombinedPipeline",
207
221
  "KandinskyImg2ImgCombinedPipeline",
@@ -239,6 +253,7 @@ else:
239
253
  ]
240
254
  )
241
255
  _import_structure["latte"] = ["LattePipeline"]
256
+ _import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
242
257
  _import_structure["lumina"] = ["LuminaText2ImgPipeline"]
243
258
  _import_structure["marigold"].extend(
244
259
  [
@@ -246,10 +261,12 @@ else:
246
261
  "MarigoldNormalsPipeline",
247
262
  ]
248
263
  )
264
+ _import_structure["mochi"] = ["MochiPipeline"]
249
265
  _import_structure["musicldm"] = ["MusicLDMPipeline"]
250
266
  _import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
251
267
  _import_structure["pia"] = ["PIAPipeline"]
252
268
  _import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
269
+ _import_structure["sana"] = ["SanaPipeline"]
253
270
  _import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
254
271
  _import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
255
272
  _import_structure["stable_audio"] = [
@@ -454,6 +471,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
454
471
  except OptionalDependencyNotAvailable:
455
472
  from ..utils.dummy_torch_and_transformers_objects import *
456
473
  else:
474
+ from .allegro import AllegroPipeline
457
475
  from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
458
476
  from .animatediff import (
459
477
  AnimateDiffControlNetPipeline,
@@ -486,6 +504,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
486
504
  StableDiffusionXLControlNetImg2ImgPipeline,
487
505
  StableDiffusionXLControlNetInpaintPipeline,
488
506
  StableDiffusionXLControlNetPipeline,
507
+ StableDiffusionXLControlNetUnionImg2ImgPipeline,
508
+ StableDiffusionXLControlNetUnionInpaintPipeline,
509
+ StableDiffusionXLControlNetUnionPipeline,
489
510
  )
490
511
  from .controlnet_hunyuandit import (
491
512
  HunyuanDiTControlNetPipeline,
@@ -518,13 +539,20 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
518
539
  VQDiffusionPipeline,
519
540
  )
520
541
  from .flux import (
542
+ FluxControlImg2ImgPipeline,
543
+ FluxControlInpaintPipeline,
521
544
  FluxControlNetImg2ImgPipeline,
522
545
  FluxControlNetInpaintPipeline,
523
546
  FluxControlNetPipeline,
547
+ FluxControlPipeline,
548
+ FluxFillPipeline,
524
549
  FluxImg2ImgPipeline,
525
550
  FluxInpaintPipeline,
526
551
  FluxPipeline,
552
+ FluxPriorReduxPipeline,
553
+ ReduxImageEncoder,
527
554
  )
555
+ from .hunyuan_video import HunyuanVideoPipeline
528
556
  from .hunyuandit import HunyuanDiTPipeline
529
557
  from .i2vgen_xl import I2VGenXLPipeline
530
558
  from .kandinsky import (
@@ -564,21 +592,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
564
592
  LEditsPPPipelineStableDiffusion,
565
593
  LEditsPPPipelineStableDiffusionXL,
566
594
  )
595
+ from .ltx import LTXImageToVideoPipeline, LTXPipeline
567
596
  from .lumina import LuminaText2ImgPipeline
568
597
  from .marigold import (
569
598
  MarigoldDepthPipeline,
570
599
  MarigoldNormalsPipeline,
571
600
  )
601
+ from .mochi import MochiPipeline
572
602
  from .musicldm import MusicLDMPipeline
573
603
  from .pag import (
574
604
  AnimateDiffPAGPipeline,
575
605
  HunyuanDiTPAGPipeline,
576
606
  KolorsPAGPipeline,
577
607
  PixArtSigmaPAGPipeline,
608
+ SanaPAGPipeline,
609
+ StableDiffusion3PAGImg2ImgPipeline,
578
610
  StableDiffusion3PAGPipeline,
579
611
  StableDiffusionControlNetPAGInpaintPipeline,
580
612
  StableDiffusionControlNetPAGPipeline,
581
613
  StableDiffusionPAGImg2ImgPipeline,
614
+ StableDiffusionPAGInpaintPipeline,
582
615
  StableDiffusionPAGPipeline,
583
616
  StableDiffusionXLControlNetPAGImg2ImgPipeline,
584
617
  StableDiffusionXLControlNetPAGPipeline,
@@ -589,6 +622,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
589
622
  from .paint_by_example import PaintByExamplePipeline
590
623
  from .pia import PIAPipeline
591
624
  from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
625
+ from .sana import SanaPipeline
592
626
  from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
593
627
  from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
594
628
  from .stable_audio import StableAudioPipeline, StableAudioProjectionModel