diffusers 0.27.2__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +19 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -18
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +18 -18
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -39
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +20 -26
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +42 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +103 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +23 -23
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +86 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +75 -54
- diffusers/schedulers/scheduling_edm_euler.py +50 -31
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +23 -29
- diffusers/schedulers/scheduling_euler_discrete.py +160 -68
- diffusers/schedulers/scheduling_heun_discrete.py +57 -39
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +19 -19
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +19 -19
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +24 -26
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +111 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/METADATA +47 -47
- diffusers-0.28.0.dist-info/RECORD +414 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/WHEEL +1 -1
- diffusers-0.27.2.dist-info/RECORD +0 -399
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.2.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
diffusers/models/downsampling.py
CHANGED
@@ -102,7 +102,6 @@ class Downsample2D(nn.Module):
|
|
102
102
|
self.padding = padding
|
103
103
|
stride = 2
|
104
104
|
self.name = name
|
105
|
-
conv_cls = nn.Conv2d
|
106
105
|
|
107
106
|
if norm_type == "ln_norm":
|
108
107
|
self.norm = nn.LayerNorm(channels, eps, elementwise_affine)
|
@@ -114,7 +113,7 @@ class Downsample2D(nn.Module):
|
|
114
113
|
raise ValueError(f"unknown norm_type: {norm_type}")
|
115
114
|
|
116
115
|
if use_conv:
|
117
|
-
conv =
|
116
|
+
conv = nn.Conv2d(
|
118
117
|
self.channels, self.out_channels, kernel_size=kernel_size, stride=stride, padding=padding, bias=bias
|
119
118
|
)
|
120
119
|
else:
|
@@ -130,7 +129,7 @@ class Downsample2D(nn.Module):
|
|
130
129
|
else:
|
131
130
|
self.conv = conv
|
132
131
|
|
133
|
-
def forward(self, hidden_states: torch.
|
132
|
+
def forward(self, hidden_states: torch.Tensor, *args, **kwargs) -> torch.Tensor:
|
134
133
|
if len(args) > 0 or kwargs.get("scale", None) is not None:
|
135
134
|
deprecation_message = "The `scale` argument is deprecated and will be ignored. Please remove it, as passing it will raise an error in the future. `scale` should directly be passed while calling the underlying pipeline component i.e., via `cross_attention_kwargs`."
|
136
135
|
deprecate("scale", "1.0.0", deprecation_message)
|
@@ -181,24 +180,24 @@ class FirDownsample2D(nn.Module):
|
|
181
180
|
|
182
181
|
def _downsample_2d(
|
183
182
|
self,
|
184
|
-
hidden_states: torch.
|
185
|
-
weight: Optional[torch.
|
186
|
-
kernel: Optional[torch.
|
183
|
+
hidden_states: torch.Tensor,
|
184
|
+
weight: Optional[torch.Tensor] = None,
|
185
|
+
kernel: Optional[torch.Tensor] = None,
|
187
186
|
factor: int = 2,
|
188
187
|
gain: float = 1,
|
189
|
-
) -> torch.
|
188
|
+
) -> torch.Tensor:
|
190
189
|
"""Fused `Conv2d()` followed by `downsample_2d()`.
|
191
190
|
Padding is performed only once at the beginning, not between the operations. The fused op is considerably more
|
192
191
|
efficient than performing the same calculation using standard TensorFlow ops. It supports gradients of
|
193
192
|
arbitrary order.
|
194
193
|
|
195
194
|
Args:
|
196
|
-
hidden_states (`torch.
|
195
|
+
hidden_states (`torch.Tensor`):
|
197
196
|
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
198
|
-
weight (`torch.
|
197
|
+
weight (`torch.Tensor`, *optional*):
|
199
198
|
Weight tensor of the shape `[filterH, filterW, inChannels, outChannels]`. Grouped convolution can be
|
200
199
|
performed by `inChannels = x.shape[0] // numGroups`.
|
201
|
-
kernel (`torch.
|
200
|
+
kernel (`torch.Tensor`, *optional*):
|
202
201
|
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
203
202
|
corresponds to average pooling.
|
204
203
|
factor (`int`, *optional*, default to `2`):
|
@@ -207,7 +206,7 @@ class FirDownsample2D(nn.Module):
|
|
207
206
|
Scaling factor for signal magnitude.
|
208
207
|
|
209
208
|
Returns:
|
210
|
-
output (`torch.
|
209
|
+
output (`torch.Tensor`):
|
211
210
|
Tensor of the shape `[N, C, H // factor, W // factor]` or `[N, H // factor, W // factor, C]`, and same
|
212
211
|
datatype as `x`.
|
213
212
|
"""
|
@@ -245,7 +244,7 @@ class FirDownsample2D(nn.Module):
|
|
245
244
|
|
246
245
|
return output
|
247
246
|
|
248
|
-
def forward(self, hidden_states: torch.
|
247
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
249
248
|
if self.use_conv:
|
250
249
|
downsample_input = self._downsample_2d(hidden_states, weight=self.Conv2d_0.weight, kernel=self.fir_kernel)
|
251
250
|
hidden_states = downsample_input + self.Conv2d_0.bias.reshape(1, -1, 1, 1)
|
@@ -287,11 +286,11 @@ class KDownsample2D(nn.Module):
|
|
287
286
|
|
288
287
|
|
289
288
|
def downsample_2d(
|
290
|
-
hidden_states: torch.
|
291
|
-
kernel: Optional[torch.
|
289
|
+
hidden_states: torch.Tensor,
|
290
|
+
kernel: Optional[torch.Tensor] = None,
|
292
291
|
factor: int = 2,
|
293
292
|
gain: float = 1,
|
294
|
-
) -> torch.
|
293
|
+
) -> torch.Tensor:
|
295
294
|
r"""Downsample2D a batch of 2D images with the given filter.
|
296
295
|
Accepts a batch of 2D images of the shape `[N, C, H, W]` or `[N, H, W, C]` and downsamples each image with the
|
297
296
|
given filter. The filter is normalized so that if the input pixels are constant, they will be scaled by the
|
@@ -299,9 +298,9 @@ def downsample_2d(
|
|
299
298
|
shape is a multiple of the downsampling factor.
|
300
299
|
|
301
300
|
Args:
|
302
|
-
hidden_states (`torch.
|
301
|
+
hidden_states (`torch.Tensor`)
|
303
302
|
Input tensor of the shape `[N, C, H, W]` or `[N, H, W, C]`.
|
304
|
-
kernel (`torch.
|
303
|
+
kernel (`torch.Tensor`, *optional*):
|
305
304
|
FIR filter of the shape `[firH, firW]` or `[firN]` (separable). The default is `[1] * factor`, which
|
306
305
|
corresponds to average pooling.
|
307
306
|
factor (`int`, *optional*, default to `2`):
|
@@ -310,7 +309,7 @@ def downsample_2d(
|
|
310
309
|
Scaling factor for signal magnitude.
|
311
310
|
|
312
311
|
Returns:
|
313
|
-
output (`torch.
|
312
|
+
output (`torch.Tensor`):
|
314
313
|
Tensor of the shape `[N, C, H // factor, W // factor]`
|
315
314
|
"""
|
316
315
|
|
diffusers/models/embeddings.py
CHANGED
@@ -199,9 +199,8 @@ class TimestepEmbedding(nn.Module):
|
|
199
199
|
sample_proj_bias=True,
|
200
200
|
):
|
201
201
|
super().__init__()
|
202
|
-
linear_cls = nn.Linear
|
203
202
|
|
204
|
-
self.linear_1 =
|
203
|
+
self.linear_1 = nn.Linear(in_channels, time_embed_dim, sample_proj_bias)
|
205
204
|
|
206
205
|
if cond_proj_dim is not None:
|
207
206
|
self.cond_proj = nn.Linear(cond_proj_dim, in_channels, bias=False)
|
@@ -214,7 +213,7 @@ class TimestepEmbedding(nn.Module):
|
|
214
213
|
time_embed_dim_out = out_dim
|
215
214
|
else:
|
216
215
|
time_embed_dim_out = time_embed_dim
|
217
|
-
self.linear_2 =
|
216
|
+
self.linear_2 = nn.Linear(time_embed_dim, time_embed_dim_out, sample_proj_bias)
|
218
217
|
|
219
218
|
if post_act_fn is None:
|
220
219
|
self.post_act = None
|
@@ -425,7 +424,7 @@ class TextImageProjection(nn.Module):
|
|
425
424
|
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
426
425
|
self.text_proj = nn.Linear(text_embed_dim, cross_attention_dim)
|
427
426
|
|
428
|
-
def forward(self, text_embeds: torch.
|
427
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
429
428
|
batch_size = text_embeds.shape[0]
|
430
429
|
|
431
430
|
# image
|
@@ -451,7 +450,7 @@ class ImageProjection(nn.Module):
|
|
451
450
|
self.image_embeds = nn.Linear(image_embed_dim, self.num_image_text_embeds * cross_attention_dim)
|
452
451
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
453
452
|
|
454
|
-
def forward(self, image_embeds: torch.
|
453
|
+
def forward(self, image_embeds: torch.Tensor):
|
455
454
|
batch_size = image_embeds.shape[0]
|
456
455
|
|
457
456
|
# image
|
@@ -469,10 +468,26 @@ class IPAdapterFullImageProjection(nn.Module):
|
|
469
468
|
self.ff = FeedForward(image_embed_dim, cross_attention_dim, mult=1, activation_fn="gelu")
|
470
469
|
self.norm = nn.LayerNorm(cross_attention_dim)
|
471
470
|
|
472
|
-
def forward(self, image_embeds: torch.
|
471
|
+
def forward(self, image_embeds: torch.Tensor):
|
473
472
|
return self.norm(self.ff(image_embeds))
|
474
473
|
|
475
474
|
|
475
|
+
class IPAdapterFaceIDImageProjection(nn.Module):
|
476
|
+
def __init__(self, image_embed_dim=1024, cross_attention_dim=1024, mult=1, num_tokens=1):
|
477
|
+
super().__init__()
|
478
|
+
from .attention import FeedForward
|
479
|
+
|
480
|
+
self.num_tokens = num_tokens
|
481
|
+
self.cross_attention_dim = cross_attention_dim
|
482
|
+
self.ff = FeedForward(image_embed_dim, cross_attention_dim * num_tokens, mult=mult, activation_fn="gelu")
|
483
|
+
self.norm = nn.LayerNorm(cross_attention_dim)
|
484
|
+
|
485
|
+
def forward(self, image_embeds: torch.Tensor):
|
486
|
+
x = self.ff(image_embeds)
|
487
|
+
x = x.reshape(-1, self.num_tokens, self.cross_attention_dim)
|
488
|
+
return self.norm(x)
|
489
|
+
|
490
|
+
|
476
491
|
class CombinedTimestepLabelEmbeddings(nn.Module):
|
477
492
|
def __init__(self, num_classes, embedding_dim, class_dropout_prob=0.1):
|
478
493
|
super().__init__()
|
@@ -515,7 +530,7 @@ class TextImageTimeEmbedding(nn.Module):
|
|
515
530
|
self.text_norm = nn.LayerNorm(time_embed_dim)
|
516
531
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
517
532
|
|
518
|
-
def forward(self, text_embeds: torch.
|
533
|
+
def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
|
519
534
|
# text
|
520
535
|
time_text_embeds = self.text_proj(text_embeds)
|
521
536
|
time_text_embeds = self.text_norm(time_text_embeds)
|
@@ -532,7 +547,7 @@ class ImageTimeEmbedding(nn.Module):
|
|
532
547
|
self.image_proj = nn.Linear(image_embed_dim, time_embed_dim)
|
533
548
|
self.image_norm = nn.LayerNorm(time_embed_dim)
|
534
549
|
|
535
|
-
def forward(self, image_embeds: torch.
|
550
|
+
def forward(self, image_embeds: torch.Tensor):
|
536
551
|
# image
|
537
552
|
time_image_embeds = self.image_proj(image_embeds)
|
538
553
|
time_image_embeds = self.image_norm(time_image_embeds)
|
@@ -562,7 +577,7 @@ class ImageHintTimeEmbedding(nn.Module):
|
|
562
577
|
nn.Conv2d(256, 4, 3, padding=1),
|
563
578
|
)
|
564
579
|
|
565
|
-
def forward(self, image_embeds: torch.
|
580
|
+
def forward(self, image_embeds: torch.Tensor, hint: torch.Tensor):
|
566
581
|
# image
|
567
582
|
time_image_embeds = self.image_proj(image_embeds)
|
568
583
|
time_image_embeds = self.image_norm(time_image_embeds)
|
@@ -795,17 +810,15 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
795
810
|
"""Resampler of IP-Adapter Plus.
|
796
811
|
|
797
812
|
Args:
|
798
|
-
|
799
|
-
|
800
|
-
|
801
|
-
|
802
|
-
|
803
|
-
|
804
|
-
|
805
|
-
|
806
|
-
|
807
|
-
num_queries (int): The number of queries. Defaults to 8.
|
808
|
-
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
813
|
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
814
|
+
that is the same
|
815
|
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
816
|
+
hidden_dims (int):
|
817
|
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
818
|
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
819
|
+
Defaults to 16. num_queries (int):
|
820
|
+
The number of queries. Defaults to 8. ffn_ratio (float): The expansion ratio
|
821
|
+
of feedforward network hidden
|
809
822
|
layer channels. Defaults to 4.
|
810
823
|
"""
|
811
824
|
|
@@ -855,11 +868,8 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
855
868
|
"""Forward pass.
|
856
869
|
|
857
870
|
Args:
|
858
|
-
----
|
859
871
|
x (torch.Tensor): Input Tensor.
|
860
|
-
|
861
872
|
Returns:
|
862
|
-
-------
|
863
873
|
torch.Tensor: Output Tensor.
|
864
874
|
"""
|
865
875
|
latents = self.latents.repeat(x.size(0), 1, 1)
|
@@ -879,12 +889,125 @@ class IPAdapterPlusImageProjection(nn.Module):
|
|
879
889
|
return self.norm_out(latents)
|
880
890
|
|
881
891
|
|
892
|
+
class IPAdapterPlusImageProjectionBlock(nn.Module):
|
893
|
+
def __init__(
|
894
|
+
self,
|
895
|
+
embed_dims: int = 768,
|
896
|
+
dim_head: int = 64,
|
897
|
+
heads: int = 16,
|
898
|
+
ffn_ratio: float = 4,
|
899
|
+
) -> None:
|
900
|
+
super().__init__()
|
901
|
+
from .attention import FeedForward
|
902
|
+
|
903
|
+
self.ln0 = nn.LayerNorm(embed_dims)
|
904
|
+
self.ln1 = nn.LayerNorm(embed_dims)
|
905
|
+
self.attn = Attention(
|
906
|
+
query_dim=embed_dims,
|
907
|
+
dim_head=dim_head,
|
908
|
+
heads=heads,
|
909
|
+
out_bias=False,
|
910
|
+
)
|
911
|
+
self.ff = nn.Sequential(
|
912
|
+
nn.LayerNorm(embed_dims),
|
913
|
+
FeedForward(embed_dims, embed_dims, activation_fn="gelu", mult=ffn_ratio, bias=False),
|
914
|
+
)
|
915
|
+
|
916
|
+
def forward(self, x, latents, residual):
|
917
|
+
encoder_hidden_states = self.ln0(x)
|
918
|
+
latents = self.ln1(latents)
|
919
|
+
encoder_hidden_states = torch.cat([encoder_hidden_states, latents], dim=-2)
|
920
|
+
latents = self.attn(latents, encoder_hidden_states) + residual
|
921
|
+
latents = self.ff(latents) + latents
|
922
|
+
return latents
|
923
|
+
|
924
|
+
|
925
|
+
class IPAdapterFaceIDPlusImageProjection(nn.Module):
|
926
|
+
"""FacePerceiverResampler of IP-Adapter Plus.
|
927
|
+
|
928
|
+
Args:
|
929
|
+
embed_dims (int): The feature dimension. Defaults to 768. output_dims (int): The number of output channels,
|
930
|
+
that is the same
|
931
|
+
number of the channels in the `unet.config.cross_attention_dim`. Defaults to 1024.
|
932
|
+
hidden_dims (int):
|
933
|
+
The number of hidden channels. Defaults to 1280. depth (int): The number of blocks. Defaults
|
934
|
+
to 8. dim_head (int): The number of head channels. Defaults to 64. heads (int): Parallel attention heads.
|
935
|
+
Defaults to 16. num_tokens (int): Number of tokens num_queries (int): The number of queries. Defaults to 8.
|
936
|
+
ffn_ratio (float): The expansion ratio of feedforward network hidden
|
937
|
+
layer channels. Defaults to 4.
|
938
|
+
ffproj_ratio (float): The expansion ratio of feedforward network hidden
|
939
|
+
layer channels (for ID embeddings). Defaults to 4.
|
940
|
+
"""
|
941
|
+
|
942
|
+
def __init__(
|
943
|
+
self,
|
944
|
+
embed_dims: int = 768,
|
945
|
+
output_dims: int = 768,
|
946
|
+
hidden_dims: int = 1280,
|
947
|
+
id_embeddings_dim: int = 512,
|
948
|
+
depth: int = 4,
|
949
|
+
dim_head: int = 64,
|
950
|
+
heads: int = 16,
|
951
|
+
num_tokens: int = 4,
|
952
|
+
num_queries: int = 8,
|
953
|
+
ffn_ratio: float = 4,
|
954
|
+
ffproj_ratio: int = 2,
|
955
|
+
) -> None:
|
956
|
+
super().__init__()
|
957
|
+
from .attention import FeedForward
|
958
|
+
|
959
|
+
self.num_tokens = num_tokens
|
960
|
+
self.embed_dim = embed_dims
|
961
|
+
self.clip_embeds = None
|
962
|
+
self.shortcut = False
|
963
|
+
self.shortcut_scale = 1.0
|
964
|
+
|
965
|
+
self.proj = FeedForward(id_embeddings_dim, embed_dims * num_tokens, activation_fn="gelu", mult=ffproj_ratio)
|
966
|
+
self.norm = nn.LayerNorm(embed_dims)
|
967
|
+
|
968
|
+
self.proj_in = nn.Linear(hidden_dims, embed_dims)
|
969
|
+
|
970
|
+
self.proj_out = nn.Linear(embed_dims, output_dims)
|
971
|
+
self.norm_out = nn.LayerNorm(output_dims)
|
972
|
+
|
973
|
+
self.layers = nn.ModuleList(
|
974
|
+
[IPAdapterPlusImageProjectionBlock(embed_dims, dim_head, heads, ffn_ratio) for _ in range(depth)]
|
975
|
+
)
|
976
|
+
|
977
|
+
def forward(self, id_embeds: torch.Tensor) -> torch.Tensor:
|
978
|
+
"""Forward pass.
|
979
|
+
|
980
|
+
Args:
|
981
|
+
id_embeds (torch.Tensor): Input Tensor (ID embeds).
|
982
|
+
Returns:
|
983
|
+
torch.Tensor: Output Tensor.
|
984
|
+
"""
|
985
|
+
id_embeds = id_embeds.to(self.clip_embeds.dtype)
|
986
|
+
id_embeds = self.proj(id_embeds)
|
987
|
+
id_embeds = id_embeds.reshape(-1, self.num_tokens, self.embed_dim)
|
988
|
+
id_embeds = self.norm(id_embeds)
|
989
|
+
latents = id_embeds
|
990
|
+
|
991
|
+
clip_embeds = self.proj_in(self.clip_embeds)
|
992
|
+
x = clip_embeds.reshape(-1, clip_embeds.shape[2], clip_embeds.shape[3])
|
993
|
+
|
994
|
+
for block in self.layers:
|
995
|
+
residual = latents
|
996
|
+
latents = block(x, latents, residual)
|
997
|
+
|
998
|
+
latents = self.proj_out(latents)
|
999
|
+
out = self.norm_out(latents)
|
1000
|
+
if self.shortcut:
|
1001
|
+
out = id_embeds + self.shortcut_scale * out
|
1002
|
+
return out
|
1003
|
+
|
1004
|
+
|
882
1005
|
class MultiIPAdapterImageProjection(nn.Module):
|
883
1006
|
def __init__(self, IPAdapterImageProjectionLayers: Union[List[nn.Module], Tuple[nn.Module]]):
|
884
1007
|
super().__init__()
|
885
1008
|
self.image_projection_layers = nn.ModuleList(IPAdapterImageProjectionLayers)
|
886
1009
|
|
887
|
-
def forward(self, image_embeds: List[torch.
|
1010
|
+
def forward(self, image_embeds: List[torch.Tensor]):
|
888
1011
|
projected_image_embeds = []
|
889
1012
|
|
890
1013
|
# currently, we accept `image_embeds` as
|
@@ -0,0 +1,149 @@
|
|
1
|
+
# coding=utf-8
|
2
|
+
# Copyright 2024 The HuggingFace Inc. team.
|
3
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
4
|
+
#
|
5
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
6
|
+
# you may not use this file except in compliance with the License.
|
7
|
+
# You may obtain a copy of the License at
|
8
|
+
#
|
9
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
10
|
+
#
|
11
|
+
# Unless required by applicable law or agreed to in writing, software
|
12
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
13
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
|
+
# See the License for the specific language governing permissions and
|
15
|
+
# limitations under the License.
|
16
|
+
|
17
|
+
import inspect
|
18
|
+
import os
|
19
|
+
from collections import OrderedDict
|
20
|
+
from typing import List, Optional, Union
|
21
|
+
|
22
|
+
import safetensors
|
23
|
+
import torch
|
24
|
+
|
25
|
+
from ..utils import (
|
26
|
+
SAFETENSORS_FILE_EXTENSION,
|
27
|
+
is_accelerate_available,
|
28
|
+
is_torch_version,
|
29
|
+
logging,
|
30
|
+
)
|
31
|
+
|
32
|
+
|
33
|
+
logger = logging.get_logger(__name__)
|
34
|
+
|
35
|
+
|
36
|
+
if is_accelerate_available():
|
37
|
+
from accelerate import infer_auto_device_map
|
38
|
+
from accelerate.utils import get_balanced_memory, get_max_memory, set_module_tensor_to_device
|
39
|
+
|
40
|
+
|
41
|
+
# Adapted from `transformers` (see modeling_utils.py)
|
42
|
+
def _determine_device_map(model: torch.nn.Module, device_map, max_memory, torch_dtype):
|
43
|
+
if isinstance(device_map, str):
|
44
|
+
no_split_modules = model._get_no_split_modules(device_map)
|
45
|
+
device_map_kwargs = {"no_split_module_classes": no_split_modules}
|
46
|
+
|
47
|
+
if device_map != "sequential":
|
48
|
+
max_memory = get_balanced_memory(
|
49
|
+
model,
|
50
|
+
dtype=torch_dtype,
|
51
|
+
low_zero=(device_map == "balanced_low_0"),
|
52
|
+
max_memory=max_memory,
|
53
|
+
**device_map_kwargs,
|
54
|
+
)
|
55
|
+
else:
|
56
|
+
max_memory = get_max_memory(max_memory)
|
57
|
+
|
58
|
+
device_map_kwargs["max_memory"] = max_memory
|
59
|
+
device_map = infer_auto_device_map(model, dtype=torch_dtype, **device_map_kwargs)
|
60
|
+
|
61
|
+
return device_map
|
62
|
+
|
63
|
+
|
64
|
+
def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[str] = None):
|
65
|
+
"""
|
66
|
+
Reads a checkpoint file, returning properly formatted errors if they arise.
|
67
|
+
"""
|
68
|
+
try:
|
69
|
+
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
70
|
+
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
71
|
+
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
72
|
+
else:
|
73
|
+
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
74
|
+
return torch.load(
|
75
|
+
checkpoint_file,
|
76
|
+
map_location="cpu",
|
77
|
+
**weights_only_kwarg,
|
78
|
+
)
|
79
|
+
except Exception as e:
|
80
|
+
try:
|
81
|
+
with open(checkpoint_file) as f:
|
82
|
+
if f.read().startswith("version"):
|
83
|
+
raise OSError(
|
84
|
+
"You seem to have cloned a repository without having git-lfs installed. Please install "
|
85
|
+
"git-lfs and run `git lfs install` followed by `git lfs pull` in the folder "
|
86
|
+
"you cloned."
|
87
|
+
)
|
88
|
+
else:
|
89
|
+
raise ValueError(
|
90
|
+
f"Unable to locate the file {checkpoint_file} which is necessary to load this pretrained "
|
91
|
+
"model. Make sure you have saved the model properly."
|
92
|
+
) from e
|
93
|
+
except (UnicodeDecodeError, ValueError):
|
94
|
+
raise OSError(
|
95
|
+
f"Unable to load weights from checkpoint file for '{checkpoint_file}' " f"at '{checkpoint_file}'. "
|
96
|
+
)
|
97
|
+
|
98
|
+
|
99
|
+
def load_model_dict_into_meta(
|
100
|
+
model,
|
101
|
+
state_dict: OrderedDict,
|
102
|
+
device: Optional[Union[str, torch.device]] = None,
|
103
|
+
dtype: Optional[Union[str, torch.dtype]] = None,
|
104
|
+
model_name_or_path: Optional[str] = None,
|
105
|
+
) -> List[str]:
|
106
|
+
device = device or torch.device("cpu")
|
107
|
+
dtype = dtype or torch.float32
|
108
|
+
|
109
|
+
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
110
|
+
|
111
|
+
unexpected_keys = []
|
112
|
+
empty_state_dict = model.state_dict()
|
113
|
+
for param_name, param in state_dict.items():
|
114
|
+
if param_name not in empty_state_dict:
|
115
|
+
unexpected_keys.append(param_name)
|
116
|
+
continue
|
117
|
+
|
118
|
+
if empty_state_dict[param_name].shape != param.shape:
|
119
|
+
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
120
|
+
raise ValueError(
|
121
|
+
f"Cannot load {model_name_or_path_str}because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
122
|
+
)
|
123
|
+
|
124
|
+
if accepts_dtype:
|
125
|
+
set_module_tensor_to_device(model, param_name, device, value=param, dtype=dtype)
|
126
|
+
else:
|
127
|
+
set_module_tensor_to_device(model, param_name, device, value=param)
|
128
|
+
return unexpected_keys
|
129
|
+
|
130
|
+
|
131
|
+
def _load_state_dict_into_model(model_to_load, state_dict: OrderedDict) -> List[str]:
|
132
|
+
# Convert old format to new format if needed from a PyTorch state_dict
|
133
|
+
# copy state_dict so _load_from_state_dict can modify it
|
134
|
+
state_dict = state_dict.copy()
|
135
|
+
error_msgs = []
|
136
|
+
|
137
|
+
# PyTorch's `_load_from_state_dict` does not copy parameters in a module's descendants
|
138
|
+
# so we need to apply the function recursively.
|
139
|
+
def load(module: torch.nn.Module, prefix: str = ""):
|
140
|
+
args = (state_dict, prefix, {}, True, [], [], error_msgs)
|
141
|
+
module._load_from_state_dict(*args)
|
142
|
+
|
143
|
+
for name, child in module._modules.items():
|
144
|
+
if child is not None:
|
145
|
+
load(child, prefix + name + ".")
|
146
|
+
|
147
|
+
load(model_to_load)
|
148
|
+
|
149
|
+
return error_msgs
|
@@ -12,7 +12,8 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""PyTorch - Flax general utilities."""
|
16
|
+
|
16
17
|
import re
|
17
18
|
|
18
19
|
import jax.numpy as jnp
|
@@ -245,9 +245,9 @@ class FlaxModelMixin(PushToHubMixin):
|
|
245
245
|
force_download (`bool`, *optional*, defaults to `False`):
|
246
246
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
247
247
|
cached versions if they exist.
|
248
|
-
resume_download
|
249
|
-
|
250
|
-
|
248
|
+
resume_download:
|
249
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
250
|
+
of Diffusers.
|
251
251
|
proxies (`Dict[str, str]`, *optional*):
|
252
252
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
253
253
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -296,7 +296,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
296
296
|
cache_dir = kwargs.pop("cache_dir", None)
|
297
297
|
force_download = kwargs.pop("force_download", False)
|
298
298
|
from_pt = kwargs.pop("from_pt", False)
|
299
|
-
resume_download = kwargs.pop("resume_download",
|
299
|
+
resume_download = kwargs.pop("resume_download", None)
|
300
300
|
proxies = kwargs.pop("proxies", None)
|
301
301
|
local_files_only = kwargs.pop("local_files_only", False)
|
302
302
|
token = kwargs.pop("token", None)
|
@@ -12,7 +12,7 @@
|
|
12
12
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
13
13
|
# See the License for the specific language governing permissions and
|
14
14
|
# limitations under the License.
|
15
|
-
"""
|
15
|
+
"""PyTorch - Flax general utilities."""
|
16
16
|
|
17
17
|
from pickle import UnpicklingError
|
18
18
|
|