diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,7 @@
|
|
13
13
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
|
+
import enum
|
16
17
|
import fnmatch
|
17
18
|
import importlib
|
18
19
|
import inspect
|
@@ -44,39 +45,44 @@ from ..configuration_utils import ConfigMixin
|
|
44
45
|
from ..models import AutoencoderKL
|
45
46
|
from ..models.attention_processor import FusedAttnProcessor2_0
|
46
47
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
48
|
+
from ..quantizers.bitsandbytes.utils import _check_bnb_status
|
47
49
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
48
50
|
from ..utils import (
|
49
51
|
CONFIG_NAME,
|
50
52
|
DEPRECATED_REVISION_ARGS,
|
51
53
|
BaseOutput,
|
52
54
|
PushToHubMixin,
|
53
|
-
deprecate,
|
54
55
|
is_accelerate_available,
|
55
56
|
is_accelerate_version,
|
56
57
|
is_torch_npu_available,
|
57
58
|
is_torch_version,
|
59
|
+
is_transformers_version,
|
58
60
|
logging,
|
59
61
|
numpy_to_pil,
|
60
62
|
)
|
61
|
-
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
63
|
+
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
62
64
|
from ..utils.torch_utils import is_compiled_module
|
63
65
|
|
64
66
|
|
65
67
|
if is_torch_npu_available():
|
66
68
|
import torch_npu # noqa: F401
|
67
69
|
|
68
|
-
|
69
70
|
from .pipeline_loading_utils import (
|
70
71
|
ALL_IMPORTABLE_CLASSES,
|
71
72
|
CONNECTED_PIPES_KEYS,
|
72
73
|
CUSTOM_PIPELINE_FILE_NAME,
|
73
74
|
LOADABLE_CLASSES,
|
74
75
|
_fetch_class_library_tuple,
|
76
|
+
_get_custom_components_and_folders,
|
75
77
|
_get_custom_pipeline_class,
|
76
78
|
_get_final_device_map,
|
79
|
+
_get_ignore_patterns,
|
77
80
|
_get_pipeline_class,
|
81
|
+
_identify_model_variants,
|
82
|
+
_maybe_raise_warning_for_inpainting,
|
83
|
+
_resolve_custom_pipeline_and_cls,
|
78
84
|
_unwrap_model,
|
79
|
-
|
85
|
+
_update_init_kwargs_with_connected_pipeline,
|
80
86
|
load_sub_model,
|
81
87
|
maybe_raise_or_warn,
|
82
88
|
variant_compatible_siblings,
|
@@ -185,6 +191,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
185
191
|
save_directory: Union[str, os.PathLike],
|
186
192
|
safe_serialization: bool = True,
|
187
193
|
variant: Optional[str] = None,
|
194
|
+
max_shard_size: Optional[Union[int, str]] = None,
|
188
195
|
push_to_hub: bool = False,
|
189
196
|
**kwargs,
|
190
197
|
):
|
@@ -200,6 +207,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
200
207
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
201
208
|
variant (`str`, *optional*):
|
202
209
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
210
|
+
max_shard_size (`int` or `str`, defaults to `None`):
|
211
|
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
212
|
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
213
|
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
214
|
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
215
|
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
216
|
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
203
217
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
204
218
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
205
219
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
@@ -215,7 +229,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
215
229
|
|
216
230
|
if push_to_hub:
|
217
231
|
commit_message = kwargs.pop("commit_message", None)
|
218
|
-
private = kwargs.pop("private",
|
232
|
+
private = kwargs.pop("private", None)
|
219
233
|
create_pr = kwargs.pop("create_pr", False)
|
220
234
|
token = kwargs.pop("token", None)
|
221
235
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
@@ -274,12 +288,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
274
288
|
save_method_signature = inspect.signature(save_method)
|
275
289
|
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
276
290
|
save_method_accept_variant = "variant" in save_method_signature.parameters
|
291
|
+
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
277
292
|
|
278
293
|
save_kwargs = {}
|
279
294
|
if save_method_accept_safe:
|
280
295
|
save_kwargs["safe_serialization"] = safe_serialization
|
281
296
|
if save_method_accept_variant:
|
282
297
|
save_kwargs["variant"] = variant
|
298
|
+
if save_method_accept_max_shard_size and max_shard_size is not None:
|
299
|
+
# max_shard_size is expected to not be None in ModelMixin
|
300
|
+
save_kwargs["max_shard_size"] = max_shard_size
|
283
301
|
|
284
302
|
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
285
303
|
|
@@ -370,6 +388,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
370
388
|
)
|
371
389
|
|
372
390
|
device = device or device_arg
|
391
|
+
pipeline_has_bnb = any(any((_check_bnb_status(module))) for _, module in self.components.items())
|
373
392
|
|
374
393
|
# throw warning if pipeline is in "offloaded"-mode but user tries to manually set to GPU.
|
375
394
|
def module_is_sequentially_offloaded(module):
|
@@ -392,10 +411,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
392
411
|
pipeline_is_sequentially_offloaded = any(
|
393
412
|
module_is_sequentially_offloaded(module) for _, module in self.components.items()
|
394
413
|
)
|
395
|
-
if
|
396
|
-
|
397
|
-
|
398
|
-
|
414
|
+
if device and torch.device(device).type == "cuda":
|
415
|
+
if pipeline_is_sequentially_offloaded and not pipeline_has_bnb:
|
416
|
+
raise ValueError(
|
417
|
+
"It seems like you have activated sequential model offloading by calling `enable_sequential_cpu_offload`, but are now attempting to move the pipeline to GPU. This is not compatible with offloading. Please, move your pipeline `.to('cpu')` or consider removing the move altogether if you use sequential offloading."
|
418
|
+
)
|
419
|
+
# PR: https://github.com/huggingface/accelerate/pull/3223/
|
420
|
+
elif pipeline_has_bnb and is_accelerate_version("<", "1.1.0.dev0"):
|
421
|
+
raise ValueError(
|
422
|
+
"You are trying to call `.to('cuda')` on a pipeline that has models quantized with `bitsandbytes`. Your current `accelerate` installation does not support it. Please upgrade the installation."
|
423
|
+
)
|
399
424
|
|
400
425
|
is_pipeline_device_mapped = self.hf_device_map is not None and len(self.hf_device_map) > 1
|
401
426
|
if is_pipeline_device_mapped:
|
@@ -416,18 +441,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
416
441
|
|
417
442
|
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
418
443
|
for module in modules:
|
419
|
-
|
444
|
+
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
420
445
|
|
421
|
-
if
|
446
|
+
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
|
422
447
|
logger.warning(
|
423
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not
|
448
|
+
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
424
449
|
)
|
425
450
|
|
426
|
-
if
|
451
|
+
if is_loaded_in_8bit_bnb and device is not None:
|
427
452
|
logger.warning(
|
428
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {
|
453
|
+
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
429
454
|
)
|
430
|
-
|
455
|
+
|
456
|
+
# This can happen for `transformer` models. CPU placement was added in
|
457
|
+
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
458
|
+
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
459
|
+
module.to(device=device)
|
460
|
+
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
|
431
461
|
module.to(device, dtype)
|
432
462
|
|
433
463
|
if (
|
@@ -622,6 +652,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
622
652
|
>>> pipeline.scheduler = scheduler
|
623
653
|
```
|
624
654
|
"""
|
655
|
+
# Copy the kwargs to re-use during loading connected pipeline.
|
656
|
+
kwargs_copied = kwargs.copy()
|
657
|
+
|
625
658
|
cache_dir = kwargs.pop("cache_dir", None)
|
626
659
|
force_download = kwargs.pop("force_download", False)
|
627
660
|
proxies = kwargs.pop("proxies", None)
|
@@ -716,39 +749,43 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
716
749
|
else:
|
717
750
|
cached_folder = pretrained_model_name_or_path
|
718
751
|
|
752
|
+
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
|
753
|
+
# a warning if detected.
|
754
|
+
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
|
755
|
+
warn_msg = (
|
756
|
+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
757
|
+
"Please check your files carefully:\n\n"
|
758
|
+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
759
|
+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
760
|
+
"If you find any files in the deprecated format:\n"
|
761
|
+
"1. Remove all existing checkpoint files for this variant.\n"
|
762
|
+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
763
|
+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
764
|
+
)
|
765
|
+
logger.warning(warn_msg)
|
766
|
+
|
719
767
|
config_dict = cls.load_config(cached_folder)
|
720
768
|
|
721
769
|
# pop out "_ignore_files" as it is only needed for download
|
722
770
|
config_dict.pop("_ignore_files", None)
|
723
771
|
|
724
772
|
# 2. Define which model components should load variants
|
725
|
-
# We retrieve the information by matching whether variant
|
726
|
-
#
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
variant_exists = is_folder and any(
|
733
|
-
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
734
|
-
)
|
735
|
-
if variant_exists:
|
736
|
-
model_variants[folder] = variant
|
773
|
+
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
|
774
|
+
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
|
775
|
+
# with variant being `"fp16"`.
|
776
|
+
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
|
777
|
+
if len(model_variants) == 0 and variant is not None:
|
778
|
+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
779
|
+
raise ValueError(error_message)
|
737
780
|
|
738
781
|
# 3. Load the pipeline class, if using custom module then load it from the hub
|
739
782
|
# if we load from explicit class, let's use it
|
740
|
-
custom_class_name =
|
741
|
-
|
742
|
-
|
743
|
-
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
744
|
-
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
745
|
-
):
|
746
|
-
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
747
|
-
custom_class_name = config_dict["_class_name"][1]
|
748
|
-
|
783
|
+
custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
|
784
|
+
folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
|
785
|
+
)
|
749
786
|
pipeline_class = _get_pipeline_class(
|
750
787
|
cls,
|
751
|
-
config_dict,
|
788
|
+
config=config_dict,
|
752
789
|
load_connected_pipeline=load_connected_pipeline,
|
753
790
|
custom_pipeline=custom_pipeline,
|
754
791
|
class_name=custom_class_name,
|
@@ -760,23 +797,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
760
797
|
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
761
798
|
|
762
799
|
# DEPRECATED: To be removed in 1.0.0
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
deprecation_message = (
|
771
|
-
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
772
|
-
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
773
|
-
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
774
|
-
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
775
|
-
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
776
|
-
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
777
|
-
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
778
|
-
)
|
779
|
-
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
800
|
+
# we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
|
801
|
+
# when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
|
802
|
+
_maybe_raise_warning_for_inpainting(
|
803
|
+
pipeline_class=pipeline_class,
|
804
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
805
|
+
config=config_dict,
|
806
|
+
)
|
780
807
|
|
781
808
|
# 4. Define expected modules given pipeline signature
|
782
809
|
# and define non-None initialized modules (=`init_kwargs`)
|
@@ -785,9 +812,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
785
812
|
# in this case they are already instantiated in `kwargs`
|
786
813
|
# extract them here
|
787
814
|
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
815
|
+
expected_types = pipeline_class._get_signature_types()
|
788
816
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
789
817
|
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
790
|
-
|
791
818
|
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
792
819
|
|
793
820
|
# define init kwargs and make sure that optional component modules are filtered out
|
@@ -808,6 +835,26 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
808
835
|
|
809
836
|
init_dict = {k: v for k, v in init_dict.items() if load_module(k, v)}
|
810
837
|
|
838
|
+
for key in init_dict.keys():
|
839
|
+
if key not in passed_class_obj:
|
840
|
+
continue
|
841
|
+
if "scheduler" in key:
|
842
|
+
continue
|
843
|
+
|
844
|
+
class_obj = passed_class_obj[key]
|
845
|
+
_expected_class_types = []
|
846
|
+
for expected_type in expected_types[key]:
|
847
|
+
if isinstance(expected_type, enum.EnumMeta):
|
848
|
+
_expected_class_types.extend(expected_type.__members__.keys())
|
849
|
+
else:
|
850
|
+
_expected_class_types.append(expected_type.__name__)
|
851
|
+
|
852
|
+
_is_valid_type = class_obj.__class__.__name__ in _expected_class_types
|
853
|
+
if not _is_valid_type:
|
854
|
+
logger.warning(
|
855
|
+
f"Expected types for {key}: {_expected_class_types}, got {class_obj.__class__.__name__}."
|
856
|
+
)
|
857
|
+
|
811
858
|
# Special case: safety_checker must be loaded separately when using `from_flax`
|
812
859
|
if from_flax and "safety_checker" in init_dict and "safety_checker" not in passed_class_obj:
|
813
860
|
raise NotImplementedError(
|
@@ -847,6 +894,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
847
894
|
# 7. Load each module in the pipeline
|
848
895
|
current_device_map = None
|
849
896
|
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
897
|
+
# 7.1 device_map shenanigans
|
850
898
|
if final_device_map is not None and len(final_device_map) > 0:
|
851
899
|
component_device = final_device_map.get(name, None)
|
852
900
|
if component_device is not None:
|
@@ -854,15 +902,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
854
902
|
else:
|
855
903
|
current_device_map = None
|
856
904
|
|
857
|
-
# 7.
|
905
|
+
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
858
906
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
859
907
|
|
860
|
-
# 7.
|
908
|
+
# 7.3 Define all importable classes
|
861
909
|
is_pipeline_module = hasattr(pipelines, library_name)
|
862
910
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
863
911
|
loaded_sub_model = None
|
864
912
|
|
865
|
-
# 7.
|
913
|
+
# 7.4 Use passed sub model or load class_name from library_name
|
866
914
|
if name in passed_class_obj:
|
867
915
|
# if the model is in a pipeline module, then we load it from the pipeline
|
868
916
|
# check that passed_class_obj has correct parent class
|
@@ -893,6 +941,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
893
941
|
variant=variant,
|
894
942
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
895
943
|
cached_folder=cached_folder,
|
944
|
+
use_safetensors=use_safetensors,
|
896
945
|
)
|
897
946
|
logger.info(
|
898
947
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
@@ -900,56 +949,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
900
949
|
|
901
950
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
902
951
|
|
952
|
+
# 8. Handle connected pipelines.
|
903
953
|
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
"token": token,
|
912
|
-
"revision": revision,
|
913
|
-
"torch_dtype": torch_dtype,
|
914
|
-
"custom_pipeline": custom_pipeline,
|
915
|
-
"custom_revision": custom_revision,
|
916
|
-
"provider": provider,
|
917
|
-
"sess_options": sess_options,
|
918
|
-
"device_map": device_map,
|
919
|
-
"max_memory": max_memory,
|
920
|
-
"offload_folder": offload_folder,
|
921
|
-
"offload_state_dict": offload_state_dict,
|
922
|
-
"low_cpu_mem_usage": low_cpu_mem_usage,
|
923
|
-
"variant": variant,
|
924
|
-
"use_safetensors": use_safetensors,
|
925
|
-
}
|
926
|
-
|
927
|
-
def get_connected_passed_kwargs(prefix):
|
928
|
-
connected_passed_class_obj = {
|
929
|
-
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
930
|
-
}
|
931
|
-
connected_passed_pipe_kwargs = {
|
932
|
-
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
933
|
-
}
|
934
|
-
|
935
|
-
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
936
|
-
return connected_passed_kwargs
|
937
|
-
|
938
|
-
connected_pipes = {
|
939
|
-
prefix: DiffusionPipeline.from_pretrained(
|
940
|
-
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
941
|
-
)
|
942
|
-
for prefix, repo_id in connected_pipes.items()
|
943
|
-
if repo_id is not None
|
944
|
-
}
|
945
|
-
|
946
|
-
for prefix, connected_pipe in connected_pipes.items():
|
947
|
-
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
948
|
-
init_kwargs.update(
|
949
|
-
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
950
|
-
)
|
954
|
+
init_kwargs = _update_init_kwargs_with_connected_pipeline(
|
955
|
+
init_kwargs=init_kwargs,
|
956
|
+
passed_pipe_kwargs=passed_pipe_kwargs,
|
957
|
+
passed_class_objs=passed_class_obj,
|
958
|
+
folder=cached_folder,
|
959
|
+
**kwargs_copied,
|
960
|
+
)
|
951
961
|
|
952
|
-
#
|
962
|
+
# 9. Potentially add passed objects if expected
|
953
963
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
954
964
|
passed_modules = list(passed_class_obj.keys())
|
955
965
|
optional_modules = pipeline_class._optional_components
|
@@ -1065,9 +1075,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1065
1075
|
hook = None
|
1066
1076
|
for model_str in self.model_cpu_offload_seq.split("->"):
|
1067
1077
|
model = all_model_components.pop(model_str, None)
|
1078
|
+
|
1068
1079
|
if not isinstance(model, torch.nn.Module):
|
1069
1080
|
continue
|
1070
1081
|
|
1082
|
+
# This is because the model would already be placed on a CUDA device.
|
1083
|
+
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
|
1084
|
+
if is_loaded_in_8bit_bnb:
|
1085
|
+
logger.info(
|
1086
|
+
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
|
1087
|
+
)
|
1088
|
+
continue
|
1089
|
+
|
1071
1090
|
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
1072
1091
|
self._all_hooks.append(hook)
|
1073
1092
|
|
@@ -1295,6 +1314,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1295
1314
|
model_info_call_error = e # save error to reraise it if model is not cached locally
|
1296
1315
|
|
1297
1316
|
if not local_files_only:
|
1317
|
+
filenames = {sibling.rfilename for sibling in info.siblings}
|
1318
|
+
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
1319
|
+
warn_msg = (
|
1320
|
+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
1321
|
+
"Please check your files carefully:\n\n"
|
1322
|
+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
1323
|
+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
1324
|
+
"If you find any files in the deprecated format:\n"
|
1325
|
+
"1. Remove all existing checkpoint files for this variant.\n"
|
1326
|
+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
1327
|
+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
1328
|
+
)
|
1329
|
+
logger.warning(warn_msg)
|
1330
|
+
|
1331
|
+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1332
|
+
|
1298
1333
|
config_file = hf_hub_download(
|
1299
1334
|
pretrained_model_name,
|
1300
1335
|
cls.config_name,
|
@@ -1308,52 +1343,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1308
1343
|
config_dict = cls._dict_from_json_file(config_file)
|
1309
1344
|
ignore_filenames = config_dict.pop("_ignore_files", [])
|
1310
1345
|
|
1311
|
-
# retrieve all folder_names that contain relevant files
|
1312
|
-
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
1313
|
-
|
1314
|
-
filenames = {sibling.rfilename for sibling in info.siblings}
|
1315
|
-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1316
|
-
|
1317
|
-
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
1318
|
-
pipelines = getattr(diffusers_module, "pipelines")
|
1319
|
-
|
1320
|
-
# optionally create a custom component <> custom file mapping
|
1321
|
-
custom_components = {}
|
1322
|
-
for component in folder_names:
|
1323
|
-
module_candidate = config_dict[component][0]
|
1324
|
-
|
1325
|
-
if module_candidate is None or not isinstance(module_candidate, str):
|
1326
|
-
continue
|
1327
|
-
|
1328
|
-
# We compute candidate file path on the Hub. Do not use `os.path.join`.
|
1329
|
-
candidate_file = f"{component}/{module_candidate}.py"
|
1330
|
-
|
1331
|
-
if candidate_file in filenames:
|
1332
|
-
custom_components[component] = module_candidate
|
1333
|
-
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
1334
|
-
raise ValueError(
|
1335
|
-
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
1336
|
-
)
|
1337
|
-
|
1338
|
-
if len(variant_filenames) == 0 and variant is not None:
|
1339
|
-
deprecation_message = (
|
1340
|
-
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
1341
|
-
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
|
1342
|
-
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
|
1343
|
-
"modeling files is deprecated."
|
1344
|
-
)
|
1345
|
-
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
|
1346
|
-
|
1347
1346
|
# remove ignored filenames
|
1348
1347
|
model_filenames = set(model_filenames) - set(ignore_filenames)
|
1349
1348
|
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
1350
1349
|
|
1351
|
-
# if the whole pipeline is cached we don't have to ping the Hub
|
1352
1350
|
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
1353
1351
|
version.parse(__version__).base_version
|
1354
1352
|
) >= version.parse("0.22.0"):
|
1355
1353
|
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
1356
1354
|
|
1355
|
+
custom_components, folder_names = _get_custom_components_and_folders(
|
1356
|
+
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
1357
|
+
)
|
1357
1358
|
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
1358
1359
|
|
1359
1360
|
custom_class_name = None
|
@@ -1413,49 +1414,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1413
1414
|
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
1414
1415
|
passed_components = [k for k in expected_components if k in kwargs]
|
1415
1416
|
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
model_filenames, variant=variant, passed_components=passed_components
|
1430
|
-
):
|
1431
|
-
ignore_patterns = ["*.bin", "*.msgpack"]
|
1432
|
-
|
1433
|
-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
1434
|
-
if not use_onnx:
|
1435
|
-
ignore_patterns += ["*.onnx", "*.pb"]
|
1436
|
-
|
1437
|
-
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
1438
|
-
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
1439
|
-
if (
|
1440
|
-
len(safetensors_variant_filenames) > 0
|
1441
|
-
and safetensors_model_filenames != safetensors_variant_filenames
|
1442
|
-
):
|
1443
|
-
logger.warning(
|
1444
|
-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1445
|
-
)
|
1446
|
-
else:
|
1447
|
-
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
1448
|
-
|
1449
|
-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
1450
|
-
if not use_onnx:
|
1451
|
-
ignore_patterns += ["*.onnx", "*.pb"]
|
1452
|
-
|
1453
|
-
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
1454
|
-
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
1455
|
-
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
1456
|
-
logger.warning(
|
1457
|
-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1458
|
-
)
|
1417
|
+
# retrieve all patterns that should not be downloaded and error out when needed
|
1418
|
+
ignore_patterns = _get_ignore_patterns(
|
1419
|
+
passed_components,
|
1420
|
+
model_folder_names,
|
1421
|
+
model_filenames,
|
1422
|
+
variant_filenames,
|
1423
|
+
use_safetensors,
|
1424
|
+
from_flax,
|
1425
|
+
allow_pickle,
|
1426
|
+
use_onnx,
|
1427
|
+
pipeline_class._is_onnx,
|
1428
|
+
variant,
|
1429
|
+
)
|
1459
1430
|
|
1460
1431
|
# Don't download any objects that are passed
|
1461
1432
|
allow_patterns = [
|
@@ -1609,6 +1580,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1609
1580
|
"""
|
1610
1581
|
return numpy_to_pil(images)
|
1611
1582
|
|
1583
|
+
@torch.compiler.disable
|
1612
1584
|
def progress_bar(self, iterable=None, total=None):
|
1613
1585
|
if not hasattr(self, "_progress_bar_config"):
|
1614
1586
|
self._progress_bar_config = {}
|
@@ -178,7 +178,7 @@ def retrieve_timesteps(
|
|
178
178
|
sigmas: Optional[List[float]] = None,
|
179
179
|
**kwargs,
|
180
180
|
):
|
181
|
-
"""
|
181
|
+
r"""
|
182
182
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
183
183
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
184
184
|
|
@@ -338,13 +338,6 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
338
338
|
if device is None:
|
339
339
|
device = self._execution_device
|
340
340
|
|
341
|
-
if prompt is not None and isinstance(prompt, str):
|
342
|
-
batch_size = 1
|
343
|
-
elif prompt is not None and isinstance(prompt, list):
|
344
|
-
batch_size = len(prompt)
|
345
|
-
else:
|
346
|
-
batch_size = prompt_embeds.shape[0]
|
347
|
-
|
348
341
|
# See Section 3.1. of the paper.
|
349
342
|
max_length = max_sequence_length
|
350
343
|
|
@@ -389,12 +382,12 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
389
382
|
# duplicate text embeddings and attention mask for each generation per prompt, using mps friendly method
|
390
383
|
prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
391
384
|
prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
392
|
-
prompt_attention_mask = prompt_attention_mask.
|
393
|
-
prompt_attention_mask = prompt_attention_mask.
|
385
|
+
prompt_attention_mask = prompt_attention_mask.repeat(1, num_images_per_prompt)
|
386
|
+
prompt_attention_mask = prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
|
394
387
|
|
395
388
|
# get unconditional embeddings for classifier free guidance
|
396
389
|
if do_classifier_free_guidance and negative_prompt_embeds is None:
|
397
|
-
uncond_tokens = [negative_prompt] *
|
390
|
+
uncond_tokens = [negative_prompt] * bs_embed if isinstance(negative_prompt, str) else negative_prompt
|
398
391
|
uncond_tokens = self._text_preprocessing(uncond_tokens, clean_caption=clean_caption)
|
399
392
|
max_length = prompt_embeds.shape[1]
|
400
393
|
uncond_input = self.tokenizer(
|
@@ -421,10 +414,10 @@ class PixArtAlphaPipeline(DiffusionPipeline):
|
|
421
414
|
negative_prompt_embeds = negative_prompt_embeds.to(dtype=dtype, device=device)
|
422
415
|
|
423
416
|
negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
|
424
|
-
negative_prompt_embeds = negative_prompt_embeds.view(
|
417
|
+
negative_prompt_embeds = negative_prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
|
425
418
|
|
426
|
-
negative_prompt_attention_mask = negative_prompt_attention_mask.
|
427
|
-
negative_prompt_attention_mask = negative_prompt_attention_mask.
|
419
|
+
negative_prompt_attention_mask = negative_prompt_attention_mask.repeat(1, num_images_per_prompt)
|
420
|
+
negative_prompt_attention_mask = negative_prompt_attention_mask.view(bs_embed * num_images_per_prompt, -1)
|
428
421
|
else:
|
429
422
|
negative_prompt_embeds = None
|
430
423
|
negative_prompt_attention_mask = None
|