diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
diffusers/models/embeddings.py
CHANGED
@@ -35,10 +35,21 @@ def get_timestep_embedding(
|
|
35
35
|
"""
|
36
36
|
This matches the implementation in Denoising Diffusion Probabilistic Models: Create sinusoidal timestep embeddings.
|
37
37
|
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
38
|
+
Args
|
39
|
+
timesteps (torch.Tensor):
|
40
|
+
a 1-D Tensor of N indices, one per batch element. These may be fractional.
|
41
|
+
embedding_dim (int):
|
42
|
+
the dimension of the output.
|
43
|
+
flip_sin_to_cos (bool):
|
44
|
+
Whether the embedding order should be `cos, sin` (if True) or `sin, cos` (if False)
|
45
|
+
downscale_freq_shift (float):
|
46
|
+
Controls the delta between frequencies between dimensions
|
47
|
+
scale (float):
|
48
|
+
Scaling factor applied to the embeddings.
|
49
|
+
max_period (int):
|
50
|
+
Controls the maximum frequency of the embeddings
|
51
|
+
Returns
|
52
|
+
torch.Tensor: an [N x dim] Tensor of positional embeddings.
|
42
53
|
"""
|
43
54
|
assert len(timesteps.shape) == 1, "Timesteps should be a 1d-array"
|
44
55
|
|
@@ -67,6 +78,53 @@ def get_timestep_embedding(
|
|
67
78
|
return emb
|
68
79
|
|
69
80
|
|
81
|
+
def get_3d_sincos_pos_embed(
|
82
|
+
embed_dim: int,
|
83
|
+
spatial_size: Union[int, Tuple[int, int]],
|
84
|
+
temporal_size: int,
|
85
|
+
spatial_interpolation_scale: float = 1.0,
|
86
|
+
temporal_interpolation_scale: float = 1.0,
|
87
|
+
) -> np.ndarray:
|
88
|
+
r"""
|
89
|
+
Args:
|
90
|
+
embed_dim (`int`):
|
91
|
+
spatial_size (`int` or `Tuple[int, int]`):
|
92
|
+
temporal_size (`int`):
|
93
|
+
spatial_interpolation_scale (`float`, defaults to 1.0):
|
94
|
+
temporal_interpolation_scale (`float`, defaults to 1.0):
|
95
|
+
"""
|
96
|
+
if embed_dim % 4 != 0:
|
97
|
+
raise ValueError("`embed_dim` must be divisible by 4")
|
98
|
+
if isinstance(spatial_size, int):
|
99
|
+
spatial_size = (spatial_size, spatial_size)
|
100
|
+
|
101
|
+
embed_dim_spatial = 3 * embed_dim // 4
|
102
|
+
embed_dim_temporal = embed_dim // 4
|
103
|
+
|
104
|
+
# 1. Spatial
|
105
|
+
grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
|
106
|
+
grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
|
107
|
+
grid = np.meshgrid(grid_w, grid_h) # here w goes first
|
108
|
+
grid = np.stack(grid, axis=0)
|
109
|
+
|
110
|
+
grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
|
111
|
+
pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
|
112
|
+
|
113
|
+
# 2. Temporal
|
114
|
+
grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
|
115
|
+
pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
|
116
|
+
|
117
|
+
# 3. Concat
|
118
|
+
pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
|
119
|
+
pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
|
120
|
+
|
121
|
+
pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
|
122
|
+
pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
|
123
|
+
|
124
|
+
pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
|
125
|
+
return pos_embed
|
126
|
+
|
127
|
+
|
70
128
|
def get_2d_sincos_pos_embed(
|
71
129
|
embed_dim, grid_size, cls_token=False, extra_tokens=0, interpolation_scale=1.0, base_size=16
|
72
130
|
):
|
@@ -230,6 +288,92 @@ class PatchEmbed(nn.Module):
|
|
230
288
|
return (latent + pos_embed).to(latent.dtype)
|
231
289
|
|
232
290
|
|
291
|
+
class LuminaPatchEmbed(nn.Module):
|
292
|
+
"""2D Image to Patch Embedding with support for Lumina-T2X"""
|
293
|
+
|
294
|
+
def __init__(self, patch_size=2, in_channels=4, embed_dim=768, bias=True):
|
295
|
+
super().__init__()
|
296
|
+
self.patch_size = patch_size
|
297
|
+
self.proj = nn.Linear(
|
298
|
+
in_features=patch_size * patch_size * in_channels,
|
299
|
+
out_features=embed_dim,
|
300
|
+
bias=bias,
|
301
|
+
)
|
302
|
+
|
303
|
+
def forward(self, x, freqs_cis):
|
304
|
+
"""
|
305
|
+
Patchifies and embeds the input tensor(s).
|
306
|
+
|
307
|
+
Args:
|
308
|
+
x (List[torch.Tensor] | torch.Tensor): The input tensor(s) to be patchified and embedded.
|
309
|
+
|
310
|
+
Returns:
|
311
|
+
Tuple[torch.Tensor, torch.Tensor, List[Tuple[int, int]], torch.Tensor]: A tuple containing the patchified
|
312
|
+
and embedded tensor(s), the mask indicating the valid patches, the original image size(s), and the
|
313
|
+
frequency tensor(s).
|
314
|
+
"""
|
315
|
+
freqs_cis = freqs_cis.to(x[0].device)
|
316
|
+
patch_height = patch_width = self.patch_size
|
317
|
+
batch_size, channel, height, width = x.size()
|
318
|
+
height_tokens, width_tokens = height // patch_height, width // patch_width
|
319
|
+
|
320
|
+
x = x.view(batch_size, channel, height_tokens, patch_height, width_tokens, patch_width).permute(
|
321
|
+
0, 2, 4, 1, 3, 5
|
322
|
+
)
|
323
|
+
x = x.flatten(3)
|
324
|
+
x = self.proj(x)
|
325
|
+
x = x.flatten(1, 2)
|
326
|
+
|
327
|
+
mask = torch.ones(x.shape[0], x.shape[1], dtype=torch.int32, device=x.device)
|
328
|
+
|
329
|
+
return (
|
330
|
+
x,
|
331
|
+
mask,
|
332
|
+
[(height, width)] * batch_size,
|
333
|
+
freqs_cis[:height_tokens, :width_tokens].flatten(0, 1).unsqueeze(0),
|
334
|
+
)
|
335
|
+
|
336
|
+
|
337
|
+
class CogVideoXPatchEmbed(nn.Module):
|
338
|
+
def __init__(
|
339
|
+
self,
|
340
|
+
patch_size: int = 2,
|
341
|
+
in_channels: int = 16,
|
342
|
+
embed_dim: int = 1920,
|
343
|
+
text_embed_dim: int = 4096,
|
344
|
+
bias: bool = True,
|
345
|
+
) -> None:
|
346
|
+
super().__init__()
|
347
|
+
self.patch_size = patch_size
|
348
|
+
|
349
|
+
self.proj = nn.Conv2d(
|
350
|
+
in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
|
351
|
+
)
|
352
|
+
self.text_proj = nn.Linear(text_embed_dim, embed_dim)
|
353
|
+
|
354
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
355
|
+
r"""
|
356
|
+
Args:
|
357
|
+
text_embeds (`torch.Tensor`):
|
358
|
+
Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
|
359
|
+
image_embeds (`torch.Tensor`):
|
360
|
+
Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
|
361
|
+
"""
|
362
|
+
text_embeds = self.text_proj(text_embeds)
|
363
|
+
|
364
|
+
batch, num_frames, channels, height, width = image_embeds.shape
|
365
|
+
image_embeds = image_embeds.reshape(-1, channels, height, width)
|
366
|
+
image_embeds = self.proj(image_embeds)
|
367
|
+
image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
|
368
|
+
image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
|
369
|
+
image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
|
370
|
+
|
371
|
+
embeds = torch.cat(
|
372
|
+
[text_embeds, image_embeds], dim=1
|
373
|
+
).contiguous() # [batch, seq_length + num_frames x height x width, channels]
|
374
|
+
return embeds
|
375
|
+
|
376
|
+
|
233
377
|
def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
234
378
|
"""
|
235
379
|
RoPE for image tokens with 2d structure.
|
@@ -245,7 +389,7 @@ def get_2d_rotary_pos_embed(embed_dim, crops_coords, grid_size, use_real=True):
|
|
245
389
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
246
390
|
|
247
391
|
Returns:
|
248
|
-
`torch.Tensor`: positional
|
392
|
+
`torch.Tensor`: positional embedding with shape `( grid_size * grid_size, embed_dim/2)`.
|
249
393
|
"""
|
250
394
|
start, stop = crops_coords
|
251
395
|
grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
|
@@ -262,19 +406,47 @@ def get_2d_rotary_pos_embed_from_grid(embed_dim, grid, use_real=False):
|
|
262
406
|
assert embed_dim % 4 == 0
|
263
407
|
|
264
408
|
# use half of dimensions to encode grid_h
|
265
|
-
emb_h = get_1d_rotary_pos_embed(
|
266
|
-
|
409
|
+
emb_h = get_1d_rotary_pos_embed(
|
410
|
+
embed_dim // 2, grid[0].reshape(-1), use_real=use_real
|
411
|
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
412
|
+
emb_w = get_1d_rotary_pos_embed(
|
413
|
+
embed_dim // 2, grid[1].reshape(-1), use_real=use_real
|
414
|
+
) # (H*W, D/2) if use_real else (H*W, D/4)
|
267
415
|
|
268
416
|
if use_real:
|
269
|
-
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D
|
270
|
-
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D
|
417
|
+
cos = torch.cat([emb_h[0], emb_w[0]], dim=1) # (H*W, D)
|
418
|
+
sin = torch.cat([emb_h[1], emb_w[1]], dim=1) # (H*W, D)
|
271
419
|
return cos, sin
|
272
420
|
else:
|
273
421
|
emb = torch.cat([emb_h, emb_w], dim=1) # (H*W, D/2)
|
274
422
|
return emb
|
275
423
|
|
276
424
|
|
277
|
-
def
|
425
|
+
def get_2d_rotary_pos_embed_lumina(embed_dim, len_h, len_w, linear_factor=1.0, ntk_factor=1.0):
|
426
|
+
assert embed_dim % 4 == 0
|
427
|
+
|
428
|
+
emb_h = get_1d_rotary_pos_embed(
|
429
|
+
embed_dim // 2, len_h, linear_factor=linear_factor, ntk_factor=ntk_factor
|
430
|
+
) # (H, D/4)
|
431
|
+
emb_w = get_1d_rotary_pos_embed(
|
432
|
+
embed_dim // 2, len_w, linear_factor=linear_factor, ntk_factor=ntk_factor
|
433
|
+
) # (W, D/4)
|
434
|
+
emb_h = emb_h.view(len_h, 1, embed_dim // 4, 1).repeat(1, len_w, 1, 1) # (H, W, D/4, 1)
|
435
|
+
emb_w = emb_w.view(1, len_w, embed_dim // 4, 1).repeat(len_h, 1, 1, 1) # (H, W, D/4, 1)
|
436
|
+
|
437
|
+
emb = torch.cat([emb_h, emb_w], dim=-1).flatten(2) # (H, W, D/2)
|
438
|
+
return emb
|
439
|
+
|
440
|
+
|
441
|
+
def get_1d_rotary_pos_embed(
|
442
|
+
dim: int,
|
443
|
+
pos: Union[np.ndarray, int],
|
444
|
+
theta: float = 10000.0,
|
445
|
+
use_real=False,
|
446
|
+
linear_factor=1.0,
|
447
|
+
ntk_factor=1.0,
|
448
|
+
repeat_interleave_real=True,
|
449
|
+
):
|
278
450
|
"""
|
279
451
|
Precompute the frequency tensor for complex exponentials (cis) with given dimensions.
|
280
452
|
|
@@ -289,19 +461,32 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
|
289
461
|
Scaling factor for frequency computation. Defaults to 10000.0.
|
290
462
|
use_real (`bool`, *optional*):
|
291
463
|
If True, return real part and imaginary part separately. Otherwise, return complex numbers.
|
292
|
-
|
464
|
+
linear_factor (`float`, *optional*, defaults to 1.0):
|
465
|
+
Scaling factor for the context extrapolation. Defaults to 1.0.
|
466
|
+
ntk_factor (`float`, *optional*, defaults to 1.0):
|
467
|
+
Scaling factor for the NTK-Aware RoPE. Defaults to 1.0.
|
468
|
+
repeat_interleave_real (`bool`, *optional*, defaults to `True`):
|
469
|
+
If `True` and `use_real`, real part and imaginary part are each interleaved with themselves to reach `dim`.
|
470
|
+
Otherwise, they are concateanted with themselves.
|
293
471
|
Returns:
|
294
472
|
`torch.Tensor`: Precomputed frequency tensor with complex exponentials. [S, D/2]
|
295
473
|
"""
|
474
|
+
assert dim % 2 == 0
|
475
|
+
|
296
476
|
if isinstance(pos, int):
|
297
477
|
pos = np.arange(pos)
|
298
|
-
|
478
|
+
theta = theta * ntk_factor
|
479
|
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) / linear_factor # [D/2]
|
299
480
|
t = torch.from_numpy(pos).to(freqs.device) # type: ignore # [S]
|
300
481
|
freqs = torch.outer(t, freqs).float() # type: ignore # [S, D/2]
|
301
|
-
if use_real:
|
482
|
+
if use_real and repeat_interleave_real:
|
302
483
|
freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
|
303
484
|
freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
|
304
485
|
return freqs_cos, freqs_sin
|
486
|
+
elif use_real:
|
487
|
+
freqs_cos = torch.cat([freqs.cos(), freqs.cos()], dim=-1) # [S, D]
|
488
|
+
freqs_sin = torch.cat([freqs.sin(), freqs.sin()], dim=-1) # [S, D]
|
489
|
+
return freqs_cos, freqs_sin
|
305
490
|
else:
|
306
491
|
freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 # [S, D/2]
|
307
492
|
return freqs_cis
|
@@ -310,6 +495,8 @@ def get_1d_rotary_pos_embed(dim: int, pos: Union[np.ndarray, int], theta: float
|
|
310
495
|
def apply_rotary_emb(
|
311
496
|
x: torch.Tensor,
|
312
497
|
freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
|
498
|
+
use_real: bool = True,
|
499
|
+
use_real_unbind_dim: int = -1,
|
313
500
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
314
501
|
"""
|
315
502
|
Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
|
@@ -325,16 +512,32 @@ def apply_rotary_emb(
|
|
325
512
|
Returns:
|
326
513
|
Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
|
327
514
|
"""
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
515
|
+
if use_real:
|
516
|
+
cos, sin = freqs_cis # [S, D]
|
517
|
+
cos = cos[None, None]
|
518
|
+
sin = sin[None, None]
|
519
|
+
cos, sin = cos.to(x.device), sin.to(x.device)
|
520
|
+
|
521
|
+
if use_real_unbind_dim == -1:
|
522
|
+
# Use for example in Lumina
|
523
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
|
524
|
+
x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
|
525
|
+
elif use_real_unbind_dim == -2:
|
526
|
+
# Use for example in Stable Audio
|
527
|
+
x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
|
528
|
+
x_rotated = torch.cat([-x_imag, x_real], dim=-1)
|
529
|
+
else:
|
530
|
+
raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
|
332
531
|
|
333
|
-
|
334
|
-
|
335
|
-
|
532
|
+
out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
|
533
|
+
|
534
|
+
return out
|
535
|
+
else:
|
536
|
+
x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
|
537
|
+
freqs_cis = freqs_cis.unsqueeze(2)
|
538
|
+
x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
|
336
539
|
|
337
|
-
|
540
|
+
return x_out.type_as(x)
|
338
541
|
|
339
542
|
|
340
543
|
class TimestepEmbedding(nn.Module):
|
@@ -386,11 +589,12 @@ class TimestepEmbedding(nn.Module):
|
|
386
589
|
|
387
590
|
|
388
591
|
class Timesteps(nn.Module):
|
389
|
-
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float):
|
592
|
+
def __init__(self, num_channels: int, flip_sin_to_cos: bool, downscale_freq_shift: float, scale: int = 1):
|
390
593
|
super().__init__()
|
391
594
|
self.num_channels = num_channels
|
392
595
|
self.flip_sin_to_cos = flip_sin_to_cos
|
393
596
|
self.downscale_freq_shift = downscale_freq_shift
|
597
|
+
self.scale = scale
|
394
598
|
|
395
599
|
def forward(self, timesteps):
|
396
600
|
t_emb = get_timestep_embedding(
|
@@ -398,6 +602,7 @@ class Timesteps(nn.Module):
|
|
398
602
|
self.num_channels,
|
399
603
|
flip_sin_to_cos=self.flip_sin_to_cos,
|
400
604
|
downscale_freq_shift=self.downscale_freq_shift,
|
605
|
+
scale=self.scale,
|
401
606
|
)
|
402
607
|
return t_emb
|
403
608
|
|
@@ -415,9 +620,10 @@ class GaussianFourierProjection(nn.Module):
|
|
415
620
|
|
416
621
|
if set_W_to_weight:
|
417
622
|
# to delete later
|
623
|
+
del self.weight
|
418
624
|
self.W = nn.Parameter(torch.randn(embedding_size) * scale, requires_grad=False)
|
419
|
-
|
420
625
|
self.weight = self.W
|
626
|
+
del self.W
|
421
627
|
|
422
628
|
def forward(self, x):
|
423
629
|
if self.log:
|
@@ -676,6 +882,30 @@ class CombinedTimestepTextProjEmbeddings(nn.Module):
|
|
676
882
|
return conditioning
|
677
883
|
|
678
884
|
|
885
|
+
class CombinedTimestepGuidanceTextProjEmbeddings(nn.Module):
|
886
|
+
def __init__(self, embedding_dim, pooled_projection_dim):
|
887
|
+
super().__init__()
|
888
|
+
|
889
|
+
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
890
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
891
|
+
self.guidance_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
892
|
+
self.text_embedder = PixArtAlphaTextProjection(pooled_projection_dim, embedding_dim, act_fn="silu")
|
893
|
+
|
894
|
+
def forward(self, timestep, guidance, pooled_projection):
|
895
|
+
timesteps_proj = self.time_proj(timestep)
|
896
|
+
timesteps_emb = self.timestep_embedder(timesteps_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
897
|
+
|
898
|
+
guidance_proj = self.time_proj(guidance)
|
899
|
+
guidance_emb = self.guidance_embedder(guidance_proj.to(dtype=pooled_projection.dtype)) # (N, D)
|
900
|
+
|
901
|
+
time_guidance_emb = timesteps_emb + guidance_emb
|
902
|
+
|
903
|
+
pooled_projections = self.text_embedder(pooled_projection)
|
904
|
+
conditioning = time_guidance_emb + pooled_projections
|
905
|
+
|
906
|
+
return conditioning
|
907
|
+
|
908
|
+
|
679
909
|
class HunyuanDiTAttentionPool(nn.Module):
|
680
910
|
# Copied from https://github.com/Tencent/HunyuanDiT/blob/cb709308d92e6c7e8d59d0dff41b74d35088db6a/hydit/modules/poolers.py#L6
|
681
911
|
|
@@ -717,18 +947,33 @@ class HunyuanDiTAttentionPool(nn.Module):
|
|
717
947
|
|
718
948
|
|
719
949
|
class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
720
|
-
def __init__(
|
950
|
+
def __init__(
|
951
|
+
self,
|
952
|
+
embedding_dim,
|
953
|
+
pooled_projection_dim=1024,
|
954
|
+
seq_len=256,
|
955
|
+
cross_attention_dim=2048,
|
956
|
+
use_style_cond_and_image_meta_size=True,
|
957
|
+
):
|
721
958
|
super().__init__()
|
722
959
|
|
723
960
|
self.time_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
724
961
|
self.timestep_embedder = TimestepEmbedding(in_channels=256, time_embed_dim=embedding_dim)
|
725
962
|
|
963
|
+
self.size_proj = Timesteps(num_channels=256, flip_sin_to_cos=True, downscale_freq_shift=0)
|
964
|
+
|
726
965
|
self.pooler = HunyuanDiTAttentionPool(
|
727
966
|
seq_len, cross_attention_dim, num_heads=8, output_dim=pooled_projection_dim
|
728
967
|
)
|
968
|
+
|
729
969
|
# Here we use a default learned embedder layer for future extension.
|
730
|
-
self.
|
731
|
-
|
970
|
+
self.use_style_cond_and_image_meta_size = use_style_cond_and_image_meta_size
|
971
|
+
if use_style_cond_and_image_meta_size:
|
972
|
+
self.style_embedder = nn.Embedding(1, embedding_dim)
|
973
|
+
extra_in_dim = 256 * 6 + embedding_dim + pooled_projection_dim
|
974
|
+
else:
|
975
|
+
extra_in_dim = pooled_projection_dim
|
976
|
+
|
732
977
|
self.extra_embedder = PixArtAlphaTextProjection(
|
733
978
|
in_features=extra_in_dim,
|
734
979
|
hidden_size=embedding_dim * 4,
|
@@ -743,21 +988,59 @@ class HunyuanCombinedTimestepTextSizeStyleEmbedding(nn.Module):
|
|
743
988
|
# extra condition1: text
|
744
989
|
pooled_projections = self.pooler(encoder_hidden_states) # (N, 1024)
|
745
990
|
|
746
|
-
|
747
|
-
|
748
|
-
|
749
|
-
|
991
|
+
if self.use_style_cond_and_image_meta_size:
|
992
|
+
# extra condition2: image meta size embedding
|
993
|
+
image_meta_size = self.size_proj(image_meta_size.view(-1))
|
994
|
+
image_meta_size = image_meta_size.to(dtype=hidden_dtype)
|
995
|
+
image_meta_size = image_meta_size.view(-1, 6 * 256) # (N, 1536)
|
996
|
+
|
997
|
+
# extra condition3: style embedding
|
998
|
+
style_embedding = self.style_embedder(style) # (N, embedding_dim)
|
750
999
|
|
751
|
-
|
752
|
-
|
1000
|
+
# Concatenate all extra vectors
|
1001
|
+
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
1002
|
+
else:
|
1003
|
+
extra_cond = torch.cat([pooled_projections], dim=1)
|
753
1004
|
|
754
|
-
# Concatenate all extra vectors
|
755
|
-
extra_cond = torch.cat([pooled_projections, image_meta_size, style_embedding], dim=1)
|
756
1005
|
conditioning = timesteps_emb + self.extra_embedder(extra_cond) # [B, D]
|
757
1006
|
|
758
1007
|
return conditioning
|
759
1008
|
|
760
1009
|
|
1010
|
+
class LuminaCombinedTimestepCaptionEmbedding(nn.Module):
|
1011
|
+
def __init__(self, hidden_size=4096, cross_attention_dim=2048, frequency_embedding_size=256):
|
1012
|
+
super().__init__()
|
1013
|
+
self.time_proj = Timesteps(
|
1014
|
+
num_channels=frequency_embedding_size, flip_sin_to_cos=True, downscale_freq_shift=0.0
|
1015
|
+
)
|
1016
|
+
|
1017
|
+
self.timestep_embedder = TimestepEmbedding(in_channels=frequency_embedding_size, time_embed_dim=hidden_size)
|
1018
|
+
|
1019
|
+
self.caption_embedder = nn.Sequential(
|
1020
|
+
nn.LayerNorm(cross_attention_dim),
|
1021
|
+
nn.Linear(
|
1022
|
+
cross_attention_dim,
|
1023
|
+
hidden_size,
|
1024
|
+
bias=True,
|
1025
|
+
),
|
1026
|
+
)
|
1027
|
+
|
1028
|
+
def forward(self, timestep, caption_feat, caption_mask):
|
1029
|
+
# timestep embedding:
|
1030
|
+
time_freq = self.time_proj(timestep)
|
1031
|
+
time_embed = self.timestep_embedder(time_freq.to(dtype=self.timestep_embedder.linear_1.weight.dtype))
|
1032
|
+
|
1033
|
+
# caption condition embedding:
|
1034
|
+
caption_mask_float = caption_mask.float().unsqueeze(-1)
|
1035
|
+
caption_feats_pool = (caption_feat * caption_mask_float).sum(dim=1) / caption_mask_float.sum(dim=1)
|
1036
|
+
caption_feats_pool = caption_feats_pool.to(caption_feat)
|
1037
|
+
caption_embed = self.caption_embedder(caption_feats_pool)
|
1038
|
+
|
1039
|
+
conditioning = time_embed + caption_embed
|
1040
|
+
|
1041
|
+
return conditioning
|
1042
|
+
|
1043
|
+
|
761
1044
|
class TextTimeEmbedding(nn.Module):
|
762
1045
|
def __init__(self, encoder_dim: int, time_embed_dim: int, num_heads: int = 64):
|
763
1046
|
super().__init__()
|
@@ -980,7 +1263,7 @@ class GLIGENTextBoundingboxProjection(nn.Module):
|
|
980
1263
|
|
981
1264
|
objs = self.linears(torch.cat([positive_embeddings, xyxy_embedding], dim=-1))
|
982
1265
|
|
983
|
-
# positionet with text and image
|
1266
|
+
# positionet with text and image information
|
984
1267
|
else:
|
985
1268
|
phrases_masks = phrases_masks.unsqueeze(-1)
|
986
1269
|
image_masks = image_masks.unsqueeze(-1)
|
@@ -1252,7 +1535,7 @@ class MultiIPAdapterImageProjection(nn.Module):
|
|
1252
1535
|
if not isinstance(image_embeds, list):
|
1253
1536
|
deprecation_message = (
|
1254
1537
|
"You have passed a tensor as `image_embeds`.This is deprecated and will be removed in a future release."
|
1255
|
-
" Please make sure to update your script to pass `image_embeds` as a list of tensors to
|
1538
|
+
" Please make sure to update your script to pass `image_embeds` as a list of tensors to suppress this warning."
|
1256
1539
|
)
|
1257
1540
|
deprecate("image_embeds not a list", "1.0.0", deprecation_message, standard_warn=False)
|
1258
1541
|
image_embeds = [image_embeds.unsqueeze(1)]
|
@@ -191,7 +191,6 @@ def _fetch_index_file(
|
|
191
191
|
cache_dir,
|
192
192
|
variant,
|
193
193
|
force_download,
|
194
|
-
resume_download,
|
195
194
|
proxies,
|
196
195
|
local_files_only,
|
197
196
|
token,
|
@@ -216,12 +215,11 @@ def _fetch_index_file(
|
|
216
215
|
weights_name=index_file_in_repo,
|
217
216
|
cache_dir=cache_dir,
|
218
217
|
force_download=force_download,
|
219
|
-
resume_download=resume_download,
|
220
218
|
proxies=proxies,
|
221
219
|
local_files_only=local_files_only,
|
222
220
|
token=token,
|
223
221
|
revision=revision,
|
224
|
-
subfolder=
|
222
|
+
subfolder=None,
|
225
223
|
user_agent=user_agent,
|
226
224
|
commit_hash=commit_hash,
|
227
225
|
)
|
@@ -245,9 +245,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
245
245
|
force_download (`bool`, *optional*, defaults to `False`):
|
246
246
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
247
247
|
cached versions if they exist.
|
248
|
-
|
249
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
250
|
-
of Diffusers.
|
248
|
+
|
251
249
|
proxies (`Dict[str, str]`, *optional*):
|
252
250
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
253
251
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -296,7 +294,6 @@ class FlaxModelMixin(PushToHubMixin):
|
|
296
294
|
cache_dir = kwargs.pop("cache_dir", None)
|
297
295
|
force_download = kwargs.pop("force_download", False)
|
298
296
|
from_pt = kwargs.pop("from_pt", False)
|
299
|
-
resume_download = kwargs.pop("resume_download", None)
|
300
297
|
proxies = kwargs.pop("proxies", None)
|
301
298
|
local_files_only = kwargs.pop("local_files_only", False)
|
302
299
|
token = kwargs.pop("token", None)
|
@@ -316,7 +313,6 @@ class FlaxModelMixin(PushToHubMixin):
|
|
316
313
|
cache_dir=cache_dir,
|
317
314
|
return_unused_kwargs=True,
|
318
315
|
force_download=force_download,
|
319
|
-
resume_download=resume_download,
|
320
316
|
proxies=proxies,
|
321
317
|
local_files_only=local_files_only,
|
322
318
|
token=token,
|
@@ -362,7 +358,6 @@ class FlaxModelMixin(PushToHubMixin):
|
|
362
358
|
cache_dir=cache_dir,
|
363
359
|
force_download=force_download,
|
364
360
|
proxies=proxies,
|
365
|
-
resume_download=resume_download,
|
366
361
|
local_files_only=local_files_only,
|
367
362
|
token=token,
|
368
363
|
user_agent=user_agent,
|
@@ -434,9 +434,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
434
434
|
force_download (`bool`, *optional*, defaults to `False`):
|
435
435
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
436
436
|
cached versions if they exist.
|
437
|
-
resume_download:
|
438
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
439
|
-
of Diffusers.
|
440
437
|
proxies (`Dict[str, str]`, *optional*):
|
441
438
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
442
439
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -518,7 +515,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
518
515
|
ignore_mismatched_sizes = kwargs.pop("ignore_mismatched_sizes", False)
|
519
516
|
force_download = kwargs.pop("force_download", False)
|
520
517
|
from_flax = kwargs.pop("from_flax", False)
|
521
|
-
resume_download = kwargs.pop("resume_download", None)
|
522
518
|
proxies = kwargs.pop("proxies", None)
|
523
519
|
output_loading_info = kwargs.pop("output_loading_info", False)
|
524
520
|
local_files_only = kwargs.pop("local_files_only", None)
|
@@ -619,7 +615,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
619
615
|
return_unused_kwargs=True,
|
620
616
|
return_commit_hash=True,
|
621
617
|
force_download=force_download,
|
622
|
-
resume_download=resume_download,
|
623
618
|
proxies=proxies,
|
624
619
|
local_files_only=local_files_only,
|
625
620
|
token=token,
|
@@ -641,7 +636,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
641
636
|
cache_dir=cache_dir,
|
642
637
|
variant=variant,
|
643
638
|
force_download=force_download,
|
644
|
-
resume_download=resume_download,
|
645
639
|
proxies=proxies,
|
646
640
|
local_files_only=local_files_only,
|
647
641
|
token=token,
|
@@ -663,7 +657,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
663
657
|
weights_name=FLAX_WEIGHTS_NAME,
|
664
658
|
cache_dir=cache_dir,
|
665
659
|
force_download=force_download,
|
666
|
-
resume_download=resume_download,
|
667
660
|
proxies=proxies,
|
668
661
|
local_files_only=local_files_only,
|
669
662
|
token=token,
|
@@ -685,7 +678,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
685
678
|
index_file,
|
686
679
|
cache_dir=cache_dir,
|
687
680
|
proxies=proxies,
|
688
|
-
resume_download=resume_download,
|
689
681
|
local_files_only=local_files_only,
|
690
682
|
token=token,
|
691
683
|
user_agent=user_agent,
|
@@ -700,7 +692,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
700
692
|
weights_name=_add_variant(SAFETENSORS_WEIGHTS_NAME, variant),
|
701
693
|
cache_dir=cache_dir,
|
702
694
|
force_download=force_download,
|
703
|
-
resume_download=resume_download,
|
704
695
|
proxies=proxies,
|
705
696
|
local_files_only=local_files_only,
|
706
697
|
token=token,
|
@@ -724,7 +715,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
724
715
|
weights_name=_add_variant(WEIGHTS_NAME, variant),
|
725
716
|
cache_dir=cache_dir,
|
726
717
|
force_download=force_download,
|
727
|
-
resume_download=resume_download,
|
728
718
|
proxies=proxies,
|
729
719
|
local_files_only=local_files_only,
|
730
720
|
token=token,
|
@@ -783,7 +773,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
783
773
|
try:
|
784
774
|
accelerate.load_checkpoint_and_dispatch(
|
785
775
|
model,
|
786
|
-
model_file if not is_sharded else
|
776
|
+
model_file if not is_sharded else index_file,
|
787
777
|
device_map,
|
788
778
|
max_memory=max_memory,
|
789
779
|
offload_folder=offload_folder,
|
@@ -813,13 +803,13 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
813
803
|
model._temp_convert_self_to_deprecated_attention_blocks()
|
814
804
|
accelerate.load_checkpoint_and_dispatch(
|
815
805
|
model,
|
816
|
-
model_file if not is_sharded else
|
806
|
+
model_file if not is_sharded else index_file,
|
817
807
|
device_map,
|
818
808
|
max_memory=max_memory,
|
819
809
|
offload_folder=offload_folder,
|
820
810
|
offload_state_dict=offload_state_dict,
|
821
811
|
dtype=torch_dtype,
|
822
|
-
|
812
|
+
force_hooks=force_hook,
|
823
813
|
strict=True,
|
824
814
|
)
|
825
815
|
model._undo_temp_convert_self_to_deprecated_attention_blocks()
|
@@ -1169,7 +1159,7 @@ class LegacyModelMixin(ModelMixin):
|
|
1169
1159
|
@classmethod
|
1170
1160
|
@validate_hf_hub_args
|
1171
1161
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
1172
|
-
# To prevent
|
1162
|
+
# To prevent dependency import problem.
|
1173
1163
|
from .model_loading_utils import _fetch_remapped_cls_from_config
|
1174
1164
|
|
1175
1165
|
# Create a copy of the kwargs so that we don't mess with the keyword arguments in the downstream calls.
|
@@ -1177,7 +1167,6 @@ class LegacyModelMixin(ModelMixin):
|
|
1177
1167
|
|
1178
1168
|
cache_dir = kwargs.pop("cache_dir", None)
|
1179
1169
|
force_download = kwargs.pop("force_download", False)
|
1180
|
-
resume_download = kwargs.pop("resume_download", None)
|
1181
1170
|
proxies = kwargs.pop("proxies", None)
|
1182
1171
|
local_files_only = kwargs.pop("local_files_only", None)
|
1183
1172
|
token = kwargs.pop("token", None)
|
@@ -1200,7 +1189,6 @@ class LegacyModelMixin(ModelMixin):
|
|
1200
1189
|
return_unused_kwargs=True,
|
1201
1190
|
return_commit_hash=True,
|
1202
1191
|
force_download=force_download,
|
1203
|
-
resume_download=resume_download,
|
1204
1192
|
proxies=proxies,
|
1205
1193
|
local_files_only=local_files_only,
|
1206
1194
|
token=token,
|