diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -217,7 +217,7 @@ class MidResTemporalBlock1D(nn.Module):
|
|
217
217
|
if self.upsample:
|
218
218
|
hidden_states = self.upsample(hidden_states)
|
219
219
|
if self.downsample:
|
220
|
-
|
220
|
+
hidden_states = self.downsample(hidden_states)
|
221
221
|
|
222
222
|
return hidden_states
|
223
223
|
|
@@ -89,6 +89,8 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
89
89
|
conditioning with `class_embed_type` equal to `None`.
|
90
90
|
"""
|
91
91
|
|
92
|
+
_supports_gradient_checkpointing = True
|
93
|
+
|
92
94
|
@register_to_config
|
93
95
|
def __init__(
|
94
96
|
self,
|
@@ -97,6 +99,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
97
99
|
out_channels: int = 3,
|
98
100
|
center_input_sample: bool = False,
|
99
101
|
time_embedding_type: str = "positional",
|
102
|
+
time_embedding_dim: Optional[int] = None,
|
100
103
|
freq_shift: int = 0,
|
101
104
|
flip_sin_to_cos: bool = True,
|
102
105
|
down_block_types: Tuple[str, ...] = ("DownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D", "AttnDownBlock2D"),
|
@@ -122,7 +125,7 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
122
125
|
super().__init__()
|
123
126
|
|
124
127
|
self.sample_size = sample_size
|
125
|
-
time_embed_dim = block_out_channels[0] * 4
|
128
|
+
time_embed_dim = time_embedding_dim or block_out_channels[0] * 4
|
126
129
|
|
127
130
|
# Check inputs
|
128
131
|
if len(down_block_types) != len(up_block_types):
|
@@ -240,6 +243,10 @@ class UNet2DModel(ModelMixin, ConfigMixin):
|
|
240
243
|
self.conv_act = nn.SiLU()
|
241
244
|
self.conv_out = nn.Conv2d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
|
242
245
|
|
246
|
+
def _set_gradient_checkpointing(self, module, value=False):
|
247
|
+
if hasattr(module, "gradient_checkpointing"):
|
248
|
+
module.gradient_checkpointing = value
|
249
|
+
|
243
250
|
def forward(
|
244
251
|
self,
|
245
252
|
sample: torch.Tensor,
|
@@ -731,12 +731,35 @@ class UNetMidBlock2D(nn.Module):
|
|
731
731
|
self.attentions = nn.ModuleList(attentions)
|
732
732
|
self.resnets = nn.ModuleList(resnets)
|
733
733
|
|
734
|
+
self.gradient_checkpointing = False
|
735
|
+
|
734
736
|
def forward(self, hidden_states: torch.Tensor, temb: Optional[torch.Tensor] = None) -> torch.Tensor:
|
735
737
|
hidden_states = self.resnets[0](hidden_states, temb)
|
736
738
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
737
|
-
if
|
738
|
-
|
739
|
-
|
739
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
740
|
+
|
741
|
+
def create_custom_forward(module, return_dict=None):
|
742
|
+
def custom_forward(*inputs):
|
743
|
+
if return_dict is not None:
|
744
|
+
return module(*inputs, return_dict=return_dict)
|
745
|
+
else:
|
746
|
+
return module(*inputs)
|
747
|
+
|
748
|
+
return custom_forward
|
749
|
+
|
750
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
751
|
+
if attn is not None:
|
752
|
+
hidden_states = attn(hidden_states, temb=temb)
|
753
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
754
|
+
create_custom_forward(resnet),
|
755
|
+
hidden_states,
|
756
|
+
temb,
|
757
|
+
**ckpt_kwargs,
|
758
|
+
)
|
759
|
+
else:
|
760
|
+
if attn is not None:
|
761
|
+
hidden_states = attn(hidden_states, temb=temb)
|
762
|
+
hidden_states = resnet(hidden_states, temb)
|
740
763
|
|
741
764
|
return hidden_states
|
742
765
|
|
@@ -859,7 +882,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
859
882
|
|
860
883
|
hidden_states = self.resnets[0](hidden_states, temb)
|
861
884
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
862
|
-
if
|
885
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
863
886
|
|
864
887
|
def create_custom_forward(module, return_dict=None):
|
865
888
|
def custom_forward(*inputs):
|
@@ -1116,6 +1139,8 @@ class AttnDownBlock2D(nn.Module):
|
|
1116
1139
|
else:
|
1117
1140
|
self.downsamplers = None
|
1118
1141
|
|
1142
|
+
self.gradient_checkpointing = False
|
1143
|
+
|
1119
1144
|
def forward(
|
1120
1145
|
self,
|
1121
1146
|
hidden_states: torch.Tensor,
|
@@ -1130,9 +1155,30 @@ class AttnDownBlock2D(nn.Module):
|
|
1130
1155
|
output_states = ()
|
1131
1156
|
|
1132
1157
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1133
|
-
|
1134
|
-
|
1135
|
-
|
1158
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1159
|
+
|
1160
|
+
def create_custom_forward(module, return_dict=None):
|
1161
|
+
def custom_forward(*inputs):
|
1162
|
+
if return_dict is not None:
|
1163
|
+
return module(*inputs, return_dict=return_dict)
|
1164
|
+
else:
|
1165
|
+
return module(*inputs)
|
1166
|
+
|
1167
|
+
return custom_forward
|
1168
|
+
|
1169
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1170
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1171
|
+
create_custom_forward(resnet),
|
1172
|
+
hidden_states,
|
1173
|
+
temb,
|
1174
|
+
**ckpt_kwargs,
|
1175
|
+
)
|
1176
|
+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1177
|
+
output_states = output_states + (hidden_states,)
|
1178
|
+
else:
|
1179
|
+
hidden_states = resnet(hidden_states, temb)
|
1180
|
+
hidden_states = attn(hidden_states, **cross_attention_kwargs)
|
1181
|
+
output_states = output_states + (hidden_states,)
|
1136
1182
|
|
1137
1183
|
if self.downsamplers is not None:
|
1138
1184
|
for downsampler in self.downsamplers:
|
@@ -1257,7 +1303,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1257
1303
|
blocks = list(zip(self.resnets, self.attentions))
|
1258
1304
|
|
1259
1305
|
for i, (resnet, attn) in enumerate(blocks):
|
1260
|
-
if
|
1306
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1261
1307
|
|
1262
1308
|
def create_custom_forward(module, return_dict=None):
|
1263
1309
|
def custom_forward(*inputs):
|
@@ -1371,7 +1417,7 @@ class DownBlock2D(nn.Module):
|
|
1371
1417
|
output_states = ()
|
1372
1418
|
|
1373
1419
|
for resnet in self.resnets:
|
1374
|
-
if
|
1420
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1375
1421
|
|
1376
1422
|
def create_custom_forward(module):
|
1377
1423
|
def custom_forward(*inputs):
|
@@ -1859,7 +1905,7 @@ class ResnetDownsampleBlock2D(nn.Module):
|
|
1859
1905
|
output_states = ()
|
1860
1906
|
|
1861
1907
|
for resnet in self.resnets:
|
1862
|
-
if
|
1908
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1863
1909
|
|
1864
1910
|
def create_custom_forward(module):
|
1865
1911
|
def custom_forward(*inputs):
|
@@ -2011,7 +2057,7 @@ class SimpleCrossAttnDownBlock2D(nn.Module):
|
|
2011
2057
|
mask = attention_mask
|
2012
2058
|
|
2013
2059
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2014
|
-
if
|
2060
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2015
2061
|
|
2016
2062
|
def create_custom_forward(module, return_dict=None):
|
2017
2063
|
def custom_forward(*inputs):
|
@@ -2106,7 +2152,7 @@ class KDownBlock2D(nn.Module):
|
|
2106
2152
|
output_states = ()
|
2107
2153
|
|
2108
2154
|
for resnet in self.resnets:
|
2109
|
-
if
|
2155
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2110
2156
|
|
2111
2157
|
def create_custom_forward(module):
|
2112
2158
|
def custom_forward(*inputs):
|
@@ -2215,7 +2261,7 @@ class KCrossAttnDownBlock2D(nn.Module):
|
|
2215
2261
|
output_states = ()
|
2216
2262
|
|
2217
2263
|
for resnet, attn in zip(self.resnets, self.attentions):
|
2218
|
-
if
|
2264
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2219
2265
|
|
2220
2266
|
def create_custom_forward(module, return_dict=None):
|
2221
2267
|
def custom_forward(*inputs):
|
@@ -2354,6 +2400,7 @@ class AttnUpBlock2D(nn.Module):
|
|
2354
2400
|
else:
|
2355
2401
|
self.upsamplers = None
|
2356
2402
|
|
2403
|
+
self.gradient_checkpointing = False
|
2357
2404
|
self.resolution_idx = resolution_idx
|
2358
2405
|
|
2359
2406
|
def forward(
|
@@ -2375,8 +2422,28 @@ class AttnUpBlock2D(nn.Module):
|
|
2375
2422
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
2376
2423
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2377
2424
|
|
2378
|
-
|
2379
|
-
|
2425
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2426
|
+
|
2427
|
+
def create_custom_forward(module, return_dict=None):
|
2428
|
+
def custom_forward(*inputs):
|
2429
|
+
if return_dict is not None:
|
2430
|
+
return module(*inputs, return_dict=return_dict)
|
2431
|
+
else:
|
2432
|
+
return module(*inputs)
|
2433
|
+
|
2434
|
+
return custom_forward
|
2435
|
+
|
2436
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
2437
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
2438
|
+
create_custom_forward(resnet),
|
2439
|
+
hidden_states,
|
2440
|
+
temb,
|
2441
|
+
**ckpt_kwargs,
|
2442
|
+
)
|
2443
|
+
hidden_states = attn(hidden_states)
|
2444
|
+
else:
|
2445
|
+
hidden_states = resnet(hidden_states, temb)
|
2446
|
+
hidden_states = attn(hidden_states)
|
2380
2447
|
|
2381
2448
|
if self.upsamplers is not None:
|
2382
2449
|
for upsampler in self.upsamplers:
|
@@ -2520,7 +2587,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
2520
2587
|
|
2521
2588
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2522
2589
|
|
2523
|
-
if
|
2590
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2524
2591
|
|
2525
2592
|
def create_custom_forward(module, return_dict=None):
|
2526
2593
|
def custom_forward(*inputs):
|
@@ -2653,7 +2720,7 @@ class UpBlock2D(nn.Module):
|
|
2653
2720
|
|
2654
2721
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
2655
2722
|
|
2656
|
-
if
|
2723
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
2657
2724
|
|
2658
2725
|
def create_custom_forward(module):
|
2659
2726
|
def custom_forward(*inputs):
|
@@ -3183,7 +3250,7 @@ class ResnetUpsampleBlock2D(nn.Module):
|
|
3183
3250
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
3184
3251
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3185
3252
|
|
3186
|
-
if
|
3253
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3187
3254
|
|
3188
3255
|
def create_custom_forward(module):
|
3189
3256
|
def custom_forward(*inputs):
|
@@ -3341,7 +3408,7 @@ class SimpleCrossAttnUpBlock2D(nn.Module):
|
|
3341
3408
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
3342
3409
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
3343
3410
|
|
3344
|
-
if
|
3411
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3345
3412
|
|
3346
3413
|
def create_custom_forward(module, return_dict=None):
|
3347
3414
|
def custom_forward(*inputs):
|
@@ -3444,7 +3511,7 @@ class KUpBlock2D(nn.Module):
|
|
3444
3511
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3445
3512
|
|
3446
3513
|
for resnet in self.resnets:
|
3447
|
-
if
|
3514
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3448
3515
|
|
3449
3516
|
def create_custom_forward(module):
|
3450
3517
|
def custom_forward(*inputs):
|
@@ -3572,7 +3639,7 @@ class KCrossAttnUpBlock2D(nn.Module):
|
|
3572
3639
|
hidden_states = torch.cat([hidden_states, res_hidden_states_tuple], dim=1)
|
3573
3640
|
|
3574
3641
|
for resnet, attn in zip(self.resnets, self.attentions):
|
3575
|
-
if
|
3642
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
3576
3643
|
|
3577
3644
|
def create_custom_forward(module, return_dict=None):
|
3578
3645
|
def custom_forward(*inputs):
|
@@ -170,7 +170,7 @@ class UNet2DConditionModel(
|
|
170
170
|
@register_to_config
|
171
171
|
def __init__(
|
172
172
|
self,
|
173
|
-
sample_size: Optional[int] = None,
|
173
|
+
sample_size: Optional[Union[int, Tuple[int, int]]] = None,
|
174
174
|
in_channels: int = 4,
|
175
175
|
out_channels: int = 4,
|
176
176
|
center_input_sample: bool = False,
|
@@ -1078,7 +1078,7 @@ class UNetMidBlockSpatioTemporal(nn.Module):
|
|
1078
1078
|
)
|
1079
1079
|
|
1080
1080
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
1081
|
-
if
|
1081
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
1082
1082
|
|
1083
1083
|
def create_custom_forward(module, return_dict=None):
|
1084
1084
|
def custom_forward(*inputs):
|
@@ -1168,7 +1168,7 @@ class DownBlockSpatioTemporal(nn.Module):
|
|
1168
1168
|
) -> Tuple[torch.Tensor, Tuple[torch.Tensor, ...]]:
|
1169
1169
|
output_states = ()
|
1170
1170
|
for resnet in self.resnets:
|
1171
|
-
if
|
1171
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1172
1172
|
|
1173
1173
|
def create_custom_forward(module):
|
1174
1174
|
def custom_forward(*inputs):
|
@@ -1281,7 +1281,7 @@ class CrossAttnDownBlockSpatioTemporal(nn.Module):
|
|
1281
1281
|
|
1282
1282
|
blocks = list(zip(self.resnets, self.attentions))
|
1283
1283
|
for resnet, attn in blocks:
|
1284
|
-
if
|
1284
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
1285
1285
|
|
1286
1286
|
def create_custom_forward(module, return_dict=None):
|
1287
1287
|
def custom_forward(*inputs):
|
@@ -1375,6 +1375,7 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1375
1375
|
res_hidden_states_tuple: Tuple[torch.Tensor, ...],
|
1376
1376
|
temb: Optional[torch.Tensor] = None,
|
1377
1377
|
image_only_indicator: Optional[torch.Tensor] = None,
|
1378
|
+
upsample_size: Optional[int] = None,
|
1378
1379
|
) -> torch.Tensor:
|
1379
1380
|
for resnet in self.resnets:
|
1380
1381
|
# pop res hidden states
|
@@ -1383,7 +1384,7 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1383
1384
|
|
1384
1385
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1385
1386
|
|
1386
|
-
if
|
1387
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1387
1388
|
|
1388
1389
|
def create_custom_forward(module):
|
1389
1390
|
def custom_forward(*inputs):
|
@@ -1415,7 +1416,7 @@ class UpBlockSpatioTemporal(nn.Module):
|
|
1415
1416
|
|
1416
1417
|
if self.upsamplers is not None:
|
1417
1418
|
for upsampler in self.upsamplers:
|
1418
|
-
hidden_states = upsampler(hidden_states)
|
1419
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1419
1420
|
|
1420
1421
|
return hidden_states
|
1421
1422
|
|
@@ -1485,6 +1486,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1485
1486
|
temb: Optional[torch.Tensor] = None,
|
1486
1487
|
encoder_hidden_states: Optional[torch.Tensor] = None,
|
1487
1488
|
image_only_indicator: Optional[torch.Tensor] = None,
|
1489
|
+
upsample_size: Optional[int] = None,
|
1488
1490
|
) -> torch.Tensor:
|
1489
1491
|
for resnet, attn in zip(self.resnets, self.attentions):
|
1490
1492
|
# pop res hidden states
|
@@ -1493,7 +1495,7 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1493
1495
|
|
1494
1496
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1495
1497
|
|
1496
|
-
if
|
1498
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing: # TODO
|
1497
1499
|
|
1498
1500
|
def create_custom_forward(module, return_dict=None):
|
1499
1501
|
def custom_forward(*inputs):
|
@@ -1533,6 +1535,6 @@ class CrossAttnUpBlockSpatioTemporal(nn.Module):
|
|
1533
1535
|
|
1534
1536
|
if self.upsamplers is not None:
|
1535
1537
|
for upsampler in self.upsamplers:
|
1536
|
-
hidden_states = upsampler(hidden_states)
|
1538
|
+
hidden_states = upsampler(hidden_states, upsample_size)
|
1537
1539
|
|
1538
1540
|
return hidden_states
|
@@ -323,7 +323,7 @@ class DownBlockMotion(nn.Module):
|
|
323
323
|
|
324
324
|
blocks = zip(self.resnets, self.motion_modules)
|
325
325
|
for resnet, motion_module in blocks:
|
326
|
-
if
|
326
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
327
327
|
|
328
328
|
def create_custom_forward(module):
|
329
329
|
def custom_forward(*inputs):
|
@@ -513,7 +513,7 @@ class CrossAttnDownBlockMotion(nn.Module):
|
|
513
513
|
|
514
514
|
blocks = list(zip(self.resnets, self.attentions, self.motion_modules))
|
515
515
|
for i, (resnet, attn, motion_module) in enumerate(blocks):
|
516
|
-
if
|
516
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
517
517
|
|
518
518
|
def create_custom_forward(module, return_dict=None):
|
519
519
|
def custom_forward(*inputs):
|
@@ -732,7 +732,7 @@ class CrossAttnUpBlockMotion(nn.Module):
|
|
732
732
|
|
733
733
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
734
734
|
|
735
|
-
if
|
735
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
736
736
|
|
737
737
|
def create_custom_forward(module, return_dict=None):
|
738
738
|
def custom_forward(*inputs):
|
@@ -895,7 +895,7 @@ class UpBlockMotion(nn.Module):
|
|
895
895
|
|
896
896
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
897
897
|
|
898
|
-
if
|
898
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
899
899
|
|
900
900
|
def create_custom_forward(module):
|
901
901
|
def custom_forward(*inputs):
|
@@ -1079,7 +1079,7 @@ class UNetMidBlockCrossAttnMotion(nn.Module):
|
|
1079
1079
|
return_dict=False,
|
1080
1080
|
)[0]
|
1081
1081
|
|
1082
|
-
if
|
1082
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1083
1083
|
|
1084
1084
|
def create_custom_forward(module, return_dict=None):
|
1085
1085
|
def custom_forward(*inputs):
|
@@ -382,6 +382,20 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
382
382
|
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
|
383
383
|
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
|
384
384
|
"""
|
385
|
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
386
|
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
387
|
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
388
|
+
# on the fly if necessary.
|
389
|
+
default_overall_up_factor = 2**self.num_upsamplers
|
390
|
+
|
391
|
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
392
|
+
forward_upsample_size = False
|
393
|
+
upsample_size = None
|
394
|
+
|
395
|
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
396
|
+
logger.info("Forward upsample size to force interpolation output size.")
|
397
|
+
forward_upsample_size = True
|
398
|
+
|
385
399
|
# 1. time
|
386
400
|
timesteps = timestep
|
387
401
|
if not torch.is_tensor(timesteps):
|
@@ -457,15 +471,23 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
457
471
|
|
458
472
|
# 5. up
|
459
473
|
for i, upsample_block in enumerate(self.up_blocks):
|
474
|
+
is_final_block = i == len(self.up_blocks) - 1
|
475
|
+
|
460
476
|
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
461
477
|
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
462
478
|
|
479
|
+
# if we have not reached the final block and need to forward the
|
480
|
+
# upsample size, we do it here
|
481
|
+
if not is_final_block and forward_upsample_size:
|
482
|
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
483
|
+
|
463
484
|
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
464
485
|
sample = upsample_block(
|
465
486
|
hidden_states=sample,
|
466
487
|
temb=emb,
|
467
488
|
res_hidden_states_tuple=res_samples,
|
468
489
|
encoder_hidden_states=encoder_hidden_states,
|
490
|
+
upsample_size=upsample_size,
|
469
491
|
image_only_indicator=image_only_indicator,
|
470
492
|
)
|
471
493
|
else:
|
@@ -473,6 +495,7 @@ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionL
|
|
473
495
|
hidden_states=sample,
|
474
496
|
temb=emb,
|
475
497
|
res_hidden_states_tuple=res_samples,
|
498
|
+
upsample_size=upsample_size,
|
476
499
|
image_only_indicator=image_only_indicator,
|
477
500
|
)
|
478
501
|
|
@@ -455,7 +455,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
455
455
|
level_outputs = []
|
456
456
|
block_group = zip(self.down_blocks, self.down_downscalers, self.down_repeat_mappers)
|
457
457
|
|
458
|
-
if
|
458
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
459
459
|
|
460
460
|
def create_custom_forward(module):
|
461
461
|
def custom_forward(*inputs):
|
@@ -504,7 +504,7 @@ class StableCascadeUNet(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
504
504
|
x = level_outputs[0]
|
505
505
|
block_group = zip(self.up_blocks, self.up_upscalers, self.up_repeat_mappers)
|
506
506
|
|
507
|
-
if
|
507
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
508
508
|
|
509
509
|
def create_custom_forward(module):
|
510
510
|
def custom_forward(*inputs):
|
@@ -181,7 +181,7 @@ class UVit2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
181
181
|
hidden_states = self.project_to_hidden(hidden_states)
|
182
182
|
|
183
183
|
for layer in self.transformer_layers:
|
184
|
-
if
|
184
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
185
185
|
|
186
186
|
def layer_(*args):
|
187
187
|
return checkpoint(layer, *args)
|
diffusers/models/upsampling.py
CHANGED
@@ -165,6 +165,14 @@ class Upsample2D(nn.Module):
|
|
165
165
|
# if `output_size` is passed we force the interpolation output
|
166
166
|
# size and do not make use of `scale_factor=2`
|
167
167
|
if self.interpolate:
|
168
|
+
# upsample_nearest_nhwc also fails when the number of output elements is large
|
169
|
+
# https://github.com/pytorch/pytorch/issues/141831
|
170
|
+
scale_factor = (
|
171
|
+
2 if output_size is None else max([f / s for f, s in zip(output_size, hidden_states.shape[-2:])])
|
172
|
+
)
|
173
|
+
if hidden_states.numel() * scale_factor > pow(2, 31):
|
174
|
+
hidden_states = hidden_states.contiguous()
|
175
|
+
|
168
176
|
if output_size is None:
|
169
177
|
hidden_states = F.interpolate(hidden_states, scale_factor=2.0, mode="nearest")
|
170
178
|
else:
|
diffusers/pipelines/__init__.py
CHANGED
@@ -116,6 +116,7 @@ else:
|
|
116
116
|
"VersatileDiffusionTextToImagePipeline",
|
117
117
|
]
|
118
118
|
)
|
119
|
+
_import_structure["allegro"] = ["AllegroPipeline"]
|
119
120
|
_import_structure["amused"] = ["AmusedImg2ImgPipeline", "AmusedInpaintPipeline", "AmusedPipeline"]
|
120
121
|
_import_structure["animatediff"] = [
|
121
122
|
"AnimateDiffPipeline",
|
@@ -126,12 +127,18 @@ else:
|
|
126
127
|
"AnimateDiffVideoToVideoControlNetPipeline",
|
127
128
|
]
|
128
129
|
_import_structure["flux"] = [
|
130
|
+
"FluxControlPipeline",
|
131
|
+
"FluxControlInpaintPipeline",
|
132
|
+
"FluxControlImg2ImgPipeline",
|
129
133
|
"FluxControlNetPipeline",
|
130
134
|
"FluxControlNetImg2ImgPipeline",
|
131
135
|
"FluxControlNetInpaintPipeline",
|
132
136
|
"FluxImg2ImgPipeline",
|
133
137
|
"FluxInpaintPipeline",
|
134
138
|
"FluxPipeline",
|
139
|
+
"FluxFillPipeline",
|
140
|
+
"FluxPriorReduxPipeline",
|
141
|
+
"ReduxImageEncoder",
|
135
142
|
]
|
136
143
|
_import_structure["audioldm"] = ["AudioLDMPipeline"]
|
137
144
|
_import_structure["audioldm2"] = [
|
@@ -156,6 +163,9 @@ else:
|
|
156
163
|
"StableDiffusionXLControlNetImg2ImgPipeline",
|
157
164
|
"StableDiffusionXLControlNetInpaintPipeline",
|
158
165
|
"StableDiffusionXLControlNetPipeline",
|
166
|
+
"StableDiffusionXLControlNetUnionPipeline",
|
167
|
+
"StableDiffusionXLControlNetUnionInpaintPipeline",
|
168
|
+
"StableDiffusionXLControlNetUnionImg2ImgPipeline",
|
159
169
|
]
|
160
170
|
)
|
161
171
|
_import_structure["pag"].extend(
|
@@ -165,8 +175,10 @@ else:
|
|
165
175
|
"KolorsPAGPipeline",
|
166
176
|
"HunyuanDiTPAGPipeline",
|
167
177
|
"StableDiffusion3PAGPipeline",
|
178
|
+
"StableDiffusion3PAGImg2ImgPipeline",
|
168
179
|
"StableDiffusionPAGPipeline",
|
169
180
|
"StableDiffusionPAGImg2ImgPipeline",
|
181
|
+
"StableDiffusionPAGInpaintPipeline",
|
170
182
|
"StableDiffusionControlNetPAGPipeline",
|
171
183
|
"StableDiffusionXLPAGPipeline",
|
172
184
|
"StableDiffusionXLPAGInpaintPipeline",
|
@@ -174,6 +186,7 @@ else:
|
|
174
186
|
"StableDiffusionXLControlNetPAGPipeline",
|
175
187
|
"StableDiffusionXLPAGImg2ImgPipeline",
|
176
188
|
"PixArtSigmaPAGPipeline",
|
189
|
+
"SanaPAGPipeline",
|
177
190
|
]
|
178
191
|
)
|
179
192
|
_import_structure["controlnet_xs"].extend(
|
@@ -202,6 +215,7 @@ else:
|
|
202
215
|
"IFSuperResolutionPipeline",
|
203
216
|
]
|
204
217
|
_import_structure["hunyuandit"] = ["HunyuanDiTPipeline"]
|
218
|
+
_import_structure["hunyuan_video"] = ["HunyuanVideoPipeline"]
|
205
219
|
_import_structure["kandinsky"] = [
|
206
220
|
"KandinskyCombinedPipeline",
|
207
221
|
"KandinskyImg2ImgCombinedPipeline",
|
@@ -239,6 +253,7 @@ else:
|
|
239
253
|
]
|
240
254
|
)
|
241
255
|
_import_structure["latte"] = ["LattePipeline"]
|
256
|
+
_import_structure["ltx"] = ["LTXPipeline", "LTXImageToVideoPipeline"]
|
242
257
|
_import_structure["lumina"] = ["LuminaText2ImgPipeline"]
|
243
258
|
_import_structure["marigold"].extend(
|
244
259
|
[
|
@@ -246,10 +261,12 @@ else:
|
|
246
261
|
"MarigoldNormalsPipeline",
|
247
262
|
]
|
248
263
|
)
|
264
|
+
_import_structure["mochi"] = ["MochiPipeline"]
|
249
265
|
_import_structure["musicldm"] = ["MusicLDMPipeline"]
|
250
266
|
_import_structure["paint_by_example"] = ["PaintByExamplePipeline"]
|
251
267
|
_import_structure["pia"] = ["PIAPipeline"]
|
252
268
|
_import_structure["pixart_alpha"] = ["PixArtAlphaPipeline", "PixArtSigmaPipeline"]
|
269
|
+
_import_structure["sana"] = ["SanaPipeline"]
|
253
270
|
_import_structure["semantic_stable_diffusion"] = ["SemanticStableDiffusionPipeline"]
|
254
271
|
_import_structure["shap_e"] = ["ShapEImg2ImgPipeline", "ShapEPipeline"]
|
255
272
|
_import_structure["stable_audio"] = [
|
@@ -454,6 +471,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
454
471
|
except OptionalDependencyNotAvailable:
|
455
472
|
from ..utils.dummy_torch_and_transformers_objects import *
|
456
473
|
else:
|
474
|
+
from .allegro import AllegroPipeline
|
457
475
|
from .amused import AmusedImg2ImgPipeline, AmusedInpaintPipeline, AmusedPipeline
|
458
476
|
from .animatediff import (
|
459
477
|
AnimateDiffControlNetPipeline,
|
@@ -486,6 +504,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
486
504
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
487
505
|
StableDiffusionXLControlNetInpaintPipeline,
|
488
506
|
StableDiffusionXLControlNetPipeline,
|
507
|
+
StableDiffusionXLControlNetUnionImg2ImgPipeline,
|
508
|
+
StableDiffusionXLControlNetUnionInpaintPipeline,
|
509
|
+
StableDiffusionXLControlNetUnionPipeline,
|
489
510
|
)
|
490
511
|
from .controlnet_hunyuandit import (
|
491
512
|
HunyuanDiTControlNetPipeline,
|
@@ -518,13 +539,20 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
518
539
|
VQDiffusionPipeline,
|
519
540
|
)
|
520
541
|
from .flux import (
|
542
|
+
FluxControlImg2ImgPipeline,
|
543
|
+
FluxControlInpaintPipeline,
|
521
544
|
FluxControlNetImg2ImgPipeline,
|
522
545
|
FluxControlNetInpaintPipeline,
|
523
546
|
FluxControlNetPipeline,
|
547
|
+
FluxControlPipeline,
|
548
|
+
FluxFillPipeline,
|
524
549
|
FluxImg2ImgPipeline,
|
525
550
|
FluxInpaintPipeline,
|
526
551
|
FluxPipeline,
|
552
|
+
FluxPriorReduxPipeline,
|
553
|
+
ReduxImageEncoder,
|
527
554
|
)
|
555
|
+
from .hunyuan_video import HunyuanVideoPipeline
|
528
556
|
from .hunyuandit import HunyuanDiTPipeline
|
529
557
|
from .i2vgen_xl import I2VGenXLPipeline
|
530
558
|
from .kandinsky import (
|
@@ -564,21 +592,26 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
564
592
|
LEditsPPPipelineStableDiffusion,
|
565
593
|
LEditsPPPipelineStableDiffusionXL,
|
566
594
|
)
|
595
|
+
from .ltx import LTXImageToVideoPipeline, LTXPipeline
|
567
596
|
from .lumina import LuminaText2ImgPipeline
|
568
597
|
from .marigold import (
|
569
598
|
MarigoldDepthPipeline,
|
570
599
|
MarigoldNormalsPipeline,
|
571
600
|
)
|
601
|
+
from .mochi import MochiPipeline
|
572
602
|
from .musicldm import MusicLDMPipeline
|
573
603
|
from .pag import (
|
574
604
|
AnimateDiffPAGPipeline,
|
575
605
|
HunyuanDiTPAGPipeline,
|
576
606
|
KolorsPAGPipeline,
|
577
607
|
PixArtSigmaPAGPipeline,
|
608
|
+
SanaPAGPipeline,
|
609
|
+
StableDiffusion3PAGImg2ImgPipeline,
|
578
610
|
StableDiffusion3PAGPipeline,
|
579
611
|
StableDiffusionControlNetPAGInpaintPipeline,
|
580
612
|
StableDiffusionControlNetPAGPipeline,
|
581
613
|
StableDiffusionPAGImg2ImgPipeline,
|
614
|
+
StableDiffusionPAGInpaintPipeline,
|
582
615
|
StableDiffusionPAGPipeline,
|
583
616
|
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
584
617
|
StableDiffusionXLControlNetPAGPipeline,
|
@@ -589,6 +622,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
589
622
|
from .paint_by_example import PaintByExamplePipeline
|
590
623
|
from .pia import PIAPipeline
|
591
624
|
from .pixart_alpha import PixArtAlphaPipeline, PixArtSigmaPipeline
|
625
|
+
from .sana import SanaPipeline
|
592
626
|
from .semantic_stable_diffusion import SemanticStableDiffusionPipeline
|
593
627
|
from .shap_e import ShapEImg2ImgPipeline, ShapEPipeline
|
594
628
|
from .stable_audio import StableAudioPipeline, StableAudioProjectionModel
|