diffusers 0.28.2__py3-none-any.whl → 0.29.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. diffusers/__init__.py +9 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +1 -1
  9. diffusers/loaders/single_file_model.py +5 -0
  10. diffusers/loaders/single_file_utils.py +242 -2
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +5 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_xs.py +6 -6
  22. diffusers/models/embeddings.py +112 -84
  23. diffusers/models/model_loading_utils.py +55 -0
  24. diffusers/models/modeling_utils.py +128 -17
  25. diffusers/models/normalization.py +11 -6
  26. diffusers/models/transformers/__init__.py +1 -0
  27. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  28. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  29. diffusers/models/transformers/prior_transformer.py +5 -5
  30. diffusers/models/transformers/transformer_2d.py +2 -2
  31. diffusers/models/transformers/transformer_sd3.py +344 -0
  32. diffusers/models/transformers/transformer_temporal.py +12 -10
  33. diffusers/models/unets/unet_1d.py +3 -3
  34. diffusers/models/unets/unet_2d.py +3 -3
  35. diffusers/models/unets/unet_2d_condition.py +4 -15
  36. diffusers/models/unets/unet_3d_condition.py +5 -17
  37. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  38. diffusers/models/unets/unet_motion_model.py +4 -4
  39. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  40. diffusers/models/vq_model.py +8 -165
  41. diffusers/pipelines/__init__.py +2 -0
  42. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  43. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  44. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  45. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  46. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  47. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  48. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  49. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  50. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  51. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  52. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  54. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  55. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  56. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  57. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  58. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  59. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  60. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  61. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  62. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  63. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  64. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  65. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  69. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  70. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  71. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +886 -0
  72. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +923 -0
  73. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  74. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  75. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  76. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  77. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  78. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  79. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  80. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  81. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  82. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  83. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  84. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  85. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  86. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  87. diffusers/schedulers/__init__.py +2 -0
  88. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  89. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  90. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  91. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  92. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  93. diffusers/training_utils.py +4 -4
  94. diffusers/utils/__init__.py +3 -0
  95. diffusers/utils/constants.py +2 -0
  96. diffusers/utils/dummy_pt_objects.py +30 -0
  97. diffusers/utils/dummy_torch_and_transformers_objects.py +30 -0
  98. diffusers/utils/dynamic_modules_utils.py +15 -13
  99. diffusers/utils/hub_utils.py +106 -0
  100. diffusers/utils/import_utils.py +0 -1
  101. diffusers/utils/logging.py +3 -1
  102. diffusers/utils/state_dict_utils.py +2 -0
  103. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/METADATA +45 -45
  104. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/RECORD +108 -111
  105. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/WHEEL +1 -1
  106. diffusers/models/dual_transformer_2d.py +0 -20
  107. diffusers/models/prior_transformer.py +0 -12
  108. diffusers/models/t5_film_transformer.py +0 -70
  109. diffusers/models/transformer_2d.py +0 -25
  110. diffusers/models/transformer_temporal.py +0 -34
  111. diffusers/models/unet_1d.py +0 -26
  112. diffusers/models/unet_1d_blocks.py +0 -203
  113. diffusers/models/unet_2d.py +0 -27
  114. diffusers/models/unet_2d_blocks.py +0 -375
  115. diffusers/models/unet_2d_condition.py +0 -25
  116. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/LICENSE +0 -0
  117. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/entry_points.txt +0 -0
  118. {diffusers-0.28.2.dist-info → diffusers-0.29.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,182 @@
1
+ # Copyright 2024 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from dataclasses import dataclass
15
+ from typing import Optional, Tuple, Union
16
+
17
+ import torch
18
+ import torch.nn as nn
19
+
20
+ from ...configuration_utils import ConfigMixin, register_to_config
21
+ from ...utils import BaseOutput
22
+ from ...utils.accelerate_utils import apply_forward_hook
23
+ from ..autoencoders.vae import Decoder, DecoderOutput, Encoder, VectorQuantizer
24
+ from ..modeling_utils import ModelMixin
25
+
26
+
27
+ @dataclass
28
+ class VQEncoderOutput(BaseOutput):
29
+ """
30
+ Output of VQModel encoding method.
31
+
32
+ Args:
33
+ latents (`torch.Tensor` of shape `(batch_size, num_channels, height, width)`):
34
+ The encoded output sample from the last layer of the model.
35
+ """
36
+
37
+ latents: torch.Tensor
38
+
39
+
40
+ class VQModel(ModelMixin, ConfigMixin):
41
+ r"""
42
+ A VQ-VAE model for decoding latent representations.
43
+
44
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
45
+ for all models (such as downloading or saving).
46
+
47
+ Parameters:
48
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
49
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
50
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("DownEncoderBlock2D",)`):
51
+ Tuple of downsample block types.
52
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpDecoderBlock2D",)`):
53
+ Tuple of upsample block types.
54
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(64,)`):
55
+ Tuple of block output channels.
56
+ layers_per_block (`int`, *optional*, defaults to `1`): Number of layers per block.
57
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
58
+ latent_channels (`int`, *optional*, defaults to `3`): Number of channels in the latent space.
59
+ sample_size (`int`, *optional*, defaults to `32`): Sample input size.
60
+ num_vq_embeddings (`int`, *optional*, defaults to `256`): Number of codebook vectors in the VQ-VAE.
61
+ norm_num_groups (`int`, *optional*, defaults to `32`): Number of groups for normalization layers.
62
+ vq_embed_dim (`int`, *optional*): Hidden dim of codebook vectors in the VQ-VAE.
63
+ scaling_factor (`float`, *optional*, defaults to `0.18215`):
64
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
65
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
66
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
67
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
68
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
69
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
70
+ norm_type (`str`, *optional*, defaults to `"group"`):
71
+ Type of normalization layer to use. Can be one of `"group"` or `"spatial"`.
72
+ """
73
+
74
+ @register_to_config
75
+ def __init__(
76
+ self,
77
+ in_channels: int = 3,
78
+ out_channels: int = 3,
79
+ down_block_types: Tuple[str, ...] = ("DownEncoderBlock2D",),
80
+ up_block_types: Tuple[str, ...] = ("UpDecoderBlock2D",),
81
+ block_out_channels: Tuple[int, ...] = (64,),
82
+ layers_per_block: int = 1,
83
+ act_fn: str = "silu",
84
+ latent_channels: int = 3,
85
+ sample_size: int = 32,
86
+ num_vq_embeddings: int = 256,
87
+ norm_num_groups: int = 32,
88
+ vq_embed_dim: Optional[int] = None,
89
+ scaling_factor: float = 0.18215,
90
+ norm_type: str = "group", # group, spatial
91
+ mid_block_add_attention=True,
92
+ lookup_from_codebook=False,
93
+ force_upcast=False,
94
+ ):
95
+ super().__init__()
96
+
97
+ # pass init params to Encoder
98
+ self.encoder = Encoder(
99
+ in_channels=in_channels,
100
+ out_channels=latent_channels,
101
+ down_block_types=down_block_types,
102
+ block_out_channels=block_out_channels,
103
+ layers_per_block=layers_per_block,
104
+ act_fn=act_fn,
105
+ norm_num_groups=norm_num_groups,
106
+ double_z=False,
107
+ mid_block_add_attention=mid_block_add_attention,
108
+ )
109
+
110
+ vq_embed_dim = vq_embed_dim if vq_embed_dim is not None else latent_channels
111
+
112
+ self.quant_conv = nn.Conv2d(latent_channels, vq_embed_dim, 1)
113
+ self.quantize = VectorQuantizer(num_vq_embeddings, vq_embed_dim, beta=0.25, remap=None, sane_index_shape=False)
114
+ self.post_quant_conv = nn.Conv2d(vq_embed_dim, latent_channels, 1)
115
+
116
+ # pass init params to Decoder
117
+ self.decoder = Decoder(
118
+ in_channels=latent_channels,
119
+ out_channels=out_channels,
120
+ up_block_types=up_block_types,
121
+ block_out_channels=block_out_channels,
122
+ layers_per_block=layers_per_block,
123
+ act_fn=act_fn,
124
+ norm_num_groups=norm_num_groups,
125
+ norm_type=norm_type,
126
+ mid_block_add_attention=mid_block_add_attention,
127
+ )
128
+
129
+ @apply_forward_hook
130
+ def encode(self, x: torch.Tensor, return_dict: bool = True) -> VQEncoderOutput:
131
+ h = self.encoder(x)
132
+ h = self.quant_conv(h)
133
+
134
+ if not return_dict:
135
+ return (h,)
136
+
137
+ return VQEncoderOutput(latents=h)
138
+
139
+ @apply_forward_hook
140
+ def decode(
141
+ self, h: torch.Tensor, force_not_quantize: bool = False, return_dict: bool = True, shape=None
142
+ ) -> Union[DecoderOutput, torch.Tensor]:
143
+ # also go through quantization layer
144
+ if not force_not_quantize:
145
+ quant, commit_loss, _ = self.quantize(h)
146
+ elif self.config.lookup_from_codebook:
147
+ quant = self.quantize.get_codebook_entry(h, shape)
148
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
149
+ else:
150
+ quant = h
151
+ commit_loss = torch.zeros((h.shape[0])).to(h.device, dtype=h.dtype)
152
+ quant2 = self.post_quant_conv(quant)
153
+ dec = self.decoder(quant2, quant if self.config.norm_type == "spatial" else None)
154
+
155
+ if not return_dict:
156
+ return dec, commit_loss
157
+
158
+ return DecoderOutput(sample=dec, commit_loss=commit_loss)
159
+
160
+ def forward(
161
+ self, sample: torch.Tensor, return_dict: bool = True
162
+ ) -> Union[DecoderOutput, Tuple[torch.Tensor, ...]]:
163
+ r"""
164
+ The [`VQModel`] forward method.
165
+
166
+ Args:
167
+ sample (`torch.Tensor`): Input sample.
168
+ return_dict (`bool`, *optional*, defaults to `True`):
169
+ Whether or not to return a [`models.vq_model.VQEncoderOutput`] instead of a plain tuple.
170
+
171
+ Returns:
172
+ [`~models.vq_model.VQEncoderOutput`] or `tuple`:
173
+ If return_dict is True, a [`~models.vq_model.VQEncoderOutput`] is returned, otherwise a plain `tuple`
174
+ is returned.
175
+ """
176
+
177
+ h = self.encode(sample).latents
178
+ dec = self.decode(h)
179
+
180
+ if not return_dict:
181
+ return dec.sample, dec.commit_loss
182
+ return dec
@@ -851,8 +851,8 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
851
851
  if hasattr(module, "gradient_checkpointing"):
852
852
  module.gradient_checkpointing = value
853
853
 
854
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel
855
854
  @property
855
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
856
856
  def attn_processors(self) -> Dict[str, AttentionProcessor]:
857
857
  r"""
858
858
  Returns:
@@ -911,7 +911,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
911
911
  for name, module in self.named_children():
912
912
  fn_recursive_attn_processor(name, module, processor)
913
913
 
914
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
914
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_default_attn_processor
915
915
  def set_default_attn_processor(self):
916
916
  """
917
917
  Disables custom attention processors and sets the default attention implementation.
@@ -927,7 +927,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
927
927
 
928
928
  self.set_attn_processor(processor)
929
929
 
930
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
930
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.enable_freeu
931
931
  def enable_freeu(self, s1: float, s2: float, b1: float, b2: float):
932
932
  r"""Enables the FreeU mechanism from https://arxiv.org/abs/2309.11497.
933
933
 
@@ -952,7 +952,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
952
952
  setattr(upsample_block, "b1", b1)
953
953
  setattr(upsample_block, "b2", b2)
954
954
 
955
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
955
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.disable_freeu
956
956
  def disable_freeu(self):
957
957
  """Disables the FreeU mechanism."""
958
958
  freeu_keys = {"s1", "s2", "b1", "b2"}
@@ -961,7 +961,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
961
961
  if hasattr(upsample_block, k) or getattr(upsample_block, k, None) is not None:
962
962
  setattr(upsample_block, k, None)
963
963
 
964
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
964
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
965
965
  def fuse_qkv_projections(self):
966
966
  """
967
967
  Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
@@ -985,7 +985,7 @@ class UNetControlNetXSModel(ModelMixin, ConfigMixin):
985
985
  if isinstance(module, Attention):
986
986
  module.fuse_projections(fuse=True)
987
987
 
988
- # copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
988
+ # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
989
989
  def unfuse_qkv_projections(self):
990
990
  """Disables the fused QKV projection if enabled.
991
991
 
@@ -123,7 +123,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
123
123
 
124
124
 
125
125
  class PatchEmbed(nn.Module):
126
- """2D Image to Patch Embedding"""
126
+ """2D Image to Patch Embedding with support for SD3 cropping."""
127
127
 
128
128
  def __init__(
129
129
  self,
@@ -137,12 +137,14 @@ class PatchEmbed(nn.Module):
137
137
  bias=True,
138
138
  interpolation_scale=1,
139
139
  pos_embed_type="sincos",
140
+ pos_embed_max_size=None, # For SD3 cropping
140
141
  ):
141
142
  super().__init__()
142
143
 
143
144
  num_patches = (height // patch_size) * (width // patch_size)
144
145
  self.flatten = flatten
145
146
  self.layer_norm = layer_norm
147
+ self.pos_embed_max_size = pos_embed_max_size
146
148
 
147
149
  self.proj = nn.Conv2d(
148
150
  in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
@@ -153,26 +155,55 @@ class PatchEmbed(nn.Module):
153
155
  self.norm = None
154
156
 
155
157
  self.patch_size = patch_size
156
- # See:
157
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
158
158
  self.height, self.width = height // patch_size, width // patch_size
159
159
  self.base_size = height // patch_size
160
160
  self.interpolation_scale = interpolation_scale
161
+
162
+ # Calculate positional embeddings based on max size or default
163
+ if pos_embed_max_size:
164
+ grid_size = pos_embed_max_size
165
+ else:
166
+ grid_size = int(num_patches**0.5)
167
+
161
168
  if pos_embed_type is None:
162
169
  self.pos_embed = None
163
170
  elif pos_embed_type == "sincos":
164
171
  pos_embed = get_2d_sincos_pos_embed(
165
- embed_dim,
166
- int(num_patches**0.5),
167
- base_size=self.base_size,
168
- interpolation_scale=self.interpolation_scale,
172
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
169
173
  )
170
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
174
+ persistent = True if pos_embed_max_size else False
175
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
171
176
  else:
172
177
  raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
173
178
 
179
+ def cropped_pos_embed(self, height, width):
180
+ """Crops positional embeddings for SD3 compatibility."""
181
+ if self.pos_embed_max_size is None:
182
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
183
+
184
+ height = height // self.patch_size
185
+ width = width // self.patch_size
186
+ if height > self.pos_embed_max_size:
187
+ raise ValueError(
188
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
189
+ )
190
+ if width > self.pos_embed_max_size:
191
+ raise ValueError(
192
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
193
+ )
194
+
195
+ top = (self.pos_embed_max_size - height) // 2
196
+ left = (self.pos_embed_max_size - width) // 2
197
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
198
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
199
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
200
+ return spatial_pos_embed
201
+
174
202
  def forward(self, latent):
175
- height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
203
+ if self.pos_embed_max_size is not None:
204
+ height, width = latent.shape[-2:]
205
+ else:
206
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
176
207
 
177
208
  latent = self.proj(latent)
178
209
  if self.flatten:
@@ -181,20 +212,20 @@ class PatchEmbed(nn.Module):
181
212
  latent = self.norm(latent)
182
213
  if self.pos_embed is None:
183
214
  return latent.to(latent.dtype)
184
-
185
- # Interpolate positional embeddings if needed.
186
- # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
187
- if self.height != height or self.width != width:
188
- pos_embed = get_2d_sincos_pos_embed(
189
- embed_dim=self.pos_embed.shape[-1],
190
- grid_size=(height, width),
191
- base_size=self.base_size,
192
- interpolation_scale=self.interpolation_scale,
193
- )
194
- pos_embed = torch.from_numpy(pos_embed)
195
- pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
215
+ # Interpolate or crop positional embeddings as needed
216
+ if self.pos_embed_max_size:
217
+ pos_embed = self.cropped_pos_embed(height, width)
196
218
  else:
197
- pos_embed = self.pos_embed
219
+ if self.height != height or self.width != width:
220
+ pos_embed = get_2d_sincos_pos_embed(
221
+ embed_dim=self.pos_embed.shape[-1],
222
+ grid_size=(height, width),
223
+ base_size=self.base_size,
224
+ interpolation_scale=self.interpolation_scale,
225
+ )
226
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
227
+ else:
228
+ pos_embed = self.pos_embed
198
229
 
199
230
  return (latent + pos_embed).to(latent.dtype)
200
231
 
@@ -626,6 +657,25 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
626
657
  return conditioning
627
658
 
628
659
 
660
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
661
+ def __init__(self, embedding_dim, pooled_projection_dim):
662
+ super().__init__()
663
+
664
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
665
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
666
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
667
+
668
+ def forward(self, timestep, pooled_projection):
669
+ timesteps_proj = self.time_proj(timestep)
670
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
671
+
672
+ pooled_projections = self.text_embedder(pooled_projection)
673
+
674
+ conditioning = timesteps_emb + pooled_projections
675
+
676
+ return conditioning
677
+
678
+
629
679
  class HunyuanDiTAttentionPool(nn.Module):
630
680
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
631
681
 
@@ -1001,6 +1051,8 @@ class PixArtAlphaTextProjection(nn.Module):
1001
1051
  self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
1002
1052
  if act_fn == "gelu_tanh":
1003
1053
  self.act_1 = nn.GELU(approximate="tanh")
1054
+ elif act_fn == "silu":
1055
+ self.act_1 = nn.SiLU()
1004
1056
  elif act_fn == "silu_fp32":
1005
1057
  self.act_1 = FP32SiLU()
1006
1058
  else:
@@ -1014,6 +1066,39 @@ class PixArtAlphaTextProjection(nn.Module):
1014
1066
  return hidden_states
1015
1067
 
1016
1068
 
1069
+ class IPAdapterPlusImageProjectionBlock(nn.Module):
1070
+ def __init__(
1071
+ self,
1072
+ embed_dims: int = 768,
1073
+ dim_head: int = 64,
1074
+ heads: int = 16,
1075
+ ffn_ratio: float = 4,
1076
+ ) -> None:
1077
+ super().__init__()
1078
+ from .attention import FeedForward
1079
+
1080
+ self.ln0 = nn.LayerNorm(embed_dims)
1081
+ self.ln1 = nn.LayerNorm(embed_dims)
1082
+ self.attn = Attention(
1083
+ query_dim=embed_dims,
1084
+ dim_head=dim_head,
1085
+ heads=heads,
1086
+ out_bias=False,
1087
+ )
1088
+ self.ff = nn.Sequential(
1089
+ nn.LayerNorm(embed_dims),
1090
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1091
+ )
1092
+
1093
+ def forward(self, x, latents, residual):
1094
+ encoder_hidden_states = self.ln0(x)
1095
+ latents = self.ln1(latents)
1096
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1097
+ latents = self.attn(latents, encoder_hidden_states) + residual
1098
+ latents = self.ff(latents) + latents
1099
+ return latents
1100
+
1101
+
1017
1102
  class IPAdapterPlusImageProjection(nn.Module):
1018
1103
  """Resampler of IP-Adapter Plus.
1019
1104
 
@@ -1042,8 +1127,6 @@ class IPAdapterPlusImageProjection(nn.Module):
1042
1127
  ffn_ratio: float = 4,
1043
1128
  ) -> None:
1044
1129
  super().__init__()
1045
- from .attention import FeedForward # Lazy import to avoid circular import
1046
-
1047
1130
  self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
1048
1131
 
1049
1132
  self.proj_in = nn.Linear(embed_dims, hidden_dims)
@@ -1051,26 +1134,9 @@ class IPAdapterPlusImageProjection(nn.Module):
1051
1134
  self.proj_out = nn.Linear(hidden_dims, output_dims)
1052
1135
  self.norm_out = nn.LayerNorm(output_dims)
1053
1136
 
1054
- self.layers = nn.ModuleList([])
1055
- for _ in range(depth):
1056
- self.layers.append(
1057
- nn.ModuleList(
1058
- [
1059
- nn.LayerNorm(hidden_dims),
1060
- nn.LayerNorm(hidden_dims),
1061
- Attention(
1062
- query_dim=hidden_dims,
1063
- dim_head=dim_head,
1064
- heads=heads,
1065
- out_bias=False,
1066
- ),
1067
- nn.Sequential(
1068
- nn.LayerNorm(hidden_dims),
1069
- FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1070
- ),
1071
- ]
1072
- )
1073
- )
1137
+ self.layers = nn.ModuleList(
1138
+ [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
1139
+ )
1074
1140
 
1075
1141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1076
1142
  """Forward pass.
@@ -1084,52 +1150,14 @@ class IPAdapterPlusImageProjection(nn.Module):
1084
1150
 
1085
1151
  x = self.proj_in(x)
1086
1152
 
1087
- for ln0, ln1, attn, ff in self.layers:
1153
+ for block in self.layers:
1088
1154
  residual = latents
1089
-
1090
- encoder_hidden_states = ln0(x)
1091
- latents = ln1(latents)
1092
- encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1093
- latents = attn(latents, encoder_hidden_states) + residual
1094
- latents = ff(latents) + latents
1155
+ latents = block(x, latents, residual)
1095
1156
 
1096
1157
  latents = self.proj_out(latents)
1097
1158
  return self.norm_out(latents)
1098
1159
 
1099
1160
 
1100
- class IPAdapterPlusImageProjectionBlock(nn.Module):
1101
- def __init__(
1102
- self,
1103
- embed_dims: int = 768,
1104
- dim_head: int = 64,
1105
- heads: int = 16,
1106
- ffn_ratio: float = 4,
1107
- ) -> None:
1108
- super().__init__()
1109
- from .attention import FeedForward
1110
-
1111
- self.ln0 = nn.LayerNorm(embed_dims)
1112
- self.ln1 = nn.LayerNorm(embed_dims)
1113
- self.attn = Attention(
1114
- query_dim=embed_dims,
1115
- dim_head=dim_head,
1116
- heads=heads,
1117
- out_bias=False,
1118
- )
1119
- self.ff = nn.Sequential(
1120
- nn.LayerNorm(embed_dims),
1121
- FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1122
- )
1123
-
1124
- def forward(self, x, latents, residual):
1125
- encoder_hidden_states = self.ln0(x)
1126
- latents = self.ln1(latents)
1127
- encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1128
- latents = self.attn(latents, encoder_hidden_states) + residual
1129
- latents = self.ff(latents) + latents
1130
- return latents
1131
-
1132
-
1133
1161
  class IPAdapterFaceIDPlusImageProjection(nn.Module):
1134
1162
  """FacePerceiverResampler of IP-Adapter Plus.
1135
1163
 
@@ -18,13 +18,19 @@ import importlib
18
18
  import inspect
19
19
  import os
20
20
  from collections import OrderedDict
21
+ from pathlib import Path
21
22
  from typing import List, Optional, Union
22
23
 
23
24
  import safetensors
24
25
  import torch
26
+ from huggingface_hub.utils import EntryNotFoundError
25
27
 
26
28
  from ..utils import (
29
+ SAFE_WEIGHTS_INDEX_NAME,
27
30
  SAFETENSORS_FILE_EXTENSION,
31
+ WEIGHTS_INDEX_NAME,
32
+ _add_variant,
33
+ _get_model_file,
28
34
  is_accelerate_available,
29
35
  is_torch_version,
30
36
  logging,
@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
175
181
  load(model_to_load)
176
182
 
177
183
  return error_msgs
184
+
185
+
186
+ def _fetch_index_file(
187
+ is_local,
188
+ pretrained_model_name_or_path,
189
+ subfolder,
190
+ use_safetensors,
191
+ cache_dir,
192
+ variant,
193
+ force_download,
194
+ resume_download,
195
+ proxies,
196
+ local_files_only,
197
+ token,
198
+ revision,
199
+ user_agent,
200
+ commit_hash,
201
+ ):
202
+ if is_local:
203
+ index_file = Path(
204
+ pretrained_model_name_or_path,
205
+ subfolder or "",
206
+ _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
207
+ )
208
+ else:
209
+ index_file_in_repo = Path(
210
+ subfolder or "",
211
+ _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
212
+ ).as_posix()
213
+ try:
214
+ index_file = _get_model_file(
215
+ pretrained_model_name_or_path,
216
+ weights_name=index_file_in_repo,
217
+ cache_dir=cache_dir,
218
+ force_download=force_download,
219
+ resume_download=resume_download,
220
+ proxies=proxies,
221
+ local_files_only=local_files_only,
222
+ token=token,
223
+ revision=revision,
224
+ subfolder=subfolder,
225
+ user_agent=user_agent,
226
+ commit_hash=commit_hash,
227
+ )
228
+ index_file = Path(index_file)
229
+ except (EntryNotFoundError, EnvironmentError):
230
+ index_file = None
231
+
232
+ return index_file