diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (191) hide show
  1. diffusers/__init__.py +98 -1
  2. diffusers/callbacks.py +35 -0
  3. diffusers/commands/custom_blocks.py +134 -0
  4. diffusers/commands/diffusers_cli.py +2 -0
  5. diffusers/commands/fp16_safetensors.py +1 -1
  6. diffusers/configuration_utils.py +11 -2
  7. diffusers/dependency_versions_table.py +3 -3
  8. diffusers/guiders/__init__.py +41 -0
  9. diffusers/guiders/adaptive_projected_guidance.py +188 -0
  10. diffusers/guiders/auto_guidance.py +190 -0
  11. diffusers/guiders/classifier_free_guidance.py +141 -0
  12. diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
  13. diffusers/guiders/frequency_decoupled_guidance.py +327 -0
  14. diffusers/guiders/guider_utils.py +309 -0
  15. diffusers/guiders/perturbed_attention_guidance.py +271 -0
  16. diffusers/guiders/skip_layer_guidance.py +262 -0
  17. diffusers/guiders/smoothed_energy_guidance.py +251 -0
  18. diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
  19. diffusers/hooks/__init__.py +17 -0
  20. diffusers/hooks/_common.py +56 -0
  21. diffusers/hooks/_helpers.py +293 -0
  22. diffusers/hooks/faster_cache.py +7 -6
  23. diffusers/hooks/first_block_cache.py +259 -0
  24. diffusers/hooks/group_offloading.py +292 -286
  25. diffusers/hooks/hooks.py +56 -1
  26. diffusers/hooks/layer_skip.py +263 -0
  27. diffusers/hooks/layerwise_casting.py +2 -7
  28. diffusers/hooks/pyramid_attention_broadcast.py +14 -11
  29. diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
  30. diffusers/hooks/utils.py +43 -0
  31. diffusers/loaders/__init__.py +6 -0
  32. diffusers/loaders/ip_adapter.py +255 -4
  33. diffusers/loaders/lora_base.py +63 -30
  34. diffusers/loaders/lora_conversion_utils.py +434 -53
  35. diffusers/loaders/lora_pipeline.py +834 -37
  36. diffusers/loaders/peft.py +28 -5
  37. diffusers/loaders/single_file_model.py +44 -11
  38. diffusers/loaders/single_file_utils.py +170 -2
  39. diffusers/loaders/transformer_flux.py +9 -10
  40. diffusers/loaders/transformer_sd3.py +6 -1
  41. diffusers/loaders/unet.py +22 -5
  42. diffusers/loaders/unet_loader_utils.py +5 -2
  43. diffusers/models/__init__.py +8 -0
  44. diffusers/models/attention.py +484 -3
  45. diffusers/models/attention_dispatch.py +1218 -0
  46. diffusers/models/attention_processor.py +105 -663
  47. diffusers/models/auto_model.py +2 -2
  48. diffusers/models/autoencoders/__init__.py +1 -0
  49. diffusers/models/autoencoders/autoencoder_dc.py +14 -1
  50. diffusers/models/autoencoders/autoencoder_kl.py +1 -1
  51. diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
  52. diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
  53. diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
  54. diffusers/models/cache_utils.py +31 -9
  55. diffusers/models/controlnets/controlnet_flux.py +5 -5
  56. diffusers/models/controlnets/controlnet_union.py +4 -4
  57. diffusers/models/embeddings.py +26 -34
  58. diffusers/models/model_loading_utils.py +233 -1
  59. diffusers/models/modeling_flax_utils.py +1 -2
  60. diffusers/models/modeling_utils.py +159 -94
  61. diffusers/models/transformers/__init__.py +2 -0
  62. diffusers/models/transformers/transformer_chroma.py +16 -117
  63. diffusers/models/transformers/transformer_cogview4.py +36 -2
  64. diffusers/models/transformers/transformer_cosmos.py +11 -4
  65. diffusers/models/transformers/transformer_flux.py +372 -132
  66. diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
  67. diffusers/models/transformers/transformer_ltx.py +104 -23
  68. diffusers/models/transformers/transformer_qwenimage.py +645 -0
  69. diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
  70. diffusers/models/transformers/transformer_wan.py +298 -85
  71. diffusers/models/transformers/transformer_wan_vace.py +15 -21
  72. diffusers/models/unets/unet_2d_condition.py +2 -1
  73. diffusers/modular_pipelines/__init__.py +83 -0
  74. diffusers/modular_pipelines/components_manager.py +1068 -0
  75. diffusers/modular_pipelines/flux/__init__.py +66 -0
  76. diffusers/modular_pipelines/flux/before_denoise.py +689 -0
  77. diffusers/modular_pipelines/flux/decoders.py +109 -0
  78. diffusers/modular_pipelines/flux/denoise.py +227 -0
  79. diffusers/modular_pipelines/flux/encoders.py +412 -0
  80. diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
  81. diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
  82. diffusers/modular_pipelines/modular_pipeline.py +2446 -0
  83. diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
  84. diffusers/modular_pipelines/node_utils.py +665 -0
  85. diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
  86. diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
  87. diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
  88. diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
  89. diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
  90. diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
  91. diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
  92. diffusers/modular_pipelines/wan/__init__.py +66 -0
  93. diffusers/modular_pipelines/wan/before_denoise.py +365 -0
  94. diffusers/modular_pipelines/wan/decoders.py +105 -0
  95. diffusers/modular_pipelines/wan/denoise.py +261 -0
  96. diffusers/modular_pipelines/wan/encoders.py +242 -0
  97. diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
  98. diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
  99. diffusers/pipelines/__init__.py +31 -0
  100. diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
  101. diffusers/pipelines/auto_pipeline.py +17 -13
  102. diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
  103. diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
  104. diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
  105. diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
  106. diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
  107. diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
  108. diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
  109. diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
  110. diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
  111. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
  112. diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
  113. diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
  114. diffusers/pipelines/dit/pipeline_dit.py +3 -1
  115. diffusers/pipelines/flux/__init__.py +4 -0
  116. diffusers/pipelines/flux/pipeline_flux.py +34 -26
  117. diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
  118. diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
  119. diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
  120. diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
  121. diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
  122. diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
  123. diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
  124. diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
  125. diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
  126. diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
  127. diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
  128. diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
  129. diffusers/pipelines/flux/pipeline_output.py +6 -4
  130. diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
  131. diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
  132. diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
  133. diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
  134. diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
  135. diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
  136. diffusers/pipelines/pipeline_flax_utils.py +2 -2
  137. diffusers/pipelines/pipeline_loading_utils.py +24 -2
  138. diffusers/pipelines/pipeline_utils.py +22 -15
  139. diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
  140. diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
  141. diffusers/pipelines/qwenimage/__init__.py +55 -0
  142. diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
  143. diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
  144. diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
  145. diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
  146. diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
  147. diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
  148. diffusers/pipelines/skyreels_v2/__init__.py +59 -0
  149. diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
  150. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
  151. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
  152. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
  153. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
  154. diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
  155. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
  156. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
  157. diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
  158. diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
  159. diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
  160. diffusers/pipelines/wan/pipeline_wan.py +78 -20
  161. diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
  162. diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
  163. diffusers/quantizers/__init__.py +1 -177
  164. diffusers/quantizers/base.py +11 -0
  165. diffusers/quantizers/gguf/utils.py +92 -3
  166. diffusers/quantizers/pipe_quant_config.py +202 -0
  167. diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
  168. diffusers/schedulers/scheduling_deis_multistep.py +8 -1
  169. diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
  170. diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
  171. diffusers/schedulers/scheduling_scm.py +0 -1
  172. diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
  173. diffusers/schedulers/scheduling_utils.py +2 -2
  174. diffusers/schedulers/scheduling_utils_flax.py +1 -1
  175. diffusers/training_utils.py +78 -0
  176. diffusers/utils/__init__.py +10 -0
  177. diffusers/utils/constants.py +4 -0
  178. diffusers/utils/dummy_pt_objects.py +312 -0
  179. diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
  180. diffusers/utils/dynamic_modules_utils.py +84 -25
  181. diffusers/utils/hub_utils.py +33 -17
  182. diffusers/utils/import_utils.py +70 -0
  183. diffusers/utils/peft_utils.py +11 -8
  184. diffusers/utils/testing_utils.py +136 -10
  185. diffusers/utils/torch_utils.py +18 -0
  186. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
  187. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
  188. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
  189. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
  190. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
  191. {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
34
34
  CACHE_T = 2
35
35
 
36
36
 
37
+ class AvgDown3D(nn.Module):
38
+ def __init__(
39
+ self,
40
+ in_channels,
41
+ out_channels,
42
+ factor_t,
43
+ factor_s=1,
44
+ ):
45
+ super().__init__()
46
+ self.in_channels = in_channels
47
+ self.out_channels = out_channels
48
+ self.factor_t = factor_t
49
+ self.factor_s = factor_s
50
+ self.factor = self.factor_t * self.factor_s * self.factor_s
51
+
52
+ assert in_channels * self.factor % out_channels == 0
53
+ self.group_size = in_channels * self.factor // out_channels
54
+
55
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
56
+ pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
57
+ pad = (0, 0, 0, 0, pad_t, 0)
58
+ x = F.pad(x, pad)
59
+ B, C, T, H, W = x.shape
60
+ x = x.view(
61
+ B,
62
+ C,
63
+ T // self.factor_t,
64
+ self.factor_t,
65
+ H // self.factor_s,
66
+ self.factor_s,
67
+ W // self.factor_s,
68
+ self.factor_s,
69
+ )
70
+ x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
71
+ x = x.view(
72
+ B,
73
+ C * self.factor,
74
+ T // self.factor_t,
75
+ H // self.factor_s,
76
+ W // self.factor_s,
77
+ )
78
+ x = x.view(
79
+ B,
80
+ self.out_channels,
81
+ self.group_size,
82
+ T // self.factor_t,
83
+ H // self.factor_s,
84
+ W // self.factor_s,
85
+ )
86
+ x = x.mean(dim=2)
87
+ return x
88
+
89
+
90
+ class DupUp3D(nn.Module):
91
+ def __init__(
92
+ self,
93
+ in_channels: int,
94
+ out_channels: int,
95
+ factor_t,
96
+ factor_s=1,
97
+ ):
98
+ super().__init__()
99
+ self.in_channels = in_channels
100
+ self.out_channels = out_channels
101
+
102
+ self.factor_t = factor_t
103
+ self.factor_s = factor_s
104
+ self.factor = self.factor_t * self.factor_s * self.factor_s
105
+
106
+ assert out_channels * self.factor % in_channels == 0
107
+ self.repeats = out_channels * self.factor // in_channels
108
+
109
+ def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
110
+ x = x.repeat_interleave(self.repeats, dim=1)
111
+ x = x.view(
112
+ x.size(0),
113
+ self.out_channels,
114
+ self.factor_t,
115
+ self.factor_s,
116
+ self.factor_s,
117
+ x.size(2),
118
+ x.size(3),
119
+ x.size(4),
120
+ )
121
+ x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
122
+ x = x.view(
123
+ x.size(0),
124
+ self.out_channels,
125
+ x.size(2) * self.factor_t,
126
+ x.size(4) * self.factor_s,
127
+ x.size(6) * self.factor_s,
128
+ )
129
+ if first_chunk:
130
+ x = x[:, :, self.factor_t - 1 :, :, :]
131
+ return x
132
+
133
+
37
134
  class WanCausalConv3d(nn.Conv3d):
38
135
  r"""
39
136
  A custom 3D causal convolution layer with feature caching support.
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
134
231
  - 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
135
232
  """
136
233
 
137
- def __init__(self, dim: int, mode: str) -> None:
234
+ def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
138
235
  super().__init__()
139
236
  self.dim = dim
140
237
  self.mode = mode
141
238
 
239
+ # default to dim //2
240
+ if upsample_out_dim is None:
241
+ upsample_out_dim = dim // 2
242
+
142
243
  # layers
143
244
  if mode == "upsample2d":
144
245
  self.resample = nn.Sequential(
145
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
246
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
247
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
146
248
  )
147
249
  elif mode == "upsample3d":
148
250
  self.resample = nn.Sequential(
149
- WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"), nn.Conv2d(dim, dim // 2, 3, padding=1)
251
+ WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
252
+ nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
150
253
  )
151
254
  self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
152
255
 
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
363
466
  return x
364
467
 
365
468
 
469
+ class WanResidualDownBlock(nn.Module):
470
+ def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
471
+ super().__init__()
472
+
473
+ # Shortcut path with downsample
474
+ self.avg_shortcut = AvgDown3D(
475
+ in_dim,
476
+ out_dim,
477
+ factor_t=2 if temperal_downsample else 1,
478
+ factor_s=2 if down_flag else 1,
479
+ )
480
+
481
+ # Main path with residual blocks and downsample
482
+ resnets = []
483
+ for _ in range(num_res_blocks):
484
+ resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
485
+ in_dim = out_dim
486
+ self.resnets = nn.ModuleList(resnets)
487
+
488
+ # Add the final downsample block
489
+ if down_flag:
490
+ mode = "downsample3d" if temperal_downsample else "downsample2d"
491
+ self.downsampler = WanResample(out_dim, mode=mode)
492
+ else:
493
+ self.downsampler = None
494
+
495
+ def forward(self, x, feat_cache=None, feat_idx=[0]):
496
+ x_copy = x.clone()
497
+ for resnet in self.resnets:
498
+ x = resnet(x, feat_cache, feat_idx)
499
+ if self.downsampler is not None:
500
+ x = self.downsampler(x, feat_cache, feat_idx)
501
+
502
+ return x + self.avg_shortcut(x_copy)
503
+
504
+
366
505
  class WanEncoder3d(nn.Module):
367
506
  r"""
368
507
  A 3D encoder module.
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
380
519
 
381
520
  def __init__(
382
521
  self,
522
+ in_channels: int = 3,
383
523
  dim=128,
384
524
  z_dim=4,
385
525
  dim_mult=[1, 2, 4, 4],
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
388
528
  temperal_downsample=[True, True, False],
389
529
  dropout=0.0,
390
530
  non_linearity: str = "silu",
531
+ is_residual: bool = False, # wan 2.2 vae use a residual downblock
391
532
  ):
392
533
  super().__init__()
393
534
  self.dim = dim
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
403
544
  scale = 1.0
404
545
 
405
546
  # init block
406
- self.conv_in = WanCausalConv3d(3, dims[0], 3, padding=1)
547
+ self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
407
548
 
408
549
  # downsample blocks
409
550
  self.down_blocks = nn.ModuleList([])
410
551
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
411
552
  # residual (+attention) blocks
412
- for _ in range(num_res_blocks):
413
- self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
414
- if scale in attn_scales:
415
- self.down_blocks.append(WanAttentionBlock(out_dim))
416
- in_dim = out_dim
417
-
418
- # downsample block
419
- if i != len(dim_mult) - 1:
420
- mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
421
- self.down_blocks.append(WanResample(out_dim, mode=mode))
422
- scale /= 2.0
553
+ if is_residual:
554
+ self.down_blocks.append(
555
+ WanResidualDownBlock(
556
+ in_dim,
557
+ out_dim,
558
+ dropout,
559
+ num_res_blocks,
560
+ temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
561
+ down_flag=i != len(dim_mult) - 1,
562
+ )
563
+ )
564
+ else:
565
+ for _ in range(num_res_blocks):
566
+ self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
567
+ if scale in attn_scales:
568
+ self.down_blocks.append(WanAttentionBlock(out_dim))
569
+ in_dim = out_dim
570
+
571
+ # downsample block
572
+ if i != len(dim_mult) - 1:
573
+ mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
574
+ self.down_blocks.append(WanResample(out_dim, mode=mode))
575
+ scale /= 2.0
423
576
 
424
577
  # middle blocks
425
578
  self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
470
623
  return x
471
624
 
472
625
 
626
+ class WanResidualUpBlock(nn.Module):
627
+ """
628
+ A block that handles upsampling for the WanVAE decoder.
629
+
630
+ Args:
631
+ in_dim (int): Input dimension
632
+ out_dim (int): Output dimension
633
+ num_res_blocks (int): Number of residual blocks
634
+ dropout (float): Dropout rate
635
+ temperal_upsample (bool): Whether to upsample on temporal dimension
636
+ up_flag (bool): Whether to upsample or not
637
+ non_linearity (str): Type of non-linearity to use
638
+ """
639
+
640
+ def __init__(
641
+ self,
642
+ in_dim: int,
643
+ out_dim: int,
644
+ num_res_blocks: int,
645
+ dropout: float = 0.0,
646
+ temperal_upsample: bool = False,
647
+ up_flag: bool = False,
648
+ non_linearity: str = "silu",
649
+ ):
650
+ super().__init__()
651
+ self.in_dim = in_dim
652
+ self.out_dim = out_dim
653
+
654
+ if up_flag:
655
+ self.avg_shortcut = DupUp3D(
656
+ in_dim,
657
+ out_dim,
658
+ factor_t=2 if temperal_upsample else 1,
659
+ factor_s=2,
660
+ )
661
+ else:
662
+ self.avg_shortcut = None
663
+
664
+ # create residual blocks
665
+ resnets = []
666
+ current_dim = in_dim
667
+ for _ in range(num_res_blocks + 1):
668
+ resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
669
+ current_dim = out_dim
670
+
671
+ self.resnets = nn.ModuleList(resnets)
672
+
673
+ # Add upsampling layer if needed
674
+ if up_flag:
675
+ upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
676
+ self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
677
+ else:
678
+ self.upsampler = None
679
+
680
+ self.gradient_checkpointing = False
681
+
682
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
683
+ """
684
+ Forward pass through the upsampling block.
685
+
686
+ Args:
687
+ x (torch.Tensor): Input tensor
688
+ feat_cache (list, optional): Feature cache for causal convolutions
689
+ feat_idx (list, optional): Feature index for cache management
690
+
691
+ Returns:
692
+ torch.Tensor: Output tensor
693
+ """
694
+ x_copy = x.clone()
695
+
696
+ for resnet in self.resnets:
697
+ if feat_cache is not None:
698
+ x = resnet(x, feat_cache, feat_idx)
699
+ else:
700
+ x = resnet(x)
701
+
702
+ if self.upsampler is not None:
703
+ if feat_cache is not None:
704
+ x = self.upsampler(x, feat_cache, feat_idx)
705
+ else:
706
+ x = self.upsampler(x)
707
+
708
+ if self.avg_shortcut is not None:
709
+ x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
710
+
711
+ return x
712
+
713
+
473
714
  class WanUpBlock(nn.Module):
474
715
  """
475
716
  A block that handles upsampling for the WanVAE decoder.
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
513
754
 
514
755
  self.gradient_checkpointing = False
515
756
 
516
- def forward(self, x, feat_cache=None, feat_idx=[0]):
757
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
517
758
  """
518
759
  Forward pass through the upsampling block.
519
760
 
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
564
805
  temperal_upsample=[False, True, True],
565
806
  dropout=0.0,
566
807
  non_linearity: str = "silu",
808
+ out_channels: int = 3,
809
+ is_residual: bool = False,
567
810
  ):
568
811
  super().__init__()
569
812
  self.dim = dim
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
577
820
 
578
821
  # dimensions
579
822
  dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
580
- scale = 1.0 / 2 ** (len(dim_mult) - 2)
581
823
 
582
824
  # init block
583
825
  self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
589
831
  self.up_blocks = nn.ModuleList([])
590
832
  for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
591
833
  # residual (+attention) blocks
592
- if i > 0:
834
+ if i > 0 and not is_residual:
835
+ # wan vae 2.1
593
836
  in_dim = in_dim // 2
594
837
 
595
- # Determine if we need upsampling
838
+ # determine if we need upsampling
839
+ up_flag = i != len(dim_mult) - 1
840
+ # determine upsampling mode, if not upsampling, set to None
596
841
  upsample_mode = None
597
- if i != len(dim_mult) - 1:
598
- upsample_mode = "upsample3d" if temperal_upsample[i] else "upsample2d"
599
-
842
+ if up_flag and temperal_upsample[i]:
843
+ upsample_mode = "upsample3d"
844
+ elif up_flag:
845
+ upsample_mode = "upsample2d"
600
846
  # Create and add the upsampling block
601
- up_block = WanUpBlock(
602
- in_dim=in_dim,
603
- out_dim=out_dim,
604
- num_res_blocks=num_res_blocks,
605
- dropout=dropout,
606
- upsample_mode=upsample_mode,
607
- non_linearity=non_linearity,
608
- )
847
+ if is_residual:
848
+ up_block = WanResidualUpBlock(
849
+ in_dim=in_dim,
850
+ out_dim=out_dim,
851
+ num_res_blocks=num_res_blocks,
852
+ dropout=dropout,
853
+ temperal_upsample=temperal_upsample[i] if up_flag else False,
854
+ up_flag=up_flag,
855
+ non_linearity=non_linearity,
856
+ )
857
+ else:
858
+ up_block = WanUpBlock(
859
+ in_dim=in_dim,
860
+ out_dim=out_dim,
861
+ num_res_blocks=num_res_blocks,
862
+ dropout=dropout,
863
+ upsample_mode=upsample_mode,
864
+ non_linearity=non_linearity,
865
+ )
609
866
  self.up_blocks.append(up_block)
610
867
 
611
- # Update scale for next iteration
612
- if upsample_mode is not None:
613
- scale *= 2.0
614
-
615
868
  # output blocks
616
869
  self.norm_out = WanRMS_norm(out_dim, images=False)
617
- self.conv_out = WanCausalConv3d(out_dim, 3, 3, padding=1)
870
+ self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
618
871
 
619
872
  self.gradient_checkpointing = False
620
873
 
621
- def forward(self, x, feat_cache=None, feat_idx=[0]):
874
+ def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
622
875
  ## conv1
623
876
  if feat_cache is not None:
624
877
  idx = feat_idx[0]
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
637
890
 
638
891
  ## upsamples
639
892
  for up_block in self.up_blocks:
640
- x = up_block(x, feat_cache, feat_idx)
893
+ x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
641
894
 
642
895
  ## head
643
896
  x = self.norm_out(x)
@@ -656,6 +909,49 @@ class WanDecoder3d(nn.Module):
656
909
  return x
657
910
 
658
911
 
912
+ def patchify(x, patch_size):
913
+ if patch_size == 1:
914
+ return x
915
+
916
+ if x.dim() != 5:
917
+ raise ValueError(f"Invalid input shape: {x.shape}")
918
+ # x shape: [batch_size, channels, frames, height, width]
919
+ batch_size, channels, frames, height, width = x.shape
920
+
921
+ # Ensure height and width are divisible by patch_size
922
+ if height % patch_size != 0 or width % patch_size != 0:
923
+ raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
924
+
925
+ # Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
926
+ x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
927
+
928
+ # Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
929
+ x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
930
+ x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
931
+
932
+ return x
933
+
934
+
935
+ def unpatchify(x, patch_size):
936
+ if patch_size == 1:
937
+ return x
938
+
939
+ if x.dim() != 5:
940
+ raise ValueError(f"Invalid input shape: {x.shape}")
941
+ # x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
942
+ batch_size, c_patches, frames, height, width = x.shape
943
+ channels = c_patches // (patch_size * patch_size)
944
+
945
+ # Reshape to [b, c, patch_size, patch_size, f, h, w]
946
+ x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
947
+
948
+ # Rearrange to [b, c, f, h * patch_size, w * patch_size]
949
+ x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
950
+ x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
951
+
952
+ return x
953
+
954
+
659
955
  class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
660
956
  r"""
661
957
  A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
@@ -671,6 +967,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
671
967
  def __init__(
672
968
  self,
673
969
  base_dim: int = 96,
970
+ decoder_base_dim: Optional[int] = None,
674
971
  z_dim: int = 16,
675
972
  dim_mult: Tuple[int] = [1, 2, 4, 4],
676
973
  num_res_blocks: int = 2,
@@ -713,6 +1010,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
713
1010
  2.8251,
714
1011
  1.9160,
715
1012
  ],
1013
+ is_residual: bool = False,
1014
+ in_channels: int = 3,
1015
+ out_channels: int = 3,
1016
+ patch_size: Optional[int] = None,
1017
+ scale_factor_temporal: Optional[int] = 4,
1018
+ scale_factor_spatial: Optional[int] = 8,
716
1019
  ) -> None:
717
1020
  super().__init__()
718
1021
 
@@ -720,14 +1023,33 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
720
1023
  self.temperal_downsample = temperal_downsample
721
1024
  self.temperal_upsample = temperal_downsample[::-1]
722
1025
 
1026
+ if decoder_base_dim is None:
1027
+ decoder_base_dim = base_dim
1028
+
723
1029
  self.encoder = WanEncoder3d(
724
- base_dim, z_dim * 2, dim_mult, num_res_blocks, attn_scales, self.temperal_downsample, dropout
1030
+ in_channels=in_channels,
1031
+ dim=base_dim,
1032
+ z_dim=z_dim * 2,
1033
+ dim_mult=dim_mult,
1034
+ num_res_blocks=num_res_blocks,
1035
+ attn_scales=attn_scales,
1036
+ temperal_downsample=temperal_downsample,
1037
+ dropout=dropout,
1038
+ is_residual=is_residual,
725
1039
  )
726
1040
  self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
727
1041
  self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
728
1042
 
729
1043
  self.decoder = WanDecoder3d(
730
- base_dim, z_dim, dim_mult, num_res_blocks, attn_scales, self.temperal_upsample, dropout
1044
+ dim=decoder_base_dim,
1045
+ z_dim=z_dim,
1046
+ dim_mult=dim_mult,
1047
+ num_res_blocks=num_res_blocks,
1048
+ attn_scales=attn_scales,
1049
+ temperal_upsample=self.temperal_upsample,
1050
+ dropout=dropout,
1051
+ out_channels=out_channels,
1052
+ is_residual=is_residual,
731
1053
  )
732
1054
 
733
1055
  self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
@@ -827,6 +1149,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
827
1149
  return self.tiled_encode(x)
828
1150
 
829
1151
  self.clear_cache()
1152
+ if self.config.patch_size is not None:
1153
+ x = patchify(x, patch_size=self.config.patch_size)
830
1154
  iter_ = 1 + (num_frame - 1) // 4
831
1155
  for i in range(iter_):
832
1156
  self._enc_conv_idx = [0]
@@ -884,12 +1208,18 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
884
1208
  for i in range(num_frame):
885
1209
  self._conv_idx = [0]
886
1210
  if i == 0:
887
- out = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
1211
+ out = self.decoder(
1212
+ x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
1213
+ )
888
1214
  else:
889
1215
  out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
890
1216
  out = torch.cat([out, out_], 2)
891
1217
 
1218
+ if self.config.patch_size is not None:
1219
+ out = unpatchify(out, patch_size=self.config.patch_size)
1220
+
892
1221
  out = torch.clamp(out, min=-1.0, max=1.0)
1222
+
893
1223
  self.clear_cache()
894
1224
  if not return_dict:
895
1225
  return (out,)
@@ -12,6 +12,8 @@
12
12
  # See the License for the specific language governing permissions and
13
13
  # limitations under the License.
14
14
 
15
+ from contextlib import contextmanager
16
+
15
17
  from ..utils.logging import get_logger
16
18
 
17
19
 
@@ -25,6 +27,7 @@ class CacheMixin:
25
27
  Supported caching techniques:
26
28
  - [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
27
29
  - [FasterCache](https://huggingface.co/papers/2410.19355)
30
+ - [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
28
31
  """
29
32
 
30
33
  _cache_config = None
@@ -62,8 +65,10 @@ class CacheMixin:
62
65
 
63
66
  from ..hooks import (
64
67
  FasterCacheConfig,
68
+ FirstBlockCacheConfig,
65
69
  PyramidAttentionBroadcastConfig,
66
70
  apply_faster_cache,
71
+ apply_first_block_cache,
67
72
  apply_pyramid_attention_broadcast,
68
73
  )
69
74
 
@@ -72,31 +77,36 @@ class CacheMixin:
72
77
  f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
73
78
  )
74
79
 
75
- if isinstance(config, PyramidAttentionBroadcastConfig):
76
- apply_pyramid_attention_broadcast(self, config)
77
- elif isinstance(config, FasterCacheConfig):
80
+ if isinstance(config, FasterCacheConfig):
78
81
  apply_faster_cache(self, config)
82
+ elif isinstance(config, FirstBlockCacheConfig):
83
+ apply_first_block_cache(self, config)
84
+ elif isinstance(config, PyramidAttentionBroadcastConfig):
85
+ apply_pyramid_attention_broadcast(self, config)
79
86
  else:
80
87
  raise ValueError(f"Cache config {type(config)} is not supported.")
81
88
 
82
89
  self._cache_config = config
83
90
 
84
91
  def disable_cache(self) -> None:
85
- from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
92
+ from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
86
93
  from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
94
+ from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
87
95
  from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
88
96
 
89
97
  if self._cache_config is None:
90
98
  logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
91
99
  return
92
100
 
93
- if isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
94
- registry = HookRegistry.check_if_exists_or_initialize(self)
95
- registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
96
- elif isinstance(self._cache_config, FasterCacheConfig):
97
- registry = HookRegistry.check_if_exists_or_initialize(self)
101
+ registry = HookRegistry.check_if_exists_or_initialize(self)
102
+ if isinstance(self._cache_config, FasterCacheConfig):
98
103
  registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
99
104
  registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
105
+ elif isinstance(self._cache_config, FirstBlockCacheConfig):
106
+ registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
107
+ registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
108
+ elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
109
+ registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
100
110
  else:
101
111
  raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
102
112
 
@@ -106,3 +116,15 @@ class CacheMixin:
106
116
  from ..hooks import HookRegistry
107
117
 
108
118
  HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
119
+
120
+ @contextmanager
121
+ def cache_context(self, name: str):
122
+ r"""Context manager that provides additional methods for cache management."""
123
+ from ..hooks import HookRegistry
124
+
125
+ registry = HookRegistry.check_if_exists_or_initialize(self)
126
+ registry._set_context(name)
127
+
128
+ yield
129
+
130
+ registry._set_context(None)
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
343
343
  )
344
344
  block_samples = block_samples + (hidden_states,)
345
345
 
346
- hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
347
-
348
346
  single_block_samples = ()
349
347
  for index_block, block in enumerate(self.single_transformer_blocks):
350
348
  if torch.is_grad_enabled() and self.gradient_checkpointing:
351
- hidden_states = self._gradient_checkpointing_func(
349
+ encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
352
350
  block,
353
351
  hidden_states,
352
+ encoder_hidden_states,
354
353
  temb,
355
354
  image_rotary_emb,
356
355
  )
357
356
 
358
357
  else:
359
- hidden_states = block(
358
+ encoder_hidden_states, hidden_states = block(
360
359
  hidden_states=hidden_states,
360
+ encoder_hidden_states=encoder_hidden_states,
361
361
  temb=temb,
362
362
  image_rotary_emb=image_rotary_emb,
363
363
  )
364
- single_block_samples = single_block_samples + (hidden_states[:, encoder_hidden_states.shape[1] :],)
364
+ single_block_samples = single_block_samples + (hidden_states,)
365
365
 
366
366
  # controlnet block
367
367
  controlnet_block_samples = ()