diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- diffusers/__init__.py +3 -1
- diffusers/commands/fp16_safetensors.py +2 -7
- diffusers/configuration_utils.py +23 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/loaders.py +62 -64
- diffusers/models/__init__.py +1 -0
- diffusers/models/activations.py +2 -0
- diffusers/models/attention.py +45 -1
- diffusers/models/autoencoder_tiny.py +193 -0
- diffusers/models/controlnet.py +1 -1
- diffusers/models/embeddings.py +56 -0
- diffusers/models/lora.py +0 -6
- diffusers/models/modeling_flax_utils.py +28 -2
- diffusers/models/modeling_utils.py +33 -16
- diffusers/models/transformer_2d.py +26 -9
- diffusers/models/unet_1d.py +2 -2
- diffusers/models/unet_2d_blocks.py +106 -56
- diffusers/models/unet_2d_condition.py +20 -5
- diffusers/models/vae.py +106 -1
- diffusers/pipelines/__init__.py +1 -0
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
- diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
- diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
- diffusers/pipelines/auto_pipeline.py +33 -43
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
- diffusers/pipelines/pipeline_flax_utils.py +41 -4
- diffusers/pipelines/pipeline_utils.py +60 -16
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/__init__.py +1 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
- diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
- diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
- diffusers/schedulers/scheduling_consistency_models.py +70 -57
- diffusers/schedulers/scheduling_ddim.py +76 -71
- diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
- diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
- diffusers/schedulers/scheduling_ddpm.py +68 -67
- diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
- diffusers/schedulers/scheduling_deis_multistep.py +93 -85
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
- diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
- diffusers/schedulers/scheduling_euler_discrete.py +63 -56
- diffusers/schedulers/scheduling_heun_discrete.py +57 -45
- diffusers/schedulers/scheduling_ipndm.py +27 -22
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
- diffusers/schedulers/scheduling_karras_ve.py +55 -45
- diffusers/schedulers/scheduling_lms_discrete.py +58 -52
- diffusers/schedulers/scheduling_pndm.py +77 -62
- diffusers/schedulers/scheduling_repaint.py +56 -38
- diffusers/schedulers/scheduling_sde_ve.py +62 -50
- diffusers/schedulers/scheduling_sde_vp.py +32 -11
- diffusers/schedulers/scheduling_unclip.py +3 -3
- diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
- diffusers/schedulers/scheduling_utils.py +41 -35
- diffusers/schedulers/scheduling_utils_flax.py +8 -2
- diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
- diffusers/utils/hub_utils.py +105 -2
- diffusers/utils/import_utils.py +0 -4
- diffusers/utils/pil_utils.py +19 -0
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
- diffusers/models/cross_attention.py +0 -94
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
- {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import inspect
|
16
|
+
import os
|
16
17
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
17
18
|
|
18
19
|
import numpy as np
|
@@ -221,9 +222,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
|
|
221
222
|
return mask, masked_image
|
222
223
|
|
223
224
|
|
224
|
-
class StableDiffusionXLInpaintPipeline(
|
225
|
-
DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
|
226
|
-
):
|
225
|
+
class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin):
|
227
226
|
r"""
|
228
227
|
Pipeline for text-to-image generation using Stable Diffusion XL.
|
229
228
|
|
@@ -231,7 +230,6 @@ class StableDiffusionXLInpaintPipeline(
|
|
231
230
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
232
231
|
|
233
232
|
In addition the pipeline inherits the following loading methods:
|
234
|
-
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
235
233
|
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
236
234
|
- *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
|
237
235
|
|
@@ -458,7 +456,6 @@ class StableDiffusionXLInpaintPipeline(
|
|
458
456
|
|
459
457
|
text_input_ids = text_inputs.input_ids
|
460
458
|
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
461
|
-
untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
|
462
459
|
|
463
460
|
if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
|
464
461
|
text_input_ids, untruncated_ids
|
@@ -993,7 +990,7 @@ class StableDiffusionXLInpaintPipeline(
|
|
993
990
|
cross_attention_kwargs (`dict`, *optional*):
|
994
991
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
995
992
|
`self.processor` in
|
996
|
-
[diffusers.
|
993
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
997
994
|
original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
|
998
995
|
If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
|
999
996
|
`original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
|
@@ -1296,3 +1293,76 @@ class StableDiffusionXLInpaintPipeline(
|
|
1296
1293
|
return (image,)
|
1297
1294
|
|
1298
1295
|
return StableDiffusionXLPipelineOutput(images=image)
|
1296
|
+
|
1297
|
+
# Overrride to properly handle the loading and unloading of the additional text encoder.
|
1298
|
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
|
1299
|
+
def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
|
1300
|
+
# We could have accessed the unet config from `lora_state_dict()` too. We pass
|
1301
|
+
# it here explicitly to be able to tell that it's coming from an SDXL
|
1302
|
+
# pipeline.
|
1303
|
+
state_dict, network_alphas = self.lora_state_dict(
|
1304
|
+
pretrained_model_name_or_path_or_dict,
|
1305
|
+
unet_config=self.unet.config,
|
1306
|
+
**kwargs,
|
1307
|
+
)
|
1308
|
+
self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
|
1309
|
+
|
1310
|
+
text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
|
1311
|
+
if len(text_encoder_state_dict) > 0:
|
1312
|
+
self.load_lora_into_text_encoder(
|
1313
|
+
text_encoder_state_dict,
|
1314
|
+
network_alphas=network_alphas,
|
1315
|
+
text_encoder=self.text_encoder,
|
1316
|
+
prefix="text_encoder",
|
1317
|
+
lora_scale=self.lora_scale,
|
1318
|
+
)
|
1319
|
+
|
1320
|
+
text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
|
1321
|
+
if len(text_encoder_2_state_dict) > 0:
|
1322
|
+
self.load_lora_into_text_encoder(
|
1323
|
+
text_encoder_2_state_dict,
|
1324
|
+
network_alphas=network_alphas,
|
1325
|
+
text_encoder=self.text_encoder_2,
|
1326
|
+
prefix="text_encoder_2",
|
1327
|
+
lora_scale=self.lora_scale,
|
1328
|
+
)
|
1329
|
+
|
1330
|
+
@classmethod
|
1331
|
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
|
1332
|
+
def save_lora_weights(
|
1333
|
+
self,
|
1334
|
+
save_directory: Union[str, os.PathLike],
|
1335
|
+
unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1336
|
+
text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1337
|
+
text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
|
1338
|
+
is_main_process: bool = True,
|
1339
|
+
weight_name: str = None,
|
1340
|
+
save_function: Callable = None,
|
1341
|
+
safe_serialization: bool = True,
|
1342
|
+
):
|
1343
|
+
state_dict = {}
|
1344
|
+
|
1345
|
+
def pack_weights(layers, prefix):
|
1346
|
+
layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
|
1347
|
+
layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
|
1348
|
+
return layers_state_dict
|
1349
|
+
|
1350
|
+
state_dict.update(pack_weights(unet_lora_layers, "unet"))
|
1351
|
+
|
1352
|
+
if text_encoder_lora_layers and text_encoder_2_lora_layers:
|
1353
|
+
state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
|
1354
|
+
state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
|
1355
|
+
|
1356
|
+
self.write_lora_layers(
|
1357
|
+
state_dict=state_dict,
|
1358
|
+
save_directory=save_directory,
|
1359
|
+
is_main_process=is_main_process,
|
1360
|
+
weight_name=weight_name,
|
1361
|
+
save_function=save_function,
|
1362
|
+
safe_serialization=safe_serialization,
|
1363
|
+
)
|
1364
|
+
|
1365
|
+
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
|
1366
|
+
def _remove_text_encoder_monkey_patch(self):
|
1367
|
+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
|
1368
|
+
self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
|
@@ -71,7 +71,6 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
|
71
71
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
|
72
72
|
|
73
73
|
In addition the pipeline inherits the following loading methods:
|
74
|
-
- *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
|
75
74
|
- *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
|
76
75
|
|
77
76
|
as well as the following saving methods:
|
@@ -688,7 +687,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
|
|
688
687
|
cross_attention_kwargs (`dict`, *optional*):
|
689
688
|
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
690
689
|
`self.processor` in
|
691
|
-
[diffusers.
|
690
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
692
691
|
guidance_rescale (`float`, *optional*, defaults to 0.7):
|
693
692
|
Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
|
694
693
|
Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
|
@@ -331,7 +331,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
|
331
331
|
)
|
332
332
|
prompt_embeds = prompt_embeds[0]
|
333
333
|
|
334
|
-
|
334
|
+
if self.text_encoder is not None:
|
335
|
+
prompt_embeds_dtype = self.text_encoder.dtype
|
336
|
+
elif self.unet is not None:
|
337
|
+
prompt_embeds_dtype = self.unet.dtype
|
338
|
+
else:
|
339
|
+
prompt_embeds_dtype = prompt_embeds.dtype
|
340
|
+
|
341
|
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
335
342
|
|
336
343
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
337
344
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
@@ -387,7 +394,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
|
387
394
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
388
395
|
seq_len = negative_prompt_embeds.shape[1]
|
389
396
|
|
390
|
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=
|
397
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
391
398
|
|
392
399
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
393
400
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
@@ -625,7 +632,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
|
|
625
632
|
cross_attention_kwargs (`dict`, *optional*):
|
626
633
|
A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
|
627
634
|
`self.processor` in
|
628
|
-
[diffusers.
|
635
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
629
636
|
adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
|
630
637
|
The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
|
631
638
|
residual in the original unet. If multiple adapters are specified in init, you can set the
|
@@ -258,7 +258,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
|
258
258
|
)
|
259
259
|
prompt_embeds = prompt_embeds[0]
|
260
260
|
|
261
|
-
|
261
|
+
if self.text_encoder is not None:
|
262
|
+
prompt_embeds_dtype = self.text_encoder.dtype
|
263
|
+
elif self.unet is not None:
|
264
|
+
prompt_embeds_dtype = self.unet.dtype
|
265
|
+
else:
|
266
|
+
prompt_embeds_dtype = prompt_embeds.dtype
|
267
|
+
|
268
|
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
262
269
|
|
263
270
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
264
271
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
@@ -314,7 +321,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
|
314
321
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
315
322
|
seq_len = negative_prompt_embeds.shape[1]
|
316
323
|
|
317
|
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=
|
324
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
318
325
|
|
319
326
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
320
327
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
@@ -516,7 +523,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
|
|
516
523
|
every step.
|
517
524
|
cross_attention_kwargs (`dict`, *optional*):
|
518
525
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
519
|
-
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/
|
526
|
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
520
527
|
|
521
528
|
Examples:
|
522
529
|
|
@@ -320,7 +320,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
|
320
320
|
)
|
321
321
|
prompt_embeds = prompt_embeds[0]
|
322
322
|
|
323
|
-
|
323
|
+
if self.text_encoder is not None:
|
324
|
+
prompt_embeds_dtype = self.text_encoder.dtype
|
325
|
+
elif self.unet is not None:
|
326
|
+
prompt_embeds_dtype = self.unet.dtype
|
327
|
+
else:
|
328
|
+
prompt_embeds_dtype = prompt_embeds.dtype
|
329
|
+
|
330
|
+
prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
324
331
|
|
325
332
|
bs_embed, seq_len, _ = prompt_embeds.shape
|
326
333
|
# duplicate text embeddings for each generation per prompt, using mps friendly method
|
@@ -376,7 +383,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
|
376
383
|
# duplicate unconditional embeddings for each generation per prompt, using mps friendly method
|
377
384
|
seq_len = negative_prompt_embeds.shape[1]
|
378
385
|
|
379
|
-
negative_prompt_embeds = negative_prompt_embeds.to(dtype=
|
386
|
+
negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
|
380
387
|
|
381
388
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
382
389
|
negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
|
@@ -554,7 +561,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
|
554
561
|
The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
|
555
562
|
video (`List[np.ndarray]` or `torch.FloatTensor`):
|
556
563
|
`video` frames or tensor representing a video batch to be used as the starting point for the process.
|
557
|
-
Can also
|
564
|
+
Can also accept video latents as `image`, if passing latents directly, it will not be encoded again.
|
558
565
|
strength (`float`, *optional*, defaults to 0.8):
|
559
566
|
Indicates extent to transform the reference `video`. Must be between 0 and 1. `video` is used as a
|
560
567
|
starting point, adding more noise to it the larger the `strength`. The number of denoising steps
|
@@ -600,7 +607,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
|
|
600
607
|
every step.
|
601
608
|
cross_attention_kwargs (`dict`, *optional*):
|
602
609
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
603
|
-
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/
|
610
|
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
604
611
|
|
605
612
|
Examples:
|
606
613
|
|
@@ -378,7 +378,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
|
|
378
378
|
Extra_step_kwargs.
|
379
379
|
cross_attention_kwargs:
|
380
380
|
A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
|
381
|
-
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/
|
381
|
+
[`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
382
382
|
num_warmup_steps:
|
383
383
|
number of warmup steps.
|
384
384
|
|
@@ -153,6 +153,62 @@ def get_up_block(
|
|
153
153
|
raise ValueError(f"{up_block_type} is not supported.")
|
154
154
|
|
155
155
|
|
156
|
+
class FourierEmbedder(nn.Module):
|
157
|
+
def __init__(self, num_freqs=64, temperature=100):
|
158
|
+
super().__init__()
|
159
|
+
|
160
|
+
self.num_freqs = num_freqs
|
161
|
+
self.temperature = temperature
|
162
|
+
|
163
|
+
freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
|
164
|
+
freq_bands = freq_bands[None, None, None]
|
165
|
+
self.register_buffer("freq_bands", freq_bands, persistent=False)
|
166
|
+
|
167
|
+
def __call__(self, x):
|
168
|
+
x = self.freq_bands * x.unsqueeze(-1)
|
169
|
+
return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
|
170
|
+
|
171
|
+
|
172
|
+
class PositionNet(nn.Module):
|
173
|
+
def __init__(self, positive_len, out_dim, fourier_freqs=8):
|
174
|
+
super().__init__()
|
175
|
+
self.positive_len = positive_len
|
176
|
+
self.out_dim = out_dim
|
177
|
+
|
178
|
+
self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
|
179
|
+
self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
|
180
|
+
|
181
|
+
if isinstance(out_dim, tuple):
|
182
|
+
out_dim = out_dim[0]
|
183
|
+
self.linears = nn.Sequential(
|
184
|
+
nn.Linear(self.positive_len + self.position_dim, 512),
|
185
|
+
nn.SiLU(),
|
186
|
+
nn.Linear(512, 512),
|
187
|
+
nn.SiLU(),
|
188
|
+
nn.Linear(512, out_dim),
|
189
|
+
)
|
190
|
+
|
191
|
+
self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
|
192
|
+
self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
|
193
|
+
|
194
|
+
def forward(self, boxes, masks, positive_embeddings):
|
195
|
+
masks = masks.unsqueeze(-1)
|
196
|
+
|
197
|
+
# embedding position (it may includes padding as placeholder)
|
198
|
+
xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
|
199
|
+
|
200
|
+
# learnable null embedding
|
201
|
+
positive_null = self.null_positive_feature.view(1, 1, -1)
|
202
|
+
xyxy_null = self.null_position_feature.view(1, 1, -1)
|
203
|
+
|
204
|
+
# replace padding with learnable null embedding
|
205
|
+
positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
|
206
|
+
xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
|
207
|
+
|
208
|
+
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
209
|
+
return objs
|
210
|
+
|
211
|
+
|
156
212
|
# Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
|
157
213
|
class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
158
214
|
r"""
|
@@ -298,6 +354,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
298
354
|
conv_in_kernel: int = 3,
|
299
355
|
conv_out_kernel: int = 3,
|
300
356
|
projection_class_embeddings_input_dim: Optional[int] = None,
|
357
|
+
attention_type: str = "default",
|
301
358
|
class_embeddings_concat: bool = False,
|
302
359
|
mid_block_only_cross_attention: Optional[bool] = None,
|
303
360
|
cross_attention_norm: Optional[str] = None,
|
@@ -556,6 +613,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
556
613
|
only_cross_attention=only_cross_attention[i],
|
557
614
|
upcast_attention=upcast_attention,
|
558
615
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
616
|
+
attention_type=attention_type,
|
559
617
|
resnet_skip_time_act=resnet_skip_time_act,
|
560
618
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
561
619
|
cross_attention_norm=cross_attention_norm,
|
@@ -579,6 +637,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
579
637
|
dual_cross_attention=dual_cross_attention,
|
580
638
|
use_linear_projection=use_linear_projection,
|
581
639
|
upcast_attention=upcast_attention,
|
640
|
+
attention_type=attention_type,
|
582
641
|
)
|
583
642
|
elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
|
584
643
|
self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
|
@@ -645,6 +704,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
645
704
|
only_cross_attention=only_cross_attention[i],
|
646
705
|
upcast_attention=upcast_attention,
|
647
706
|
resnet_time_scale_shift=resnet_time_scale_shift,
|
707
|
+
attention_type=attention_type,
|
648
708
|
resnet_skip_time_act=resnet_skip_time_act,
|
649
709
|
resnet_out_scale_factor=resnet_out_scale_factor,
|
650
710
|
cross_attention_norm=cross_attention_norm,
|
@@ -670,6 +730,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
670
730
|
block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
|
671
731
|
)
|
672
732
|
|
733
|
+
if attention_type == "gated":
|
734
|
+
positive_len = 768
|
735
|
+
if isinstance(cross_attention_dim, int):
|
736
|
+
positive_len = cross_attention_dim
|
737
|
+
elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
|
738
|
+
positive_len = cross_attention_dim[0]
|
739
|
+
self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
|
740
|
+
|
673
741
|
@property
|
674
742
|
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
675
743
|
r"""
|
@@ -800,7 +868,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
800
868
|
fn_recursive_set_attention_slice(module, reversed_slice_size)
|
801
869
|
|
802
870
|
def _set_gradient_checkpointing(self, module, value=False):
|
803
|
-
if
|
871
|
+
if hasattr(module, "gradient_checkpointing"):
|
804
872
|
module.gradient_checkpointing = value
|
805
873
|
|
806
874
|
def forward(
|
@@ -1012,6 +1080,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
|
|
1012
1080
|
# 2. pre-process
|
1013
1081
|
sample = self.conv_in(sample)
|
1014
1082
|
|
1083
|
+
# 2.5 GLIGEN position net
|
1084
|
+
if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
|
1085
|
+
cross_attention_kwargs = cross_attention_kwargs.copy()
|
1086
|
+
gligen_args = cross_attention_kwargs.pop("gligen")
|
1087
|
+
cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
|
1088
|
+
|
1015
1089
|
# 3. down
|
1016
1090
|
|
1017
1091
|
is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
|
@@ -1331,6 +1405,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1331
1405
|
use_linear_projection=False,
|
1332
1406
|
only_cross_attention=False,
|
1333
1407
|
upcast_attention=False,
|
1408
|
+
attention_type="default",
|
1334
1409
|
):
|
1335
1410
|
super().__init__()
|
1336
1411
|
resnets = []
|
@@ -1367,6 +1442,7 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1367
1442
|
use_linear_projection=use_linear_projection,
|
1368
1443
|
only_cross_attention=only_cross_attention,
|
1369
1444
|
upcast_attention=upcast_attention,
|
1445
|
+
attention_type=attention_type,
|
1370
1446
|
)
|
1371
1447
|
)
|
1372
1448
|
else:
|
@@ -1429,16 +1505,13 @@ class CrossAttnDownBlockFlat(nn.Module):
|
|
1429
1505
|
temb,
|
1430
1506
|
**ckpt_kwargs,
|
1431
1507
|
)
|
1432
|
-
hidden_states =
|
1433
|
-
create_custom_forward(attn, return_dict=False),
|
1508
|
+
hidden_states = attn(
|
1434
1509
|
hidden_states,
|
1435
|
-
encoder_hidden_states,
|
1436
|
-
|
1437
|
-
|
1438
|
-
|
1439
|
-
|
1440
|
-
encoder_attention_mask,
|
1441
|
-
**ckpt_kwargs,
|
1510
|
+
encoder_hidden_states=encoder_hidden_states,
|
1511
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1512
|
+
attention_mask=attention_mask,
|
1513
|
+
encoder_attention_mask=encoder_attention_mask,
|
1514
|
+
return_dict=False,
|
1442
1515
|
)[0]
|
1443
1516
|
else:
|
1444
1517
|
hidden_states = resnet(hidden_states, temb)
|
@@ -1572,6 +1645,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1572
1645
|
use_linear_projection=False,
|
1573
1646
|
only_cross_attention=False,
|
1574
1647
|
upcast_attention=False,
|
1648
|
+
attention_type="default",
|
1575
1649
|
):
|
1576
1650
|
super().__init__()
|
1577
1651
|
resnets = []
|
@@ -1610,6 +1684,7 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1610
1684
|
use_linear_projection=use_linear_projection,
|
1611
1685
|
only_cross_attention=only_cross_attention,
|
1612
1686
|
upcast_attention=upcast_attention,
|
1687
|
+
attention_type=attention_type,
|
1613
1688
|
)
|
1614
1689
|
)
|
1615
1690
|
else:
|
@@ -1668,16 +1743,13 @@ class CrossAttnUpBlockFlat(nn.Module):
|
|
1668
1743
|
temb,
|
1669
1744
|
**ckpt_kwargs,
|
1670
1745
|
)
|
1671
|
-
hidden_states =
|
1672
|
-
create_custom_forward(attn, return_dict=False),
|
1746
|
+
hidden_states = attn(
|
1673
1747
|
hidden_states,
|
1674
|
-
encoder_hidden_states,
|
1675
|
-
|
1676
|
-
|
1677
|
-
|
1678
|
-
|
1679
|
-
encoder_attention_mask,
|
1680
|
-
**ckpt_kwargs,
|
1748
|
+
encoder_hidden_states=encoder_hidden_states,
|
1749
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1750
|
+
attention_mask=attention_mask,
|
1751
|
+
encoder_attention_mask=encoder_attention_mask,
|
1752
|
+
return_dict=False,
|
1681
1753
|
)[0]
|
1682
1754
|
else:
|
1683
1755
|
hidden_states = resnet(hidden_states, temb)
|
@@ -1717,6 +1789,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1717
1789
|
dual_cross_attention=False,
|
1718
1790
|
use_linear_projection=False,
|
1719
1791
|
upcast_attention=False,
|
1792
|
+
attention_type="default",
|
1720
1793
|
):
|
1721
1794
|
super().__init__()
|
1722
1795
|
|
@@ -1753,6 +1826,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1753
1826
|
norm_num_groups=resnet_groups,
|
1754
1827
|
use_linear_projection=use_linear_projection,
|
1755
1828
|
upcast_attention=upcast_attention,
|
1829
|
+
attention_type=attention_type,
|
1756
1830
|
)
|
1757
1831
|
)
|
1758
1832
|
else:
|
@@ -1784,6 +1858,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1784
1858
|
self.attentions = nn.ModuleList(attentions)
|
1785
1859
|
self.resnets = nn.ModuleList(resnets)
|
1786
1860
|
|
1861
|
+
self.gradient_checkpointing = False
|
1862
|
+
|
1787
1863
|
def forward(
|
1788
1864
|
self,
|
1789
1865
|
hidden_states: torch.FloatTensor,
|
@@ -1795,15 +1871,42 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
|
|
1795
1871
|
) -> torch.FloatTensor:
|
1796
1872
|
hidden_states = self.resnets[0](hidden_states, temb)
|
1797
1873
|
for attn, resnet in zip(self.attentions, self.resnets[1:]):
|
1798
|
-
|
1799
|
-
|
1800
|
-
|
1801
|
-
|
1802
|
-
|
1803
|
-
|
1804
|
-
|
1805
|
-
|
1806
|
-
|
1874
|
+
if self.training and self.gradient_checkpointing:
|
1875
|
+
|
1876
|
+
def create_custom_forward(module, return_dict=None):
|
1877
|
+
def custom_forward(*inputs):
|
1878
|
+
if return_dict is not None:
|
1879
|
+
return module(*inputs, return_dict=return_dict)
|
1880
|
+
else:
|
1881
|
+
return module(*inputs)
|
1882
|
+
|
1883
|
+
return custom_forward
|
1884
|
+
|
1885
|
+
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
|
1886
|
+
hidden_states = attn(
|
1887
|
+
hidden_states,
|
1888
|
+
encoder_hidden_states=encoder_hidden_states,
|
1889
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1890
|
+
attention_mask=attention_mask,
|
1891
|
+
encoder_attention_mask=encoder_attention_mask,
|
1892
|
+
return_dict=False,
|
1893
|
+
)[0]
|
1894
|
+
hidden_states = torch.utils.checkpoint.checkpoint(
|
1895
|
+
create_custom_forward(resnet),
|
1896
|
+
hidden_states,
|
1897
|
+
temb,
|
1898
|
+
**ckpt_kwargs,
|
1899
|
+
)
|
1900
|
+
else:
|
1901
|
+
hidden_states = attn(
|
1902
|
+
hidden_states,
|
1903
|
+
encoder_hidden_states=encoder_hidden_states,
|
1904
|
+
cross_attention_kwargs=cross_attention_kwargs,
|
1905
|
+
attention_mask=attention_mask,
|
1906
|
+
encoder_attention_mask=encoder_attention_mask,
|
1907
|
+
return_dict=False,
|
1908
|
+
)[0]
|
1909
|
+
hidden_states = resnet(hidden_states, temb)
|
1807
1910
|
|
1808
1911
|
return hidden_states
|
1809
1912
|
|