diffusers 0.19.3__py3-none-any.whl → 0.20.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (114) hide show
  1. diffusers/__init__.py +3 -1
  2. diffusers/commands/fp16_safetensors.py +2 -7
  3. diffusers/configuration_utils.py +23 -1
  4. diffusers/dependency_versions_table.py +1 -1
  5. diffusers/loaders.py +62 -64
  6. diffusers/models/__init__.py +1 -0
  7. diffusers/models/activations.py +2 -0
  8. diffusers/models/attention.py +45 -1
  9. diffusers/models/autoencoder_tiny.py +193 -0
  10. diffusers/models/controlnet.py +1 -1
  11. diffusers/models/embeddings.py +56 -0
  12. diffusers/models/lora.py +0 -6
  13. diffusers/models/modeling_flax_utils.py +28 -2
  14. diffusers/models/modeling_utils.py +33 -16
  15. diffusers/models/transformer_2d.py +26 -9
  16. diffusers/models/unet_1d.py +2 -2
  17. diffusers/models/unet_2d_blocks.py +106 -56
  18. diffusers/models/unet_2d_condition.py +20 -5
  19. diffusers/models/vae.py +106 -1
  20. diffusers/pipelines/__init__.py +1 -0
  21. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +10 -3
  22. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +10 -3
  23. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -1
  24. diffusers/pipelines/auto_pipeline.py +33 -43
  25. diffusers/pipelines/controlnet/multicontrolnet.py +4 -2
  26. diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -4
  27. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +15 -7
  28. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +14 -4
  29. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +157 -10
  30. diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -10
  31. diffusers/pipelines/deepfloyd_if/pipeline_if.py +1 -1
  32. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +1 -1
  33. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +1 -1
  34. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +1 -1
  35. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +1 -1
  36. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +1 -1
  37. diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +43 -2
  38. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +44 -2
  39. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +1 -1
  40. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  41. diffusers/pipelines/pipeline_flax_utils.py +41 -4
  42. diffusers/pipelines/pipeline_utils.py +60 -16
  43. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +2 -2
  44. diffusers/pipelines/stable_diffusion/__init__.py +1 -0
  45. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +81 -37
  46. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +10 -3
  47. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +10 -3
  48. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +10 -3
  49. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +10 -3
  50. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +12 -5
  51. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +832 -0
  52. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +10 -3
  53. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +10 -3
  54. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +10 -3
  55. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +9 -2
  56. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +17 -8
  57. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +10 -3
  58. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +10 -3
  59. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +10 -3
  60. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +10 -3
  61. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +10 -3
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +10 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +10 -3
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +10 -3
  65. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +3 -5
  66. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +75 -3
  67. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +76 -6
  68. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +1 -2
  69. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +10 -3
  70. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +10 -3
  71. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +11 -4
  72. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +1 -1
  73. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +131 -28
  74. diffusers/schedulers/scheduling_consistency_models.py +70 -57
  75. diffusers/schedulers/scheduling_ddim.py +76 -71
  76. diffusers/schedulers/scheduling_ddim_inverse.py +76 -44
  77. diffusers/schedulers/scheduling_ddim_parallel.py +11 -8
  78. diffusers/schedulers/scheduling_ddpm.py +68 -67
  79. diffusers/schedulers/scheduling_ddpm_parallel.py +18 -15
  80. diffusers/schedulers/scheduling_deis_multistep.py +93 -85
  81. diffusers/schedulers/scheduling_dpmsolver_multistep.py +118 -120
  82. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +116 -109
  83. diffusers/schedulers/scheduling_dpmsolver_sde.py +57 -43
  84. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +122 -121
  85. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +54 -44
  86. diffusers/schedulers/scheduling_euler_discrete.py +63 -56
  87. diffusers/schedulers/scheduling_heun_discrete.py +57 -45
  88. diffusers/schedulers/scheduling_ipndm.py +27 -22
  89. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +54 -41
  90. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +52 -41
  91. diffusers/schedulers/scheduling_karras_ve.py +55 -45
  92. diffusers/schedulers/scheduling_lms_discrete.py +58 -52
  93. diffusers/schedulers/scheduling_pndm.py +77 -62
  94. diffusers/schedulers/scheduling_repaint.py +56 -38
  95. diffusers/schedulers/scheduling_sde_ve.py +62 -50
  96. diffusers/schedulers/scheduling_sde_vp.py +32 -11
  97. diffusers/schedulers/scheduling_unclip.py +3 -3
  98. diffusers/schedulers/scheduling_unipc_multistep.py +131 -91
  99. diffusers/schedulers/scheduling_utils.py +41 -35
  100. diffusers/schedulers/scheduling_utils_flax.py +8 -2
  101. diffusers/schedulers/scheduling_vq_diffusion.py +39 -68
  102. diffusers/utils/__init__.py +2 -2
  103. diffusers/utils/dummy_pt_objects.py +15 -0
  104. diffusers/utils/dummy_torch_and_transformers_objects.py +15 -0
  105. diffusers/utils/hub_utils.py +105 -2
  106. diffusers/utils/import_utils.py +0 -4
  107. diffusers/utils/pil_utils.py +19 -0
  108. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/METADATA +5 -7
  109. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/RECORD +113 -112
  110. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/WHEEL +1 -1
  111. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/entry_points.txt +0 -1
  112. diffusers/models/cross_attention.py +0 -94
  113. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/LICENSE +0 -0
  114. {diffusers-0.19.3.dist-info → diffusers-0.20.1.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
13
13
  # limitations under the License.
14
14
 
15
15
  import inspect
16
+ import os
16
17
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
17
18
 
18
19
  import numpy as np
@@ -221,9 +222,7 @@ def prepare_mask_and_masked_image(image, mask, height, width, return_image: bool
221
222
  return mask, masked_image
222
223
 
223
224
 
224
- class StableDiffusionXLInpaintPipeline(
225
- DiffusionPipeline, TextualInversionLoaderMixin, LoraLoaderMixin, FromSingleFileMixin
226
- ):
225
+ class StableDiffusionXLInpaintPipeline(DiffusionPipeline, LoraLoaderMixin, FromSingleFileMixin):
227
226
  r"""
228
227
  Pipeline for text-to-image generation using Stable Diffusion XL.
229
228
 
@@ -231,7 +230,6 @@ class StableDiffusionXLInpaintPipeline(
231
230
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
232
231
 
233
232
  In addition the pipeline inherits the following loading methods:
234
- - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
235
233
  - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
236
234
  - *Ckpt*: [`loaders.FromSingleFileMixin.from_single_file`]
237
235
 
@@ -458,7 +456,6 @@ class StableDiffusionXLInpaintPipeline(
458
456
 
459
457
  text_input_ids = text_inputs.input_ids
460
458
  untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
461
- untruncated_ids = tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
462
459
 
463
460
  if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
464
461
  text_input_ids, untruncated_ids
@@ -993,7 +990,7 @@ class StableDiffusionXLInpaintPipeline(
993
990
  cross_attention_kwargs (`dict`, *optional*):
994
991
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
995
992
  `self.processor` in
996
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
993
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
997
994
  original_size (`Tuple[int]`, *optional*, defaults to (1024, 1024)):
998
995
  If `original_size` is not the same as `target_size` the image will appear to be down- or upsampled.
999
996
  `original_size` defaults to `(width, height)` if not specified. Part of SDXL's micro-conditioning as
@@ -1296,3 +1293,76 @@ class StableDiffusionXLInpaintPipeline(
1296
1293
  return (image,)
1297
1294
 
1298
1295
  return StableDiffusionXLPipelineOutput(images=image)
1296
+
1297
+ # Overrride to properly handle the loading and unloading of the additional text encoder.
1298
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.load_lora_weights
1299
+ def load_lora_weights(self, pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]], **kwargs):
1300
+ # We could have accessed the unet config from `lora_state_dict()` too. We pass
1301
+ # it here explicitly to be able to tell that it's coming from an SDXL
1302
+ # pipeline.
1303
+ state_dict, network_alphas = self.lora_state_dict(
1304
+ pretrained_model_name_or_path_or_dict,
1305
+ unet_config=self.unet.config,
1306
+ **kwargs,
1307
+ )
1308
+ self.load_lora_into_unet(state_dict, network_alphas=network_alphas, unet=self.unet)
1309
+
1310
+ text_encoder_state_dict = {k: v for k, v in state_dict.items() if "text_encoder." in k}
1311
+ if len(text_encoder_state_dict) > 0:
1312
+ self.load_lora_into_text_encoder(
1313
+ text_encoder_state_dict,
1314
+ network_alphas=network_alphas,
1315
+ text_encoder=self.text_encoder,
1316
+ prefix="text_encoder",
1317
+ lora_scale=self.lora_scale,
1318
+ )
1319
+
1320
+ text_encoder_2_state_dict = {k: v for k, v in state_dict.items() if "text_encoder_2." in k}
1321
+ if len(text_encoder_2_state_dict) > 0:
1322
+ self.load_lora_into_text_encoder(
1323
+ text_encoder_2_state_dict,
1324
+ network_alphas=network_alphas,
1325
+ text_encoder=self.text_encoder_2,
1326
+ prefix="text_encoder_2",
1327
+ lora_scale=self.lora_scale,
1328
+ )
1329
+
1330
+ @classmethod
1331
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline.save_lora_weights
1332
+ def save_lora_weights(
1333
+ self,
1334
+ save_directory: Union[str, os.PathLike],
1335
+ unet_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1336
+ text_encoder_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1337
+ text_encoder_2_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
1338
+ is_main_process: bool = True,
1339
+ weight_name: str = None,
1340
+ save_function: Callable = None,
1341
+ safe_serialization: bool = True,
1342
+ ):
1343
+ state_dict = {}
1344
+
1345
+ def pack_weights(layers, prefix):
1346
+ layers_weights = layers.state_dict() if isinstance(layers, torch.nn.Module) else layers
1347
+ layers_state_dict = {f"{prefix}.{module_name}": param for module_name, param in layers_weights.items()}
1348
+ return layers_state_dict
1349
+
1350
+ state_dict.update(pack_weights(unet_lora_layers, "unet"))
1351
+
1352
+ if text_encoder_lora_layers and text_encoder_2_lora_layers:
1353
+ state_dict.update(pack_weights(text_encoder_lora_layers, "text_encoder"))
1354
+ state_dict.update(pack_weights(text_encoder_2_lora_layers, "text_encoder_2"))
1355
+
1356
+ self.write_lora_layers(
1357
+ state_dict=state_dict,
1358
+ save_directory=save_directory,
1359
+ is_main_process=is_main_process,
1360
+ weight_name=weight_name,
1361
+ save_function=save_function,
1362
+ safe_serialization=safe_serialization,
1363
+ )
1364
+
1365
+ # Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl.StableDiffusionXLPipeline._remove_text_encoder_monkey_patch
1366
+ def _remove_text_encoder_monkey_patch(self):
1367
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder)
1368
+ self._remove_text_encoder_monkey_patch_classmethod(self.text_encoder_2)
@@ -71,7 +71,6 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
71
71
  library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
72
72
 
73
73
  In addition the pipeline inherits the following loading methods:
74
- - *Textual-Inversion*: [`loaders.TextualInversionLoaderMixin.load_textual_inversion`]
75
74
  - *LoRA*: [`loaders.LoraLoaderMixin.load_lora_weights`]
76
75
 
77
76
  as well as the following saving methods:
@@ -688,7 +687,7 @@ class StableDiffusionXLInstructPix2PixPipeline(DiffusionPipeline, FromSingleFile
688
687
  cross_attention_kwargs (`dict`, *optional*):
689
688
  A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
690
689
  `self.processor` in
691
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
690
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
692
691
  guidance_rescale (`float`, *optional*, defaults to 0.7):
693
692
  Guidance rescale factor proposed by [Common Diffusion Noise Schedules and Sample Steps are
694
693
  Flawed](https://arxiv.org/pdf/2305.08891.pdf) `guidance_scale` is defined as `φ` in equation 16. of
@@ -331,7 +331,14 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
331
331
  )
332
332
  prompt_embeds = prompt_embeds[0]
333
333
 
334
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
334
+ if self.text_encoder is not None:
335
+ prompt_embeds_dtype = self.text_encoder.dtype
336
+ elif self.unet is not None:
337
+ prompt_embeds_dtype = self.unet.dtype
338
+ else:
339
+ prompt_embeds_dtype = prompt_embeds.dtype
340
+
341
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
335
342
 
336
343
  bs_embed, seq_len, _ = prompt_embeds.shape
337
344
  # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -387,7 +394,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
387
394
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
388
395
  seq_len = negative_prompt_embeds.shape[1]
389
396
 
390
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
397
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
391
398
 
392
399
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
393
400
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -625,7 +632,7 @@ class StableDiffusionAdapterPipeline(DiffusionPipeline):
625
632
  cross_attention_kwargs (`dict`, *optional*):
626
633
  A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
627
634
  `self.processor` in
628
- [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
635
+ [diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
629
636
  adapter_conditioning_scale (`float` or `List[float]`, *optional*, defaults to 1.0):
630
637
  The outputs of the adapter are multiplied by `adapter_conditioning_scale` before they are added to the
631
638
  residual in the original unet. If multiple adapters are specified in init, you can set the
@@ -258,7 +258,14 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
258
258
  )
259
259
  prompt_embeds = prompt_embeds[0]
260
260
 
261
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
261
+ if self.text_encoder is not None:
262
+ prompt_embeds_dtype = self.text_encoder.dtype
263
+ elif self.unet is not None:
264
+ prompt_embeds_dtype = self.unet.dtype
265
+ else:
266
+ prompt_embeds_dtype = prompt_embeds.dtype
267
+
268
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
262
269
 
263
270
  bs_embed, seq_len, _ = prompt_embeds.shape
264
271
  # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -314,7 +321,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
314
321
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
315
322
  seq_len = negative_prompt_embeds.shape[1]
316
323
 
317
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
324
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
318
325
 
319
326
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
320
327
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -516,7 +523,7 @@ class TextToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lora
516
523
  every step.
517
524
  cross_attention_kwargs (`dict`, *optional*):
518
525
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
519
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
526
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
520
527
 
521
528
  Examples:
522
529
 
@@ -320,7 +320,14 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
320
320
  )
321
321
  prompt_embeds = prompt_embeds[0]
322
322
 
323
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
323
+ if self.text_encoder is not None:
324
+ prompt_embeds_dtype = self.text_encoder.dtype
325
+ elif self.unet is not None:
326
+ prompt_embeds_dtype = self.unet.dtype
327
+ else:
328
+ prompt_embeds_dtype = prompt_embeds.dtype
329
+
330
+ prompt_embeds = prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
324
331
 
325
332
  bs_embed, seq_len, _ = prompt_embeds.shape
326
333
  # duplicate text embeddings for each generation per prompt, using mps friendly method
@@ -376,7 +383,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
376
383
  # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
377
384
  seq_len = negative_prompt_embeds.shape[1]
378
385
 
379
- negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
386
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=prompt_embeds_dtype, device=device)
380
387
 
381
388
  negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
382
389
  negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
@@ -554,7 +561,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
554
561
  The prompt or prompts to guide image generation. If not defined, you need to pass `prompt_embeds`.
555
562
  video (`List[np.ndarray]` or `torch.FloatTensor`):
556
563
  `video` frames or tensor representing a video batch to be used as the starting point for the process.
557
- Can also accpet video latents as `image`, if passing latents directly, it will not be encoded again.
564
+ Can also accept video latents as `image`, if passing latents directly, it will not be encoded again.
558
565
  strength (`float`, *optional*, defaults to 0.8):
559
566
  Indicates extent to transform the reference `video`. Must be between 0 and 1. `video` is used as a
560
567
  starting point, adding more noise to it the larger the `strength`. The number of denoising steps
@@ -600,7 +607,7 @@ class VideoToVideoSDPipeline(DiffusionPipeline, TextualInversionLoaderMixin, Lor
600
607
  every step.
601
608
  cross_attention_kwargs (`dict`, *optional*):
602
609
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
603
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
610
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
604
611
 
605
612
  Examples:
606
613
 
@@ -378,7 +378,7 @@ class TextToVideoZeroPipeline(StableDiffusionPipeline):
378
378
  Extra_step_kwargs.
379
379
  cross_attention_kwargs:
380
380
  A kwargs dictionary that if specified is passed along to the [`AttentionProcessor`] as defined in
381
- [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
381
+ [`self.processor`](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
382
382
  num_warmup_steps:
383
383
  number of warmup steps.
384
384
 
@@ -153,6 +153,62 @@ def get_up_block(
153
153
  raise ValueError(f"{up_block_type} is not supported.")
154
154
 
155
155
 
156
+ class FourierEmbedder(nn.Module):
157
+ def __init__(self, num_freqs=64, temperature=100):
158
+ super().__init__()
159
+
160
+ self.num_freqs = num_freqs
161
+ self.temperature = temperature
162
+
163
+ freq_bands = temperature ** (torch.arange(num_freqs) / num_freqs)
164
+ freq_bands = freq_bands[None, None, None]
165
+ self.register_buffer("freq_bands", freq_bands, persistent=False)
166
+
167
+ def __call__(self, x):
168
+ x = self.freq_bands * x.unsqueeze(-1)
169
+ return torch.stack((x.sin(), x.cos()), dim=-1).permute(0, 1, 3, 4, 2).reshape(*x.shape[:2], -1)
170
+
171
+
172
+ class PositionNet(nn.Module):
173
+ def __init__(self, positive_len, out_dim, fourier_freqs=8):
174
+ super().__init__()
175
+ self.positive_len = positive_len
176
+ self.out_dim = out_dim
177
+
178
+ self.fourier_embedder = FourierEmbedder(num_freqs=fourier_freqs)
179
+ self.position_dim = fourier_freqs * 2 * 4 # 2: sin/cos, 4: xyxy
180
+
181
+ if isinstance(out_dim, tuple):
182
+ out_dim = out_dim[0]
183
+ self.linears = nn.Sequential(
184
+ nn.Linear(self.positive_len + self.position_dim, 512),
185
+ nn.SiLU(),
186
+ nn.Linear(512, 512),
187
+ nn.SiLU(),
188
+ nn.Linear(512, out_dim),
189
+ )
190
+
191
+ self.null_positive_feature = torch.nn.Parameter(torch.zeros([self.positive_len]))
192
+ self.null_position_feature = torch.nn.Parameter(torch.zeros([self.position_dim]))
193
+
194
+ def forward(self, boxes, masks, positive_embeddings):
195
+ masks = masks.unsqueeze(-1)
196
+
197
+ # embedding position (it may includes padding as placeholder)
198
+ xyxy_embedding = self.fourier_embedder(boxes) # B*N*4 -> B*N*C
199
+
200
+ # learnable null embedding
201
+ positive_null = self.null_positive_feature.view(1, 1, -1)
202
+ xyxy_null = self.null_position_feature.view(1, 1, -1)
203
+
204
+ # replace padding with learnable null embedding
205
+ positive_embeddings = positive_embeddings * masks + (1 - masks) * positive_null
206
+ xyxy_embedding = xyxy_embedding * masks + (1 - masks) * xyxy_null
207
+
208
+ objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
209
+ return objs
210
+
211
+
156
212
  # Copied from diffusers.models.unet_2d_condition.UNet2DConditionModel with UNet2DConditionModel->UNetFlatConditionModel, nn.Conv2d->LinearMultiDim, Block2D->BlockFlat
157
213
  class UNetFlatConditionModel(ModelMixin, ConfigMixin):
158
214
  r"""
@@ -298,6 +354,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
298
354
  conv_in_kernel: int = 3,
299
355
  conv_out_kernel: int = 3,
300
356
  projection_class_embeddings_input_dim: Optional[int] = None,
357
+ attention_type: str = "default",
301
358
  class_embeddings_concat: bool = False,
302
359
  mid_block_only_cross_attention: Optional[bool] = None,
303
360
  cross_attention_norm: Optional[str] = None,
@@ -556,6 +613,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
556
613
  only_cross_attention=only_cross_attention[i],
557
614
  upcast_attention=upcast_attention,
558
615
  resnet_time_scale_shift=resnet_time_scale_shift,
616
+ attention_type=attention_type,
559
617
  resnet_skip_time_act=resnet_skip_time_act,
560
618
  resnet_out_scale_factor=resnet_out_scale_factor,
561
619
  cross_attention_norm=cross_attention_norm,
@@ -579,6 +637,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
579
637
  dual_cross_attention=dual_cross_attention,
580
638
  use_linear_projection=use_linear_projection,
581
639
  upcast_attention=upcast_attention,
640
+ attention_type=attention_type,
582
641
  )
583
642
  elif mid_block_type == "UNetMidBlockFlatSimpleCrossAttn":
584
643
  self.mid_block = UNetMidBlockFlatSimpleCrossAttn(
@@ -645,6 +704,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
645
704
  only_cross_attention=only_cross_attention[i],
646
705
  upcast_attention=upcast_attention,
647
706
  resnet_time_scale_shift=resnet_time_scale_shift,
707
+ attention_type=attention_type,
648
708
  resnet_skip_time_act=resnet_skip_time_act,
649
709
  resnet_out_scale_factor=resnet_out_scale_factor,
650
710
  cross_attention_norm=cross_attention_norm,
@@ -670,6 +730,14 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
670
730
  block_out_channels[0], out_channels, kernel_size=conv_out_kernel, padding=conv_out_padding
671
731
  )
672
732
 
733
+ if attention_type == "gated":
734
+ positive_len = 768
735
+ if isinstance(cross_attention_dim, int):
736
+ positive_len = cross_attention_dim
737
+ elif isinstance(cross_attention_dim, tuple) or isinstance(cross_attention_dim, list):
738
+ positive_len = cross_attention_dim[0]
739
+ self.position_net = PositionNet(positive_len=positive_len, out_dim=cross_attention_dim)
740
+
673
741
  @property
674
742
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
675
743
  r"""
@@ -800,7 +868,7 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
800
868
  fn_recursive_set_attention_slice(module, reversed_slice_size)
801
869
 
802
870
  def _set_gradient_checkpointing(self, module, value=False):
803
- if isinstance(module, (CrossAttnDownBlockFlat, DownBlockFlat, CrossAttnUpBlockFlat, UpBlockFlat)):
871
+ if hasattr(module, "gradient_checkpointing"):
804
872
  module.gradient_checkpointing = value
805
873
 
806
874
  def forward(
@@ -1012,6 +1080,12 @@ class UNetFlatConditionModel(ModelMixin, ConfigMixin):
1012
1080
  # 2. pre-process
1013
1081
  sample = self.conv_in(sample)
1014
1082
 
1083
+ # 2.5 GLIGEN position net
1084
+ if cross_attention_kwargs is not None and cross_attention_kwargs.get("gligen", None) is not None:
1085
+ cross_attention_kwargs = cross_attention_kwargs.copy()
1086
+ gligen_args = cross_attention_kwargs.pop("gligen")
1087
+ cross_attention_kwargs["gligen"] = {"objs": self.position_net(**gligen_args)}
1088
+
1015
1089
  # 3. down
1016
1090
 
1017
1091
  is_controlnet = mid_block_additional_residual is not None and down_block_additional_residuals is not None
@@ -1331,6 +1405,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1331
1405
  use_linear_projection=False,
1332
1406
  only_cross_attention=False,
1333
1407
  upcast_attention=False,
1408
+ attention_type="default",
1334
1409
  ):
1335
1410
  super().__init__()
1336
1411
  resnets = []
@@ -1367,6 +1442,7 @@ class CrossAttnDownBlockFlat(nn.Module):
1367
1442
  use_linear_projection=use_linear_projection,
1368
1443
  only_cross_attention=only_cross_attention,
1369
1444
  upcast_attention=upcast_attention,
1445
+ attention_type=attention_type,
1370
1446
  )
1371
1447
  )
1372
1448
  else:
@@ -1429,16 +1505,13 @@ class CrossAttnDownBlockFlat(nn.Module):
1429
1505
  temb,
1430
1506
  **ckpt_kwargs,
1431
1507
  )
1432
- hidden_states = torch.utils.checkpoint.checkpoint(
1433
- create_custom_forward(attn, return_dict=False),
1508
+ hidden_states = attn(
1434
1509
  hidden_states,
1435
- encoder_hidden_states,
1436
- None, # timestep
1437
- None, # class_labels
1438
- cross_attention_kwargs,
1439
- attention_mask,
1440
- encoder_attention_mask,
1441
- **ckpt_kwargs,
1510
+ encoder_hidden_states=encoder_hidden_states,
1511
+ cross_attention_kwargs=cross_attention_kwargs,
1512
+ attention_mask=attention_mask,
1513
+ encoder_attention_mask=encoder_attention_mask,
1514
+ return_dict=False,
1442
1515
  )[0]
1443
1516
  else:
1444
1517
  hidden_states = resnet(hidden_states, temb)
@@ -1572,6 +1645,7 @@ class CrossAttnUpBlockFlat(nn.Module):
1572
1645
  use_linear_projection=False,
1573
1646
  only_cross_attention=False,
1574
1647
  upcast_attention=False,
1648
+ attention_type="default",
1575
1649
  ):
1576
1650
  super().__init__()
1577
1651
  resnets = []
@@ -1610,6 +1684,7 @@ class CrossAttnUpBlockFlat(nn.Module):
1610
1684
  use_linear_projection=use_linear_projection,
1611
1685
  only_cross_attention=only_cross_attention,
1612
1686
  upcast_attention=upcast_attention,
1687
+ attention_type=attention_type,
1613
1688
  )
1614
1689
  )
1615
1690
  else:
@@ -1668,16 +1743,13 @@ class CrossAttnUpBlockFlat(nn.Module):
1668
1743
  temb,
1669
1744
  **ckpt_kwargs,
1670
1745
  )
1671
- hidden_states = torch.utils.checkpoint.checkpoint(
1672
- create_custom_forward(attn, return_dict=False),
1746
+ hidden_states = attn(
1673
1747
  hidden_states,
1674
- encoder_hidden_states,
1675
- None, # timestep
1676
- None, # class_labels
1677
- cross_attention_kwargs,
1678
- attention_mask,
1679
- encoder_attention_mask,
1680
- **ckpt_kwargs,
1748
+ encoder_hidden_states=encoder_hidden_states,
1749
+ cross_attention_kwargs=cross_attention_kwargs,
1750
+ attention_mask=attention_mask,
1751
+ encoder_attention_mask=encoder_attention_mask,
1752
+ return_dict=False,
1681
1753
  )[0]
1682
1754
  else:
1683
1755
  hidden_states = resnet(hidden_states, temb)
@@ -1717,6 +1789,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1717
1789
  dual_cross_attention=False,
1718
1790
  use_linear_projection=False,
1719
1791
  upcast_attention=False,
1792
+ attention_type="default",
1720
1793
  ):
1721
1794
  super().__init__()
1722
1795
 
@@ -1753,6 +1826,7 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1753
1826
  norm_num_groups=resnet_groups,
1754
1827
  use_linear_projection=use_linear_projection,
1755
1828
  upcast_attention=upcast_attention,
1829
+ attention_type=attention_type,
1756
1830
  )
1757
1831
  )
1758
1832
  else:
@@ -1784,6 +1858,8 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1784
1858
  self.attentions = nn.ModuleList(attentions)
1785
1859
  self.resnets = nn.ModuleList(resnets)
1786
1860
 
1861
+ self.gradient_checkpointing = False
1862
+
1787
1863
  def forward(
1788
1864
  self,
1789
1865
  hidden_states: torch.FloatTensor,
@@ -1795,15 +1871,42 @@ class UNetMidBlockFlatCrossAttn(nn.Module):
1795
1871
  ) -> torch.FloatTensor:
1796
1872
  hidden_states = self.resnets[0](hidden_states, temb)
1797
1873
  for attn, resnet in zip(self.attentions, self.resnets[1:]):
1798
- hidden_states = attn(
1799
- hidden_states,
1800
- encoder_hidden_states=encoder_hidden_states,
1801
- cross_attention_kwargs=cross_attention_kwargs,
1802
- attention_mask=attention_mask,
1803
- encoder_attention_mask=encoder_attention_mask,
1804
- return_dict=False,
1805
- )[0]
1806
- hidden_states = resnet(hidden_states, temb)
1874
+ if self.training and self.gradient_checkpointing:
1875
+
1876
+ def create_custom_forward(module, return_dict=None):
1877
+ def custom_forward(*inputs):
1878
+ if return_dict is not None:
1879
+ return module(*inputs, return_dict=return_dict)
1880
+ else:
1881
+ return module(*inputs)
1882
+
1883
+ return custom_forward
1884
+
1885
+ ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False} if is_torch_version(">=", "1.11.0") else {}
1886
+ hidden_states = attn(
1887
+ hidden_states,
1888
+ encoder_hidden_states=encoder_hidden_states,
1889
+ cross_attention_kwargs=cross_attention_kwargs,
1890
+ attention_mask=attention_mask,
1891
+ encoder_attention_mask=encoder_attention_mask,
1892
+ return_dict=False,
1893
+ )[0]
1894
+ hidden_states = torch.utils.checkpoint.checkpoint(
1895
+ create_custom_forward(resnet),
1896
+ hidden_states,
1897
+ temb,
1898
+ **ckpt_kwargs,
1899
+ )
1900
+ else:
1901
+ hidden_states = attn(
1902
+ hidden_states,
1903
+ encoder_hidden_states=encoder_hidden_states,
1904
+ cross_attention_kwargs=cross_attention_kwargs,
1905
+ attention_mask=attention_mask,
1906
+ encoder_attention_mask=encoder_attention_mask,
1907
+ return_dict=False,
1908
+ )[0]
1909
+ hidden_states = resnet(hidden_states, temb)
1807
1910
 
1808
1911
  return hidden_states
1809
1912