diffusers 0.31.0__py3-none-any.whl → 0.32.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +66 -5
- diffusers/callbacks.py +56 -3
- diffusers/configuration_utils.py +1 -1
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +25 -17
- diffusers/loaders/__init__.py +22 -3
- diffusers/loaders/ip_adapter.py +538 -15
- diffusers/loaders/lora_base.py +124 -118
- diffusers/loaders/lora_conversion_utils.py +318 -3
- diffusers/loaders/lora_pipeline.py +1688 -368
- diffusers/loaders/peft.py +379 -0
- diffusers/loaders/single_file_model.py +71 -4
- diffusers/loaders/single_file_utils.py +519 -9
- diffusers/loaders/textual_inversion.py +3 -3
- diffusers/loaders/transformer_flux.py +181 -0
- diffusers/loaders/transformer_sd3.py +89 -0
- diffusers/loaders/unet.py +17 -4
- diffusers/models/__init__.py +47 -14
- diffusers/models/activations.py +22 -9
- diffusers/models/attention.py +13 -4
- diffusers/models/attention_flax.py +1 -1
- diffusers/models/attention_processor.py +2059 -281
- diffusers/models/autoencoders/__init__.py +5 -0
- diffusers/models/autoencoders/autoencoder_dc.py +620 -0
- diffusers/models/autoencoders/autoencoder_kl.py +2 -1
- diffusers/models/autoencoders/autoencoder_kl_allegro.py +1149 -0
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +36 -27
- diffusers/models/autoencoders/autoencoder_kl_hunyuan_video.py +1176 -0
- diffusers/models/autoencoders/autoencoder_kl_ltx.py +1338 -0
- diffusers/models/autoencoders/autoencoder_kl_mochi.py +1166 -0
- diffusers/models/autoencoders/autoencoder_kl_temporal_decoder.py +3 -10
- diffusers/models/autoencoders/autoencoder_tiny.py +4 -2
- diffusers/models/autoencoders/vae.py +18 -5
- diffusers/models/controlnet.py +47 -802
- diffusers/models/controlnet_flux.py +29 -495
- diffusers/models/controlnet_sd3.py +25 -379
- diffusers/models/controlnet_sparsectrl.py +46 -718
- diffusers/models/controlnets/__init__.py +23 -0
- diffusers/models/controlnets/controlnet.py +872 -0
- diffusers/models/{controlnet_flax.py → controlnets/controlnet_flax.py} +5 -5
- diffusers/models/controlnets/controlnet_flux.py +536 -0
- diffusers/models/{controlnet_hunyuan.py → controlnets/controlnet_hunyuan.py} +7 -7
- diffusers/models/controlnets/controlnet_sd3.py +489 -0
- diffusers/models/controlnets/controlnet_sparsectrl.py +788 -0
- diffusers/models/controlnets/controlnet_union.py +832 -0
- diffusers/models/{controlnet_xs.py → controlnets/controlnet_xs.py} +14 -13
- diffusers/models/controlnets/multicontrolnet.py +183 -0
- diffusers/models/embeddings.py +838 -43
- diffusers/models/model_loading_utils.py +88 -6
- diffusers/models/modeling_flax_utils.py +1 -1
- diffusers/models/modeling_utils.py +74 -28
- diffusers/models/normalization.py +78 -13
- diffusers/models/transformers/__init__.py +5 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +2 -2
- diffusers/models/transformers/cogvideox_transformer_3d.py +46 -11
- diffusers/models/transformers/dit_transformer_2d.py +1 -1
- diffusers/models/transformers/latte_transformer_3d.py +4 -4
- diffusers/models/transformers/pixart_transformer_2d.py +1 -1
- diffusers/models/transformers/sana_transformer.py +488 -0
- diffusers/models/transformers/stable_audio_transformer.py +1 -1
- diffusers/models/transformers/transformer_2d.py +1 -1
- diffusers/models/transformers/transformer_allegro.py +422 -0
- diffusers/models/transformers/transformer_cogview3plus.py +1 -1
- diffusers/models/transformers/transformer_flux.py +30 -9
- diffusers/models/transformers/transformer_hunyuan_video.py +789 -0
- diffusers/models/transformers/transformer_ltx.py +469 -0
- diffusers/models/transformers/transformer_mochi.py +499 -0
- diffusers/models/transformers/transformer_sd3.py +105 -17
- diffusers/models/transformers/transformer_temporal.py +1 -1
- diffusers/models/unets/unet_1d_blocks.py +1 -1
- diffusers/models/unets/unet_2d.py +8 -1
- diffusers/models/unets/unet_2d_blocks.py +88 -21
- diffusers/models/unets/unet_2d_condition.py +1 -1
- diffusers/models/unets/unet_3d_blocks.py +9 -7
- diffusers/models/unets/unet_motion_model.py +5 -5
- diffusers/models/unets/unet_spatio_temporal_condition.py +23 -0
- diffusers/models/unets/unet_stable_cascade.py +2 -2
- diffusers/models/unets/uvit_2d.py +1 -1
- diffusers/models/upsampling.py +8 -0
- diffusers/pipelines/__init__.py +34 -0
- diffusers/pipelines/allegro/__init__.py +48 -0
- diffusers/pipelines/allegro/pipeline_allegro.py +938 -0
- diffusers/pipelines/allegro/pipeline_output.py +23 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +8 -2
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +1 -1
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +0 -6
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +8 -8
- diffusers/pipelines/audioldm2/modeling_audioldm2.py +3 -3
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -8
- diffusers/pipelines/auto_pipeline.py +53 -6
- diffusers/pipelines/blip_diffusion/modeling_blip2.py +1 -1
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +50 -22
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +51 -20
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +69 -21
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +47 -21
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +1 -1
- diffusers/pipelines/controlnet/__init__.py +86 -80
- diffusers/pipelines/controlnet/multicontrolnet.py +7 -178
- diffusers/pipelines/controlnet/pipeline_controlnet.py +11 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +1 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +3 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +1790 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +1501 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +1627 -0
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +5 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +53 -19
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +7 -7
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +31 -8
- diffusers/pipelines/flux/__init__.py +13 -1
- diffusers/pipelines/flux/modeling_flux.py +47 -0
- diffusers/pipelines/flux/pipeline_flux.py +204 -29
- diffusers/pipelines/flux/pipeline_flux_control.py +889 -0
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +945 -0
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1141 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +49 -27
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +40 -30
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +78 -56
- diffusers/pipelines/flux/pipeline_flux_fill.py +969 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +33 -27
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +36 -29
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +492 -0
- diffusers/pipelines/flux/pipeline_output.py +16 -0
- diffusers/pipelines/hunyuan_video/__init__.py +48 -0
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +687 -0
- diffusers/pipelines/hunyuan_video/pipeline_output.py +20 -0
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +5 -1
- diffusers/pipelines/kandinsky/pipeline_kandinsky_combined.py +9 -9
- diffusers/pipelines/kolors/text_encoder.py +2 -2
- diffusers/pipelines/latent_diffusion/pipeline_latent_diffusion.py +1 -1
- diffusers/pipelines/ltx/__init__.py +50 -0
- diffusers/pipelines/ltx/pipeline_ltx.py +789 -0
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +885 -0
- diffusers/pipelines/ltx/pipeline_output.py +20 -0
- diffusers/pipelines/lumina/pipeline_lumina.py +1 -8
- diffusers/pipelines/mochi/__init__.py +48 -0
- diffusers/pipelines/mochi/pipeline_mochi.py +748 -0
- diffusers/pipelines/mochi/pipeline_output.py +20 -0
- diffusers/pipelines/pag/__init__.py +7 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1 -3
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +6 -13
- diffusers/pipelines/pag/pipeline_pag_sana.py +886 -0
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +6 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3_img2img.py +1058 -0
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +3 -0
- diffusers/pipelines/pag/pipeline_pag_sd_inpaint.py +1356 -0
- diffusers/pipelines/pipeline_flax_utils.py +1 -1
- diffusers/pipelines/pipeline_loading_utils.py +25 -4
- diffusers/pipelines/pipeline_utils.py +35 -6
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +6 -13
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +6 -13
- diffusers/pipelines/sana/__init__.py +47 -0
- diffusers/pipelines/sana/pipeline_output.py +21 -0
- diffusers/pipelines/sana/pipeline_sana.py +884 -0
- diffusers/pipelines/stable_audio/pipeline_stable_audio.py +12 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +18 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +216 -20
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +62 -9
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +57 -8
- diffusers/pipelines/stable_diffusion_gligen/pipeline_stable_diffusion_gligen_text_image.py +11 -1
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +0 -8
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +0 -8
- diffusers/pipelines/unidiffuser/modeling_uvit.py +2 -2
- diffusers/pipelines/wuerstchen/modeling_wuerstchen_prior.py +1 -1
- diffusers/quantizers/auto.py +14 -1
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +4 -1
- diffusers/quantizers/gguf/__init__.py +1 -0
- diffusers/quantizers/gguf/gguf_quantizer.py +159 -0
- diffusers/quantizers/gguf/utils.py +456 -0
- diffusers/quantizers/quantization_config.py +280 -2
- diffusers/quantizers/torchao/__init__.py +15 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +285 -0
- diffusers/schedulers/scheduling_ddpm.py +2 -6
- diffusers/schedulers/scheduling_ddpm_parallel.py +2 -6
- diffusers/schedulers/scheduling_deis_multistep.py +28 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +35 -9
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +35 -8
- diffusers/schedulers/scheduling_dpmsolver_sde.py +4 -4
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +48 -10
- diffusers/schedulers/scheduling_euler_discrete.py +4 -4
- diffusers/schedulers/scheduling_flow_match_euler_discrete.py +153 -6
- diffusers/schedulers/scheduling_heun_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +4 -4
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +4 -4
- diffusers/schedulers/scheduling_lcm.py +2 -6
- diffusers/schedulers/scheduling_lms_discrete.py +4 -4
- diffusers/schedulers/scheduling_repaint.py +1 -1
- diffusers/schedulers/scheduling_sasolver.py +28 -9
- diffusers/schedulers/scheduling_tcd.py +2 -6
- diffusers/schedulers/scheduling_unipc_multistep.py +53 -8
- diffusers/training_utils.py +16 -2
- diffusers/utils/__init__.py +5 -0
- diffusers/utils/constants.py +1 -0
- diffusers/utils/dummy_pt_objects.py +180 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +270 -0
- diffusers/utils/dynamic_modules_utils.py +3 -3
- diffusers/utils/hub_utils.py +31 -39
- diffusers/utils/import_utils.py +67 -0
- diffusers/utils/peft_utils.py +3 -0
- diffusers/utils/testing_utils.py +56 -1
- diffusers/utils/torch_utils.py +3 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/METADATA +69 -69
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/RECORD +214 -162
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/WHEEL +1 -1
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/LICENSE +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.31.0.dist-info → diffusers-0.32.0.dist-info}/top_level.txt +0 -0
@@ -17,6 +17,7 @@
|
|
17
17
|
import importlib
|
18
18
|
import inspect
|
19
19
|
import os
|
20
|
+
from array import array
|
20
21
|
from collections import OrderedDict
|
21
22
|
from pathlib import Path
|
22
23
|
from typing import List, Optional, Union
|
@@ -25,8 +26,8 @@ import safetensors
|
|
25
26
|
import torch
|
26
27
|
from huggingface_hub.utils import EntryNotFoundError
|
27
28
|
|
28
|
-
from ..quantizers.quantization_config import QuantizationMethod
|
29
29
|
from ..utils import (
|
30
|
+
GGUF_FILE_EXTENSION,
|
30
31
|
SAFE_WEIGHTS_INDEX_NAME,
|
31
32
|
SAFETENSORS_FILE_EXTENSION,
|
32
33
|
WEIGHTS_INDEX_NAME,
|
@@ -34,6 +35,8 @@ from ..utils import (
|
|
34
35
|
_get_model_file,
|
35
36
|
deprecate,
|
36
37
|
is_accelerate_available,
|
38
|
+
is_gguf_available,
|
39
|
+
is_torch_available,
|
37
40
|
is_torch_version,
|
38
41
|
logging,
|
39
42
|
)
|
@@ -140,6 +143,8 @@ def load_state_dict(checkpoint_file: Union[str, os.PathLike], variant: Optional[
|
|
140
143
|
file_extension = os.path.basename(checkpoint_file).split(".")[-1]
|
141
144
|
if file_extension == SAFETENSORS_FILE_EXTENSION:
|
142
145
|
return safetensors.torch.load_file(checkpoint_file, device="cpu")
|
146
|
+
elif file_extension == GGUF_FILE_EXTENSION:
|
147
|
+
return load_gguf_checkpoint(checkpoint_file)
|
143
148
|
else:
|
144
149
|
weights_only_kwarg = {"weights_only": True} if is_torch_version(">=", "1.13") else {}
|
145
150
|
return torch.load(
|
@@ -176,11 +181,12 @@ def load_model_dict_into_meta(
|
|
176
181
|
hf_quantizer=None,
|
177
182
|
keep_in_fp32_modules=None,
|
178
183
|
) -> List[str]:
|
184
|
+
if device is not None and not isinstance(device, (str, torch.device)):
|
185
|
+
raise ValueError(f"Expected device to have type `str` or `torch.device`, but got {type(device)=}.")
|
179
186
|
if hf_quantizer is None:
|
180
187
|
device = device or torch.device("cpu")
|
181
188
|
dtype = dtype or torch.float32
|
182
189
|
is_quantized = hf_quantizer is not None
|
183
|
-
is_quant_method_bnb = getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
184
190
|
|
185
191
|
accepts_dtype = "dtype" in set(inspect.signature(set_module_tensor_to_device).parameters.keys())
|
186
192
|
empty_state_dict = model.state_dict()
|
@@ -211,17 +217,18 @@ def load_model_dict_into_meta(
|
|
211
217
|
set_module_kwargs["dtype"] = dtype
|
212
218
|
|
213
219
|
# bnb params are flattened.
|
220
|
+
# gguf quants have a different shape based on the type of quantization applied
|
214
221
|
if empty_state_dict[param_name].shape != param.shape:
|
215
222
|
if (
|
216
|
-
|
223
|
+
is_quantized
|
217
224
|
and hf_quantizer.pre_quantized
|
218
225
|
and hf_quantizer.check_if_quantized_param(model, param, param_name, state_dict, param_device=device)
|
219
226
|
):
|
220
|
-
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name]
|
221
|
-
|
227
|
+
hf_quantizer.check_quantized_param_shape(param_name, empty_state_dict[param_name], param)
|
228
|
+
else:
|
222
229
|
model_name_or_path_str = f"{model_name_or_path} " if model_name_or_path is not None else ""
|
223
230
|
raise ValueError(
|
224
|
-
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name]}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
231
|
+
f"Cannot load {model_name_or_path_str} because {param_name} expected shape {empty_state_dict[param_name].shape}, but got {param.shape}. If you want to instead overwrite randomly initialized weights, please make sure to pass both `low_cpu_mem_usage=False` and `ignore_mismatched_sizes=True`. For more information, see also: https://github.com/huggingface/diffusers/issues/1619#issuecomment-1345604389 as an example."
|
225
232
|
)
|
226
233
|
|
227
234
|
if is_quantized and (
|
@@ -396,3 +403,78 @@ def _fetch_index_file_legacy(
|
|
396
403
|
index_file = None
|
397
404
|
|
398
405
|
return index_file
|
406
|
+
|
407
|
+
|
408
|
+
def _gguf_parse_value(_value, data_type):
|
409
|
+
if not isinstance(data_type, list):
|
410
|
+
data_type = [data_type]
|
411
|
+
if len(data_type) == 1:
|
412
|
+
data_type = data_type[0]
|
413
|
+
array_data_type = None
|
414
|
+
else:
|
415
|
+
if data_type[0] != 9:
|
416
|
+
raise ValueError("Received multiple types, therefore expected the first type to indicate an array.")
|
417
|
+
data_type, array_data_type = data_type
|
418
|
+
|
419
|
+
if data_type in [0, 1, 2, 3, 4, 5, 10, 11]:
|
420
|
+
_value = int(_value[0])
|
421
|
+
elif data_type in [6, 12]:
|
422
|
+
_value = float(_value[0])
|
423
|
+
elif data_type in [7]:
|
424
|
+
_value = bool(_value[0])
|
425
|
+
elif data_type in [8]:
|
426
|
+
_value = array("B", list(_value)).tobytes().decode()
|
427
|
+
elif data_type in [9]:
|
428
|
+
_value = _gguf_parse_value(_value, array_data_type)
|
429
|
+
return _value
|
430
|
+
|
431
|
+
|
432
|
+
def load_gguf_checkpoint(gguf_checkpoint_path, return_tensors=False):
|
433
|
+
"""
|
434
|
+
Load a GGUF file and return a dictionary of parsed parameters containing tensors, the parsed tokenizer and config
|
435
|
+
attributes.
|
436
|
+
|
437
|
+
Args:
|
438
|
+
gguf_checkpoint_path (`str`):
|
439
|
+
The path the to GGUF file to load
|
440
|
+
return_tensors (`bool`, defaults to `True`):
|
441
|
+
Whether to read the tensors from the file and return them. Not doing so is faster and only loads the
|
442
|
+
metadata in memory.
|
443
|
+
"""
|
444
|
+
|
445
|
+
if is_gguf_available() and is_torch_available():
|
446
|
+
import gguf
|
447
|
+
from gguf import GGUFReader
|
448
|
+
|
449
|
+
from ..quantizers.gguf.utils import SUPPORTED_GGUF_QUANT_TYPES, GGUFParameter
|
450
|
+
else:
|
451
|
+
logger.error(
|
452
|
+
"Loading a GGUF checkpoint in PyTorch, requires both PyTorch and GGUF>=0.10.0 to be installed. Please see "
|
453
|
+
"https://pytorch.org/ and https://github.com/ggerganov/llama.cpp/tree/master/gguf-py for installation instructions."
|
454
|
+
)
|
455
|
+
raise ImportError("Please install torch and gguf>=0.10.0 to load a GGUF checkpoint in PyTorch.")
|
456
|
+
|
457
|
+
reader = GGUFReader(gguf_checkpoint_path)
|
458
|
+
|
459
|
+
parsed_parameters = {}
|
460
|
+
for tensor in reader.tensors:
|
461
|
+
name = tensor.name
|
462
|
+
quant_type = tensor.tensor_type
|
463
|
+
|
464
|
+
# if the tensor is a torch supported dtype do not use GGUFParameter
|
465
|
+
is_gguf_quant = quant_type not in [gguf.GGMLQuantizationType.F32, gguf.GGMLQuantizationType.F16]
|
466
|
+
if is_gguf_quant and quant_type not in SUPPORTED_GGUF_QUANT_TYPES:
|
467
|
+
_supported_quants_str = "\n".join([str(type) for type in SUPPORTED_GGUF_QUANT_TYPES])
|
468
|
+
raise ValueError(
|
469
|
+
(
|
470
|
+
f"{name} has a quantization type: {str(quant_type)} which is unsupported."
|
471
|
+
"\n\nCurrently the following quantization types are supported: \n\n"
|
472
|
+
f"{_supported_quants_str}"
|
473
|
+
"\n\nTo request support for this quantization type please open an issue here: https://github.com/huggingface/diffusers"
|
474
|
+
)
|
475
|
+
)
|
476
|
+
|
477
|
+
weights = torch.from_numpy(tensor.data.copy())
|
478
|
+
parsed_parameters[name] = GGUFParameter(weights, quant_type=quant_type) if is_gguf_quant else weights
|
479
|
+
|
480
|
+
return parsed_parameters
|
@@ -530,7 +530,7 @@ class FlaxModelMixin(PushToHubMixin):
|
|
530
530
|
|
531
531
|
if push_to_hub:
|
532
532
|
commit_message = kwargs.pop("commit_message", None)
|
533
|
-
private = kwargs.pop("private",
|
533
|
+
private = kwargs.pop("private", None)
|
534
534
|
create_pr = kwargs.pop("create_pr", False)
|
535
535
|
token = kwargs.pop("token", None)
|
536
536
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
@@ -99,21 +99,39 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
|
99
99
|
|
100
100
|
|
101
101
|
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
102
|
+
"""
|
103
|
+
Returns the first found floating dtype in parameters if there is one, otherwise returns the last dtype it found.
|
104
|
+
"""
|
105
|
+
last_dtype = None
|
106
|
+
for param in parameter.parameters():
|
107
|
+
last_dtype = param.dtype
|
108
|
+
if param.is_floating_point():
|
109
|
+
return param.dtype
|
110
|
+
|
111
|
+
for buffer in parameter.buffers():
|
112
|
+
last_dtype = buffer.dtype
|
113
|
+
if buffer.is_floating_point():
|
114
|
+
return buffer.dtype
|
115
|
+
|
116
|
+
if last_dtype is not None:
|
117
|
+
# if no floating dtype was found return whatever the first dtype is
|
118
|
+
return last_dtype
|
119
|
+
|
120
|
+
# For nn.DataParallel compatibility in PyTorch > 1.5
|
121
|
+
def find_tensor_attributes(module: nn.Module) -> List[Tuple[str, Tensor]]:
|
122
|
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
123
|
+
return tuples
|
124
|
+
|
125
|
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
126
|
+
last_tuple = None
|
127
|
+
for tuple in gen:
|
128
|
+
last_tuple = tuple
|
129
|
+
if tuple[1].is_floating_point():
|
130
|
+
return tuple[1].dtype
|
131
|
+
|
132
|
+
if last_tuple is not None:
|
133
|
+
# fallback to the last dtype
|
134
|
+
return last_tuple[1].dtype
|
117
135
|
|
118
136
|
|
119
137
|
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
@@ -208,6 +226,35 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
208
226
|
"""
|
209
227
|
self.set_use_npu_flash_attention(False)
|
210
228
|
|
229
|
+
def set_use_xla_flash_attention(
|
230
|
+
self, use_xla_flash_attention: bool, partition_spec: Optional[Callable] = None
|
231
|
+
) -> None:
|
232
|
+
# Recursively walk through all the children.
|
233
|
+
# Any children which exposes the set_use_xla_flash_attention method
|
234
|
+
# gets the message
|
235
|
+
def fn_recursive_set_flash_attention(module: torch.nn.Module):
|
236
|
+
if hasattr(module, "set_use_xla_flash_attention"):
|
237
|
+
module.set_use_xla_flash_attention(use_xla_flash_attention, partition_spec)
|
238
|
+
|
239
|
+
for child in module.children():
|
240
|
+
fn_recursive_set_flash_attention(child)
|
241
|
+
|
242
|
+
for module in self.children():
|
243
|
+
if isinstance(module, torch.nn.Module):
|
244
|
+
fn_recursive_set_flash_attention(module)
|
245
|
+
|
246
|
+
def enable_xla_flash_attention(self, partition_spec: Optional[Callable] = None):
|
247
|
+
r"""
|
248
|
+
Enable the flash attention pallals kernel for torch_xla.
|
249
|
+
"""
|
250
|
+
self.set_use_xla_flash_attention(True, partition_spec)
|
251
|
+
|
252
|
+
def disable_xla_flash_attention(self):
|
253
|
+
r"""
|
254
|
+
Disable the flash attention pallals kernel for torch_xla.
|
255
|
+
"""
|
256
|
+
self.set_use_xla_flash_attention(False)
|
257
|
+
|
211
258
|
def set_use_memory_efficient_attention_xformers(
|
212
259
|
self, valid: bool, attention_op: Optional[Callable] = None
|
213
260
|
) -> None:
|
@@ -338,7 +385,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
338
385
|
|
339
386
|
if push_to_hub:
|
340
387
|
commit_message = kwargs.pop("commit_message", None)
|
341
|
-
private = kwargs.pop("private",
|
388
|
+
private = kwargs.pop("private", None)
|
342
389
|
create_pr = kwargs.pop("create_pr", False)
|
343
390
|
token = kwargs.pop("token", None)
|
344
391
|
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
@@ -671,10 +718,12 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
671
718
|
hf_quantizer = None
|
672
719
|
|
673
720
|
if hf_quantizer is not None:
|
674
|
-
|
721
|
+
is_bnb_quantization_method = hf_quantizer.quantization_config.quant_method.value == "bitsandbytes"
|
722
|
+
if is_bnb_quantization_method and device_map is not None:
|
675
723
|
raise NotImplementedError(
|
676
|
-
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
|
724
|
+
"Currently, `device_map` is automatically inferred for quantized bitsandbytes models. Support for providing `device_map` as an input will be added in the future."
|
677
725
|
)
|
726
|
+
|
678
727
|
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
679
728
|
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
680
729
|
|
@@ -771,7 +820,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
771
820
|
revision=revision,
|
772
821
|
subfolder=subfolder or "",
|
773
822
|
)
|
774
|
-
if hf_quantizer is not None:
|
823
|
+
if hf_quantizer is not None and is_bnb_quantization_method:
|
775
824
|
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
776
825
|
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
777
826
|
is_sharded = False
|
@@ -829,14 +878,11 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
829
878
|
if device_map is None and not is_sharded:
|
830
879
|
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
|
831
880
|
# It would error out during the `validate_environment()` call above in the absence of cuda.
|
832
|
-
is_quant_method_bnb = (
|
833
|
-
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
834
|
-
)
|
835
881
|
if hf_quantizer is None:
|
836
882
|
param_device = "cpu"
|
837
883
|
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
838
|
-
|
839
|
-
param_device = torch.cuda.current_device()
|
884
|
+
else:
|
885
|
+
param_device = torch.device(torch.cuda.current_device())
|
840
886
|
state_dict = load_state_dict(model_file, variant=variant)
|
841
887
|
model._convert_deprecated_attention_blocks(state_dict)
|
842
888
|
|
@@ -1010,14 +1056,14 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1010
1056
|
dtype_present_in_args = True
|
1011
1057
|
break
|
1012
1058
|
|
1013
|
-
|
1014
|
-
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
1059
|
+
if getattr(self, "is_quantized", False):
|
1015
1060
|
if dtype_present_in_args:
|
1016
1061
|
raise ValueError(
|
1017
|
-
"
|
1018
|
-
"
|
1062
|
+
"Casting a quantized model to a new `dtype` is unsupported. To set the dtype of unquantized layers, please "
|
1063
|
+
"use the `torch_dtype` argument when loading the model using `from_pretrained` or `from_single_file`"
|
1019
1064
|
)
|
1020
1065
|
|
1066
|
+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
1021
1067
|
if getattr(self, "is_loaded_in_8bit", False):
|
1022
1068
|
raise ValueError(
|
1023
1069
|
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
@@ -22,10 +22,7 @@ import torch.nn.functional as F
|
|
22
22
|
|
23
23
|
from ..utils import is_torch_version
|
24
24
|
from .activations import get_activation
|
25
|
-
from .embeddings import
|
26
|
-
CombinedTimestepLabelEmbeddings,
|
27
|
-
PixArtAlphaCombinedTimestepSizeEmbeddings,
|
28
|
-
)
|
25
|
+
from .embeddings import CombinedTimestepLabelEmbeddings, PixArtAlphaCombinedTimestepSizeEmbeddings
|
29
26
|
|
30
27
|
|
31
28
|
class AdaLayerNorm(nn.Module):
|
@@ -266,6 +263,7 @@ class AdaLayerNormSingle(nn.Module):
|
|
266
263
|
hidden_dtype: Optional[torch.dtype] = None,
|
267
264
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
268
265
|
# No modulation happening here.
|
266
|
+
added_cond_kwargs = added_cond_kwargs or {"resolution": None, "aspect_ratio": None}
|
269
267
|
embedded_timestep = self.emb(timestep, **added_cond_kwargs, batch_size=batch_size, hidden_dtype=hidden_dtype)
|
270
268
|
return self.linear(self.silu(embedded_timestep)), embedded_timestep
|
271
269
|
|
@@ -358,20 +356,21 @@ class LuminaLayerNormContinuous(nn.Module):
|
|
358
356
|
out_dim: Optional[int] = None,
|
359
357
|
):
|
360
358
|
super().__init__()
|
359
|
+
|
361
360
|
# AdaLN
|
362
361
|
self.silu = nn.SiLU()
|
363
362
|
self.linear_1 = nn.Linear(conditioning_embedding_dim, embedding_dim, bias=bias)
|
363
|
+
|
364
364
|
if norm_type == "layer_norm":
|
365
365
|
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
366
|
+
elif norm_type == "rms_norm":
|
367
|
+
self.norm = RMSNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
|
366
368
|
else:
|
367
369
|
raise ValueError(f"unknown norm_type {norm_type}")
|
368
|
-
|
370
|
+
|
371
|
+
self.linear_2 = None
|
369
372
|
if out_dim is not None:
|
370
|
-
self.linear_2 = nn.Linear(
|
371
|
-
embedding_dim,
|
372
|
-
out_dim,
|
373
|
-
bias=bias,
|
374
|
-
)
|
373
|
+
self.linear_2 = nn.Linear(embedding_dim, out_dim, bias=bias)
|
375
374
|
|
376
375
|
def forward(
|
377
376
|
self,
|
@@ -486,20 +485,24 @@ else:
|
|
486
485
|
|
487
486
|
|
488
487
|
class RMSNorm(nn.Module):
|
489
|
-
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
488
|
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True, bias: bool = False):
|
490
489
|
super().__init__()
|
491
490
|
|
492
491
|
self.eps = eps
|
492
|
+
self.elementwise_affine = elementwise_affine
|
493
493
|
|
494
494
|
if isinstance(dim, numbers.Integral):
|
495
495
|
dim = (dim,)
|
496
496
|
|
497
497
|
self.dim = torch.Size(dim)
|
498
498
|
|
499
|
+
self.weight = None
|
500
|
+
self.bias = None
|
501
|
+
|
499
502
|
if elementwise_affine:
|
500
503
|
self.weight = nn.Parameter(torch.ones(dim))
|
501
|
-
|
502
|
-
|
504
|
+
if bias:
|
505
|
+
self.bias = nn.Parameter(torch.zeros(dim))
|
503
506
|
|
504
507
|
def forward(self, hidden_states):
|
505
508
|
input_dtype = hidden_states.dtype
|
@@ -511,12 +514,44 @@ class RMSNorm(nn.Module):
|
|
511
514
|
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
512
515
|
hidden_states = hidden_states.to(self.weight.dtype)
|
513
516
|
hidden_states = hidden_states * self.weight
|
517
|
+
if self.bias is not None:
|
518
|
+
hidden_states = hidden_states + self.bias
|
514
519
|
else:
|
515
520
|
hidden_states = hidden_states.to(input_dtype)
|
516
521
|
|
517
522
|
return hidden_states
|
518
523
|
|
519
524
|
|
525
|
+
# TODO: (Dhruv) This can be replaced with regular RMSNorm in Mochi once `_keep_in_fp32_modules` is supported
|
526
|
+
# for sharded checkpoints, see: https://github.com/huggingface/diffusers/issues/10013
|
527
|
+
class MochiRMSNorm(nn.Module):
|
528
|
+
def __init__(self, dim, eps: float, elementwise_affine: bool = True):
|
529
|
+
super().__init__()
|
530
|
+
|
531
|
+
self.eps = eps
|
532
|
+
|
533
|
+
if isinstance(dim, numbers.Integral):
|
534
|
+
dim = (dim,)
|
535
|
+
|
536
|
+
self.dim = torch.Size(dim)
|
537
|
+
|
538
|
+
if elementwise_affine:
|
539
|
+
self.weight = nn.Parameter(torch.ones(dim))
|
540
|
+
else:
|
541
|
+
self.weight = None
|
542
|
+
|
543
|
+
def forward(self, hidden_states):
|
544
|
+
input_dtype = hidden_states.dtype
|
545
|
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
546
|
+
hidden_states = hidden_states * torch.rsqrt(variance + self.eps)
|
547
|
+
|
548
|
+
if self.weight is not None:
|
549
|
+
hidden_states = hidden_states * self.weight
|
550
|
+
hidden_states = hidden_states.to(input_dtype)
|
551
|
+
|
552
|
+
return hidden_states
|
553
|
+
|
554
|
+
|
520
555
|
class GlobalResponseNorm(nn.Module):
|
521
556
|
# Taken from https://github.com/facebookresearch/ConvNeXt-V2/blob/3608f67cc1dae164790c5d0aead7bf2d73d9719b/models/utils.py#L105
|
522
557
|
def __init__(self, dim):
|
@@ -528,3 +563,33 @@ class GlobalResponseNorm(nn.Module):
|
|
528
563
|
gx = torch.norm(x, p=2, dim=(1, 2), keepdim=True)
|
529
564
|
nx = gx / (gx.mean(dim=-1, keepdim=True) + 1e-6)
|
530
565
|
return self.gamma * (x * nx) + self.beta + x
|
566
|
+
|
567
|
+
|
568
|
+
class LpNorm(nn.Module):
|
569
|
+
def __init__(self, p: int = 2, dim: int = -1, eps: float = 1e-12):
|
570
|
+
super().__init__()
|
571
|
+
|
572
|
+
self.p = p
|
573
|
+
self.dim = dim
|
574
|
+
self.eps = eps
|
575
|
+
|
576
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
577
|
+
return F.normalize(hidden_states, p=self.p, dim=self.dim, eps=self.eps)
|
578
|
+
|
579
|
+
|
580
|
+
def get_normalization(
|
581
|
+
norm_type: str = "batch_norm",
|
582
|
+
num_features: Optional[int] = None,
|
583
|
+
eps: float = 1e-5,
|
584
|
+
elementwise_affine: bool = True,
|
585
|
+
bias: bool = True,
|
586
|
+
) -> nn.Module:
|
587
|
+
if norm_type == "rms_norm":
|
588
|
+
norm = RMSNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
589
|
+
elif norm_type == "layer_norm":
|
590
|
+
norm = nn.LayerNorm(num_features, eps=eps, elementwise_affine=elementwise_affine, bias=bias)
|
591
|
+
elif norm_type == "batch_norm":
|
592
|
+
norm = nn.BatchNorm2d(num_features, eps=eps, affine=elementwise_affine)
|
593
|
+
else:
|
594
|
+
raise ValueError(f"{norm_type=} is not supported.")
|
595
|
+
return norm
|
@@ -11,10 +11,15 @@ if is_torch_available():
|
|
11
11
|
from .lumina_nextdit2d import LuminaNextDiT2DModel
|
12
12
|
from .pixart_transformer_2d import PixArtTransformer2DModel
|
13
13
|
from .prior_transformer import PriorTransformer
|
14
|
+
from .sana_transformer import SanaTransformer2DModel
|
14
15
|
from .stable_audio_transformer import StableAudioDiTModel
|
15
16
|
from .t5_film_transformer import T5FilmDecoder
|
16
17
|
from .transformer_2d import Transformer2DModel
|
18
|
+
from .transformer_allegro import AllegroTransformer3DModel
|
17
19
|
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
18
20
|
from .transformer_flux import FluxTransformer2DModel
|
21
|
+
from .transformer_hunyuan_video import HunyuanVideoTransformer3DModel
|
22
|
+
from .transformer_ltx import LTXVideoTransformer3DModel
|
23
|
+
from .transformer_mochi import MochiTransformer3DModel
|
19
24
|
from .transformer_sd3 import SD3Transformer2DModel
|
20
25
|
from .transformer_temporal import TransformerTemporalModel
|
@@ -466,7 +466,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
466
466
|
|
467
467
|
# MMDiT blocks.
|
468
468
|
for index_block, block in enumerate(self.joint_transformer_blocks):
|
469
|
-
if
|
469
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
470
470
|
|
471
471
|
def create_custom_forward(module, return_dict=None):
|
472
472
|
def custom_forward(*inputs):
|
@@ -497,7 +497,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
497
497
|
combined_hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
498
498
|
|
499
499
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
500
|
-
if
|
500
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
501
501
|
|
502
502
|
def create_custom_forward(module, return_dict=None):
|
503
503
|
def custom_forward(*inputs):
|
@@ -170,6 +170,8 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
170
170
|
Whether to flip the sin to cos in the time embedding.
|
171
171
|
time_embed_dim (`int`, defaults to `512`):
|
172
172
|
Output dimension of timestep embeddings.
|
173
|
+
ofs_embed_dim (`int`, defaults to `512`):
|
174
|
+
Output dimension of "ofs" embeddings used in CogVideoX-5b-I2B in version 1.5
|
173
175
|
text_embed_dim (`int`, defaults to `4096`):
|
174
176
|
Input dimension of text embeddings from the text encoder.
|
175
177
|
num_layers (`int`, defaults to `30`):
|
@@ -177,7 +179,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
177
179
|
dropout (`float`, defaults to `0.0`):
|
178
180
|
The dropout probability to use.
|
179
181
|
attention_bias (`bool`, defaults to `True`):
|
180
|
-
Whether
|
182
|
+
Whether to use bias in the attention projection layers.
|
181
183
|
sample_width (`int`, defaults to `90`):
|
182
184
|
The width of the input latents.
|
183
185
|
sample_height (`int`, defaults to `60`):
|
@@ -198,7 +200,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
198
200
|
timestep_activation_fn (`str`, defaults to `"silu"`):
|
199
201
|
Activation function to use when generating the timestep embeddings.
|
200
202
|
norm_elementwise_affine (`bool`, defaults to `True`):
|
201
|
-
Whether
|
203
|
+
Whether to use elementwise affine in normalization layers.
|
202
204
|
norm_eps (`float`, defaults to `1e-5`):
|
203
205
|
The epsilon value to use in normalization layers.
|
204
206
|
spatial_interpolation_scale (`float`, defaults to `1.875`):
|
@@ -219,6 +221,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
219
221
|
flip_sin_to_cos: bool = True,
|
220
222
|
freq_shift: int = 0,
|
221
223
|
time_embed_dim: int = 512,
|
224
|
+
ofs_embed_dim: Optional[int] = None,
|
222
225
|
text_embed_dim: int = 4096,
|
223
226
|
num_layers: int = 30,
|
224
227
|
dropout: float = 0.0,
|
@@ -227,6 +230,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
227
230
|
sample_height: int = 60,
|
228
231
|
sample_frames: int = 49,
|
229
232
|
patch_size: int = 2,
|
233
|
+
patch_size_t: Optional[int] = None,
|
230
234
|
temporal_compression_ratio: int = 4,
|
231
235
|
max_text_seq_length: int = 226,
|
232
236
|
activation_fn: str = "gelu-approximate",
|
@@ -237,6 +241,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
237
241
|
temporal_interpolation_scale: float = 1.0,
|
238
242
|
use_rotary_positional_embeddings: bool = False,
|
239
243
|
use_learned_positional_embeddings: bool = False,
|
244
|
+
patch_bias: bool = True,
|
240
245
|
):
|
241
246
|
super().__init__()
|
242
247
|
inner_dim = num_attention_heads * attention_head_dim
|
@@ -251,10 +256,11 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
251
256
|
# 1. Patch embedding
|
252
257
|
self.patch_embed = CogVideoXPatchEmbed(
|
253
258
|
patch_size=patch_size,
|
259
|
+
patch_size_t=patch_size_t,
|
254
260
|
in_channels=in_channels,
|
255
261
|
embed_dim=inner_dim,
|
256
262
|
text_embed_dim=text_embed_dim,
|
257
|
-
bias=
|
263
|
+
bias=patch_bias,
|
258
264
|
sample_width=sample_width,
|
259
265
|
sample_height=sample_height,
|
260
266
|
sample_frames=sample_frames,
|
@@ -267,10 +273,19 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
267
273
|
)
|
268
274
|
self.embedding_dropout = nn.Dropout(dropout)
|
269
275
|
|
270
|
-
# 2. Time embeddings
|
276
|
+
# 2. Time embeddings and ofs embedding(Only CogVideoX1.5-5B I2V have)
|
277
|
+
|
271
278
|
self.time_proj = Timesteps(inner_dim, flip_sin_to_cos, freq_shift)
|
272
279
|
self.time_embedding = TimestepEmbedding(inner_dim, time_embed_dim, timestep_activation_fn)
|
273
280
|
|
281
|
+
self.ofs_proj = None
|
282
|
+
self.ofs_embedding = None
|
283
|
+
if ofs_embed_dim:
|
284
|
+
self.ofs_proj = Timesteps(ofs_embed_dim, flip_sin_to_cos, freq_shift)
|
285
|
+
self.ofs_embedding = TimestepEmbedding(
|
286
|
+
ofs_embed_dim, ofs_embed_dim, timestep_activation_fn
|
287
|
+
) # same as time embeddings, for ofs
|
288
|
+
|
274
289
|
# 3. Define spatio-temporal transformers blocks
|
275
290
|
self.transformer_blocks = nn.ModuleList(
|
276
291
|
[
|
@@ -298,7 +313,15 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
298
313
|
norm_eps=norm_eps,
|
299
314
|
chunk_dim=1,
|
300
315
|
)
|
301
|
-
|
316
|
+
|
317
|
+
if patch_size_t is None:
|
318
|
+
# For CogVideox 1.0
|
319
|
+
output_dim = patch_size * patch_size * out_channels
|
320
|
+
else:
|
321
|
+
# For CogVideoX 1.5
|
322
|
+
output_dim = patch_size * patch_size * patch_size_t * out_channels
|
323
|
+
|
324
|
+
self.proj_out = nn.Linear(inner_dim, output_dim)
|
302
325
|
|
303
326
|
self.gradient_checkpointing = False
|
304
327
|
|
@@ -411,6 +434,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
411
434
|
encoder_hidden_states: torch.Tensor,
|
412
435
|
timestep: Union[int, float, torch.LongTensor],
|
413
436
|
timestep_cond: Optional[torch.Tensor] = None,
|
437
|
+
ofs: Optional[Union[int, float, torch.LongTensor]] = None,
|
414
438
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
415
439
|
attention_kwargs: Optional[Dict[str, Any]] = None,
|
416
440
|
return_dict: bool = True,
|
@@ -442,6 +466,12 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
442
466
|
t_emb = t_emb.to(dtype=hidden_states.dtype)
|
443
467
|
emb = self.time_embedding(t_emb, timestep_cond)
|
444
468
|
|
469
|
+
if self.ofs_embedding is not None:
|
470
|
+
ofs_emb = self.ofs_proj(ofs)
|
471
|
+
ofs_emb = ofs_emb.to(dtype=hidden_states.dtype)
|
472
|
+
ofs_emb = self.ofs_embedding(ofs_emb)
|
473
|
+
emb = emb + ofs_emb
|
474
|
+
|
445
475
|
# 2. Patch embedding
|
446
476
|
hidden_states = self.patch_embed(encoder_hidden_states, hidden_states)
|
447
477
|
hidden_states = self.embedding_dropout(hidden_states)
|
@@ -452,7 +482,7 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
452
482
|
|
453
483
|
# 3. Transformer blocks
|
454
484
|
for i, block in enumerate(self.transformer_blocks):
|
455
|
-
if
|
485
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
456
486
|
|
457
487
|
def create_custom_forward(module):
|
458
488
|
def custom_forward(*inputs):
|
@@ -491,12 +521,17 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
491
521
|
hidden_states = self.proj_out(hidden_states)
|
492
522
|
|
493
523
|
# 5. Unpatchify
|
494
|
-
# Note: we use `-1` instead of `channels`:
|
495
|
-
# - It is okay to `channels` use for CogVideoX-2b and CogVideoX-5b (number of input channels is equal to output channels)
|
496
|
-
# - However, for CogVideoX-5b-I2V also takes concatenated input image latents (number of input channels is twice the output channels)
|
497
524
|
p = self.config.patch_size
|
498
|
-
|
499
|
-
|
525
|
+
p_t = self.config.patch_size_t
|
526
|
+
|
527
|
+
if p_t is None:
|
528
|
+
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
529
|
+
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
530
|
+
else:
|
531
|
+
output = hidden_states.reshape(
|
532
|
+
batch_size, (num_frames + p_t - 1) // p_t, height // p, width // p, -1, p_t, p, p
|
533
|
+
)
|
534
|
+
output = output.permute(0, 1, 5, 4, 2, 6, 3, 7).flatten(6, 7).flatten(4, 5).flatten(1, 2)
|
500
535
|
|
501
536
|
if USE_PEFT_BACKEND:
|
502
537
|
# remove `lora_scale` from each PEFT layer
|
@@ -184,7 +184,7 @@ class DiTTransformer2DModel(ModelMixin, ConfigMixin):
|
|
184
184
|
|
185
185
|
# 2. Blocks
|
186
186
|
for block in self.transformer_blocks:
|
187
|
-
if
|
187
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
188
188
|
|
189
189
|
def create_custom_forward(module, return_dict=None):
|
190
190
|
def custom_forward(*inputs):
|