diffusers 0.23.0__py3-none-any.whl → 0.24.0__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (177) hide show
  1. diffusers/__init__.py +16 -2
  2. diffusers/configuration_utils.py +1 -0
  3. diffusers/dependency_versions_check.py +1 -14
  4. diffusers/dependency_versions_table.py +5 -4
  5. diffusers/image_processor.py +186 -14
  6. diffusers/loaders/__init__.py +82 -0
  7. diffusers/loaders/ip_adapter.py +157 -0
  8. diffusers/loaders/lora.py +1415 -0
  9. diffusers/loaders/lora_conversion_utils.py +284 -0
  10. diffusers/loaders/single_file.py +631 -0
  11. diffusers/loaders/textual_inversion.py +459 -0
  12. diffusers/loaders/unet.py +735 -0
  13. diffusers/loaders/utils.py +59 -0
  14. diffusers/models/__init__.py +12 -1
  15. diffusers/models/attention.py +165 -14
  16. diffusers/models/attention_flax.py +9 -1
  17. diffusers/models/attention_processor.py +286 -1
  18. diffusers/models/autoencoder_asym_kl.py +14 -9
  19. diffusers/models/autoencoder_kl.py +3 -18
  20. diffusers/models/autoencoder_kl_temporal_decoder.py +402 -0
  21. diffusers/models/autoencoder_tiny.py +20 -24
  22. diffusers/models/consistency_decoder_vae.py +37 -30
  23. diffusers/models/controlnet.py +59 -39
  24. diffusers/models/controlnet_flax.py +19 -18
  25. diffusers/models/embeddings_flax.py +2 -0
  26. diffusers/models/lora.py +131 -1
  27. diffusers/models/modeling_flax_utils.py +2 -1
  28. diffusers/models/modeling_outputs.py +17 -0
  29. diffusers/models/modeling_utils.py +27 -19
  30. diffusers/models/normalization.py +2 -2
  31. diffusers/models/resnet.py +390 -59
  32. diffusers/models/transformer_2d.py +20 -3
  33. diffusers/models/transformer_temporal.py +183 -1
  34. diffusers/models/unet_2d_blocks_flax.py +5 -0
  35. diffusers/models/unet_2d_condition.py +9 -0
  36. diffusers/models/unet_2d_condition_flax.py +13 -13
  37. diffusers/models/unet_3d_blocks.py +957 -173
  38. diffusers/models/unet_3d_condition.py +16 -8
  39. diffusers/models/unet_kandi3.py +589 -0
  40. diffusers/models/unet_motion_model.py +48 -33
  41. diffusers/models/unet_spatio_temporal_condition.py +489 -0
  42. diffusers/models/vae.py +63 -13
  43. diffusers/models/vae_flax.py +7 -0
  44. diffusers/models/vq_model.py +3 -1
  45. diffusers/optimization.py +16 -9
  46. diffusers/pipelines/__init__.py +65 -12
  47. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion.py +93 -23
  48. diffusers/pipelines/alt_diffusion/pipeline_alt_diffusion_img2img.py +97 -25
  49. diffusers/pipelines/animatediff/pipeline_animatediff.py +34 -4
  50. diffusers/pipelines/audioldm/pipeline_audioldm.py +1 -0
  51. diffusers/pipelines/auto_pipeline.py +6 -0
  52. diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -0
  53. diffusers/pipelines/controlnet/pipeline_controlnet.py +217 -31
  54. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +101 -32
  55. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +136 -39
  56. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +119 -37
  57. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +196 -35
  58. diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +102 -31
  59. diffusers/pipelines/dance_diffusion/pipeline_dance_diffusion.py +1 -0
  60. diffusers/pipelines/ddim/pipeline_ddim.py +1 -0
  61. diffusers/pipelines/ddpm/pipeline_ddpm.py +1 -0
  62. diffusers/pipelines/deepfloyd_if/pipeline_if.py +13 -1
  63. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +13 -1
  64. diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +13 -1
  65. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +13 -1
  66. diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +13 -1
  67. diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +13 -1
  68. diffusers/pipelines/dit/pipeline_dit.py +1 -0
  69. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +1 -1
  70. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +3 -3
  71. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +1 -1
  72. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +1 -1
  73. diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +1 -1
  74. diffusers/pipelines/kandinsky3/__init__.py +49 -0
  75. diffusers/pipelines/kandinsky3/kandinsky3_pipeline.py +452 -0
  76. diffusers/pipelines/kandinsky3/kandinsky3img2img_pipeline.py +460 -0
  77. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +65 -6
  78. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +55 -3
  79. diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -0
  80. diffusers/pipelines/musicldm/pipeline_musicldm.py +1 -1
  81. diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +7 -2
  82. diffusers/pipelines/pipeline_flax_utils.py +4 -2
  83. diffusers/pipelines/pipeline_utils.py +33 -13
  84. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +196 -36
  85. diffusers/pipelines/score_sde_ve/pipeline_score_sde_ve.py +1 -0
  86. diffusers/pipelines/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +1 -0
  87. diffusers/pipelines/stable_diffusion/__init__.py +64 -21
  88. diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +8 -3
  89. diffusers/pipelines/stable_diffusion/pipeline_cycle_diffusion.py +18 -2
  90. diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +2 -2
  91. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_img2img.py +2 -4
  92. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -0
  93. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint_legacy.py +1 -0
  94. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +88 -9
  95. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_attend_and_excite.py +1 -0
  96. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +8 -3
  97. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_diffedit.py +1 -0
  98. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen.py +1 -0
  99. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_gligen_text_image.py +1 -0
  100. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +1 -0
  101. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +92 -9
  102. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +92 -9
  103. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint_legacy.py +1 -0
  104. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +17 -13
  105. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -0
  106. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +1 -0
  107. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_ldm3d.py +1 -0
  108. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_model_editing.py +1 -0
  109. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_panorama.py +1 -0
  110. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_paradigms.py +1 -0
  111. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_pix2pix_zero.py +1 -0
  112. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_sag.py +1 -0
  113. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +1 -0
  114. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +103 -8
  115. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +113 -8
  116. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +115 -9
  117. diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +16 -12
  118. diffusers/pipelines/stable_video_diffusion/__init__.py +58 -0
  119. diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +649 -0
  120. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +108 -12
  121. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +109 -14
  122. diffusers/pipelines/text_to_video_synthesis/__init__.py +2 -0
  123. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +1 -0
  124. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +18 -3
  125. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -2
  126. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +872 -0
  127. diffusers/pipelines/versatile_diffusion/modeling_text_unet.py +29 -40
  128. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +1 -0
  129. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +1 -0
  130. diffusers/pipelines/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +1 -0
  131. diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +14 -4
  132. diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +9 -5
  133. diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +1 -1
  134. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +2 -2
  135. diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +1 -1
  136. diffusers/schedulers/__init__.py +2 -4
  137. diffusers/schedulers/deprecated/__init__.py +50 -0
  138. diffusers/schedulers/{scheduling_karras_ve.py → deprecated/scheduling_karras_ve.py} +4 -4
  139. diffusers/schedulers/{scheduling_sde_vp.py → deprecated/scheduling_sde_vp.py} +4 -6
  140. diffusers/schedulers/scheduling_ddim.py +1 -3
  141. diffusers/schedulers/scheduling_ddim_inverse.py +1 -3
  142. diffusers/schedulers/scheduling_ddim_parallel.py +1 -3
  143. diffusers/schedulers/scheduling_ddpm.py +1 -3
  144. diffusers/schedulers/scheduling_ddpm_parallel.py +1 -3
  145. diffusers/schedulers/scheduling_deis_multistep.py +15 -5
  146. diffusers/schedulers/scheduling_dpmsolver_multistep.py +15 -5
  147. diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +15 -5
  148. diffusers/schedulers/scheduling_dpmsolver_sde.py +1 -3
  149. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +15 -5
  150. diffusers/schedulers/scheduling_euler_ancestral_discrete.py +1 -3
  151. diffusers/schedulers/scheduling_euler_discrete.py +40 -13
  152. diffusers/schedulers/scheduling_heun_discrete.py +15 -5
  153. diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +15 -5
  154. diffusers/schedulers/scheduling_k_dpm_2_discrete.py +15 -5
  155. diffusers/schedulers/scheduling_lcm.py +123 -29
  156. diffusers/schedulers/scheduling_lms_discrete.py +1 -3
  157. diffusers/schedulers/scheduling_pndm.py +1 -3
  158. diffusers/schedulers/scheduling_repaint.py +1 -3
  159. diffusers/schedulers/scheduling_unipc_multistep.py +15 -5
  160. diffusers/utils/__init__.py +1 -0
  161. diffusers/utils/constants.py +11 -6
  162. diffusers/utils/dummy_pt_objects.py +45 -0
  163. diffusers/utils/dummy_torch_and_transformers_objects.py +60 -0
  164. diffusers/utils/dynamic_modules_utils.py +4 -4
  165. diffusers/utils/export_utils.py +8 -3
  166. diffusers/utils/logging.py +10 -10
  167. diffusers/utils/outputs.py +5 -5
  168. diffusers/utils/peft_utils.py +88 -44
  169. diffusers/utils/torch_utils.py +2 -2
  170. diffusers/utils/versions.py +117 -0
  171. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/METADATA +83 -64
  172. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/RECORD +176 -157
  173. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/WHEEL +1 -1
  174. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/entry_points.txt +1 -0
  175. diffusers/loaders.py +0 -3336
  176. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/LICENSE +0 -0
  177. {diffusers-0.23.0.dist-info → diffusers-0.24.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,59 @@
1
+ # Copyright 2023 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Dict
16
+
17
+ import torch
18
+
19
+
20
+ class AttnProcsLayers(torch.nn.Module):
21
+ def __init__(self, state_dict: Dict[str, torch.Tensor]):
22
+ super().__init__()
23
+ self.layers = torch.nn.ModuleList(state_dict.values())
24
+ self.mapping = dict(enumerate(state_dict.keys()))
25
+ self.rev_mapping = {v: k for k, v in enumerate(state_dict.keys())}
26
+
27
+ # .processor for unet, .self_attn for text encoder
28
+ self.split_keys = [".processor", ".self_attn"]
29
+
30
+ # we add a hook to state_dict() and load_state_dict() so that the
31
+ # naming fits with `unet.attn_processors`
32
+ def map_to(module, state_dict, *args, **kwargs):
33
+ new_state_dict = {}
34
+ for key, value in state_dict.items():
35
+ num = int(key.split(".")[1]) # 0 is always "layers"
36
+ new_key = key.replace(f"layers.{num}", module.mapping[num])
37
+ new_state_dict[new_key] = value
38
+
39
+ return new_state_dict
40
+
41
+ def remap_key(key, state_dict):
42
+ for k in self.split_keys:
43
+ if k in key:
44
+ return key.split(k)[0] + k
45
+
46
+ raise ValueError(
47
+ f"There seems to be a problem with the state_dict: {set(state_dict.keys())}. {key} has to have one of {self.split_keys}."
48
+ )
49
+
50
+ def map_from(module, state_dict, *args, **kwargs):
51
+ all_keys = list(state_dict.keys())
52
+ for key in all_keys:
53
+ replace_key = remap_key(key, state_dict)
54
+ new_key = key.replace(replace_key, f"layers.{module.rev_mapping[replace_key]}")
55
+ state_dict[new_key] = state_dict[key]
56
+ del state_dict[key]
57
+
58
+ self._register_state_dict_hook(map_to)
59
+ self._register_load_state_dict_pre_hook(map_from, with_module=True)
@@ -14,7 +14,12 @@
14
14
 
15
15
  from typing import TYPE_CHECKING
16
16
 
17
- from ..utils import DIFFUSERS_SLOW_IMPORT, _LazyModule, is_flax_available, is_torch_available
17
+ from ..utils import (
18
+ DIFFUSERS_SLOW_IMPORT,
19
+ _LazyModule,
20
+ is_flax_available,
21
+ is_torch_available,
22
+ )
18
23
 
19
24
 
20
25
  _import_structure = {}
@@ -23,6 +28,7 @@ if is_torch_available():
23
28
  _import_structure["adapter"] = ["MultiAdapter", "T2IAdapter"]
24
29
  _import_structure["autoencoder_asym_kl"] = ["AsymmetricAutoencoderKL"]
25
30
  _import_structure["autoencoder_kl"] = ["AutoencoderKL"]
31
+ _import_structure["autoencoder_kl_temporal_decoder"] = ["AutoencoderKLTemporalDecoder"]
26
32
  _import_structure["autoencoder_tiny"] = ["AutoencoderTiny"]
27
33
  _import_structure["consistency_decoder_vae"] = ["ConsistencyDecoderVAE"]
28
34
  _import_structure["controlnet"] = ["ControlNetModel"]
@@ -36,7 +42,9 @@ if is_torch_available():
36
42
  _import_structure["unet_2d"] = ["UNet2DModel"]
37
43
  _import_structure["unet_2d_condition"] = ["UNet2DConditionModel"]
38
44
  _import_structure["unet_3d_condition"] = ["UNet3DConditionModel"]
45
+ _import_structure["unet_kandi3"] = ["Kandinsky3UNet"]
39
46
  _import_structure["unet_motion_model"] = ["MotionAdapter", "UNetMotionModel"]
47
+ _import_structure["unet_spatio_temporal_condition"] = ["UNetSpatioTemporalConditionModel"]
40
48
  _import_structure["vq_model"] = ["VQModel"]
41
49
 
42
50
  if is_flax_available():
@@ -50,6 +58,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
50
58
  from .adapter import MultiAdapter, T2IAdapter
51
59
  from .autoencoder_asym_kl import AsymmetricAutoencoderKL
52
60
  from .autoencoder_kl import AutoencoderKL
61
+ from .autoencoder_kl_temporal_decoder import AutoencoderKLTemporalDecoder
53
62
  from .autoencoder_tiny import AutoencoderTiny
54
63
  from .consistency_decoder_vae import ConsistencyDecoderVAE
55
64
  from .controlnet import ControlNetModel
@@ -63,7 +72,9 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
63
72
  from .unet_2d import UNet2DModel
64
73
  from .unet_2d_condition import UNet2DConditionModel
65
74
  from .unet_3d_condition import UNet3DConditionModel
75
+ from .unet_kandi3 import Kandinsky3UNet
66
76
  from .unet_motion_model import MotionAdapter, UNetMotionModel
77
+ from .unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
67
78
  from .vq_model import VQModel
68
79
 
69
80
  if is_flax_available():
@@ -25,6 +25,31 @@ from .lora import LoRACompatibleLinear
25
25
  from .normalization import AdaLayerNorm, AdaLayerNormZero
26
26
 
27
27
 
28
+ def _chunked_feed_forward(
29
+ ff: nn.Module, hidden_states: torch.Tensor, chunk_dim: int, chunk_size: int, lora_scale: Optional[float] = None
30
+ ):
31
+ # "feed_forward_chunk_size" can be used to save memory
32
+ if hidden_states.shape[chunk_dim] % chunk_size != 0:
33
+ raise ValueError(
34
+ f"`hidden_states` dimension to be chunked: {hidden_states.shape[chunk_dim]} has to be divisible by chunk size: {chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
35
+ )
36
+
37
+ num_chunks = hidden_states.shape[chunk_dim] // chunk_size
38
+ if lora_scale is None:
39
+ ff_output = torch.cat(
40
+ [ff(hid_slice) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
41
+ dim=chunk_dim,
42
+ )
43
+ else:
44
+ # TOOD(Patrick): LoRA scale can be removed once PEFT refactor is complete
45
+ ff_output = torch.cat(
46
+ [ff(hid_slice, scale=lora_scale) for hid_slice in hidden_states.chunk(num_chunks, dim=chunk_dim)],
47
+ dim=chunk_dim,
48
+ )
49
+
50
+ return ff_output
51
+
52
+
28
53
  @maybe_allow_in_graph
29
54
  class GatedSelfAttentionDense(nn.Module):
30
55
  r"""
@@ -194,7 +219,12 @@ class BasicTransformerBlock(nn.Module):
194
219
  if not self.use_ada_layer_norm_single:
195
220
  self.norm3 = nn.LayerNorm(dim, elementwise_affine=norm_elementwise_affine, eps=norm_eps)
196
221
 
197
- self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn, final_dropout=final_dropout)
222
+ self.ff = FeedForward(
223
+ dim,
224
+ dropout=dropout,
225
+ activation_fn=activation_fn,
226
+ final_dropout=final_dropout,
227
+ )
198
228
 
199
229
  # 4. Fuser
200
230
  if attention_type == "gated" or attention_type == "gated-text-image":
@@ -208,7 +238,7 @@ class BasicTransformerBlock(nn.Module):
208
238
  self._chunk_size = None
209
239
  self._chunk_dim = 0
210
240
 
211
- def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int):
241
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], dim: int = 0):
212
242
  # Sets chunk feed-forward
213
243
  self._chunk_size = chunk_size
214
244
  self._chunk_dim = dim
@@ -311,18 +341,8 @@ class BasicTransformerBlock(nn.Module):
311
341
 
312
342
  if self._chunk_size is not None:
313
343
  # "feed_forward_chunk_size" can be used to save memory
314
- if norm_hidden_states.shape[self._chunk_dim] % self._chunk_size != 0:
315
- raise ValueError(
316
- f"`hidden_states` dimension to be chunked: {norm_hidden_states.shape[self._chunk_dim]} has to be divisible by chunk size: {self._chunk_size}. Make sure to set an appropriate `chunk_size` when calling `unet.enable_forward_chunking`."
317
- )
318
-
319
- num_chunks = norm_hidden_states.shape[self._chunk_dim] // self._chunk_size
320
- ff_output = torch.cat(
321
- [
322
- self.ff(hid_slice, scale=lora_scale)
323
- for hid_slice in norm_hidden_states.chunk(num_chunks, dim=self._chunk_dim)
324
- ],
325
- dim=self._chunk_dim,
344
+ ff_output = _chunked_feed_forward(
345
+ self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size, lora_scale=lora_scale
326
346
  )
327
347
  else:
328
348
  ff_output = self.ff(norm_hidden_states, scale=lora_scale)
@@ -339,6 +359,137 @@ class BasicTransformerBlock(nn.Module):
339
359
  return hidden_states
340
360
 
341
361
 
362
+ @maybe_allow_in_graph
363
+ class TemporalBasicTransformerBlock(nn.Module):
364
+ r"""
365
+ A basic Transformer block for video like data.
366
+
367
+ Parameters:
368
+ dim (`int`): The number of channels in the input and output.
369
+ time_mix_inner_dim (`int`): The number of channels for temporal attention.
370
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
371
+ attention_head_dim (`int`): The number of channels in each head.
372
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
373
+ """
374
+
375
+ def __init__(
376
+ self,
377
+ dim: int,
378
+ time_mix_inner_dim: int,
379
+ num_attention_heads: int,
380
+ attention_head_dim: int,
381
+ cross_attention_dim: Optional[int] = None,
382
+ ):
383
+ super().__init__()
384
+ self.is_res = dim == time_mix_inner_dim
385
+
386
+ self.norm_in = nn.LayerNorm(dim)
387
+
388
+ # Define 3 blocks. Each block has its own normalization layer.
389
+ # 1. Self-Attn
390
+ self.norm_in = nn.LayerNorm(dim)
391
+ self.ff_in = FeedForward(
392
+ dim,
393
+ dim_out=time_mix_inner_dim,
394
+ activation_fn="geglu",
395
+ )
396
+
397
+ self.norm1 = nn.LayerNorm(time_mix_inner_dim)
398
+ self.attn1 = Attention(
399
+ query_dim=time_mix_inner_dim,
400
+ heads=num_attention_heads,
401
+ dim_head=attention_head_dim,
402
+ cross_attention_dim=None,
403
+ )
404
+
405
+ # 2. Cross-Attn
406
+ if cross_attention_dim is not None:
407
+ # We currently only use AdaLayerNormZero for self attention where there will only be one attention block.
408
+ # I.e. the number of returned modulation chunks from AdaLayerZero would not make sense if returned during
409
+ # the second cross attention block.
410
+ self.norm2 = nn.LayerNorm(time_mix_inner_dim)
411
+ self.attn2 = Attention(
412
+ query_dim=time_mix_inner_dim,
413
+ cross_attention_dim=cross_attention_dim,
414
+ heads=num_attention_heads,
415
+ dim_head=attention_head_dim,
416
+ ) # is self-attn if encoder_hidden_states is none
417
+ else:
418
+ self.norm2 = None
419
+ self.attn2 = None
420
+
421
+ # 3. Feed-forward
422
+ self.norm3 = nn.LayerNorm(time_mix_inner_dim)
423
+ self.ff = FeedForward(time_mix_inner_dim, activation_fn="geglu")
424
+
425
+ # let chunk size default to None
426
+ self._chunk_size = None
427
+ self._chunk_dim = None
428
+
429
+ def set_chunk_feed_forward(self, chunk_size: Optional[int], **kwargs):
430
+ # Sets chunk feed-forward
431
+ self._chunk_size = chunk_size
432
+ # chunk dim should be hardcoded to 1 to have better speed vs. memory trade-off
433
+ self._chunk_dim = 1
434
+
435
+ def forward(
436
+ self,
437
+ hidden_states: torch.FloatTensor,
438
+ num_frames: int,
439
+ encoder_hidden_states: Optional[torch.FloatTensor] = None,
440
+ ) -> torch.FloatTensor:
441
+ # Notice that normalization is always applied before the real computation in the following blocks.
442
+ # 0. Self-Attention
443
+ batch_size = hidden_states.shape[0]
444
+
445
+ batch_frames, seq_length, channels = hidden_states.shape
446
+ batch_size = batch_frames // num_frames
447
+
448
+ hidden_states = hidden_states[None, :].reshape(batch_size, num_frames, seq_length, channels)
449
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
450
+ hidden_states = hidden_states.reshape(batch_size * seq_length, num_frames, channels)
451
+
452
+ residual = hidden_states
453
+ hidden_states = self.norm_in(hidden_states)
454
+
455
+ if self._chunk_size is not None:
456
+ hidden_states = _chunked_feed_forward(self.ff, hidden_states, self._chunk_dim, self._chunk_size)
457
+ else:
458
+ hidden_states = self.ff_in(hidden_states)
459
+
460
+ if self.is_res:
461
+ hidden_states = hidden_states + residual
462
+
463
+ norm_hidden_states = self.norm1(hidden_states)
464
+ attn_output = self.attn1(norm_hidden_states, encoder_hidden_states=None)
465
+ hidden_states = attn_output + hidden_states
466
+
467
+ # 3. Cross-Attention
468
+ if self.attn2 is not None:
469
+ norm_hidden_states = self.norm2(hidden_states)
470
+ attn_output = self.attn2(norm_hidden_states, encoder_hidden_states=encoder_hidden_states)
471
+ hidden_states = attn_output + hidden_states
472
+
473
+ # 4. Feed-forward
474
+ norm_hidden_states = self.norm3(hidden_states)
475
+
476
+ if self._chunk_size is not None:
477
+ ff_output = _chunked_feed_forward(self.ff, norm_hidden_states, self._chunk_dim, self._chunk_size)
478
+ else:
479
+ ff_output = self.ff(norm_hidden_states)
480
+
481
+ if self.is_res:
482
+ hidden_states = ff_output + hidden_states
483
+ else:
484
+ hidden_states = ff_output
485
+
486
+ hidden_states = hidden_states[None, :].reshape(batch_size, seq_length, num_frames, channels)
487
+ hidden_states = hidden_states.permute(0, 2, 1, 3)
488
+ hidden_states = hidden_states.reshape(batch_size * num_frames, seq_length, channels)
489
+
490
+ return hidden_states
491
+
492
+
342
493
  class FeedForward(nn.Module):
343
494
  r"""
344
495
  A feed-forward layer.
@@ -110,7 +110,10 @@ def jax_memory_efficient_attention(
110
110
  )
111
111
 
112
112
  _, res = jax.lax.scan(
113
- f=chunk_scanner, init=0, xs=None, length=math.ceil(num_q / query_chunk_size) # start counter # stop counter
113
+ f=chunk_scanner,
114
+ init=0,
115
+ xs=None,
116
+ length=math.ceil(num_q / query_chunk_size), # start counter # stop counter
114
117
  )
115
118
 
116
119
  return jnp.concatenate(res, axis=-3) # fuse the chunked result back
@@ -138,6 +141,7 @@ class FlaxAttention(nn.Module):
138
141
  Parameters `dtype`
139
142
 
140
143
  """
144
+
141
145
  query_dim: int
142
146
  heads: int = 8
143
147
  dim_head: int = 64
@@ -262,6 +266,7 @@ class FlaxBasicTransformerBlock(nn.Module):
262
266
  Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
263
267
  enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
264
268
  """
269
+
265
270
  dim: int
266
271
  n_heads: int
267
272
  d_head: int
@@ -347,6 +352,7 @@ class FlaxTransformer2DModel(nn.Module):
347
352
  Whether to split the head dimension into a new axis for the self-attention computation. In most cases,
348
353
  enabling this flag should speed up the computation for Stable Diffusion 2.x and Stable Diffusion XL.
349
354
  """
355
+
350
356
  in_channels: int
351
357
  n_heads: int
352
358
  d_head: int
@@ -442,6 +448,7 @@ class FlaxFeedForward(nn.Module):
442
448
  dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
443
449
  Parameters `dtype`
444
450
  """
451
+
445
452
  dim: int
446
453
  dropout: float = 0.0
447
454
  dtype: jnp.dtype = jnp.float32
@@ -471,6 +478,7 @@ class FlaxGEGLU(nn.Module):
471
478
  dtype (:obj:`jnp.dtype`, *optional*, defaults to jnp.float32):
472
479
  Parameters `dtype`
473
480
  """
481
+
474
482
  dim: int
475
483
  dropout: float = 0.0
476
484
  dtype: jnp.dtype = jnp.float32
@@ -16,7 +16,7 @@ from typing import Callable, Optional, Union
16
16
 
17
17
  import torch
18
18
  import torch.nn.functional as F
19
- from torch import nn
19
+ from torch import einsum, nn
20
20
 
21
21
  from ..utils import USE_PEFT_BACKEND, deprecate, logging
22
22
  from ..utils.import_utils import is_xformers_available
@@ -1975,6 +1975,288 @@ class LoRAAttnAddedKVProcessor(nn.Module):
1975
1975
  return attn.processor(attn, hidden_states, *args, **kwargs)
1976
1976
 
1977
1977
 
1978
+ class IPAdapterAttnProcessor(nn.Module):
1979
+ r"""
1980
+ Attention processor for IP-Adapater.
1981
+
1982
+ Args:
1983
+ hidden_size (`int`):
1984
+ The hidden size of the attention layer.
1985
+ cross_attention_dim (`int`):
1986
+ The number of channels in the `encoder_hidden_states`.
1987
+ num_tokens (`int`, defaults to 4):
1988
+ The context length of the image features.
1989
+ scale (`float`, defaults to 1.0):
1990
+ the weight scale of image prompt.
1991
+ """
1992
+
1993
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
1994
+ super().__init__()
1995
+
1996
+ self.hidden_size = hidden_size
1997
+ self.cross_attention_dim = cross_attention_dim
1998
+ self.num_tokens = num_tokens
1999
+ self.scale = scale
2000
+
2001
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2002
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2003
+
2004
+ def __call__(
2005
+ self,
2006
+ attn,
2007
+ hidden_states,
2008
+ encoder_hidden_states=None,
2009
+ attention_mask=None,
2010
+ temb=None,
2011
+ scale=1.0,
2012
+ ):
2013
+ if scale != 1.0:
2014
+ logger.warning("`scale` of IPAttnProcessor should be set with `set_ip_adapter_scale`.")
2015
+ residual = hidden_states
2016
+
2017
+ if attn.spatial_norm is not None:
2018
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2019
+
2020
+ input_ndim = hidden_states.ndim
2021
+
2022
+ if input_ndim == 4:
2023
+ batch_size, channel, height, width = hidden_states.shape
2024
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2025
+
2026
+ batch_size, sequence_length, _ = (
2027
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2028
+ )
2029
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2030
+
2031
+ if attn.group_norm is not None:
2032
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2033
+
2034
+ query = attn.to_q(hidden_states)
2035
+
2036
+ if encoder_hidden_states is None:
2037
+ encoder_hidden_states = hidden_states
2038
+ elif attn.norm_cross:
2039
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2040
+
2041
+ # split hidden states
2042
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
2043
+ encoder_hidden_states, ip_hidden_states = (
2044
+ encoder_hidden_states[:, :end_pos, :],
2045
+ encoder_hidden_states[:, end_pos:, :],
2046
+ )
2047
+
2048
+ key = attn.to_k(encoder_hidden_states)
2049
+ value = attn.to_v(encoder_hidden_states)
2050
+
2051
+ query = attn.head_to_batch_dim(query)
2052
+ key = attn.head_to_batch_dim(key)
2053
+ value = attn.head_to_batch_dim(value)
2054
+
2055
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
2056
+ hidden_states = torch.bmm(attention_probs, value)
2057
+ hidden_states = attn.batch_to_head_dim(hidden_states)
2058
+
2059
+ # for ip-adapter
2060
+ ip_key = self.to_k_ip(ip_hidden_states)
2061
+ ip_value = self.to_v_ip(ip_hidden_states)
2062
+
2063
+ ip_key = attn.head_to_batch_dim(ip_key)
2064
+ ip_value = attn.head_to_batch_dim(ip_value)
2065
+
2066
+ ip_attention_probs = attn.get_attention_scores(query, ip_key, None)
2067
+ ip_hidden_states = torch.bmm(ip_attention_probs, ip_value)
2068
+ ip_hidden_states = attn.batch_to_head_dim(ip_hidden_states)
2069
+
2070
+ hidden_states = hidden_states + self.scale * ip_hidden_states
2071
+
2072
+ # linear proj
2073
+ hidden_states = attn.to_out[0](hidden_states)
2074
+ # dropout
2075
+ hidden_states = attn.to_out[1](hidden_states)
2076
+
2077
+ if input_ndim == 4:
2078
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2079
+
2080
+ if attn.residual_connection:
2081
+ hidden_states = hidden_states + residual
2082
+
2083
+ hidden_states = hidden_states / attn.rescale_output_factor
2084
+
2085
+ return hidden_states
2086
+
2087
+
2088
+ class IPAdapterAttnProcessor2_0(torch.nn.Module):
2089
+ r"""
2090
+ Attention processor for IP-Adapater for PyTorch 2.0.
2091
+
2092
+ Args:
2093
+ hidden_size (`int`):
2094
+ The hidden size of the attention layer.
2095
+ cross_attention_dim (`int`):
2096
+ The number of channels in the `encoder_hidden_states`.
2097
+ num_tokens (`int`, defaults to 4):
2098
+ The context length of the image features.
2099
+ scale (`float`, defaults to 1.0):
2100
+ the weight scale of image prompt.
2101
+ """
2102
+
2103
+ def __init__(self, hidden_size, cross_attention_dim=None, num_tokens=4, scale=1.0):
2104
+ super().__init__()
2105
+
2106
+ if not hasattr(F, "scaled_dot_product_attention"):
2107
+ raise ImportError(
2108
+ f"{self.__class__.__name__} requires PyTorch 2.0, to use it, please upgrade PyTorch to 2.0."
2109
+ )
2110
+
2111
+ self.hidden_size = hidden_size
2112
+ self.cross_attention_dim = cross_attention_dim
2113
+ self.num_tokens = num_tokens
2114
+ self.scale = scale
2115
+
2116
+ self.to_k_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2117
+ self.to_v_ip = nn.Linear(cross_attention_dim or hidden_size, hidden_size, bias=False)
2118
+
2119
+ def __call__(
2120
+ self,
2121
+ attn,
2122
+ hidden_states,
2123
+ encoder_hidden_states=None,
2124
+ attention_mask=None,
2125
+ temb=None,
2126
+ scale=1.0,
2127
+ ):
2128
+ if scale != 1.0:
2129
+ logger.warning("`scale` of IPAttnProcessor should be set by `set_ip_adapter_scale`.")
2130
+ residual = hidden_states
2131
+
2132
+ if attn.spatial_norm is not None:
2133
+ hidden_states = attn.spatial_norm(hidden_states, temb)
2134
+
2135
+ input_ndim = hidden_states.ndim
2136
+
2137
+ if input_ndim == 4:
2138
+ batch_size, channel, height, width = hidden_states.shape
2139
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
2140
+
2141
+ batch_size, sequence_length, _ = (
2142
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
2143
+ )
2144
+
2145
+ if attention_mask is not None:
2146
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
2147
+ # scaled_dot_product_attention expects attention_mask shape to be
2148
+ # (batch, heads, source_length, target_length)
2149
+ attention_mask = attention_mask.view(batch_size, attn.heads, -1, attention_mask.shape[-1])
2150
+
2151
+ if attn.group_norm is not None:
2152
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
2153
+
2154
+ query = attn.to_q(hidden_states)
2155
+
2156
+ if encoder_hidden_states is None:
2157
+ encoder_hidden_states = hidden_states
2158
+ elif attn.norm_cross:
2159
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
2160
+
2161
+ # split hidden states
2162
+ end_pos = encoder_hidden_states.shape[1] - self.num_tokens
2163
+ encoder_hidden_states, ip_hidden_states = (
2164
+ encoder_hidden_states[:, :end_pos, :],
2165
+ encoder_hidden_states[:, end_pos:, :],
2166
+ )
2167
+
2168
+ key = attn.to_k(encoder_hidden_states)
2169
+ value = attn.to_v(encoder_hidden_states)
2170
+
2171
+ inner_dim = key.shape[-1]
2172
+ head_dim = inner_dim // attn.heads
2173
+
2174
+ query = query.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2175
+
2176
+ key = key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2177
+ value = value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2178
+
2179
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2180
+ # TODO: add support for attn.scale when we move to Torch 2.1
2181
+ hidden_states = F.scaled_dot_product_attention(
2182
+ query, key, value, attn_mask=attention_mask, dropout_p=0.0, is_causal=False
2183
+ )
2184
+
2185
+ hidden_states = hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2186
+ hidden_states = hidden_states.to(query.dtype)
2187
+
2188
+ # for ip-adapter
2189
+ ip_key = self.to_k_ip(ip_hidden_states)
2190
+ ip_value = self.to_v_ip(ip_hidden_states)
2191
+
2192
+ ip_key = ip_key.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2193
+ ip_value = ip_value.view(batch_size, -1, attn.heads, head_dim).transpose(1, 2)
2194
+
2195
+ # the output of sdp = (batch, num_heads, seq_len, head_dim)
2196
+ # TODO: add support for attn.scale when we move to Torch 2.1
2197
+ ip_hidden_states = F.scaled_dot_product_attention(
2198
+ query, ip_key, ip_value, attn_mask=None, dropout_p=0.0, is_causal=False
2199
+ )
2200
+
2201
+ ip_hidden_states = ip_hidden_states.transpose(1, 2).reshape(batch_size, -1, attn.heads * head_dim)
2202
+ ip_hidden_states = ip_hidden_states.to(query.dtype)
2203
+
2204
+ hidden_states = hidden_states + self.scale * ip_hidden_states
2205
+
2206
+ # linear proj
2207
+ hidden_states = attn.to_out[0](hidden_states)
2208
+ # dropout
2209
+ hidden_states = attn.to_out[1](hidden_states)
2210
+
2211
+ if input_ndim == 4:
2212
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
2213
+
2214
+ if attn.residual_connection:
2215
+ hidden_states = hidden_states + residual
2216
+
2217
+ hidden_states = hidden_states / attn.rescale_output_factor
2218
+
2219
+ return hidden_states
2220
+
2221
+
2222
+ # TODO(Yiyi): This class should not exist, we can replace it with a normal attention processor I believe
2223
+ # this way torch.compile and co. will work as well
2224
+ class Kandi3AttnProcessor:
2225
+ r"""
2226
+ Default kandinsky3 proccesor for performing attention-related computations.
2227
+ """
2228
+
2229
+ @staticmethod
2230
+ def _reshape(hid_states, h):
2231
+ b, n, f = hid_states.shape
2232
+ d = f // h
2233
+ return hid_states.unsqueeze(-1).reshape(b, n, h, d).permute(0, 2, 1, 3)
2234
+
2235
+ def __call__(
2236
+ self,
2237
+ attn,
2238
+ x,
2239
+ context,
2240
+ context_mask=None,
2241
+ ):
2242
+ query = self._reshape(attn.to_q(x), h=attn.num_heads)
2243
+ key = self._reshape(attn.to_k(context), h=attn.num_heads)
2244
+ value = self._reshape(attn.to_v(context), h=attn.num_heads)
2245
+
2246
+ attention_matrix = einsum("b h i d, b h j d -> b h i j", query, key)
2247
+
2248
+ if context_mask is not None:
2249
+ max_neg_value = -torch.finfo(attention_matrix.dtype).max
2250
+ context_mask = context_mask.unsqueeze(1).unsqueeze(1)
2251
+ attention_matrix = attention_matrix.masked_fill(~(context_mask != 0), max_neg_value)
2252
+ attention_matrix = (attention_matrix * attn.scale).softmax(dim=-1)
2253
+
2254
+ out = einsum("b h i j, b h j d -> b h i d", attention_matrix, value)
2255
+ out = out.permute(0, 2, 1, 3).reshape(out.shape[0], out.shape[2], -1)
2256
+ out = attn.to_out[0](out)
2257
+ return out
2258
+
2259
+
1978
2260
  LORA_ATTENTION_PROCESSORS = (
1979
2261
  LoRAAttnProcessor,
1980
2262
  LoRAAttnProcessor2_0,
@@ -1998,6 +2280,9 @@ CROSS_ATTENTION_PROCESSORS = (
1998
2280
  LoRAAttnProcessor,
1999
2281
  LoRAAttnProcessor2_0,
2000
2282
  LoRAXFormersAttnProcessor,
2283
+ IPAdapterAttnProcessor,
2284
+ IPAdapterAttnProcessor2_0,
2285
+ Kandi3AttnProcessor,
2001
2286
  )
2002
2287
 
2003
2288
  AttentionProcessor = Union[