diffusers 0.27.1__py3-none-any.whl → 0.28.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +18 -1
- diffusers/callbacks.py +156 -0
- diffusers/commands/env.py +110 -6
- diffusers/configuration_utils.py +16 -11
- diffusers/dependency_versions_table.py +2 -1
- diffusers/image_processor.py +158 -45
- diffusers/loaders/__init__.py +2 -5
- diffusers/loaders/autoencoder.py +4 -4
- diffusers/loaders/controlnet.py +4 -4
- diffusers/loaders/ip_adapter.py +80 -22
- diffusers/loaders/lora.py +134 -20
- diffusers/loaders/lora_conversion_utils.py +46 -43
- diffusers/loaders/peft.py +4 -3
- diffusers/loaders/single_file.py +401 -170
- diffusers/loaders/single_file_model.py +290 -0
- diffusers/loaders/single_file_utils.py +616 -672
- diffusers/loaders/textual_inversion.py +41 -20
- diffusers/loaders/unet.py +168 -115
- diffusers/loaders/unet_loader_utils.py +163 -0
- diffusers/models/__init__.py +2 -0
- diffusers/models/activations.py +11 -3
- diffusers/models/attention.py +10 -11
- diffusers/models/attention_processor.py +367 -148
- diffusers/models/autoencoders/autoencoder_asym_kl.py +14 -16
- diffusers/models/autoencoders/autoencoder_kl.py +18 -19
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +11 -11
- diffusers/models/autoencoders/autoencoder_tiny.py +16 -16
- diffusers/models/autoencoders/consistency_decoder_vae.py +36 -11
- diffusers/models/autoencoders/vae.py +23 -24
- diffusers/models/controlnet.py +12 -9
- diffusers/models/controlnet_flax.py +4 -4
- diffusers/models/controlnet_xs.py +1915 -0
- diffusers/models/downsampling.py +17 -18
- diffusers/models/embeddings.py +147 -24
- diffusers/models/model_loading_utils.py +149 -0
- diffusers/models/modeling_flax_pytorch_utils.py +2 -1
- diffusers/models/modeling_flax_utils.py +4 -4
- diffusers/models/modeling_pytorch_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +118 -98
- diffusers/models/resnet.py +18 -23
- diffusers/models/transformer_temporal.py +3 -3
- diffusers/models/transformers/dual_transformer_2d.py +4 -4
- diffusers/models/transformers/prior_transformer.py +7 -7
- diffusers/models/transformers/t5_film_transformer.py +17 -19
- diffusers/models/transformers/transformer_2d.py +272 -156
- diffusers/models/transformers/transformer_temporal.py +10 -10
- diffusers/models/unets/unet_1d.py +5 -5
- diffusers/models/unets/unet_1d_blocks.py +29 -29
- diffusers/models/unets/unet_2d.py +6 -6
- diffusers/models/unets/unet_2d_blocks.py +137 -128
- diffusers/models/unets/unet_2d_condition.py +20 -15
- diffusers/models/unets/unet_2d_condition_flax.py +6 -5
- diffusers/models/unets/unet_3d_blocks.py +79 -77
- diffusers/models/unets/unet_3d_condition.py +13 -9
- diffusers/models/unets/unet_i2vgen_xl.py +14 -13
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +114 -14
- diffusers/models/unets/unet_spatio_temporal_condition.py +15 -14
- diffusers/models/unets/unet_stable_cascade.py +16 -13
- diffusers/models/upsampling.py +17 -20
- diffusers/models/vq_model.py +16 -15
- diffusers/pipelines/__init__.py +25 -3
- diffusers/pipelines/amused/pipeline_amused.py +12 -12
- diffusers/pipelines/amused/pipeline_amused_img2img.py +14 -12
- diffusers/pipelines/amused/pipeline_amused_inpaint.py +13 -11
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +24 -46
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +1284 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +82 -72
- diffusers/pipelines/animatediff/pipeline_output.py +3 -2
- diffusers/pipelines/audioldm/pipeline_audioldm.py +14 -14
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +54 -35
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +120 -36
- diffusers/pipelines/auto_pipeline.py +21 -17
- diffusers/pipelines/blip_diffusion/blip_image_processing.py +1 -1
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +5 -5
- diffusers/pipelines/blip_diffusion/modeling_ctx_clip.py +1 -1
- diffusers/pipelines/blip_diffusion/pipeline_blip_diffusion.py +2 -2
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +5 -5
- diffusers/pipelines/controlnet/multicontrolnet.py +4 -8
- diffusers/pipelines/controlnet/pipeline_controlnet.py +87 -52
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +2 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +50 -43
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +52 -40
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +80 -47
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +147 -49
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +89 -55
- diffusers/pipelines/controlnet_xs/__init__.py +68 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +911 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +1115 -0
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +14 -28
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +18 -33
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +21 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +20 -36
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +23 -39
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +17 -32
- diffusers/pipelines/deprecated/alt_diffusion/modeling_roberta_series.py +11 -11
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +43 -20
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +36 -18
- diffusers/pipelines/deprecated/repaint/pipeline_repaint.py +2 -2
- diffusers/pipelines/deprecated/spectrogram_diffusion/pipeline_spectrogram_diffusion.py +7 -7
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +12 -12
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +18 -21
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +20 -15
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +30 -25
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +69 -59
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion.py +13 -13
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_dual_guided.py +10 -5
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_image_variation.py +11 -6
- diffusers/pipelines/deprecated/versatile_diffusion/pipeline_versatile_diffusion_text_to_image.py +10 -5
- diffusers/pipelines/deprecated/vq_diffusion/pipeline_vq_diffusion.py +5 -5
- diffusers/pipelines/dit/pipeline_dit.py +3 -0
- diffusers/pipelines/free_init_utils.py +39 -38
- diffusers/pipelines/i2vgen_xl/pipeline_i2vgen_xl.py +33 -48
- diffusers/pipelines/kandinsky/pipeline_kandinsky.py +8 -8
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +23 -20
- diffusers/pipelines/kandinsky/pipeline_kandinsky_img2img.py +11 -11
- diffusers/pipelines/kandinsky/pipeline_kandinsky_inpaint.py +12 -12
- diffusers/pipelines/kandinsky/pipeline_kandinsky_prior.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +32 -29
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_controlnet_img2img.py +10 -10
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_img2img.py +6 -6
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_inpainting.py +8 -8
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior.py +7 -7
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_prior_emb2emb.py +6 -6
- diffusers/pipelines/kandinsky3/convert_kandinsky3_unet.py +3 -3
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +20 -33
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +24 -35
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +48 -30
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +50 -28
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +11 -11
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +61 -67
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +70 -69
- diffusers/pipelines/ledits_pp/pipeline_output.py +2 -2
- diffusers/pipelines/marigold/__init__.py +50 -0
- diffusers/pipelines/marigold/marigold_image_processing.py +561 -0
- diffusers/pipelines/marigold/pipeline_marigold_depth.py +813 -0
- diffusers/pipelines/marigold/pipeline_marigold_normals.py +690 -0
- diffusers/pipelines/musicldm/pipeline_musicldm.py +14 -14
- diffusers/pipelines/paint_by_example/pipeline_paint_by_example.py +17 -12
- diffusers/pipelines/pia/pipeline_pia.py +39 -125
- diffusers/pipelines/pipeline_flax_utils.py +4 -4
- diffusers/pipelines/pipeline_loading_utils.py +268 -23
- diffusers/pipelines/pipeline_utils.py +266 -37
- diffusers/pipelines/pixart_alpha/__init__.py +8 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +65 -75
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +880 -0
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +10 -5
- diffusers/pipelines/shap_e/pipeline_shap_e.py +3 -3
- diffusers/pipelines/shap_e/pipeline_shap_e_img2img.py +14 -14
- diffusers/pipelines/shap_e/renderer.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +36 -22
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_combined.py +23 -19
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +33 -32
- diffusers/pipelines/stable_diffusion/__init__.py +0 -1
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +18 -11
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +73 -39
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +24 -17
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_image_variation.py +13 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +66 -36
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +82 -46
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +123 -28
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +6 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +16 -16
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +24 -19
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +37 -31
- diffusers/pipelines/stable_diffusion/safety_checker.py +2 -1
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +23 -15
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +44 -42
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +23 -18
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +19 -14
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +20 -15
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +24 -19
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +65 -32
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +274 -38
- diffusers/pipelines/stable_diffusion_safe/pipeline_stable_diffusion_safe.py +10 -5
- diffusers/pipelines/stable_diffusion_safe/safety_checker.py +1 -1
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +92 -25
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +88 -44
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +108 -56
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +96 -51
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +45 -25
- diffusers/pipelines/stable_diffusion_xl/watermark.py +9 -3
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +110 -57
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +59 -30
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +71 -42
- diffusers/pipelines/text_to_video_synthesis/pipeline_output.py +3 -2
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +18 -41
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +21 -85
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +28 -19
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +39 -33
- diffusers/pipelines/unclip/pipeline_unclip.py +6 -6
- diffusers/pipelines/unclip/pipeline_unclip_image_variation.py +6 -6
- diffusers/pipelines/unidiffuser/modeling_text_decoder.py +1 -1
- diffusers/pipelines/unidiffuser/modeling_uvit.py +9 -9
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +23 -23
- diffusers/pipelines/wuerstchen/modeling_paella_vq_model.py +5 -5
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_common.py +5 -10
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +4 -6
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen.py +4 -4
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_combined.py +12 -12
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +10 -10
- diffusers/schedulers/__init__.py +2 -2
- diffusers/schedulers/deprecated/__init__.py +1 -1
- diffusers/schedulers/deprecated/scheduling_karras_ve.py +25 -25
- diffusers/schedulers/scheduling_amused.py +5 -5
- diffusers/schedulers/scheduling_consistency_decoder.py +11 -11
- diffusers/schedulers/scheduling_consistency_models.py +23 -25
- diffusers/schedulers/scheduling_ddim.py +22 -24
- diffusers/schedulers/scheduling_ddim_flax.py +2 -1
- diffusers/schedulers/scheduling_ddim_inverse.py +16 -16
- diffusers/schedulers/scheduling_ddim_parallel.py +28 -30
- diffusers/schedulers/scheduling_ddpm.py +20 -22
- diffusers/schedulers/scheduling_ddpm_flax.py +7 -3
- diffusers/schedulers/scheduling_ddpm_parallel.py +26 -28
- diffusers/schedulers/scheduling_ddpm_wuerstchen.py +14 -14
- diffusers/schedulers/scheduling_deis_multistep.py +46 -42
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +107 -77
- diffusers/schedulers/scheduling_dpmsolver_multistep_flax.py +2 -2
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +46 -46
- diffusers/schedulers/scheduling_dpmsolver_sde.py +26 -22
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +90 -65
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +78 -53
- diffusers/schedulers/scheduling_edm_euler.py +53 -30
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +26 -28
- diffusers/schedulers/scheduling_euler_discrete.py +163 -67
- diffusers/schedulers/scheduling_heun_discrete.py +60 -38
- diffusers/schedulers/scheduling_ipndm.py +8 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +22 -18
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +22 -18
- diffusers/schedulers/scheduling_karras_ve_flax.py +6 -6
- diffusers/schedulers/scheduling_lcm.py +21 -23
- diffusers/schedulers/scheduling_lms_discrete.py +27 -25
- diffusers/schedulers/scheduling_pndm.py +20 -20
- diffusers/schedulers/scheduling_repaint.py +20 -20
- diffusers/schedulers/scheduling_sasolver.py +55 -54
- diffusers/schedulers/scheduling_sde_ve.py +19 -19
- diffusers/schedulers/scheduling_tcd.py +39 -30
- diffusers/schedulers/scheduling_unclip.py +15 -15
- diffusers/schedulers/scheduling_unipc_multistep.py +115 -41
- diffusers/schedulers/scheduling_utils.py +14 -5
- diffusers/schedulers/scheduling_utils_flax.py +3 -3
- diffusers/schedulers/scheduling_vq_diffusion.py +10 -10
- diffusers/training_utils.py +56 -1
- diffusers/utils/__init__.py +7 -0
- diffusers/utils/doc_utils.py +1 -0
- diffusers/utils/dummy_pt_objects.py +30 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +90 -0
- diffusers/utils/dynamic_modules_utils.py +24 -11
- diffusers/utils/hub_utils.py +3 -2
- diffusers/utils/import_utils.py +91 -0
- diffusers/utils/loading_utils.py +2 -2
- diffusers/utils/logging.py +1 -1
- diffusers/utils/peft_utils.py +32 -5
- diffusers/utils/state_dict_utils.py +11 -2
- diffusers/utils/testing_utils.py +71 -6
- diffusers/utils/torch_utils.py +1 -0
- diffusers/video_processor.py +113 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/METADATA +7 -7
- diffusers-0.28.0.dist-info/RECORD +414 -0
- diffusers-0.27.1.dist-info/RECORD +0 -399
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/LICENSE +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/WHEEL +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.27.1.dist-info → diffusers-0.28.0.dist-info}/top_level.txt +0 -0
diffusers/__init__.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
__version__ = "0.
|
1
|
+
__version__ = "0.28.0"
|
2
2
|
|
3
3
|
from typing import TYPE_CHECKING
|
4
4
|
|
@@ -27,6 +27,7 @@ from .utils import (
|
|
27
27
|
|
28
28
|
_import_structure = {
|
29
29
|
"configuration_utils": ["ConfigMixin"],
|
30
|
+
"loaders": ["FromOriginalModelMixin"],
|
30
31
|
"models": [],
|
31
32
|
"pipelines": [],
|
32
33
|
"schedulers": [],
|
@@ -80,6 +81,7 @@ else:
|
|
80
81
|
"AutoencoderTiny",
|
81
82
|
"ConsistencyDecoderVAE",
|
82
83
|
"ControlNetModel",
|
84
|
+
"ControlNetXSAdapter",
|
83
85
|
"I2VGenXLUNet",
|
84
86
|
"Kandinsky3UNet",
|
85
87
|
"ModelMixin",
|
@@ -94,6 +96,7 @@ else:
|
|
94
96
|
"UNet2DConditionModel",
|
95
97
|
"UNet2DModel",
|
96
98
|
"UNet3DConditionModel",
|
99
|
+
"UNetControlNetXSModel",
|
97
100
|
"UNetMotionModel",
|
98
101
|
"UNetSpatioTemporalConditionModel",
|
99
102
|
"UVit2DModel",
|
@@ -214,6 +217,7 @@ else:
|
|
214
217
|
"AmusedInpaintPipeline",
|
215
218
|
"AmusedPipeline",
|
216
219
|
"AnimateDiffPipeline",
|
220
|
+
"AnimateDiffSDXLPipeline",
|
217
221
|
"AnimateDiffVideoToVideoPipeline",
|
218
222
|
"AudioLDM2Pipeline",
|
219
223
|
"AudioLDM2ProjectionModel",
|
@@ -255,10 +259,13 @@ else:
|
|
255
259
|
"LDMTextToImagePipeline",
|
256
260
|
"LEditsPPPipelineStableDiffusion",
|
257
261
|
"LEditsPPPipelineStableDiffusionXL",
|
262
|
+
"MarigoldDepthPipeline",
|
263
|
+
"MarigoldNormalsPipeline",
|
258
264
|
"MusicLDMPipeline",
|
259
265
|
"PaintByExamplePipeline",
|
260
266
|
"PIAPipeline",
|
261
267
|
"PixArtAlphaPipeline",
|
268
|
+
"PixArtSigmaPipeline",
|
262
269
|
"SemanticStableDiffusionPipeline",
|
263
270
|
"ShapEImg2ImgPipeline",
|
264
271
|
"ShapEPipeline",
|
@@ -270,6 +277,7 @@ else:
|
|
270
277
|
"StableDiffusionControlNetImg2ImgPipeline",
|
271
278
|
"StableDiffusionControlNetInpaintPipeline",
|
272
279
|
"StableDiffusionControlNetPipeline",
|
280
|
+
"StableDiffusionControlNetXSPipeline",
|
273
281
|
"StableDiffusionDepth2ImgPipeline",
|
274
282
|
"StableDiffusionDiffEditPipeline",
|
275
283
|
"StableDiffusionGLIGENPipeline",
|
@@ -293,6 +301,7 @@ else:
|
|
293
301
|
"StableDiffusionXLControlNetImg2ImgPipeline",
|
294
302
|
"StableDiffusionXLControlNetInpaintPipeline",
|
295
303
|
"StableDiffusionXLControlNetPipeline",
|
304
|
+
"StableDiffusionXLControlNetXSPipeline",
|
296
305
|
"StableDiffusionXLImg2ImgPipeline",
|
297
306
|
"StableDiffusionXLInpaintPipeline",
|
298
307
|
"StableDiffusionXLInstructPix2PixPipeline",
|
@@ -474,6 +483,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
474
483
|
AutoencoderTiny,
|
475
484
|
ConsistencyDecoderVAE,
|
476
485
|
ControlNetModel,
|
486
|
+
ControlNetXSAdapter,
|
477
487
|
I2VGenXLUNet,
|
478
488
|
Kandinsky3UNet,
|
479
489
|
ModelMixin,
|
@@ -487,6 +497,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
487
497
|
UNet2DConditionModel,
|
488
498
|
UNet2DModel,
|
489
499
|
UNet3DConditionModel,
|
500
|
+
UNetControlNetXSModel,
|
490
501
|
UNetMotionModel,
|
491
502
|
UNetSpatioTemporalConditionModel,
|
492
503
|
UVit2DModel,
|
@@ -588,6 +599,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
588
599
|
AmusedInpaintPipeline,
|
589
600
|
AmusedPipeline,
|
590
601
|
AnimateDiffPipeline,
|
602
|
+
AnimateDiffSDXLPipeline,
|
591
603
|
AnimateDiffVideoToVideoPipeline,
|
592
604
|
AudioLDM2Pipeline,
|
593
605
|
AudioLDM2ProjectionModel,
|
@@ -627,10 +639,13 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
627
639
|
LDMTextToImagePipeline,
|
628
640
|
LEditsPPPipelineStableDiffusion,
|
629
641
|
LEditsPPPipelineStableDiffusionXL,
|
642
|
+
MarigoldDepthPipeline,
|
643
|
+
MarigoldNormalsPipeline,
|
630
644
|
MusicLDMPipeline,
|
631
645
|
PaintByExamplePipeline,
|
632
646
|
PIAPipeline,
|
633
647
|
PixArtAlphaPipeline,
|
648
|
+
PixArtSigmaPipeline,
|
634
649
|
SemanticStableDiffusionPipeline,
|
635
650
|
ShapEImg2ImgPipeline,
|
636
651
|
ShapEPipeline,
|
@@ -642,6 +657,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
642
657
|
StableDiffusionControlNetImg2ImgPipeline,
|
643
658
|
StableDiffusionControlNetInpaintPipeline,
|
644
659
|
StableDiffusionControlNetPipeline,
|
660
|
+
StableDiffusionControlNetXSPipeline,
|
645
661
|
StableDiffusionDepth2ImgPipeline,
|
646
662
|
StableDiffusionDiffEditPipeline,
|
647
663
|
StableDiffusionGLIGENPipeline,
|
@@ -665,6 +681,7 @@ if TYPE_CHECKING or DIFFUSERS_SLOW_IMPORT:
|
|
665
681
|
StableDiffusionXLControlNetImg2ImgPipeline,
|
666
682
|
StableDiffusionXLControlNetInpaintPipeline,
|
667
683
|
StableDiffusionXLControlNetPipeline,
|
684
|
+
StableDiffusionXLControlNetXSPipeline,
|
668
685
|
StableDiffusionXLImg2ImgPipeline,
|
669
686
|
StableDiffusionXLInpaintPipeline,
|
670
687
|
StableDiffusionXLInstructPix2PixPipeline,
|
diffusers/callbacks.py
ADDED
@@ -0,0 +1,156 @@
|
|
1
|
+
from typing import Any, Dict, List
|
2
|
+
|
3
|
+
from .configuration_utils import ConfigMixin, register_to_config
|
4
|
+
from .utils import CONFIG_NAME
|
5
|
+
|
6
|
+
|
7
|
+
class PipelineCallback(ConfigMixin):
|
8
|
+
"""
|
9
|
+
Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
|
10
|
+
custom callbacks and ensures that all callbacks have a consistent interface.
|
11
|
+
|
12
|
+
Please implement the following:
|
13
|
+
`tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
|
14
|
+
include
|
15
|
+
variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
|
16
|
+
`callback_fn`: This method defines the core functionality of your callback.
|
17
|
+
"""
|
18
|
+
|
19
|
+
config_name = CONFIG_NAME
|
20
|
+
|
21
|
+
@register_to_config
|
22
|
+
def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
|
23
|
+
super().__init__()
|
24
|
+
|
25
|
+
if (cutoff_step_ratio is None and cutoff_step_index is None) or (
|
26
|
+
cutoff_step_ratio is not None and cutoff_step_index is not None
|
27
|
+
):
|
28
|
+
raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
|
29
|
+
|
30
|
+
if cutoff_step_ratio is not None and (
|
31
|
+
not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
|
32
|
+
):
|
33
|
+
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
|
34
|
+
|
35
|
+
@property
|
36
|
+
def tensor_inputs(self) -> List[str]:
|
37
|
+
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
|
38
|
+
|
39
|
+
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
|
40
|
+
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
|
41
|
+
|
42
|
+
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
43
|
+
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
|
44
|
+
|
45
|
+
|
46
|
+
class MultiPipelineCallbacks:
|
47
|
+
"""
|
48
|
+
This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
|
49
|
+
provides a unified interface for calling all of them.
|
50
|
+
"""
|
51
|
+
|
52
|
+
def __init__(self, callbacks: List[PipelineCallback]):
|
53
|
+
self.callbacks = callbacks
|
54
|
+
|
55
|
+
@property
|
56
|
+
def tensor_inputs(self) -> List[str]:
|
57
|
+
return [input for callback in self.callbacks for input in callback.tensor_inputs]
|
58
|
+
|
59
|
+
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
60
|
+
"""
|
61
|
+
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
|
62
|
+
"""
|
63
|
+
for callback in self.callbacks:
|
64
|
+
callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
|
65
|
+
|
66
|
+
return callback_kwargs
|
67
|
+
|
68
|
+
|
69
|
+
class SDCFGCutoffCallback(PipelineCallback):
|
70
|
+
"""
|
71
|
+
Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
|
72
|
+
`cutoff_step_index`), this callback will disable the CFG.
|
73
|
+
|
74
|
+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
|
75
|
+
"""
|
76
|
+
|
77
|
+
tensor_inputs = ["prompt_embeds"]
|
78
|
+
|
79
|
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
80
|
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
81
|
+
cutoff_step_index = self.config.cutoff_step_index
|
82
|
+
|
83
|
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
84
|
+
cutoff_step = (
|
85
|
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
86
|
+
)
|
87
|
+
|
88
|
+
if step_index == cutoff_step:
|
89
|
+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
|
90
|
+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
|
91
|
+
|
92
|
+
pipeline._guidance_scale = 0.0
|
93
|
+
|
94
|
+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
95
|
+
return callback_kwargs
|
96
|
+
|
97
|
+
|
98
|
+
class SDXLCFGCutoffCallback(PipelineCallback):
|
99
|
+
"""
|
100
|
+
Callback function for Stable Diffusion XL Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
|
101
|
+
`cutoff_step_index`), this callback will disable the CFG.
|
102
|
+
|
103
|
+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
|
104
|
+
"""
|
105
|
+
|
106
|
+
tensor_inputs = ["prompt_embeds", "add_text_embeds", "add_time_ids"]
|
107
|
+
|
108
|
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
109
|
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
110
|
+
cutoff_step_index = self.config.cutoff_step_index
|
111
|
+
|
112
|
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
113
|
+
cutoff_step = (
|
114
|
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
115
|
+
)
|
116
|
+
|
117
|
+
if step_index == cutoff_step:
|
118
|
+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
|
119
|
+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
|
120
|
+
|
121
|
+
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
|
122
|
+
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
|
123
|
+
|
124
|
+
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
|
125
|
+
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
|
126
|
+
|
127
|
+
pipeline._guidance_scale = 0.0
|
128
|
+
|
129
|
+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
130
|
+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
|
131
|
+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
|
132
|
+
return callback_kwargs
|
133
|
+
|
134
|
+
|
135
|
+
class IPAdapterScaleCutoffCallback(PipelineCallback):
|
136
|
+
"""
|
137
|
+
Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
|
138
|
+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
|
139
|
+
|
140
|
+
Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
|
141
|
+
"""
|
142
|
+
|
143
|
+
tensor_inputs = []
|
144
|
+
|
145
|
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
146
|
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
147
|
+
cutoff_step_index = self.config.cutoff_step_index
|
148
|
+
|
149
|
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
150
|
+
cutoff_step = (
|
151
|
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
152
|
+
)
|
153
|
+
|
154
|
+
if step_index == cutoff_step:
|
155
|
+
pipeline.set_ip_adapter_scale(0.0)
|
156
|
+
return callback_kwargs
|
diffusers/commands/env.py
CHANGED
@@ -13,12 +13,25 @@
|
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
15
|
import platform
|
16
|
+
import subprocess
|
16
17
|
from argparse import ArgumentParser
|
17
18
|
|
18
19
|
import huggingface_hub
|
19
20
|
|
20
21
|
from .. import __version__ as version
|
21
|
-
from ..utils import
|
22
|
+
from ..utils import (
|
23
|
+
is_accelerate_available,
|
24
|
+
is_bitsandbytes_available,
|
25
|
+
is_flax_available,
|
26
|
+
is_google_colab,
|
27
|
+
is_notebook,
|
28
|
+
is_peft_available,
|
29
|
+
is_safetensors_available,
|
30
|
+
is_torch_available,
|
31
|
+
is_transformers_available,
|
32
|
+
is_xformers_available,
|
33
|
+
)
|
34
|
+
from ..utils.testing_utils import get_python_version
|
22
35
|
from . import BaseDiffusersCLICommand
|
23
36
|
|
24
37
|
|
@@ -28,13 +41,19 @@ def info_command_factory(_):
|
|
28
41
|
|
29
42
|
class EnvironmentCommand(BaseDiffusersCLICommand):
|
30
43
|
@staticmethod
|
31
|
-
def register_subcommand(parser: ArgumentParser):
|
44
|
+
def register_subcommand(parser: ArgumentParser) -> None:
|
32
45
|
download_parser = parser.add_parser("env")
|
33
46
|
download_parser.set_defaults(func=info_command_factory)
|
34
47
|
|
35
|
-
def run(self):
|
48
|
+
def run(self) -> dict:
|
36
49
|
hub_version = huggingface_hub.__version__
|
37
50
|
|
51
|
+
safetensors_version = "not installed"
|
52
|
+
if is_safetensors_available():
|
53
|
+
import safetensors
|
54
|
+
|
55
|
+
safetensors_version = safetensors.__version__
|
56
|
+
|
38
57
|
pt_version = "not installed"
|
39
58
|
pt_cuda_available = "NA"
|
40
59
|
if is_torch_available():
|
@@ -43,6 +62,20 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
|
|
43
62
|
pt_version = torch.__version__
|
44
63
|
pt_cuda_available = torch.cuda.is_available()
|
45
64
|
|
65
|
+
flax_version = "not installed"
|
66
|
+
jax_version = "not installed"
|
67
|
+
jaxlib_version = "not installed"
|
68
|
+
jax_backend = "NA"
|
69
|
+
if is_flax_available():
|
70
|
+
import flax
|
71
|
+
import jax
|
72
|
+
import jaxlib
|
73
|
+
|
74
|
+
flax_version = flax.__version__
|
75
|
+
jax_version = jax.__version__
|
76
|
+
jaxlib_version = jaxlib.__version__
|
77
|
+
jax_backend = jax.lib.xla_bridge.get_backend().platform
|
78
|
+
|
46
79
|
transformers_version = "not installed"
|
47
80
|
if is_transformers_available():
|
48
81
|
import transformers
|
@@ -55,21 +88,92 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
|
|
55
88
|
|
56
89
|
accelerate_version = accelerate.__version__
|
57
90
|
|
91
|
+
peft_version = "not installed"
|
92
|
+
if is_peft_available():
|
93
|
+
import peft
|
94
|
+
|
95
|
+
peft_version = peft.__version__
|
96
|
+
|
97
|
+
bitsandbytes_version = "not installed"
|
98
|
+
if is_bitsandbytes_available():
|
99
|
+
import bitsandbytes
|
100
|
+
|
101
|
+
bitsandbytes_version = bitsandbytes.__version__
|
102
|
+
|
58
103
|
xformers_version = "not installed"
|
59
104
|
if is_xformers_available():
|
60
105
|
import xformers
|
61
106
|
|
62
107
|
xformers_version = xformers.__version__
|
63
108
|
|
109
|
+
if get_python_version() >= (3, 10):
|
110
|
+
platform_info = f"{platform.freedesktop_os_release().get('PRETTY_NAME', None)} - {platform.platform()}"
|
111
|
+
else:
|
112
|
+
platform_info = platform.platform()
|
113
|
+
|
114
|
+
is_notebook_str = "Yes" if is_notebook() else "No"
|
115
|
+
|
116
|
+
is_google_colab_str = "Yes" if is_google_colab() else "No"
|
117
|
+
|
118
|
+
accelerator = "NA"
|
119
|
+
if platform.system() in {"Linux", "Windows"}:
|
120
|
+
try:
|
121
|
+
sp = subprocess.Popen(
|
122
|
+
["nvidia-smi", "--query-gpu=gpu_name,memory.total", "--format=csv,noheader"],
|
123
|
+
stdout=subprocess.PIPE,
|
124
|
+
stderr=subprocess.PIPE,
|
125
|
+
)
|
126
|
+
out_str, _ = sp.communicate()
|
127
|
+
out_str = out_str.decode("utf-8")
|
128
|
+
|
129
|
+
if len(out_str) > 0:
|
130
|
+
accelerator = out_str.strip() + " VRAM"
|
131
|
+
except FileNotFoundError:
|
132
|
+
pass
|
133
|
+
elif platform.system() == "Darwin": # Mac OS
|
134
|
+
try:
|
135
|
+
sp = subprocess.Popen(
|
136
|
+
["system_profiler", "SPDisplaysDataType"],
|
137
|
+
stdout=subprocess.PIPE,
|
138
|
+
stderr=subprocess.PIPE,
|
139
|
+
)
|
140
|
+
out_str, _ = sp.communicate()
|
141
|
+
out_str = out_str.decode("utf-8")
|
142
|
+
|
143
|
+
start = out_str.find("Chipset Model:")
|
144
|
+
if start != -1:
|
145
|
+
start += len("Chipset Model:")
|
146
|
+
end = out_str.find("\n", start)
|
147
|
+
accelerator = out_str[start:end].strip()
|
148
|
+
|
149
|
+
start = out_str.find("VRAM (Total):")
|
150
|
+
if start != -1:
|
151
|
+
start += len("VRAM (Total):")
|
152
|
+
end = out_str.find("\n", start)
|
153
|
+
accelerator += " VRAM: " + out_str[start:end].strip()
|
154
|
+
except FileNotFoundError:
|
155
|
+
pass
|
156
|
+
else:
|
157
|
+
print("It seems you are running an unusual OS. Could you fill in the accelerator manually?")
|
158
|
+
|
64
159
|
info = {
|
65
|
-
"
|
66
|
-
"Platform":
|
160
|
+
"🤗 Diffusers version": version,
|
161
|
+
"Platform": platform_info,
|
162
|
+
"Running on a notebook?": is_notebook_str,
|
163
|
+
"Running on Google Colab?": is_google_colab_str,
|
67
164
|
"Python version": platform.python_version(),
|
68
165
|
"PyTorch version (GPU?)": f"{pt_version} ({pt_cuda_available})",
|
166
|
+
"Flax version (CPU?/GPU?/TPU?)": f"{flax_version} ({jax_backend})",
|
167
|
+
"Jax version": jax_version,
|
168
|
+
"JaxLib version": jaxlib_version,
|
69
169
|
"Huggingface_hub version": hub_version,
|
70
170
|
"Transformers version": transformers_version,
|
71
171
|
"Accelerate version": accelerate_version,
|
172
|
+
"PEFT version": peft_version,
|
173
|
+
"Bitsandbytes version": bitsandbytes_version,
|
174
|
+
"Safetensors version": safetensors_version,
|
72
175
|
"xFormers version": xformers_version,
|
176
|
+
"Accelerator": accelerator,
|
73
177
|
"Using GPU in script?": "<fill in>",
|
74
178
|
"Using distributed or parallel set-up in script?": "<fill in>",
|
75
179
|
}
|
@@ -80,5 +184,5 @@ class EnvironmentCommand(BaseDiffusersCLICommand):
|
|
80
184
|
return info
|
81
185
|
|
82
186
|
@staticmethod
|
83
|
-
def format_dict(d):
|
187
|
+
def format_dict(d: dict) -> str:
|
84
188
|
return "\n".join([f"- {prop}: {val}" for prop, val in d.items()]) + "\n"
|
diffusers/configuration_utils.py
CHANGED
@@ -13,7 +13,8 @@
|
|
13
13
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
|
-
"""
|
16
|
+
"""ConfigMixin base class and utilities."""
|
17
|
+
|
17
18
|
import dataclasses
|
18
19
|
import functools
|
19
20
|
import importlib
|
@@ -309,9 +310,9 @@ class ConfigMixin:
|
|
309
310
|
force_download (`bool`, *optional*, defaults to `False`):
|
310
311
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
311
312
|
cached versions if they exist.
|
312
|
-
resume_download
|
313
|
-
|
314
|
-
|
313
|
+
resume_download:
|
314
|
+
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
315
|
+
of Diffusers.
|
315
316
|
proxies (`Dict[str, str]`, *optional*):
|
316
317
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
317
318
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -339,8 +340,10 @@ class ConfigMixin:
|
|
339
340
|
|
340
341
|
"""
|
341
342
|
cache_dir = kwargs.pop("cache_dir", None)
|
343
|
+
local_dir = kwargs.pop("local_dir", None)
|
344
|
+
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
|
342
345
|
force_download = kwargs.pop("force_download", False)
|
343
|
-
resume_download = kwargs.pop("resume_download",
|
346
|
+
resume_download = kwargs.pop("resume_download", None)
|
344
347
|
proxies = kwargs.pop("proxies", None)
|
345
348
|
token = kwargs.pop("token", None)
|
346
349
|
local_files_only = kwargs.pop("local_files_only", False)
|
@@ -363,13 +366,13 @@ class ConfigMixin:
|
|
363
366
|
if os.path.isfile(pretrained_model_name_or_path):
|
364
367
|
config_file = pretrained_model_name_or_path
|
365
368
|
elif os.path.isdir(pretrained_model_name_or_path):
|
366
|
-
if os.path.isfile(
|
367
|
-
# Load from a PyTorch checkpoint
|
368
|
-
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
369
|
-
elif subfolder is not None and os.path.isfile(
|
369
|
+
if subfolder is not None and os.path.isfile(
|
370
370
|
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
371
371
|
):
|
372
372
|
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
373
|
+
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
374
|
+
# Load from a PyTorch checkpoint
|
375
|
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
373
376
|
else:
|
374
377
|
raise EnvironmentError(
|
375
378
|
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
@@ -389,6 +392,8 @@ class ConfigMixin:
|
|
389
392
|
user_agent=user_agent,
|
390
393
|
subfolder=subfolder,
|
391
394
|
revision=revision,
|
395
|
+
local_dir=local_dir,
|
396
|
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
392
397
|
)
|
393
398
|
except RepositoryNotFoundError:
|
394
399
|
raise EnvironmentError(
|
@@ -449,8 +454,8 @@ class ConfigMixin:
|
|
449
454
|
return outputs
|
450
455
|
|
451
456
|
@staticmethod
|
452
|
-
def _get_init_keys(
|
453
|
-
return set(dict(inspect.signature(
|
457
|
+
def _get_init_keys(input_class):
|
458
|
+
return set(dict(inspect.signature(input_class.__init__).parameters).keys())
|
454
459
|
|
455
460
|
@classmethod
|
456
461
|
def extract_init_dict(cls, config_dict, **kwargs):
|
@@ -3,7 +3,7 @@
|
|
3
3
|
# 2. run `make deps_table_update`
|
4
4
|
deps = {
|
5
5
|
"Pillow": "Pillow",
|
6
|
-
"accelerate": "accelerate>=0.
|
6
|
+
"accelerate": "accelerate>=0.29.3",
|
7
7
|
"compel": "compel==0.1.8",
|
8
8
|
"datasets": "datasets",
|
9
9
|
"filelock": "filelock",
|
@@ -42,4 +42,5 @@ deps = {
|
|
42
42
|
"torchvision": "torchvision",
|
43
43
|
"transformers": "transformers>=4.25.1",
|
44
44
|
"urllib3": "urllib3<=2.0.0",
|
45
|
+
"black": "black",
|
45
46
|
}
|