diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -44,21 +44,22 @@ from ..configuration_utils import ConfigMixin
|
|
44
44
|
from ..models import AutoencoderKL
|
45
45
|
from ..models.attention_processor import FusedAttnProcessor2_0
|
46
46
|
from ..models.modeling_utils import _LOW_CPU_MEM_USAGE_DEFAULT, ModelMixin
|
47
|
+
from ..quantizers.bitsandbytes.utils import _check_bnb_status
|
47
48
|
from ..schedulers.scheduling_utils import SCHEDULER_CONFIG_NAME
|
48
49
|
from ..utils import (
|
49
50
|
CONFIG_NAME,
|
50
51
|
DEPRECATED_REVISION_ARGS,
|
51
52
|
BaseOutput,
|
52
53
|
PushToHubMixin,
|
53
|
-
deprecate,
|
54
54
|
is_accelerate_available,
|
55
55
|
is_accelerate_version,
|
56
56
|
is_torch_npu_available,
|
57
57
|
is_torch_version,
|
58
|
+
is_transformers_version,
|
58
59
|
logging,
|
59
60
|
numpy_to_pil,
|
60
61
|
)
|
61
|
-
from ..utils.hub_utils import load_or_create_model_card, populate_model_card
|
62
|
+
from ..utils.hub_utils import _check_legacy_sharding_variant_format, load_or_create_model_card, populate_model_card
|
62
63
|
from ..utils.torch_utils import is_compiled_module
|
63
64
|
|
64
65
|
|
@@ -72,11 +73,16 @@ from .pipeline_loading_utils import (
|
|
72
73
|
CUSTOM_PIPELINE_FILE_NAME,
|
73
74
|
LOADABLE_CLASSES,
|
74
75
|
_fetch_class_library_tuple,
|
76
|
+
_get_custom_components_and_folders,
|
75
77
|
_get_custom_pipeline_class,
|
76
78
|
_get_final_device_map,
|
79
|
+
_get_ignore_patterns,
|
77
80
|
_get_pipeline_class,
|
81
|
+
_identify_model_variants,
|
82
|
+
_maybe_raise_warning_for_inpainting,
|
83
|
+
_resolve_custom_pipeline_and_cls,
|
78
84
|
_unwrap_model,
|
79
|
-
|
85
|
+
_update_init_kwargs_with_connected_pipeline,
|
80
86
|
load_sub_model,
|
81
87
|
maybe_raise_or_warn,
|
82
88
|
variant_compatible_siblings,
|
@@ -185,6 +191,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
185
191
|
save_directory: Union[str, os.PathLike],
|
186
192
|
safe_serialization: bool = True,
|
187
193
|
variant: Optional[str] = None,
|
194
|
+
max_shard_size: Optional[Union[int, str]] = None,
|
188
195
|
push_to_hub: bool = False,
|
189
196
|
**kwargs,
|
190
197
|
):
|
@@ -200,6 +207,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
200
207
|
Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
|
201
208
|
variant (`str`, *optional*):
|
202
209
|
If specified, weights are saved in the format `pytorch_model.<variant>.bin`.
|
210
|
+
max_shard_size (`int` or `str`, defaults to `None`):
|
211
|
+
The maximum size for a checkpoint before being sharded. Checkpoints shard will then be each of size
|
212
|
+
lower than this size. If expressed as a string, needs to be digits followed by a unit (like `"5GB"`).
|
213
|
+
If expressed as an integer, the unit is bytes. Note that this limit will be decreased after a certain
|
214
|
+
period of time (starting from Oct 2024) to allow users to upgrade to the latest version of `diffusers`.
|
215
|
+
This is to establish a common default size for this argument across different libraries in the Hugging
|
216
|
+
Face ecosystem (`transformers`, and `accelerate`, for example).
|
203
217
|
push_to_hub (`bool`, *optional*, defaults to `False`):
|
204
218
|
Whether or not to push your model to the Hugging Face model hub after saving it. You can specify the
|
205
219
|
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
@@ -274,12 +288,16 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
274
288
|
save_method_signature = inspect.signature(save_method)
|
275
289
|
save_method_accept_safe = "safe_serialization" in save_method_signature.parameters
|
276
290
|
save_method_accept_variant = "variant" in save_method_signature.parameters
|
291
|
+
save_method_accept_max_shard_size = "max_shard_size" in save_method_signature.parameters
|
277
292
|
|
278
293
|
save_kwargs = {}
|
279
294
|
if save_method_accept_safe:
|
280
295
|
save_kwargs["safe_serialization"] = safe_serialization
|
281
296
|
if save_method_accept_variant:
|
282
297
|
save_kwargs["variant"] = variant
|
298
|
+
if save_method_accept_max_shard_size and max_shard_size is not None:
|
299
|
+
# max_shard_size is expected to not be None in ModelMixin
|
300
|
+
save_kwargs["max_shard_size"] = max_shard_size
|
283
301
|
|
284
302
|
save_method(os.path.join(save_directory, pipeline_component_name), **save_kwargs)
|
285
303
|
|
@@ -416,18 +434,23 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
416
434
|
|
417
435
|
is_offloaded = pipeline_is_offloaded or pipeline_is_sequentially_offloaded
|
418
436
|
for module in modules:
|
419
|
-
|
437
|
+
_, is_loaded_in_4bit_bnb, is_loaded_in_8bit_bnb = _check_bnb_status(module)
|
420
438
|
|
421
|
-
if
|
439
|
+
if (is_loaded_in_4bit_bnb or is_loaded_in_8bit_bnb) and dtype is not None:
|
422
440
|
logger.warning(
|
423
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and conversion to {dtype} is not
|
441
|
+
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` {'4bit' if is_loaded_in_4bit_bnb else '8bit'} and conversion to {dtype} is not supported. Module is still in {'4bit' if is_loaded_in_4bit_bnb else '8bit'} precision."
|
424
442
|
)
|
425
443
|
|
426
|
-
if
|
444
|
+
if is_loaded_in_8bit_bnb and device is not None:
|
427
445
|
logger.warning(
|
428
|
-
f"The module '{module.__class__.__name__}' has been loaded in 8bit and moving it to {
|
446
|
+
f"The module '{module.__class__.__name__}' has been loaded in `bitsandbytes` 8bit and moving it to {device} via `.to()` is not supported. Module is still on {module.device}."
|
429
447
|
)
|
430
|
-
|
448
|
+
|
449
|
+
# This can happen for `transformer` models. CPU placement was added in
|
450
|
+
# https://github.com/huggingface/transformers/pull/33122. So, we guard this accordingly.
|
451
|
+
if is_loaded_in_4bit_bnb and device is not None and is_transformers_version(">", "4.44.0"):
|
452
|
+
module.to(device=device)
|
453
|
+
elif not is_loaded_in_4bit_bnb and not is_loaded_in_8bit_bnb:
|
431
454
|
module.to(device, dtype)
|
432
455
|
|
433
456
|
if (
|
@@ -622,6 +645,9 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
622
645
|
>>> pipeline.scheduler = scheduler
|
623
646
|
```
|
624
647
|
"""
|
648
|
+
# Copy the kwargs to re-use during loading connected pipeline.
|
649
|
+
kwargs_copied = kwargs.copy()
|
650
|
+
|
625
651
|
cache_dir = kwargs.pop("cache_dir", None)
|
626
652
|
force_download = kwargs.pop("force_download", False)
|
627
653
|
proxies = kwargs.pop("proxies", None)
|
@@ -716,39 +742,43 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
716
742
|
else:
|
717
743
|
cached_folder = pretrained_model_name_or_path
|
718
744
|
|
745
|
+
# The variant filenames can have the legacy sharding checkpoint format that we check and throw
|
746
|
+
# a warning if detected.
|
747
|
+
if variant is not None and _check_legacy_sharding_variant_format(folder=cached_folder, variant=variant):
|
748
|
+
warn_msg = (
|
749
|
+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
750
|
+
"Please check your files carefully:\n\n"
|
751
|
+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
752
|
+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
753
|
+
"If you find any files in the deprecated format:\n"
|
754
|
+
"1. Remove all existing checkpoint files for this variant.\n"
|
755
|
+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
756
|
+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
757
|
+
)
|
758
|
+
logger.warning(warn_msg)
|
759
|
+
|
719
760
|
config_dict = cls.load_config(cached_folder)
|
720
761
|
|
721
762
|
# pop out "_ignore_files" as it is only needed for download
|
722
763
|
config_dict.pop("_ignore_files", None)
|
723
764
|
|
724
765
|
# 2. Define which model components should load variants
|
725
|
-
# We retrieve the information by matching whether variant
|
726
|
-
#
|
727
|
-
|
728
|
-
|
729
|
-
|
730
|
-
|
731
|
-
|
732
|
-
variant_exists = is_folder and any(
|
733
|
-
p.split(".")[1].startswith(variant) for p in os.listdir(folder_path)
|
734
|
-
)
|
735
|
-
if variant_exists:
|
736
|
-
model_variants[folder] = variant
|
766
|
+
# We retrieve the information by matching whether variant model checkpoints exist in the subfolders.
|
767
|
+
# Example: `diffusion_pytorch_model.safetensors` -> `diffusion_pytorch_model.fp16.safetensors`
|
768
|
+
# with variant being `"fp16"`.
|
769
|
+
model_variants = _identify_model_variants(folder=cached_folder, variant=variant, config=config_dict)
|
770
|
+
if len(model_variants) == 0 and variant is not None:
|
771
|
+
error_message = f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
772
|
+
raise ValueError(error_message)
|
737
773
|
|
738
774
|
# 3. Load the pipeline class, if using custom module then load it from the hub
|
739
775
|
# if we load from explicit class, let's use it
|
740
|
-
custom_class_name =
|
741
|
-
|
742
|
-
|
743
|
-
elif isinstance(config_dict["_class_name"], (list, tuple)) and os.path.isfile(
|
744
|
-
os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
745
|
-
):
|
746
|
-
custom_pipeline = os.path.join(cached_folder, f"{config_dict['_class_name'][0]}.py")
|
747
|
-
custom_class_name = config_dict["_class_name"][1]
|
748
|
-
|
776
|
+
custom_pipeline, custom_class_name = _resolve_custom_pipeline_and_cls(
|
777
|
+
folder=cached_folder, config=config_dict, custom_pipeline=custom_pipeline
|
778
|
+
)
|
749
779
|
pipeline_class = _get_pipeline_class(
|
750
780
|
cls,
|
751
|
-
config_dict,
|
781
|
+
config=config_dict,
|
752
782
|
load_connected_pipeline=load_connected_pipeline,
|
753
783
|
custom_pipeline=custom_pipeline,
|
754
784
|
class_name=custom_class_name,
|
@@ -760,23 +790,13 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
760
790
|
raise NotImplementedError("`device_map` is not yet supported for connected pipelines.")
|
761
791
|
|
762
792
|
# DEPRECATED: To be removed in 1.0.0
|
763
|
-
|
764
|
-
|
765
|
-
|
766
|
-
|
767
|
-
|
768
|
-
|
769
|
-
|
770
|
-
deprecation_message = (
|
771
|
-
"You are using a legacy checkpoint for inpainting with Stable Diffusion, therefore we are loading the"
|
772
|
-
f" {StableDiffusionInpaintPipelineLegacy} class instead of {StableDiffusionInpaintPipeline}. For"
|
773
|
-
" better inpainting results, we strongly suggest using Stable Diffusion's official inpainting"
|
774
|
-
" checkpoint: https://huggingface.co/runwayml/stable-diffusion-inpainting instead or adapting your"
|
775
|
-
f" checkpoint {pretrained_model_name_or_path} to the format of"
|
776
|
-
" https://huggingface.co/runwayml/stable-diffusion-inpainting. Note that we do not actively maintain"
|
777
|
-
" the {StableDiffusionInpaintPipelineLegacy} class and will likely remove it in version 1.0.0."
|
778
|
-
)
|
779
|
-
deprecate("StableDiffusionInpaintPipelineLegacy", "1.0.0", deprecation_message, standard_warn=False)
|
793
|
+
# we are deprecating the `StableDiffusionInpaintPipelineLegacy` pipeline which gets loaded
|
794
|
+
# when a user requests for a `StableDiffusionInpaintPipeline` with `diffusers` version being <= 0.5.1.
|
795
|
+
_maybe_raise_warning_for_inpainting(
|
796
|
+
pipeline_class=pipeline_class,
|
797
|
+
pretrained_model_name_or_path=pretrained_model_name_or_path,
|
798
|
+
config=config_dict,
|
799
|
+
)
|
780
800
|
|
781
801
|
# 4. Define expected modules given pipeline signature
|
782
802
|
# and define non-None initialized modules (=`init_kwargs`)
|
@@ -787,7 +807,6 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
787
807
|
expected_modules, optional_kwargs = cls._get_signature_keys(pipeline_class)
|
788
808
|
passed_class_obj = {k: kwargs.pop(k) for k in expected_modules if k in kwargs}
|
789
809
|
passed_pipe_kwargs = {k: kwargs.pop(k) for k in optional_kwargs if k in kwargs}
|
790
|
-
|
791
810
|
init_dict, unused_kwargs, _ = pipeline_class.extract_init_dict(config_dict, **kwargs)
|
792
811
|
|
793
812
|
# define init kwargs and make sure that optional component modules are filtered out
|
@@ -847,6 +866,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
847
866
|
# 7. Load each module in the pipeline
|
848
867
|
current_device_map = None
|
849
868
|
for name, (library_name, class_name) in logging.tqdm(init_dict.items(), desc="Loading pipeline components..."):
|
869
|
+
# 7.1 device_map shenanigans
|
850
870
|
if final_device_map is not None and len(final_device_map) > 0:
|
851
871
|
component_device = final_device_map.get(name, None)
|
852
872
|
if component_device is not None:
|
@@ -854,15 +874,15 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
854
874
|
else:
|
855
875
|
current_device_map = None
|
856
876
|
|
857
|
-
# 7.
|
877
|
+
# 7.2 - now that JAX/Flax is an official framework of the library, we might load from Flax names
|
858
878
|
class_name = class_name[4:] if class_name.startswith("Flax") else class_name
|
859
879
|
|
860
|
-
# 7.
|
880
|
+
# 7.3 Define all importable classes
|
861
881
|
is_pipeline_module = hasattr(pipelines, library_name)
|
862
882
|
importable_classes = ALL_IMPORTABLE_CLASSES
|
863
883
|
loaded_sub_model = None
|
864
884
|
|
865
|
-
# 7.
|
885
|
+
# 7.4 Use passed sub model or load class_name from library_name
|
866
886
|
if name in passed_class_obj:
|
867
887
|
# if the model is in a pipeline module, then we load it from the pipeline
|
868
888
|
# check that passed_class_obj has correct parent class
|
@@ -893,6 +913,7 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
893
913
|
variant=variant,
|
894
914
|
low_cpu_mem_usage=low_cpu_mem_usage,
|
895
915
|
cached_folder=cached_folder,
|
916
|
+
use_safetensors=use_safetensors,
|
896
917
|
)
|
897
918
|
logger.info(
|
898
919
|
f"Loaded {name} as {class_name} from `{name}` subfolder of {pretrained_model_name_or_path}."
|
@@ -900,56 +921,17 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
900
921
|
|
901
922
|
init_kwargs[name] = loaded_sub_model # UNet(...), # DiffusionSchedule(...)
|
902
923
|
|
924
|
+
# 8. Handle connected pipelines.
|
903
925
|
if pipeline_class._load_connected_pipes and os.path.isfile(os.path.join(cached_folder, "README.md")):
|
904
|
-
|
905
|
-
|
906
|
-
|
907
|
-
|
908
|
-
|
909
|
-
|
910
|
-
|
911
|
-
"token": token,
|
912
|
-
"revision": revision,
|
913
|
-
"torch_dtype": torch_dtype,
|
914
|
-
"custom_pipeline": custom_pipeline,
|
915
|
-
"custom_revision": custom_revision,
|
916
|
-
"provider": provider,
|
917
|
-
"sess_options": sess_options,
|
918
|
-
"device_map": device_map,
|
919
|
-
"max_memory": max_memory,
|
920
|
-
"offload_folder": offload_folder,
|
921
|
-
"offload_state_dict": offload_state_dict,
|
922
|
-
"low_cpu_mem_usage": low_cpu_mem_usage,
|
923
|
-
"variant": variant,
|
924
|
-
"use_safetensors": use_safetensors,
|
925
|
-
}
|
926
|
-
|
927
|
-
def get_connected_passed_kwargs(prefix):
|
928
|
-
connected_passed_class_obj = {
|
929
|
-
k.replace(f"{prefix}_", ""): w for k, w in passed_class_obj.items() if k.split("_")[0] == prefix
|
930
|
-
}
|
931
|
-
connected_passed_pipe_kwargs = {
|
932
|
-
k.replace(f"{prefix}_", ""): w for k, w in passed_pipe_kwargs.items() if k.split("_")[0] == prefix
|
933
|
-
}
|
934
|
-
|
935
|
-
connected_passed_kwargs = {**connected_passed_class_obj, **connected_passed_pipe_kwargs}
|
936
|
-
return connected_passed_kwargs
|
937
|
-
|
938
|
-
connected_pipes = {
|
939
|
-
prefix: DiffusionPipeline.from_pretrained(
|
940
|
-
repo_id, **load_kwargs.copy(), **get_connected_passed_kwargs(prefix)
|
941
|
-
)
|
942
|
-
for prefix, repo_id in connected_pipes.items()
|
943
|
-
if repo_id is not None
|
944
|
-
}
|
945
|
-
|
946
|
-
for prefix, connected_pipe in connected_pipes.items():
|
947
|
-
# add connected pipes to `init_kwargs` with <prefix>_<component_name>, e.g. "prior_text_encoder"
|
948
|
-
init_kwargs.update(
|
949
|
-
{"_".join([prefix, name]): component for name, component in connected_pipe.components.items()}
|
950
|
-
)
|
926
|
+
init_kwargs = _update_init_kwargs_with_connected_pipeline(
|
927
|
+
init_kwargs=init_kwargs,
|
928
|
+
passed_pipe_kwargs=passed_pipe_kwargs,
|
929
|
+
passed_class_objs=passed_class_obj,
|
930
|
+
folder=cached_folder,
|
931
|
+
**kwargs_copied,
|
932
|
+
)
|
951
933
|
|
952
|
-
#
|
934
|
+
# 9. Potentially add passed objects if expected
|
953
935
|
missing_modules = set(expected_modules) - set(init_kwargs.keys())
|
954
936
|
passed_modules = list(passed_class_obj.keys())
|
955
937
|
optional_modules = pipeline_class._optional_components
|
@@ -1065,9 +1047,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1065
1047
|
hook = None
|
1066
1048
|
for model_str in self.model_cpu_offload_seq.split("->"):
|
1067
1049
|
model = all_model_components.pop(model_str, None)
|
1050
|
+
|
1068
1051
|
if not isinstance(model, torch.nn.Module):
|
1069
1052
|
continue
|
1070
1053
|
|
1054
|
+
# This is because the model would already be placed on a CUDA device.
|
1055
|
+
_, _, is_loaded_in_8bit_bnb = _check_bnb_status(model)
|
1056
|
+
if is_loaded_in_8bit_bnb:
|
1057
|
+
logger.info(
|
1058
|
+
f"Skipping the hook placement for the {model.__class__.__name__} as it is loaded in `bitsandbytes` 8bit."
|
1059
|
+
)
|
1060
|
+
continue
|
1061
|
+
|
1071
1062
|
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
1072
1063
|
self._all_hooks.append(hook)
|
1073
1064
|
|
@@ -1295,6 +1286,22 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1295
1286
|
model_info_call_error = e # save error to reraise it if model is not cached locally
|
1296
1287
|
|
1297
1288
|
if not local_files_only:
|
1289
|
+
filenames = {sibling.rfilename for sibling in info.siblings}
|
1290
|
+
if variant is not None and _check_legacy_sharding_variant_format(filenames=filenames, variant=variant):
|
1291
|
+
warn_msg = (
|
1292
|
+
f"Warning: The repository contains sharded checkpoints for variant '{variant}' maybe in a deprecated format. "
|
1293
|
+
"Please check your files carefully:\n\n"
|
1294
|
+
"- Correct format example: diffusion_pytorch_model.fp16-00003-of-00003.safetensors\n"
|
1295
|
+
"- Deprecated format example: diffusion_pytorch_model-00001-of-00002.fp16.safetensors\n\n"
|
1296
|
+
"If you find any files in the deprecated format:\n"
|
1297
|
+
"1. Remove all existing checkpoint files for this variant.\n"
|
1298
|
+
"2. Re-obtain the correct files by running `save_pretrained()`.\n\n"
|
1299
|
+
"This will ensure you're using the most up-to-date and compatible checkpoint format."
|
1300
|
+
)
|
1301
|
+
logger.warning(warn_msg)
|
1302
|
+
|
1303
|
+
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1304
|
+
|
1298
1305
|
config_file = hf_hub_download(
|
1299
1306
|
pretrained_model_name,
|
1300
1307
|
cls.config_name,
|
@@ -1308,52 +1315,18 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1308
1315
|
config_dict = cls._dict_from_json_file(config_file)
|
1309
1316
|
ignore_filenames = config_dict.pop("_ignore_files", [])
|
1310
1317
|
|
1311
|
-
# retrieve all folder_names that contain relevant files
|
1312
|
-
folder_names = [k for k, v in config_dict.items() if isinstance(v, list) and k != "_class_name"]
|
1313
|
-
|
1314
|
-
filenames = {sibling.rfilename for sibling in info.siblings}
|
1315
|
-
model_filenames, variant_filenames = variant_compatible_siblings(filenames, variant=variant)
|
1316
|
-
|
1317
|
-
diffusers_module = importlib.import_module(__name__.split(".")[0])
|
1318
|
-
pipelines = getattr(diffusers_module, "pipelines")
|
1319
|
-
|
1320
|
-
# optionally create a custom component <> custom file mapping
|
1321
|
-
custom_components = {}
|
1322
|
-
for component in folder_names:
|
1323
|
-
module_candidate = config_dict[component][0]
|
1324
|
-
|
1325
|
-
if module_candidate is None or not isinstance(module_candidate, str):
|
1326
|
-
continue
|
1327
|
-
|
1328
|
-
# We compute candidate file path on the Hub. Do not use `os.path.join`.
|
1329
|
-
candidate_file = f"{component}/{module_candidate}.py"
|
1330
|
-
|
1331
|
-
if candidate_file in filenames:
|
1332
|
-
custom_components[component] = module_candidate
|
1333
|
-
elif module_candidate not in LOADABLE_CLASSES and not hasattr(pipelines, module_candidate):
|
1334
|
-
raise ValueError(
|
1335
|
-
f"{candidate_file} as defined in `model_index.json` does not exist in {pretrained_model_name} and is not a module in 'diffusers/pipelines'."
|
1336
|
-
)
|
1337
|
-
|
1338
|
-
if len(variant_filenames) == 0 and variant is not None:
|
1339
|
-
deprecation_message = (
|
1340
|
-
f"You are trying to load the model files of the `variant={variant}`, but no such modeling files are available."
|
1341
|
-
f"The default model files: {model_filenames} will be loaded instead. Make sure to not load from `variant={variant}`"
|
1342
|
-
"if such variant modeling files are not available. Doing so will lead to an error in v0.24.0 as defaulting to non-variant"
|
1343
|
-
"modeling files is deprecated."
|
1344
|
-
)
|
1345
|
-
deprecate("no variant default", "0.24.0", deprecation_message, standard_warn=False)
|
1346
|
-
|
1347
1318
|
# remove ignored filenames
|
1348
1319
|
model_filenames = set(model_filenames) - set(ignore_filenames)
|
1349
1320
|
variant_filenames = set(variant_filenames) - set(ignore_filenames)
|
1350
1321
|
|
1351
|
-
# if the whole pipeline is cached we don't have to ping the Hub
|
1352
1322
|
if revision in DEPRECATED_REVISION_ARGS and version.parse(
|
1353
1323
|
version.parse(__version__).base_version
|
1354
1324
|
) >= version.parse("0.22.0"):
|
1355
1325
|
warn_deprecated_model_variant(pretrained_model_name, token, variant, revision, model_filenames)
|
1356
1326
|
|
1327
|
+
custom_components, folder_names = _get_custom_components_and_folders(
|
1328
|
+
pretrained_model_name, config_dict, filenames, variant_filenames, variant
|
1329
|
+
)
|
1357
1330
|
model_folder_names = {os.path.split(f)[0] for f in model_filenames if os.path.split(f)[0] in folder_names}
|
1358
1331
|
|
1359
1332
|
custom_class_name = None
|
@@ -1413,49 +1386,19 @@ class DiffusionPipeline(ConfigMixin, PushToHubMixin):
|
|
1413
1386
|
expected_components, _ = cls._get_signature_keys(pipeline_class)
|
1414
1387
|
passed_components = [k for k in expected_components if k in kwargs]
|
1415
1388
|
|
1416
|
-
|
1417
|
-
|
1418
|
-
|
1419
|
-
|
1420
|
-
|
1421
|
-
|
1422
|
-
|
1423
|
-
|
1424
|
-
|
1425
|
-
|
1426
|
-
|
1427
|
-
|
1428
|
-
|
1429
|
-
model_filenames, variant=variant, passed_components=passed_components
|
1430
|
-
):
|
1431
|
-
ignore_patterns = ["*.bin", "*.msgpack"]
|
1432
|
-
|
1433
|
-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
1434
|
-
if not use_onnx:
|
1435
|
-
ignore_patterns += ["*.onnx", "*.pb"]
|
1436
|
-
|
1437
|
-
safetensors_variant_filenames = {f for f in variant_filenames if f.endswith(".safetensors")}
|
1438
|
-
safetensors_model_filenames = {f for f in model_filenames if f.endswith(".safetensors")}
|
1439
|
-
if (
|
1440
|
-
len(safetensors_variant_filenames) > 0
|
1441
|
-
and safetensors_model_filenames != safetensors_variant_filenames
|
1442
|
-
):
|
1443
|
-
logger.warning(
|
1444
|
-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(safetensors_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(safetensors_model_filenames - safetensors_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1445
|
-
)
|
1446
|
-
else:
|
1447
|
-
ignore_patterns = ["*.safetensors", "*.msgpack"]
|
1448
|
-
|
1449
|
-
use_onnx = use_onnx if use_onnx is not None else pipeline_class._is_onnx
|
1450
|
-
if not use_onnx:
|
1451
|
-
ignore_patterns += ["*.onnx", "*.pb"]
|
1452
|
-
|
1453
|
-
bin_variant_filenames = {f for f in variant_filenames if f.endswith(".bin")}
|
1454
|
-
bin_model_filenames = {f for f in model_filenames if f.endswith(".bin")}
|
1455
|
-
if len(bin_variant_filenames) > 0 and bin_model_filenames != bin_variant_filenames:
|
1456
|
-
logger.warning(
|
1457
|
-
f"\nA mixture of {variant} and non-{variant} filenames will be loaded.\nLoaded {variant} filenames:\n[{', '.join(bin_variant_filenames)}]\nLoaded non-{variant} filenames:\n[{', '.join(bin_model_filenames - bin_variant_filenames)}\nIf this behavior is not expected, please check your folder structure."
|
1458
|
-
)
|
1389
|
+
# retrieve all patterns that should not be downloaded and error out when needed
|
1390
|
+
ignore_patterns = _get_ignore_patterns(
|
1391
|
+
passed_components,
|
1392
|
+
model_folder_names,
|
1393
|
+
model_filenames,
|
1394
|
+
variant_filenames,
|
1395
|
+
use_safetensors,
|
1396
|
+
from_flax,
|
1397
|
+
allow_pickle,
|
1398
|
+
use_onnx,
|
1399
|
+
pipeline_class._is_onnx,
|
1400
|
+
variant,
|
1401
|
+
)
|
1459
1402
|
|
1460
1403
|
# Don't download any objects that are passed
|
1461
1404
|
allow_patterns = [
|
@@ -178,7 +178,7 @@ def retrieve_timesteps(
|
|
178
178
|
sigmas: Optional[List[float]] = None,
|
179
179
|
**kwargs,
|
180
180
|
):
|
181
|
-
"""
|
181
|
+
r"""
|
182
182
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
183
183
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
184
184
|
|
@@ -122,7 +122,7 @@ def retrieve_timesteps(
|
|
122
122
|
sigmas: Optional[List[float]] = None,
|
123
123
|
**kwargs,
|
124
124
|
):
|
125
|
-
"""
|
125
|
+
r"""
|
126
126
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
127
127
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
128
128
|
|
@@ -281,6 +281,16 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|
281
281
|
def num_timesteps(self):
|
282
282
|
return self._num_timesteps
|
283
283
|
|
284
|
+
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
285
|
+
s = torch.tensor([0.008])
|
286
|
+
clamp_range = [0, 1]
|
287
|
+
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
288
|
+
var = alphas_cumprod[t]
|
289
|
+
var = var.clamp(*clamp_range)
|
290
|
+
s, min_var = s.to(var.device), min_var.to(var.device)
|
291
|
+
ratio = (((var * min_var) ** 0.5).acos() / (torch.pi * 0.5)) * (1 + s) - s
|
292
|
+
return ratio
|
293
|
+
|
284
294
|
@torch.no_grad()
|
285
295
|
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
286
296
|
def __call__(
|
@@ -434,10 +444,30 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|
434
444
|
batch_size, image_embeddings, num_images_per_prompt, dtype, device, generator, latents, self.scheduler
|
435
445
|
)
|
436
446
|
|
447
|
+
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
448
|
+
timesteps = timesteps[:-1]
|
449
|
+
else:
|
450
|
+
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
451
|
+
self.scheduler.config.clip_sample = False # disample sample clipping
|
452
|
+
logger.warning(" set `clip_sample` to be False")
|
453
|
+
|
437
454
|
# 6. Run denoising loop
|
438
|
-
self.
|
439
|
-
|
440
|
-
|
455
|
+
if hasattr(self.scheduler, "betas"):
|
456
|
+
alphas = 1.0 - self.scheduler.betas
|
457
|
+
alphas_cumprod = torch.cumprod(alphas, dim=0)
|
458
|
+
else:
|
459
|
+
alphas_cumprod = []
|
460
|
+
|
461
|
+
self._num_timesteps = len(timesteps)
|
462
|
+
for i, t in enumerate(self.progress_bar(timesteps)):
|
463
|
+
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
464
|
+
if len(alphas_cumprod) > 0:
|
465
|
+
timestep_ratio = self.get_timestep_ratio_conditioning(t.long().cpu(), alphas_cumprod)
|
466
|
+
timestep_ratio = timestep_ratio.expand(latents.size(0)).to(dtype).to(device)
|
467
|
+
else:
|
468
|
+
timestep_ratio = t.float().div(self.scheduler.timesteps[-1]).expand(latents.size(0)).to(dtype)
|
469
|
+
else:
|
470
|
+
timestep_ratio = t.expand(latents.size(0)).to(dtype)
|
441
471
|
|
442
472
|
# 7. Denoise latents
|
443
473
|
predicted_latents = self.decoder(
|
@@ -454,6 +484,8 @@ class StableCascadeDecoderPipeline(DiffusionPipeline):
|
|
454
484
|
predicted_latents = torch.lerp(predicted_latents_uncond, predicted_latents_text, self.guidance_scale)
|
455
485
|
|
456
486
|
# 9. Renoise latents to next timestep
|
487
|
+
if not isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
488
|
+
timestep_ratio = t
|
457
489
|
latents = self.scheduler.step(
|
458
490
|
model_output=predicted_latents,
|
459
491
|
timestep=timestep_ratio,
|
@@ -353,7 +353,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
|
353
353
|
return self._num_timesteps
|
354
354
|
|
355
355
|
def get_timestep_ratio_conditioning(self, t, alphas_cumprod):
|
356
|
-
s = torch.tensor([0.
|
356
|
+
s = torch.tensor([0.008])
|
357
357
|
clamp_range = [0, 1]
|
358
358
|
min_var = torch.cos(s / (1 + s) * torch.pi * 0.5) ** 2
|
359
359
|
var = alphas_cumprod[t]
|
@@ -557,7 +557,7 @@ class StableCascadePriorPipeline(DiffusionPipeline):
|
|
557
557
|
if isinstance(self.scheduler, DDPMWuerstchenScheduler):
|
558
558
|
timesteps = timesteps[:-1]
|
559
559
|
else:
|
560
|
-
if self.scheduler.config.clip_sample:
|
560
|
+
if hasattr(self.scheduler.config, "clip_sample") and self.scheduler.config.clip_sample:
|
561
561
|
self.scheduler.config.clip_sample = False # disample sample clipping
|
562
562
|
logger.warning(" set `clip_sample` to be False")
|
563
563
|
# 6. Run denoising loop
|