diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
|
|
15
15
|
# limitations under the License.
|
16
16
|
|
17
17
|
import copy
|
18
|
+
import functools
|
18
19
|
import inspect
|
19
20
|
import itertools
|
20
21
|
import json
|
@@ -42,6 +43,7 @@ from ..quantizers.quantization_config import QuantizationMethod
|
|
42
43
|
from ..utils import (
|
43
44
|
CONFIG_NAME,
|
44
45
|
FLAX_WEIGHTS_NAME,
|
46
|
+
HF_ENABLE_PARALLEL_LOADING,
|
45
47
|
SAFE_WEIGHTS_INDEX_NAME,
|
46
48
|
SAFETENSORS_WEIGHTS_NAME,
|
47
49
|
WEIGHTS_INDEX_NAME,
|
@@ -62,12 +64,15 @@ from ..utils.hub_utils import (
|
|
62
64
|
load_or_create_model_card,
|
63
65
|
populate_model_card,
|
64
66
|
)
|
67
|
+
from ..utils.torch_utils import empty_device_cache
|
65
68
|
from .model_loading_utils import (
|
69
|
+
_caching_allocator_warmup,
|
66
70
|
_determine_device_map,
|
71
|
+
_expand_device_map,
|
67
72
|
_fetch_index_file,
|
68
73
|
_fetch_index_file_legacy,
|
69
|
-
|
70
|
-
|
74
|
+
_load_shard_file,
|
75
|
+
_load_shard_files_with_threadpool,
|
71
76
|
load_state_dict,
|
72
77
|
)
|
73
78
|
|
@@ -168,7 +173,11 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
|
168
173
|
|
169
174
|
for name, param in parameter.named_parameters():
|
170
175
|
last_dtype = param.dtype
|
171
|
-
if
|
176
|
+
if (
|
177
|
+
hasattr(parameter, "_keep_in_fp32_modules")
|
178
|
+
and parameter._keep_in_fp32_modules
|
179
|
+
and any(m in name for m in parameter._keep_in_fp32_modules)
|
180
|
+
):
|
172
181
|
continue
|
173
182
|
|
174
183
|
if param.is_floating_point():
|
@@ -200,34 +209,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
|
|
200
209
|
return last_tuple[1].dtype
|
201
210
|
|
202
211
|
|
203
|
-
def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
|
204
|
-
"""
|
205
|
-
Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
|
206
|
-
checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
|
207
|
-
parameters.
|
208
|
-
|
209
|
-
"""
|
210
|
-
if model_to_load.device.type == "meta":
|
211
|
-
return False
|
212
|
-
|
213
|
-
if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
|
214
|
-
return False
|
215
|
-
|
216
|
-
# Some models explicitly do not support param buffer assignment
|
217
|
-
if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
|
218
|
-
logger.debug(
|
219
|
-
f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
|
220
|
-
)
|
221
|
-
return False
|
222
|
-
|
223
|
-
# If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
|
224
|
-
first_key = next(iter(model_to_load.state_dict().keys()))
|
225
|
-
if start_prefix + first_key in state_dict:
|
226
|
-
return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
|
227
|
-
|
228
|
-
return False
|
229
|
-
|
230
|
-
|
231
212
|
@contextmanager
|
232
213
|
def no_init_weights():
|
233
214
|
"""
|
@@ -266,6 +247,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
266
247
|
_keep_in_fp32_modules = None
|
267
248
|
_skip_layerwise_casting_patterns = None
|
268
249
|
_supports_group_offloading = True
|
250
|
+
_repeated_blocks = []
|
269
251
|
|
270
252
|
def __init__(self):
|
271
253
|
super().__init__()
|
@@ -601,6 +583,60 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
601
583
|
offload_to_disk_path=offload_to_disk_path,
|
602
584
|
)
|
603
585
|
|
586
|
+
def set_attention_backend(self, backend: str) -> None:
|
587
|
+
"""
|
588
|
+
Set the attention backend for the model.
|
589
|
+
|
590
|
+
Args:
|
591
|
+
backend (`str`):
|
592
|
+
The name of the backend to set. Must be one of the available backends defined in
|
593
|
+
`AttentionBackendName`. Available backends can be found in
|
594
|
+
`diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
|
595
|
+
attention as backend.
|
596
|
+
"""
|
597
|
+
from .attention import AttentionModuleMixin
|
598
|
+
from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
|
599
|
+
|
600
|
+
# TODO: the following will not be required when everything is refactored to AttentionModuleMixin
|
601
|
+
from .attention_processor import Attention, MochiAttention
|
602
|
+
|
603
|
+
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
|
604
|
+
|
605
|
+
backend = backend.lower()
|
606
|
+
available_backends = {x.value for x in AttentionBackendName.__members__.values()}
|
607
|
+
if backend not in available_backends:
|
608
|
+
raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
|
609
|
+
backend = AttentionBackendName(backend)
|
610
|
+
_check_attention_backend_requirements(backend)
|
611
|
+
|
612
|
+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
613
|
+
for module in self.modules():
|
614
|
+
if not isinstance(module, attention_classes):
|
615
|
+
continue
|
616
|
+
processor = module.processor
|
617
|
+
if processor is None or not hasattr(processor, "_attention_backend"):
|
618
|
+
continue
|
619
|
+
processor._attention_backend = backend
|
620
|
+
|
621
|
+
def reset_attention_backend(self) -> None:
|
622
|
+
"""
|
623
|
+
Resets the attention backend for the model. Following calls to `forward` will use the environment default or
|
624
|
+
the torch native scaled dot product attention.
|
625
|
+
"""
|
626
|
+
from .attention import AttentionModuleMixin
|
627
|
+
from .attention_processor import Attention, MochiAttention
|
628
|
+
|
629
|
+
logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
|
630
|
+
|
631
|
+
attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
|
632
|
+
for module in self.modules():
|
633
|
+
if not isinstance(module, attention_classes):
|
634
|
+
continue
|
635
|
+
processor = module.processor
|
636
|
+
if processor is None or not hasattr(processor, "_attention_backend"):
|
637
|
+
continue
|
638
|
+
processor._attention_backend = None
|
639
|
+
|
604
640
|
def save_pretrained(
|
605
641
|
self,
|
606
642
|
save_directory: Union[str, os.PathLike],
|
@@ -880,8 +916,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
880
916
|
|
881
917
|
<Tip>
|
882
918
|
|
883
|
-
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
|
884
|
-
|
919
|
+
To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
|
920
|
+
auth login`. You can also activate the special
|
885
921
|
["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
|
886
922
|
firewalled environment.
|
887
923
|
|
@@ -925,6 +961,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
925
961
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
926
962
|
disable_mmap = kwargs.pop("disable_mmap", False)
|
927
963
|
|
964
|
+
is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
|
965
|
+
if is_parallel_loading_enabled and not low_cpu_mem_usage:
|
966
|
+
raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
|
967
|
+
|
928
968
|
if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
|
929
969
|
torch_dtype = torch.float32
|
930
970
|
logger.warning(
|
@@ -1260,6 +1300,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1260
1300
|
hf_quantizer=hf_quantizer,
|
1261
1301
|
keep_in_fp32_modules=keep_in_fp32_modules,
|
1262
1302
|
dduf_entries=dduf_entries,
|
1303
|
+
is_parallel_loading_enabled=is_parallel_loading_enabled,
|
1263
1304
|
)
|
1264
1305
|
loading_info = {
|
1265
1306
|
"missing_keys": missing_keys,
|
@@ -1404,6 +1445,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1404
1445
|
else:
|
1405
1446
|
return super().float(*args)
|
1406
1447
|
|
1448
|
+
def compile_repeated_blocks(self, *args, **kwargs):
|
1449
|
+
"""
|
1450
|
+
Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
|
1451
|
+
compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
|
1452
|
+
https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
|
1453
|
+
substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
|
1454
|
+
|
1455
|
+
The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
|
1456
|
+
model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
|
1457
|
+
module whose class name matches will be compiled.
|
1458
|
+
|
1459
|
+
Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
|
1460
|
+
positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
|
1461
|
+
`torch.compile`.
|
1462
|
+
"""
|
1463
|
+
repeated_blocks = getattr(self, "_repeated_blocks", None)
|
1464
|
+
|
1465
|
+
if not repeated_blocks:
|
1466
|
+
raise ValueError(
|
1467
|
+
"`_repeated_blocks` attribute is empty. "
|
1468
|
+
f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
|
1469
|
+
)
|
1470
|
+
has_compiled_region = False
|
1471
|
+
for submod in self.modules():
|
1472
|
+
if submod.__class__.__name__ in repeated_blocks:
|
1473
|
+
submod.compile(*args, **kwargs)
|
1474
|
+
has_compiled_region = True
|
1475
|
+
|
1476
|
+
if not has_compiled_region:
|
1477
|
+
raise ValueError(
|
1478
|
+
f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
|
1479
|
+
)
|
1480
|
+
|
1407
1481
|
@classmethod
|
1408
1482
|
def _load_pretrained_model(
|
1409
1483
|
cls,
|
@@ -1422,6 +1496,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1422
1496
|
offload_state_dict: Optional[bool] = None,
|
1423
1497
|
offload_folder: Optional[Union[str, os.PathLike]] = None,
|
1424
1498
|
dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
|
1499
|
+
is_parallel_loading_enabled: Optional[bool] = False,
|
1425
1500
|
):
|
1426
1501
|
model_state_dict = model.state_dict()
|
1427
1502
|
expected_keys = list(model_state_dict.keys())
|
@@ -1436,8 +1511,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1436
1511
|
unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
|
1437
1512
|
|
1438
1513
|
mismatched_keys = []
|
1439
|
-
|
1440
|
-
assign_to_params_buffers = None
|
1441
1514
|
error_msgs = []
|
1442
1515
|
|
1443
1516
|
# Deal with offload
|
@@ -1448,80 +1521,67 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
|
|
1448
1521
|
" for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
|
1449
1522
|
" offers the weights in this format."
|
1450
1523
|
)
|
1451
|
-
|
1524
|
+
else:
|
1452
1525
|
os.makedirs(offload_folder, exist_ok=True)
|
1453
1526
|
if offload_state_dict is None:
|
1454
1527
|
offload_state_dict = True
|
1455
1528
|
|
1529
|
+
# If a device map has been used, we can speedup the load time by warming up the device caching allocator.
|
1530
|
+
# If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
|
1531
|
+
# lot of individual calls to device malloc). We can, however, preallocate the memory required by the
|
1532
|
+
# tensors using their expected shape and not performing any initialization of the memory (empty data).
|
1533
|
+
# When the actual device allocations happen, the allocator already has a pool of unused device memory
|
1534
|
+
# that it can re-use for faster loading of the model.
|
1535
|
+
if device_map is not None:
|
1536
|
+
expanded_device_map = _expand_device_map(device_map, expected_keys)
|
1537
|
+
_caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
|
1538
|
+
|
1456
1539
|
offload_index = {} if device_map is not None and "disk" in device_map.values() else None
|
1540
|
+
state_dict_folder, state_dict_index = None, None
|
1457
1541
|
if offload_state_dict:
|
1458
1542
|
state_dict_folder = tempfile.mkdtemp()
|
1459
1543
|
state_dict_index = {}
|
1460
|
-
else:
|
1461
|
-
state_dict_folder = None
|
1462
|
-
state_dict_index = None
|
1463
1544
|
|
1464
1545
|
if state_dict is not None:
|
1465
1546
|
# load_state_dict will manage the case where we pass a dict instead of a file
|
1466
1547
|
# if state dict is not None, it means that we don't need to read the files from resolved_model_file also
|
1467
1548
|
resolved_model_file = [state_dict]
|
1468
1549
|
|
1469
|
-
|
1470
|
-
|
1471
|
-
|
1472
|
-
|
1473
|
-
|
1474
|
-
|
1475
|
-
|
1476
|
-
|
1477
|
-
|
1478
|
-
|
1479
|
-
|
1480
|
-
|
1481
|
-
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1489
|
-
if (
|
1490
|
-
model_key in model_state_dict
|
1491
|
-
and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
|
1492
|
-
):
|
1493
|
-
mismatched_keys.append(
|
1494
|
-
(checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
|
1495
|
-
)
|
1496
|
-
del state_dict[checkpoint_key]
|
1497
|
-
return mismatched_keys
|
1498
|
-
|
1499
|
-
mismatched_keys += _find_mismatched_keys(
|
1500
|
-
state_dict,
|
1501
|
-
model_state_dict,
|
1502
|
-
loaded_keys,
|
1503
|
-
ignore_mismatched_sizes,
|
1504
|
-
)
|
1550
|
+
# Prepare the loading function sharing the attributes shared between them.
|
1551
|
+
load_fn = functools.partial(
|
1552
|
+
_load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file,
|
1553
|
+
model=model,
|
1554
|
+
model_state_dict=model_state_dict,
|
1555
|
+
device_map=device_map,
|
1556
|
+
dtype=dtype,
|
1557
|
+
hf_quantizer=hf_quantizer,
|
1558
|
+
keep_in_fp32_modules=keep_in_fp32_modules,
|
1559
|
+
dduf_entries=dduf_entries,
|
1560
|
+
loaded_keys=loaded_keys,
|
1561
|
+
unexpected_keys=unexpected_keys,
|
1562
|
+
offload_index=offload_index,
|
1563
|
+
offload_folder=offload_folder,
|
1564
|
+
state_dict_index=state_dict_index,
|
1565
|
+
state_dict_folder=state_dict_folder,
|
1566
|
+
ignore_mismatched_sizes=ignore_mismatched_sizes,
|
1567
|
+
low_cpu_mem_usage=low_cpu_mem_usage,
|
1568
|
+
)
|
1505
1569
|
|
1506
|
-
|
1507
|
-
|
1508
|
-
|
1509
|
-
|
1510
|
-
|
1511
|
-
|
1512
|
-
|
1513
|
-
|
1514
|
-
unexpected_keys=unexpected_keys,
|
1515
|
-
offload_folder=offload_folder,
|
1516
|
-
offload_index=offload_index,
|
1517
|
-
state_dict_index=state_dict_index,
|
1518
|
-
state_dict_folder=state_dict_folder,
|
1519
|
-
)
|
1520
|
-
else:
|
1521
|
-
if assign_to_params_buffers is None:
|
1522
|
-
assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
|
1570
|
+
if is_parallel_loading_enabled:
|
1571
|
+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file)
|
1572
|
+
error_msgs += _error_msgs
|
1573
|
+
mismatched_keys += _mismatched_keys
|
1574
|
+
else:
|
1575
|
+
shard_files = resolved_model_file
|
1576
|
+
if len(resolved_model_file) > 1:
|
1577
|
+
shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
|
1523
1578
|
|
1524
|
-
|
1579
|
+
for shard_file in shard_files:
|
1580
|
+
offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
|
1581
|
+
error_msgs += _error_msgs
|
1582
|
+
mismatched_keys += _mismatched_keys
|
1583
|
+
|
1584
|
+
empty_device_cache()
|
1525
1585
|
|
1526
1586
|
if offload_index is not None and len(offload_index) > 0:
|
1527
1587
|
save_offload_index(offload_index, offload_folder)
|
@@ -1858,4 +1918,9 @@ class LegacyModelMixin(ModelMixin):
|
|
1858
1918
|
# resolve remapping
|
1859
1919
|
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
1860
1920
|
|
1861
|
-
|
1921
|
+
if remapped_class is cls:
|
1922
|
+
return super(LegacyModelMixin, remapped_class).from_pretrained(
|
1923
|
+
pretrained_model_name_or_path, **kwargs_copy
|
1924
|
+
)
|
1925
|
+
else:
|
1926
|
+
return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
|
@@ -30,7 +30,9 @@ if is_torch_available():
|
|
30
30
|
from .transformer_lumina2 import Lumina2Transformer2DModel
|
31
31
|
from .transformer_mochi import MochiTransformer3DModel
|
32
32
|
from .transformer_omnigen import OmniGenTransformer2DModel
|
33
|
+
from .transformer_qwenimage import QwenImageTransformer2DModel
|
33
34
|
from .transformer_sd3 import SD3Transformer2DModel
|
35
|
+
from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
|
34
36
|
from .transformer_temporal import TransformerTemporalModel
|
35
37
|
from .transformer_wan import WanTransformer3DModel
|
36
38
|
from .transformer_wan_vace import WanVACETransformer3DModel
|
@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
|
|
24
24
|
from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
|
25
25
|
from ...utils.import_utils import is_torch_npu_available
|
26
26
|
from ...utils.torch_utils import maybe_allow_in_graph
|
27
|
-
from ..attention import FeedForward
|
28
|
-
from ..attention_processor import (
|
29
|
-
Attention,
|
30
|
-
AttentionProcessor,
|
31
|
-
FluxAttnProcessor2_0,
|
32
|
-
FluxAttnProcessor2_0_NPU,
|
33
|
-
FusedFluxAttnProcessor2_0,
|
34
|
-
)
|
27
|
+
from ..attention import AttentionMixin, FeedForward
|
35
28
|
from ..cache_utils import CacheMixin
|
36
29
|
from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
|
37
30
|
from ..modeling_outputs import Transformer2DModelOutput
|
38
31
|
from ..modeling_utils import ModelMixin
|
39
32
|
from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
|
33
|
+
from .transformer_flux import FluxAttention, FluxAttnProcessor
|
40
34
|
|
41
35
|
|
42
36
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
|
|
223
217
|
self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
|
224
218
|
|
225
219
|
if is_torch_npu_available():
|
220
|
+
from ..attention_processor import FluxAttnProcessor2_0_NPU
|
221
|
+
|
226
222
|
deprecation_message = (
|
227
223
|
"Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
|
228
224
|
"should be set explicitly using the `set_attn_processor` method."
|
@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
|
|
230
226
|
deprecate("npu_processor", "0.34.0", deprecation_message)
|
231
227
|
processor = FluxAttnProcessor2_0_NPU()
|
232
228
|
else:
|
233
|
-
processor =
|
229
|
+
processor = FluxAttnProcessor()
|
234
230
|
|
235
|
-
self.attn =
|
231
|
+
self.attn = FluxAttention(
|
236
232
|
query_dim=dim,
|
237
|
-
cross_attention_dim=None,
|
238
233
|
dim_head=attention_head_dim,
|
239
234
|
heads=num_attention_heads,
|
240
235
|
out_dim=dim,
|
241
236
|
bias=True,
|
242
237
|
processor=processor,
|
243
|
-
qk_norm="rms_norm",
|
244
238
|
eps=1e-6,
|
245
239
|
pre_only=True,
|
246
240
|
)
|
@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
|
|
292
286
|
self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
|
293
287
|
self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
|
294
288
|
|
295
|
-
self.attn =
|
289
|
+
self.attn = FluxAttention(
|
296
290
|
query_dim=dim,
|
297
|
-
cross_attention_dim=None,
|
298
291
|
added_kv_proj_dim=dim,
|
299
292
|
dim_head=attention_head_dim,
|
300
293
|
heads=num_attention_heads,
|
301
294
|
out_dim=dim,
|
302
295
|
context_pre_only=False,
|
303
296
|
bias=True,
|
304
|
-
processor=
|
305
|
-
qk_norm=qk_norm,
|
297
|
+
processor=FluxAttnProcessor(),
|
306
298
|
eps=eps,
|
307
299
|
)
|
308
300
|
|
@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
|
|
376
368
|
|
377
369
|
|
378
370
|
class ChromaTransformer2DModel(
|
379
|
-
ModelMixin,
|
371
|
+
ModelMixin,
|
372
|
+
ConfigMixin,
|
373
|
+
PeftAdapterMixin,
|
374
|
+
FromOriginalModelMixin,
|
375
|
+
FluxTransformer2DLoadersMixin,
|
376
|
+
CacheMixin,
|
377
|
+
AttentionMixin,
|
380
378
|
):
|
381
379
|
"""
|
382
380
|
The Transformer model introduced in Flux, modified for Chroma.
|
@@ -407,6 +405,7 @@ class ChromaTransformer2DModel(
|
|
407
405
|
|
408
406
|
_supports_gradient_checkpointing = True
|
409
407
|
_no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
408
|
+
_repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
|
410
409
|
_skip_layerwise_casting_patterns = ["pos_embed", "norm"]
|
411
410
|
|
412
411
|
@register_to_config
|
@@ -474,106 +473,6 @@ class ChromaTransformer2DModel(
|
|
474
473
|
|
475
474
|
self.gradient_checkpointing = False
|
476
475
|
|
477
|
-
@property
|
478
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
|
479
|
-
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
480
|
-
r"""
|
481
|
-
Returns:
|
482
|
-
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
483
|
-
indexed by its weight name.
|
484
|
-
"""
|
485
|
-
# set recursively
|
486
|
-
processors = {}
|
487
|
-
|
488
|
-
def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
|
489
|
-
if hasattr(module, "get_processor"):
|
490
|
-
processors[f"{name}.processor"] = module.get_processor()
|
491
|
-
|
492
|
-
for sub_name, child in module.named_children():
|
493
|
-
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
494
|
-
|
495
|
-
return processors
|
496
|
-
|
497
|
-
for name, module in self.named_children():
|
498
|
-
fn_recursive_add_processors(name, module, processors)
|
499
|
-
|
500
|
-
return processors
|
501
|
-
|
502
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
|
503
|
-
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
504
|
-
r"""
|
505
|
-
Sets the attention processor to use to compute attention.
|
506
|
-
|
507
|
-
Parameters:
|
508
|
-
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
509
|
-
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
510
|
-
for **all** `Attention` layers.
|
511
|
-
|
512
|
-
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
513
|
-
processor. This is strongly recommended when setting trainable attention processors.
|
514
|
-
|
515
|
-
"""
|
516
|
-
count = len(self.attn_processors.keys())
|
517
|
-
|
518
|
-
if isinstance(processor, dict) and len(processor) != count:
|
519
|
-
raise ValueError(
|
520
|
-
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
521
|
-
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
522
|
-
)
|
523
|
-
|
524
|
-
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
525
|
-
if hasattr(module, "set_processor"):
|
526
|
-
if not isinstance(processor, dict):
|
527
|
-
module.set_processor(processor)
|
528
|
-
else:
|
529
|
-
module.set_processor(processor.pop(f"{name}.processor"))
|
530
|
-
|
531
|
-
for sub_name, child in module.named_children():
|
532
|
-
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
533
|
-
|
534
|
-
for name, module in self.named_children():
|
535
|
-
fn_recursive_attn_processor(name, module, processor)
|
536
|
-
|
537
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
|
538
|
-
def fuse_qkv_projections(self):
|
539
|
-
"""
|
540
|
-
Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
|
541
|
-
are fused. For cross-attention modules, key and value projection matrices are fused.
|
542
|
-
|
543
|
-
<Tip warning={true}>
|
544
|
-
|
545
|
-
This API is 🧪 experimental.
|
546
|
-
|
547
|
-
</Tip>
|
548
|
-
"""
|
549
|
-
self.original_attn_processors = None
|
550
|
-
|
551
|
-
for _, attn_processor in self.attn_processors.items():
|
552
|
-
if "Added" in str(attn_processor.__class__.__name__):
|
553
|
-
raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
|
554
|
-
|
555
|
-
self.original_attn_processors = self.attn_processors
|
556
|
-
|
557
|
-
for module in self.modules():
|
558
|
-
if isinstance(module, Attention):
|
559
|
-
module.fuse_projections(fuse=True)
|
560
|
-
|
561
|
-
self.set_attn_processor(FusedFluxAttnProcessor2_0())
|
562
|
-
|
563
|
-
# Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
|
564
|
-
def unfuse_qkv_projections(self):
|
565
|
-
"""Disables the fused QKV projection if enabled.
|
566
|
-
|
567
|
-
<Tip warning={true}>
|
568
|
-
|
569
|
-
This API is 🧪 experimental.
|
570
|
-
|
571
|
-
</Tip>
|
572
|
-
|
573
|
-
"""
|
574
|
-
if self.original_attn_processors is not None:
|
575
|
-
self.set_attn_processor(self.original_attn_processors)
|
576
|
-
|
577
476
|
def forward(
|
578
477
|
self,
|
579
478
|
hidden_states: torch.Tensor,
|
@@ -21,13 +21,14 @@ import torch.nn.functional as F
|
|
21
21
|
from ...configuration_utils import ConfigMixin, register_to_config
|
22
22
|
from ...loaders import PeftAdapterMixin
|
23
23
|
from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
|
24
|
+
from ...utils.torch_utils import maybe_allow_in_graph
|
24
25
|
from ..attention import FeedForward
|
25
26
|
from ..attention_processor import Attention
|
26
27
|
from ..cache_utils import CacheMixin
|
27
28
|
from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
|
28
29
|
from ..modeling_outputs import Transformer2DModelOutput
|
29
30
|
from ..modeling_utils import ModelMixin
|
30
|
-
from ..normalization import
|
31
|
+
from ..normalization import LayerNorm, RMSNorm
|
31
32
|
|
32
33
|
|
33
34
|
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
|
|
453
454
|
return hidden_states, encoder_hidden_states
|
454
455
|
|
455
456
|
|
457
|
+
@maybe_allow_in_graph
|
456
458
|
class CogView4TransformerBlock(nn.Module):
|
457
459
|
def __init__(
|
458
460
|
self,
|
@@ -582,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
|
|
582
584
|
return (freqs.cos(), freqs.sin())
|
583
585
|
|
584
586
|
|
587
|
+
class CogView4AdaLayerNormContinuous(nn.Module):
|
588
|
+
"""
|
589
|
+
CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
|
590
|
+
Linear on conditioning embedding.
|
591
|
+
"""
|
592
|
+
|
593
|
+
def __init__(
|
594
|
+
self,
|
595
|
+
embedding_dim: int,
|
596
|
+
conditioning_embedding_dim: int,
|
597
|
+
elementwise_affine: bool = True,
|
598
|
+
eps: float = 1e-5,
|
599
|
+
bias: bool = True,
|
600
|
+
norm_type: str = "layer_norm",
|
601
|
+
):
|
602
|
+
super().__init__()
|
603
|
+
self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
|
604
|
+
if norm_type == "layer_norm":
|
605
|
+
self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
|
606
|
+
elif norm_type == "rms_norm":
|
607
|
+
self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
|
608
|
+
else:
|
609
|
+
raise ValueError(f"unknown norm_type {norm_type}")
|
610
|
+
|
611
|
+
def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
|
612
|
+
# *** NO SiLU here ***
|
613
|
+
emb = self.linear(conditioning_embedding.to(x.dtype))
|
614
|
+
scale, shift = torch.chunk(emb, 2, dim=1)
|
615
|
+
x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
|
616
|
+
return x
|
617
|
+
|
618
|
+
|
585
619
|
class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
|
586
620
|
r"""
|
587
621
|
Args:
|
@@ -664,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
|
|
664
698
|
)
|
665
699
|
|
666
700
|
# 4. Output projection
|
667
|
-
self.norm_out =
|
701
|
+
self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
|
668
702
|
self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
|
669
703
|
|
670
704
|
self.gradient_checkpointing = False
|