diffusers 0.30.3__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +97 -4
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +13 -1
- diffusers/image_processor.py +282 -71
- diffusers/loaders/__init__.py +24 -3
- diffusers/loaders/ip_adapter.py +543 -16
- diffusers/loaders/lora_base.py +138 -125
- diffusers/loaders/lora_conversion_utils.py +647 -0
- diffusers/loaders/lora_pipeline.py +2216 -230
- diffusers/loaders/peft.py +380 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +597 -10
- diffusers/loaders/textual_inversion.py +5 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +56 -12
- diffusers/models/__init__.py +49 -12
- diffusers/models/activations.py +22 -9
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +98 -13
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2160 -346
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +73 -12
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +213 -105
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +70 -0
- diffusers/models/controlnet_sd3.py +26 -376
- diffusers/models/controlnet_sparsectrl.py +46 -719
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +996 -92
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +264 -14
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +334 -51
- diffusers/models/normalization.py +157 -13
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +3 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +69 -13
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +10 -2
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +189 -51
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +112 -18
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +9 -9
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +46 -68
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +14 -6
- diffusers/pipelines/__init__.py +69 -6
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +52 -22
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +3 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -72
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +2 -9
- diffusers/pipelines/auto_pipeline.py +88 -10
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +80 -39
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +108 -50
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +89 -50
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +20 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +9 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +37 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +12 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +9 -4
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +22 -4
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +56 -20
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +32 -9
- diffusers/pipelines/flux/__init__.py +23 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +256 -48
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1006 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +998 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1204 -0
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +856 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1022 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +20 -4
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +3 -10
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +13 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +2 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1543 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +3 -5
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1683 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +22 -6
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +7 -14
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +18 -9
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1094 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +250 -31
- diffusers/pipelines/pipeline_utils.py +158 -186
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +7 -14
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +7 -14
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +46 -9
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +228 -23
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +82 -13
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +60 -11
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -12
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -22
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -22
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +139 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +561 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +669 -0
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +6 -7
- diffusers/schedulers/scheduling_ddpm_parallel.py +6 -7
- diffusers/schedulers/scheduling_deis_multistep.py +102 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +113 -6
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +111 -5
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +126 -7
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +102 -6
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +127 -5
- diffusers/training_utils.py +63 -19
- diffusers/utils/__init__.py +7 -1
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +240 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +435 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +44 -40
- diffusers/utils/import_utils.py +98 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +6 -3
- diffusers/utils/testing_utils.py +115 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/METADATA +73 -72
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/RECORD +268 -193
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
diffusers/training_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import contextlib
|
2
2
|
import copy
|
3
|
+
import gc
|
3
4
|
import math
|
4
5
|
import random
|
5
6
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
@@ -23,6 +24,9 @@ from .utils import (
|
|
23
24
|
if is_transformers_available():
|
24
25
|
import transformers
|
25
26
|
|
27
|
+
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
28
|
+
import deepspeed
|
29
|
+
|
26
30
|
if is_peft_available():
|
27
31
|
from peft import set_peft_model_state_dict
|
28
32
|
|
@@ -35,9 +39,13 @@ if is_torch_npu_available():
|
|
35
39
|
|
36
40
|
def set_seed(seed: int):
|
37
41
|
"""
|
38
|
-
Args:
|
39
42
|
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
43
|
+
|
44
|
+
Args:
|
40
45
|
seed (`int`): The seed to set.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
`None`
|
41
49
|
"""
|
42
50
|
random.seed(seed)
|
43
51
|
np.random.seed(seed)
|
@@ -53,6 +61,17 @@ def compute_snr(noise_scheduler, timesteps):
|
|
53
61
|
"""
|
54
62
|
Computes SNR as per
|
55
63
|
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
64
|
+
for the given timesteps using the provided noise scheduler.
|
65
|
+
|
66
|
+
Args:
|
67
|
+
noise_scheduler (`NoiseScheduler`):
|
68
|
+
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
69
|
+
the SNR values.
|
70
|
+
timesteps (`torch.Tensor`):
|
71
|
+
A tensor of timesteps for which the SNR is computed.
|
72
|
+
|
73
|
+
Returns:
|
74
|
+
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
56
75
|
"""
|
57
76
|
alphas_cumprod = noise_scheduler.alphas_cumprod
|
58
77
|
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
@@ -193,6 +212,13 @@ def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
|
193
212
|
|
194
213
|
|
195
214
|
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
215
|
+
"""
|
216
|
+
Casts the training parameters of the model to the specified data type.
|
217
|
+
|
218
|
+
Args:
|
219
|
+
model: The PyTorch model whose parameters will be cast.
|
220
|
+
dtype: The data type to which the model parameters will be cast.
|
221
|
+
"""
|
196
222
|
if not isinstance(model, list):
|
197
223
|
model = [model]
|
198
224
|
for m in model:
|
@@ -224,7 +250,8 @@ def _set_state_dict_into_text_encoder(
|
|
224
250
|
def compute_density_for_timestep_sampling(
|
225
251
|
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
226
252
|
):
|
227
|
-
"""
|
253
|
+
"""
|
254
|
+
Compute the density for sampling the timesteps when doing SD3 training.
|
228
255
|
|
229
256
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
230
257
|
|
@@ -243,7 +270,8 @@ def compute_density_for_timestep_sampling(
|
|
243
270
|
|
244
271
|
|
245
272
|
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
246
|
-
"""
|
273
|
+
"""
|
274
|
+
Computes loss weighting scheme for SD3 training.
|
247
275
|
|
248
276
|
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
249
277
|
|
@@ -259,6 +287,20 @@ def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
|
259
287
|
return weighting
|
260
288
|
|
261
289
|
|
290
|
+
def free_memory():
|
291
|
+
"""
|
292
|
+
Runs garbage collection. Then clears the cache of the available accelerator.
|
293
|
+
"""
|
294
|
+
gc.collect()
|
295
|
+
|
296
|
+
if torch.cuda.is_available():
|
297
|
+
torch.cuda.empty_cache()
|
298
|
+
elif torch.backends.mps.is_available():
|
299
|
+
torch.mps.empty_cache()
|
300
|
+
elif is_torch_npu_available():
|
301
|
+
torch_npu.npu.empty_cache()
|
302
|
+
|
303
|
+
|
262
304
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
263
305
|
class EMAModel:
|
264
306
|
"""
|
@@ -351,7 +393,7 @@ class EMAModel:
|
|
351
393
|
|
352
394
|
@classmethod
|
353
395
|
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
354
|
-
_, ema_kwargs = model_cls.
|
396
|
+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
355
397
|
model = model_cls.from_pretrained(path)
|
356
398
|
|
357
399
|
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
@@ -417,15 +459,13 @@ class EMAModel:
|
|
417
459
|
self.cur_decay_value = decay
|
418
460
|
one_minus_decay = 1 - decay
|
419
461
|
|
420
|
-
context_manager = contextlib.nullcontext
|
421
|
-
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
422
|
-
import deepspeed
|
462
|
+
context_manager = contextlib.nullcontext()
|
423
463
|
|
424
464
|
if self.foreach:
|
425
|
-
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
465
|
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
426
466
|
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
427
467
|
|
428
|
-
with context_manager
|
468
|
+
with context_manager:
|
429
469
|
params_grad = [param for param in parameters if param.requires_grad]
|
430
470
|
s_params_grad = [
|
431
471
|
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
|
@@ -444,10 +484,10 @@ class EMAModel:
|
|
444
484
|
|
445
485
|
else:
|
446
486
|
for s_param, param in zip(self.shadow_params, parameters):
|
447
|
-
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
487
|
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
448
488
|
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
449
489
|
|
450
|
-
with context_manager
|
490
|
+
with context_manager:
|
451
491
|
if param.requires_grad:
|
452
492
|
s_param.sub_(one_minus_decay * (s_param - param))
|
453
493
|
else:
|
@@ -481,7 +521,8 @@ class EMAModel:
|
|
481
521
|
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
|
482
522
|
|
483
523
|
def to(self, device=None, dtype=None, non_blocking=False) -> None:
|
484
|
-
r"""
|
524
|
+
r"""
|
525
|
+
Move internal buffers of the ExponentialMovingAverage to `device`.
|
485
526
|
|
486
527
|
Args:
|
487
528
|
device: like `device` argument to `torch.Tensor.to`
|
@@ -515,23 +556,25 @@ class EMAModel:
|
|
515
556
|
|
516
557
|
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
517
558
|
r"""
|
559
|
+
Saves the current parameters for restoring later.
|
560
|
+
|
518
561
|
Args:
|
519
|
-
|
520
|
-
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
521
|
-
temporarily stored.
|
562
|
+
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
|
522
563
|
"""
|
523
564
|
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
524
565
|
|
525
566
|
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
526
567
|
r"""
|
527
|
-
|
528
|
-
|
529
|
-
affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
568
|
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
|
569
|
+
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
530
570
|
validation (or model saving), use this to restore the former parameters.
|
571
|
+
|
572
|
+
Args:
|
531
573
|
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
532
574
|
updated with the stored parameters. If `None`, the parameters with which this
|
533
575
|
`ExponentialMovingAverage` was initialized will be used.
|
534
576
|
"""
|
577
|
+
|
535
578
|
if self.temp_stored_params is None:
|
536
579
|
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
|
537
580
|
if self.foreach:
|
@@ -547,9 +590,10 @@ class EMAModel:
|
|
547
590
|
|
548
591
|
def load_state_dict(self, state_dict: dict) -> None:
|
549
592
|
r"""
|
550
|
-
Args:
|
551
593
|
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
552
594
|
ema state dict.
|
595
|
+
|
596
|
+
Args:
|
553
597
|
state_dict (dict): EMA state. Should be an object returned
|
554
598
|
from a call to :meth:`state_dict`.
|
555
599
|
"""
|
diffusers/utils/__init__.py
CHANGED
@@ -23,6 +23,7 @@ from .constants import (
|
|
23
23
|
DEPRECATED_REVISION_ARGS,
|
24
24
|
DIFFUSERS_DYNAMIC_MODULE_NAME,
|
25
25
|
FLAX_WEIGHTS_NAME,
|
26
|
+
GGUF_FILE_EXTENSION,
|
26
27
|
HF_MODULES_CACHE,
|
27
28
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
28
29
|
MIN_PEFT_VERSION,
|
@@ -62,9 +63,12 @@ from .import_utils import (
|
|
62
63
|
is_accelerate_available,
|
63
64
|
is_accelerate_version,
|
64
65
|
is_bitsandbytes_available,
|
66
|
+
is_bitsandbytes_version,
|
65
67
|
is_bs4_available,
|
66
68
|
is_flax_available,
|
67
69
|
is_ftfy_available,
|
70
|
+
is_gguf_available,
|
71
|
+
is_gguf_version,
|
68
72
|
is_google_colab,
|
69
73
|
is_inflect_available,
|
70
74
|
is_invisible_watermark_available,
|
@@ -85,6 +89,8 @@ from .import_utils import (
|
|
85
89
|
is_torch_npu_available,
|
86
90
|
is_torch_version,
|
87
91
|
is_torch_xla_available,
|
92
|
+
is_torch_xla_version,
|
93
|
+
is_torchao_available,
|
88
94
|
is_torchsde_available,
|
89
95
|
is_torchvision_available,
|
90
96
|
is_transformers_available,
|
@@ -94,7 +100,7 @@ from .import_utils import (
|
|
94
100
|
is_xformers_available,
|
95
101
|
requires_backends,
|
96
102
|
)
|
97
|
-
from .loading_utils import load_image, load_video
|
103
|
+
from .loading_utils import get_module_from_name, load_image, load_video
|
98
104
|
from .logging import get_logger
|
99
105
|
from .outputs import BaseOutput
|
100
106
|
from .peft_utils import (
|
diffusers/utils/constants.py
CHANGED
@@ -34,6 +34,7 @@ ONNX_WEIGHTS_NAME = "model.onnx"
|
|
34
34
|
SAFETENSORS_WEIGHTS_NAME = "diffusion_pytorch_model.safetensors"
|
35
35
|
SAFE_WEIGHTS_INDEX_NAME = "diffusion_pytorch_model.safetensors.index.json"
|
36
36
|
SAFETENSORS_FILE_EXTENSION = "safetensors"
|
37
|
+
GGUF_FILE_EXTENSION = "gguf"
|
37
38
|
ONNX_EXTERNAL_WEIGHTS_NAME = "weights.pb"
|
38
39
|
HUGGINGFACE_CO_RESOLVE_ENDPOINT = os.environ.get("HF_ENDPOINT", "https://huggingface.co")
|
39
40
|
DIFFUSERS_DYNAMIC_MODULE_NAME = "diffusers_modules"
|
@@ -2,6 +2,21 @@
|
|
2
2
|
from ..utils import DummyObject, requires_backends
|
3
3
|
|
4
4
|
|
5
|
+
class AllegroTransformer3DModel(metaclass=DummyObject):
|
6
|
+
_backends = ["torch"]
|
7
|
+
|
8
|
+
def __init__(self, *args, **kwargs):
|
9
|
+
requires_backends(self, ["torch"])
|
10
|
+
|
11
|
+
@classmethod
|
12
|
+
def from_config(cls, *args, **kwargs):
|
13
|
+
requires_backends(cls, ["torch"])
|
14
|
+
|
15
|
+
@classmethod
|
16
|
+
def from_pretrained(cls, *args, **kwargs):
|
17
|
+
requires_backends(cls, ["torch"])
|
18
|
+
|
19
|
+
|
5
20
|
class AsymmetricAutoencoderKL(metaclass=DummyObject):
|
6
21
|
_backends = ["torch"]
|
7
22
|
|
@@ -32,6 +47,21 @@ class AuraFlowTransformer2DModel(metaclass=DummyObject):
|
|
32
47
|
requires_backends(cls, ["torch"])
|
33
48
|
|
34
49
|
|
50
|
+
class AutoencoderDC(metaclass=DummyObject):
|
51
|
+
_backends = ["torch"]
|
52
|
+
|
53
|
+
def __init__(self, *args, **kwargs):
|
54
|
+
requires_backends(self, ["torch"])
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def from_config(cls, *args, **kwargs):
|
58
|
+
requires_backends(cls, ["torch"])
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def from_pretrained(cls, *args, **kwargs):
|
62
|
+
requires_backends(cls, ["torch"])
|
63
|
+
|
64
|
+
|
35
65
|
class AutoencoderKL(metaclass=DummyObject):
|
36
66
|
_backends = ["torch"]
|
37
67
|
|
@@ -47,6 +77,21 @@ class AutoencoderKL(metaclass=DummyObject):
|
|
47
77
|
requires_backends(cls, ["torch"])
|
48
78
|
|
49
79
|
|
80
|
+
class AutoencoderKLAllegro(metaclass=DummyObject):
|
81
|
+
_backends = ["torch"]
|
82
|
+
|
83
|
+
def __init__(self, *args, **kwargs):
|
84
|
+
requires_backends(self, ["torch"])
|
85
|
+
|
86
|
+
@classmethod
|
87
|
+
def from_config(cls, *args, **kwargs):
|
88
|
+
requires_backends(cls, ["torch"])
|
89
|
+
|
90
|
+
@classmethod
|
91
|
+
def from_pretrained(cls, *args, **kwargs):
|
92
|
+
requires_backends(cls, ["torch"])
|
93
|
+
|
94
|
+
|
50
95
|
class AutoencoderKLCogVideoX(metaclass=DummyObject):
|
51
96
|
_backends = ["torch"]
|
52
97
|
|
@@ -62,6 +107,51 @@ class AutoencoderKLCogVideoX(metaclass=DummyObject):
|
|
62
107
|
requires_backends(cls, ["torch"])
|
63
108
|
|
64
109
|
|
110
|
+
class AutoencoderKLHunyuanVideo(metaclass=DummyObject):
|
111
|
+
_backends = ["torch"]
|
112
|
+
|
113
|
+
def __init__(self, *args, **kwargs):
|
114
|
+
requires_backends(self, ["torch"])
|
115
|
+
|
116
|
+
@classmethod
|
117
|
+
def from_config(cls, *args, **kwargs):
|
118
|
+
requires_backends(cls, ["torch"])
|
119
|
+
|
120
|
+
@classmethod
|
121
|
+
def from_pretrained(cls, *args, **kwargs):
|
122
|
+
requires_backends(cls, ["torch"])
|
123
|
+
|
124
|
+
|
125
|
+
class AutoencoderKLLTXVideo(metaclass=DummyObject):
|
126
|
+
_backends = ["torch"]
|
127
|
+
|
128
|
+
def __init__(self, *args, **kwargs):
|
129
|
+
requires_backends(self, ["torch"])
|
130
|
+
|
131
|
+
@classmethod
|
132
|
+
def from_config(cls, *args, **kwargs):
|
133
|
+
requires_backends(cls, ["torch"])
|
134
|
+
|
135
|
+
@classmethod
|
136
|
+
def from_pretrained(cls, *args, **kwargs):
|
137
|
+
requires_backends(cls, ["torch"])
|
138
|
+
|
139
|
+
|
140
|
+
class AutoencoderKLMochi(metaclass=DummyObject):
|
141
|
+
_backends = ["torch"]
|
142
|
+
|
143
|
+
def __init__(self, *args, **kwargs):
|
144
|
+
requires_backends(self, ["torch"])
|
145
|
+
|
146
|
+
@classmethod
|
147
|
+
def from_config(cls, *args, **kwargs):
|
148
|
+
requires_backends(cls, ["torch"])
|
149
|
+
|
150
|
+
@classmethod
|
151
|
+
def from_pretrained(cls, *args, **kwargs):
|
152
|
+
requires_backends(cls, ["torch"])
|
153
|
+
|
154
|
+
|
65
155
|
class AutoencoderKLTemporalDecoder(metaclass=DummyObject):
|
66
156
|
_backends = ["torch"]
|
67
157
|
|
@@ -122,6 +212,21 @@ class CogVideoXTransformer3DModel(metaclass=DummyObject):
|
|
122
212
|
requires_backends(cls, ["torch"])
|
123
213
|
|
124
214
|
|
215
|
+
class CogView3PlusTransformer2DModel(metaclass=DummyObject):
|
216
|
+
_backends = ["torch"]
|
217
|
+
|
218
|
+
def __init__(self, *args, **kwargs):
|
219
|
+
requires_backends(self, ["torch"])
|
220
|
+
|
221
|
+
@classmethod
|
222
|
+
def from_config(cls, *args, **kwargs):
|
223
|
+
requires_backends(cls, ["torch"])
|
224
|
+
|
225
|
+
@classmethod
|
226
|
+
def from_pretrained(cls, *args, **kwargs):
|
227
|
+
requires_backends(cls, ["torch"])
|
228
|
+
|
229
|
+
|
125
230
|
class ConsistencyDecoderVAE(metaclass=DummyObject):
|
126
231
|
_backends = ["torch"]
|
127
232
|
|
@@ -152,6 +257,21 @@ class ControlNetModel(metaclass=DummyObject):
|
|
152
257
|
requires_backends(cls, ["torch"])
|
153
258
|
|
154
259
|
|
260
|
+
class ControlNetUnionModel(metaclass=DummyObject):
|
261
|
+
_backends = ["torch"]
|
262
|
+
|
263
|
+
def __init__(self, *args, **kwargs):
|
264
|
+
requires_backends(self, ["torch"])
|
265
|
+
|
266
|
+
@classmethod
|
267
|
+
def from_config(cls, *args, **kwargs):
|
268
|
+
requires_backends(cls, ["torch"])
|
269
|
+
|
270
|
+
@classmethod
|
271
|
+
def from_pretrained(cls, *args, **kwargs):
|
272
|
+
requires_backends(cls, ["torch"])
|
273
|
+
|
274
|
+
|
155
275
|
class ControlNetXSAdapter(metaclass=DummyObject):
|
156
276
|
_backends = ["torch"]
|
157
277
|
|
@@ -182,6 +302,36 @@ class DiTTransformer2DModel(metaclass=DummyObject):
|
|
182
302
|
requires_backends(cls, ["torch"])
|
183
303
|
|
184
304
|
|
305
|
+
class FluxControlNetModel(metaclass=DummyObject):
|
306
|
+
_backends = ["torch"]
|
307
|
+
|
308
|
+
def __init__(self, *args, **kwargs):
|
309
|
+
requires_backends(self, ["torch"])
|
310
|
+
|
311
|
+
@classmethod
|
312
|
+
def from_config(cls, *args, **kwargs):
|
313
|
+
requires_backends(cls, ["torch"])
|
314
|
+
|
315
|
+
@classmethod
|
316
|
+
def from_pretrained(cls, *args, **kwargs):
|
317
|
+
requires_backends(cls, ["torch"])
|
318
|
+
|
319
|
+
|
320
|
+
class FluxMultiControlNetModel(metaclass=DummyObject):
|
321
|
+
_backends = ["torch"]
|
322
|
+
|
323
|
+
def __init__(self, *args, **kwargs):
|
324
|
+
requires_backends(self, ["torch"])
|
325
|
+
|
326
|
+
@classmethod
|
327
|
+
def from_config(cls, *args, **kwargs):
|
328
|
+
requires_backends(cls, ["torch"])
|
329
|
+
|
330
|
+
@classmethod
|
331
|
+
def from_pretrained(cls, *args, **kwargs):
|
332
|
+
requires_backends(cls, ["torch"])
|
333
|
+
|
334
|
+
|
185
335
|
class FluxTransformer2DModel(metaclass=DummyObject):
|
186
336
|
_backends = ["torch"]
|
187
337
|
|
@@ -242,6 +392,21 @@ class HunyuanDiT2DMultiControlNetModel(metaclass=DummyObject):
|
|
242
392
|
requires_backends(cls, ["torch"])
|
243
393
|
|
244
394
|
|
395
|
+
class HunyuanVideoTransformer3DModel(metaclass=DummyObject):
|
396
|
+
_backends = ["torch"]
|
397
|
+
|
398
|
+
def __init__(self, *args, **kwargs):
|
399
|
+
requires_backends(self, ["torch"])
|
400
|
+
|
401
|
+
@classmethod
|
402
|
+
def from_config(cls, *args, **kwargs):
|
403
|
+
requires_backends(cls, ["torch"])
|
404
|
+
|
405
|
+
@classmethod
|
406
|
+
def from_pretrained(cls, *args, **kwargs):
|
407
|
+
requires_backends(cls, ["torch"])
|
408
|
+
|
409
|
+
|
245
410
|
class I2VGenXLUNet(metaclass=DummyObject):
|
246
411
|
_backends = ["torch"]
|
247
412
|
|
@@ -287,6 +452,21 @@ class LatteTransformer3DModel(metaclass=DummyObject):
|
|
287
452
|
requires_backends(cls, ["torch"])
|
288
453
|
|
289
454
|
|
455
|
+
class LTXVideoTransformer3DModel(metaclass=DummyObject):
|
456
|
+
_backends = ["torch"]
|
457
|
+
|
458
|
+
def __init__(self, *args, **kwargs):
|
459
|
+
requires_backends(self, ["torch"])
|
460
|
+
|
461
|
+
@classmethod
|
462
|
+
def from_config(cls, *args, **kwargs):
|
463
|
+
requires_backends(cls, ["torch"])
|
464
|
+
|
465
|
+
@classmethod
|
466
|
+
def from_pretrained(cls, *args, **kwargs):
|
467
|
+
requires_backends(cls, ["torch"])
|
468
|
+
|
469
|
+
|
290
470
|
class LuminaNextDiT2DModel(metaclass=DummyObject):
|
291
471
|
_backends = ["torch"]
|
292
472
|
|
@@ -302,6 +482,21 @@ class LuminaNextDiT2DModel(metaclass=DummyObject):
|
|
302
482
|
requires_backends(cls, ["torch"])
|
303
483
|
|
304
484
|
|
485
|
+
class MochiTransformer3DModel(metaclass=DummyObject):
|
486
|
+
_backends = ["torch"]
|
487
|
+
|
488
|
+
def __init__(self, *args, **kwargs):
|
489
|
+
requires_backends(self, ["torch"])
|
490
|
+
|
491
|
+
@classmethod
|
492
|
+
def from_config(cls, *args, **kwargs):
|
493
|
+
requires_backends(cls, ["torch"])
|
494
|
+
|
495
|
+
@classmethod
|
496
|
+
def from_pretrained(cls, *args, **kwargs):
|
497
|
+
requires_backends(cls, ["torch"])
|
498
|
+
|
499
|
+
|
305
500
|
class ModelMixin(metaclass=DummyObject):
|
306
501
|
_backends = ["torch"]
|
307
502
|
|
@@ -347,6 +542,21 @@ class MultiAdapter(metaclass=DummyObject):
|
|
347
542
|
requires_backends(cls, ["torch"])
|
348
543
|
|
349
544
|
|
545
|
+
class MultiControlNetModel(metaclass=DummyObject):
|
546
|
+
_backends = ["torch"]
|
547
|
+
|
548
|
+
def __init__(self, *args, **kwargs):
|
549
|
+
requires_backends(self, ["torch"])
|
550
|
+
|
551
|
+
@classmethod
|
552
|
+
def from_config(cls, *args, **kwargs):
|
553
|
+
requires_backends(cls, ["torch"])
|
554
|
+
|
555
|
+
@classmethod
|
556
|
+
def from_pretrained(cls, *args, **kwargs):
|
557
|
+
requires_backends(cls, ["torch"])
|
558
|
+
|
559
|
+
|
350
560
|
class PixArtTransformer2DModel(metaclass=DummyObject):
|
351
561
|
_backends = ["torch"]
|
352
562
|
|
@@ -377,6 +587,21 @@ class PriorTransformer(metaclass=DummyObject):
|
|
377
587
|
requires_backends(cls, ["torch"])
|
378
588
|
|
379
589
|
|
590
|
+
class SanaTransformer2DModel(metaclass=DummyObject):
|
591
|
+
_backends = ["torch"]
|
592
|
+
|
593
|
+
def __init__(self, *args, **kwargs):
|
594
|
+
requires_backends(self, ["torch"])
|
595
|
+
|
596
|
+
@classmethod
|
597
|
+
def from_config(cls, *args, **kwargs):
|
598
|
+
requires_backends(cls, ["torch"])
|
599
|
+
|
600
|
+
@classmethod
|
601
|
+
def from_pretrained(cls, *args, **kwargs):
|
602
|
+
requires_backends(cls, ["torch"])
|
603
|
+
|
604
|
+
|
380
605
|
class SD3ControlNetModel(metaclass=DummyObject):
|
381
606
|
_backends = ["torch"]
|
382
607
|
|
@@ -975,6 +1200,21 @@ class StableDiffusionMixin(metaclass=DummyObject):
|
|
975
1200
|
requires_backends(cls, ["torch"])
|
976
1201
|
|
977
1202
|
|
1203
|
+
class DiffusersQuantizer(metaclass=DummyObject):
|
1204
|
+
_backends = ["torch"]
|
1205
|
+
|
1206
|
+
def __init__(self, *args, **kwargs):
|
1207
|
+
requires_backends(self, ["torch"])
|
1208
|
+
|
1209
|
+
@classmethod
|
1210
|
+
def from_config(cls, *args, **kwargs):
|
1211
|
+
requires_backends(cls, ["torch"])
|
1212
|
+
|
1213
|
+
@classmethod
|
1214
|
+
def from_pretrained(cls, *args, **kwargs):
|
1215
|
+
requires_backends(cls, ["torch"])
|
1216
|
+
|
1217
|
+
|
978
1218
|
class AmusedScheduler(metaclass=DummyObject):
|
979
1219
|
_backends = ["torch"]
|
980
1220
|
|