diffusers 0.29.2__py3-none-any.whl → 0.30.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +94 -3
- diffusers/commands/env.py +1 -5
- diffusers/configuration_utils.py +4 -9
- diffusers/dependency_versions_table.py +2 -2
- diffusers/image_processor.py +1 -2
- diffusers/loaders/__init__.py +17 -2
- diffusers/loaders/ip_adapter.py +10 -7
- diffusers/loaders/lora_base.py +752 -0
- diffusers/loaders/lora_pipeline.py +2222 -0
- diffusers/loaders/peft.py +213 -5
- diffusers/loaders/single_file.py +1 -12
- diffusers/loaders/single_file_model.py +31 -10
- diffusers/loaders/single_file_utils.py +262 -2
- diffusers/loaders/textual_inversion.py +1 -6
- diffusers/loaders/unet.py +23 -208
- diffusers/models/__init__.py +20 -0
- diffusers/models/activations.py +22 -0
- diffusers/models/attention.py +386 -7
- diffusers/models/attention_processor.py +1795 -629
- diffusers/models/autoencoders/__init__.py +2 -0
- diffusers/models/autoencoders/autoencoder_kl.py +14 -3
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +1035 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +1 -1
- diffusers/models/autoencoders/autoencoder_oobleck.py +464 -0
- diffusers/models/autoencoders/autoencoder_tiny.py +1 -0
- diffusers/models/autoencoders/consistency_decoder_vae.py +1 -1
- diffusers/models/autoencoders/vq_model.py +4 -4
- diffusers/models/controlnet.py +2 -3
- diffusers/models/controlnet_hunyuan.py +401 -0
- diffusers/models/controlnet_sd3.py +11 -11
- diffusers/models/controlnet_sparsectrl.py +789 -0
- diffusers/models/controlnet_xs.py +40 -10
- diffusers/models/downsampling.py +68 -0
- diffusers/models/embeddings.py +319 -36
- diffusers/models/model_loading_utils.py +1 -3
- diffusers/models/modeling_flax_utils.py +1 -6
- diffusers/models/modeling_utils.py +4 -16
- diffusers/models/normalization.py +203 -12
- diffusers/models/transformers/__init__.py +6 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +527 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +345 -0
- diffusers/models/transformers/hunyuan_transformer_2d.py +19 -15
- diffusers/models/transformers/latte_transformer_3d.py +327 -0
- diffusers/models/transformers/lumina_nextdit2d.py +340 -0
- diffusers/models/transformers/pixart_transformer_2d.py +102 -1
- diffusers/models/transformers/prior_transformer.py +1 -1
- diffusers/models/transformers/stable_audio_transformer.py +458 -0
- diffusers/models/transformers/transformer_flux.py +455 -0
- diffusers/models/transformers/transformer_sd3.py +18 -4
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d_condition.py +8 -1
- diffusers/models/unets/unet_3d_blocks.py +51 -920
- diffusers/models/unets/unet_3d_condition.py +4 -1
- diffusers/models/unets/unet_i2vgen_xl.py +4 -1
- diffusers/models/unets/unet_kandinsky3.py +1 -1
- diffusers/models/unets/unet_motion_model.py +1330 -84
- diffusers/models/unets/unet_spatio_temporal_condition.py +1 -1
- diffusers/models/unets/unet_stable_cascade.py +1 -3
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +64 -0
- diffusers/models/vq_model.py +8 -4
- diffusers/optimization.py +1 -1
- diffusers/pipelines/__init__.py +100 -3
- diffusers/pipelines/animatediff/__init__.py +4 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +50 -40
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +1076 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +17 -27
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1008 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +51 -38
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +1 -1
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +1 -0
- diffusers/pipelines/aura_flow/__init__.py +48 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +591 -0
- diffusers/pipelines/auto_pipeline.py +97 -19
- diffusers/pipelines/cogvideo/__init__.py +48 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +687 -0
- diffusers/pipelines/consistency_models/pipeline_consistency_models.py +1 -1
- diffusers/pipelines/controlnet/pipeline_controlnet.py +24 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +31 -30
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +24 -153
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +19 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +18 -28
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +29 -32
- diffusers/pipelines/controlnet/pipeline_flax_controlnet.py +2 -2
- diffusers/pipelines/controlnet_hunyuandit/__init__.py +48 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +1042 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +35 -0
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs.py +10 -6
- diffusers/pipelines/controlnet_xs/pipeline_controlnet_xs_sd_xl.py +0 -4
- diffusers/pipelines/deepfloyd_if/pipeline_if.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_img2img_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_inpainting_superresolution.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_if_superresolution.py +2 -2
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +11 -6
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +11 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_inpaint_legacy.py +6 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_model_editing.py +10 -10
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_paradigms.py +10 -6
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_stable_diffusion_pix2pix_zero.py +3 -3
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +749 -0
- diffusers/pipelines/flux/pipeline_output.py +21 -0
- diffusers/pipelines/free_init_utils.py +2 -0
- diffusers/pipelines/free_noise_utils.py +236 -0
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3.py +2 -2
- diffusers/pipelines/kandinsky3/pipeline_kandinsky3_img2img.py +2 -2
- diffusers/pipelines/kolors/__init__.py +54 -0
- diffusers/pipelines/kolors/pipeline_kolors.py +1070 -0
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +1247 -0
- diffusers/pipelines/kolors/pipeline_output.py +21 -0
- diffusers/pipelines/kolors/text_encoder.py +889 -0
- diffusers/pipelines/kolors/tokenizer.py +334 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +30 -29
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +23 -29
- diffusers/pipelines/latte/__init__.py +48 -0
- diffusers/pipelines/latte/pipeline_latte.py +881 -0
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +4 -4
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +0 -4
- diffusers/pipelines/lumina/__init__.py +48 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +897 -0
- diffusers/pipelines/pag/__init__.py +67 -0
- diffusers/pipelines/pag/pag_utils.py +237 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1329 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1612 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +953 -0
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1136 -0
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +872 -0
- diffusers/pipelines/pag/pipeline_pag_sd.py +1050 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +985 -0
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +862 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +1333 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +1529 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +1753 -0
- diffusers/pipelines/pia/pipeline_pia.py +30 -37
- diffusers/pipelines/pipeline_flax_utils.py +4 -9
- diffusers/pipelines/pipeline_loading_utils.py +0 -3
- diffusers/pipelines/pipeline_utils.py +2 -14
- diffusers/pipelines/semantic_stable_diffusion/pipeline_semantic_stable_diffusion.py +0 -1
- diffusers/pipelines/stable_audio/__init__.py +50 -0
- diffusers/pipelines/stable_audio/modeling_stable_audio.py +158 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +745 -0
- diffusers/pipelines/stable_diffusion/convert_from_ckpt.py +2 -0
- diffusers/pipelines/stable_diffusion/pipeline_flax_stable_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +23 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_depth2img.py +15 -8
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +30 -29
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +23 -152
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_instruct_pix2pix.py +8 -4
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_upscale.py +11 -11
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip.py +8 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_unclip_img2img.py +6 -6
- diffusers/pipelines/stable_diffusion_3/__init__.py +2 -0
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +34 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +33 -7
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +1201 -0
- diffusers/pipelines/stable_diffusion_attend_and_excite/pipeline_stable_diffusion_attend_and_excite.py +3 -3
- diffusers/pipelines/stable_diffusion_diffedit/pipeline_stable_diffusion_diffedit.py +6 -6
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen.py +5 -5
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +5 -5
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +6 -6
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_xl_k_diffusion.py +0 -4
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +23 -29
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +27 -29
- diffusers/pipelines/stable_diffusion_sag/pipeline_stable_diffusion_sag.py +3 -3
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +17 -27
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +26 -29
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +17 -145
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_instruct_pix2pix.py +0 -4
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +6 -6
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +18 -28
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_synth_img2img.py +8 -6
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero.py +6 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +0 -4
- diffusers/pipelines/unidiffuser/pipeline_unidiffuser.py +3 -3
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/pipelines/wuerstchen/pipeline_wuerstchen_prior.py +5 -4
- diffusers/schedulers/__init__.py +8 -0
- diffusers/schedulers/scheduling_cosine_dpmsolver_multistep.py +572 -0
- diffusers/schedulers/scheduling_ddim.py +1 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +449 -0
- diffusers/schedulers/scheduling_ddpm.py +1 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +1 -1
- diffusers/schedulers/scheduling_deis_multistep.py +2 -2
- diffusers/schedulers/scheduling_dpm_cogvideox.py +489 -0
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +1 -1
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +64 -19
- diffusers/schedulers/scheduling_edm_dpmsolver_multistep.py +2 -2
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +63 -39
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +321 -0
- diffusers/schedulers/scheduling_ipndm.py +1 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +1 -1
- diffusers/schedulers/scheduling_utils.py +1 -3
- diffusers/schedulers/scheduling_utils_flax.py +1 -3
- diffusers/training_utils.py +99 -14
- diffusers/utils/__init__.py +2 -2
- diffusers/utils/dummy_pt_objects.py +210 -0
- diffusers/utils/dummy_torch_and_torchsde_objects.py +15 -0
- diffusers/utils/dummy_torch_and_transformers_and_sentencepiece_objects.py +47 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +315 -0
- diffusers/utils/dynamic_modules_utils.py +1 -11
- diffusers/utils/export_utils.py +1 -4
- diffusers/utils/hub_utils.py +45 -42
- diffusers/utils/import_utils.py +19 -16
- diffusers/utils/loading_utils.py +76 -3
- diffusers/utils/testing_utils.py +11 -8
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/METADATA +73 -83
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/RECORD +217 -164
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/WHEEL +1 -1
- diffusers/loaders/autoencoder.py +0 -146
- diffusers/loaders/controlnet.py +0 -136
- diffusers/loaders/lora.py +0 -1728
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/LICENSE +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.29.2.dist-info → diffusers-0.30.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,321 @@
|
|
1
|
+
# Copyright 2024 Stability AI, Katherine Crowson and The HuggingFace Team. All rights reserved.
|
2
|
+
#
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
|
15
|
+
from dataclasses import dataclass
|
16
|
+
from typing import Optional, Tuple, Union
|
17
|
+
|
18
|
+
import numpy as np
|
19
|
+
import torch
|
20
|
+
|
21
|
+
from ..configuration_utils import ConfigMixin, register_to_config
|
22
|
+
from ..utils import BaseOutput, logging
|
23
|
+
from ..utils.torch_utils import randn_tensor
|
24
|
+
from .scheduling_utils import SchedulerMixin
|
25
|
+
|
26
|
+
|
27
|
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
28
|
+
|
29
|
+
|
30
|
+
@dataclass
|
31
|
+
class FlowMatchHeunDiscreteSchedulerOutput(BaseOutput):
|
32
|
+
"""
|
33
|
+
Output class for the scheduler's `step` function output.
|
34
|
+
|
35
|
+
Args:
|
36
|
+
prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
|
37
|
+
Computed sample `(x_{t-1})` of previous timestep. `prev_sample` should be used as next model input in the
|
38
|
+
denoising loop.
|
39
|
+
"""
|
40
|
+
|
41
|
+
prev_sample: torch.FloatTensor
|
42
|
+
|
43
|
+
|
44
|
+
class FlowMatchHeunDiscreteScheduler(SchedulerMixin, ConfigMixin):
|
45
|
+
"""
|
46
|
+
Heun scheduler.
|
47
|
+
|
48
|
+
This model inherits from [`SchedulerMixin`] and [`ConfigMixin`]. Check the superclass documentation for the generic
|
49
|
+
methods the library implements for all schedulers such as loading and saving.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
num_train_timesteps (`int`, defaults to 1000):
|
53
|
+
The number of diffusion steps to train the model.
|
54
|
+
timestep_spacing (`str`, defaults to `"linspace"`):
|
55
|
+
The way the timesteps should be scaled. Refer to Table 2 of the [Common Diffusion Noise Schedules and
|
56
|
+
Sample Steps are Flawed](https://huggingface.co/papers/2305.08891) for more information.
|
57
|
+
shift (`float`, defaults to 1.0):
|
58
|
+
The shift value for the timestep schedule.
|
59
|
+
"""
|
60
|
+
|
61
|
+
_compatibles = []
|
62
|
+
order = 2
|
63
|
+
|
64
|
+
@register_to_config
|
65
|
+
def __init__(
|
66
|
+
self,
|
67
|
+
num_train_timesteps: int = 1000,
|
68
|
+
shift: float = 1.0,
|
69
|
+
):
|
70
|
+
timesteps = np.linspace(1, num_train_timesteps, num_train_timesteps, dtype=np.float32)[::-1].copy()
|
71
|
+
timesteps = torch.from_numpy(timesteps).to(dtype=torch.float32)
|
72
|
+
|
73
|
+
sigmas = timesteps / num_train_timesteps
|
74
|
+
sigmas = shift * sigmas / (1 + (shift - 1) * sigmas)
|
75
|
+
|
76
|
+
self.timesteps = sigmas * num_train_timesteps
|
77
|
+
|
78
|
+
self._step_index = None
|
79
|
+
self._begin_index = None
|
80
|
+
|
81
|
+
self.sigmas = sigmas.to("cpu") # to avoid too much CPU/GPU communication
|
82
|
+
self.sigma_min = self.sigmas[-1].item()
|
83
|
+
self.sigma_max = self.sigmas[0].item()
|
84
|
+
|
85
|
+
@property
|
86
|
+
def step_index(self):
|
87
|
+
"""
|
88
|
+
The index counter for current timestep. It will increase 1 after each scheduler step.
|
89
|
+
"""
|
90
|
+
return self._step_index
|
91
|
+
|
92
|
+
@property
|
93
|
+
def begin_index(self):
|
94
|
+
"""
|
95
|
+
The index for the first timestep. It should be set from pipeline with `set_begin_index` method.
|
96
|
+
"""
|
97
|
+
return self._begin_index
|
98
|
+
|
99
|
+
# Copied from diffusers.schedulers.scheduling_dpmsolver_multistep.DPMSolverMultistepScheduler.set_begin_index
|
100
|
+
def set_begin_index(self, begin_index: int = 0):
|
101
|
+
"""
|
102
|
+
Sets the begin index for the scheduler. This function should be run from pipeline before the inference.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
begin_index (`int`):
|
106
|
+
The begin index for the scheduler.
|
107
|
+
"""
|
108
|
+
self._begin_index = begin_index
|
109
|
+
|
110
|
+
def scale_noise(
|
111
|
+
self,
|
112
|
+
sample: torch.FloatTensor,
|
113
|
+
timestep: Union[float, torch.FloatTensor],
|
114
|
+
noise: Optional[torch.FloatTensor] = None,
|
115
|
+
) -> torch.FloatTensor:
|
116
|
+
"""
|
117
|
+
Forward process in flow-matching
|
118
|
+
|
119
|
+
Args:
|
120
|
+
sample (`torch.FloatTensor`):
|
121
|
+
The input sample.
|
122
|
+
timestep (`int`, *optional*):
|
123
|
+
The current timestep in the diffusion chain.
|
124
|
+
|
125
|
+
Returns:
|
126
|
+
`torch.FloatTensor`:
|
127
|
+
A scaled input sample.
|
128
|
+
"""
|
129
|
+
if self.step_index is None:
|
130
|
+
self._init_step_index(timestep)
|
131
|
+
|
132
|
+
sigma = self.sigmas[self.step_index]
|
133
|
+
sample = sigma * noise + (1.0 - sigma) * sample
|
134
|
+
|
135
|
+
return sample
|
136
|
+
|
137
|
+
def _sigma_to_t(self, sigma):
|
138
|
+
return sigma * self.config.num_train_timesteps
|
139
|
+
|
140
|
+
def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
|
141
|
+
"""
|
142
|
+
Sets the discrete timesteps used for the diffusion chain (to be run before inference).
|
143
|
+
|
144
|
+
Args:
|
145
|
+
num_inference_steps (`int`):
|
146
|
+
The number of diffusion steps used when generating samples with a pre-trained model.
|
147
|
+
device (`str` or `torch.device`, *optional*):
|
148
|
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
149
|
+
"""
|
150
|
+
self.num_inference_steps = num_inference_steps
|
151
|
+
|
152
|
+
timesteps = np.linspace(
|
153
|
+
self._sigma_to_t(self.sigma_max), self._sigma_to_t(self.sigma_min), num_inference_steps
|
154
|
+
)
|
155
|
+
|
156
|
+
sigmas = timesteps / self.config.num_train_timesteps
|
157
|
+
sigmas = self.config.shift * sigmas / (1 + (self.config.shift - 1) * sigmas)
|
158
|
+
sigmas = torch.from_numpy(sigmas).to(dtype=torch.float32, device=device)
|
159
|
+
|
160
|
+
timesteps = sigmas * self.config.num_train_timesteps
|
161
|
+
timesteps = torch.cat([timesteps[:1], timesteps[1:].repeat_interleave(2)])
|
162
|
+
self.timesteps = timesteps.to(device=device)
|
163
|
+
|
164
|
+
sigmas = torch.cat([sigmas, torch.zeros(1, device=sigmas.device)])
|
165
|
+
self.sigmas = torch.cat([sigmas[:1], sigmas[1:-1].repeat_interleave(2), sigmas[-1:]])
|
166
|
+
|
167
|
+
# empty dt and derivative
|
168
|
+
self.prev_derivative = None
|
169
|
+
self.dt = None
|
170
|
+
|
171
|
+
self._step_index = None
|
172
|
+
self._begin_index = None
|
173
|
+
|
174
|
+
def index_for_timestep(self, timestep, schedule_timesteps=None):
|
175
|
+
if schedule_timesteps is None:
|
176
|
+
schedule_timesteps = self.timesteps
|
177
|
+
|
178
|
+
indices = (schedule_timesteps == timestep).nonzero()
|
179
|
+
|
180
|
+
# The sigma index that is taken for the **very** first `step`
|
181
|
+
# is always the second index (or the last index if there is only 1)
|
182
|
+
# This way we can ensure we don't accidentally skip a sigma in
|
183
|
+
# case we start in the middle of the denoising schedule (e.g. for image-to-image)
|
184
|
+
pos = 1 if len(indices) > 1 else 0
|
185
|
+
|
186
|
+
return indices[pos].item()
|
187
|
+
|
188
|
+
def _init_step_index(self, timestep):
|
189
|
+
if self.begin_index is None:
|
190
|
+
if isinstance(timestep, torch.Tensor):
|
191
|
+
timestep = timestep.to(self.timesteps.device)
|
192
|
+
self._step_index = self.index_for_timestep(timestep)
|
193
|
+
else:
|
194
|
+
self._step_index = self._begin_index
|
195
|
+
|
196
|
+
@property
|
197
|
+
def state_in_first_order(self):
|
198
|
+
return self.dt is None
|
199
|
+
|
200
|
+
def step(
|
201
|
+
self,
|
202
|
+
model_output: torch.FloatTensor,
|
203
|
+
timestep: Union[float, torch.FloatTensor],
|
204
|
+
sample: torch.FloatTensor,
|
205
|
+
s_churn: float = 0.0,
|
206
|
+
s_tmin: float = 0.0,
|
207
|
+
s_tmax: float = float("inf"),
|
208
|
+
s_noise: float = 1.0,
|
209
|
+
generator: Optional[torch.Generator] = None,
|
210
|
+
return_dict: bool = True,
|
211
|
+
) -> Union[FlowMatchHeunDiscreteSchedulerOutput, Tuple]:
|
212
|
+
"""
|
213
|
+
Predict the sample from the previous timestep by reversing the SDE. This function propagates the diffusion
|
214
|
+
process from the learned model outputs (most often the predicted noise).
|
215
|
+
|
216
|
+
Args:
|
217
|
+
model_output (`torch.FloatTensor`):
|
218
|
+
The direct output from learned diffusion model.
|
219
|
+
timestep (`float`):
|
220
|
+
The current discrete timestep in the diffusion chain.
|
221
|
+
sample (`torch.FloatTensor`):
|
222
|
+
A current instance of a sample created by the diffusion process.
|
223
|
+
s_churn (`float`):
|
224
|
+
s_tmin (`float`):
|
225
|
+
s_tmax (`float`):
|
226
|
+
s_noise (`float`, defaults to 1.0):
|
227
|
+
Scaling factor for noise added to the sample.
|
228
|
+
generator (`torch.Generator`, *optional*):
|
229
|
+
A random number generator.
|
230
|
+
return_dict (`bool`):
|
231
|
+
Whether or not to return a [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or
|
232
|
+
tuple.
|
233
|
+
|
234
|
+
Returns:
|
235
|
+
[`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] or `tuple`:
|
236
|
+
If return_dict is `True`, [`~schedulers.scheduling_Heun_discrete.HeunDiscreteSchedulerOutput`] is
|
237
|
+
returned, otherwise a tuple is returned where the first element is the sample tensor.
|
238
|
+
"""
|
239
|
+
|
240
|
+
if (
|
241
|
+
isinstance(timestep, int)
|
242
|
+
or isinstance(timestep, torch.IntTensor)
|
243
|
+
or isinstance(timestep, torch.LongTensor)
|
244
|
+
):
|
245
|
+
raise ValueError(
|
246
|
+
(
|
247
|
+
"Passing integer indices (e.g. from `enumerate(timesteps)`) as timesteps to"
|
248
|
+
" `HeunDiscreteScheduler.step()` is not supported. Make sure to pass"
|
249
|
+
" one of the `scheduler.timesteps` as a timestep."
|
250
|
+
),
|
251
|
+
)
|
252
|
+
|
253
|
+
if self.step_index is None:
|
254
|
+
self._init_step_index(timestep)
|
255
|
+
|
256
|
+
# Upcast to avoid precision issues when computing prev_sample
|
257
|
+
sample = sample.to(torch.float32)
|
258
|
+
|
259
|
+
if self.state_in_first_order:
|
260
|
+
sigma = self.sigmas[self.step_index]
|
261
|
+
sigma_next = self.sigmas[self.step_index + 1]
|
262
|
+
else:
|
263
|
+
# 2nd order / Heun's method
|
264
|
+
sigma = self.sigmas[self.step_index - 1]
|
265
|
+
sigma_next = self.sigmas[self.step_index]
|
266
|
+
|
267
|
+
gamma = min(s_churn / (len(self.sigmas) - 1), 2**0.5 - 1) if s_tmin <= sigma <= s_tmax else 0.0
|
268
|
+
|
269
|
+
noise = randn_tensor(
|
270
|
+
model_output.shape, dtype=model_output.dtype, device=model_output.device, generator=generator
|
271
|
+
)
|
272
|
+
|
273
|
+
eps = noise * s_noise
|
274
|
+
sigma_hat = sigma * (gamma + 1)
|
275
|
+
|
276
|
+
if gamma > 0:
|
277
|
+
sample = sample + eps * (sigma_hat**2 - sigma**2) ** 0.5
|
278
|
+
|
279
|
+
if self.state_in_first_order:
|
280
|
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
281
|
+
denoised = sample - model_output * sigma
|
282
|
+
# 2. convert to an ODE derivative for 1st order
|
283
|
+
derivative = (sample - denoised) / sigma_hat
|
284
|
+
# 3. Delta timestep
|
285
|
+
dt = sigma_next - sigma_hat
|
286
|
+
|
287
|
+
# store for 2nd order step
|
288
|
+
self.prev_derivative = derivative
|
289
|
+
self.dt = dt
|
290
|
+
self.sample = sample
|
291
|
+
else:
|
292
|
+
# 1. compute predicted original sample (x_0) from sigma-scaled predicted noise
|
293
|
+
denoised = sample - model_output * sigma_next
|
294
|
+
# 2. 2nd order / Heun's method
|
295
|
+
derivative = (sample - denoised) / sigma_next
|
296
|
+
derivative = 0.5 * (self.prev_derivative + derivative)
|
297
|
+
|
298
|
+
# 3. take prev timestep & sample
|
299
|
+
dt = self.dt
|
300
|
+
sample = self.sample
|
301
|
+
|
302
|
+
# free dt and derivative
|
303
|
+
# Note, this puts the scheduler in "first order mode"
|
304
|
+
self.prev_derivative = None
|
305
|
+
self.dt = None
|
306
|
+
self.sample = None
|
307
|
+
|
308
|
+
prev_sample = sample + derivative * dt
|
309
|
+
# Cast sample back to model compatible dtype
|
310
|
+
prev_sample = prev_sample.to(model_output.dtype)
|
311
|
+
|
312
|
+
# upon completion increase step index by one
|
313
|
+
self._step_index += 1
|
314
|
+
|
315
|
+
if not return_dict:
|
316
|
+
return (prev_sample,)
|
317
|
+
|
318
|
+
return FlowMatchHeunDiscreteSchedulerOutput(prev_sample=prev_sample)
|
319
|
+
|
320
|
+
def __len__(self):
|
321
|
+
return self.config.num_train_timesteps
|
@@ -138,7 +138,7 @@ class IPNDMScheduler(SchedulerMixin, ConfigMixin):
|
|
138
138
|
def step(
|
139
139
|
self,
|
140
140
|
model_output: torch.Tensor,
|
141
|
-
timestep: int,
|
141
|
+
timestep: Union[int, torch.Tensor],
|
142
142
|
sample: torch.Tensor,
|
143
143
|
return_dict: bool = True,
|
144
144
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -822,7 +822,7 @@ class UniPCMultistepScheduler(SchedulerMixin, ConfigMixin):
|
|
822
822
|
def step(
|
823
823
|
self,
|
824
824
|
model_output: torch.Tensor,
|
825
|
-
timestep: int,
|
825
|
+
timestep: Union[int, torch.Tensor],
|
826
826
|
sample: torch.Tensor,
|
827
827
|
return_dict: bool = True,
|
828
828
|
) -> Union[SchedulerOutput, Tuple]:
|
@@ -121,9 +121,7 @@ class SchedulerMixin(PushToHubMixin):
|
|
121
121
|
force_download (`bool`, *optional*, defaults to `False`):
|
122
122
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
123
123
|
cached versions if they exist.
|
124
|
-
|
125
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
126
|
-
of Diffusers.
|
124
|
+
|
127
125
|
proxies (`Dict[str, str]`, *optional*):
|
128
126
|
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
129
127
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
@@ -102,9 +102,7 @@ class FlaxSchedulerMixin(PushToHubMixin):
|
|
102
102
|
force_download (`bool`, *optional*, defaults to `False`):
|
103
103
|
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
104
104
|
cached versions if they exist.
|
105
|
-
|
106
|
-
Deprecated and ignored. All downloads are now resumed by default when possible. Will be removed in v1
|
107
|
-
of Diffusers.
|
105
|
+
|
108
106
|
proxies (`Dict[str, str]`, *optional*):
|
109
107
|
A dictionary of proxy servers to use by protocol or endpoint, e.g., `{'http': 'foo.bar:3128',
|
110
108
|
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
diffusers/training_utils.py
CHANGED
@@ -1,5 +1,6 @@
|
|
1
1
|
import contextlib
|
2
2
|
import copy
|
3
|
+
import math
|
3
4
|
import random
|
4
5
|
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
5
6
|
|
@@ -220,6 +221,44 @@ def _set_state_dict_into_text_encoder(
|
|
220
221
|
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
221
222
|
|
222
223
|
|
224
|
+
def compute_density_for_timestep_sampling(
|
225
|
+
weighting_scheme: str, batch_size: int, logit_mean: float = None, logit_std: float = None, mode_scale: float = None
|
226
|
+
):
|
227
|
+
"""Compute the density for sampling the timesteps when doing SD3 training.
|
228
|
+
|
229
|
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
230
|
+
|
231
|
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
232
|
+
"""
|
233
|
+
if weighting_scheme == "logit_normal":
|
234
|
+
# See 3.1 in the SD3 paper ($rf/lognorm(0.00,1.00)$).
|
235
|
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device="cpu")
|
236
|
+
u = torch.nn.functional.sigmoid(u)
|
237
|
+
elif weighting_scheme == "mode":
|
238
|
+
u = torch.rand(size=(batch_size,), device="cpu")
|
239
|
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
240
|
+
else:
|
241
|
+
u = torch.rand(size=(batch_size,), device="cpu")
|
242
|
+
return u
|
243
|
+
|
244
|
+
|
245
|
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
246
|
+
"""Computes loss weighting scheme for SD3 training.
|
247
|
+
|
248
|
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
249
|
+
|
250
|
+
SD3 paper reference: https://arxiv.org/abs/2403.03206v1.
|
251
|
+
"""
|
252
|
+
if weighting_scheme == "sigma_sqrt":
|
253
|
+
weighting = (sigmas**-2.0).float()
|
254
|
+
elif weighting_scheme == "cosmap":
|
255
|
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
256
|
+
weighting = 2 / (math.pi * bot)
|
257
|
+
else:
|
258
|
+
weighting = torch.ones_like(sigmas)
|
259
|
+
return weighting
|
260
|
+
|
261
|
+
|
223
262
|
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
224
263
|
class EMAModel:
|
225
264
|
"""
|
@@ -235,6 +274,7 @@ class EMAModel:
|
|
235
274
|
use_ema_warmup: bool = False,
|
236
275
|
inv_gamma: Union[float, int] = 1.0,
|
237
276
|
power: Union[float, int] = 2 / 3,
|
277
|
+
foreach: bool = False,
|
238
278
|
model_cls: Optional[Any] = None,
|
239
279
|
model_config: Dict[str, Any] = None,
|
240
280
|
**kwargs,
|
@@ -249,6 +289,7 @@ class EMAModel:
|
|
249
289
|
inv_gamma (float):
|
250
290
|
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
251
291
|
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
292
|
+
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
|
252
293
|
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
253
294
|
weights will be stored on CPU.
|
254
295
|
|
@@ -303,16 +344,17 @@ class EMAModel:
|
|
303
344
|
self.power = power
|
304
345
|
self.optimization_step = 0
|
305
346
|
self.cur_decay_value = None # set in `step()`
|
347
|
+
self.foreach = foreach
|
306
348
|
|
307
349
|
self.model_cls = model_cls
|
308
350
|
self.model_config = model_config
|
309
351
|
|
310
352
|
@classmethod
|
311
|
-
def from_pretrained(cls, path, model_cls) -> "EMAModel":
|
353
|
+
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
312
354
|
_, ema_kwargs = model_cls.load_config(path, return_unused_kwargs=True)
|
313
355
|
model = model_cls.from_pretrained(path)
|
314
356
|
|
315
|
-
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config)
|
357
|
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
316
358
|
|
317
359
|
ema_model.load_state_dict(ema_kwargs)
|
318
360
|
return ema_model
|
@@ -379,15 +421,37 @@ class EMAModel:
|
|
379
421
|
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
380
422
|
import deepspeed
|
381
423
|
|
382
|
-
|
424
|
+
if self.foreach:
|
383
425
|
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
384
|
-
context_manager = deepspeed.zero.GatheredParameters(
|
426
|
+
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
385
427
|
|
386
428
|
with context_manager():
|
387
|
-
if param.requires_grad
|
388
|
-
|
389
|
-
|
390
|
-
|
429
|
+
params_grad = [param for param in parameters if param.requires_grad]
|
430
|
+
s_params_grad = [
|
431
|
+
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
|
432
|
+
]
|
433
|
+
|
434
|
+
if len(params_grad) < len(parameters):
|
435
|
+
torch._foreach_copy_(
|
436
|
+
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
|
437
|
+
[param for param in parameters if not param.requires_grad],
|
438
|
+
non_blocking=True,
|
439
|
+
)
|
440
|
+
|
441
|
+
torch._foreach_sub_(
|
442
|
+
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
|
443
|
+
)
|
444
|
+
|
445
|
+
else:
|
446
|
+
for s_param, param in zip(self.shadow_params, parameters):
|
447
|
+
if is_transformers_available() and transformers.deepspeed.is_deepspeed_zero3_enabled():
|
448
|
+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
449
|
+
|
450
|
+
with context_manager():
|
451
|
+
if param.requires_grad:
|
452
|
+
s_param.sub_(one_minus_decay * (s_param - param))
|
453
|
+
else:
|
454
|
+
s_param.copy_(param)
|
391
455
|
|
392
456
|
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
393
457
|
"""
|
@@ -399,10 +463,24 @@ class EMAModel:
|
|
399
463
|
`ExponentialMovingAverage` was initialized will be used.
|
400
464
|
"""
|
401
465
|
parameters = list(parameters)
|
402
|
-
|
403
|
-
|
466
|
+
if self.foreach:
|
467
|
+
torch._foreach_copy_(
|
468
|
+
[param.data for param in parameters],
|
469
|
+
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
|
470
|
+
)
|
471
|
+
else:
|
472
|
+
for s_param, param in zip(self.shadow_params, parameters):
|
473
|
+
param.data.copy_(s_param.to(param.device).data)
|
404
474
|
|
405
|
-
def
|
475
|
+
def pin_memory(self) -> None:
|
476
|
+
r"""
|
477
|
+
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
|
478
|
+
offloading EMA params to the host.
|
479
|
+
"""
|
480
|
+
|
481
|
+
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
|
482
|
+
|
483
|
+
def to(self, device=None, dtype=None, non_blocking=False) -> None:
|
406
484
|
r"""Move internal buffers of the ExponentialMovingAverage to `device`.
|
407
485
|
|
408
486
|
Args:
|
@@ -410,7 +488,9 @@ class EMAModel:
|
|
410
488
|
"""
|
411
489
|
# .to() on the tensors handles None correctly
|
412
490
|
self.shadow_params = [
|
413
|
-
p.to(device=device, dtype=dtype
|
491
|
+
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
492
|
+
if p.is_floating_point()
|
493
|
+
else p.to(device=device, non_blocking=non_blocking)
|
414
494
|
for p in self.shadow_params
|
415
495
|
]
|
416
496
|
|
@@ -454,8 +534,13 @@ class EMAModel:
|
|
454
534
|
"""
|
455
535
|
if self.temp_stored_params is None:
|
456
536
|
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights " "to `restore()`")
|
457
|
-
|
458
|
-
|
537
|
+
if self.foreach:
|
538
|
+
torch._foreach_copy_(
|
539
|
+
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
|
540
|
+
)
|
541
|
+
else:
|
542
|
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
543
|
+
param.data.copy_(c_param.data)
|
459
544
|
|
460
545
|
# Better memory-wise.
|
461
546
|
self.temp_stored_params = None
|
diffusers/utils/__init__.py
CHANGED
@@ -73,12 +73,12 @@ from .import_utils import (
|
|
73
73
|
is_librosa_available,
|
74
74
|
is_matplotlib_available,
|
75
75
|
is_note_seq_available,
|
76
|
-
is_notebook,
|
77
76
|
is_onnx_available,
|
78
77
|
is_peft_available,
|
79
78
|
is_peft_version,
|
80
79
|
is_safetensors_available,
|
81
80
|
is_scipy_available,
|
81
|
+
is_sentencepiece_available,
|
82
82
|
is_tensorboard_available,
|
83
83
|
is_timm_available,
|
84
84
|
is_torch_available,
|
@@ -94,7 +94,7 @@ from .import_utils import (
|
|
94
94
|
is_xformers_available,
|
95
95
|
requires_backends,
|
96
96
|
)
|
97
|
-
from .loading_utils import load_image
|
97
|
+
from .loading_utils import load_image, load_video
|
98
98
|
from .logging import get_logger
|
99
99
|
from .outputs import BaseOutput
|
100
100
|
from .peft_utils import (
|