diffusers 0.34.0__py3-none-any.whl → 0.35.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +882 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.0.dist-info}/top_level.txt +0 -0
@@ -15,6 +15,7 @@
15
15
  # limitations under the License.
16
16
 
17
17
  import copy
18
+ import functools
18
19
  import inspect
19
20
  import itertools
20
21
  import json
@@ -42,6 +43,7 @@ from ..quantizers.quantization_config import QuantizationMethod
42
43
  from ..utils import (
43
44
  CONFIG_NAME,
44
45
  FLAX_WEIGHTS_NAME,
46
+ HF_ENABLE_PARALLEL_LOADING,
45
47
  SAFE_WEIGHTS_INDEX_NAME,
46
48
  SAFETENSORS_WEIGHTS_NAME,
47
49
  WEIGHTS_INDEX_NAME,
@@ -62,12 +64,15 @@ from ..utils.hub_utils import (
62
64
  load_or_create_model_card,
63
65
  populate_model_card,
64
66
  )
67
+ from ..utils.torch_utils import empty_device_cache
65
68
  from .model_loading_utils import (
69
+ _caching_allocator_warmup,
66
70
  _determine_device_map,
71
+ _expand_device_map,
67
72
  _fetch_index_file,
68
73
  _fetch_index_file_legacy,
69
- _load_state_dict_into_model,
70
- load_model_dict_into_meta,
74
+ _load_shard_file,
75
+ _load_shard_files_with_threadpool,
71
76
  load_state_dict,
72
77
  )
73
78
 
@@ -168,7 +173,11 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
168
173
 
169
174
  for name, param in parameter.named_parameters():
170
175
  last_dtype = param.dtype
171
- if parameter._keep_in_fp32_modules and any(m in name for m in parameter._keep_in_fp32_modules):
176
+ if (
177
+ hasattr(parameter, "_keep_in_fp32_modules")
178
+ and parameter._keep_in_fp32_modules
179
+ and any(m in name for m in parameter._keep_in_fp32_modules)
180
+ ):
172
181
  continue
173
182
 
174
183
  if param.is_floating_point():
@@ -200,34 +209,6 @@ def get_parameter_dtype(parameter: torch.nn.Module) -> torch.dtype:
200
209
  return last_tuple[1].dtype
201
210
 
202
211
 
203
- def check_support_param_buffer_assignment(model_to_load, state_dict, start_prefix=""):
204
- """
205
- Checks if `model_to_load` supports param buffer assignment (such as when loading in empty weights) by first
206
- checking if the model explicitly disables it, then by ensuring that the state dict keys are a subset of the model's
207
- parameters.
208
-
209
- """
210
- if model_to_load.device.type == "meta":
211
- return False
212
-
213
- if len([key for key in state_dict if key.startswith(start_prefix)]) == 0:
214
- return False
215
-
216
- # Some models explicitly do not support param buffer assignment
217
- if not getattr(model_to_load, "_supports_param_buffer_assignment", True):
218
- logger.debug(
219
- f"{model_to_load.__class__.__name__} does not support param buffer assignment, loading will be slower"
220
- )
221
- return False
222
-
223
- # If the model does, the incoming `state_dict` and the `model_to_load` must be the same dtype
224
- first_key = next(iter(model_to_load.state_dict().keys()))
225
- if start_prefix + first_key in state_dict:
226
- return state_dict[start_prefix + first_key].dtype == model_to_load.state_dict()[first_key].dtype
227
-
228
- return False
229
-
230
-
231
212
  @contextmanager
232
213
  def no_init_weights():
233
214
  """
@@ -266,6 +247,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
266
247
  _keep_in_fp32_modules = None
267
248
  _skip_layerwise_casting_patterns = None
268
249
  _supports_group_offloading = True
250
+ _repeated_blocks = []
269
251
 
270
252
  def __init__(self):
271
253
  super().__init__()
@@ -601,6 +583,60 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
601
583
  offload_to_disk_path=offload_to_disk_path,
602
584
  )
603
585
 
586
+ def set_attention_backend(self, backend: str) -> None:
587
+ """
588
+ Set the attention backend for the model.
589
+
590
+ Args:
591
+ backend (`str`):
592
+ The name of the backend to set. Must be one of the available backends defined in
593
+ `AttentionBackendName`. Available backends can be found in
594
+ `diffusers.attention_dispatch.AttentionBackendName`. Defaults to torch native scaled dot product
595
+ attention as backend.
596
+ """
597
+ from .attention import AttentionModuleMixin
598
+ from .attention_dispatch import AttentionBackendName, _check_attention_backend_requirements
599
+
600
+ # TODO: the following will not be required when everything is refactored to AttentionModuleMixin
601
+ from .attention_processor import Attention, MochiAttention
602
+
603
+ logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
604
+
605
+ backend = backend.lower()
606
+ available_backends = {x.value for x in AttentionBackendName.__members__.values()}
607
+ if backend not in available_backends:
608
+ raise ValueError(f"`{backend=}` must be one of the following: " + ", ".join(available_backends))
609
+ backend = AttentionBackendName(backend)
610
+ _check_attention_backend_requirements(backend)
611
+
612
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
613
+ for module in self.modules():
614
+ if not isinstance(module, attention_classes):
615
+ continue
616
+ processor = module.processor
617
+ if processor is None or not hasattr(processor, "_attention_backend"):
618
+ continue
619
+ processor._attention_backend = backend
620
+
621
+ def reset_attention_backend(self) -> None:
622
+ """
623
+ Resets the attention backend for the model. Following calls to `forward` will use the environment default or
624
+ the torch native scaled dot product attention.
625
+ """
626
+ from .attention import AttentionModuleMixin
627
+ from .attention_processor import Attention, MochiAttention
628
+
629
+ logger.warning("Attention backends are an experimental feature and the API may be subject to change.")
630
+
631
+ attention_classes = (Attention, MochiAttention, AttentionModuleMixin)
632
+ for module in self.modules():
633
+ if not isinstance(module, attention_classes):
634
+ continue
635
+ processor = module.processor
636
+ if processor is None or not hasattr(processor, "_attention_backend"):
637
+ continue
638
+ processor._attention_backend = None
639
+
604
640
  def save_pretrained(
605
641
  self,
606
642
  save_directory: Union[str, os.PathLike],
@@ -880,8 +916,8 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
880
916
 
881
917
  <Tip>
882
918
 
883
- To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with
884
- `huggingface-cli login`. You can also activate the special
919
+ To use private or [gated models](https://huggingface.co/docs/hub/models-gated#gated-models), log-in with `hf
920
+ auth login`. You can also activate the special
885
921
  ["offline-mode"](https://huggingface.co/diffusers/installation.html#offline-mode) to use this method in a
886
922
  firewalled environment.
887
923
 
@@ -925,6 +961,10 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
925
961
  dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
926
962
  disable_mmap = kwargs.pop("disable_mmap", False)
927
963
 
964
+ is_parallel_loading_enabled = HF_ENABLE_PARALLEL_LOADING
965
+ if is_parallel_loading_enabled and not low_cpu_mem_usage:
966
+ raise NotImplementedError("Parallel loading is not supported when not using `low_cpu_mem_usage`.")
967
+
928
968
  if torch_dtype is not None and not isinstance(torch_dtype, torch.dtype):
929
969
  torch_dtype = torch.float32
930
970
  logger.warning(
@@ -1260,6 +1300,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1260
1300
  hf_quantizer=hf_quantizer,
1261
1301
  keep_in_fp32_modules=keep_in_fp32_modules,
1262
1302
  dduf_entries=dduf_entries,
1303
+ is_parallel_loading_enabled=is_parallel_loading_enabled,
1263
1304
  )
1264
1305
  loading_info = {
1265
1306
  "missing_keys": missing_keys,
@@ -1404,6 +1445,39 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1404
1445
  else:
1405
1446
  return super().float(*args)
1406
1447
 
1448
+ def compile_repeated_blocks(self, *args, **kwargs):
1449
+ """
1450
+ Compiles *only* the frequently repeated sub-modules of a model (e.g. the Transformer layers) instead of
1451
+ compiling the entire model. This technique—often called **regional compilation** (see the PyTorch recipe
1452
+ https://docs.pytorch.org/tutorials/recipes/regional_compilation.html) can reduce end-to-end compile time
1453
+ substantially, while preserving the runtime speed-ups you would expect from a full `torch.compile`.
1454
+
1455
+ The set of sub-modules to compile is discovered by the presence of **`_repeated_blocks`** attribute in the
1456
+ model definition. Define this attribute on your model subclass as a list/tuple of class names (strings). Every
1457
+ module whose class name matches will be compiled.
1458
+
1459
+ Once discovered, each matching sub-module is compiled by calling `submodule.compile(*args, **kwargs)`. Any
1460
+ positional or keyword arguments you supply to `compile_repeated_blocks` are forwarded verbatim to
1461
+ `torch.compile`.
1462
+ """
1463
+ repeated_blocks = getattr(self, "_repeated_blocks", None)
1464
+
1465
+ if not repeated_blocks:
1466
+ raise ValueError(
1467
+ "`_repeated_blocks` attribute is empty. "
1468
+ f"Set `_repeated_blocks` for the class `{self.__class__.__name__}` to benefit from faster compilation. "
1469
+ )
1470
+ has_compiled_region = False
1471
+ for submod in self.modules():
1472
+ if submod.__class__.__name__ in repeated_blocks:
1473
+ submod.compile(*args, **kwargs)
1474
+ has_compiled_region = True
1475
+
1476
+ if not has_compiled_region:
1477
+ raise ValueError(
1478
+ f"Regional compilation failed because {repeated_blocks} classes are not found in the model. "
1479
+ )
1480
+
1407
1481
  @classmethod
1408
1482
  def _load_pretrained_model(
1409
1483
  cls,
@@ -1422,6 +1496,7 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1422
1496
  offload_state_dict: Optional[bool] = None,
1423
1497
  offload_folder: Optional[Union[str, os.PathLike]] = None,
1424
1498
  dduf_entries: Optional[Dict[str, DDUFEntry]] = None,
1499
+ is_parallel_loading_enabled: Optional[bool] = False,
1425
1500
  ):
1426
1501
  model_state_dict = model.state_dict()
1427
1502
  expected_keys = list(model_state_dict.keys())
@@ -1436,8 +1511,6 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1436
1511
  unexpected_keys = [k for k in unexpected_keys if re.search(pat, k) is None]
1437
1512
 
1438
1513
  mismatched_keys = []
1439
-
1440
- assign_to_params_buffers = None
1441
1514
  error_msgs = []
1442
1515
 
1443
1516
  # Deal with offload
@@ -1448,80 +1521,67 @@ class ModelMixin(torch.nn.Module, PushToHubMixin):
1448
1521
  " for them. Alternatively, make sure you have `safetensors` installed if the model you are using"
1449
1522
  " offers the weights in this format."
1450
1523
  )
1451
- if offload_folder is not None:
1524
+ else:
1452
1525
  os.makedirs(offload_folder, exist_ok=True)
1453
1526
  if offload_state_dict is None:
1454
1527
  offload_state_dict = True
1455
1528
 
1529
+ # If a device map has been used, we can speedup the load time by warming up the device caching allocator.
1530
+ # If we don't warmup, each tensor allocation on device calls to the allocator for memory (effectively, a
1531
+ # lot of individual calls to device malloc). We can, however, preallocate the memory required by the
1532
+ # tensors using their expected shape and not performing any initialization of the memory (empty data).
1533
+ # When the actual device allocations happen, the allocator already has a pool of unused device memory
1534
+ # that it can re-use for faster loading of the model.
1535
+ if device_map is not None:
1536
+ expanded_device_map = _expand_device_map(device_map, expected_keys)
1537
+ _caching_allocator_warmup(model, expanded_device_map, dtype, hf_quantizer)
1538
+
1456
1539
  offload_index = {} if device_map is not None and "disk" in device_map.values() else None
1540
+ state_dict_folder, state_dict_index = None, None
1457
1541
  if offload_state_dict:
1458
1542
  state_dict_folder = tempfile.mkdtemp()
1459
1543
  state_dict_index = {}
1460
- else:
1461
- state_dict_folder = None
1462
- state_dict_index = None
1463
1544
 
1464
1545
  if state_dict is not None:
1465
1546
  # load_state_dict will manage the case where we pass a dict instead of a file
1466
1547
  # if state dict is not None, it means that we don't need to read the files from resolved_model_file also
1467
1548
  resolved_model_file = [state_dict]
1468
1549
 
1469
- if len(resolved_model_file) > 1:
1470
- resolved_model_file = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
1471
-
1472
- for shard_file in resolved_model_file:
1473
- state_dict = load_state_dict(shard_file, dduf_entries=dduf_entries)
1474
-
1475
- def _find_mismatched_keys(
1476
- state_dict,
1477
- model_state_dict,
1478
- loaded_keys,
1479
- ignore_mismatched_sizes,
1480
- ):
1481
- mismatched_keys = []
1482
- if ignore_mismatched_sizes:
1483
- for checkpoint_key in loaded_keys:
1484
- model_key = checkpoint_key
1485
- # If the checkpoint is sharded, we may not have the key here.
1486
- if checkpoint_key not in state_dict:
1487
- continue
1488
-
1489
- if (
1490
- model_key in model_state_dict
1491
- and state_dict[checkpoint_key].shape != model_state_dict[model_key].shape
1492
- ):
1493
- mismatched_keys.append(
1494
- (checkpoint_key, state_dict[checkpoint_key].shape, model_state_dict[model_key].shape)
1495
- )
1496
- del state_dict[checkpoint_key]
1497
- return mismatched_keys
1498
-
1499
- mismatched_keys += _find_mismatched_keys(
1500
- state_dict,
1501
- model_state_dict,
1502
- loaded_keys,
1503
- ignore_mismatched_sizes,
1504
- )
1550
+ # Prepare the loading function sharing the attributes shared between them.
1551
+ load_fn = functools.partial(
1552
+ _load_shard_files_with_threadpool if is_parallel_loading_enabled else _load_shard_file,
1553
+ model=model,
1554
+ model_state_dict=model_state_dict,
1555
+ device_map=device_map,
1556
+ dtype=dtype,
1557
+ hf_quantizer=hf_quantizer,
1558
+ keep_in_fp32_modules=keep_in_fp32_modules,
1559
+ dduf_entries=dduf_entries,
1560
+ loaded_keys=loaded_keys,
1561
+ unexpected_keys=unexpected_keys,
1562
+ offload_index=offload_index,
1563
+ offload_folder=offload_folder,
1564
+ state_dict_index=state_dict_index,
1565
+ state_dict_folder=state_dict_folder,
1566
+ ignore_mismatched_sizes=ignore_mismatched_sizes,
1567
+ low_cpu_mem_usage=low_cpu_mem_usage,
1568
+ )
1505
1569
 
1506
- if low_cpu_mem_usage:
1507
- offload_index, state_dict_index = load_model_dict_into_meta(
1508
- model,
1509
- state_dict,
1510
- device_map=device_map,
1511
- dtype=dtype,
1512
- hf_quantizer=hf_quantizer,
1513
- keep_in_fp32_modules=keep_in_fp32_modules,
1514
- unexpected_keys=unexpected_keys,
1515
- offload_folder=offload_folder,
1516
- offload_index=offload_index,
1517
- state_dict_index=state_dict_index,
1518
- state_dict_folder=state_dict_folder,
1519
- )
1520
- else:
1521
- if assign_to_params_buffers is None:
1522
- assign_to_params_buffers = check_support_param_buffer_assignment(model, state_dict)
1570
+ if is_parallel_loading_enabled:
1571
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(resolved_model_file)
1572
+ error_msgs += _error_msgs
1573
+ mismatched_keys += _mismatched_keys
1574
+ else:
1575
+ shard_files = resolved_model_file
1576
+ if len(resolved_model_file) > 1:
1577
+ shard_files = logging.tqdm(resolved_model_file, desc="Loading checkpoint shards")
1523
1578
 
1524
- error_msgs += _load_state_dict_into_model(model, state_dict, assign_to_params_buffers)
1579
+ for shard_file in shard_files:
1580
+ offload_index, state_dict_index, _mismatched_keys, _error_msgs = load_fn(shard_file)
1581
+ error_msgs += _error_msgs
1582
+ mismatched_keys += _mismatched_keys
1583
+
1584
+ empty_device_cache()
1525
1585
 
1526
1586
  if offload_index is not None and len(offload_index) > 0:
1527
1587
  save_offload_index(offload_index, offload_folder)
@@ -1858,4 +1918,9 @@ class LegacyModelMixin(ModelMixin):
1858
1918
  # resolve remapping
1859
1919
  remapped_class = _fetch_remapped_cls_from_config(config, cls)
1860
1920
 
1861
- return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
1921
+ if remapped_class is cls:
1922
+ return super(LegacyModelMixin, remapped_class).from_pretrained(
1923
+ pretrained_model_name_or_path, **kwargs_copy
1924
+ )
1925
+ else:
1926
+ return remapped_class.from_pretrained(pretrained_model_name_or_path, **kwargs_copy)
@@ -30,7 +30,9 @@ if is_torch_available():
30
30
  from .transformer_lumina2 import Lumina2Transformer2DModel
31
31
  from .transformer_mochi import MochiTransformer3DModel
32
32
  from .transformer_omnigen import OmniGenTransformer2DModel
33
+ from .transformer_qwenimage import QwenImageTransformer2DModel
33
34
  from .transformer_sd3 import SD3Transformer2DModel
35
+ from .transformer_skyreels_v2 import SkyReelsV2Transformer3DModel
34
36
  from .transformer_temporal import TransformerTemporalModel
35
37
  from .transformer_wan import WanTransformer3DModel
36
38
  from .transformer_wan_vace import WanVACETransformer3DModel
@@ -24,19 +24,13 @@ from ...loaders import FluxTransformer2DLoadersMixin, FromOriginalModelMixin, Pe
24
24
  from ...utils import USE_PEFT_BACKEND, deprecate, logging, scale_lora_layers, unscale_lora_layers
25
25
  from ...utils.import_utils import is_torch_npu_available
26
26
  from ...utils.torch_utils import maybe_allow_in_graph
27
- from ..attention import FeedForward
28
- from ..attention_processor import (
29
- Attention,
30
- AttentionProcessor,
31
- FluxAttnProcessor2_0,
32
- FluxAttnProcessor2_0_NPU,
33
- FusedFluxAttnProcessor2_0,
34
- )
27
+ from ..attention import AttentionMixin, FeedForward
35
28
  from ..cache_utils import CacheMixin
36
29
  from ..embeddings import FluxPosEmbed, PixArtAlphaTextProjection, Timesteps, get_timestep_embedding
37
30
  from ..modeling_outputs import Transformer2DModelOutput
38
31
  from ..modeling_utils import ModelMixin
39
32
  from ..normalization import CombinedTimestepLabelEmbeddings, FP32LayerNorm, RMSNorm
33
+ from .transformer_flux import FluxAttention, FluxAttnProcessor
40
34
 
41
35
 
42
36
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -223,6 +217,8 @@ class ChromaSingleTransformerBlock(nn.Module):
223
217
  self.proj_out = nn.Linear(dim + self.mlp_hidden_dim, dim)
224
218
 
225
219
  if is_torch_npu_available():
220
+ from ..attention_processor import FluxAttnProcessor2_0_NPU
221
+
226
222
  deprecation_message = (
227
223
  "Defaulting to FluxAttnProcessor2_0_NPU for NPU devices will be removed. Attention processors "
228
224
  "should be set explicitly using the `set_attn_processor` method."
@@ -230,17 +226,15 @@ class ChromaSingleTransformerBlock(nn.Module):
230
226
  deprecate("npu_processor", "0.34.0", deprecation_message)
231
227
  processor = FluxAttnProcessor2_0_NPU()
232
228
  else:
233
- processor = FluxAttnProcessor2_0()
229
+ processor = FluxAttnProcessor()
234
230
 
235
- self.attn = Attention(
231
+ self.attn = FluxAttention(
236
232
  query_dim=dim,
237
- cross_attention_dim=None,
238
233
  dim_head=attention_head_dim,
239
234
  heads=num_attention_heads,
240
235
  out_dim=dim,
241
236
  bias=True,
242
237
  processor=processor,
243
- qk_norm="rms_norm",
244
238
  eps=1e-6,
245
239
  pre_only=True,
246
240
  )
@@ -292,17 +286,15 @@ class ChromaTransformerBlock(nn.Module):
292
286
  self.norm1 = ChromaAdaLayerNormZeroPruned(dim)
293
287
  self.norm1_context = ChromaAdaLayerNormZeroPruned(dim)
294
288
 
295
- self.attn = Attention(
289
+ self.attn = FluxAttention(
296
290
  query_dim=dim,
297
- cross_attention_dim=None,
298
291
  added_kv_proj_dim=dim,
299
292
  dim_head=attention_head_dim,
300
293
  heads=num_attention_heads,
301
294
  out_dim=dim,
302
295
  context_pre_only=False,
303
296
  bias=True,
304
- processor=FluxAttnProcessor2_0(),
305
- qk_norm=qk_norm,
297
+ processor=FluxAttnProcessor(),
306
298
  eps=eps,
307
299
  )
308
300
 
@@ -376,7 +368,13 @@ class ChromaTransformerBlock(nn.Module):
376
368
 
377
369
 
378
370
  class ChromaTransformer2DModel(
379
- ModelMixin, ConfigMixin, PeftAdapterMixin, FromOriginalModelMixin, FluxTransformer2DLoadersMixin, CacheMixin
371
+ ModelMixin,
372
+ ConfigMixin,
373
+ PeftAdapterMixin,
374
+ FromOriginalModelMixin,
375
+ FluxTransformer2DLoadersMixin,
376
+ CacheMixin,
377
+ AttentionMixin,
380
378
  ):
381
379
  """
382
380
  The Transformer model introduced in Flux, modified for Chroma.
@@ -407,6 +405,7 @@ class ChromaTransformer2DModel(
407
405
 
408
406
  _supports_gradient_checkpointing = True
409
407
  _no_split_modules = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
408
+ _repeated_blocks = ["ChromaTransformerBlock", "ChromaSingleTransformerBlock"]
410
409
  _skip_layerwise_casting_patterns = ["pos_embed", "norm"]
411
410
 
412
411
  @register_to_config
@@ -474,106 +473,6 @@ class ChromaTransformer2DModel(
474
473
 
475
474
  self.gradient_checkpointing = False
476
475
 
477
- @property
478
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.attn_processors
479
- def attn_processors(self) -> Dict[str, AttentionProcessor]:
480
- r"""
481
- Returns:
482
- `dict` of attention processors: A dictionary containing all attention processors used in the model with
483
- indexed by its weight name.
484
- """
485
- # set recursively
486
- processors = {}
487
-
488
- def fn_recursive_add_processors(name: str, module: torch.nn.Module, processors: Dict[str, AttentionProcessor]):
489
- if hasattr(module, "get_processor"):
490
- processors[f"{name}.processor"] = module.get_processor()
491
-
492
- for sub_name, child in module.named_children():
493
- fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
494
-
495
- return processors
496
-
497
- for name, module in self.named_children():
498
- fn_recursive_add_processors(name, module, processors)
499
-
500
- return processors
501
-
502
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.set_attn_processor
503
- def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
504
- r"""
505
- Sets the attention processor to use to compute attention.
506
-
507
- Parameters:
508
- processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
509
- The instantiated processor class or a dictionary of processor classes that will be set as the processor
510
- for **all** `Attention` layers.
511
-
512
- If `processor` is a dict, the key needs to define the path to the corresponding cross attention
513
- processor. This is strongly recommended when setting trainable attention processors.
514
-
515
- """
516
- count = len(self.attn_processors.keys())
517
-
518
- if isinstance(processor, dict) and len(processor) != count:
519
- raise ValueError(
520
- f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
521
- f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
522
- )
523
-
524
- def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
525
- if hasattr(module, "set_processor"):
526
- if not isinstance(processor, dict):
527
- module.set_processor(processor)
528
- else:
529
- module.set_processor(processor.pop(f"{name}.processor"))
530
-
531
- for sub_name, child in module.named_children():
532
- fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
533
-
534
- for name, module in self.named_children():
535
- fn_recursive_attn_processor(name, module, processor)
536
-
537
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.fuse_qkv_projections with FusedAttnProcessor2_0->FusedFluxAttnProcessor2_0
538
- def fuse_qkv_projections(self):
539
- """
540
- Enables fused QKV projections. For self-attention modules, all projection matrices (i.e., query, key, value)
541
- are fused. For cross-attention modules, key and value projection matrices are fused.
542
-
543
- <Tip warning={true}>
544
-
545
- This API is 🧪 experimental.
546
-
547
- </Tip>
548
- """
549
- self.original_attn_processors = None
550
-
551
- for _, attn_processor in self.attn_processors.items():
552
- if "Added" in str(attn_processor.__class__.__name__):
553
- raise ValueError("`fuse_qkv_projections()` is not supported for models having added KV projections.")
554
-
555
- self.original_attn_processors = self.attn_processors
556
-
557
- for module in self.modules():
558
- if isinstance(module, Attention):
559
- module.fuse_projections(fuse=True)
560
-
561
- self.set_attn_processor(FusedFluxAttnProcessor2_0())
562
-
563
- # Copied from diffusers.models.unets.unet_2d_condition.UNet2DConditionModel.unfuse_qkv_projections
564
- def unfuse_qkv_projections(self):
565
- """Disables the fused QKV projection if enabled.
566
-
567
- <Tip warning={true}>
568
-
569
- This API is 🧪 experimental.
570
-
571
- </Tip>
572
-
573
- """
574
- if self.original_attn_processors is not None:
575
- self.set_attn_processor(self.original_attn_processors)
576
-
577
476
  def forward(
578
477
  self,
579
478
  hidden_states: torch.Tensor,
@@ -21,13 +21,14 @@ import torch.nn.functional as F
21
21
  from ...configuration_utils import ConfigMixin, register_to_config
22
22
  from ...loaders import PeftAdapterMixin
23
23
  from ...utils import USE_PEFT_BACKEND, logging, scale_lora_layers, unscale_lora_layers
24
+ from ...utils.torch_utils import maybe_allow_in_graph
24
25
  from ..attention import FeedForward
25
26
  from ..attention_processor import Attention
26
27
  from ..cache_utils import CacheMixin
27
28
  from ..embeddings import CogView3CombinedTimestepSizeEmbeddings
28
29
  from ..modeling_outputs import Transformer2DModelOutput
29
30
  from ..modeling_utils import ModelMixin
30
- from ..normalization import AdaLayerNormContinuous
31
+ from ..normalization import LayerNorm, RMSNorm
31
32
 
32
33
 
33
34
  logger = logging.get_logger(__name__) # pylint: disable=invalid-name
@@ -453,6 +454,7 @@ class CogView4TrainingAttnProcessor:
453
454
  return hidden_states, encoder_hidden_states
454
455
 
455
456
 
457
+ @maybe_allow_in_graph
456
458
  class CogView4TransformerBlock(nn.Module):
457
459
  def __init__(
458
460
  self,
@@ -582,6 +584,38 @@ class CogView4RotaryPosEmbed(nn.Module):
582
584
  return (freqs.cos(), freqs.sin())
583
585
 
584
586
 
587
+ class CogView4AdaLayerNormContinuous(nn.Module):
588
+ """
589
+ CogView4-only final AdaLN: LN(x) -> Linear(cond) -> chunk -> affine. Matches Megatron: **no activation** before the
590
+ Linear on conditioning embedding.
591
+ """
592
+
593
+ def __init__(
594
+ self,
595
+ embedding_dim: int,
596
+ conditioning_embedding_dim: int,
597
+ elementwise_affine: bool = True,
598
+ eps: float = 1e-5,
599
+ bias: bool = True,
600
+ norm_type: str = "layer_norm",
601
+ ):
602
+ super().__init__()
603
+ self.linear = nn.Linear(conditioning_embedding_dim, embedding_dim * 2, bias=bias)
604
+ if norm_type == "layer_norm":
605
+ self.norm = LayerNorm(embedding_dim, eps, elementwise_affine, bias)
606
+ elif norm_type == "rms_norm":
607
+ self.norm = RMSNorm(embedding_dim, eps, elementwise_affine)
608
+ else:
609
+ raise ValueError(f"unknown norm_type {norm_type}")
610
+
611
+ def forward(self, x: torch.Tensor, conditioning_embedding: torch.Tensor) -> torch.Tensor:
612
+ # *** NO SiLU here ***
613
+ emb = self.linear(conditioning_embedding.to(x.dtype))
614
+ scale, shift = torch.chunk(emb, 2, dim=1)
615
+ x = self.norm(x) * (1 + scale)[:, None, :] + shift[:, None, :]
616
+ return x
617
+
618
+
585
619
  class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, CacheMixin):
586
620
  r"""
587
621
  Args:
@@ -664,7 +698,7 @@ class CogView4Transformer2DModel(ModelMixin, ConfigMixin, PeftAdapterMixin, Cach
664
698
  )
665
699
 
666
700
  # 4. Output projection
667
- self.norm_out = AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
701
+ self.norm_out = CogView4AdaLayerNormContinuous(inner_dim, time_embed_dim, elementwise_affine=False)
668
702
  self.proj_out = nn.Linear(inner_dim, patch_size * patch_size * out_channels, bias=True)
669
703
 
670
704
  self.gradient_checkpointing = False