diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -41,6 +41,7 @@ from .lora_base import ( # noqa
41
41
  )
42
42
  from .lora_conversion_utils import (
43
43
  _convert_bfl_flux_control_lora_to_diffusers,
44
+ _convert_fal_kontext_lora_to_diffusers,
44
45
  _convert_hunyuan_video_lora_to_diffusers,
45
46
  _convert_kohya_flux_lora_to_diffusers,
46
47
  _convert_musubi_wan_lora_to_diffusers,
@@ -48,6 +49,7 @@ from .lora_conversion_utils import (
48
49
  _convert_non_diffusers_lora_to_diffusers,
49
50
  _convert_non_diffusers_ltxv_lora_to_diffusers,
50
51
  _convert_non_diffusers_lumina2_lora_to_diffusers,
52
+ _convert_non_diffusers_qwen_lora_to_diffusers,
51
53
  _convert_non_diffusers_wan_lora_to_diffusers,
52
54
  _convert_xlabs_flux_lora_to_diffusers,
53
55
  _maybe_map_sgm_blocks_to_diffusers,
@@ -2062,6 +2064,17 @@ class FluxLoraLoaderMixin(LoraBaseMixin):
2062
2064
  return_metadata=return_lora_metadata,
2063
2065
  )
2064
2066
 
2067
+ is_fal_kontext = any("base_model" in k for k in state_dict)
2068
+ if is_fal_kontext:
2069
+ state_dict = _convert_fal_kontext_lora_to_diffusers(state_dict)
2070
+ return cls._prepare_outputs(
2071
+ state_dict,
2072
+ metadata=metadata,
2073
+ alphas=None,
2074
+ return_alphas=return_alphas,
2075
+ return_metadata=return_lora_metadata,
2076
+ )
2077
+
2065
2078
  # For state dicts like
2066
2079
  # https://huggingface.co/TheLastBen/Jon_Snow_Flux_LoRA
2067
2080
  keys = list(state_dict.keys())
@@ -5052,7 +5065,7 @@ class WanLoraLoaderMixin(LoraBaseMixin):
5052
5065
  Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`WanPipeline`] and `[WanImageToVideoPipeline`].
5053
5066
  """
5054
5067
 
5055
- _lora_loadable_modules = ["transformer"]
5068
+ _lora_loadable_modules = ["transformer", "transformer_2"]
5056
5069
  transformer_name = TRANSFORMER_NAME
5057
5070
 
5058
5071
  @classmethod
@@ -5257,15 +5270,35 @@ class WanLoraLoaderMixin(LoraBaseMixin):
5257
5270
  if not is_correct_format:
5258
5271
  raise ValueError("Invalid LoRA checkpoint.")
5259
5272
 
5260
- self.load_lora_into_transformer(
5261
- state_dict,
5262
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5263
- adapter_name=adapter_name,
5264
- metadata=metadata,
5265
- _pipeline=self,
5266
- low_cpu_mem_usage=low_cpu_mem_usage,
5267
- hotswap=hotswap,
5268
- )
5273
+ load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5274
+ if load_into_transformer_2:
5275
+ if not hasattr(self, "transformer_2"):
5276
+ raise AttributeError(
5277
+ f"'{type(self).__name__}' object has no attribute transformer_2"
5278
+ "Note that Wan2.1 models do not have a transformer_2 component."
5279
+ "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5280
+ )
5281
+ self.load_lora_into_transformer(
5282
+ state_dict,
5283
+ transformer=self.transformer_2,
5284
+ adapter_name=adapter_name,
5285
+ metadata=metadata,
5286
+ _pipeline=self,
5287
+ low_cpu_mem_usage=low_cpu_mem_usage,
5288
+ hotswap=hotswap,
5289
+ )
5290
+ else:
5291
+ self.load_lora_into_transformer(
5292
+ state_dict,
5293
+ transformer=getattr(self, self.transformer_name)
5294
+ if not hasattr(self, "transformer")
5295
+ else self.transformer,
5296
+ adapter_name=adapter_name,
5297
+ metadata=metadata,
5298
+ _pipeline=self,
5299
+ low_cpu_mem_usage=low_cpu_mem_usage,
5300
+ hotswap=hotswap,
5301
+ )
5269
5302
 
5270
5303
  @classmethod
5271
5304
  # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->WanTransformer3DModel
@@ -5442,9 +5475,9 @@ class WanLoraLoaderMixin(LoraBaseMixin):
5442
5475
  super().unfuse_lora(components=components, **kwargs)
5443
5476
 
5444
5477
 
5445
- class CogView4LoraLoaderMixin(LoraBaseMixin):
5478
+ class SkyReelsV2LoraLoaderMixin(LoraBaseMixin):
5446
5479
  r"""
5447
- Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
5480
+ Load LoRA layers into [`SkyReelsV2Transformer3DModel`].
5448
5481
  """
5449
5482
 
5450
5483
  _lora_loadable_modules = ["transformer"]
@@ -5452,7 +5485,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5452
5485
 
5453
5486
  @classmethod
5454
5487
  @validate_hf_hub_args
5455
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
5488
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.lora_state_dict
5456
5489
  def lora_state_dict(
5457
5490
  cls,
5458
5491
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5503,7 +5536,6 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5503
5536
  The subfolder location of a model file within a larger model repository on the Hub or locally.
5504
5537
  return_lora_metadata (`bool`, *optional*, defaults to False):
5505
5538
  When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
5506
-
5507
5539
  """
5508
5540
  # Load the main state dict first which has the LoRA layers for either of
5509
5541
  # transformer and text encoder or both.
@@ -5539,6 +5571,10 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5539
5571
  user_agent=user_agent,
5540
5572
  allow_pickle=allow_pickle,
5541
5573
  )
5574
+ if any(k.startswith("diffusion_model.") for k in state_dict):
5575
+ state_dict = _convert_non_diffusers_wan_lora_to_diffusers(state_dict)
5576
+ elif any(k.startswith("lora_unet_") for k in state_dict):
5577
+ state_dict = _convert_musubi_wan_lora_to_diffusers(state_dict)
5542
5578
 
5543
5579
  is_dora_scale_present = any("dora_scale" in k for k in state_dict)
5544
5580
  if is_dora_scale_present:
@@ -5549,7 +5585,56 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5549
5585
  out = (state_dict, metadata) if return_lora_metadata else state_dict
5550
5586
  return out
5551
5587
 
5552
- # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
5588
+ @classmethod
5589
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin._maybe_expand_t2v_lora_for_i2v
5590
+ def _maybe_expand_t2v_lora_for_i2v(
5591
+ cls,
5592
+ transformer: torch.nn.Module,
5593
+ state_dict,
5594
+ ):
5595
+ if transformer.config.image_dim is None:
5596
+ return state_dict
5597
+
5598
+ target_device = transformer.device
5599
+
5600
+ if any(k.startswith("transformer.blocks.") for k in state_dict):
5601
+ num_blocks = len({k.split("blocks.")[1].split(".")[0] for k in state_dict if "blocks." in k})
5602
+ is_i2v_lora = any("add_k_proj" in k for k in state_dict) and any("add_v_proj" in k for k in state_dict)
5603
+ has_bias = any(".lora_B.bias" in k for k in state_dict)
5604
+
5605
+ if is_i2v_lora:
5606
+ return state_dict
5607
+
5608
+ for i in range(num_blocks):
5609
+ for o, c in zip(["k_img", "v_img"], ["add_k_proj", "add_v_proj"]):
5610
+ # These keys should exist if the block `i` was part of the T2V LoRA.
5611
+ ref_key_lora_A = f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"
5612
+ ref_key_lora_B = f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"
5613
+
5614
+ if ref_key_lora_A not in state_dict or ref_key_lora_B not in state_dict:
5615
+ continue
5616
+
5617
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_A.weight"] = torch.zeros_like(
5618
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_A.weight"], device=target_device
5619
+ )
5620
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.weight"] = torch.zeros_like(
5621
+ state_dict[f"transformer.blocks.{i}.attn2.to_k.lora_B.weight"], device=target_device
5622
+ )
5623
+
5624
+ # If the original LoRA had biases (indicated by has_bias)
5625
+ # AND the specific reference bias key exists for this block.
5626
+
5627
+ ref_key_lora_B_bias = f"transformer.blocks.{i}.attn2.to_k.lora_B.bias"
5628
+ if has_bias and ref_key_lora_B_bias in state_dict:
5629
+ ref_lora_B_bias_tensor = state_dict[ref_key_lora_B_bias]
5630
+ state_dict[f"transformer.blocks.{i}.attn2.{c}.lora_B.bias"] = torch.zeros_like(
5631
+ ref_lora_B_bias_tensor,
5632
+ device=target_device,
5633
+ )
5634
+
5635
+ return state_dict
5636
+
5637
+ # Copied from diffusers.loaders.lora_pipeline.WanLoraLoaderMixin.load_lora_weights
5553
5638
  def load_lora_weights(
5554
5639
  self,
5555
5640
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5594,23 +5679,47 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5594
5679
  # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
5595
5680
  kwargs["return_lora_metadata"] = True
5596
5681
  state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
5597
-
5682
+ # convert T2V LoRA to I2V LoRA (when loaded to Wan I2V) by adding zeros for the additional (missing) _img layers
5683
+ state_dict = self._maybe_expand_t2v_lora_for_i2v(
5684
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5685
+ state_dict=state_dict,
5686
+ )
5598
5687
  is_correct_format = all("lora" in key for key in state_dict.keys())
5599
5688
  if not is_correct_format:
5600
5689
  raise ValueError("Invalid LoRA checkpoint.")
5601
5690
 
5602
- self.load_lora_into_transformer(
5603
- state_dict,
5604
- transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
5605
- adapter_name=adapter_name,
5606
- metadata=metadata,
5607
- _pipeline=self,
5608
- low_cpu_mem_usage=low_cpu_mem_usage,
5609
- hotswap=hotswap,
5610
- )
5691
+ load_into_transformer_2 = kwargs.pop("load_into_transformer_2", False)
5692
+ if load_into_transformer_2:
5693
+ if not hasattr(self, "transformer_2"):
5694
+ raise AttributeError(
5695
+ f"'{type(self).__name__}' object has no attribute transformer_2"
5696
+ "Note that Wan2.1 models do not have a transformer_2 component."
5697
+ "Ensure the model has a transformer_2 component before setting load_into_transformer_2=True."
5698
+ )
5699
+ self.load_lora_into_transformer(
5700
+ state_dict,
5701
+ transformer=self.transformer_2,
5702
+ adapter_name=adapter_name,
5703
+ metadata=metadata,
5704
+ _pipeline=self,
5705
+ low_cpu_mem_usage=low_cpu_mem_usage,
5706
+ hotswap=hotswap,
5707
+ )
5708
+ else:
5709
+ self.load_lora_into_transformer(
5710
+ state_dict,
5711
+ transformer=getattr(self, self.transformer_name)
5712
+ if not hasattr(self, "transformer")
5713
+ else self.transformer,
5714
+ adapter_name=adapter_name,
5715
+ metadata=metadata,
5716
+ _pipeline=self,
5717
+ low_cpu_mem_usage=low_cpu_mem_usage,
5718
+ hotswap=hotswap,
5719
+ )
5611
5720
 
5612
5721
  @classmethod
5613
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
5722
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->SkyReelsV2Transformer3DModel
5614
5723
  def load_lora_into_transformer(
5615
5724
  cls,
5616
5725
  state_dict,
@@ -5629,7 +5738,7 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5629
5738
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5630
5739
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5631
5740
  encoder lora layers.
5632
- transformer (`CogView4Transformer2DModel`):
5741
+ transformer (`SkyReelsV2Transformer3DModel`):
5633
5742
  The Transformer model to load the LoRA layers into.
5634
5743
  adapter_name (`str`, *optional*):
5635
5744
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -5784,9 +5893,9 @@ class CogView4LoraLoaderMixin(LoraBaseMixin):
5784
5893
  super().unfuse_lora(components=components, **kwargs)
5785
5894
 
5786
5895
 
5787
- class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5896
+ class CogView4LoraLoaderMixin(LoraBaseMixin):
5788
5897
  r"""
5789
- Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
5898
+ Load LoRA layers into [`WanTransformer3DModel`]. Specific to [`CogView4Pipeline`].
5790
5899
  """
5791
5900
 
5792
5901
  _lora_loadable_modules = ["transformer"]
@@ -5794,6 +5903,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5794
5903
 
5795
5904
  @classmethod
5796
5905
  @validate_hf_hub_args
5906
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.lora_state_dict
5797
5907
  def lora_state_dict(
5798
5908
  cls,
5799
5909
  pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
@@ -5844,6 +5954,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5844
5954
  The subfolder location of a model file within a larger model repository on the Hub or locally.
5845
5955
  return_lora_metadata (`bool`, *optional*, defaults to False):
5846
5956
  When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
5957
+
5847
5958
  """
5848
5959
  # Load the main state dict first which has the LoRA layers for either of
5849
5960
  # transformer and text encoder or both.
@@ -5886,10 +5997,6 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5886
5997
  logger.warning(warn_msg)
5887
5998
  state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
5888
5999
 
5889
- is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
5890
- if is_non_diffusers_format:
5891
- state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
5892
-
5893
6000
  out = (state_dict, metadata) if return_lora_metadata else state_dict
5894
6001
  return out
5895
6002
 
@@ -5954,7 +6061,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5954
6061
  )
5955
6062
 
5956
6063
  @classmethod
5957
- # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
6064
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->CogView4Transformer2DModel
5958
6065
  def load_lora_into_transformer(
5959
6066
  cls,
5960
6067
  state_dict,
@@ -5973,7 +6080,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
5973
6080
  A standard state dict containing the lora layer parameters. The keys can either be indexed directly
5974
6081
  into the unet or prefixed with an additional `unet` which can be used to distinguish between text
5975
6082
  encoder lora layers.
5976
- transformer (`HiDreamImageTransformer2DModel`):
6083
+ transformer (`CogView4Transformer2DModel`):
5977
6084
  The Transformer model to load the LoRA layers into.
5978
6085
  adapter_name (`str`, *optional*):
5979
6086
  Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
@@ -6061,7 +6168,7 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
6061
6168
  lora_adapter_metadata=lora_adapter_metadata,
6062
6169
  )
6063
6170
 
6064
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
6171
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
6065
6172
  def fuse_lora(
6066
6173
  self,
6067
6174
  components: List[str] = ["transformer"],
@@ -6109,7 +6216,697 @@ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
6109
6216
  **kwargs,
6110
6217
  )
6111
6218
 
6112
- # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
6219
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
6220
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
6221
+ r"""
6222
+ Reverses the effect of
6223
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
6224
+
6225
+ <Tip warning={true}>
6226
+
6227
+ This is an experimental API.
6228
+
6229
+ </Tip>
6230
+
6231
+ Args:
6232
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
6233
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
6234
+ """
6235
+ super().unfuse_lora(components=components, **kwargs)
6236
+
6237
+
6238
+ class HiDreamImageLoraLoaderMixin(LoraBaseMixin):
6239
+ r"""
6240
+ Load LoRA layers into [`HiDreamImageTransformer2DModel`]. Specific to [`HiDreamImagePipeline`].
6241
+ """
6242
+
6243
+ _lora_loadable_modules = ["transformer"]
6244
+ transformer_name = TRANSFORMER_NAME
6245
+
6246
+ @classmethod
6247
+ @validate_hf_hub_args
6248
+ def lora_state_dict(
6249
+ cls,
6250
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
6251
+ **kwargs,
6252
+ ):
6253
+ r"""
6254
+ Return state dict for lora weights and the network alphas.
6255
+
6256
+ <Tip warning={true}>
6257
+
6258
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
6259
+
6260
+ This function is experimental and might change in the future.
6261
+
6262
+ </Tip>
6263
+
6264
+ Parameters:
6265
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
6266
+ Can be either:
6267
+
6268
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
6269
+ the Hub.
6270
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
6271
+ with [`ModelMixin.save_pretrained`].
6272
+ - A [torch state
6273
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
6274
+
6275
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
6276
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
6277
+ is not used.
6278
+ force_download (`bool`, *optional*, defaults to `False`):
6279
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
6280
+ cached versions if they exist.
6281
+
6282
+ proxies (`Dict[str, str]`, *optional*):
6283
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
6284
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
6285
+ local_files_only (`bool`, *optional*, defaults to `False`):
6286
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
6287
+ won't be downloaded from the Hub.
6288
+ token (`str` or *bool*, *optional*):
6289
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
6290
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
6291
+ revision (`str`, *optional*, defaults to `"main"`):
6292
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
6293
+ allowed by Git.
6294
+ subfolder (`str`, *optional*, defaults to `""`):
6295
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
6296
+ return_lora_metadata (`bool`, *optional*, defaults to False):
6297
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
6298
+ """
6299
+ # Load the main state dict first which has the LoRA layers for either of
6300
+ # transformer and text encoder or both.
6301
+ cache_dir = kwargs.pop("cache_dir", None)
6302
+ force_download = kwargs.pop("force_download", False)
6303
+ proxies = kwargs.pop("proxies", None)
6304
+ local_files_only = kwargs.pop("local_files_only", None)
6305
+ token = kwargs.pop("token", None)
6306
+ revision = kwargs.pop("revision", None)
6307
+ subfolder = kwargs.pop("subfolder", None)
6308
+ weight_name = kwargs.pop("weight_name", None)
6309
+ use_safetensors = kwargs.pop("use_safetensors", None)
6310
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
6311
+
6312
+ allow_pickle = False
6313
+ if use_safetensors is None:
6314
+ use_safetensors = True
6315
+ allow_pickle = True
6316
+
6317
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
6318
+
6319
+ state_dict, metadata = _fetch_state_dict(
6320
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
6321
+ weight_name=weight_name,
6322
+ use_safetensors=use_safetensors,
6323
+ local_files_only=local_files_only,
6324
+ cache_dir=cache_dir,
6325
+ force_download=force_download,
6326
+ proxies=proxies,
6327
+ token=token,
6328
+ revision=revision,
6329
+ subfolder=subfolder,
6330
+ user_agent=user_agent,
6331
+ allow_pickle=allow_pickle,
6332
+ )
6333
+
6334
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
6335
+ if is_dora_scale_present:
6336
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
6337
+ logger.warning(warn_msg)
6338
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
6339
+
6340
+ is_non_diffusers_format = any("diffusion_model" in k for k in state_dict)
6341
+ if is_non_diffusers_format:
6342
+ state_dict = _convert_non_diffusers_hidream_lora_to_diffusers(state_dict)
6343
+
6344
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
6345
+ return out
6346
+
6347
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
6348
+ def load_lora_weights(
6349
+ self,
6350
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
6351
+ adapter_name: Optional[str] = None,
6352
+ hotswap: bool = False,
6353
+ **kwargs,
6354
+ ):
6355
+ """
6356
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
6357
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
6358
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
6359
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
6360
+ dict is loaded into `self.transformer`.
6361
+
6362
+ Parameters:
6363
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
6364
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
6365
+ adapter_name (`str`, *optional*):
6366
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
6367
+ `default_{i}` where i is the total number of adapters being loaded.
6368
+ low_cpu_mem_usage (`bool`, *optional*):
6369
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
6370
+ weights.
6371
+ hotswap (`bool`, *optional*):
6372
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
6373
+ kwargs (`dict`, *optional*):
6374
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
6375
+ """
6376
+ if not USE_PEFT_BACKEND:
6377
+ raise ValueError("PEFT backend is required for this method.")
6378
+
6379
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
6380
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
6381
+ raise ValueError(
6382
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
6383
+ )
6384
+
6385
+ # if a dict is passed, copy it instead of modifying it inplace
6386
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
6387
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
6388
+
6389
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
6390
+ kwargs["return_lora_metadata"] = True
6391
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
6392
+
6393
+ is_correct_format = all("lora" in key for key in state_dict.keys())
6394
+ if not is_correct_format:
6395
+ raise ValueError("Invalid LoRA checkpoint.")
6396
+
6397
+ self.load_lora_into_transformer(
6398
+ state_dict,
6399
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
6400
+ adapter_name=adapter_name,
6401
+ metadata=metadata,
6402
+ _pipeline=self,
6403
+ low_cpu_mem_usage=low_cpu_mem_usage,
6404
+ hotswap=hotswap,
6405
+ )
6406
+
6407
+ @classmethod
6408
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->HiDreamImageTransformer2DModel
6409
+ def load_lora_into_transformer(
6410
+ cls,
6411
+ state_dict,
6412
+ transformer,
6413
+ adapter_name=None,
6414
+ _pipeline=None,
6415
+ low_cpu_mem_usage=False,
6416
+ hotswap: bool = False,
6417
+ metadata=None,
6418
+ ):
6419
+ """
6420
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
6421
+
6422
+ Parameters:
6423
+ state_dict (`dict`):
6424
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
6425
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
6426
+ encoder lora layers.
6427
+ transformer (`HiDreamImageTransformer2DModel`):
6428
+ The Transformer model to load the LoRA layers into.
6429
+ adapter_name (`str`, *optional*):
6430
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
6431
+ `default_{i}` where i is the total number of adapters being loaded.
6432
+ low_cpu_mem_usage (`bool`, *optional*):
6433
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
6434
+ weights.
6435
+ hotswap (`bool`, *optional*):
6436
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
6437
+ metadata (`dict`):
6438
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
6439
+ from the state dict.
6440
+ """
6441
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
6442
+ raise ValueError(
6443
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
6444
+ )
6445
+
6446
+ # Load the layers corresponding to transformer.
6447
+ logger.info(f"Loading {cls.transformer_name}.")
6448
+ transformer.load_lora_adapter(
6449
+ state_dict,
6450
+ network_alphas=None,
6451
+ adapter_name=adapter_name,
6452
+ metadata=metadata,
6453
+ _pipeline=_pipeline,
6454
+ low_cpu_mem_usage=low_cpu_mem_usage,
6455
+ hotswap=hotswap,
6456
+ )
6457
+
6458
+ @classmethod
6459
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
6460
+ def save_lora_weights(
6461
+ cls,
6462
+ save_directory: Union[str, os.PathLike],
6463
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
6464
+ is_main_process: bool = True,
6465
+ weight_name: str = None,
6466
+ save_function: Callable = None,
6467
+ safe_serialization: bool = True,
6468
+ transformer_lora_adapter_metadata: Optional[dict] = None,
6469
+ ):
6470
+ r"""
6471
+ Save the LoRA parameters corresponding to the transformer.
6472
+
6473
+ Arguments:
6474
+ save_directory (`str` or `os.PathLike`):
6475
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
6476
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
6477
+ State dict of the LoRA layers corresponding to the `transformer`.
6478
+ is_main_process (`bool`, *optional*, defaults to `True`):
6479
+ Whether the process calling this is the main process or not. Useful during distributed training and you
6480
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
6481
+ process to avoid race conditions.
6482
+ save_function (`Callable`):
6483
+ The function to use to save the state dictionary. Useful during distributed training when you need to
6484
+ replace `torch.save` with another method. Can be configured with the environment variable
6485
+ `DIFFUSERS_SAVE_MODE`.
6486
+ safe_serialization (`bool`, *optional*, defaults to `True`):
6487
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
6488
+ transformer_lora_adapter_metadata:
6489
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
6490
+ """
6491
+ state_dict = {}
6492
+ lora_adapter_metadata = {}
6493
+
6494
+ if not transformer_lora_layers:
6495
+ raise ValueError("You must pass `transformer_lora_layers`.")
6496
+
6497
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
6498
+
6499
+ if transformer_lora_adapter_metadata is not None:
6500
+ lora_adapter_metadata.update(
6501
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
6502
+ )
6503
+
6504
+ # Save the model
6505
+ cls.write_lora_layers(
6506
+ state_dict=state_dict,
6507
+ save_directory=save_directory,
6508
+ is_main_process=is_main_process,
6509
+ weight_name=weight_name,
6510
+ save_function=save_function,
6511
+ safe_serialization=safe_serialization,
6512
+ lora_adapter_metadata=lora_adapter_metadata,
6513
+ )
6514
+
6515
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.fuse_lora
6516
+ def fuse_lora(
6517
+ self,
6518
+ components: List[str] = ["transformer"],
6519
+ lora_scale: float = 1.0,
6520
+ safe_fusing: bool = False,
6521
+ adapter_names: Optional[List[str]] = None,
6522
+ **kwargs,
6523
+ ):
6524
+ r"""
6525
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
6526
+
6527
+ <Tip warning={true}>
6528
+
6529
+ This is an experimental API.
6530
+
6531
+ </Tip>
6532
+
6533
+ Args:
6534
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
6535
+ lora_scale (`float`, defaults to 1.0):
6536
+ Controls how much to influence the outputs with the LoRA parameters.
6537
+ safe_fusing (`bool`, defaults to `False`):
6538
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
6539
+ adapter_names (`List[str]`, *optional*):
6540
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
6541
+
6542
+ Example:
6543
+
6544
+ ```py
6545
+ from diffusers import DiffusionPipeline
6546
+ import torch
6547
+
6548
+ pipeline = DiffusionPipeline.from_pretrained(
6549
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
6550
+ ).to("cuda")
6551
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
6552
+ pipeline.fuse_lora(lora_scale=0.7)
6553
+ ```
6554
+ """
6555
+ super().fuse_lora(
6556
+ components=components,
6557
+ lora_scale=lora_scale,
6558
+ safe_fusing=safe_fusing,
6559
+ adapter_names=adapter_names,
6560
+ **kwargs,
6561
+ )
6562
+
6563
+ # Copied from diffusers.loaders.lora_pipeline.SanaLoraLoaderMixin.unfuse_lora
6564
+ def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
6565
+ r"""
6566
+ Reverses the effect of
6567
+ [`pipe.fuse_lora()`](https://huggingface.co/docs/diffusers/main/en/api/loaders#diffusers.loaders.LoraBaseMixin.fuse_lora).
6568
+
6569
+ <Tip warning={true}>
6570
+
6571
+ This is an experimental API.
6572
+
6573
+ </Tip>
6574
+
6575
+ Args:
6576
+ components (`List[str]`): List of LoRA-injectable components to unfuse LoRA from.
6577
+ unfuse_transformer (`bool`, defaults to `True`): Whether to unfuse the UNet LoRA parameters.
6578
+ """
6579
+ super().unfuse_lora(components=components, **kwargs)
6580
+
6581
+
6582
+ class QwenImageLoraLoaderMixin(LoraBaseMixin):
6583
+ r"""
6584
+ Load LoRA layers into [`QwenImageTransformer2DModel`]. Specific to [`QwenImagePipeline`].
6585
+ """
6586
+
6587
+ _lora_loadable_modules = ["transformer"]
6588
+ transformer_name = TRANSFORMER_NAME
6589
+
6590
+ @classmethod
6591
+ @validate_hf_hub_args
6592
+ def lora_state_dict(
6593
+ cls,
6594
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
6595
+ **kwargs,
6596
+ ):
6597
+ r"""
6598
+ Return state dict for lora weights and the network alphas.
6599
+
6600
+ <Tip warning={true}>
6601
+
6602
+ We support loading A1111 formatted LoRA checkpoints in a limited capacity.
6603
+
6604
+ This function is experimental and might change in the future.
6605
+
6606
+ </Tip>
6607
+
6608
+ Parameters:
6609
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
6610
+ Can be either:
6611
+
6612
+ - A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
6613
+ the Hub.
6614
+ - A path to a *directory* (for example `./my_model_directory`) containing the model weights saved
6615
+ with [`ModelMixin.save_pretrained`].
6616
+ - A [torch state
6617
+ dict](https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict).
6618
+
6619
+ cache_dir (`Union[str, os.PathLike]`, *optional*):
6620
+ Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
6621
+ is not used.
6622
+ force_download (`bool`, *optional*, defaults to `False`):
6623
+ Whether or not to force the (re-)download of the model weights and configuration files, overriding the
6624
+ cached versions if they exist.
6625
+
6626
+ proxies (`Dict[str, str]`, *optional*):
6627
+ A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
6628
+ 'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
6629
+ local_files_only (`bool`, *optional*, defaults to `False`):
6630
+ Whether to only load local model weights and configuration files or not. If set to `True`, the model
6631
+ won't be downloaded from the Hub.
6632
+ token (`str` or *bool*, *optional*):
6633
+ The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
6634
+ `diffusers-cli login` (stored in `~/.huggingface`) is used.
6635
+ revision (`str`, *optional*, defaults to `"main"`):
6636
+ The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
6637
+ allowed by Git.
6638
+ subfolder (`str`, *optional*, defaults to `""`):
6639
+ The subfolder location of a model file within a larger model repository on the Hub or locally.
6640
+ return_lora_metadata (`bool`, *optional*, defaults to False):
6641
+ When enabled, additionally return the LoRA adapter metadata, typically found in the state dict.
6642
+
6643
+ """
6644
+ # Load the main state dict first which has the LoRA layers for either of
6645
+ # transformer and text encoder or both.
6646
+ cache_dir = kwargs.pop("cache_dir", None)
6647
+ force_download = kwargs.pop("force_download", False)
6648
+ proxies = kwargs.pop("proxies", None)
6649
+ local_files_only = kwargs.pop("local_files_only", None)
6650
+ token = kwargs.pop("token", None)
6651
+ revision = kwargs.pop("revision", None)
6652
+ subfolder = kwargs.pop("subfolder", None)
6653
+ weight_name = kwargs.pop("weight_name", None)
6654
+ use_safetensors = kwargs.pop("use_safetensors", None)
6655
+ return_lora_metadata = kwargs.pop("return_lora_metadata", False)
6656
+
6657
+ allow_pickle = False
6658
+ if use_safetensors is None:
6659
+ use_safetensors = True
6660
+ allow_pickle = True
6661
+
6662
+ user_agent = {"file_type": "attn_procs_weights", "framework": "pytorch"}
6663
+
6664
+ state_dict, metadata = _fetch_state_dict(
6665
+ pretrained_model_name_or_path_or_dict=pretrained_model_name_or_path_or_dict,
6666
+ weight_name=weight_name,
6667
+ use_safetensors=use_safetensors,
6668
+ local_files_only=local_files_only,
6669
+ cache_dir=cache_dir,
6670
+ force_download=force_download,
6671
+ proxies=proxies,
6672
+ token=token,
6673
+ revision=revision,
6674
+ subfolder=subfolder,
6675
+ user_agent=user_agent,
6676
+ allow_pickle=allow_pickle,
6677
+ )
6678
+
6679
+ is_dora_scale_present = any("dora_scale" in k for k in state_dict)
6680
+ if is_dora_scale_present:
6681
+ warn_msg = "It seems like you are using a DoRA checkpoint that is not compatible in Diffusers at the moment. So, we are going to filter out the keys associated to 'dora_scale` from the state dict. If you think this is a mistake please open an issue https://github.com/huggingface/diffusers/issues/new."
6682
+ logger.warning(warn_msg)
6683
+ state_dict = {k: v for k, v in state_dict.items() if "dora_scale" not in k}
6684
+
6685
+ has_alphas_in_sd = any(k.endswith(".alpha") for k in state_dict)
6686
+ has_lora_unet = any(k.startswith("lora_unet_") for k in state_dict)
6687
+ if has_alphas_in_sd or has_lora_unet:
6688
+ state_dict = _convert_non_diffusers_qwen_lora_to_diffusers(state_dict)
6689
+
6690
+ out = (state_dict, metadata) if return_lora_metadata else state_dict
6691
+ return out
6692
+
6693
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.load_lora_weights
6694
+ def load_lora_weights(
6695
+ self,
6696
+ pretrained_model_name_or_path_or_dict: Union[str, Dict[str, torch.Tensor]],
6697
+ adapter_name: Optional[str] = None,
6698
+ hotswap: bool = False,
6699
+ **kwargs,
6700
+ ):
6701
+ """
6702
+ Load LoRA weights specified in `pretrained_model_name_or_path_or_dict` into `self.transformer` and
6703
+ `self.text_encoder`. All kwargs are forwarded to `self.lora_state_dict`. See
6704
+ [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`] for more details on how the state dict is loaded.
6705
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_into_transformer`] for more details on how the state
6706
+ dict is loaded into `self.transformer`.
6707
+
6708
+ Parameters:
6709
+ pretrained_model_name_or_path_or_dict (`str` or `os.PathLike` or `dict`):
6710
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
6711
+ adapter_name (`str`, *optional*):
6712
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
6713
+ `default_{i}` where i is the total number of adapters being loaded.
6714
+ low_cpu_mem_usage (`bool`, *optional*):
6715
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
6716
+ weights.
6717
+ hotswap (`bool`, *optional*):
6718
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
6719
+ kwargs (`dict`, *optional*):
6720
+ See [`~loaders.StableDiffusionLoraLoaderMixin.lora_state_dict`].
6721
+ """
6722
+ if not USE_PEFT_BACKEND:
6723
+ raise ValueError("PEFT backend is required for this method.")
6724
+
6725
+ low_cpu_mem_usage = kwargs.pop("low_cpu_mem_usage", _LOW_CPU_MEM_USAGE_DEFAULT_LORA)
6726
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
6727
+ raise ValueError(
6728
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
6729
+ )
6730
+
6731
+ # if a dict is passed, copy it instead of modifying it inplace
6732
+ if isinstance(pretrained_model_name_or_path_or_dict, dict):
6733
+ pretrained_model_name_or_path_or_dict = pretrained_model_name_or_path_or_dict.copy()
6734
+
6735
+ # First, ensure that the checkpoint is a compatible one and can be successfully loaded.
6736
+ kwargs["return_lora_metadata"] = True
6737
+ state_dict, metadata = self.lora_state_dict(pretrained_model_name_or_path_or_dict, **kwargs)
6738
+
6739
+ is_correct_format = all("lora" in key for key in state_dict.keys())
6740
+ if not is_correct_format:
6741
+ raise ValueError("Invalid LoRA checkpoint.")
6742
+
6743
+ self.load_lora_into_transformer(
6744
+ state_dict,
6745
+ transformer=getattr(self, self.transformer_name) if not hasattr(self, "transformer") else self.transformer,
6746
+ adapter_name=adapter_name,
6747
+ metadata=metadata,
6748
+ _pipeline=self,
6749
+ low_cpu_mem_usage=low_cpu_mem_usage,
6750
+ hotswap=hotswap,
6751
+ )
6752
+
6753
+ @classmethod
6754
+ # Copied from diffusers.loaders.lora_pipeline.SD3LoraLoaderMixin.load_lora_into_transformer with SD3Transformer2DModel->QwenImageTransformer2DModel
6755
+ def load_lora_into_transformer(
6756
+ cls,
6757
+ state_dict,
6758
+ transformer,
6759
+ adapter_name=None,
6760
+ _pipeline=None,
6761
+ low_cpu_mem_usage=False,
6762
+ hotswap: bool = False,
6763
+ metadata=None,
6764
+ ):
6765
+ """
6766
+ This will load the LoRA layers specified in `state_dict` into `transformer`.
6767
+
6768
+ Parameters:
6769
+ state_dict (`dict`):
6770
+ A standard state dict containing the lora layer parameters. The keys can either be indexed directly
6771
+ into the unet or prefixed with an additional `unet` which can be used to distinguish between text
6772
+ encoder lora layers.
6773
+ transformer (`QwenImageTransformer2DModel`):
6774
+ The Transformer model to load the LoRA layers into.
6775
+ adapter_name (`str`, *optional*):
6776
+ Adapter name to be used for referencing the loaded adapter model. If not specified, it will use
6777
+ `default_{i}` where i is the total number of adapters being loaded.
6778
+ low_cpu_mem_usage (`bool`, *optional*):
6779
+ Speed up model loading by only loading the pretrained LoRA weights and not initializing the random
6780
+ weights.
6781
+ hotswap (`bool`, *optional*):
6782
+ See [`~loaders.StableDiffusionLoraLoaderMixin.load_lora_weights`].
6783
+ metadata (`dict`):
6784
+ Optional LoRA adapter metadata. When supplied, the `LoraConfig` arguments of `peft` won't be derived
6785
+ from the state dict.
6786
+ """
6787
+ if low_cpu_mem_usage and is_peft_version("<", "0.13.0"):
6788
+ raise ValueError(
6789
+ "`low_cpu_mem_usage=True` is not compatible with this `peft` version. Please update it with `pip install -U peft`."
6790
+ )
6791
+
6792
+ # Load the layers corresponding to transformer.
6793
+ logger.info(f"Loading {cls.transformer_name}.")
6794
+ transformer.load_lora_adapter(
6795
+ state_dict,
6796
+ network_alphas=None,
6797
+ adapter_name=adapter_name,
6798
+ metadata=metadata,
6799
+ _pipeline=_pipeline,
6800
+ low_cpu_mem_usage=low_cpu_mem_usage,
6801
+ hotswap=hotswap,
6802
+ )
6803
+
6804
+ @classmethod
6805
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.save_lora_weights
6806
+ def save_lora_weights(
6807
+ cls,
6808
+ save_directory: Union[str, os.PathLike],
6809
+ transformer_lora_layers: Dict[str, Union[torch.nn.Module, torch.Tensor]] = None,
6810
+ is_main_process: bool = True,
6811
+ weight_name: str = None,
6812
+ save_function: Callable = None,
6813
+ safe_serialization: bool = True,
6814
+ transformer_lora_adapter_metadata: Optional[dict] = None,
6815
+ ):
6816
+ r"""
6817
+ Save the LoRA parameters corresponding to the transformer.
6818
+
6819
+ Arguments:
6820
+ save_directory (`str` or `os.PathLike`):
6821
+ Directory to save LoRA parameters to. Will be created if it doesn't exist.
6822
+ transformer_lora_layers (`Dict[str, torch.nn.Module]` or `Dict[str, torch.Tensor]`):
6823
+ State dict of the LoRA layers corresponding to the `transformer`.
6824
+ is_main_process (`bool`, *optional*, defaults to `True`):
6825
+ Whether the process calling this is the main process or not. Useful during distributed training and you
6826
+ need to call this function on all processes. In this case, set `is_main_process=True` only on the main
6827
+ process to avoid race conditions.
6828
+ save_function (`Callable`):
6829
+ The function to use to save the state dictionary. Useful during distributed training when you need to
6830
+ replace `torch.save` with another method. Can be configured with the environment variable
6831
+ `DIFFUSERS_SAVE_MODE`.
6832
+ safe_serialization (`bool`, *optional*, defaults to `True`):
6833
+ Whether to save the model using `safetensors` or the traditional PyTorch way with `pickle`.
6834
+ transformer_lora_adapter_metadata:
6835
+ LoRA adapter metadata associated with the transformer to be serialized with the state dict.
6836
+ """
6837
+ state_dict = {}
6838
+ lora_adapter_metadata = {}
6839
+
6840
+ if not transformer_lora_layers:
6841
+ raise ValueError("You must pass `transformer_lora_layers`.")
6842
+
6843
+ state_dict.update(cls.pack_weights(transformer_lora_layers, cls.transformer_name))
6844
+
6845
+ if transformer_lora_adapter_metadata is not None:
6846
+ lora_adapter_metadata.update(
6847
+ _pack_dict_with_prefix(transformer_lora_adapter_metadata, cls.transformer_name)
6848
+ )
6849
+
6850
+ # Save the model
6851
+ cls.write_lora_layers(
6852
+ state_dict=state_dict,
6853
+ save_directory=save_directory,
6854
+ is_main_process=is_main_process,
6855
+ weight_name=weight_name,
6856
+ save_function=save_function,
6857
+ safe_serialization=safe_serialization,
6858
+ lora_adapter_metadata=lora_adapter_metadata,
6859
+ )
6860
+
6861
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.fuse_lora
6862
+ def fuse_lora(
6863
+ self,
6864
+ components: List[str] = ["transformer"],
6865
+ lora_scale: float = 1.0,
6866
+ safe_fusing: bool = False,
6867
+ adapter_names: Optional[List[str]] = None,
6868
+ **kwargs,
6869
+ ):
6870
+ r"""
6871
+ Fuses the LoRA parameters into the original parameters of the corresponding blocks.
6872
+
6873
+ <Tip warning={true}>
6874
+
6875
+ This is an experimental API.
6876
+
6877
+ </Tip>
6878
+
6879
+ Args:
6880
+ components: (`List[str]`): List of LoRA-injectable components to fuse the LoRAs into.
6881
+ lora_scale (`float`, defaults to 1.0):
6882
+ Controls how much to influence the outputs with the LoRA parameters.
6883
+ safe_fusing (`bool`, defaults to `False`):
6884
+ Whether to check fused weights for NaN values before fusing and if values are NaN not fusing them.
6885
+ adapter_names (`List[str]`, *optional*):
6886
+ Adapter names to be used for fusing. If nothing is passed, all active adapters will be fused.
6887
+
6888
+ Example:
6889
+
6890
+ ```py
6891
+ from diffusers import DiffusionPipeline
6892
+ import torch
6893
+
6894
+ pipeline = DiffusionPipeline.from_pretrained(
6895
+ "stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16
6896
+ ).to("cuda")
6897
+ pipeline.load_lora_weights("nerijs/pixel-art-xl", weight_name="pixel-art-xl.safetensors", adapter_name="pixel")
6898
+ pipeline.fuse_lora(lora_scale=0.7)
6899
+ ```
6900
+ """
6901
+ super().fuse_lora(
6902
+ components=components,
6903
+ lora_scale=lora_scale,
6904
+ safe_fusing=safe_fusing,
6905
+ adapter_names=adapter_names,
6906
+ **kwargs,
6907
+ )
6908
+
6909
+ # Copied from diffusers.loaders.lora_pipeline.CogVideoXLoraLoaderMixin.unfuse_lora
6113
6910
  def unfuse_lora(self, components: List[str] = ["transformer"], **kwargs):
6114
6911
  r"""
6115
6912
  Reverses the effect of