diffusers 0.30.2__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +38 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +287 -85
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +238 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +58 -36
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +40 -7
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +6 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +32 -34
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +837 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +825 -0
- diffusers/pipelines/cogvideo/pipeline_output.py +20 -0
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +195 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/RECORD +173 -147
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/WHEEL +1 -1
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.2.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -12,16 +12,21 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
-
from typing import Optional, Union
|
15
|
+
from typing import Callable, Dict, List, Optional, Tuple, Union
|
16
16
|
|
17
17
|
import torch
|
18
|
+
import torch.nn as nn
|
18
19
|
|
19
20
|
from ..models.attention import BasicTransformerBlock, FreeNoiseTransformerBlock
|
21
|
+
from ..models.resnet import Downsample2D, ResnetBlock2D, Upsample2D
|
22
|
+
from ..models.transformers.transformer_2d import Transformer2DModel
|
20
23
|
from ..models.unets.unet_motion_model import (
|
24
|
+
AnimateDiffTransformer3D,
|
21
25
|
CrossAttnDownBlockMotion,
|
22
26
|
DownBlockMotion,
|
23
27
|
UpBlockMotion,
|
24
28
|
)
|
29
|
+
from ..pipelines.pipeline_utils import DiffusionPipeline
|
25
30
|
from ..utils import logging
|
26
31
|
from ..utils.torch_utils import randn_tensor
|
27
32
|
|
@@ -29,6 +34,114 @@ from ..utils.torch_utils import randn_tensor
|
|
29
34
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
30
35
|
|
31
36
|
|
37
|
+
class SplitInferenceModule(nn.Module):
|
38
|
+
r"""
|
39
|
+
A wrapper module class that splits inputs along a specified dimension before performing a forward pass.
|
40
|
+
|
41
|
+
This module is useful when you need to perform inference on large tensors in a memory-efficient way by breaking
|
42
|
+
them into smaller chunks, processing each chunk separately, and then reassembling the results.
|
43
|
+
|
44
|
+
Args:
|
45
|
+
module (`nn.Module`):
|
46
|
+
The underlying PyTorch module that will be applied to each chunk of split inputs.
|
47
|
+
split_size (`int`, defaults to `1`):
|
48
|
+
The size of each chunk after splitting the input tensor.
|
49
|
+
split_dim (`int`, defaults to `0`):
|
50
|
+
The dimension along which the input tensors are split.
|
51
|
+
input_kwargs_to_split (`List[str]`, defaults to `["hidden_states"]`):
|
52
|
+
A list of keyword arguments (strings) that represent the input tensors to be split.
|
53
|
+
|
54
|
+
Workflow:
|
55
|
+
1. The keyword arguments specified in `input_kwargs_to_split` are split into smaller chunks using
|
56
|
+
`torch.split()` along the dimension `split_dim` and with a chunk size of `split_size`.
|
57
|
+
2. The `module` is invoked once for each split with both the split inputs and any unchanged arguments
|
58
|
+
that were passed.
|
59
|
+
3. The output tensors from each split are concatenated back together along `split_dim` before returning.
|
60
|
+
|
61
|
+
Example:
|
62
|
+
```python
|
63
|
+
>>> import torch
|
64
|
+
>>> import torch.nn as nn
|
65
|
+
|
66
|
+
>>> model = nn.Linear(1000, 1000)
|
67
|
+
>>> split_module = SplitInferenceModule(model, split_size=2, split_dim=0, input_kwargs_to_split=["input"])
|
68
|
+
|
69
|
+
>>> input_tensor = torch.randn(42, 1000)
|
70
|
+
>>> # Will split the tensor into 21 slices of shape [2, 1000].
|
71
|
+
>>> output = split_module(input=input_tensor)
|
72
|
+
```
|
73
|
+
|
74
|
+
It is also possible to nest `SplitInferenceModule` across different split dimensions for more complex
|
75
|
+
multi-dimensional splitting.
|
76
|
+
"""
|
77
|
+
|
78
|
+
def __init__(
|
79
|
+
self,
|
80
|
+
module: nn.Module,
|
81
|
+
split_size: int = 1,
|
82
|
+
split_dim: int = 0,
|
83
|
+
input_kwargs_to_split: List[str] = ["hidden_states"],
|
84
|
+
) -> None:
|
85
|
+
super().__init__()
|
86
|
+
|
87
|
+
self.module = module
|
88
|
+
self.split_size = split_size
|
89
|
+
self.split_dim = split_dim
|
90
|
+
self.input_kwargs_to_split = set(input_kwargs_to_split)
|
91
|
+
|
92
|
+
def forward(self, *args, **kwargs) -> Union[torch.Tensor, Tuple[torch.Tensor]]:
|
93
|
+
r"""Forward method for the `SplitInferenceModule`.
|
94
|
+
|
95
|
+
This method processes the input by splitting specified keyword arguments along a given dimension, running the
|
96
|
+
underlying module on each split, and then concatenating the results. The splitting is controlled by the
|
97
|
+
`split_size` and `split_dim` parameters specified during initialization.
|
98
|
+
|
99
|
+
Args:
|
100
|
+
*args (`Any`):
|
101
|
+
Positional arguments that are passed directly to the `module` without modification.
|
102
|
+
**kwargs (`Dict[str, torch.Tensor]`):
|
103
|
+
Keyword arguments passed to the underlying `module`. Only keyword arguments whose names match the
|
104
|
+
entries in `input_kwargs_to_split` and are of type `torch.Tensor` will be split. The remaining keyword
|
105
|
+
arguments are passed unchanged.
|
106
|
+
|
107
|
+
Returns:
|
108
|
+
`Union[torch.Tensor, Tuple[torch.Tensor]]`:
|
109
|
+
The outputs obtained from `SplitInferenceModule` are the same as if the underlying module was inferred
|
110
|
+
without it.
|
111
|
+
- If the underlying module returns a single tensor, the result will be a single concatenated tensor
|
112
|
+
along the same `split_dim` after processing all splits.
|
113
|
+
- If the underlying module returns a tuple of tensors, each element of the tuple will be concatenated
|
114
|
+
along the `split_dim` across all splits, and the final result will be a tuple of concatenated tensors.
|
115
|
+
"""
|
116
|
+
split_inputs = {}
|
117
|
+
|
118
|
+
# 1. Split inputs that were specified during initialization and also present in passed kwargs
|
119
|
+
for key in list(kwargs.keys()):
|
120
|
+
if key not in self.input_kwargs_to_split or not torch.is_tensor(kwargs[key]):
|
121
|
+
continue
|
122
|
+
split_inputs[key] = torch.split(kwargs[key], self.split_size, self.split_dim)
|
123
|
+
kwargs.pop(key)
|
124
|
+
|
125
|
+
# 2. Invoke forward pass across each split
|
126
|
+
results = []
|
127
|
+
for split_input in zip(*split_inputs.values()):
|
128
|
+
inputs = dict(zip(split_inputs.keys(), split_input))
|
129
|
+
inputs.update(kwargs)
|
130
|
+
|
131
|
+
intermediate_tensor_or_tensor_tuple = self.module(*args, **inputs)
|
132
|
+
results.append(intermediate_tensor_or_tensor_tuple)
|
133
|
+
|
134
|
+
# 3. Concatenate split restuls to obtain final outputs
|
135
|
+
if isinstance(results[0], torch.Tensor):
|
136
|
+
return torch.cat(results, dim=self.split_dim)
|
137
|
+
elif isinstance(results[0], tuple):
|
138
|
+
return tuple([torch.cat(x, dim=self.split_dim) for x in zip(*results)])
|
139
|
+
else:
|
140
|
+
raise ValueError(
|
141
|
+
"In order to use the SplitInferenceModule, it is necessary for the underlying `module` to either return a torch.Tensor or a tuple of torch.Tensor's."
|
142
|
+
)
|
143
|
+
|
144
|
+
|
32
145
|
class AnimateDiffFreeNoiseMixin:
|
33
146
|
r"""Mixin class for [FreeNoise](https://arxiv.org/abs/2310.15169)."""
|
34
147
|
|
@@ -69,6 +182,9 @@ class AnimateDiffFreeNoiseMixin:
|
|
69
182
|
motion_module.transformer_blocks[i].load_state_dict(
|
70
183
|
basic_transfomer_block.state_dict(), strict=True
|
71
184
|
)
|
185
|
+
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
186
|
+
basic_transfomer_block._chunk_size, basic_transfomer_block._chunk_dim
|
187
|
+
)
|
72
188
|
|
73
189
|
def _disable_free_noise_in_block(self, block: Union[CrossAttnDownBlockMotion, DownBlockMotion, UpBlockMotion]):
|
74
190
|
r"""Helper function to disable FreeNoise in transformer blocks."""
|
@@ -97,6 +213,145 @@ class AnimateDiffFreeNoiseMixin:
|
|
97
213
|
motion_module.transformer_blocks[i].load_state_dict(
|
98
214
|
free_noise_transfomer_block.state_dict(), strict=True
|
99
215
|
)
|
216
|
+
motion_module.transformer_blocks[i].set_chunk_feed_forward(
|
217
|
+
free_noise_transfomer_block._chunk_size, free_noise_transfomer_block._chunk_dim
|
218
|
+
)
|
219
|
+
|
220
|
+
def _check_inputs_free_noise(
|
221
|
+
self,
|
222
|
+
prompt,
|
223
|
+
negative_prompt,
|
224
|
+
prompt_embeds,
|
225
|
+
negative_prompt_embeds,
|
226
|
+
num_frames,
|
227
|
+
) -> None:
|
228
|
+
if not isinstance(prompt, (str, dict)):
|
229
|
+
raise ValueError(f"Expected `prompt` to have type `str` or `dict` but found {type(prompt)=}")
|
230
|
+
|
231
|
+
if negative_prompt is not None:
|
232
|
+
if not isinstance(negative_prompt, (str, dict)):
|
233
|
+
raise ValueError(
|
234
|
+
f"Expected `negative_prompt` to have type `str` or `dict` but found {type(negative_prompt)=}"
|
235
|
+
)
|
236
|
+
|
237
|
+
if prompt_embeds is not None or negative_prompt_embeds is not None:
|
238
|
+
raise ValueError("`prompt_embeds` and `negative_prompt_embeds` is not supported in FreeNoise yet.")
|
239
|
+
|
240
|
+
frame_indices = [isinstance(x, int) for x in prompt.keys()]
|
241
|
+
frame_prompts = [isinstance(x, str) for x in prompt.values()]
|
242
|
+
min_frame = min(list(prompt.keys()))
|
243
|
+
max_frame = max(list(prompt.keys()))
|
244
|
+
|
245
|
+
if not all(frame_indices):
|
246
|
+
raise ValueError("Expected integer keys in `prompt` dict for FreeNoise.")
|
247
|
+
if not all(frame_prompts):
|
248
|
+
raise ValueError("Expected str values in `prompt` dict for FreeNoise.")
|
249
|
+
if min_frame != 0:
|
250
|
+
raise ValueError("The minimum frame index in `prompt` dict must be 0 as a starting prompt is necessary.")
|
251
|
+
if max_frame >= num_frames:
|
252
|
+
raise ValueError(
|
253
|
+
f"The maximum frame index in `prompt` dict must be lesser than {num_frames=} and follow 0-based indexing."
|
254
|
+
)
|
255
|
+
|
256
|
+
def _encode_prompt_free_noise(
|
257
|
+
self,
|
258
|
+
prompt: Union[str, Dict[int, str]],
|
259
|
+
num_frames: int,
|
260
|
+
device: torch.device,
|
261
|
+
num_videos_per_prompt: int,
|
262
|
+
do_classifier_free_guidance: bool,
|
263
|
+
negative_prompt: Optional[Union[str, Dict[int, str]]] = None,
|
264
|
+
prompt_embeds: Optional[torch.Tensor] = None,
|
265
|
+
negative_prompt_embeds: Optional[torch.Tensor] = None,
|
266
|
+
lora_scale: Optional[float] = None,
|
267
|
+
clip_skip: Optional[int] = None,
|
268
|
+
) -> torch.Tensor:
|
269
|
+
if negative_prompt is None:
|
270
|
+
negative_prompt = ""
|
271
|
+
|
272
|
+
# Ensure that we have a dictionary of prompts
|
273
|
+
if isinstance(prompt, str):
|
274
|
+
prompt = {0: prompt}
|
275
|
+
if isinstance(negative_prompt, str):
|
276
|
+
negative_prompt = {0: negative_prompt}
|
277
|
+
|
278
|
+
self._check_inputs_free_noise(prompt, negative_prompt, prompt_embeds, negative_prompt_embeds, num_frames)
|
279
|
+
|
280
|
+
# Sort the prompts based on frame indices
|
281
|
+
prompt = dict(sorted(prompt.items()))
|
282
|
+
negative_prompt = dict(sorted(negative_prompt.items()))
|
283
|
+
|
284
|
+
# Ensure that we have a prompt for the last frame index
|
285
|
+
prompt[num_frames - 1] = prompt[list(prompt.keys())[-1]]
|
286
|
+
negative_prompt[num_frames - 1] = negative_prompt[list(negative_prompt.keys())[-1]]
|
287
|
+
|
288
|
+
frame_indices = list(prompt.keys())
|
289
|
+
frame_prompts = list(prompt.values())
|
290
|
+
frame_negative_indices = list(negative_prompt.keys())
|
291
|
+
frame_negative_prompts = list(negative_prompt.values())
|
292
|
+
|
293
|
+
# Generate and interpolate positive prompts
|
294
|
+
prompt_embeds, _ = self.encode_prompt(
|
295
|
+
prompt=frame_prompts,
|
296
|
+
device=device,
|
297
|
+
num_images_per_prompt=num_videos_per_prompt,
|
298
|
+
do_classifier_free_guidance=False,
|
299
|
+
negative_prompt=None,
|
300
|
+
prompt_embeds=None,
|
301
|
+
negative_prompt_embeds=None,
|
302
|
+
lora_scale=lora_scale,
|
303
|
+
clip_skip=clip_skip,
|
304
|
+
)
|
305
|
+
|
306
|
+
shape = (num_frames, *prompt_embeds.shape[1:])
|
307
|
+
prompt_interpolation_embeds = prompt_embeds.new_zeros(shape)
|
308
|
+
|
309
|
+
for i in range(len(frame_indices) - 1):
|
310
|
+
start_frame = frame_indices[i]
|
311
|
+
end_frame = frame_indices[i + 1]
|
312
|
+
start_tensor = prompt_embeds[i].unsqueeze(0)
|
313
|
+
end_tensor = prompt_embeds[i + 1].unsqueeze(0)
|
314
|
+
|
315
|
+
prompt_interpolation_embeds[start_frame : end_frame + 1] = self._free_noise_prompt_interpolation_callback(
|
316
|
+
start_frame, end_frame, start_tensor, end_tensor
|
317
|
+
)
|
318
|
+
|
319
|
+
# Generate and interpolate negative prompts
|
320
|
+
negative_prompt_embeds = None
|
321
|
+
negative_prompt_interpolation_embeds = None
|
322
|
+
|
323
|
+
if do_classifier_free_guidance:
|
324
|
+
_, negative_prompt_embeds = self.encode_prompt(
|
325
|
+
prompt=[""] * len(frame_negative_prompts),
|
326
|
+
device=device,
|
327
|
+
num_images_per_prompt=num_videos_per_prompt,
|
328
|
+
do_classifier_free_guidance=True,
|
329
|
+
negative_prompt=frame_negative_prompts,
|
330
|
+
prompt_embeds=None,
|
331
|
+
negative_prompt_embeds=None,
|
332
|
+
lora_scale=lora_scale,
|
333
|
+
clip_skip=clip_skip,
|
334
|
+
)
|
335
|
+
|
336
|
+
negative_prompt_interpolation_embeds = negative_prompt_embeds.new_zeros(shape)
|
337
|
+
|
338
|
+
for i in range(len(frame_negative_indices) - 1):
|
339
|
+
start_frame = frame_negative_indices[i]
|
340
|
+
end_frame = frame_negative_indices[i + 1]
|
341
|
+
start_tensor = negative_prompt_embeds[i].unsqueeze(0)
|
342
|
+
end_tensor = negative_prompt_embeds[i + 1].unsqueeze(0)
|
343
|
+
|
344
|
+
negative_prompt_interpolation_embeds[
|
345
|
+
start_frame : end_frame + 1
|
346
|
+
] = self._free_noise_prompt_interpolation_callback(start_frame, end_frame, start_tensor, end_tensor)
|
347
|
+
|
348
|
+
prompt_embeds = prompt_interpolation_embeds
|
349
|
+
negative_prompt_embeds = negative_prompt_interpolation_embeds
|
350
|
+
|
351
|
+
if do_classifier_free_guidance:
|
352
|
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
353
|
+
|
354
|
+
return prompt_embeds, negative_prompt_embeds
|
100
355
|
|
101
356
|
def _prepare_latents_free_noise(
|
102
357
|
self,
|
@@ -172,12 +427,29 @@ class AnimateDiffFreeNoiseMixin:
|
|
172
427
|
latents = latents[:, :, :num_frames]
|
173
428
|
return latents
|
174
429
|
|
430
|
+
def _lerp(
|
431
|
+
self, start_index: int, end_index: int, start_tensor: torch.Tensor, end_tensor: torch.Tensor
|
432
|
+
) -> torch.Tensor:
|
433
|
+
num_indices = end_index - start_index + 1
|
434
|
+
interpolated_tensors = []
|
435
|
+
|
436
|
+
for i in range(num_indices):
|
437
|
+
alpha = i / (num_indices - 1)
|
438
|
+
interpolated_tensor = (1 - alpha) * start_tensor + alpha * end_tensor
|
439
|
+
interpolated_tensors.append(interpolated_tensor)
|
440
|
+
|
441
|
+
interpolated_tensors = torch.cat(interpolated_tensors)
|
442
|
+
return interpolated_tensors
|
443
|
+
|
175
444
|
def enable_free_noise(
|
176
445
|
self,
|
177
446
|
context_length: Optional[int] = 16,
|
178
447
|
context_stride: int = 4,
|
179
448
|
weighting_scheme: str = "pyramid",
|
180
449
|
noise_type: str = "shuffle_context",
|
450
|
+
prompt_interpolation_callback: Optional[
|
451
|
+
Callable[[DiffusionPipeline, int, int, torch.Tensor, torch.Tensor], torch.Tensor]
|
452
|
+
] = None,
|
181
453
|
) -> None:
|
182
454
|
r"""
|
183
455
|
Enable long video generation using FreeNoise.
|
@@ -195,13 +467,27 @@ class AnimateDiffFreeNoiseMixin:
|
|
195
467
|
weighting_scheme (`str`, defaults to `pyramid`):
|
196
468
|
Weighting scheme for averaging latents after accumulation in FreeNoise blocks. The following weighting
|
197
469
|
schemes are supported currently:
|
470
|
+
- "flat"
|
471
|
+
Performs weighting averaging with a flat weight pattern: [1, 1, 1, 1, 1].
|
198
472
|
- "pyramid"
|
199
|
-
|
473
|
+
Performs weighted averaging with a pyramid like weight pattern: [1, 2, 3, 2, 1].
|
474
|
+
- "delayed_reverse_sawtooth"
|
475
|
+
Performs weighted averaging with low weights for earlier frames and high-to-low weights for
|
476
|
+
later frames: [0.01, 0.01, 3, 2, 1].
|
200
477
|
noise_type (`str`, defaults to "shuffle_context"):
|
201
|
-
|
478
|
+
Must be one of ["shuffle_context", "repeat_context", "random"].
|
479
|
+
- "shuffle_context"
|
480
|
+
Shuffles a fixed batch of `context_length` latents to create a final latent of size
|
481
|
+
`num_frames`. This is usually the best setting for most generation scenarious. However, there
|
482
|
+
might be visible repetition noticeable in the kinds of motion/animation generated.
|
483
|
+
- "repeated_context"
|
484
|
+
Repeats a fixed batch of `context_length` latents to create a final latent of size
|
485
|
+
`num_frames`.
|
486
|
+
- "random"
|
487
|
+
The final latents are random without any repetition.
|
202
488
|
"""
|
203
489
|
|
204
|
-
allowed_weighting_scheme = ["pyramid"]
|
490
|
+
allowed_weighting_scheme = ["flat", "pyramid", "delayed_reverse_sawtooth"]
|
205
491
|
allowed_noise_type = ["shuffle_context", "repeat_context", "random"]
|
206
492
|
|
207
493
|
if context_length > self.motion_adapter.config.motion_max_seq_length:
|
@@ -219,18 +505,92 @@ class AnimateDiffFreeNoiseMixin:
|
|
219
505
|
self._free_noise_context_stride = context_stride
|
220
506
|
self._free_noise_weighting_scheme = weighting_scheme
|
221
507
|
self._free_noise_noise_type = noise_type
|
508
|
+
self._free_noise_prompt_interpolation_callback = prompt_interpolation_callback or self._lerp
|
509
|
+
|
510
|
+
if hasattr(self.unet.mid_block, "motion_modules"):
|
511
|
+
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
512
|
+
else:
|
513
|
+
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
|
222
514
|
|
223
|
-
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
224
515
|
for block in blocks:
|
225
516
|
self._enable_free_noise_in_block(block)
|
226
517
|
|
227
518
|
def disable_free_noise(self) -> None:
|
519
|
+
r"""Disable the FreeNoise sampling mechanism."""
|
228
520
|
self._free_noise_context_length = None
|
229
521
|
|
522
|
+
if hasattr(self.unet.mid_block, "motion_modules"):
|
523
|
+
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
524
|
+
else:
|
525
|
+
blocks = [*self.unet.down_blocks, *self.unet.up_blocks]
|
526
|
+
|
230
527
|
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
231
528
|
for block in blocks:
|
232
529
|
self._disable_free_noise_in_block(block)
|
233
530
|
|
531
|
+
def _enable_split_inference_motion_modules_(
|
532
|
+
self, motion_modules: List[AnimateDiffTransformer3D], spatial_split_size: int
|
533
|
+
) -> None:
|
534
|
+
for motion_module in motion_modules:
|
535
|
+
motion_module.proj_in = SplitInferenceModule(motion_module.proj_in, spatial_split_size, 0, ["input"])
|
536
|
+
|
537
|
+
for i in range(len(motion_module.transformer_blocks)):
|
538
|
+
motion_module.transformer_blocks[i] = SplitInferenceModule(
|
539
|
+
motion_module.transformer_blocks[i],
|
540
|
+
spatial_split_size,
|
541
|
+
0,
|
542
|
+
["hidden_states", "encoder_hidden_states"],
|
543
|
+
)
|
544
|
+
|
545
|
+
motion_module.proj_out = SplitInferenceModule(motion_module.proj_out, spatial_split_size, 0, ["input"])
|
546
|
+
|
547
|
+
def _enable_split_inference_attentions_(
|
548
|
+
self, attentions: List[Transformer2DModel], temporal_split_size: int
|
549
|
+
) -> None:
|
550
|
+
for i in range(len(attentions)):
|
551
|
+
attentions[i] = SplitInferenceModule(
|
552
|
+
attentions[i], temporal_split_size, 0, ["hidden_states", "encoder_hidden_states"]
|
553
|
+
)
|
554
|
+
|
555
|
+
def _enable_split_inference_resnets_(self, resnets: List[ResnetBlock2D], temporal_split_size: int) -> None:
|
556
|
+
for i in range(len(resnets)):
|
557
|
+
resnets[i] = SplitInferenceModule(resnets[i], temporal_split_size, 0, ["input_tensor", "temb"])
|
558
|
+
|
559
|
+
def _enable_split_inference_samplers_(
|
560
|
+
self, samplers: Union[List[Downsample2D], List[Upsample2D]], temporal_split_size: int
|
561
|
+
) -> None:
|
562
|
+
for i in range(len(samplers)):
|
563
|
+
samplers[i] = SplitInferenceModule(samplers[i], temporal_split_size, 0, ["hidden_states"])
|
564
|
+
|
565
|
+
def enable_free_noise_split_inference(self, spatial_split_size: int = 256, temporal_split_size: int = 16) -> None:
|
566
|
+
r"""
|
567
|
+
Enable FreeNoise memory optimizations by utilizing
|
568
|
+
[`~diffusers.pipelines.free_noise_utils.SplitInferenceModule`] across different intermediate modeling blocks.
|
569
|
+
|
570
|
+
Args:
|
571
|
+
spatial_split_size (`int`, defaults to `256`):
|
572
|
+
The split size across spatial dimensions for internal blocks. This is used in facilitating split
|
573
|
+
inference across the effective batch dimension (`[B x H x W, F, C]`) of intermediate tensors in motion
|
574
|
+
modeling blocks.
|
575
|
+
temporal_split_size (`int`, defaults to `16`):
|
576
|
+
The split size across temporal dimensions for internal blocks. This is used in facilitating split
|
577
|
+
inference across the effective batch dimension (`[B x F, H x W, C]`) of intermediate tensors in spatial
|
578
|
+
attention, resnets, downsampling and upsampling blocks.
|
579
|
+
"""
|
580
|
+
# TODO(aryan): Discuss on what's the best way to provide more control to users
|
581
|
+
blocks = [*self.unet.down_blocks, self.unet.mid_block, *self.unet.up_blocks]
|
582
|
+
for block in blocks:
|
583
|
+
if getattr(block, "motion_modules", None) is not None:
|
584
|
+
self._enable_split_inference_motion_modules_(block.motion_modules, spatial_split_size)
|
585
|
+
if getattr(block, "attentions", None) is not None:
|
586
|
+
self._enable_split_inference_attentions_(block.attentions, temporal_split_size)
|
587
|
+
if getattr(block, "resnets", None) is not None:
|
588
|
+
self._enable_split_inference_resnets_(block.resnets, temporal_split_size)
|
589
|
+
if getattr(block, "downsamplers", None) is not None:
|
590
|
+
self._enable_split_inference_samplers_(block.downsamplers, temporal_split_size)
|
591
|
+
if getattr(block, "upsamplers", None) is not None:
|
592
|
+
self._enable_split_inference_samplers_(block.upsamplers, temporal_split_size)
|
593
|
+
|
234
594
|
@property
|
235
595
|
def free_noise_enabled(self):
|
236
596
|
return hasattr(self, "_free_noise_context_length") and self._free_noise_context_length is not None
|
@@ -125,9 +125,21 @@ def get_resize_crop_region_for_grid(src, tgt_size):
|
|
125
125
|
|
126
126
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
127
127
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
128
|
-
"""
|
129
|
-
|
130
|
-
|
128
|
+
r"""
|
129
|
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
130
|
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
131
|
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
132
|
+
|
133
|
+
Args:
|
134
|
+
noise_cfg (`torch.Tensor`):
|
135
|
+
The predicted noise tensor for the guided diffusion process.
|
136
|
+
noise_pred_text (`torch.Tensor`):
|
137
|
+
The predicted noise tensor for the text-guided diffusion process.
|
138
|
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
139
|
+
A rescale factor applied to the noise predictions.
|
140
|
+
|
141
|
+
Returns:
|
142
|
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
131
143
|
"""
|
132
144
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
133
145
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|
@@ -547,7 +547,7 @@ class KandinskyV22Img2ImgCombinedPipeline(DiffusionPipeline):
|
|
547
547
|
negative_image_embeds = prior_outputs[1]
|
548
548
|
|
549
549
|
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
550
|
-
image = [image] if isinstance(
|
550
|
+
image = [image] if isinstance(image, PIL.Image.Image) else image
|
551
551
|
|
552
552
|
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
553
553
|
prompt = (image_embeds.shape[0] // len(prompt)) * prompt
|
@@ -813,7 +813,7 @@ class KandinskyV22InpaintCombinedPipeline(DiffusionPipeline):
|
|
813
813
|
negative_image_embeds = prior_outputs[1]
|
814
814
|
|
815
815
|
prompt = [prompt] if not isinstance(prompt, (list, tuple)) else prompt
|
816
|
-
image = [image] if isinstance(
|
816
|
+
image = [image] if isinstance(image, PIL.Image.Image) else image
|
817
817
|
mask_image = [mask_image] if isinstance(mask_image, PIL.Image.Image) else mask_image
|
818
818
|
|
819
819
|
if len(prompt) < image_embeds.shape[0] and image_embeds.shape[0] % len(prompt) == 0:
|
@@ -70,7 +70,7 @@ def retrieve_timesteps(
|
|
70
70
|
sigmas: Optional[List[float]] = None,
|
71
71
|
**kwargs,
|
72
72
|
):
|
73
|
-
"""
|
73
|
+
r"""
|
74
74
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
75
75
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
76
76
|
|
@@ -89,7 +89,7 @@ def retrieve_timesteps(
|
|
89
89
|
sigmas: Optional[List[float]] = None,
|
90
90
|
**kwargs,
|
91
91
|
):
|
92
|
-
"""
|
92
|
+
r"""
|
93
93
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
94
94
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
95
95
|
|
@@ -564,14 +564,16 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
|
564
564
|
if denoising_start is None:
|
565
565
|
init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
|
566
566
|
t_start = max(num_inference_steps - init_timestep, 0)
|
567
|
-
else:
|
568
|
-
t_start = 0
|
569
567
|
|
570
|
-
|
568
|
+
timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
|
569
|
+
if hasattr(self.scheduler, "set_begin_index"):
|
570
|
+
self.scheduler.set_begin_index(t_start * self.scheduler.order)
|
571
571
|
|
572
|
-
|
573
|
-
|
574
|
-
|
572
|
+
return timesteps, num_inference_steps - t_start
|
573
|
+
|
574
|
+
else:
|
575
|
+
# Strength is irrelevant if we directly request a timestep to start at;
|
576
|
+
# that is, strength is determined by the denoising_start instead.
|
575
577
|
discrete_timestep_cutoff = int(
|
576
578
|
round(
|
577
579
|
self.scheduler.config.num_train_timesteps
|
@@ -579,7 +581,7 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
|
579
581
|
)
|
580
582
|
)
|
581
583
|
|
582
|
-
num_inference_steps = (timesteps < discrete_timestep_cutoff).sum().item()
|
584
|
+
num_inference_steps = (self.scheduler.timesteps < discrete_timestep_cutoff).sum().item()
|
583
585
|
if self.scheduler.order == 2 and num_inference_steps % 2 == 0:
|
584
586
|
# if the scheduler is a 2nd order scheduler we might have to do +1
|
585
587
|
# because `num_inference_steps` might be even given that every timestep
|
@@ -590,11 +592,12 @@ class KolorsImg2ImgPipeline(DiffusionPipeline, StableDiffusionMixin, StableDiffu
|
|
590
592
|
num_inference_steps = num_inference_steps + 1
|
591
593
|
|
592
594
|
# because t_n+1 >= t_n, we slice the timesteps starting from the end
|
593
|
-
|
595
|
+
t_start = len(self.scheduler.timesteps) - num_inference_steps
|
596
|
+
timesteps = self.scheduler.timesteps[t_start:]
|
597
|
+
if hasattr(self.scheduler, "set_begin_index"):
|
598
|
+
self.scheduler.set_begin_index(t_start)
|
594
599
|
return timesteps, num_inference_steps
|
595
600
|
|
596
|
-
return timesteps, num_inference_steps - t_start
|
597
|
-
|
598
601
|
# Copied from diffusers.pipelines.stable_diffusion_xl.pipeline_stable_diffusion_xl_img2img.StableDiffusionXLImg2ImgPipeline.prepare_latents
|
599
602
|
def prepare_latents(
|
600
603
|
self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None, add_noise=True
|
@@ -277,6 +277,7 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
277
277
|
padding_strategy: PaddingStrategy = PaddingStrategy.DO_NOT_PAD,
|
278
278
|
pad_to_multiple_of: Optional[int] = None,
|
279
279
|
return_attention_mask: Optional[bool] = None,
|
280
|
+
padding_side: Optional[bool] = None,
|
280
281
|
) -> dict:
|
281
282
|
"""
|
282
283
|
Pad encoded inputs (on left/right and up to predefined length or max length in the batch)
|
@@ -298,6 +299,9 @@ class ChatGLMTokenizer(PreTrainedTokenizer):
|
|
298
299
|
pad_to_multiple_of: (optional) Integer if set will pad the sequence to a multiple of the provided value.
|
299
300
|
This is especially useful to enable the use of Tensor Core on NVIDIA hardware with compute capability
|
300
301
|
`>= 7.5` (Volta).
|
302
|
+
padding_side (`str`, *optional*):
|
303
|
+
The side on which the model should have padding applied. Should be selected between ['right', 'left'].
|
304
|
+
Default value is picked from the class attribute of the same name.
|
301
305
|
return_attention_mask:
|
302
306
|
(optional) Set to False to avoid returning attention mask (default: set to model specifics)
|
303
307
|
"""
|
@@ -66,7 +66,7 @@ def retrieve_timesteps(
|
|
66
66
|
sigmas: Optional[List[float]] = None,
|
67
67
|
**kwargs,
|
68
68
|
):
|
69
|
-
"""
|
69
|
+
r"""
|
70
70
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
71
71
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
72
72
|
|
@@ -70,7 +70,7 @@ def retrieve_timesteps(
|
|
70
70
|
sigmas: Optional[List[float]] = None,
|
71
71
|
**kwargs,
|
72
72
|
):
|
73
|
-
"""
|
73
|
+
r"""
|
74
74
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
75
75
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
76
76
|
|
@@ -56,7 +56,7 @@ EXAMPLE_DOC_STRING = """
|
|
56
56
|
>>> from diffusers.utils import export_to_gif
|
57
57
|
|
58
58
|
>>> # You can replace the checkpoint id with "maxin-cn/Latte-1" too.
|
59
|
-
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
|
59
|
+
>>> pipe = LattePipeline.from_pretrained("maxin-cn/Latte-1", torch_dtype=torch.float16)
|
60
60
|
>>> # Enable memory optimizations.
|
61
61
|
>>> pipe.enable_model_cpu_offload()
|
62
62
|
|
@@ -76,7 +76,7 @@ def retrieve_timesteps(
|
|
76
76
|
sigmas: Optional[List[float]] = None,
|
77
77
|
**kwargs,
|
78
78
|
):
|
79
|
-
"""
|
79
|
+
r"""
|
80
80
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
81
81
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
82
82
|
|
@@ -234,9 +234,21 @@ class LEDITSCrossAttnProcessor:
|
|
234
234
|
|
235
235
|
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.rescale_noise_cfg
|
236
236
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
|
237
|
-
"""
|
238
|
-
|
239
|
-
|
237
|
+
r"""
|
238
|
+
Rescales `noise_cfg` tensor based on `guidance_rescale` to improve image quality and fix overexposure. Based on
|
239
|
+
Section 3.4 from [Common Diffusion Noise Schedules and Sample Steps are
|
240
|
+
Flawed](https://arxiv.org/pdf/2305.08891.pdf).
|
241
|
+
|
242
|
+
Args:
|
243
|
+
noise_cfg (`torch.Tensor`):
|
244
|
+
The predicted noise tensor for the guided diffusion process.
|
245
|
+
noise_pred_text (`torch.Tensor`):
|
246
|
+
The predicted noise tensor for the text-guided diffusion process.
|
247
|
+
guidance_rescale (`float`, *optional*, defaults to 0.0):
|
248
|
+
A rescale factor applied to the noise predictions.
|
249
|
+
|
250
|
+
Returns:
|
251
|
+
noise_cfg (`torch.Tensor`): The rescaled noise prediction tensor.
|
240
252
|
"""
|
241
253
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
|
242
254
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
|