diffusers 0.28.2__py3-none-any.whl → 0.29.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. diffusers/__init__.py +15 -1
  2. diffusers/commands/env.py +1 -5
  3. diffusers/dependency_versions_table.py +1 -1
  4. diffusers/image_processor.py +2 -1
  5. diffusers/loaders/__init__.py +2 -2
  6. diffusers/loaders/lora.py +406 -140
  7. diffusers/loaders/lora_conversion_utils.py +7 -1
  8. diffusers/loaders/single_file.py +13 -1
  9. diffusers/loaders/single_file_model.py +15 -8
  10. diffusers/loaders/single_file_utils.py +267 -17
  11. diffusers/loaders/unet.py +307 -272
  12. diffusers/models/__init__.py +7 -3
  13. diffusers/models/attention.py +125 -1
  14. diffusers/models/attention_processor.py +169 -1
  15. diffusers/models/autoencoders/__init__.py +1 -0
  16. diffusers/models/autoencoders/autoencoder_asym_kl.py +1 -1
  17. diffusers/models/autoencoders/autoencoder_kl.py +17 -6
  18. diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +4 -2
  19. diffusers/models/autoencoders/consistency_decoder_vae.py +9 -9
  20. diffusers/models/autoencoders/vq_model.py +182 -0
  21. diffusers/models/controlnet_sd3.py +418 -0
  22. diffusers/models/controlnet_xs.py +6 -6
  23. diffusers/models/embeddings.py +112 -84
  24. diffusers/models/model_loading_utils.py +55 -0
  25. diffusers/models/modeling_utils.py +138 -20
  26. diffusers/models/normalization.py +11 -6
  27. diffusers/models/transformers/__init__.py +1 -0
  28. diffusers/models/transformers/dual_transformer_2d.py +5 -4
  29. diffusers/models/transformers/hunyuan_transformer_2d.py +149 -2
  30. diffusers/models/transformers/prior_transformer.py +5 -5
  31. diffusers/models/transformers/transformer_2d.py +2 -2
  32. diffusers/models/transformers/transformer_sd3.py +353 -0
  33. diffusers/models/transformers/transformer_temporal.py +12 -10
  34. diffusers/models/unets/unet_1d.py +3 -3
  35. diffusers/models/unets/unet_2d.py +3 -3
  36. diffusers/models/unets/unet_2d_condition.py +4 -15
  37. diffusers/models/unets/unet_3d_condition.py +5 -17
  38. diffusers/models/unets/unet_i2vgen_xl.py +4 -4
  39. diffusers/models/unets/unet_motion_model.py +4 -4
  40. diffusers/models/unets/unet_spatio_temporal_condition.py +3 -3
  41. diffusers/models/vq_model.py +8 -165
  42. diffusers/pipelines/__init__.py +11 -0
  43. diffusers/pipelines/animatediff/pipeline_animatediff.py +4 -3
  44. diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +4 -3
  45. diffusers/pipelines/auto_pipeline.py +8 -0
  46. diffusers/pipelines/controlnet/pipeline_controlnet.py +4 -3
  47. diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +4 -3
  48. diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +4 -3
  49. diffusers/pipelines/controlnet_sd3/__init__.py +53 -0
  50. diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +1062 -0
  51. diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +4 -3
  52. diffusers/pipelines/deepfloyd_if/watermark.py +1 -1
  53. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +4 -3
  54. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +4 -3
  55. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +4 -3
  56. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +4 -3
  57. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +4 -3
  58. diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +24 -5
  59. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +4 -3
  60. diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +4 -3
  61. diffusers/pipelines/marigold/marigold_image_processing.py +35 -20
  62. diffusers/pipelines/pia/pipeline_pia.py +4 -3
  63. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
  64. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
  65. diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +17 -17
  66. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +4 -3
  67. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +5 -4
  68. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +4 -3
  69. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +4 -3
  70. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +4 -3
  71. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +4 -3
  72. diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +7 -6
  73. diffusers/pipelines/stable_diffusion_3/__init__.py +52 -0
  74. diffusers/pipelines/stable_diffusion_3/pipeline_output.py +21 -0
  75. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +904 -0
  76. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +941 -0
  77. diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +4 -3
  78. diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +10 -11
  79. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +4 -3
  80. diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +4 -3
  81. diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +4 -3
  82. diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +4 -3
  83. diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +4 -3
  84. diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +4 -3
  85. diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +4 -3
  86. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +4 -3
  87. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +4 -3
  88. diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +4 -3
  89. diffusers/pipelines/unidiffuser/modeling_uvit.py +1 -1
  90. diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +4 -3
  91. diffusers/schedulers/__init__.py +2 -0
  92. diffusers/schedulers/scheduling_dpmsolver_sde.py +2 -2
  93. diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -3
  94. diffusers/schedulers/scheduling_edm_euler.py +2 -4
  95. diffusers/schedulers/scheduling_flow_match_euler_discrete.py +287 -0
  96. diffusers/schedulers/scheduling_lms_discrete.py +2 -2
  97. diffusers/training_utils.py +4 -4
  98. diffusers/utils/__init__.py +3 -0
  99. diffusers/utils/constants.py +2 -0
  100. diffusers/utils/dummy_pt_objects.py +60 -0
  101. diffusers/utils/dummy_torch_and_transformers_objects.py +45 -0
  102. diffusers/utils/dynamic_modules_utils.py +15 -13
  103. diffusers/utils/hub_utils.py +106 -0
  104. diffusers/utils/import_utils.py +0 -1
  105. diffusers/utils/logging.py +3 -1
  106. diffusers/utils/state_dict_utils.py +2 -0
  107. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/METADATA +3 -3
  108. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/RECORD +112 -112
  109. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/WHEEL +1 -1
  110. diffusers/models/dual_transformer_2d.py +0 -20
  111. diffusers/models/prior_transformer.py +0 -12
  112. diffusers/models/t5_film_transformer.py +0 -70
  113. diffusers/models/transformer_2d.py +0 -25
  114. diffusers/models/transformer_temporal.py +0 -34
  115. diffusers/models/unet_1d.py +0 -26
  116. diffusers/models/unet_1d_blocks.py +0 -203
  117. diffusers/models/unet_2d.py +0 -27
  118. diffusers/models/unet_2d_blocks.py +0 -375
  119. diffusers/models/unet_2d_condition.py +0 -25
  120. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/LICENSE +0 -0
  121. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/entry_points.txt +0 -0
  122. {diffusers-0.28.2.dist-info → diffusers-0.29.1.dist-info}/top_level.txt +0 -0
@@ -123,7 +123,7 @@ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
123
123
 
124
124
 
125
125
  class PatchEmbed(nn.Module):
126
- """2D Image to Patch Embedding"""
126
+ """2D Image to Patch Embedding with support for SD3 cropping."""
127
127
 
128
128
  def __init__(
129
129
  self,
@@ -137,12 +137,14 @@ class PatchEmbed(nn.Module):
137
137
  bias=True,
138
138
  interpolation_scale=1,
139
139
  pos_embed_type="sincos",
140
+ pos_embed_max_size=None, # For SD3 cropping
140
141
  ):
141
142
  super().__init__()
142
143
 
143
144
  num_patches = (height // patch_size) * (width // patch_size)
144
145
  self.flatten = flatten
145
146
  self.layer_norm = layer_norm
147
+ self.pos_embed_max_size = pos_embed_max_size
146
148
 
147
149
  self.proj = nn.Conv2d(
148
150
  in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
@@ -153,26 +155,55 @@ class PatchEmbed(nn.Module):
153
155
  self.norm = None
154
156
 
155
157
  self.patch_size = patch_size
156
- # See:
157
- # https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L161
158
158
  self.height, self.width = height // patch_size, width // patch_size
159
159
  self.base_size = height // patch_size
160
160
  self.interpolation_scale = interpolation_scale
161
+
162
+ # Calculate positional embeddings based on max size or default
163
+ if pos_embed_max_size:
164
+ grid_size = pos_embed_max_size
165
+ else:
166
+ grid_size = int(num_patches**0.5)
167
+
161
168
  if pos_embed_type is None:
162
169
  self.pos_embed = None
163
170
  elif pos_embed_type == "sincos":
164
171
  pos_embed = get_2d_sincos_pos_embed(
165
- embed_dim,
166
- int(num_patches**0.5),
167
- base_size=self.base_size,
168
- interpolation_scale=self.interpolation_scale,
172
+ embed_dim, grid_size, base_size=self.base_size, interpolation_scale=self.interpolation_scale
169
173
  )
170
- self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=False)
174
+ persistent = True if pos_embed_max_size else False
175
+ self.register_buffer("pos_embed", torch.from_numpy(pos_embed).float().unsqueeze(0), persistent=persistent)
171
176
  else:
172
177
  raise ValueError(f"Unsupported pos_embed_type: {pos_embed_type}")
173
178
 
179
+ def cropped_pos_embed(self, height, width):
180
+ """Crops positional embeddings for SD3 compatibility."""
181
+ if self.pos_embed_max_size is None:
182
+ raise ValueError("`pos_embed_max_size` must be set for cropping.")
183
+
184
+ height = height // self.patch_size
185
+ width = width // self.patch_size
186
+ if height > self.pos_embed_max_size:
187
+ raise ValueError(
188
+ f"Height ({height}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
189
+ )
190
+ if width > self.pos_embed_max_size:
191
+ raise ValueError(
192
+ f"Width ({width}) cannot be greater than `pos_embed_max_size`: {self.pos_embed_max_size}."
193
+ )
194
+
195
+ top = (self.pos_embed_max_size - height) // 2
196
+ left = (self.pos_embed_max_size - width) // 2
197
+ spatial_pos_embed = self.pos_embed.reshape(1, self.pos_embed_max_size, self.pos_embed_max_size, -1)
198
+ spatial_pos_embed = spatial_pos_embed[:, top : top + height, left : left + width, :]
199
+ spatial_pos_embed = spatial_pos_embed.reshape(1, -1, spatial_pos_embed.shape[-1])
200
+ return spatial_pos_embed
201
+
174
202
  def forward(self, latent):
175
- height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
203
+ if self.pos_embed_max_size is not None:
204
+ height, width = latent.shape[-2:]
205
+ else:
206
+ height, width = latent.shape[-2] // self.patch_size, latent.shape[-1] // self.patch_size
176
207
 
177
208
  latent = self.proj(latent)
178
209
  if self.flatten:
@@ -181,20 +212,20 @@ class PatchEmbed(nn.Module):
181
212
  latent = self.norm(latent)
182
213
  if self.pos_embed is None:
183
214
  return latent.to(latent.dtype)
184
-
185
- # Interpolate positional embeddings if needed.
186
- # (For PixArt-Alpha: https://github.com/PixArt-alpha/PixArt-alpha/blob/0f55e922376d8b797edd44d25d0e7464b260dcab/diffusion/model/nets/PixArtMS.py#L162C151-L162C160)
187
- if self.height != height or self.width != width:
188
- pos_embed = get_2d_sincos_pos_embed(
189
- embed_dim=self.pos_embed.shape[-1],
190
- grid_size=(height, width),
191
- base_size=self.base_size,
192
- interpolation_scale=self.interpolation_scale,
193
- )
194
- pos_embed = torch.from_numpy(pos_embed)
195
- pos_embed = pos_embed.float().unsqueeze(0).to(latent.device)
215
+ # Interpolate or crop positional embeddings as needed
216
+ if self.pos_embed_max_size:
217
+ pos_embed = self.cropped_pos_embed(height, width)
196
218
  else:
197
- pos_embed = self.pos_embed
219
+ if self.height != height or self.width != width:
220
+ pos_embed = get_2d_sincos_pos_embed(
221
+ embed_dim=self.pos_embed.shape[-1],
222
+ grid_size=(height, width),
223
+ base_size=self.base_size,
224
+ interpolation_scale=self.interpolation_scale,
225
+ )
226
+ pos_embed = torch.from_numpy(pos_embed).float().unsqueeze(0).to(latent.device)
227
+ else:
228
+ pos_embed = self.pos_embed
198
229
 
199
230
  return (latent + pos_embed).to(latent.dtype)
200
231
 
@@ -626,6 +657,25 @@ class CombinedTimestepLabelEmbeddings(nn.Module):
626
657
  return conditioning
627
658
 
628
659
 
660
+ class CombinedTimestepTextProjEmbeddings(nn.Module):
661
+ def __init__(self, embedding_dim, pooled_projection_dim):
662
+ super().__init__()
663
+
664
+ self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
665
+ self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
666
+ self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
667
+
668
+ def forward(self, timestep, pooled_projection):
669
+ timesteps_proj = self.time_proj(timestep)
670
+ timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
671
+
672
+ pooled_projections = self.text_embedder(pooled_projection)
673
+
674
+ conditioning = timesteps_emb + pooled_projections
675
+
676
+ return conditioning
677
+
678
+
629
679
  class HunyuanDiTAttentionPool(nn.Module):
630
680
  # Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
631
681
 
@@ -1001,6 +1051,8 @@ class PixArtAlphaTextProjection(nn.Module):
1001
1051
  self.linear_1 = nn.Linear(in_features=in_features, out_features=hidden_size, bias=True)
1002
1052
  if act_fn == "gelu_tanh":
1003
1053
  self.act_1 = nn.GELU(approximate="tanh")
1054
+ elif act_fn == "silu":
1055
+ self.act_1 = nn.SiLU()
1004
1056
  elif act_fn == "silu_fp32":
1005
1057
  self.act_1 = FP32SiLU()
1006
1058
  else:
@@ -1014,6 +1066,39 @@ class PixArtAlphaTextProjection(nn.Module):
1014
1066
  return hidden_states
1015
1067
 
1016
1068
 
1069
+ class IPAdapterPlusImageProjectionBlock(nn.Module):
1070
+ def __init__(
1071
+ self,
1072
+ embed_dims: int = 768,
1073
+ dim_head: int = 64,
1074
+ heads: int = 16,
1075
+ ffn_ratio: float = 4,
1076
+ ) -> None:
1077
+ super().__init__()
1078
+ from .attention import FeedForward
1079
+
1080
+ self.ln0 = nn.LayerNorm(embed_dims)
1081
+ self.ln1 = nn.LayerNorm(embed_dims)
1082
+ self.attn = Attention(
1083
+ query_dim=embed_dims,
1084
+ dim_head=dim_head,
1085
+ heads=heads,
1086
+ out_bias=False,
1087
+ )
1088
+ self.ff = nn.Sequential(
1089
+ nn.LayerNorm(embed_dims),
1090
+ FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1091
+ )
1092
+
1093
+ def forward(self, x, latents, residual):
1094
+ encoder_hidden_states = self.ln0(x)
1095
+ latents = self.ln1(latents)
1096
+ encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1097
+ latents = self.attn(latents, encoder_hidden_states) + residual
1098
+ latents = self.ff(latents) + latents
1099
+ return latents
1100
+
1101
+
1017
1102
  class IPAdapterPlusImageProjection(nn.Module):
1018
1103
  """Resampler of IP-Adapter Plus.
1019
1104
 
@@ -1042,8 +1127,6 @@ class IPAdapterPlusImageProjection(nn.Module):
1042
1127
  ffn_ratio: float = 4,
1043
1128
  ) -> None:
1044
1129
  super().__init__()
1045
- from .attention import FeedForward # Lazy import to avoid circular import
1046
-
1047
1130
  self.latents = nn.Parameter(torch.randn(1, num_queries, hidden_dims) / hidden_dims**0.5)
1048
1131
 
1049
1132
  self.proj_in = nn.Linear(embed_dims, hidden_dims)
@@ -1051,26 +1134,9 @@ class IPAdapterPlusImageProjection(nn.Module):
1051
1134
  self.proj_out = nn.Linear(hidden_dims, output_dims)
1052
1135
  self.norm_out = nn.LayerNorm(output_dims)
1053
1136
 
1054
- self.layers = nn.ModuleList([])
1055
- for _ in range(depth):
1056
- self.layers.append(
1057
- nn.ModuleList(
1058
- [
1059
- nn.LayerNorm(hidden_dims),
1060
- nn.LayerNorm(hidden_dims),
1061
- Attention(
1062
- query_dim=hidden_dims,
1063
- dim_head=dim_head,
1064
- heads=heads,
1065
- out_bias=False,
1066
- ),
1067
- nn.Sequential(
1068
- nn.LayerNorm(hidden_dims),
1069
- FeedForward(hidden_dims, hidden_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1070
- ),
1071
- ]
1072
- )
1073
- )
1137
+ self.layers = nn.ModuleList(
1138
+ [IPAdapterPlusImageProjectionBlock(hidden_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
1139
+ )
1074
1140
 
1075
1141
  def forward(self, x: torch.Tensor) -> torch.Tensor:
1076
1142
  """Forward pass.
@@ -1084,52 +1150,14 @@ class IPAdapterPlusImageProjection(nn.Module):
1084
1150
 
1085
1151
  x = self.proj_in(x)
1086
1152
 
1087
- for ln0, ln1, attn, ff in self.layers:
1153
+ for block in self.layers:
1088
1154
  residual = latents
1089
-
1090
- encoder_hidden_states = ln0(x)
1091
- latents = ln1(latents)
1092
- encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1093
- latents = attn(latents, encoder_hidden_states) + residual
1094
- latents = ff(latents) + latents
1155
+ latents = block(x, latents, residual)
1095
1156
 
1096
1157
  latents = self.proj_out(latents)
1097
1158
  return self.norm_out(latents)
1098
1159
 
1099
1160
 
1100
- class IPAdapterPlusImageProjectionBlock(nn.Module):
1101
- def __init__(
1102
- self,
1103
- embed_dims: int = 768,
1104
- dim_head: int = 64,
1105
- heads: int = 16,
1106
- ffn_ratio: float = 4,
1107
- ) -> None:
1108
- super().__init__()
1109
- from .attention import FeedForward
1110
-
1111
- self.ln0 = nn.LayerNorm(embed_dims)
1112
- self.ln1 = nn.LayerNorm(embed_dims)
1113
- self.attn = Attention(
1114
- query_dim=embed_dims,
1115
- dim_head=dim_head,
1116
- heads=heads,
1117
- out_bias=False,
1118
- )
1119
- self.ff = nn.Sequential(
1120
- nn.LayerNorm(embed_dims),
1121
- FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
1122
- )
1123
-
1124
- def forward(self, x, latents, residual):
1125
- encoder_hidden_states = self.ln0(x)
1126
- latents = self.ln1(latents)
1127
- encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
1128
- latents = self.attn(latents, encoder_hidden_states) + residual
1129
- latents = self.ff(latents) + latents
1130
- return latents
1131
-
1132
-
1133
1161
  class IPAdapterFaceIDPlusImageProjection(nn.Module):
1134
1162
  """FacePerceiverResampler of IP-Adapter Plus.
1135
1163
 
@@ -18,13 +18,19 @@ import importlib
18
18
  import inspect
19
19
  import os
20
20
  from collections import OrderedDict
21
+ from pathlib import Path
21
22
  from typing import List, Optional, Union
22
23
 
23
24
  import safetensors
24
25
  import torch
26
+ from huggingface_hub.utils import EntryNotFoundError
25
27
 
26
28
  from ..utils import (
29
+ SAFE_WEIGHTS_INDEX_NAME,
27
30
  SAFETENSORS_FILE_EXTENSION,
31
+ WEIGHTS_INDEX_NAME,
32
+ _add_variant,
33
+ _get_model_file,
28
34
  is_accelerate_available,
29
35
  is_torch_version,
30
36
  logging,
@@ -175,3 +181,52 @@ def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[
175
181
  load(model_to_load)
176
182
 
177
183
  return error_msgs
184
+
185
+
186
+ def _fetch_index_file(
187
+ is_local,
188
+ pretrained_model_name_or_path,
189
+ subfolder,
190
+ use_safetensors,
191
+ cache_dir,
192
+ variant,
193
+ force_download,
194
+ resume_download,
195
+ proxies,
196
+ local_files_only,
197
+ token,
198
+ revision,
199
+ user_agent,
200
+ commit_hash,
201
+ ):
202
+ if is_local:
203
+ index_file = Path(
204
+ pretrained_model_name_or_path,
205
+ subfolder or "",
206
+ _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
207
+ )
208
+ else:
209
+ index_file_in_repo = Path(
210
+ subfolder or "",
211
+ _add_variant(SAFE_WEIGHTS_INDEX_NAME if use_safetensors else WEIGHTS_INDEX_NAME, variant),
212
+ ).as_posix()
213
+ try:
214
+ index_file = _get_model_file(
215
+ pretrained_model_name_or_path,
216
+ weights_name=index_file_in_repo,
217
+ cache_dir=cache_dir,
218
+ force_download=force_download,
219
+ resume_download=resume_download,
220
+ proxies=proxies,
221
+ local_files_only=local_files_only,
222
+ token=token,
223
+ revision=revision,
224
+ subfolder=subfolder,
225
+ user_agent=user_agent,
226
+ commit_hash=commit_hash,
227
+ )
228
+ index_file = Path(index_file)
229
+ except (EntryNotFoundError, EnvironmentError):
230
+ index_file = None
231
+
232
+ return index_file
@@ -16,6 +16,7 @@
16
16
 
17
17
  import inspect
18
18
  import itertools
19
+ import json
19
20
  import os
20
21
  import re
21
22
  from collections import OrderedDict
@@ -25,7 +26,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union
25
26
 
26
27
  import safetensors
27
28
  import torch
28
- from huggingface_hub import create_repo
29
+ from huggingface_hub import create_repo, split_torch_state_dict_into_shards
29
30
  from huggingface_hub.utils import validate_hf_hub_args
30
31
  from torch import Tensor, nn
31
32
 
@@ -33,9 +34,12 @@ from .. import __version__
33
34
  from ..utils import (
34
35
  CONFIG_NAME,
35
36
  FLAX_WEIGHTS_NAME,
37
+ SAFE_WEIGHTS_INDEX_NAME,
36
38
  SAFETENSORS_WEIGHTS_NAME,
39
+ WEIGHTS_INDEX_NAME,
37
40
  WEIGHTS_NAME,
38
41
  _add_variant,
42
+ _get_checkpoint_shard_files,
39
43
  _get_model_file,
40
44
  deprecate,
41
45
  is_accelerate_available,
@@ -49,6 +53,7 @@ from ..utils.hub_utils import (
49
53
  )
50
54
  from .model_loading_utils import (
51
55
  _determine_device_map,
56
+ _fetch_index_file,
52
57
  _load_state_dict_into_model,
53
58
  load_model_dict_into_meta,
54
59
  load_state_dict,
@@ -57,6 +62,8 @@ from .model_loading_utils import (
57
62
 
58
63
  logger = logging.get_logger(__name__)
59
64
 
65
+ _REGEX_SHARD = re.compile(r"(.*?)-\d{5}-of-\d{5}")
66
+
60
67
 
61
68
  if is_torch_version(">=", "1.9.0"):
62
69
  _LOW_CPU_MEM_USAGE_DEFAULT = True
@@ -263,6 +270,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
263
270
  save_function: Optional[Callable] = None,
264
271
  safe_serialization: bool = True,
265
272
  variant: Optional[str] = None,
273
+ max_shard_size: Union[int, str] = "10GB",
266
274
  push_to_hub: bool = False,
267
275
  **kwargs,
268
276
  ):
@@ -285,6 +293,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
285
293
  Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
286
294
  variant (`str`, *optional*):
287
295
  If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
296
+ max_shard_size (`int` or `str`, defaults to `"10GB"`):
297
+ The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
298
+ lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
299
+ If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
300
+ period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
301
+ This is to establish a common default size for this argument across different libraries in the Hugging
302
+ Face ecosystem (`transformers`, and `accelerate`, for example).
288
303
  push_to_hub (`bool`, *optional*, defaults to `False`):
289
304
  Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
290
305
  repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
@@ -296,6 +311,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
296
311
  logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
297
312
  return
298
313
 
314
+ weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
315
+ weights_name = _add_variant(weights_name, variant)
316
+ weight_name_split = weights_name.split(".")
317
+ if len(weight_name_split) in [2, 3]:
318
+ weights_name_pattern = weight_name_split[0] + "{suffix}." + ".".join(weight_name_split[1:])
319
+ else:
320
+ raise ValueError(f"Invalid {weights_name} provided.")
321
+
299
322
  os.makedirs(save_directory, exist_ok=True)
300
323
 
301
324
  if push_to_hub:
@@ -317,18 +340,58 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
317
340
  # Save the model
318
341
  state_dict = model_to_save.state_dict()
319
342
 
320
- weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
321
- weights_name = _add_variant(weights_name, variant)
322
-
323
343
  # Save the model
324
- if safe_serialization:
325
- safetensors.torch.save_file(
326
- state_dict, Path(save_directory, weights_name).as_posix(), metadata={"format": "pt"}
344
+ state_dict_split = split_torch_state_dict_into_shards(
345
+ state_dict, max_shard_size=max_shard_size, filename_pattern=weights_name_pattern
346
+ )
347
+
348
+ # Clean the folder from a previous save
349
+ if is_main_process:
350
+ for filename in os.listdir(save_directory):
351
+ if filename in state_dict_split.filename_to_tensors.keys():
352
+ continue
353
+ full_filename = os.path.join(save_directory, filename)
354
+ if not os.path.isfile(full_filename):
355
+ continue
356
+ weights_without_ext = weights_name_pattern.replace(".bin", "").replace(".safetensors", "")
357
+ weights_without_ext = weights_without_ext.replace("{suffix}", "")
358
+ filename_without_ext = filename.replace(".bin", "").replace(".safetensors", "")
359
+ # make sure that file to be deleted matches format of sharded file, e.g. pytorch_model-00001-of-00005
360
+ if (
361
+ filename.startswith(weights_without_ext)
362
+ and _REGEX_SHARD.fullmatch(filename_without_ext) is not None
363
+ ):
364
+ os.remove(full_filename)
365
+
366
+ for filename, tensors in state_dict_split.filename_to_tensors.items():
367
+ shard = {tensor: state_dict[tensor] for tensor in tensors}
368
+ filepath = os.path.join(save_directory, filename)
369
+ if safe_serialization:
370
+ # At some point we will need to deal better with save_function (used for TPU and other distributed
371
+ # joyfulness), but for now this enough.
372
+ safetensors.torch.save_file(shard, filepath, metadata={"format": "pt"})
373
+ else:
374
+ torch.save(shard, filepath)
375
+
376
+ if state_dict_split.is_sharded:
377
+ index = {
378
+ "metadata": state_dict_split.metadata,
379
+ "weight_map": state_dict_split.tensor_to_filename,
380
+ }
381
+ save_index_file = SAFE_WEIGHTS_INDEX_NAME if safe_serialization else WEIGHTS_INDEX_NAME
382
+ save_index_file = os.path.join(save_directory, _add_variant(save_index_file, variant))
383
+ # Save the index as well
384
+ with open(save_index_file, "w", encoding="utf-8") as f:
385
+ content = json.dumps(index, indent=2, sort_keys=True) + "\n"
386
+ f.write(content)
387
+ logger.info(
388
+ f"The model is bigger than the maximum size per checkpoint ({max_shard_size}) and is going to be "
389
+ f"split in {len(state_dict_split.filename_to_tensors)} checkpoint shards. You can find where each parameters has been saved in the "
390
+ f"index located at {save_index_file}."
327
391
  )
328
392
  else:
329
- torch.save(state_dict, Path(save_directory, weights_name).as_posix())
330
-
331
- logger.info(f"Model weights saved in {Path(save_directory, weights_name).as_posix()}")
393
+ path_to_weights = os.path.join(save_directory, weights_name)
394
+ logger.info(f"Model weights saved in {path_to_weights}")
332
395
 
333
396
  if push_to_hub:
334
397
  # Create a new empty model card and eventually tag it
@@ -399,7 +462,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
399
462
  device_map (`str` or `Dict[str, Union[int, str, torch.device]]`, *optional*):
400
463
  A map that specifies where each submodule should go. It doesn't need to be defined for each
401
464
  parameter/buffer name; once a given module name is inside, every submodule of it will be sent to the
402
- same device.
465
+ same device. Defaults to `None`, meaning that the model will be loaded on CPU.
403
466
 
404
467
  Set `device_map="auto"` to have 🤗 Accelerate automatically compute the most optimized `device_map`. For
405
468
  more information about each option see [designing a device
@@ -566,6 +629,32 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
566
629
  **kwargs,
567
630
  )
568
631
 
632
+ # Determine if we're loading from a directory of sharded checkpoints.
633
+ is_sharded = False
634
+ index_file = None
635
+ is_local = os.path.isdir(pretrained_model_name_or_path)
636
+ index_file = _fetch_index_file(
637
+ is_local=is_local,
638
+ pretrained_model_name_or_path=pretrained_model_name_or_path,
639
+ subfolder=subfolder or "",
640
+ use_safetensors=use_safetensors,
641
+ cache_dir=cache_dir,
642
+ variant=variant,
643
+ force_download=force_download,
644
+ resume_download=resume_download,
645
+ proxies=proxies,
646
+ local_files_only=local_files_only,
647
+ token=token,
648
+ revision=revision,
649
+ user_agent=user_agent,
650
+ commit_hash=commit_hash,
651
+ )
652
+ if index_file is not None and index_file.is_file():
653
+ is_sharded = True
654
+
655
+ if is_sharded and from_flax:
656
+ raise ValueError("Loading of sharded checkpoints is not supported when `from_flax=True`.")
657
+
569
658
  # load model
570
659
  model_file = None
571
660
  if from_flax:
@@ -590,7 +679,21 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
590
679
 
591
680
  model = load_flax_checkpoint_in_pytorch_model(model, model_file)
592
681
  else:
593
- if use_safetensors:
682
+ if is_sharded:
683
+ sharded_ckpt_cached_folder, sharded_metadata = _get_checkpoint_shard_files(
684
+ pretrained_model_name_or_path,
685
+ index_file,
686
+ cache_dir=cache_dir,
687
+ proxies=proxies,
688
+ resume_download=resume_download,
689
+ local_files_only=local_files_only,
690
+ token=token,
691
+ user_agent=user_agent,
692
+ revision=revision,
693
+ subfolder=subfolder or "",
694
+ )
695
+
696
+ elif use_safetensors and not is_sharded:
594
697
  try:
595
698
  model_file = _get_model_file(
596
699
  pretrained_model_name_or_path,
@@ -606,11 +709,16 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
606
709
  user_agent=user_agent,
607
710
  commit_hash=commit_hash,
608
711
  )
712
+
609
713
  except IOError as e:
714
+ logger.error(f"An error occurred while trying to fetch {pretrained_model_name_or_path}: {e}")
610
715
  if not allow_pickle:
611
- raise e
612
- pass
613
- if model_file is None:
716
+ raise
717
+ logger.warning(
718
+ "Defaulting to unsafe serialization. Pass `allow_pickle=False` to raise an error instead."
719
+ )
720
+
721
+ if model_file is None and not is_sharded:
614
722
  model_file = _get_model_file(
615
723
  pretrained_model_name_or_path,
616
724
  weights_name=_add_variant(WEIGHTS_NAME, variant),
@@ -632,7 +740,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
632
740
  model = cls.from_config(config, **unused_kwargs)
633
741
 
634
742
  # if device_map is None, load the state dict and move the params from meta device to the cpu
635
- if device_map is None:
743
+ if device_map is None and not is_sharded:
636
744
  param_device = "cpu"
637
745
  state_dict = load_state_dict(model_file, variant=variant)
638
746
  model._convert_deprecated_attention_blocks(state_dict)
@@ -666,17 +774,22 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
666
774
  else: # else let accelerate handle loading and dispatching.
667
775
  # Load weights and dispatch according to the device_map
668
776
  # by default the device_map is None and the weights are loaded on the CPU
777
+ force_hook = True
669
778
  device_map = _determine_device_map(model, device_map, max_memory, torch_dtype)
779
+ if device_map is None and is_sharded:
780
+ # we load the parameters on the cpu
781
+ device_map = {"": "cpu"}
782
+ force_hook = False
670
783
  try:
671
784
  accelerate.load_checkpoint_and_dispatch(
672
785
  model,
673
- model_file,
786
+ model_file if not is_sharded else sharded_ckpt_cached_folder,
674
787
  device_map,
675
788
  max_memory=max_memory,
676
789
  offload_folder=offload_folder,
677
790
  offload_state_dict=offload_state_dict,
678
791
  dtype=torch_dtype,
679
- force_hooks=True,
792
+ force_hooks=force_hook,
680
793
  strict=True,
681
794
  )
682
795
  except AttributeError as e:
@@ -700,12 +813,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
700
813
  model._temp_convert_self_to_deprecated_attention_blocks()
701
814
  accelerate.load_checkpoint_and_dispatch(
702
815
  model,
703
- model_file,
816
+ model_file if not is_sharded else sharded_ckpt_cached_folder,
704
817
  device_map,
705
818
  max_memory=max_memory,
706
819
  offload_folder=offload_folder,
707
820
  offload_state_dict=offload_state_dict,
708
821
  dtype=torch_dtype,
822
+ force_hook=force_hook,
823
+ strict=True,
709
824
  )
710
825
  model._undo_temp_convert_self_to_deprecated_attention_blocks()
711
826
  else:
@@ -1057,6 +1172,9 @@ class LegacyModelMixin(ModelMixin):
1057
1172
  # To prevent depedency import problem.
1058
1173
  from .model_loading_utils import _fetch_remapped_cls_from_config
1059
1174
 
1175
+ # Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
1176
+ kwargs_copy = kwargs.copy()
1177
+
1060
1178
  cache_dir = kwargs.pop("cache_dir", None)
1061
1179
  force_download = kwargs.pop("force_download", False)
1062
1180
  resume_download = kwargs.pop("resume_download", None)
@@ -1094,4 +1212,4 @@ class LegacyModelMixin(ModelMixin):
1094
1212
  # resolve remapping
1095
1213
  remapped_class = _fetch_remapped_cls_from_config(config, cls)
1096
1214
 
1097
- return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs)
1215
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)