diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,23 @@
|
|
1
|
+
from dataclasses import dataclass
|
2
|
+
from typing import List, Union
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import PIL
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from diffusers.utils import BaseOutput
|
9
|
+
|
10
|
+
|
11
|
+
@dataclass
|
12
|
+
class AllegroPipelineOutput(BaseOutput):
|
13
|
+
r"""
|
14
|
+
Output class for Allegro pipelines.
|
15
|
+
|
16
|
+
Args:
|
17
|
+
frames (`torch.Tensor`, `np.ndarray`, or List[List[PIL.Image.Image]]):
|
18
|
+
List of video outputs - It can be a nested list of length `batch_size,` with each sub-list containing
|
19
|
+
denoised PIL image sequences of length `num_frames.` It can also be a NumPy array or Torch tensor of shape
|
20
|
+
`(batch_size, num_frames, channels, height, width)`.
|
21
|
+
"""
|
22
|
+
|
23
|
+
frames: Union[torch.Tensor, np.ndarray, List[List[PIL.Image.Image]]]
|
@@ -21,14 +21,20 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
21
21
|
|
22
22
|
from ...image_processor import PipelineImageInput
|
23
23
|
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
24
|
-
from ...models import
|
24
|
+
from ...models import (
|
25
|
+
AutoencoderKL,
|
26
|
+
ControlNetModel,
|
27
|
+
ImageProjection,
|
28
|
+
MultiControlNetModel,
|
29
|
+
UNet2DConditionModel,
|
30
|
+
UNetMotionModel,
|
31
|
+
)
|
25
32
|
from ...models.lora import adjust_lora_scale_text_encoder
|
26
33
|
from ...models.unets.unet_motion_model import MotionAdapter
|
27
34
|
from ...schedulers import KarrasDiffusionSchedulers
|
28
35
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
29
36
|
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
30
37
|
from ...video_processor import VideoProcessor
|
31
|
-
from ..controlnet.multicontrolnet import MultiControlNetModel
|
32
38
|
from ..free_init_utils import FreeInitMixin
|
33
39
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
34
40
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
@@ -24,7 +24,7 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
24
24
|
from ...image_processor import PipelineImageInput, VaeImageProcessor
|
25
25
|
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
26
26
|
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
27
|
-
from ...models.controlnet_sparsectrl import SparseControlNetModel
|
27
|
+
from ...models.controlnets.controlnet_sparsectrl import SparseControlNetModel
|
28
28
|
from ...models.lora import adjust_lora_scale_text_encoder
|
29
29
|
from ...models.unets.unet_motion_model import MotionAdapter
|
30
30
|
from ...schedulers import KarrasDiffusionSchedulers
|
@@ -662,12 +662,6 @@ class AnimateDiffVideoToVideoPipeline(
|
|
662
662
|
self.vae.to(dtype=torch.float32)
|
663
663
|
|
664
664
|
if isinstance(generator, list):
|
665
|
-
if len(generator) != batch_size:
|
666
|
-
raise ValueError(
|
667
|
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
668
|
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
669
|
-
)
|
670
|
-
|
671
665
|
init_latents = [
|
672
666
|
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
|
673
667
|
for i in range(batch_size)
|
@@ -21,7 +21,14 @@ from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPV
|
|
21
21
|
|
22
22
|
from ...image_processor import PipelineImageInput
|
23
23
|
from ...loaders import IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
24
|
-
from ...models import
|
24
|
+
from ...models import (
|
25
|
+
AutoencoderKL,
|
26
|
+
ControlNetModel,
|
27
|
+
ImageProjection,
|
28
|
+
MultiControlNetModel,
|
29
|
+
UNet2DConditionModel,
|
30
|
+
UNetMotionModel,
|
31
|
+
)
|
25
32
|
from ...models.lora import adjust_lora_scale_text_encoder
|
26
33
|
from ...models.unets.unet_motion_model import MotionAdapter
|
27
34
|
from ...schedulers import (
|
@@ -35,7 +42,6 @@ from ...schedulers import (
|
|
35
42
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
36
43
|
from ...utils.torch_utils import is_compiled_module, randn_tensor
|
37
44
|
from ...video_processor import VideoProcessor
|
38
|
-
from ..controlnet.multicontrolnet import MultiControlNetModel
|
39
45
|
from ..free_init_utils import FreeInitMixin
|
40
46
|
from ..free_noise_utils import AnimateDiffFreeNoiseMixin
|
41
47
|
from ..pipeline_utils import DiffusionPipeline, StableDiffusionMixin
|
@@ -788,12 +794,6 @@ class AnimateDiffVideoToVideoControlNetPipeline(
|
|
788
794
|
self.vae.to(dtype=torch.float32)
|
789
795
|
|
790
796
|
if isinstance(generator, list):
|
791
|
-
if len(generator) != batch_size:
|
792
|
-
raise ValueError(
|
793
|
-
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
794
|
-
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
795
|
-
)
|
796
|
-
|
797
797
|
init_latents = [
|
798
798
|
self.encode_video(video[i], generator[i], decode_chunk_size).unsqueeze(0)
|
799
799
|
for i in range(batch_size)
|
@@ -1112,7 +1112,7 @@ class CrossAttnDownBlock2D(nn.Module):
|
|
1112
1112
|
)
|
1113
1113
|
|
1114
1114
|
for i in range(num_layers):
|
1115
|
-
if
|
1115
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1116
1116
|
|
1117
1117
|
def create_custom_forward(module, return_dict=None):
|
1118
1118
|
def custom_forward(*inputs):
|
@@ -1290,7 +1290,7 @@ class UNetMidBlock2DCrossAttn(nn.Module):
|
|
1290
1290
|
)
|
1291
1291
|
|
1292
1292
|
for i in range(len(self.resnets[1:])):
|
1293
|
-
if
|
1293
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1294
1294
|
|
1295
1295
|
def create_custom_forward(module, return_dict=None):
|
1296
1296
|
def custom_forward(*inputs):
|
@@ -1464,7 +1464,7 @@ class CrossAttnUpBlock2D(nn.Module):
|
|
1464
1464
|
res_hidden_states_tuple = res_hidden_states_tuple[:-1]
|
1465
1465
|
hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
|
1466
1466
|
|
1467
|
-
if
|
1467
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
1468
1468
|
|
1469
1469
|
def create_custom_forward(module, return_dict=None):
|
1470
1470
|
def custom_forward(*inputs):
|
@@ -387,7 +387,6 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
387
387
|
prompt: Union[str, List[str]] = None,
|
388
388
|
negative_prompt: Union[str, List[str]] = None,
|
389
389
|
num_inference_steps: int = 50,
|
390
|
-
timesteps: List[int] = None,
|
391
390
|
sigmas: List[float] = None,
|
392
391
|
guidance_scale: float = 3.5,
|
393
392
|
num_images_per_prompt: Optional[int] = 1,
|
@@ -424,10 +423,6 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
424
423
|
sigmas (`List[float]`, *optional*):
|
425
424
|
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
426
425
|
`num_inference_steps` and `timesteps` must be `None`.
|
427
|
-
timesteps (`List[int]`, *optional*):
|
428
|
-
Custom timesteps to use for the denoising process with schedulers which support a `timesteps` argument
|
429
|
-
in their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is
|
430
|
-
passed will be used. Must be in descending order.
|
431
426
|
guidance_scale (`float`, *optional*, defaults to 5.0):
|
432
427
|
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
433
428
|
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
@@ -522,9 +517,7 @@ class AuraFlowPipeline(DiffusionPipeline):
|
|
522
517
|
# 4. Prepare timesteps
|
523
518
|
|
524
519
|
# sigmas = np.linspace(1.0, 1 / num_inference_steps, num_inference_steps)
|
525
|
-
timesteps, num_inference_steps = retrieve_timesteps(
|
526
|
-
self.scheduler, num_inference_steps, device, timesteps, sigmas
|
527
|
-
)
|
520
|
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, sigmas=sigmas)
|
528
521
|
|
529
522
|
# 5. Prepare latents.
|
530
523
|
latent_channels = self.transformer.config.in_channels
|
@@ -18,6 +18,7 @@ from collections import OrderedDict
|
|
18
18
|
from huggingface_hub.utils import validate_hf_hub_args
|
19
19
|
|
20
20
|
from ..configuration_utils import ConfigMixin
|
21
|
+
from ..models.controlnets import ControlNetUnionModel
|
21
22
|
from ..utils import is_sentencepiece_available
|
22
23
|
from .aura_flow import AuraFlowPipeline
|
23
24
|
from .cogview3 import CogView3PlusPipeline
|
@@ -28,12 +29,18 @@ from .controlnet import (
|
|
28
29
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
29
30
|
StableDiffusionXLControlNetInpaintPipeline,
|
30
31
|
StableDiffusionXLControlNetPipeline,
|
32
|
+
StableDiffusionXLControlNetUnionImg2ImgPipeline,
|
33
|
+
StableDiffusionXLControlNetUnionInpaintPipeline,
|
34
|
+
StableDiffusionXLControlNetUnionPipeline,
|
31
35
|
)
|
32
36
|
from .deepfloyd_if import IFImg2ImgPipeline, IFInpaintingPipeline, IFPipeline
|
33
37
|
from .flux import (
|
38
|
+
FluxControlImg2ImgPipeline,
|
39
|
+
FluxControlInpaintPipeline,
|
34
40
|
FluxControlNetImg2ImgPipeline,
|
35
41
|
FluxControlNetInpaintPipeline,
|
36
42
|
FluxControlNetPipeline,
|
43
|
+
FluxControlPipeline,
|
37
44
|
FluxImg2ImgPipeline,
|
38
45
|
FluxInpaintPipeline,
|
39
46
|
FluxPipeline,
|
@@ -61,10 +68,12 @@ from .lumina import LuminaText2ImgPipeline
|
|
61
68
|
from .pag import (
|
62
69
|
HunyuanDiTPAGPipeline,
|
63
70
|
PixArtSigmaPAGPipeline,
|
71
|
+
StableDiffusion3PAGImg2ImgPipeline,
|
64
72
|
StableDiffusion3PAGPipeline,
|
65
73
|
StableDiffusionControlNetPAGInpaintPipeline,
|
66
74
|
StableDiffusionControlNetPAGPipeline,
|
67
75
|
StableDiffusionPAGImg2ImgPipeline,
|
76
|
+
StableDiffusionPAGInpaintPipeline,
|
68
77
|
StableDiffusionPAGPipeline,
|
69
78
|
StableDiffusionXLControlNetPAGImg2ImgPipeline,
|
70
79
|
StableDiffusionXLControlNetPAGPipeline,
|
@@ -106,6 +115,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
|
106
115
|
("kandinsky3", Kandinsky3Pipeline),
|
107
116
|
("stable-diffusion-controlnet", StableDiffusionControlNetPipeline),
|
108
117
|
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetPipeline),
|
118
|
+
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionPipeline),
|
109
119
|
("wuerstchen", WuerstchenCombinedPipeline),
|
110
120
|
("cascade", StableCascadeCombinedPipeline),
|
111
121
|
("lcm", LatentConsistencyModelPipeline),
|
@@ -118,6 +128,7 @@ AUTO_TEXT2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
|
118
128
|
("pixart-sigma-pag", PixArtSigmaPAGPipeline),
|
119
129
|
("auraflow", AuraFlowPipeline),
|
120
130
|
("flux", FluxPipeline),
|
131
|
+
("flux-control", FluxControlPipeline),
|
121
132
|
("flux-controlnet", FluxControlNetPipeline),
|
122
133
|
("lumina", LuminaText2ImgPipeline),
|
123
134
|
("cogview3", CogView3PlusPipeline),
|
@@ -129,6 +140,7 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
|
129
140
|
("stable-diffusion", StableDiffusionImg2ImgPipeline),
|
130
141
|
("stable-diffusion-xl", StableDiffusionXLImg2ImgPipeline),
|
131
142
|
("stable-diffusion-3", StableDiffusion3Img2ImgPipeline),
|
143
|
+
("stable-diffusion-3-pag", StableDiffusion3PAGImg2ImgPipeline),
|
132
144
|
("if", IFImg2ImgPipeline),
|
133
145
|
("kandinsky", KandinskyImg2ImgCombinedPipeline),
|
134
146
|
("kandinsky22", KandinskyV22Img2ImgCombinedPipeline),
|
@@ -136,11 +148,13 @@ AUTO_IMAGE2IMAGE_PIPELINES_MAPPING = OrderedDict(
|
|
136
148
|
("stable-diffusion-controlnet", StableDiffusionControlNetImg2ImgPipeline),
|
137
149
|
("stable-diffusion-pag", StableDiffusionPAGImg2ImgPipeline),
|
138
150
|
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetImg2ImgPipeline),
|
151
|
+
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionImg2ImgPipeline),
|
139
152
|
("stable-diffusion-xl-pag", StableDiffusionXLPAGImg2ImgPipeline),
|
140
153
|
("stable-diffusion-xl-controlnet-pag", StableDiffusionXLControlNetPAGImg2ImgPipeline),
|
141
154
|
("lcm", LatentConsistencyModelImg2ImgPipeline),
|
142
155
|
("flux", FluxImg2ImgPipeline),
|
143
156
|
("flux-controlnet", FluxControlNetImg2ImgPipeline),
|
157
|
+
("flux-control", FluxControlImg2ImgPipeline),
|
144
158
|
]
|
145
159
|
)
|
146
160
|
|
@@ -155,9 +169,12 @@ AUTO_INPAINT_PIPELINES_MAPPING = OrderedDict(
|
|
155
169
|
("stable-diffusion-controlnet", StableDiffusionControlNetInpaintPipeline),
|
156
170
|
("stable-diffusion-controlnet-pag", StableDiffusionControlNetPAGInpaintPipeline),
|
157
171
|
("stable-diffusion-xl-controlnet", StableDiffusionXLControlNetInpaintPipeline),
|
172
|
+
("stable-diffusion-xl-controlnet-union", StableDiffusionXLControlNetUnionInpaintPipeline),
|
158
173
|
("stable-diffusion-xl-pag", StableDiffusionXLPAGInpaintPipeline),
|
159
174
|
("flux", FluxInpaintPipeline),
|
160
175
|
("flux-controlnet", FluxControlNetInpaintPipeline),
|
176
|
+
("flux-control", FluxControlInpaintPipeline),
|
177
|
+
("stable-diffusion-pag", StableDiffusionPAGInpaintPipeline),
|
161
178
|
]
|
162
179
|
)
|
163
180
|
|
@@ -390,13 +407,20 @@ class AutoPipelineForText2Image(ConfigMixin):
|
|
390
407
|
|
391
408
|
config = cls.load_config(pretrained_model_or_path, **load_config_kwargs)
|
392
409
|
orig_class_name = config["_class_name"]
|
410
|
+
if "ControlPipeline" in orig_class_name:
|
411
|
+
to_replace = "ControlPipeline"
|
412
|
+
else:
|
413
|
+
to_replace = "Pipeline"
|
393
414
|
|
394
415
|
if "controlnet" in kwargs:
|
395
|
-
|
416
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
417
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetUnionPipeline")
|
418
|
+
else:
|
419
|
+
orig_class_name = config["_class_name"].replace(to_replace, "ControlNetPipeline")
|
396
420
|
if "enable_pag" in kwargs:
|
397
421
|
enable_pag = kwargs.pop("enable_pag")
|
398
422
|
if enable_pag:
|
399
|
-
orig_class_name = orig_class_name.replace(
|
423
|
+
orig_class_name = orig_class_name.replace(to_replace, "PAGPipeline")
|
400
424
|
|
401
425
|
text_2_image_cls = _get_task_class(AUTO_TEXT2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
402
426
|
|
@@ -680,16 +704,28 @@ class AutoPipelineForImage2Image(ConfigMixin):
|
|
680
704
|
|
681
705
|
# the `orig_class_name` can be:
|
682
706
|
# `- *Pipeline` (for regular text-to-image checkpoint)
|
707
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
683
708
|
# `- *Img2ImgPipeline` (for refiner checkpoint)
|
684
|
-
|
709
|
+
if "Img2Img" in orig_class_name:
|
710
|
+
to_replace = "Img2ImgPipeline"
|
711
|
+
elif "ControlPipeline" in orig_class_name:
|
712
|
+
to_replace = "ControlPipeline"
|
713
|
+
else:
|
714
|
+
to_replace = "Pipeline"
|
685
715
|
|
686
716
|
if "controlnet" in kwargs:
|
687
|
-
|
717
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
718
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
719
|
+
else:
|
720
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
688
721
|
if "enable_pag" in kwargs:
|
689
722
|
enable_pag = kwargs.pop("enable_pag")
|
690
723
|
if enable_pag:
|
691
724
|
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
692
725
|
|
726
|
+
if to_replace == "ControlPipeline":
|
727
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlImg2ImgPipeline")
|
728
|
+
|
693
729
|
image_2_image_cls = _get_task_class(AUTO_IMAGE2IMAGE_PIPELINES_MAPPING, orig_class_name)
|
694
730
|
|
695
731
|
kwargs = {**load_config_kwargs, **kwargs}
|
@@ -977,15 +1013,26 @@ class AutoPipelineForInpainting(ConfigMixin):
|
|
977
1013
|
|
978
1014
|
# The `orig_class_name`` can be:
|
979
1015
|
# `- *InpaintPipeline` (for inpaint-specific checkpoint)
|
1016
|
+
# - `*ControlPipeline` (for Flux tools specific checkpoint)
|
980
1017
|
# - or *Pipeline (for regular text-to-image checkpoint)
|
981
|
-
|
1018
|
+
if "Inpaint" in orig_class_name:
|
1019
|
+
to_replace = "InpaintPipeline"
|
1020
|
+
elif "ControlPipeline" in orig_class_name:
|
1021
|
+
to_replace = "ControlPipeline"
|
1022
|
+
else:
|
1023
|
+
to_replace = "Pipeline"
|
982
1024
|
|
983
1025
|
if "controlnet" in kwargs:
|
984
|
-
|
1026
|
+
if isinstance(kwargs["controlnet"], ControlNetUnionModel):
|
1027
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNetUnion" + to_replace)
|
1028
|
+
else:
|
1029
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlNet" + to_replace)
|
985
1030
|
if "enable_pag" in kwargs:
|
986
1031
|
enable_pag = kwargs.pop("enable_pag")
|
987
1032
|
if enable_pag:
|
988
1033
|
orig_class_name = orig_class_name.replace(to_replace, "PAG" + to_replace)
|
1034
|
+
if to_replace == "ControlPipeline":
|
1035
|
+
orig_class_name = orig_class_name.replace(to_replace, "ControlInpaintPipeline")
|
989
1036
|
inpainting_cls = _get_task_class(AUTO_INPAINT_PIPELINES_MAPPING, orig_class_name)
|
990
1037
|
|
991
1038
|
kwargs = {**load_config_kwargs, **kwargs}
|
@@ -167,7 +167,7 @@ class Blip2QFormerEncoder(nn.Module):
|
|
167
167
|
layer_head_mask = head_mask[i] if head_mask is not None else None
|
168
168
|
past_key_value = past_key_values[i] if past_key_values is not None else None
|
169
169
|
|
170
|
-
if getattr(self.config, "gradient_checkpointing", False) and
|
170
|
+
if getattr(self.config, "gradient_checkpointing", False) and torch.is_grad_enabled():
|
171
171
|
if use_cache:
|
172
172
|
logger.warning(
|
173
173
|
"`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..."
|
@@ -442,21 +442,39 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
442
442
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
443
443
|
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
444
444
|
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
445
|
-
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
446
|
-
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
447
445
|
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
|
455
|
-
|
456
|
-
|
446
|
+
p = self.transformer.config.patch_size
|
447
|
+
p_t = self.transformer.config.patch_size_t
|
448
|
+
|
449
|
+
base_size_width = self.transformer.config.sample_width // p
|
450
|
+
base_size_height = self.transformer.config.sample_height // p
|
451
|
+
|
452
|
+
if p_t is None:
|
453
|
+
# CogVideoX 1.0
|
454
|
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
455
|
+
(grid_height, grid_width), base_size_width, base_size_height
|
456
|
+
)
|
457
|
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
458
|
+
embed_dim=self.transformer.config.attention_head_dim,
|
459
|
+
crops_coords=grid_crops_coords,
|
460
|
+
grid_size=(grid_height, grid_width),
|
461
|
+
temporal_size=num_frames,
|
462
|
+
device=device,
|
463
|
+
)
|
464
|
+
else:
|
465
|
+
# CogVideoX 1.5
|
466
|
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
467
|
+
|
468
|
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
469
|
+
embed_dim=self.transformer.config.attention_head_dim,
|
470
|
+
crops_coords=None,
|
471
|
+
grid_size=(grid_height, grid_width),
|
472
|
+
temporal_size=base_num_frames,
|
473
|
+
grid_type="slice",
|
474
|
+
max_size=(base_size_height, base_size_width),
|
475
|
+
device=device,
|
476
|
+
)
|
457
477
|
|
458
|
-
freqs_cos = freqs_cos.to(device=device)
|
459
|
-
freqs_sin = freqs_sin.to(device=device)
|
460
478
|
return freqs_cos, freqs_sin
|
461
479
|
|
462
480
|
@property
|
@@ -481,9 +499,9 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
481
499
|
self,
|
482
500
|
prompt: Optional[Union[str, List[str]]] = None,
|
483
501
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
484
|
-
height: int =
|
485
|
-
width: int =
|
486
|
-
num_frames: int =
|
502
|
+
height: Optional[int] = None,
|
503
|
+
width: Optional[int] = None,
|
504
|
+
num_frames: Optional[int] = None,
|
487
505
|
num_inference_steps: int = 50,
|
488
506
|
timesteps: Optional[List[int]] = None,
|
489
507
|
guidance_scale: float = 6,
|
@@ -583,14 +601,13 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
583
601
|
`tuple`. When returning a tuple, the first element is a list with the generated images.
|
584
602
|
"""
|
585
603
|
|
586
|
-
if num_frames > 49:
|
587
|
-
raise ValueError(
|
588
|
-
"The number of frames must be less than 49 for now due to static positional embeddings. This will be updated in the future to remove this limitation."
|
589
|
-
)
|
590
|
-
|
591
604
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
592
605
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
593
606
|
|
607
|
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
608
|
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
609
|
+
num_frames = num_frames or self.transformer.config.sample_frames
|
610
|
+
|
594
611
|
num_videos_per_prompt = 1
|
595
612
|
|
596
613
|
# 1. Check inputs. Raise error if not correct
|
@@ -640,7 +657,16 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
640
657
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
641
658
|
self._num_timesteps = len(timesteps)
|
642
659
|
|
643
|
-
# 5. Prepare latents
|
660
|
+
# 5. Prepare latents
|
661
|
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
662
|
+
|
663
|
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
664
|
+
patch_size_t = self.transformer.config.patch_size_t
|
665
|
+
additional_frames = 0
|
666
|
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
667
|
+
additional_frames = patch_size_t - latent_frames % patch_size_t
|
668
|
+
num_frames += additional_frames * self.vae_scale_factor_temporal
|
669
|
+
|
644
670
|
latent_channels = self.transformer.config.in_channels
|
645
671
|
latents = self.prepare_latents(
|
646
672
|
batch_size * num_videos_per_prompt,
|
@@ -730,6 +756,8 @@ class CogVideoXPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
730
756
|
progress_bar.update()
|
731
757
|
|
732
758
|
if not output_type == "latent":
|
759
|
+
# Discard any padding frames that were added for CogVideoX 1.5
|
760
|
+
latents = latents[:, additional_frames:]
|
733
761
|
video = self.decode_latents(latents)
|
734
762
|
video = self.video_processor.postprocess_video(video=video, output_type=output_type)
|
735
763
|
else:
|
@@ -488,21 +488,39 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
488
488
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
489
489
|
grid_height = height // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
490
490
|
grid_width = width // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
491
|
-
base_size_width = 720 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
492
|
-
base_size_height = 480 // (self.vae_scale_factor_spatial * self.transformer.config.patch_size)
|
493
491
|
|
494
|
-
|
495
|
-
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
492
|
+
p = self.transformer.config.patch_size
|
493
|
+
p_t = self.transformer.config.patch_size_t
|
494
|
+
|
495
|
+
base_size_width = self.transformer.config.sample_width // p
|
496
|
+
base_size_height = self.transformer.config.sample_height // p
|
497
|
+
|
498
|
+
if p_t is None:
|
499
|
+
# CogVideoX 1.0
|
500
|
+
grid_crops_coords = get_resize_crop_region_for_grid(
|
501
|
+
(grid_height, grid_width), base_size_width, base_size_height
|
502
|
+
)
|
503
|
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
504
|
+
embed_dim=self.transformer.config.attention_head_dim,
|
505
|
+
crops_coords=grid_crops_coords,
|
506
|
+
grid_size=(grid_height, grid_width),
|
507
|
+
temporal_size=num_frames,
|
508
|
+
device=device,
|
509
|
+
)
|
510
|
+
else:
|
511
|
+
# CogVideoX 1.5
|
512
|
+
base_num_frames = (num_frames + p_t - 1) // p_t
|
513
|
+
|
514
|
+
freqs_cos, freqs_sin = get_3d_rotary_pos_embed(
|
515
|
+
embed_dim=self.transformer.config.attention_head_dim,
|
516
|
+
crops_coords=None,
|
517
|
+
grid_size=(grid_height, grid_width),
|
518
|
+
temporal_size=base_num_frames,
|
519
|
+
grid_type="slice",
|
520
|
+
max_size=(base_size_height, base_size_width),
|
521
|
+
device=device,
|
522
|
+
)
|
503
523
|
|
504
|
-
freqs_cos = freqs_cos.to(device=device)
|
505
|
-
freqs_sin = freqs_sin.to(device=device)
|
506
524
|
return freqs_cos, freqs_sin
|
507
525
|
|
508
526
|
@property
|
@@ -528,8 +546,8 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
528
546
|
prompt: Optional[Union[str, List[str]]] = None,
|
529
547
|
negative_prompt: Optional[Union[str, List[str]]] = None,
|
530
548
|
control_video: Optional[List[Image.Image]] = None,
|
531
|
-
height: int =
|
532
|
-
width: int =
|
549
|
+
height: Optional[int] = None,
|
550
|
+
width: Optional[int] = None,
|
533
551
|
num_inference_steps: int = 50,
|
534
552
|
timesteps: Optional[List[int]] = None,
|
535
553
|
guidance_scale: float = 6,
|
@@ -634,6 +652,13 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
634
652
|
if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
|
635
653
|
callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
|
636
654
|
|
655
|
+
if control_video is not None and isinstance(control_video[0], Image.Image):
|
656
|
+
control_video = [control_video]
|
657
|
+
|
658
|
+
height = height or self.transformer.config.sample_height * self.vae_scale_factor_spatial
|
659
|
+
width = width or self.transformer.config.sample_width * self.vae_scale_factor_spatial
|
660
|
+
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
|
661
|
+
|
637
662
|
num_videos_per_prompt = 1
|
638
663
|
|
639
664
|
# 1. Check inputs. Raise error if not correct
|
@@ -660,9 +685,6 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
660
685
|
else:
|
661
686
|
batch_size = prompt_embeds.shape[0]
|
662
687
|
|
663
|
-
if control_video is not None and isinstance(control_video[0], Image.Image):
|
664
|
-
control_video = [control_video]
|
665
|
-
|
666
688
|
device = self._execution_device
|
667
689
|
|
668
690
|
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
@@ -688,9 +710,18 @@ class CogVideoXFunControlPipeline(DiffusionPipeline, CogVideoXLoraLoaderMixin):
|
|
688
710
|
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, timesteps)
|
689
711
|
self._num_timesteps = len(timesteps)
|
690
712
|
|
691
|
-
# 5. Prepare latents
|
713
|
+
# 5. Prepare latents
|
714
|
+
latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1
|
715
|
+
|
716
|
+
# For CogVideoX 1.5, the latent frames should be padded to make it divisible by patch_size_t
|
717
|
+
patch_size_t = self.transformer.config.patch_size_t
|
718
|
+
if patch_size_t is not None and latent_frames % patch_size_t != 0:
|
719
|
+
raise ValueError(
|
720
|
+
f"The number of latent frames must be divisible by `{patch_size_t=}` but the given video "
|
721
|
+
f"contains {latent_frames=}, which is not divisible."
|
722
|
+
)
|
723
|
+
|
692
724
|
latent_channels = self.transformer.config.in_channels // 2
|
693
|
-
num_frames = len(control_video[0]) if control_video is not None else control_video_latents.size(2)
|
694
725
|
latents = self.prepare_latents(
|
695
726
|
batch_size * num_videos_per_prompt,
|
696
727
|
latent_channels,
|