diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -383,7 +383,8 @@ class OnnxStableDiffusionPipeline(DiffusionPipeline):
383
383
  # set timesteps
384
384
  self.scheduler.set_timesteps(num_inference_steps)
385
385
 
386
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
386
+ # scale the initial noise by the standard deviation required by the scheduler
387
+ latents = latents * self.scheduler.init_noise_sigma
387
388
 
388
389
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
389
390
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -483,7 +483,7 @@ class OnnxStableDiffusionInpaintPipeline(DiffusionPipeline):
483
483
  self.scheduler.set_timesteps(num_inference_steps)
484
484
 
485
485
  # scale the initial noise by the standard deviation required by the scheduler
486
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
486
+ latents = latents * self.scheduler.init_noise_sigma
487
487
 
488
488
  # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
489
489
  # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
@@ -481,7 +481,7 @@ class OnnxStableDiffusionUpscalePipeline(DiffusionPipeline):
481
481
  timesteps = self.scheduler.timesteps
482
482
 
483
483
  # Scale the initial noise by the standard deviation required by the scheduler
484
- latents = latents * np.float64(self.scheduler.init_noise_sigma)
484
+ latents = latents * self.scheduler.init_noise_sigma
485
485
 
486
486
  # 5. Add noise to image
487
487
  noise_level = np.array([noise_level]).astype(np.int64)
@@ -1034,7 +1034,8 @@ class StableDiffusionPipeline(
1034
1034
 
1035
1035
  # expand the latents if we are doing classifier free guidance
1036
1036
  latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
1037
- latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1037
+ if hasattr(self.scheduler, "scale_model_input"):
1038
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
1038
1039
 
1039
1040
  # predict the noise residual
1040
1041
  noise_pred = self.unet(
@@ -25,6 +25,7 @@ from transformers import (
25
25
  T5TokenizerFast,
26
26
  )
27
27
 
28
+ from ...callbacks import MultiPipelineCallbacks, PipelineCallback
28
29
  from ...image_processor import PipelineImageInput, VaeImageProcessor
29
30
  from ...loaders import FromSingleFileMixin, SD3IPAdapterMixin, SD3LoraLoaderMixin
30
31
  from ...models.autoencoders import AutoencoderKL
@@ -184,7 +185,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
184
185
 
185
186
  model_cpu_offload_seq = "text_encoder->text_encoder_2->text_encoder_3->image_encoder->transformer->vae"
186
187
  _optional_components = ["image_encoder", "feature_extractor"]
187
- _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds", "negative_pooled_prompt_embeds"]
188
+ _callback_tensor_inputs = ["latents", "prompt_embeds", "pooled_prompt_embeds"]
188
189
 
189
190
  def __init__(
190
191
  self,
@@ -923,6 +924,9 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
923
924
  height = height or self.default_sample_size * self.vae_scale_factor
924
925
  width = width or self.default_sample_size * self.vae_scale_factor
925
926
 
927
+ if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)):
928
+ callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs
929
+
926
930
  # 1. Check inputs. Raise error if not correct
927
931
  self.check_inputs(
928
932
  prompt,
@@ -1109,10 +1113,7 @@ class StableDiffusion3Pipeline(DiffusionPipeline, SD3LoraLoaderMixin, FromSingle
1109
1113
 
1110
1114
  latents = callback_outputs.pop("latents", latents)
1111
1115
  prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds)
1112
- negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds)
1113
- negative_pooled_prompt_embeds = callback_outputs.pop(
1114
- "negative_pooled_prompt_embeds", negative_pooled_prompt_embeds
1115
- )
1116
+ pooled_prompt_embeds = callback_outputs.pop("pooled_prompt_embeds", pooled_prompt_embeds)
1116
1117
 
1117
1118
  # call the callback, if provided
1118
1119
  if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
@@ -112,18 +112,31 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
112
112
  A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
113
113
  vae ([`AutoencoderKLWan`]):
114
114
  Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
115
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
116
+ Conditional Transformer to denoise the input latents during the low-noise stage. If provided, enables
117
+ two-stage denoising where `transformer` handles high-noise stages and `transformer_2` handles low-noise
118
+ stages. If not provided, only `transformer` is used.
119
+ boundary_ratio (`float`, *optional*, defaults to `None`):
120
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
121
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
122
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
123
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
115
124
  """
116
125
 
117
- model_cpu_offload_seq = "text_encoder->transformer->vae"
126
+ model_cpu_offload_seq = "text_encoder->transformer->transformer_2->vae"
118
127
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
128
+ _optional_components = ["transformer", "transformer_2"]
119
129
 
120
130
  def __init__(
121
131
  self,
122
132
  tokenizer: AutoTokenizer,
123
133
  text_encoder: UMT5EncoderModel,
124
- transformer: WanTransformer3DModel,
125
134
  vae: AutoencoderKLWan,
126
135
  scheduler: FlowMatchEulerDiscreteScheduler,
136
+ transformer: Optional[WanTransformer3DModel] = None,
137
+ transformer_2: Optional[WanTransformer3DModel] = None,
138
+ boundary_ratio: Optional[float] = None,
139
+ expand_timesteps: bool = False, # Wan2.2 ti2v
127
140
  ):
128
141
  super().__init__()
129
142
 
@@ -133,10 +146,12 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
133
146
  tokenizer=tokenizer,
134
147
  transformer=transformer,
135
148
  scheduler=scheduler,
149
+ transformer_2=transformer_2,
136
150
  )
137
-
138
- self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
139
- self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
151
+ self.register_to_config(boundary_ratio=boundary_ratio)
152
+ self.register_to_config(expand_timesteps=expand_timesteps)
153
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
154
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
140
155
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
141
156
 
142
157
  def _get_t5_prompt_embeds(
@@ -270,6 +285,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
270
285
  prompt_embeds=None,
271
286
  negative_prompt_embeds=None,
272
287
  callback_on_step_end_tensor_inputs=None,
288
+ guidance_scale_2=None,
273
289
  ):
274
290
  if height % 16 != 0 or width % 16 != 0:
275
291
  raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.")
@@ -302,6 +318,9 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
302
318
  ):
303
319
  raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
304
320
 
321
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
322
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
323
+
305
324
  def prepare_latents(
306
325
  self,
307
326
  batch_size: int,
@@ -369,6 +388,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
369
388
  num_frames: int = 81,
370
389
  num_inference_steps: int = 50,
371
390
  guidance_scale: float = 5.0,
391
+ guidance_scale_2: Optional[float] = None,
372
392
  num_videos_per_prompt: Optional[int] = 1,
373
393
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
374
394
  latents: Optional[torch.Tensor] = None,
@@ -407,6 +427,10 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
407
427
  of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
408
428
  `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
409
429
  the text `prompt`, usually at the expense of lower image quality.
430
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
431
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
432
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
433
+ and the pipeline's `boundary_ratio` are not None.
410
434
  num_videos_per_prompt (`int`, *optional*, defaults to 1):
411
435
  The number of images to generate per prompt.
412
436
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -461,6 +485,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
461
485
  prompt_embeds,
462
486
  negative_prompt_embeds,
463
487
  callback_on_step_end_tensor_inputs,
488
+ guidance_scale_2,
464
489
  )
465
490
 
466
491
  if num_frames % self.vae_scale_factor_temporal != 1:
@@ -470,7 +495,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
470
495
  num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
471
496
  num_frames = max(num_frames, 1)
472
497
 
498
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
499
+ guidance_scale_2 = guidance_scale
500
+
473
501
  self._guidance_scale = guidance_scale
502
+ self._guidance_scale_2 = guidance_scale_2
474
503
  self._attention_kwargs = attention_kwargs
475
504
  self._current_timestep = None
476
505
  self._interrupt = False
@@ -497,7 +526,7 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
497
526
  device=device,
498
527
  )
499
528
 
500
- transformer_dtype = self.transformer.dtype
529
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
501
530
  prompt_embeds = prompt_embeds.to(transformer_dtype)
502
531
  if negative_prompt_embeds is not None:
503
532
  negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
@@ -507,7 +536,11 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
507
536
  timesteps = self.scheduler.timesteps
508
537
 
509
538
  # 5. Prepare latent variables
510
- num_channels_latents = self.transformer.config.in_channels
539
+ num_channels_latents = (
540
+ self.transformer.config.in_channels
541
+ if self.transformer is not None
542
+ else self.transformer_2.config.in_channels
543
+ )
511
544
  latents = self.prepare_latents(
512
545
  batch_size * num_videos_per_prompt,
513
546
  num_channels_latents,
@@ -520,36 +553,61 @@ class WanPipeline(DiffusionPipeline, WanLoraLoaderMixin):
520
553
  latents,
521
554
  )
522
555
 
556
+ mask = torch.ones(latents.shape, dtype=torch.float32, device=device)
557
+
523
558
  # 6. Denoising loop
524
559
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
525
560
  self._num_timesteps = len(timesteps)
526
561
 
562
+ if self.config.boundary_ratio is not None:
563
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
564
+ else:
565
+ boundary_timestep = None
566
+
527
567
  with self.progress_bar(total=num_inference_steps) as progress_bar:
528
568
  for i, t in enumerate(timesteps):
529
569
  if self.interrupt:
530
570
  continue
531
571
 
532
572
  self._current_timestep = t
533
- latent_model_input = latents.to(transformer_dtype)
534
- timestep = t.expand(latents.shape[0])
535
573
 
536
- noise_pred = self.transformer(
537
- hidden_states=latent_model_input,
538
- timestep=timestep,
539
- encoder_hidden_states=prompt_embeds,
540
- attention_kwargs=attention_kwargs,
541
- return_dict=False,
542
- )[0]
574
+ if boundary_timestep is None or t >= boundary_timestep:
575
+ # wan2.1 or high-noise stage in wan2.2
576
+ current_model = self.transformer
577
+ current_guidance_scale = guidance_scale
578
+ else:
579
+ # low-noise stage in wan2.2
580
+ current_model = self.transformer_2
581
+ current_guidance_scale = guidance_scale_2
543
582
 
544
- if self.do_classifier_free_guidance:
545
- noise_uncond = self.transformer(
583
+ latent_model_input = latents.to(transformer_dtype)
584
+ if self.config.expand_timesteps:
585
+ # seq_len: num_latent_frames * latent_height//2 * latent_width//2
586
+ temp_ts = (mask[0][0][:, ::2, ::2] * t).flatten()
587
+ # batch_size, seq_len
588
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
589
+ else:
590
+ timestep = t.expand(latents.shape[0])
591
+
592
+ with current_model.cache_context("cond"):
593
+ noise_pred = current_model(
546
594
  hidden_states=latent_model_input,
547
595
  timestep=timestep,
548
- encoder_hidden_states=negative_prompt_embeds,
596
+ encoder_hidden_states=prompt_embeds,
549
597
  attention_kwargs=attention_kwargs,
550
598
  return_dict=False,
551
599
  )[0]
552
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
600
+
601
+ if self.do_classifier_free_guidance:
602
+ with current_model.cache_context("uncond"):
603
+ noise_uncond = current_model(
604
+ hidden_states=latent_model_input,
605
+ timestep=timestep,
606
+ encoder_hidden_states=negative_prompt_embeds,
607
+ attention_kwargs=attention_kwargs,
608
+ return_dict=False,
609
+ )[0]
610
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
553
611
 
554
612
  # compute the previous noisy sample x_t -> x_t-1
555
613
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -149,20 +149,33 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
149
149
  A scheduler to be used in combination with `transformer` to denoise the encoded image latents.
150
150
  vae ([`AutoencoderKLWan`]):
151
151
  Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations.
152
+ transformer_2 ([`WanTransformer3DModel`], *optional*):
153
+ Conditional Transformer to denoise the input latents during the low-noise stage. In two-stage denoising,
154
+ `transformer` handles high-noise stages and `transformer_2` handles low-noise stages. If not provided, only
155
+ `transformer` is used.
156
+ boundary_ratio (`float`, *optional*, defaults to `None`):
157
+ Ratio of total timesteps to use as the boundary for switching between transformers in two-stage denoising.
158
+ The actual boundary timestep is calculated as `boundary_ratio * num_train_timesteps`. When provided,
159
+ `transformer` handles timesteps >= boundary_timestep and `transformer_2` handles timesteps <
160
+ boundary_timestep. If `None`, only `transformer` is used for the entire denoising process.
152
161
  """
153
162
 
154
- model_cpu_offload_seq = "text_encoder->image_encoder->transformer->vae"
163
+ model_cpu_offload_seq = "text_encoder->image_encoder->transformer->transformer_2->vae"
155
164
  _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"]
165
+ _optional_components = ["transformer", "transformer_2", "image_encoder", "image_processor"]
156
166
 
157
167
  def __init__(
158
168
  self,
159
169
  tokenizer: AutoTokenizer,
160
170
  text_encoder: UMT5EncoderModel,
161
- image_encoder: CLIPVisionModel,
162
- image_processor: CLIPImageProcessor,
163
- transformer: WanTransformer3DModel,
164
171
  vae: AutoencoderKLWan,
165
172
  scheduler: FlowMatchEulerDiscreteScheduler,
173
+ image_processor: CLIPImageProcessor = None,
174
+ image_encoder: CLIPVisionModel = None,
175
+ transformer: WanTransformer3DModel = None,
176
+ transformer_2: WanTransformer3DModel = None,
177
+ boundary_ratio: Optional[float] = None,
178
+ expand_timesteps: bool = False,
166
179
  ):
167
180
  super().__init__()
168
181
 
@@ -174,10 +187,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
174
187
  transformer=transformer,
175
188
  scheduler=scheduler,
176
189
  image_processor=image_processor,
190
+ transformer_2=transformer_2,
177
191
  )
192
+ self.register_to_config(boundary_ratio=boundary_ratio, expand_timesteps=expand_timesteps)
178
193
 
179
- self.vae_scale_factor_temporal = 2 ** sum(self.vae.temperal_downsample) if getattr(self, "vae", None) else 4
180
- self.vae_scale_factor_spatial = 2 ** len(self.vae.temperal_downsample) if getattr(self, "vae", None) else 8
194
+ self.vae_scale_factor_temporal = self.vae.config.scale_factor_temporal if getattr(self, "vae", None) else 4
195
+ self.vae_scale_factor_spatial = self.vae.config.scale_factor_spatial if getattr(self, "vae", None) else 8
181
196
  self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial)
182
197
  self.image_processor = image_processor
183
198
 
@@ -325,6 +340,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
325
340
  negative_prompt_embeds=None,
326
341
  image_embeds=None,
327
342
  callback_on_step_end_tensor_inputs=None,
343
+ guidance_scale_2=None,
328
344
  ):
329
345
  if image is not None and image_embeds is not None:
330
346
  raise ValueError(
@@ -368,6 +384,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
368
384
  ):
369
385
  raise ValueError(f"`negative_prompt` has to be of type `str` or `list` but is {type(negative_prompt)}")
370
386
 
387
+ if self.config.boundary_ratio is None and guidance_scale_2 is not None:
388
+ raise ValueError("`guidance_scale_2` is only supported when the pipeline's `boundary_ratio` is not None.")
389
+
390
+ if self.config.boundary_ratio is not None and image_embeds is not None:
391
+ raise ValueError("Cannot forward `image_embeds` when the pipeline's `boundary_ratio` is not configured.")
392
+
371
393
  def prepare_latents(
372
394
  self,
373
395
  image: PipelineImageInput,
@@ -398,8 +420,12 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
398
420
  else:
399
421
  latents = latents.to(device=device, dtype=dtype)
400
422
 
401
- image = image.unsqueeze(2)
402
- if last_image is None:
423
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
424
+
425
+ if self.config.expand_timesteps:
426
+ video_condition = image
427
+
428
+ elif last_image is None:
403
429
  video_condition = torch.cat(
404
430
  [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
405
431
  )
@@ -432,6 +458,13 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
432
458
  latent_condition = latent_condition.to(dtype)
433
459
  latent_condition = (latent_condition - latents_mean) * latents_std
434
460
 
461
+ if self.config.expand_timesteps:
462
+ first_frame_mask = torch.ones(
463
+ 1, 1, num_latent_frames, latent_height, latent_width, dtype=dtype, device=device
464
+ )
465
+ first_frame_mask[:, :, 0] = 0
466
+ return latents, latent_condition, first_frame_mask
467
+
435
468
  mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
436
469
 
437
470
  if last_image is None:
@@ -483,6 +516,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
483
516
  num_frames: int = 81,
484
517
  num_inference_steps: int = 50,
485
518
  guidance_scale: float = 5.0,
519
+ guidance_scale_2: Optional[float] = None,
486
520
  num_videos_per_prompt: Optional[int] = 1,
487
521
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
488
522
  latents: Optional[torch.Tensor] = None,
@@ -527,6 +561,10 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
527
561
  of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting
528
562
  `guidance_scale > 1`. Higher guidance scale encourages to generate images that are closely linked to
529
563
  the text `prompt`, usually at the expense of lower image quality.
564
+ guidance_scale_2 (`float`, *optional*, defaults to `None`):
565
+ Guidance scale for the low-noise stage transformer (`transformer_2`). If `None` and the pipeline's
566
+ `boundary_ratio` is not None, uses the same value as `guidance_scale`. Only used when `transformer_2`
567
+ and the pipeline's `boundary_ratio` are not None.
530
568
  num_videos_per_prompt (`int`, *optional*, defaults to 1):
531
569
  The number of images to generate per prompt.
532
570
  generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
@@ -589,6 +627,7 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
589
627
  negative_prompt_embeds,
590
628
  image_embeds,
591
629
  callback_on_step_end_tensor_inputs,
630
+ guidance_scale_2,
592
631
  )
593
632
 
594
633
  if num_frames % self.vae_scale_factor_temporal != 1:
@@ -598,7 +637,11 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
598
637
  num_frames = num_frames // self.vae_scale_factor_temporal * self.vae_scale_factor_temporal + 1
599
638
  num_frames = max(num_frames, 1)
600
639
 
640
+ if self.config.boundary_ratio is not None and guidance_scale_2 is None:
641
+ guidance_scale_2 = guidance_scale
642
+
601
643
  self._guidance_scale = guidance_scale
644
+ self._guidance_scale_2 = guidance_scale_2
602
645
  self._attention_kwargs = attention_kwargs
603
646
  self._current_timestep = None
604
647
  self._interrupt = False
@@ -626,18 +669,20 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
626
669
  )
627
670
 
628
671
  # Encode image embedding
629
- transformer_dtype = self.transformer.dtype
672
+ transformer_dtype = self.transformer.dtype if self.transformer is not None else self.transformer_2.dtype
630
673
  prompt_embeds = prompt_embeds.to(transformer_dtype)
631
674
  if negative_prompt_embeds is not None:
632
675
  negative_prompt_embeds = negative_prompt_embeds.to(transformer_dtype)
633
676
 
634
- if image_embeds is None:
635
- if last_image is None:
636
- image_embeds = self.encode_image(image, device)
637
- else:
638
- image_embeds = self.encode_image([image, last_image], device)
639
- image_embeds = image_embeds.repeat(batch_size, 1, 1)
640
- image_embeds = image_embeds.to(transformer_dtype)
677
+ # only wan 2.1 i2v transformer accepts image_embeds
678
+ if self.transformer is not None and self.transformer.config.image_dim is not None:
679
+ if image_embeds is None:
680
+ if last_image is None:
681
+ image_embeds = self.encode_image(image, device)
682
+ else:
683
+ image_embeds = self.encode_image([image, last_image], device)
684
+ image_embeds = image_embeds.repeat(batch_size, 1, 1)
685
+ image_embeds = image_embeds.to(transformer_dtype)
641
686
 
642
687
  # 4. Prepare timesteps
643
688
  self.scheduler.set_timesteps(num_inference_steps, device=device)
@@ -650,7 +695,8 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
650
695
  last_image = self.video_processor.preprocess(last_image, height=height, width=width).to(
651
696
  device, dtype=torch.float32
652
697
  )
653
- latents, condition = self.prepare_latents(
698
+
699
+ latents_outputs = self.prepare_latents(
654
700
  image,
655
701
  batch_size * num_videos_per_prompt,
656
702
  num_channels_latents,
@@ -663,39 +709,70 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
663
709
  latents,
664
710
  last_image,
665
711
  )
712
+ if self.config.expand_timesteps:
713
+ # wan 2.2 5b i2v use firt_frame_mask to mask timesteps
714
+ latents, condition, first_frame_mask = latents_outputs
715
+ else:
716
+ latents, condition = latents_outputs
666
717
 
667
718
  # 6. Denoising loop
668
719
  num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
669
720
  self._num_timesteps = len(timesteps)
670
721
 
722
+ if self.config.boundary_ratio is not None:
723
+ boundary_timestep = self.config.boundary_ratio * self.scheduler.config.num_train_timesteps
724
+ else:
725
+ boundary_timestep = None
726
+
671
727
  with self.progress_bar(total=num_inference_steps) as progress_bar:
672
728
  for i, t in enumerate(timesteps):
673
729
  if self.interrupt:
674
730
  continue
675
731
 
676
732
  self._current_timestep = t
677
- latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
678
- timestep = t.expand(latents.shape[0])
679
-
680
- noise_pred = self.transformer(
681
- hidden_states=latent_model_input,
682
- timestep=timestep,
683
- encoder_hidden_states=prompt_embeds,
684
- encoder_hidden_states_image=image_embeds,
685
- attention_kwargs=attention_kwargs,
686
- return_dict=False,
687
- )[0]
688
733
 
689
- if self.do_classifier_free_guidance:
690
- noise_uncond = self.transformer(
734
+ if boundary_timestep is None or t >= boundary_timestep:
735
+ # wan2.1 or high-noise stage in wan2.2
736
+ current_model = self.transformer
737
+ current_guidance_scale = guidance_scale
738
+ else:
739
+ # low-noise stage in wan2.2
740
+ current_model = self.transformer_2
741
+ current_guidance_scale = guidance_scale_2
742
+
743
+ if self.config.expand_timesteps:
744
+ latent_model_input = (1 - first_frame_mask) * condition + first_frame_mask * latents
745
+ latent_model_input = latent_model_input.to(transformer_dtype)
746
+
747
+ # seq_len: num_latent_frames * (latent_height // patch_size) * (latent_width // patch_size)
748
+ temp_ts = (first_frame_mask[0][0][:, ::2, ::2] * t).flatten()
749
+ # batch_size, seq_len
750
+ timestep = temp_ts.unsqueeze(0).expand(latents.shape[0], -1)
751
+ else:
752
+ latent_model_input = torch.cat([latents, condition], dim=1).to(transformer_dtype)
753
+ timestep = t.expand(latents.shape[0])
754
+
755
+ with current_model.cache_context("cond"):
756
+ noise_pred = current_model(
691
757
  hidden_states=latent_model_input,
692
758
  timestep=timestep,
693
- encoder_hidden_states=negative_prompt_embeds,
759
+ encoder_hidden_states=prompt_embeds,
694
760
  encoder_hidden_states_image=image_embeds,
695
761
  attention_kwargs=attention_kwargs,
696
762
  return_dict=False,
697
763
  )[0]
698
- noise_pred = noise_uncond + guidance_scale * (noise_pred - noise_uncond)
764
+
765
+ if self.do_classifier_free_guidance:
766
+ with current_model.cache_context("uncond"):
767
+ noise_uncond = current_model(
768
+ hidden_states=latent_model_input,
769
+ timestep=timestep,
770
+ encoder_hidden_states=negative_prompt_embeds,
771
+ encoder_hidden_states_image=image_embeds,
772
+ attention_kwargs=attention_kwargs,
773
+ return_dict=False,
774
+ )[0]
775
+ noise_pred = noise_uncond + current_guidance_scale * (noise_pred - noise_uncond)
699
776
 
700
777
  # compute the previous noisy sample x_t -> x_t-1
701
778
  latents = self.scheduler.step(noise_pred, t, latents, return_dict=False)[0]
@@ -719,6 +796,9 @@ class WanImageToVideoPipeline(DiffusionPipeline, WanLoraLoaderMixin):
719
796
 
720
797
  self._current_timestep = None
721
798
 
799
+ if self.config.expand_timesteps:
800
+ latents = (1 - first_frame_mask) * condition + first_frame_mask * latents
801
+
722
802
  if not output_type == "latent":
723
803
  latents = latents.to(self.vae.dtype)
724
804
  latents_mean = (
@@ -525,8 +525,7 @@ class WanVACEPipeline(DiffusionPipeline, WanLoraLoaderMixin):
525
525
  latents = retrieve_latents(self.vae.encode(video), generator, sample_mode="argmax").unbind(0)
526
526
  latents = ((latents.float() - latents_mean) * latents_std).to(vae_dtype)
527
527
  else:
528
- mask = mask.to(dtype=vae_dtype)
529
- mask = torch.where(mask > 0.5, 1.0, 0.0)
528
+ mask = torch.where(mask > 0.5, 1.0, 0.0).to(dtype=vae_dtype)
530
529
  inactive = video * (1 - mask)
531
530
  reactive = video * mask
532
531
  inactive = retrieve_latents(self.vae.encode(inactive), generator, sample_mode="argmax")