diffusers 0.30.3__py3-none-any.whl → 0.31.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +34 -2
- diffusers/configuration_utils.py +12 -0
- diffusers/dependency_versions_table.py +1 -1
- diffusers/image_processor.py +257 -54
- diffusers/loaders/__init__.py +2 -0
- diffusers/loaders/ip_adapter.py +5 -1
- diffusers/loaders/lora_base.py +14 -7
- diffusers/loaders/lora_conversion_utils.py +332 -0
- diffusers/loaders/lora_pipeline.py +707 -41
- diffusers/loaders/peft.py +1 -0
- diffusers/loaders/single_file_utils.py +81 -4
- diffusers/loaders/textual_inversion.py +2 -0
- diffusers/loaders/unet.py +39 -8
- diffusers/models/__init__.py +4 -0
- diffusers/models/adapter.py +53 -53
- diffusers/models/attention.py +86 -10
- diffusers/models/attention_processor.py +169 -133
- diffusers/models/autoencoders/autoencoder_kl.py +71 -11
- diffusers/models/autoencoders/autoencoder_kl_cogvideox.py +187 -88
- diffusers/models/controlnet_flux.py +536 -0
- diffusers/models/controlnet_sd3.py +7 -3
- diffusers/models/controlnet_sparsectrl.py +0 -1
- diffusers/models/embeddings.py +170 -61
- diffusers/models/embeddings_flax.py +23 -9
- diffusers/models/model_loading_utils.py +182 -14
- diffusers/models/modeling_utils.py +283 -46
- diffusers/models/normalization.py +79 -0
- diffusers/models/transformers/__init__.py +1 -0
- diffusers/models/transformers/auraflow_transformer_2d.py +1 -0
- diffusers/models/transformers/cogvideox_transformer_3d.py +23 -2
- diffusers/models/transformers/pixart_transformer_2d.py +9 -1
- diffusers/models/transformers/transformer_cogview3plus.py +386 -0
- diffusers/models/transformers/transformer_flux.py +161 -44
- diffusers/models/transformers/transformer_sd3.py +7 -1
- diffusers/models/unets/unet_2d_condition.py +8 -8
- diffusers/models/unets/unet_motion_model.py +41 -63
- diffusers/models/upsampling.py +6 -6
- diffusers/pipelines/__init__.py +35 -6
- diffusers/pipelines/animatediff/__init__.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff.py +45 -21
- diffusers/pipelines/animatediff/pipeline_animatediff_controlnet.py +44 -20
- diffusers/pipelines/animatediff/pipeline_animatediff_sdxl.py +18 -4
- diffusers/pipelines/animatediff/pipeline_animatediff_sparsectrl.py +2 -0
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video.py +104 -66
- diffusers/pipelines/animatediff/pipeline_animatediff_video2video_controlnet.py +1341 -0
- diffusers/pipelines/aura_flow/pipeline_aura_flow.py +1 -1
- diffusers/pipelines/auto_pipeline.py +39 -8
- diffusers/pipelines/cogvideo/__init__.py +2 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +30 -17
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +794 -0
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +41 -31
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +42 -29
- diffusers/pipelines/cogview3/__init__.py +47 -0
- diffusers/pipelines/cogview3/pipeline_cogview3plus.py +674 -0
- diffusers/pipelines/cogview3/pipeline_output.py +21 -0
- diffusers/pipelines/controlnet/pipeline_controlnet.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_img2img.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint.py +8 -0
- diffusers/pipelines/controlnet/pipeline_controlnet_inpaint_sd_xl.py +36 -13
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl.py +9 -1
- diffusers/pipelines/controlnet/pipeline_controlnet_sd_xl_img2img.py +8 -1
- diffusers/pipelines/controlnet_hunyuandit/pipeline_hunyuandit_controlnet.py +17 -3
- diffusers/pipelines/controlnet_sd3/__init__.py +4 -0
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet.py +3 -1
- diffusers/pipelines/controlnet_sd3/pipeline_stable_diffusion_3_controlnet_inpainting.py +1153 -0
- diffusers/pipelines/ddpm/pipeline_ddpm.py +2 -2
- diffusers/pipelines/deepfloyd_if/pipeline_output.py +6 -5
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion.py +16 -4
- diffusers/pipelines/deprecated/alt_diffusion/pipeline_alt_diffusion_img2img.py +1 -1
- diffusers/pipelines/deprecated/versatile_diffusion/modeling_text_unet.py +1 -1
- diffusers/pipelines/flux/__init__.py +10 -0
- diffusers/pipelines/flux/pipeline_flux.py +53 -20
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +984 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +988 -0
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1182 -0
- diffusers/pipelines/flux/pipeline_flux_img2img.py +850 -0
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1015 -0
- diffusers/pipelines/free_noise_utils.py +365 -5
- diffusers/pipelines/hunyuandit/pipeline_hunyuandit.py +15 -3
- diffusers/pipelines/kandinsky2_2/pipeline_kandinsky2_2_combined.py +2 -2
- diffusers/pipelines/kolors/pipeline_kolors.py +1 -1
- diffusers/pipelines/kolors/pipeline_kolors_img2img.py +14 -11
- diffusers/pipelines/kolors/tokenizer.py +4 -0
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_img2img.py +1 -1
- diffusers/pipelines/latent_consistency_models/pipeline_latent_consistency_text2img.py +1 -1
- diffusers/pipelines/latte/pipeline_latte.py +2 -2
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion.py +15 -3
- diffusers/pipelines/ledits_pp/pipeline_leditspp_stable_diffusion_xl.py +15 -3
- diffusers/pipelines/lumina/pipeline_lumina.py +2 -2
- diffusers/pipelines/pag/__init__.py +6 -0
- diffusers/pipelines/pag/pag_utils.py +8 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_inpaint.py +1544 -0
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl.py +2 -2
- diffusers/pipelines/pag/pipeline_pag_controlnet_sd_xl_img2img.py +1685 -0
- diffusers/pipelines/pag/pipeline_pag_hunyuandit.py +17 -5
- diffusers/pipelines/pag/pipeline_pag_kolors.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_pixart_sigma.py +1 -1
- diffusers/pipelines/pag/pipeline_pag_sd.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_3.py +12 -3
- diffusers/pipelines/pag/pipeline_pag_sd_animatediff.py +5 -1
- diffusers/pipelines/pag/pipeline_pag_sd_img2img.py +1091 -0
- diffusers/pipelines/pag/pipeline_pag_sd_xl.py +18 -6
- diffusers/pipelines/pag/pipeline_pag_sd_xl_img2img.py +31 -16
- diffusers/pipelines/pag/pipeline_pag_sd_xl_inpaint.py +42 -19
- diffusers/pipelines/pia/pipeline_pia.py +2 -0
- diffusers/pipelines/pipeline_loading_utils.py +225 -27
- diffusers/pipelines/pipeline_utils.py +123 -180
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +1 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +1 -1
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade.py +35 -3
- diffusers/pipelines/stable_cascade/pipeline_stable_cascade_prior.py +2 -2
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +28 -6
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_latent_upscale.py +241 -81
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +12 -3
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_img2img.py +20 -4
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3_inpaint.py +3 -3
- diffusers/pipelines/stable_diffusion_k_diffusion/pipeline_stable_diffusion_k_diffusion.py +1 -1
- diffusers/pipelines/stable_diffusion_ldm3d/pipeline_stable_diffusion_ldm3d.py +16 -4
- diffusers/pipelines/stable_diffusion_panorama/pipeline_stable_diffusion_panorama.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl.py +16 -4
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_img2img.py +29 -14
- diffusers/pipelines/stable_diffusion_xl/pipeline_stable_diffusion_xl_inpaint.py +29 -14
- diffusers/pipelines/stable_video_diffusion/pipeline_stable_video_diffusion.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_adapter.py +1 -1
- diffusers/pipelines/t2i_adapter/pipeline_stable_diffusion_xl_adapter.py +16 -4
- diffusers/pipelines/text_to_video_synthesis/pipeline_text_to_video_zero_sdxl.py +15 -3
- diffusers/quantizers/__init__.py +16 -0
- diffusers/quantizers/auto.py +126 -0
- diffusers/quantizers/base.py +233 -0
- diffusers/quantizers/bitsandbytes/__init__.py +2 -0
- diffusers/quantizers/bitsandbytes/bnb_quantizer.py +558 -0
- diffusers/quantizers/bitsandbytes/utils.py +306 -0
- diffusers/quantizers/quantization_config.py +391 -0
- diffusers/schedulers/scheduling_ddim.py +4 -1
- diffusers/schedulers/scheduling_ddim_cogvideox.py +4 -1
- diffusers/schedulers/scheduling_ddim_parallel.py +4 -1
- diffusers/schedulers/scheduling_ddpm.py +4 -1
- diffusers/schedulers/scheduling_ddpm_parallel.py +4 -1
- diffusers/schedulers/scheduling_deis_multistep.py +78 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +82 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep_inverse.py +80 -1
- diffusers/schedulers/scheduling_dpmsolver_sde.py +125 -10
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +82 -1
- diffusers/schedulers/scheduling_edm_euler.py +8 -6
- diffusers/schedulers/scheduling_euler_ancestral_discrete.py +4 -1
- diffusers/schedulers/scheduling_euler_discrete.py +92 -7
- diffusers/schedulers/scheduling_flow_match_heun_discrete.py +4 -5
- diffusers/schedulers/scheduling_heun_discrete.py +114 -8
- diffusers/schedulers/scheduling_k_dpm_2_ancestral_discrete.py +116 -11
- diffusers/schedulers/scheduling_k_dpm_2_discrete.py +110 -8
- diffusers/schedulers/scheduling_lms_discrete.py +76 -1
- diffusers/schedulers/scheduling_sasolver.py +78 -1
- diffusers/schedulers/scheduling_unclip.py +4 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +78 -1
- diffusers/training_utils.py +48 -18
- diffusers/utils/__init__.py +2 -1
- diffusers/utils/dummy_pt_objects.py +60 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +165 -0
- diffusers/utils/hub_utils.py +16 -4
- diffusers/utils/import_utils.py +31 -8
- diffusers/utils/loading_utils.py +28 -4
- diffusers/utils/peft_utils.py +3 -3
- diffusers/utils/testing_utils.py +59 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/METADATA +7 -6
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/RECORD +172 -149
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/LICENSE +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/WHEEL +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.30.3.dist-info → diffusers-0.31.0.dist-info}/top_level.txt +0 -0
@@ -14,13 +14,14 @@
|
|
14
14
|
# See the License for the specific language governing permissions and
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
|
+
import copy
|
17
18
|
import inspect
|
18
19
|
import itertools
|
19
20
|
import json
|
20
21
|
import os
|
21
22
|
import re
|
22
23
|
from collections import OrderedDict
|
23
|
-
from functools import partial
|
24
|
+
from functools import partial, wraps
|
24
25
|
from pathlib import Path
|
25
26
|
from typing import Any, Callable, List, Optional, Tuple, Union
|
26
27
|
|
@@ -31,6 +32,8 @@ from huggingface_hub.utils import validate_hf_hub_args
|
|
31
32
|
from torch import Tensor, nn
|
32
33
|
|
33
34
|
from .. import __version__
|
35
|
+
from ..quantizers import DiffusersAutoQuantizer, DiffusersQuantizer
|
36
|
+
from ..quantizers.quantization_config import QuantizationMethod
|
34
37
|
from ..utils import (
|
35
38
|
CONFIG_NAME,
|
36
39
|
FLAX_WEIGHTS_NAME,
|
@@ -43,6 +46,8 @@ from ..utils import (
|
|
43
46
|
_get_model_file,
|
44
47
|
deprecate,
|
45
48
|
is_accelerate_available,
|
49
|
+
is_bitsandbytes_available,
|
50
|
+
is_bitsandbytes_version,
|
46
51
|
is_torch_version,
|
47
52
|
logging,
|
48
53
|
)
|
@@ -54,7 +59,9 @@ from ..utils.hub_utils import (
|
|
54
59
|
from .model_loading_utils import (
|
55
60
|
_determine_device_map,
|
56
61
|
_fetch_index_file,
|
62
|
+
_fetch_index_file_legacy,
|
57
63
|
_load_state_dict_into_model,
|
64
|
+
_merge_sharded_checkpoints,
|
58
65
|
load_model_dict_into_meta,
|
59
66
|
load_state_dict,
|
60
67
|
)
|
@@ -93,24 +100,20 @@ def get_parameter_device(parameter: torch.nn.Module) -> torch.device:
|
|
93
100
|
|
94
101
|
def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
95
102
|
try:
|
96
|
-
|
97
|
-
if len(params) > 0:
|
98
|
-
return params[0].dtype
|
99
|
-
|
100
|
-
buffers = tuple(parameter.buffers())
|
101
|
-
if len(buffers) > 0:
|
102
|
-
return buffers[0].dtype
|
103
|
-
|
103
|
+
return next(parameter.parameters()).dtype
|
104
104
|
except StopIteration:
|
105
|
-
|
105
|
+
try:
|
106
|
+
return next(parameter.buffers()).dtype
|
107
|
+
except StopIteration:
|
108
|
+
# For torch.nn.DataParallel compatibility in PyTorch 1.5
|
106
109
|
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
+
def find_tensor_attributes(module: torch.nn.Module) -> List[Tuple[str, Tensor]]:
|
111
|
+
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
112
|
+
return tuples
|
110
113
|
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
+
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
115
|
+
first_tuple = next(gen)
|
116
|
+
return first_tuple[1].dtype
|
114
117
|
|
115
118
|
|
116
119
|
class ModelMixin(torch.nn.Module, PushToHubMixin):
|
@@ -128,6 +131,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
128
131
|
_supports_gradient_checkpointing = False
|
129
132
|
_keys_to_ignore_on_load_unexpected = None
|
130
133
|
_no_split_modules = None
|
134
|
+
_keep_in_fp32_modules = None
|
131
135
|
|
132
136
|
def __init__(self):
|
133
137
|
super().__init__()
|
@@ -311,13 +315,24 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
311
315
|
logger.error(f"Provided path ({save_directory}) should be a directory, not a file")
|
312
316
|
return
|
313
317
|
|
318
|
+
hf_quantizer = getattr(self, "hf_quantizer", None)
|
319
|
+
if hf_quantizer is not None:
|
320
|
+
quantization_serializable = (
|
321
|
+
hf_quantizer is not None
|
322
|
+
and isinstance(hf_quantizer, DiffusersQuantizer)
|
323
|
+
and hf_quantizer.is_serializable
|
324
|
+
)
|
325
|
+
if not quantization_serializable:
|
326
|
+
raise ValueError(
|
327
|
+
f"The model is quantized with {hf_quantizer.quantization_config.quant_method} and is not serializable - check out the warnings from"
|
328
|
+
" the logger on the traceback to understand the reason why the quantized model is not serializable."
|
329
|
+
)
|
330
|
+
|
314
331
|
weights_name = SAFETENSORS_WEIGHTS_NAME if safe_serialization else WEIGHTS_NAME
|
315
332
|
weights_name = _add_variant(weights_name, variant)
|
316
|
-
|
317
|
-
|
318
|
-
|
319
|
-
else:
|
320
|
-
raise ValueError(f"Invalid {weights_name} provided.")
|
333
|
+
weights_name_pattern = weights_name.replace(".bin", "{suffix}.bin").replace(
|
334
|
+
".safetensors", "{suffix}.safetensors"
|
335
|
+
)
|
321
336
|
|
322
337
|
os.makedirs(save_directory, exist_ok=True)
|
323
338
|
|
@@ -407,6 +422,18 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
407
422
|
create_pr=create_pr,
|
408
423
|
)
|
409
424
|
|
425
|
+
def dequantize(self):
|
426
|
+
"""
|
427
|
+
Potentially dequantize the model in case it has been quantized by a quantization method that support
|
428
|
+
dequantization.
|
429
|
+
"""
|
430
|
+
hf_quantizer = getattr(self, "hf_quantizer", None)
|
431
|
+
|
432
|
+
if hf_quantizer is None:
|
433
|
+
raise ValueError("You need to first quantize your model in order to dequantize it")
|
434
|
+
|
435
|
+
return hf_quantizer.dequantize(self)
|
436
|
+
|
410
437
|
@classmethod
|
411
438
|
@validate_hf_hub_args
|
412
439
|
def from_pretrained(cls, pretrained_model_name_or_path: Optional[Union[str, os.PathLike]], **kwargs):
|
@@ -529,6 +556,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
529
556
|
low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT)
|
530
557
|
variant = kwargs.pop("variant", None)
|
531
558
|
use_safetensors = kwargs.pop("use_safetensors", None)
|
559
|
+
quantization_config = kwargs.pop("quantization_config", None)
|
532
560
|
|
533
561
|
allow_pickle = False
|
534
562
|
if use_safetensors is None:
|
@@ -623,26 +651,85 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
623
651
|
user_agent=user_agent,
|
624
652
|
**kwargs,
|
625
653
|
)
|
654
|
+
# no in-place modification of the original config.
|
655
|
+
config = copy.deepcopy(config)
|
656
|
+
|
657
|
+
# determine initial quantization config.
|
658
|
+
#######################################
|
659
|
+
pre_quantized = "quantization_config" in config and config["quantization_config"] is not None
|
660
|
+
if pre_quantized or quantization_config is not None:
|
661
|
+
if pre_quantized:
|
662
|
+
config["quantization_config"] = DiffusersAutoQuantizer.merge_quantization_configs(
|
663
|
+
config["quantization_config"], quantization_config
|
664
|
+
)
|
665
|
+
else:
|
666
|
+
config["quantization_config"] = quantization_config
|
667
|
+
hf_quantizer = DiffusersAutoQuantizer.from_config(
|
668
|
+
config["quantization_config"], pre_quantized=pre_quantized
|
669
|
+
)
|
670
|
+
else:
|
671
|
+
hf_quantizer = None
|
672
|
+
|
673
|
+
if hf_quantizer is not None:
|
674
|
+
if device_map is not None:
|
675
|
+
raise NotImplementedError(
|
676
|
+
"Currently, `device_map` is automatically inferred for quantized models. Support for providing `device_map` as an input will be added in the future."
|
677
|
+
)
|
678
|
+
hf_quantizer.validate_environment(torch_dtype=torch_dtype, from_flax=from_flax, device_map=device_map)
|
679
|
+
torch_dtype = hf_quantizer.update_torch_dtype(torch_dtype)
|
680
|
+
|
681
|
+
# In order to ensure popular quantization methods are supported. Can be disable with `disable_telemetry`
|
682
|
+
user_agent["quant"] = hf_quantizer.quantization_config.quant_method.value
|
683
|
+
|
684
|
+
# Force-set to `True` for more mem efficiency
|
685
|
+
if low_cpu_mem_usage is None:
|
686
|
+
low_cpu_mem_usage = True
|
687
|
+
logger.info("Set `low_cpu_mem_usage` to True as `hf_quantizer` is not None.")
|
688
|
+
elif not low_cpu_mem_usage:
|
689
|
+
raise ValueError("`low_cpu_mem_usage` cannot be False or None when using quantization.")
|
690
|
+
|
691
|
+
# Check if `_keep_in_fp32_modules` is not None
|
692
|
+
use_keep_in_fp32_modules = (cls._keep_in_fp32_modules is not None) and (
|
693
|
+
(torch_dtype == torch.float16) or hasattr(hf_quantizer, "use_keep_in_fp32_modules")
|
694
|
+
)
|
695
|
+
if use_keep_in_fp32_modules:
|
696
|
+
keep_in_fp32_modules = cls._keep_in_fp32_modules
|
697
|
+
if not isinstance(keep_in_fp32_modules, list):
|
698
|
+
keep_in_fp32_modules = [keep_in_fp32_modules]
|
699
|
+
|
700
|
+
if low_cpu_mem_usage is None:
|
701
|
+
low_cpu_mem_usage = True
|
702
|
+
logger.info("Set `low_cpu_mem_usage` to True as `_keep_in_fp32_modules` is not None.")
|
703
|
+
elif not low_cpu_mem_usage:
|
704
|
+
raise ValueError("`low_cpu_mem_usage` cannot be False when `keep_in_fp32_modules` is True.")
|
705
|
+
else:
|
706
|
+
keep_in_fp32_modules = []
|
707
|
+
#######################################
|
626
708
|
|
627
709
|
# Determine if we're loading from a directory of sharded checkpoints.
|
628
710
|
is_sharded = False
|
629
711
|
index_file = None
|
630
712
|
is_local = os.path.isdir(pretrained_model_name_or_path)
|
631
|
-
|
632
|
-
is_local
|
633
|
-
pretrained_model_name_or_path
|
634
|
-
subfolder
|
635
|
-
use_safetensors
|
636
|
-
cache_dir
|
637
|
-
variant
|
638
|
-
force_download
|
639
|
-
proxies
|
640
|
-
local_files_only
|
641
|
-
token
|
642
|
-
revision
|
643
|
-
user_agent
|
644
|
-
commit_hash
|
645
|
-
|
713
|
+
index_file_kwargs = {
|
714
|
+
"is_local": is_local,
|
715
|
+
"pretrained_model_name_or_path": pretrained_model_name_or_path,
|
716
|
+
"subfolder": subfolder or "",
|
717
|
+
"use_safetensors": use_safetensors,
|
718
|
+
"cache_dir": cache_dir,
|
719
|
+
"variant": variant,
|
720
|
+
"force_download": force_download,
|
721
|
+
"proxies": proxies,
|
722
|
+
"local_files_only": local_files_only,
|
723
|
+
"token": token,
|
724
|
+
"revision": revision,
|
725
|
+
"user_agent": user_agent,
|
726
|
+
"commit_hash": commit_hash,
|
727
|
+
}
|
728
|
+
index_file = _fetch_index_file(**index_file_kwargs)
|
729
|
+
# In case the index file was not found we still have to consider the legacy format.
|
730
|
+
# this becomes applicable when the variant is not None.
|
731
|
+
if variant is not None and (index_file is None or not os.path.exists(index_file)):
|
732
|
+
index_file = _fetch_index_file_legacy(**index_file_kwargs)
|
646
733
|
if index_file is not None and index_file.is_file():
|
647
734
|
is_sharded = True
|
648
735
|
|
@@ -684,6 +771,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
684
771
|
revision=revision,
|
685
772
|
subfolder=subfolder or "",
|
686
773
|
)
|
774
|
+
if hf_quantizer is not None:
|
775
|
+
model_file = _merge_sharded_checkpoints(sharded_ckpt_cached_folder, sharded_metadata)
|
776
|
+
logger.info("Merged sharded checkpoints as `hf_quantizer` is not None.")
|
777
|
+
is_sharded = False
|
687
778
|
|
688
779
|
elif use_safetensors and not is_sharded:
|
689
780
|
try:
|
@@ -729,13 +820,30 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
729
820
|
with accelerate.init_empty_weights():
|
730
821
|
model = cls.from_config(config, **unused_kwargs)
|
731
822
|
|
823
|
+
if hf_quantizer is not None:
|
824
|
+
hf_quantizer.preprocess_model(
|
825
|
+
model=model, device_map=device_map, keep_in_fp32_modules=keep_in_fp32_modules
|
826
|
+
)
|
827
|
+
|
732
828
|
# if device_map is None, load the state dict and move the params from meta device to the cpu
|
733
829
|
if device_map is None and not is_sharded:
|
734
|
-
|
830
|
+
# `torch.cuda.current_device()` is fine here when `hf_quantizer` is not None.
|
831
|
+
# It would error out during the `validate_environment()` call above in the absence of cuda.
|
832
|
+
is_quant_method_bnb = (
|
833
|
+
getattr(model, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES
|
834
|
+
)
|
835
|
+
if hf_quantizer is None:
|
836
|
+
param_device = "cpu"
|
837
|
+
# TODO (sayakpaul, SunMarc): remove this after model loading refactor
|
838
|
+
elif is_quant_method_bnb:
|
839
|
+
param_device = torch.cuda.current_device()
|
735
840
|
state_dict = load_state_dict(model_file, variant=variant)
|
736
841
|
model._convert_deprecated_attention_blocks(state_dict)
|
842
|
+
|
737
843
|
# move the params from meta device to cpu
|
738
844
|
missing_keys = set(model.state_dict().keys()) - set(state_dict.keys())
|
845
|
+
if hf_quantizer is not None:
|
846
|
+
missing_keys = hf_quantizer.update_missing_keys(model, missing_keys, prefix="")
|
739
847
|
if len(missing_keys) > 0:
|
740
848
|
raise ValueError(
|
741
849
|
f"Cannot load {cls} from {pretrained_model_name_or_path} because the following keys are"
|
@@ -750,6 +858,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
750
858
|
device=param_device,
|
751
859
|
dtype=torch_dtype,
|
752
860
|
model_name_or_path=pretrained_model_name_or_path,
|
861
|
+
hf_quantizer=hf_quantizer,
|
862
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
753
863
|
)
|
754
864
|
|
755
865
|
if cls._keys_to_ignore_on_load_unexpected is not None:
|
@@ -765,7 +875,9 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
765
875
|
# Load weights and dispatch according to the device_map
|
766
876
|
# by default the device_map is None and the weights are loaded on the CPU
|
767
877
|
force_hook = True
|
768
|
-
device_map = _determine_device_map(
|
878
|
+
device_map = _determine_device_map(
|
879
|
+
model, device_map, max_memory, torch_dtype, keep_in_fp32_modules, hf_quantizer
|
880
|
+
)
|
769
881
|
if device_map is None and is_sharded:
|
770
882
|
# we load the parameters on the cpu
|
771
883
|
device_map = {"": "cpu"}
|
@@ -843,14 +955,25 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
843
955
|
"error_msgs": error_msgs,
|
844
956
|
}
|
845
957
|
|
958
|
+
if hf_quantizer is not None:
|
959
|
+
hf_quantizer.postprocess_model(model)
|
960
|
+
model.hf_quantizer = hf_quantizer
|
961
|
+
|
846
962
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
847
963
|
raise ValueError(
|
848
964
|
f"{torch_dtype} needs to be of type `torch.dtype`, e.g. `torch.float16`, but is {type(torch_dtype)}."
|
849
965
|
)
|
850
|
-
|
966
|
+
# When using `use_keep_in_fp32_modules` if we do a global `to()` here, then we will
|
967
|
+
# completely lose the effectivity of `use_keep_in_fp32_modules`.
|
968
|
+
elif torch_dtype is not None and hf_quantizer is None and not use_keep_in_fp32_modules:
|
851
969
|
model = model.to(torch_dtype)
|
852
970
|
|
853
|
-
|
971
|
+
if hf_quantizer is not None:
|
972
|
+
# We also make sure to purge `_pre_quantization_dtype` when we serialize
|
973
|
+
# the model config because `_pre_quantization_dtype` is `torch.dtype`, not JSON serializable.
|
974
|
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path, _pre_quantization_dtype=torch_dtype)
|
975
|
+
else:
|
976
|
+
model.register_to_config(_name_or_path=pretrained_model_name_or_path)
|
854
977
|
|
855
978
|
# Set model in evaluation mode to deactivate DropOut modules by default
|
856
979
|
model.eval()
|
@@ -859,6 +982,76 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
859
982
|
|
860
983
|
return model
|
861
984
|
|
985
|
+
# Adapted from `transformers`.
|
986
|
+
@wraps(torch.nn.Module.cuda)
|
987
|
+
def cuda(self, *args, **kwargs):
|
988
|
+
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
989
|
+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
990
|
+
if getattr(self, "is_loaded_in_8bit", False):
|
991
|
+
raise ValueError(
|
992
|
+
"Calling `cuda()` is not supported for `8-bit` quantized models. "
|
993
|
+
" Please use the model as it is, since the model has already been set to the correct devices."
|
994
|
+
)
|
995
|
+
elif is_bitsandbytes_version("<", "0.43.2"):
|
996
|
+
raise ValueError(
|
997
|
+
"Calling `cuda()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
998
|
+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
999
|
+
)
|
1000
|
+
return super().cuda(*args, **kwargs)
|
1001
|
+
|
1002
|
+
# Adapted from `transformers`.
|
1003
|
+
@wraps(torch.nn.Module.to)
|
1004
|
+
def to(self, *args, **kwargs):
|
1005
|
+
dtype_present_in_args = "dtype" in kwargs
|
1006
|
+
|
1007
|
+
if not dtype_present_in_args:
|
1008
|
+
for arg in args:
|
1009
|
+
if isinstance(arg, torch.dtype):
|
1010
|
+
dtype_present_in_args = True
|
1011
|
+
break
|
1012
|
+
|
1013
|
+
# Checks if the model has been loaded in 4-bit or 8-bit with BNB
|
1014
|
+
if getattr(self, "quantization_method", None) == QuantizationMethod.BITS_AND_BYTES:
|
1015
|
+
if dtype_present_in_args:
|
1016
|
+
raise ValueError(
|
1017
|
+
"You cannot cast a bitsandbytes model in a new `dtype`. Make sure to load the model using `from_pretrained` using the"
|
1018
|
+
" desired `dtype` by passing the correct `torch_dtype` argument."
|
1019
|
+
)
|
1020
|
+
|
1021
|
+
if getattr(self, "is_loaded_in_8bit", False):
|
1022
|
+
raise ValueError(
|
1023
|
+
"`.to` is not supported for `8-bit` bitsandbytes models. Please use the model as it is, since the"
|
1024
|
+
" model has already been set to the correct devices and casted to the correct `dtype`."
|
1025
|
+
)
|
1026
|
+
elif is_bitsandbytes_version("<", "0.43.2"):
|
1027
|
+
raise ValueError(
|
1028
|
+
"Calling `to()` is not supported for `4-bit` quantized models with the installed version of bitsandbytes. "
|
1029
|
+
f"The current device is `{self.device}`. If you intended to move the model, please install bitsandbytes >= 0.43.2."
|
1030
|
+
)
|
1031
|
+
return super().to(*args, **kwargs)
|
1032
|
+
|
1033
|
+
# Taken from `transformers`.
|
1034
|
+
def half(self, *args):
|
1035
|
+
# Checks if the model is quantized
|
1036
|
+
if getattr(self, "is_quantized", False):
|
1037
|
+
raise ValueError(
|
1038
|
+
"`.half()` is not supported for quantized model. Please use the model as it is, since the"
|
1039
|
+
" model has already been cast to the correct `dtype`."
|
1040
|
+
)
|
1041
|
+
else:
|
1042
|
+
return super().half(*args)
|
1043
|
+
|
1044
|
+
# Taken from `transformers`.
|
1045
|
+
def float(self, *args):
|
1046
|
+
# Checks if the model is quantized
|
1047
|
+
if getattr(self, "is_quantized", False):
|
1048
|
+
raise ValueError(
|
1049
|
+
"`.float()` is not supported for quantized model. Please use the model as it is, since the"
|
1050
|
+
" model has already been cast to the correct `dtype`."
|
1051
|
+
)
|
1052
|
+
else:
|
1053
|
+
return super().float(*args)
|
1054
|
+
|
862
1055
|
@classmethod
|
863
1056
|
def _load_pretrained_model(
|
864
1057
|
cls,
|
@@ -1041,19 +1234,63 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1041
1234
|
859520964
|
1042
1235
|
```
|
1043
1236
|
"""
|
1237
|
+
is_loaded_in_4bit = getattr(self, "is_loaded_in_4bit", False)
|
1238
|
+
|
1239
|
+
if is_loaded_in_4bit:
|
1240
|
+
if is_bitsandbytes_available():
|
1241
|
+
import bitsandbytes as bnb
|
1242
|
+
else:
|
1243
|
+
raise ValueError(
|
1244
|
+
"bitsandbytes is not installed but it seems that the model has been loaded in 4bit precision, something went wrong"
|
1245
|
+
" make sure to install bitsandbytes with `pip install bitsandbytes`. You also need a GPU. "
|
1246
|
+
)
|
1044
1247
|
|
1045
1248
|
if exclude_embeddings:
|
1046
1249
|
embedding_param_names = [
|
1047
|
-
f"{name}.weight"
|
1048
|
-
for name, module_type in self.named_modules()
|
1049
|
-
if isinstance(module_type, torch.nn.Embedding)
|
1250
|
+
f"{name}.weight" for name, module_type in self.named_modules() if isinstance(module_type, nn.Embedding)
|
1050
1251
|
]
|
1051
|
-
|
1252
|
+
total_parameters = [
|
1052
1253
|
parameter for name, parameter in self.named_parameters() if name not in embedding_param_names
|
1053
1254
|
]
|
1054
|
-
return sum(p.numel() for p in non_embedding_parameters if p.requires_grad or not only_trainable)
|
1055
1255
|
else:
|
1056
|
-
|
1256
|
+
total_parameters = list(self.parameters())
|
1257
|
+
|
1258
|
+
total_numel = []
|
1259
|
+
|
1260
|
+
for param in total_parameters:
|
1261
|
+
if param.requires_grad or not only_trainable:
|
1262
|
+
# For 4bit models, we need to multiply the number of parameters by 2 as half of the parameters are
|
1263
|
+
# used for the 4bit quantization (uint8 tensors are stored)
|
1264
|
+
if is_loaded_in_4bit and isinstance(param, bnb.nn.Params4bit):
|
1265
|
+
if hasattr(param, "element_size"):
|
1266
|
+
num_bytes = param.element_size()
|
1267
|
+
elif hasattr(param, "quant_storage"):
|
1268
|
+
num_bytes = param.quant_storage.itemsize
|
1269
|
+
else:
|
1270
|
+
num_bytes = 1
|
1271
|
+
total_numel.append(param.numel() * 2 * num_bytes)
|
1272
|
+
else:
|
1273
|
+
total_numel.append(param.numel())
|
1274
|
+
|
1275
|
+
return sum(total_numel)
|
1276
|
+
|
1277
|
+
def get_memory_footprint(self, return_buffers=True):
|
1278
|
+
r"""
|
1279
|
+
Get the memory footprint of a model. This will return the memory footprint of the current model in bytes.
|
1280
|
+
Useful to benchmark the memory footprint of the current model and design some tests. Solution inspired from the
|
1281
|
+
PyTorch discussions: https://discuss.pytorch.org/t/gpu-memory-that-model-uses/56822/2
|
1282
|
+
|
1283
|
+
Arguments:
|
1284
|
+
return_buffers (`bool`, *optional*, defaults to `True`):
|
1285
|
+
Whether to return the size of the buffer tensors in the computation of the memory footprint. Buffers
|
1286
|
+
are tensors that do not require gradients and not registered as parameters. E.g. mean and std in batch
|
1287
|
+
norm layers. Please see: https://discuss.pytorch.org/t/what-pytorch-means-by-buffers/120266/2
|
1288
|
+
"""
|
1289
|
+
mem = sum([param.nelement() * param.element_size() for param in self.parameters()])
|
1290
|
+
if return_buffers:
|
1291
|
+
mem_bufs = sum([buf.nelement() * buf.element_size() for buf in self.buffers()])
|
1292
|
+
mem = mem + mem_bufs
|
1293
|
+
return mem
|
1057
1294
|
|
1058
1295
|
def _convert_deprecated_attention_blocks(self, state_dict: OrderedDict) -> None:
|
1059
1296
|
deprecated_attention_block_paths = []
|
@@ -97,6 +97,40 @@ class FP32LayerNorm(nn.LayerNorm):
|
|
97
97
|
).to(origin_dtype)
|
98
98
|
|
99
99
|
|
100
|
+
class SD35AdaLayerNormZeroX(nn.Module):
|
101
|
+
r"""
|
102
|
+
Norm layer adaptive layer norm zero (AdaLN-Zero).
|
103
|
+
|
104
|
+
Parameters:
|
105
|
+
embedding_dim (`int`): The size of each embedding vector.
|
106
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
107
|
+
"""
|
108
|
+
|
109
|
+
def __init__(self, embedding_dim: int, norm_type: str = "layer_norm", bias: bool = True) -> None:
|
110
|
+
super().__init__()
|
111
|
+
|
112
|
+
self.silu = nn.SiLU()
|
113
|
+
self.linear = nn.Linear(embedding_dim, 9 * embedding_dim, bias=bias)
|
114
|
+
if norm_type == "layer_norm":
|
115
|
+
self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False, eps=1e-6)
|
116
|
+
else:
|
117
|
+
raise ValueError(f"Unsupported `norm_type` ({norm_type}) provided. Supported ones are: 'layer_norm'.")
|
118
|
+
|
119
|
+
def forward(
|
120
|
+
self,
|
121
|
+
hidden_states: torch.Tensor,
|
122
|
+
emb: Optional[torch.Tensor] = None,
|
123
|
+
) -> Tuple[torch.Tensor, ...]:
|
124
|
+
emb = self.linear(self.silu(emb))
|
125
|
+
shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp, shift_msa2, scale_msa2, gate_msa2 = emb.chunk(
|
126
|
+
9, dim=1
|
127
|
+
)
|
128
|
+
norm_hidden_states = self.norm(hidden_states)
|
129
|
+
hidden_states = norm_hidden_states * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
130
|
+
norm_hidden_states2 = norm_hidden_states * (1 + scale_msa2[:, None]) + shift_msa2[:, None]
|
131
|
+
return hidden_states, gate_msa, shift_mlp, scale_mlp, gate_mlp, norm_hidden_states2, gate_msa2
|
132
|
+
|
133
|
+
|
100
134
|
class AdaLayerNormZero(nn.Module):
|
101
135
|
r"""
|
102
136
|
Norm layer adaptive layer norm zero (adaLN-Zero).
|
@@ -355,6 +389,51 @@ class LuminaLayerNormContinuous(nn.Module):
|
|
355
389
|
return x
|
356
390
|
|
357
391
|
|
392
|
+
class CogView3PlusAdaLayerNormZeroTextImage(nn.Module):
|
393
|
+
r"""
|
394
|
+
Norm layer adaptive layer norm zero (adaLN-Zero).
|
395
|
+
|
396
|
+
Parameters:
|
397
|
+
embedding_dim (`int`): The size of each embedding vector.
|
398
|
+
num_embeddings (`int`): The size of the embeddings dictionary.
|
399
|
+
"""
|
400
|
+
|
401
|
+
def __init__(self, embedding_dim: int, dim: int):
|
402
|
+
super().__init__()
|
403
|
+
|
404
|
+
self.silu = nn.SiLU()
|
405
|
+
self.linear = nn.Linear(embedding_dim, 12 * dim, bias=True)
|
406
|
+
self.norm_x = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
407
|
+
self.norm_c = nn.LayerNorm(dim, elementwise_affine=False, eps=1e-5)
|
408
|
+
|
409
|
+
def forward(
|
410
|
+
self,
|
411
|
+
x: torch.Tensor,
|
412
|
+
context: torch.Tensor,
|
413
|
+
emb: Optional[torch.Tensor] = None,
|
414
|
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
415
|
+
emb = self.linear(self.silu(emb))
|
416
|
+
(
|
417
|
+
shift_msa,
|
418
|
+
scale_msa,
|
419
|
+
gate_msa,
|
420
|
+
shift_mlp,
|
421
|
+
scale_mlp,
|
422
|
+
gate_mlp,
|
423
|
+
c_shift_msa,
|
424
|
+
c_scale_msa,
|
425
|
+
c_gate_msa,
|
426
|
+
c_shift_mlp,
|
427
|
+
c_scale_mlp,
|
428
|
+
c_gate_mlp,
|
429
|
+
) = emb.chunk(12, dim=1)
|
430
|
+
normed_x = self.norm_x(x)
|
431
|
+
normed_context = self.norm_c(context)
|
432
|
+
x = normed_x * (1 + scale_msa[:, None]) + shift_msa[:, None]
|
433
|
+
context = normed_context * (1 + c_scale_msa[:, None]) + c_shift_msa[:, None]
|
434
|
+
return x, gate_msa, shift_mlp, scale_mlp, gate_mlp, context, c_gate_msa, c_shift_mlp, c_scale_mlp, c_gate_mlp
|
435
|
+
|
436
|
+
|
358
437
|
class CogVideoXLayerNormZero(nn.Module):
|
359
438
|
def __init__(
|
360
439
|
self,
|
@@ -14,6 +14,7 @@ if is_torch_available():
|
|
14
14
|
from .stable_audio_transformer import StableAudioDiTModel
|
15
15
|
from .t5_film_transformer import T5FilmDecoder
|
16
16
|
from .transformer_2d import Transformer2DModel
|
17
|
+
from .transformer_cogview3plus import CogView3PlusTransformer2DModel
|
17
18
|
from .transformer_flux import FluxTransformer2DModel
|
18
19
|
from .transformer_sd3 import SD3Transformer2DModel
|
19
20
|
from .transformer_temporal import TransformerTemporalModel
|
@@ -274,6 +274,7 @@ class AuraFlowTransformer2DModel(ModelMixin, ConfigMixin):
|
|
274
274
|
pos_embed_max_size (`int`, defaults to 4096): Maximum positions to embed from the image latents.
|
275
275
|
"""
|
276
276
|
|
277
|
+
_no_split_modules = ["AuraFlowJointTransformerBlock", "AuraFlowSingleTransformerBlock", "AuraFlowPatchEmbed"]
|
277
278
|
_supports_gradient_checkpointing = True
|
278
279
|
|
279
280
|
@register_to_config
|
@@ -19,7 +19,8 @@ import torch
|
|
19
19
|
from torch import nn
|
20
20
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
|
-
from ...
|
22
|
+
from ...loaders import PeftAdapterMixin
|
23
|
+
from ...utils import USE_PEFT_BACKEND, is_torch_version, logging, scale_lora_layers, unscale_lora_layers
|
23
24
|
from ...utils.torch_utils import maybe_allow_in_graph
|
24
25
|
from ..attention import Attention, FeedForward
|
25
26
|
from ..attention_processor import AttentionProcessor, CogVideoXAttnProcessor2_0, FusedCogVideoXAttnProcessor2_0
|
@@ -152,7 +153,7 @@ class CogVideoXBlock(nn.Module):
|
|
152
153
|
return hidden_states, encoder_hidden_states
|
153
154
|
|
154
155
|
|
155
|
-
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
156
|
+
class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
156
157
|
"""
|
157
158
|
A Transformer model for video-like data in [CogVideoX](https://github.com/THUDM/CogVideo).
|
158
159
|
|
@@ -411,8 +412,24 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
411
412
|
timestep: Union[int, float, torch.LongTensor],
|
412
413
|
timestep_cond: Optional[torch.Tensor] = None,
|
413
414
|
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
415
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
414
416
|
return_dict: bool = True,
|
415
417
|
):
|
418
|
+
if attention_kwargs is not None:
|
419
|
+
attention_kwargs = attention_kwargs.copy()
|
420
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
421
|
+
else:
|
422
|
+
lora_scale = 1.0
|
423
|
+
|
424
|
+
if USE_PEFT_BACKEND:
|
425
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
426
|
+
scale_lora_layers(self, lora_scale)
|
427
|
+
else:
|
428
|
+
if attention_kwargs is not None and attention_kwargs.get("scale", None) is not None:
|
429
|
+
logger.warning(
|
430
|
+
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
431
|
+
)
|
432
|
+
|
416
433
|
batch_size, num_frames, channels, height, width = hidden_states.shape
|
417
434
|
|
418
435
|
# 1. Time embedding
|
@@ -481,6 +498,10 @@ class CogVideoXTransformer3DModel(ModelMixin, ConfigMixin):
|
|
481
498
|
output = hidden_states.reshape(batch_size, num_frames, height // p, width // p, -1, p, p)
|
482
499
|
output = output.permute(0, 1, 4, 2, 5, 3, 6).flatten(5, 6).flatten(3, 4)
|
483
500
|
|
501
|
+
if USE_PEFT_BACKEND:
|
502
|
+
# remove `lora_scale` from each PEFT layer
|
503
|
+
unscale_lora_layers(self, lora_scale)
|
504
|
+
|
484
505
|
if not return_dict:
|
485
506
|
return (output,)
|
486
507
|
return Transformer2DModelOutput(sample=output)
|
@@ -19,7 +19,7 @@ from torch import nn
|
|
19
19
|
from ...configuration_utils import ConfigMixin, register_to_config
|
20
20
|
from ...utils import is_torch_version, logging
|
21
21
|
from ..attention import BasicTransformerBlock
|
22
|
-
from ..attention_processor import Attention, AttentionProcessor, FusedAttnProcessor2_0
|
22
|
+
from ..attention_processor import Attention, AttentionProcessor, AttnProcessor, FusedAttnProcessor2_0
|
23
23
|
from ..embeddings import PatchEmbed, PixArtAlphaTextProjection
|
24
24
|
from ..modeling_outputs import Transformer2DModelOutput
|
25
25
|
from ..modeling_utils import ModelMixin
|
@@ -247,6 +247,14 @@ class PixArtTransformer2DModel(ModelMixin, ConfigMixin):
|
|
247
247
|
for name, module in self.named_children():
|
248
248
|
fn_recursive_attn_processor(name, module, processor)
|
249
249
|
|
250
|
+
def set_default_attn_processor(self):
|
251
|
+
"""
|
252
|
+
Disables custom attention processors and sets the default attention implementation.
|
253
|
+
|
254
|
+
Safe to just use `AttnProcessor()` as PixArt doesn't have any exotic attention processors in default model.
|
255
|
+
"""
|
256
|
+
self.set_attn_processor(AttnProcessor())
|
257
|
+
|
250
258
|
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections
|
251
259
|
def fuse_qkv_projections(self):
|
252
260
|
"""
|