diffusers 0.29.2__py3-none-any.whl → 0.30.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2252 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +3 -14
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +293 -8
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1937 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1271 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +403 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +543 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +485 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +746 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +50 -6
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +37 -15
- diffusers/utils/loading_utils.py +80 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.1.dist-info}/top_level.txt +0 -0
@@ -22,7 +22,7 @@ import torch
|
|
22
22
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection
|
23
23
|
|
24
24
|
from ...image_processor import PipelineImageInput
|
25
|
-
from ...loaders import FromSingleFileMixin, IPAdapterMixin,
|
25
|
+
from ...loaders import FromSingleFileMixin, IPAdapterMixin, StableDiffusionLoraLoaderMixin, TextualInversionLoaderMixin
|
26
26
|
from ...models import AutoencoderKL, ImageProjection, UNet2DConditionModel, UNetMotionModel
|
27
27
|
from ...models.lora import adjust_lora_scale_text_encoder
|
28
28
|
from ...models.unets.unet_motion_model import MotionAdapter
|
@@ -54,22 +54,21 @@ EXAMPLE_DOC_STRING = """
|
|
54
54
|
Examples:
|
55
55
|
```py
|
56
56
|
>>> import torch
|
57
|
-
>>> from diffusers import
|
58
|
-
... EulerDiscreteScheduler,
|
59
|
-
... MotionAdapter,
|
60
|
-
... PIAPipeline,
|
61
|
-
... )
|
57
|
+
>>> from diffusers import EulerDiscreteScheduler, MotionAdapter, PIAPipeline
|
62
58
|
>>> from diffusers.utils import export_to_gif, load_image
|
63
59
|
|
64
|
-
>>> adapter = MotionAdapter.from_pretrained("
|
65
|
-
>>> pipe = PIAPipeline.from_pretrained(
|
60
|
+
>>> adapter = MotionAdapter.from_pretrained("openmmlab/PIA-condition-adapter")
|
61
|
+
>>> pipe = PIAPipeline.from_pretrained(
|
62
|
+
... "SG161222/Realistic_Vision_V6.0_B1_noVAE", motion_adapter=adapter, torch_dtype=torch.float16
|
63
|
+
... )
|
64
|
+
|
66
65
|
>>> pipe.scheduler = EulerDiscreteScheduler.from_config(pipe.scheduler.config)
|
67
66
|
>>> image = load_image(
|
68
67
|
... "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/pix2pix/cat_6.png?download=true"
|
69
68
|
... )
|
70
69
|
>>> image = image.resize((512, 512))
|
71
70
|
>>> prompt = "cat in a hat"
|
72
|
-
>>> negative_prompt = "wrong white balance, dark, sketches,worst quality,low quality, deformed, distorted
|
71
|
+
>>> negative_prompt = "wrong white balance, dark, sketches, worst quality, low quality, deformed, distorted"
|
73
72
|
>>> generator = torch.Generator("cpu").manual_seed(0)
|
74
73
|
>>> output = pipe(image=image, prompt=prompt, negative_prompt=negative_prompt, generator=generator)
|
75
74
|
>>> frames = output.frames[0]
|
@@ -128,7 +127,7 @@ class PIAPipeline(
|
|
128
127
|
StableDiffusionMixin,
|
129
128
|
TextualInversionLoaderMixin,
|
130
129
|
IPAdapterMixin,
|
131
|
-
|
130
|
+
StableDiffusionLoraLoaderMixin,
|
132
131
|
FromSingleFileMixin,
|
133
132
|
FreeInitMixin,
|
134
133
|
):
|
@@ -140,8 +139,8 @@ class PIAPipeline(
|
|
140
139
|
|
141
140
|
The pipeline also inherits the following loading methods:
|
142
141
|
- [`~loaders.TextualInversionLoaderMixin.load_textual_inversion`] for loading textual inversion embeddings
|
143
|
-
- [`~loaders.
|
144
|
-
- [`~loaders.
|
142
|
+
- [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`] for loading LoRA weights
|
143
|
+
- [`~loaders.StableDiffusionLoraLoaderMixin.save_lora_weights`] for saving LoRA weights
|
145
144
|
- [`~loaders.IPAdapterMixin.load_ip_adapter`] for loading IP Adapters
|
146
145
|
|
147
146
|
Args:
|
@@ -243,7 +242,7 @@ class PIAPipeline(
|
|
243
242
|
"""
|
244
243
|
# set lora scale so that monkey patched LoRA
|
245
244
|
# function of text encoder can correctly access it
|
246
|
-
if lora_scale is not None and isinstance(self,
|
245
|
+
if lora_scale is not None and isinstance(self, StableDiffusionLoraLoaderMixin):
|
247
246
|
self._lora_scale = lora_scale
|
248
247
|
|
249
248
|
# dynamically adjust the LoRA scale
|
@@ -376,7 +375,7 @@ class PIAPipeline(
|
|
376
375
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
377
376
|
|
378
377
|
if self.text_encoder is not None:
|
379
|
-
if isinstance(self,
|
378
|
+
if isinstance(self, StableDiffusionLoraLoaderMixin) and USE_PEFT_BACKEND:
|
380
379
|
# Retrieve the original scale by scaling back the LoRA layers
|
381
380
|
unscale_lora_layers(self.text_encoder, lora_scale)
|
382
381
|
|
@@ -505,6 +504,9 @@ class PIAPipeline(
|
|
505
504
|
def prepare_ip_adapter_image_embeds(
|
506
505
|
self, ip_adapter_image, ip_adapter_image_embeds, device, num_images_per_prompt, do_classifier_free_guidance
|
507
506
|
):
|
507
|
+
image_embeds = []
|
508
|
+
if do_classifier_free_guidance:
|
509
|
+
negative_image_embeds = []
|
508
510
|
if ip_adapter_image_embeds is None:
|
509
511
|
if not isinstance(ip_adapter_image, list):
|
510
512
|
ip_adapter_image = [ip_adapter_image]
|
@@ -514,7 +516,6 @@ class PIAPipeline(
|
|
514
516
|
f"`ip_adapter_image` must have same length as the number of IP Adapters. Got {len(ip_adapter_image)} images and {len(self.unet.encoder_hid_proj.image_projection_layers)} IP Adapters."
|
515
517
|
)
|
516
518
|
|
517
|
-
image_embeds = []
|
518
519
|
for single_ip_adapter_image, image_proj_layer in zip(
|
519
520
|
ip_adapter_image, self.unet.encoder_hid_proj.image_projection_layers
|
520
521
|
):
|
@@ -522,36 +523,28 @@ class PIAPipeline(
|
|
522
523
|
single_image_embeds, single_negative_image_embeds = self.encode_image(
|
523
524
|
single_ip_adapter_image, device, 1, output_hidden_state
|
524
525
|
)
|
525
|
-
single_image_embeds = torch.stack([single_image_embeds] * num_images_per_prompt, dim=0)
|
526
|
-
single_negative_image_embeds = torch.stack(
|
527
|
-
[single_negative_image_embeds] * num_images_per_prompt, dim=0
|
528
|
-
)
|
529
526
|
|
527
|
+
image_embeds.append(single_image_embeds[None, :])
|
530
528
|
if do_classifier_free_guidance:
|
531
|
-
|
532
|
-
single_image_embeds = single_image_embeds.to(device)
|
533
|
-
|
534
|
-
image_embeds.append(single_image_embeds)
|
529
|
+
negative_image_embeds.append(single_negative_image_embeds[None, :])
|
535
530
|
else:
|
536
|
-
repeat_dims = [1]
|
537
|
-
image_embeds = []
|
538
531
|
for single_image_embeds in ip_adapter_image_embeds:
|
539
532
|
if do_classifier_free_guidance:
|
540
533
|
single_negative_image_embeds, single_image_embeds = single_image_embeds.chunk(2)
|
541
|
-
|
542
|
-
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
543
|
-
)
|
544
|
-
single_negative_image_embeds = single_negative_image_embeds.repeat(
|
545
|
-
num_images_per_prompt, *(repeat_dims * len(single_negative_image_embeds.shape[1:]))
|
546
|
-
)
|
547
|
-
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds])
|
548
|
-
else:
|
549
|
-
single_image_embeds = single_image_embeds.repeat(
|
550
|
-
num_images_per_prompt, *(repeat_dims * len(single_image_embeds.shape[1:]))
|
551
|
-
)
|
534
|
+
negative_image_embeds.append(single_negative_image_embeds)
|
552
535
|
image_embeds.append(single_image_embeds)
|
553
536
|
|
554
|
-
|
537
|
+
ip_adapter_image_embeds = []
|
538
|
+
for i, single_image_embeds in enumerate(image_embeds):
|
539
|
+
single_image_embeds = torch.cat([single_image_embeds] * num_images_per_prompt, dim=0)
|
540
|
+
if do_classifier_free_guidance:
|
541
|
+
single_negative_image_embeds = torch.cat([negative_image_embeds[i]] * num_images_per_prompt, dim=0)
|
542
|
+
single_image_embeds = torch.cat([single_negative_image_embeds, single_image_embeds], dim=0)
|
543
|
+
|
544
|
+
single_image_embeds = single_image_embeds.to(device=device)
|
545
|
+
ip_adapter_image_embeds.append(single_image_embeds)
|
546
|
+
|
547
|
+
return ip_adapter_image_embeds
|
555
548
|
|
556
549
|
# Copied from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth.TextToVideoSDPipeline.prepare_latents
|
557
550
|
def prepare_latents(
|
@@ -254,9 +254,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
254
254
|
force_download (`bool`, *optional*, defaults to `False`):
|
255
255
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
256
256
|
cached versions if they exist.
|
257
|
-
|
258
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
259
|
-
of Diffusers.
|
257
|
+
|
260
258
|
proxies (`Dict[str, str]`, *optional*):
|
261
259
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
262
260
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -296,7 +294,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
296
294
|
>>> # see more in [the documentation](https://huggingface.co/docs/hub/security-tokens)
|
297
295
|
>>> pipeline, params = FlaxDiffusionPipeline.from_pretrained(
|
298
296
|
... "runwayml/stable-diffusion-v1-5",
|
299
|
-
...
|
297
|
+
... variant="bf16",
|
300
298
|
... dtype=jnp.bfloat16,
|
301
299
|
... )
|
302
300
|
|
@@ -310,13 +308,12 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
310
308
|
... )
|
311
309
|
|
312
310
|
>>> dpm_pipe, dpm_params = FlaxStableDiffusionPipeline.from_pretrained(
|
313
|
-
... model_id,
|
311
|
+
... model_id, variant="bf16", dtype=jnp.bfloat16, scheduler=dpmpp
|
314
312
|
... )
|
315
313
|
>>> dpm_params["scheduler"] = dpmpp_state
|
316
314
|
```
|
317
315
|
"""
|
318
316
|
cache_dir = kwargs.pop("cache_dir", None)
|
319
|
-
resume_download = kwargs.pop("resume_download", None)
|
320
317
|
proxies = kwargs.pop("proxies", None)
|
321
318
|
local_files_only = kwargs.pop("local_files_only", False)
|
322
319
|
token = kwargs.pop("token", None)
|
@@ -332,7 +329,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
332
329
|
config_dict = cls.load_config(
|
333
330
|
pretrained_model_name_or_path,
|
334
331
|
cache_dir=cache_dir,
|
335
|
-
resume_download=resume_download,
|
336
332
|
proxies=proxies,
|
337
333
|
local_files_only=local_files_only,
|
338
334
|
token=token,
|
@@ -363,7 +359,6 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
363
359
|
cached_folder = snapshot_download(
|
364
360
|
pretrained_model_name_or_path,
|
365
361
|
cache_dir=cache_dir,
|
366
|
-
resume_download=resume_download,
|
367
362
|
proxies=proxies,
|
368
363
|
local_files_only=local_files_only,
|
369
364
|
token=token,
|
@@ -564,7 +559,7 @@ class FlaxDiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
564
559
|
... )
|
565
560
|
|
566
561
|
>>> text2img = FlaxStableDiffusionPipeline.from_pretrained(
|
567
|
-
... "runwayml/stable-diffusion-v1-5",
|
562
|
+
... "runwayml/stable-diffusion-v1-5", variant="bf16", dtype=jnp.bfloat16
|
568
563
|
... )
|
569
564
|
>>> img2img = FlaxStableDiffusionImg2ImgPipeline(**text2img.components)
|
570
565
|
```
|
@@ -435,7 +435,6 @@ def _load_empty_model(
|
|
435
435
|
return_unused_kwargs=True,
|
436
436
|
return_commit_hash=True,
|
437
437
|
force_download=kwargs.pop("force_download", False),
|
438
|
-
resume_download=kwargs.pop("resume_download", None),
|
439
438
|
proxies=kwargs.pop("proxies", None),
|
440
439
|
local_files_only=kwargs.pop("local_files_only", False),
|
441
440
|
token=kwargs.pop("token", None),
|
@@ -454,7 +453,6 @@ def _load_empty_model(
|
|
454
453
|
cached_folder,
|
455
454
|
subfolder=name,
|
456
455
|
force_download=kwargs.pop("force_download", False),
|
457
|
-
resume_download=kwargs.pop("resume_download", None),
|
458
456
|
proxies=kwargs.pop("proxies", None),
|
459
457
|
local_files_only=kwargs.pop("local_files_only", False),
|
460
458
|
token=kwargs.pop("token", None),
|
@@ -544,7 +542,6 @@ def _get_final_device_map(device_map, pipeline_class, passed_class_obj, init_dic
|
|
544
542
|
torch_dtype=torch_dtype,
|
545
543
|
cached_folder=kwargs.get("cached_folder", None),
|
546
544
|
force_download=kwargs.get("force_download", None),
|
547
|
-
resume_download=kwargs.get("resume_download", None),
|
548
545
|
proxies=kwargs.get("proxies", None),
|
549
546
|
local_files_only=kwargs.get("local_files_only", None),
|
550
547
|
token=kwargs.get("token", None),
|
@@ -533,9 +533,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
533
533
|
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
534
534
|
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
535
535
|
is not used.
|
536
|
-
|
537
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
538
|
-
of Diffusers.
|
536
|
+
|
539
537
|
proxies (`Dict[str, str]`, *optional*):
|
540
538
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
541
539
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -625,7 +623,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
625
623
|
```
|
626
624
|
"""
|
627
625
|
cache_dir = kwargs.pop("cache_dir", None)
|
628
|
-
resume_download = kwargs.pop("resume_download", None)
|
629
626
|
force_download = kwargs.pop("force_download", False)
|
630
627
|
proxies = kwargs.pop("proxies", None)
|
631
628
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -702,7 +699,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
702
699
|
cached_folder = cls.download(
|
703
700
|
pretrained_model_name_or_path,
|
704
701
|
cache_dir=cache_dir,
|
705
|
-
resume_download=resume_download,
|
706
702
|
force_download=force_download,
|
707
703
|
proxies=proxies,
|
708
704
|
local_files_only=local_files_only,
|
@@ -842,7 +838,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
842
838
|
torch_dtype=torch_dtype,
|
843
839
|
cached_folder=cached_folder,
|
844
840
|
force_download=force_download,
|
845
|
-
resume_download=resume_download,
|
846
841
|
proxies=proxies,
|
847
842
|
local_files_only=local_files_only,
|
848
843
|
token=token,
|
@@ -910,7 +905,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
910
905
|
connected_pipes = {prefix: getattr(modelcard.data, prefix, [None])[0] for prefix in CONNECTED_PIPES_KEYS}
|
911
906
|
load_kwargs = {
|
912
907
|
"cache_dir": cache_dir,
|
913
|
-
"resume_download": resume_download,
|
914
908
|
"force_download": force_download,
|
915
909
|
"proxies": proxies,
|
916
910
|
"local_files_only": local_files_only,
|
@@ -1216,9 +1210,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1216
1210
|
force_download (`bool`, *optional*, defaults to `False`):
|
1217
1211
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
1218
1212
|
cached versions if they exist.
|
1219
|
-
|
1220
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
1221
|
-
of Diffusers.
|
1213
|
+
|
1222
1214
|
proxies (`Dict[str, str]`, *optional*):
|
1223
1215
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
1224
1216
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -1271,7 +1263,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1271
1263
|
|
1272
1264
|
"""
|
1273
1265
|
cache_dir = kwargs.pop("cache_dir", None)
|
1274
|
-
resume_download = kwargs.pop("resume_download", None)
|
1275
1266
|
force_download = kwargs.pop("force_download", False)
|
1276
1267
|
proxies = kwargs.pop("proxies", None)
|
1277
1268
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -1311,7 +1302,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1311
1302
|
revision=revision,
|
1312
1303
|
proxies=proxies,
|
1313
1304
|
force_download=force_download,
|
1314
|
-
resume_download=resume_download,
|
1315
1305
|
token=token,
|
1316
1306
|
)
|
1317
1307
|
|
@@ -1500,7 +1490,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1500
1490
|
cached_folder = snapshot_download(
|
1501
1491
|
pretrained_model_name,
|
1502
1492
|
cache_dir=cache_dir,
|
1503
|
-
resume_download=resume_download,
|
1504
1493
|
proxies=proxies,
|
1505
1494
|
local_files_only=local_files_only,
|
1506
1495
|
token=token,
|
@@ -1523,7 +1512,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1523
1512
|
for connected_pipe_repo_id in connected_pipes:
|
1524
1513
|
download_kwargs = {
|
1525
1514
|
"cache_dir": cache_dir,
|
1526
|
-
"resume_download": resume_download,
|
1527
1515
|
"force_download": force_download,
|
1528
1516
|
"proxies": proxies,
|
1529
1517
|
"local_files_only": local_files_only,
|
@@ -661,7 +661,6 @@ class SemanticStableDiffusionPipeline(DiffusionPipeline, StableDiffusionMixin):
|
|
661
661
|
noise_guidance_edit_tmp = torch.einsum(
|
662
662
|
"cb,cbijk->bijk", concept_weights_tmp, noise_guidance_edit_tmp
|
663
663
|
)
|
664
|
-
noise_guidance_edit_tmp = noise_guidance_edit_tmp
|
665
664
|
noise_guidance = noise_guidance + noise_guidance_edit_tmp
|
666
665
|
|
667
666
|
self.sem_guidance[i] = noise_guidance_edit_tmp.detach().cpu()
|
@@ -0,0 +1,50 @@
|
|
1
|
+
from typing import TYPE_CHECKING
|
2
|
+
|
3
|
+
from ...utils import (
|
4
|
+
DIFFUSERS_SLOW_IMPORT,
|
5
|
+
OptionalDependencyNotAvailable,
|
6
|
+
_LazyModule,
|
7
|
+
get_objects_from_module,
|
8
|
+
is_torch_available,
|
9
|
+
is_transformers_available,
|
10
|
+
is_transformers_version,
|
11
|
+
)
|
12
|
+
|
13
|
+
|
14
|
+
_dummy_objects = {}
|
15
|
+
_import_structure = {}
|
16
|
+
|
17
|
+
try:
|
18
|
+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
19
|
+
raise OptionalDependencyNotAvailable()
|
20
|
+
except OptionalDependencyNotAvailable:
|
21
|
+
from ...utils import dummy_torch_and_transformers_objects
|
22
|
+
|
23
|
+
_dummy_objects.update(get_objects_from_module(dummy_torch_and_transformers_objects))
|
24
|
+
else:
|
25
|
+
_import_structure["modeling_stable_audio"] = ["StableAudioProjectionModel"]
|
26
|
+
_import_structure["pipeline_stable_audio"] = ["StableAudioPipeline"]
|
27
|
+
|
28
|
+
|
29
|
+
if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
30
|
+
try:
|
31
|
+
if not (is_transformers_available() and is_torch_available() and is_transformers_version(">=", "4.27.0")):
|
32
|
+
raise OptionalDependencyNotAvailable()
|
33
|
+
except OptionalDependencyNotAvailable:
|
34
|
+
from ...utils.dummy_torch_and_transformers_objects import *
|
35
|
+
|
36
|
+
else:
|
37
|
+
from .modeling_stable_audio import StableAudioProjectionModel
|
38
|
+
from .pipeline_stable_audio import StableAudioPipeline
|
39
|
+
|
40
|
+
else:
|
41
|
+
import sys
|
42
|
+
|
43
|
+
sys.modules[__name__] = _LazyModule(
|
44
|
+
__name__,
|
45
|
+
globals()["__file__"],
|
46
|
+
_import_structure,
|
47
|
+
module_spec=__spec__,
|
48
|
+
)
|
49
|
+
for name, value in _dummy_objects.items():
|
50
|
+
setattr(sys.modules[__name__], name, value)
|
@@ -0,0 +1,158 @@
|
|
1
|
+
# Copyright 2024 Stability AI and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from math import pi
|
17
|
+
from typing import Optional
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn as nn
|
21
|
+
import torch.utils.checkpoint
|
22
|
+
|
23
|
+
from ...configuration_utils import ConfigMixin, register_to_config
|
24
|
+
from ...models.modeling_utils import ModelMixin
|
25
|
+
from ...utils import BaseOutput, logging
|
26
|
+
|
27
|
+
|
28
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
29
|
+
|
30
|
+
|
31
|
+
class StableAudioPositionalEmbedding(nn.Module):
|
32
|
+
"""Used for continuous time"""
|
33
|
+
|
34
|
+
def __init__(self, dim: int):
|
35
|
+
super().__init__()
|
36
|
+
assert (dim % 2) == 0
|
37
|
+
half_dim = dim // 2
|
38
|
+
self.weights = nn.Parameter(torch.randn(half_dim))
|
39
|
+
|
40
|
+
def forward(self, times: torch.Tensor) -> torch.Tensor:
|
41
|
+
times = times[..., None]
|
42
|
+
freqs = times * self.weights[None] * 2 * pi
|
43
|
+
fouriered = torch.cat((freqs.sin(), freqs.cos()), dim=-1)
|
44
|
+
fouriered = torch.cat((times, fouriered), dim=-1)
|
45
|
+
return fouriered
|
46
|
+
|
47
|
+
|
48
|
+
@dataclass
|
49
|
+
class StableAudioProjectionModelOutput(BaseOutput):
|
50
|
+
"""
|
51
|
+
Args:
|
52
|
+
Class for StableAudio projection layer's outputs.
|
53
|
+
text_hidden_states (`torch.Tensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
|
54
|
+
Sequence of hidden-states obtained by linearly projecting the hidden-states for the text encoder.
|
55
|
+
seconds_start_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
|
56
|
+
Sequence of hidden-states obtained by linearly projecting the audio start hidden states.
|
57
|
+
seconds_end_hidden_states (`torch.Tensor` of shape `(batch_size, 1, hidden_size)`, *optional*):
|
58
|
+
Sequence of hidden-states obtained by linearly projecting the audio end hidden states.
|
59
|
+
"""
|
60
|
+
|
61
|
+
text_hidden_states: Optional[torch.Tensor] = None
|
62
|
+
seconds_start_hidden_states: Optional[torch.Tensor] = None
|
63
|
+
seconds_end_hidden_states: Optional[torch.Tensor] = None
|
64
|
+
|
65
|
+
|
66
|
+
class StableAudioNumberConditioner(nn.Module):
|
67
|
+
"""
|
68
|
+
A simple linear projection model to map numbers to a latent space.
|
69
|
+
|
70
|
+
Args:
|
71
|
+
number_embedding_dim (`int`):
|
72
|
+
Dimensionality of the number embeddings.
|
73
|
+
min_value (`int`):
|
74
|
+
The minimum value of the seconds number conditioning modules.
|
75
|
+
max_value (`int`):
|
76
|
+
The maximum value of the seconds number conditioning modules
|
77
|
+
internal_dim (`int`):
|
78
|
+
Dimensionality of the intermediate number hidden states.
|
79
|
+
"""
|
80
|
+
|
81
|
+
def __init__(
|
82
|
+
self,
|
83
|
+
number_embedding_dim,
|
84
|
+
min_value,
|
85
|
+
max_value,
|
86
|
+
internal_dim: Optional[int] = 256,
|
87
|
+
):
|
88
|
+
super().__init__()
|
89
|
+
self.time_positional_embedding = nn.Sequential(
|
90
|
+
StableAudioPositionalEmbedding(internal_dim),
|
91
|
+
nn.Linear(in_features=internal_dim + 1, out_features=number_embedding_dim),
|
92
|
+
)
|
93
|
+
|
94
|
+
self.number_embedding_dim = number_embedding_dim
|
95
|
+
self.min_value = min_value
|
96
|
+
self.max_value = max_value
|
97
|
+
|
98
|
+
def forward(
|
99
|
+
self,
|
100
|
+
floats: torch.Tensor,
|
101
|
+
):
|
102
|
+
floats = floats.clamp(self.min_value, self.max_value)
|
103
|
+
|
104
|
+
normalized_floats = (floats - self.min_value) / (self.max_value - self.min_value)
|
105
|
+
|
106
|
+
# Cast floats to same type as embedder
|
107
|
+
embedder_dtype = next(self.time_positional_embedding.parameters()).dtype
|
108
|
+
normalized_floats = normalized_floats.to(embedder_dtype)
|
109
|
+
|
110
|
+
embedding = self.time_positional_embedding(normalized_floats)
|
111
|
+
float_embeds = embedding.view(-1, 1, self.number_embedding_dim)
|
112
|
+
|
113
|
+
return float_embeds
|
114
|
+
|
115
|
+
|
116
|
+
class StableAudioProjectionModel(ModelMixin, ConfigMixin):
|
117
|
+
"""
|
118
|
+
A simple linear projection model to map the conditioning values to a shared latent space.
|
119
|
+
|
120
|
+
Args:
|
121
|
+
text_encoder_dim (`int`):
|
122
|
+
Dimensionality of the text embeddings from the text encoder (T5).
|
123
|
+
conditioning_dim (`int`):
|
124
|
+
Dimensionality of the output conditioning tensors.
|
125
|
+
min_value (`int`):
|
126
|
+
The minimum value of the seconds number conditioning modules.
|
127
|
+
max_value (`int`):
|
128
|
+
The maximum value of the seconds number conditioning modules
|
129
|
+
"""
|
130
|
+
|
131
|
+
@register_to_config
|
132
|
+
def __init__(self, text_encoder_dim, conditioning_dim, min_value, max_value):
|
133
|
+
super().__init__()
|
134
|
+
self.text_projection = (
|
135
|
+
nn.Identity() if conditioning_dim == text_encoder_dim else nn.Linear(text_encoder_dim, conditioning_dim)
|
136
|
+
)
|
137
|
+
self.start_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
|
138
|
+
self.end_number_conditioner = StableAudioNumberConditioner(conditioning_dim, min_value, max_value)
|
139
|
+
|
140
|
+
def forward(
|
141
|
+
self,
|
142
|
+
text_hidden_states: Optional[torch.Tensor] = None,
|
143
|
+
start_seconds: Optional[torch.Tensor] = None,
|
144
|
+
end_seconds: Optional[torch.Tensor] = None,
|
145
|
+
):
|
146
|
+
text_hidden_states = (
|
147
|
+
text_hidden_states if text_hidden_states is None else self.text_projection(text_hidden_states)
|
148
|
+
)
|
149
|
+
seconds_start_hidden_states = (
|
150
|
+
start_seconds if start_seconds is None else self.start_number_conditioner(start_seconds)
|
151
|
+
)
|
152
|
+
seconds_end_hidden_states = end_seconds if end_seconds is None else self.end_number_conditioner(end_seconds)
|
153
|
+
|
154
|
+
return StableAudioProjectionModelOutput(
|
155
|
+
text_hidden_states=text_hidden_states,
|
156
|
+
seconds_start_hidden_states=seconds_start_hidden_states,
|
157
|
+
seconds_end_hidden_states=seconds_end_hidden_states,
|
158
|
+
)
|