diffusers 0.34.0__py3-none-any.whl → 0.35.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- diffusers/__init__.py +98 -1
- diffusers/callbacks.py +35 -0
- diffusers/commands/custom_blocks.py +134 -0
- diffusers/commands/diffusers_cli.py +2 -0
- diffusers/commands/fp16_safetensors.py +1 -1
- diffusers/configuration_utils.py +11 -2
- diffusers/dependency_versions_table.py +3 -3
- diffusers/guiders/__init__.py +41 -0
- diffusers/guiders/adaptive_projected_guidance.py +188 -0
- diffusers/guiders/auto_guidance.py +190 -0
- diffusers/guiders/classifier_free_guidance.py +141 -0
- diffusers/guiders/classifier_free_zero_star_guidance.py +152 -0
- diffusers/guiders/frequency_decoupled_guidance.py +327 -0
- diffusers/guiders/guider_utils.py +309 -0
- diffusers/guiders/perturbed_attention_guidance.py +271 -0
- diffusers/guiders/skip_layer_guidance.py +262 -0
- diffusers/guiders/smoothed_energy_guidance.py +251 -0
- diffusers/guiders/tangential_classifier_free_guidance.py +143 -0
- diffusers/hooks/__init__.py +17 -0
- diffusers/hooks/_common.py +56 -0
- diffusers/hooks/_helpers.py +293 -0
- diffusers/hooks/faster_cache.py +7 -6
- diffusers/hooks/first_block_cache.py +259 -0
- diffusers/hooks/group_offloading.py +292 -286
- diffusers/hooks/hooks.py +56 -1
- diffusers/hooks/layer_skip.py +263 -0
- diffusers/hooks/layerwise_casting.py +2 -7
- diffusers/hooks/pyramid_attention_broadcast.py +14 -11
- diffusers/hooks/smoothed_energy_guidance_utils.py +167 -0
- diffusers/hooks/utils.py +43 -0
- diffusers/loaders/__init__.py +6 -0
- diffusers/loaders/ip_adapter.py +255 -4
- diffusers/loaders/lora_base.py +63 -30
- diffusers/loaders/lora_conversion_utils.py +434 -53
- diffusers/loaders/lora_pipeline.py +834 -37
- diffusers/loaders/peft.py +28 -5
- diffusers/loaders/single_file_model.py +44 -11
- diffusers/loaders/single_file_utils.py +170 -2
- diffusers/loaders/transformer_flux.py +9 -10
- diffusers/loaders/transformer_sd3.py +6 -1
- diffusers/loaders/unet.py +22 -5
- diffusers/loaders/unet_loader_utils.py +5 -2
- diffusers/models/__init__.py +8 -0
- diffusers/models/attention.py +484 -3
- diffusers/models/attention_dispatch.py +1218 -0
- diffusers/models/attention_processor.py +105 -663
- diffusers/models/auto_model.py +2 -2
- diffusers/models/autoencoders/__init__.py +1 -0
- diffusers/models/autoencoders/autoencoder_dc.py +14 -1
- diffusers/models/autoencoders/autoencoder_kl.py +1 -1
- diffusers/models/autoencoders/autoencoder_kl_cosmos.py +3 -1
- diffusers/models/autoencoders/autoencoder_kl_qwenimage.py +1070 -0
- diffusers/models/autoencoders/autoencoder_kl_wan.py +370 -40
- diffusers/models/cache_utils.py +31 -9
- diffusers/models/controlnets/controlnet_flux.py +5 -5
- diffusers/models/controlnets/controlnet_union.py +4 -4
- diffusers/models/embeddings.py +26 -34
- diffusers/models/model_loading_utils.py +233 -1
- diffusers/models/modeling_flax_utils.py +1 -2
- diffusers/models/modeling_utils.py +159 -94
- diffusers/models/transformers/__init__.py +2 -0
- diffusers/models/transformers/transformer_chroma.py +16 -117
- diffusers/models/transformers/transformer_cogview4.py +36 -2
- diffusers/models/transformers/transformer_cosmos.py +11 -4
- diffusers/models/transformers/transformer_flux.py +372 -132
- diffusers/models/transformers/transformer_hunyuan_video.py +6 -0
- diffusers/models/transformers/transformer_ltx.py +104 -23
- diffusers/models/transformers/transformer_qwenimage.py +645 -0
- diffusers/models/transformers/transformer_skyreels_v2.py +607 -0
- diffusers/models/transformers/transformer_wan.py +298 -85
- diffusers/models/transformers/transformer_wan_vace.py +15 -21
- diffusers/models/unets/unet_2d_condition.py +2 -1
- diffusers/modular_pipelines/__init__.py +83 -0
- diffusers/modular_pipelines/components_manager.py +1068 -0
- diffusers/modular_pipelines/flux/__init__.py +66 -0
- diffusers/modular_pipelines/flux/before_denoise.py +689 -0
- diffusers/modular_pipelines/flux/decoders.py +109 -0
- diffusers/modular_pipelines/flux/denoise.py +227 -0
- diffusers/modular_pipelines/flux/encoders.py +412 -0
- diffusers/modular_pipelines/flux/modular_blocks.py +181 -0
- diffusers/modular_pipelines/flux/modular_pipeline.py +59 -0
- diffusers/modular_pipelines/modular_pipeline.py +2446 -0
- diffusers/modular_pipelines/modular_pipeline_utils.py +672 -0
- diffusers/modular_pipelines/node_utils.py +665 -0
- diffusers/modular_pipelines/stable_diffusion_xl/__init__.py +77 -0
- diffusers/modular_pipelines/stable_diffusion_xl/before_denoise.py +1874 -0
- diffusers/modular_pipelines/stable_diffusion_xl/decoders.py +208 -0
- diffusers/modular_pipelines/stable_diffusion_xl/denoise.py +771 -0
- diffusers/modular_pipelines/stable_diffusion_xl/encoders.py +887 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_blocks.py +380 -0
- diffusers/modular_pipelines/stable_diffusion_xl/modular_pipeline.py +365 -0
- diffusers/modular_pipelines/wan/__init__.py +66 -0
- diffusers/modular_pipelines/wan/before_denoise.py +365 -0
- diffusers/modular_pipelines/wan/decoders.py +105 -0
- diffusers/modular_pipelines/wan/denoise.py +261 -0
- diffusers/modular_pipelines/wan/encoders.py +242 -0
- diffusers/modular_pipelines/wan/modular_blocks.py +144 -0
- diffusers/modular_pipelines/wan/modular_pipeline.py +90 -0
- diffusers/pipelines/__init__.py +31 -0
- diffusers/pipelines/audioldm2/pipeline_audioldm2.py +2 -3
- diffusers/pipelines/auto_pipeline.py +17 -13
- diffusers/pipelines/chroma/pipeline_chroma.py +5 -5
- diffusers/pipelines/chroma/pipeline_chroma_img2img.py +5 -5
- diffusers/pipelines/cogvideo/pipeline_cogvideox.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_fun_control.py +9 -8
- diffusers/pipelines/cogvideo/pipeline_cogvideox_image2video.py +10 -9
- diffusers/pipelines/cogvideo/pipeline_cogvideox_video2video.py +9 -8
- diffusers/pipelines/cogview4/pipeline_cogview4.py +16 -15
- diffusers/pipelines/controlnet/pipeline_controlnet_blip_diffusion.py +3 -2
- diffusers/pipelines/controlnet/pipeline_controlnet_union_inpaint_sd_xl.py +212 -93
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl.py +7 -3
- diffusers/pipelines/controlnet/pipeline_controlnet_union_sd_xl_img2img.py +194 -92
- diffusers/pipelines/deprecated/stable_diffusion_variants/pipeline_cycle_diffusion.py +1 -1
- diffusers/pipelines/dit/pipeline_dit.py +3 -1
- diffusers/pipelines/flux/__init__.py +4 -0
- diffusers/pipelines/flux/pipeline_flux.py +34 -26
- diffusers/pipelines/flux/pipeline_flux_control.py +8 -8
- diffusers/pipelines/flux/pipeline_flux_control_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_control_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_image_to_image.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_controlnet_inpainting.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_fill.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_img2img.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_inpaint.py +1 -1
- diffusers/pipelines/flux/pipeline_flux_kontext.py +1134 -0
- diffusers/pipelines/flux/pipeline_flux_kontext_inpaint.py +1460 -0
- diffusers/pipelines/flux/pipeline_flux_prior_redux.py +1 -1
- diffusers/pipelines/flux/pipeline_output.py +6 -4
- diffusers/pipelines/hidream_image/pipeline_hidream_image.py +5 -5
- diffusers/pipelines/hunyuan_video/pipeline_hunyuan_video.py +25 -24
- diffusers/pipelines/ltx/pipeline_ltx.py +13 -12
- diffusers/pipelines/ltx/pipeline_ltx_condition.py +10 -9
- diffusers/pipelines/ltx/pipeline_ltx_image2video.py +13 -12
- diffusers/pipelines/mochi/pipeline_mochi.py +9 -8
- diffusers/pipelines/pipeline_flax_utils.py +2 -2
- diffusers/pipelines/pipeline_loading_utils.py +24 -2
- diffusers/pipelines/pipeline_utils.py +22 -15
- diffusers/pipelines/pixart_alpha/pipeline_pixart_alpha.py +3 -1
- diffusers/pipelines/pixart_alpha/pipeline_pixart_sigma.py +20 -0
- diffusers/pipelines/qwenimage/__init__.py +55 -0
- diffusers/pipelines/qwenimage/pipeline_output.py +21 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage.py +726 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_edit.py +849 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_img2img.py +829 -0
- diffusers/pipelines/qwenimage/pipeline_qwenimage_inpaint.py +1015 -0
- diffusers/pipelines/sana/pipeline_sana_sprint.py +5 -5
- diffusers/pipelines/skyreels_v2/__init__.py +59 -0
- diffusers/pipelines/skyreels_v2/pipeline_output.py +20 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2.py +610 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing.py +978 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_i2v.py +1059 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_diffusion_forcing_v2v.py +1063 -0
- diffusers/pipelines/skyreels_v2/pipeline_skyreels_v2_i2v.py +745 -0
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_inpaint.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_onnx_stable_diffusion_upscale.py +1 -1
- diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion.py +2 -1
- diffusers/pipelines/stable_diffusion_3/pipeline_stable_diffusion_3.py +6 -5
- diffusers/pipelines/wan/pipeline_wan.py +78 -20
- diffusers/pipelines/wan/pipeline_wan_i2v.py +112 -32
- diffusers/pipelines/wan/pipeline_wan_vace.py +1 -2
- diffusers/quantizers/__init__.py +1 -177
- diffusers/quantizers/base.py +11 -0
- diffusers/quantizers/gguf/utils.py +92 -3
- diffusers/quantizers/pipe_quant_config.py +202 -0
- diffusers/quantizers/torchao/torchao_quantizer.py +26 -0
- diffusers/schedulers/scheduling_deis_multistep.py +8 -1
- diffusers/schedulers/scheduling_dpmsolver_multistep.py +6 -0
- diffusers/schedulers/scheduling_dpmsolver_singlestep.py +6 -0
- diffusers/schedulers/scheduling_scm.py +0 -1
- diffusers/schedulers/scheduling_unipc_multistep.py +10 -1
- diffusers/schedulers/scheduling_utils.py +2 -2
- diffusers/schedulers/scheduling_utils_flax.py +1 -1
- diffusers/training_utils.py +78 -0
- diffusers/utils/__init__.py +10 -0
- diffusers/utils/constants.py +4 -0
- diffusers/utils/dummy_pt_objects.py +312 -0
- diffusers/utils/dummy_torch_and_transformers_objects.py +255 -0
- diffusers/utils/dynamic_modules_utils.py +84 -25
- diffusers/utils/hub_utils.py +33 -17
- diffusers/utils/import_utils.py +70 -0
- diffusers/utils/peft_utils.py +11 -8
- diffusers/utils/testing_utils.py +136 -10
- diffusers/utils/torch_utils.py +18 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/METADATA +6 -6
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/RECORD +191 -127
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/LICENSE +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/WHEEL +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/entry_points.txt +0 -0
- {diffusers-0.34.0.dist-info → diffusers-0.35.1.dist-info}/top_level.txt +0 -0
@@ -34,6 +34,103 @@ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
|
34
34
|
CACHE_T = 2
|
35
35
|
|
36
36
|
|
37
|
+
class AvgDown3D(nn.Module):
|
38
|
+
def __init__(
|
39
|
+
self,
|
40
|
+
in_channels,
|
41
|
+
out_channels,
|
42
|
+
factor_t,
|
43
|
+
factor_s=1,
|
44
|
+
):
|
45
|
+
super().__init__()
|
46
|
+
self.in_channels = in_channels
|
47
|
+
self.out_channels = out_channels
|
48
|
+
self.factor_t = factor_t
|
49
|
+
self.factor_s = factor_s
|
50
|
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
51
|
+
|
52
|
+
assert in_channels * self.factor % out_channels == 0
|
53
|
+
self.group_size = in_channels * self.factor // out_channels
|
54
|
+
|
55
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
56
|
+
pad_t = (self.factor_t - x.shape[2] % self.factor_t) % self.factor_t
|
57
|
+
pad = (0, 0, 0, 0, pad_t, 0)
|
58
|
+
x = F.pad(x, pad)
|
59
|
+
B, C, T, H, W = x.shape
|
60
|
+
x = x.view(
|
61
|
+
B,
|
62
|
+
C,
|
63
|
+
T // self.factor_t,
|
64
|
+
self.factor_t,
|
65
|
+
H // self.factor_s,
|
66
|
+
self.factor_s,
|
67
|
+
W // self.factor_s,
|
68
|
+
self.factor_s,
|
69
|
+
)
|
70
|
+
x = x.permute(0, 1, 3, 5, 7, 2, 4, 6).contiguous()
|
71
|
+
x = x.view(
|
72
|
+
B,
|
73
|
+
C * self.factor,
|
74
|
+
T // self.factor_t,
|
75
|
+
H // self.factor_s,
|
76
|
+
W // self.factor_s,
|
77
|
+
)
|
78
|
+
x = x.view(
|
79
|
+
B,
|
80
|
+
self.out_channels,
|
81
|
+
self.group_size,
|
82
|
+
T // self.factor_t,
|
83
|
+
H // self.factor_s,
|
84
|
+
W // self.factor_s,
|
85
|
+
)
|
86
|
+
x = x.mean(dim=2)
|
87
|
+
return x
|
88
|
+
|
89
|
+
|
90
|
+
class DupUp3D(nn.Module):
|
91
|
+
def __init__(
|
92
|
+
self,
|
93
|
+
in_channels: int,
|
94
|
+
out_channels: int,
|
95
|
+
factor_t,
|
96
|
+
factor_s=1,
|
97
|
+
):
|
98
|
+
super().__init__()
|
99
|
+
self.in_channels = in_channels
|
100
|
+
self.out_channels = out_channels
|
101
|
+
|
102
|
+
self.factor_t = factor_t
|
103
|
+
self.factor_s = factor_s
|
104
|
+
self.factor = self.factor_t * self.factor_s * self.factor_s
|
105
|
+
|
106
|
+
assert out_channels * self.factor % in_channels == 0
|
107
|
+
self.repeats = out_channels * self.factor // in_channels
|
108
|
+
|
109
|
+
def forward(self, x: torch.Tensor, first_chunk=False) -> torch.Tensor:
|
110
|
+
x = x.repeat_interleave(self.repeats, dim=1)
|
111
|
+
x = x.view(
|
112
|
+
x.size(0),
|
113
|
+
self.out_channels,
|
114
|
+
self.factor_t,
|
115
|
+
self.factor_s,
|
116
|
+
self.factor_s,
|
117
|
+
x.size(2),
|
118
|
+
x.size(3),
|
119
|
+
x.size(4),
|
120
|
+
)
|
121
|
+
x = x.permute(0, 1, 5, 2, 6, 3, 7, 4).contiguous()
|
122
|
+
x = x.view(
|
123
|
+
x.size(0),
|
124
|
+
self.out_channels,
|
125
|
+
x.size(2) * self.factor_t,
|
126
|
+
x.size(4) * self.factor_s,
|
127
|
+
x.size(6) * self.factor_s,
|
128
|
+
)
|
129
|
+
if first_chunk:
|
130
|
+
x = x[:, :, self.factor_t - 1 :, :, :]
|
131
|
+
return x
|
132
|
+
|
133
|
+
|
37
134
|
class WanCausalConv3d(nn.Conv3d):
|
38
135
|
r"""
|
39
136
|
A custom 3D causal convolution layer with feature caching support.
|
@@ -134,19 +231,25 @@ class WanResample(nn.Module):
|
|
134
231
|
- 'downsample3d': 3D downsampling with zero-padding, convolution, and causal 3D convolution.
|
135
232
|
"""
|
136
233
|
|
137
|
-
def __init__(self, dim: int, mode: str) -> None:
|
234
|
+
def __init__(self, dim: int, mode: str, upsample_out_dim: int = None) -> None:
|
138
235
|
super().__init__()
|
139
236
|
self.dim = dim
|
140
237
|
self.mode = mode
|
141
238
|
|
239
|
+
# default to dim //2
|
240
|
+
if upsample_out_dim is None:
|
241
|
+
upsample_out_dim = dim // 2
|
242
|
+
|
142
243
|
# layers
|
143
244
|
if mode == "upsample2d":
|
144
245
|
self.resample = nn.Sequential(
|
145
|
-
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
246
|
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
247
|
+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
146
248
|
)
|
147
249
|
elif mode == "upsample3d":
|
148
250
|
self.resample = nn.Sequential(
|
149
|
-
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
251
|
+
WanUpsample(scale_factor=(2.0, 2.0), mode="nearest-exact"),
|
252
|
+
nn.Conv2d(dim, upsample_out_dim, 3, padding=1),
|
150
253
|
)
|
151
254
|
self.time_conv = WanCausalConv3d(dim, dim * 2, (3, 1, 1), padding=(1, 0, 0))
|
152
255
|
|
@@ -363,6 +466,42 @@ class WanMidBlock(nn.Module):
|
|
363
466
|
return x
|
364
467
|
|
365
468
|
|
469
|
+
class WanResidualDownBlock(nn.Module):
|
470
|
+
def __init__(self, in_dim, out_dim, dropout, num_res_blocks, temperal_downsample=False, down_flag=False):
|
471
|
+
super().__init__()
|
472
|
+
|
473
|
+
# Shortcut path with downsample
|
474
|
+
self.avg_shortcut = AvgDown3D(
|
475
|
+
in_dim,
|
476
|
+
out_dim,
|
477
|
+
factor_t=2 if temperal_downsample else 1,
|
478
|
+
factor_s=2 if down_flag else 1,
|
479
|
+
)
|
480
|
+
|
481
|
+
# Main path with residual blocks and downsample
|
482
|
+
resnets = []
|
483
|
+
for _ in range(num_res_blocks):
|
484
|
+
resnets.append(WanResidualBlock(in_dim, out_dim, dropout))
|
485
|
+
in_dim = out_dim
|
486
|
+
self.resnets = nn.ModuleList(resnets)
|
487
|
+
|
488
|
+
# Add the final downsample block
|
489
|
+
if down_flag:
|
490
|
+
mode = "downsample3d" if temperal_downsample else "downsample2d"
|
491
|
+
self.downsampler = WanResample(out_dim, mode=mode)
|
492
|
+
else:
|
493
|
+
self.downsampler = None
|
494
|
+
|
495
|
+
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
496
|
+
x_copy = x.clone()
|
497
|
+
for resnet in self.resnets:
|
498
|
+
x = resnet(x, feat_cache, feat_idx)
|
499
|
+
if self.downsampler is not None:
|
500
|
+
x = self.downsampler(x, feat_cache, feat_idx)
|
501
|
+
|
502
|
+
return x + self.avg_shortcut(x_copy)
|
503
|
+
|
504
|
+
|
366
505
|
class WanEncoder3d(nn.Module):
|
367
506
|
r"""
|
368
507
|
A 3D encoder module.
|
@@ -380,6 +519,7 @@ class WanEncoder3d(nn.Module):
|
|
380
519
|
|
381
520
|
def __init__(
|
382
521
|
self,
|
522
|
+
in_channels: int = 3,
|
383
523
|
dim=128,
|
384
524
|
z_dim=4,
|
385
525
|
dim_mult=[1, 2, 4, 4],
|
@@ -388,6 +528,7 @@ class WanEncoder3d(nn.Module):
|
|
388
528
|
temperal_downsample=[True, True, False],
|
389
529
|
dropout=0.0,
|
390
530
|
non_linearity: str = "silu",
|
531
|
+
is_residual: bool = False, # wan 2.2 vae use a residual downblock
|
391
532
|
):
|
392
533
|
super().__init__()
|
393
534
|
self.dim = dim
|
@@ -403,23 +544,35 @@ class WanEncoder3d(nn.Module):
|
|
403
544
|
scale = 1.0
|
404
545
|
|
405
546
|
# init block
|
406
|
-
self.conv_in = WanCausalConv3d(
|
547
|
+
self.conv_in = WanCausalConv3d(in_channels, dims[0], 3, padding=1)
|
407
548
|
|
408
549
|
# downsample blocks
|
409
550
|
self.down_blocks = nn.ModuleList([])
|
410
551
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
411
552
|
# residual (+attention) blocks
|
412
|
-
|
413
|
-
self.down_blocks.append(
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
553
|
+
if is_residual:
|
554
|
+
self.down_blocks.append(
|
555
|
+
WanResidualDownBlock(
|
556
|
+
in_dim,
|
557
|
+
out_dim,
|
558
|
+
dropout,
|
559
|
+
num_res_blocks,
|
560
|
+
temperal_downsample=temperal_downsample[i] if i != len(dim_mult) - 1 else False,
|
561
|
+
down_flag=i != len(dim_mult) - 1,
|
562
|
+
)
|
563
|
+
)
|
564
|
+
else:
|
565
|
+
for _ in range(num_res_blocks):
|
566
|
+
self.down_blocks.append(WanResidualBlock(in_dim, out_dim, dropout))
|
567
|
+
if scale in attn_scales:
|
568
|
+
self.down_blocks.append(WanAttentionBlock(out_dim))
|
569
|
+
in_dim = out_dim
|
570
|
+
|
571
|
+
# downsample block
|
572
|
+
if i != len(dim_mult) - 1:
|
573
|
+
mode = "downsample3d" if temperal_downsample[i] else "downsample2d"
|
574
|
+
self.down_blocks.append(WanResample(out_dim, mode=mode))
|
575
|
+
scale /= 2.0
|
423
576
|
|
424
577
|
# middle blocks
|
425
578
|
self.mid_block = WanMidBlock(out_dim, dropout, non_linearity, num_layers=1)
|
@@ -470,6 +623,94 @@ class WanEncoder3d(nn.Module):
|
|
470
623
|
return x
|
471
624
|
|
472
625
|
|
626
|
+
class WanResidualUpBlock(nn.Module):
|
627
|
+
"""
|
628
|
+
A block that handles upsampling for the WanVAE decoder.
|
629
|
+
|
630
|
+
Args:
|
631
|
+
in_dim (int): Input dimension
|
632
|
+
out_dim (int): Output dimension
|
633
|
+
num_res_blocks (int): Number of residual blocks
|
634
|
+
dropout (float): Dropout rate
|
635
|
+
temperal_upsample (bool): Whether to upsample on temporal dimension
|
636
|
+
up_flag (bool): Whether to upsample or not
|
637
|
+
non_linearity (str): Type of non-linearity to use
|
638
|
+
"""
|
639
|
+
|
640
|
+
def __init__(
|
641
|
+
self,
|
642
|
+
in_dim: int,
|
643
|
+
out_dim: int,
|
644
|
+
num_res_blocks: int,
|
645
|
+
dropout: float = 0.0,
|
646
|
+
temperal_upsample: bool = False,
|
647
|
+
up_flag: bool = False,
|
648
|
+
non_linearity: str = "silu",
|
649
|
+
):
|
650
|
+
super().__init__()
|
651
|
+
self.in_dim = in_dim
|
652
|
+
self.out_dim = out_dim
|
653
|
+
|
654
|
+
if up_flag:
|
655
|
+
self.avg_shortcut = DupUp3D(
|
656
|
+
in_dim,
|
657
|
+
out_dim,
|
658
|
+
factor_t=2 if temperal_upsample else 1,
|
659
|
+
factor_s=2,
|
660
|
+
)
|
661
|
+
else:
|
662
|
+
self.avg_shortcut = None
|
663
|
+
|
664
|
+
# create residual blocks
|
665
|
+
resnets = []
|
666
|
+
current_dim = in_dim
|
667
|
+
for _ in range(num_res_blocks + 1):
|
668
|
+
resnets.append(WanResidualBlock(current_dim, out_dim, dropout, non_linearity))
|
669
|
+
current_dim = out_dim
|
670
|
+
|
671
|
+
self.resnets = nn.ModuleList(resnets)
|
672
|
+
|
673
|
+
# Add upsampling layer if needed
|
674
|
+
if up_flag:
|
675
|
+
upsample_mode = "upsample3d" if temperal_upsample else "upsample2d"
|
676
|
+
self.upsampler = WanResample(out_dim, mode=upsample_mode, upsample_out_dim=out_dim)
|
677
|
+
else:
|
678
|
+
self.upsampler = None
|
679
|
+
|
680
|
+
self.gradient_checkpointing = False
|
681
|
+
|
682
|
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
683
|
+
"""
|
684
|
+
Forward pass through the upsampling block.
|
685
|
+
|
686
|
+
Args:
|
687
|
+
x (torch.Tensor): Input tensor
|
688
|
+
feat_cache (list, optional): Feature cache for causal convolutions
|
689
|
+
feat_idx (list, optional): Feature index for cache management
|
690
|
+
|
691
|
+
Returns:
|
692
|
+
torch.Tensor: Output tensor
|
693
|
+
"""
|
694
|
+
x_copy = x.clone()
|
695
|
+
|
696
|
+
for resnet in self.resnets:
|
697
|
+
if feat_cache is not None:
|
698
|
+
x = resnet(x, feat_cache, feat_idx)
|
699
|
+
else:
|
700
|
+
x = resnet(x)
|
701
|
+
|
702
|
+
if self.upsampler is not None:
|
703
|
+
if feat_cache is not None:
|
704
|
+
x = self.upsampler(x, feat_cache, feat_idx)
|
705
|
+
else:
|
706
|
+
x = self.upsampler(x)
|
707
|
+
|
708
|
+
if self.avg_shortcut is not None:
|
709
|
+
x = x + self.avg_shortcut(x_copy, first_chunk=first_chunk)
|
710
|
+
|
711
|
+
return x
|
712
|
+
|
713
|
+
|
473
714
|
class WanUpBlock(nn.Module):
|
474
715
|
"""
|
475
716
|
A block that handles upsampling for the WanVAE decoder.
|
@@ -513,7 +754,7 @@ class WanUpBlock(nn.Module):
|
|
513
754
|
|
514
755
|
self.gradient_checkpointing = False
|
515
756
|
|
516
|
-
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
757
|
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=None):
|
517
758
|
"""
|
518
759
|
Forward pass through the upsampling block.
|
519
760
|
|
@@ -564,6 +805,8 @@ class WanDecoder3d(nn.Module):
|
|
564
805
|
temperal_upsample=[False, True, True],
|
565
806
|
dropout=0.0,
|
566
807
|
non_linearity: str = "silu",
|
808
|
+
out_channels: int = 3,
|
809
|
+
is_residual: bool = False,
|
567
810
|
):
|
568
811
|
super().__init__()
|
569
812
|
self.dim = dim
|
@@ -577,7 +820,6 @@ class WanDecoder3d(nn.Module):
|
|
577
820
|
|
578
821
|
# dimensions
|
579
822
|
dims = [dim * u for u in [dim_mult[-1]] + dim_mult[::-1]]
|
580
|
-
scale = 1.0 / 2 ** (len(dim_mult) - 2)
|
581
823
|
|
582
824
|
# init block
|
583
825
|
self.conv_in = WanCausalConv3d(z_dim, dims[0], 3, padding=1)
|
@@ -589,36 +831,47 @@ class WanDecoder3d(nn.Module):
|
|
589
831
|
self.up_blocks = nn.ModuleList([])
|
590
832
|
for i, (in_dim, out_dim) in enumerate(zip(dims[:-1], dims[1:])):
|
591
833
|
# residual (+attention) blocks
|
592
|
-
if i > 0:
|
834
|
+
if i > 0 and not is_residual:
|
835
|
+
# wan vae 2.1
|
593
836
|
in_dim = in_dim // 2
|
594
837
|
|
595
|
-
#
|
838
|
+
# determine if we need upsampling
|
839
|
+
up_flag = i != len(dim_mult) - 1
|
840
|
+
# determine upsampling mode, if not upsampling, set to None
|
596
841
|
upsample_mode = None
|
597
|
-
if
|
598
|
-
upsample_mode = "upsample3d"
|
599
|
-
|
842
|
+
if up_flag and temperal_upsample[i]:
|
843
|
+
upsample_mode = "upsample3d"
|
844
|
+
elif up_flag:
|
845
|
+
upsample_mode = "upsample2d"
|
600
846
|
# Create and add the upsampling block
|
601
|
-
|
602
|
-
|
603
|
-
|
604
|
-
|
605
|
-
|
606
|
-
|
607
|
-
|
608
|
-
|
847
|
+
if is_residual:
|
848
|
+
up_block = WanResidualUpBlock(
|
849
|
+
in_dim=in_dim,
|
850
|
+
out_dim=out_dim,
|
851
|
+
num_res_blocks=num_res_blocks,
|
852
|
+
dropout=dropout,
|
853
|
+
temperal_upsample=temperal_upsample[i] if up_flag else False,
|
854
|
+
up_flag=up_flag,
|
855
|
+
non_linearity=non_linearity,
|
856
|
+
)
|
857
|
+
else:
|
858
|
+
up_block = WanUpBlock(
|
859
|
+
in_dim=in_dim,
|
860
|
+
out_dim=out_dim,
|
861
|
+
num_res_blocks=num_res_blocks,
|
862
|
+
dropout=dropout,
|
863
|
+
upsample_mode=upsample_mode,
|
864
|
+
non_linearity=non_linearity,
|
865
|
+
)
|
609
866
|
self.up_blocks.append(up_block)
|
610
867
|
|
611
|
-
# Update scale for next iteration
|
612
|
-
if upsample_mode is not None:
|
613
|
-
scale *= 2.0
|
614
|
-
|
615
868
|
# output blocks
|
616
869
|
self.norm_out = WanRMS_norm(out_dim, images=False)
|
617
|
-
self.conv_out = WanCausalConv3d(out_dim,
|
870
|
+
self.conv_out = WanCausalConv3d(out_dim, out_channels, 3, padding=1)
|
618
871
|
|
619
872
|
self.gradient_checkpointing = False
|
620
873
|
|
621
|
-
def forward(self, x, feat_cache=None, feat_idx=[0]):
|
874
|
+
def forward(self, x, feat_cache=None, feat_idx=[0], first_chunk=False):
|
622
875
|
## conv1
|
623
876
|
if feat_cache is not None:
|
624
877
|
idx = feat_idx[0]
|
@@ -637,7 +890,7 @@ class WanDecoder3d(nn.Module):
|
|
637
890
|
|
638
891
|
## upsamples
|
639
892
|
for up_block in self.up_blocks:
|
640
|
-
x = up_block(x, feat_cache, feat_idx)
|
893
|
+
x = up_block(x, feat_cache, feat_idx, first_chunk=first_chunk)
|
641
894
|
|
642
895
|
## head
|
643
896
|
x = self.norm_out(x)
|
@@ -656,6 +909,49 @@ class WanDecoder3d(nn.Module):
|
|
656
909
|
return x
|
657
910
|
|
658
911
|
|
912
|
+
def patchify(x, patch_size):
|
913
|
+
if patch_size == 1:
|
914
|
+
return x
|
915
|
+
|
916
|
+
if x.dim() != 5:
|
917
|
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
918
|
+
# x shape: [batch_size, channels, frames, height, width]
|
919
|
+
batch_size, channels, frames, height, width = x.shape
|
920
|
+
|
921
|
+
# Ensure height and width are divisible by patch_size
|
922
|
+
if height % patch_size != 0 or width % patch_size != 0:
|
923
|
+
raise ValueError(f"Height ({height}) and width ({width}) must be divisible by patch_size ({patch_size})")
|
924
|
+
|
925
|
+
# Reshape to [batch_size, channels, frames, height//patch_size, patch_size, width//patch_size, patch_size]
|
926
|
+
x = x.view(batch_size, channels, frames, height // patch_size, patch_size, width // patch_size, patch_size)
|
927
|
+
|
928
|
+
# Rearrange to [batch_size, channels * patch_size * patch_size, frames, height//patch_size, width//patch_size]
|
929
|
+
x = x.permute(0, 1, 6, 4, 2, 3, 5).contiguous()
|
930
|
+
x = x.view(batch_size, channels * patch_size * patch_size, frames, height // patch_size, width // patch_size)
|
931
|
+
|
932
|
+
return x
|
933
|
+
|
934
|
+
|
935
|
+
def unpatchify(x, patch_size):
|
936
|
+
if patch_size == 1:
|
937
|
+
return x
|
938
|
+
|
939
|
+
if x.dim() != 5:
|
940
|
+
raise ValueError(f"Invalid input shape: {x.shape}")
|
941
|
+
# x shape: [batch_size, (channels * patch_size * patch_size), frame, height, width]
|
942
|
+
batch_size, c_patches, frames, height, width = x.shape
|
943
|
+
channels = c_patches // (patch_size * patch_size)
|
944
|
+
|
945
|
+
# Reshape to [b, c, patch_size, patch_size, f, h, w]
|
946
|
+
x = x.view(batch_size, channels, patch_size, patch_size, frames, height, width)
|
947
|
+
|
948
|
+
# Rearrange to [b, c, f, h * patch_size, w * patch_size]
|
949
|
+
x = x.permute(0, 1, 4, 5, 3, 6, 2).contiguous()
|
950
|
+
x = x.view(batch_size, channels, frames, height * patch_size, width * patch_size)
|
951
|
+
|
952
|
+
return x
|
953
|
+
|
954
|
+
|
659
955
|
class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
660
956
|
r"""
|
661
957
|
A VAE model with KL loss for encoding videos into latents and decoding latent representations into videos.
|
@@ -671,6 +967,7 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
671
967
|
def __init__(
|
672
968
|
self,
|
673
969
|
base_dim: int = 96,
|
970
|
+
decoder_base_dim: Optional[int] = None,
|
674
971
|
z_dim: int = 16,
|
675
972
|
dim_mult: Tuple[int] = [1, 2, 4, 4],
|
676
973
|
num_res_blocks: int = 2,
|
@@ -713,6 +1010,12 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
713
1010
|
2.8251,
|
714
1011
|
1.9160,
|
715
1012
|
],
|
1013
|
+
is_residual: bool = False,
|
1014
|
+
in_channels: int = 3,
|
1015
|
+
out_channels: int = 3,
|
1016
|
+
patch_size: Optional[int] = None,
|
1017
|
+
scale_factor_temporal: Optional[int] = 4,
|
1018
|
+
scale_factor_spatial: Optional[int] = 8,
|
716
1019
|
) -> None:
|
717
1020
|
super().__init__()
|
718
1021
|
|
@@ -720,14 +1023,33 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
720
1023
|
self.temperal_downsample = temperal_downsample
|
721
1024
|
self.temperal_upsample = temperal_downsample[::-1]
|
722
1025
|
|
1026
|
+
if decoder_base_dim is None:
|
1027
|
+
decoder_base_dim = base_dim
|
1028
|
+
|
723
1029
|
self.encoder = WanEncoder3d(
|
724
|
-
|
1030
|
+
in_channels=in_channels,
|
1031
|
+
dim=base_dim,
|
1032
|
+
z_dim=z_dim * 2,
|
1033
|
+
dim_mult=dim_mult,
|
1034
|
+
num_res_blocks=num_res_blocks,
|
1035
|
+
attn_scales=attn_scales,
|
1036
|
+
temperal_downsample=temperal_downsample,
|
1037
|
+
dropout=dropout,
|
1038
|
+
is_residual=is_residual,
|
725
1039
|
)
|
726
1040
|
self.quant_conv = WanCausalConv3d(z_dim * 2, z_dim * 2, 1)
|
727
1041
|
self.post_quant_conv = WanCausalConv3d(z_dim, z_dim, 1)
|
728
1042
|
|
729
1043
|
self.decoder = WanDecoder3d(
|
730
|
-
|
1044
|
+
dim=decoder_base_dim,
|
1045
|
+
z_dim=z_dim,
|
1046
|
+
dim_mult=dim_mult,
|
1047
|
+
num_res_blocks=num_res_blocks,
|
1048
|
+
attn_scales=attn_scales,
|
1049
|
+
temperal_upsample=self.temperal_upsample,
|
1050
|
+
dropout=dropout,
|
1051
|
+
out_channels=out_channels,
|
1052
|
+
is_residual=is_residual,
|
731
1053
|
)
|
732
1054
|
|
733
1055
|
self.spatial_compression_ratio = 2 ** len(self.temperal_downsample)
|
@@ -827,6 +1149,8 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
827
1149
|
return self.tiled_encode(x)
|
828
1150
|
|
829
1151
|
self.clear_cache()
|
1152
|
+
if self.config.patch_size is not None:
|
1153
|
+
x = patchify(x, patch_size=self.config.patch_size)
|
830
1154
|
iter_ = 1 + (num_frame - 1) // 4
|
831
1155
|
for i in range(iter_):
|
832
1156
|
self._enc_conv_idx = [0]
|
@@ -884,12 +1208,18 @@ class AutoencoderKLWan(ModelMixin, ConfigMixin, FromOriginalModelMixin):
|
|
884
1208
|
for i in range(num_frame):
|
885
1209
|
self._conv_idx = [0]
|
886
1210
|
if i == 0:
|
887
|
-
out = self.decoder(
|
1211
|
+
out = self.decoder(
|
1212
|
+
x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx, first_chunk=True
|
1213
|
+
)
|
888
1214
|
else:
|
889
1215
|
out_ = self.decoder(x[:, :, i : i + 1, :, :], feat_cache=self._feat_map, feat_idx=self._conv_idx)
|
890
1216
|
out = torch.cat([out, out_], 2)
|
891
1217
|
|
1218
|
+
if self.config.patch_size is not None:
|
1219
|
+
out = unpatchify(out, patch_size=self.config.patch_size)
|
1220
|
+
|
892
1221
|
out = torch.clamp(out, min=-1.0, max=1.0)
|
1222
|
+
|
893
1223
|
self.clear_cache()
|
894
1224
|
if not return_dict:
|
895
1225
|
return (out,)
|
diffusers/models/cache_utils.py
CHANGED
@@ -12,6 +12,8 @@
|
|
12
12
|
# See the License for the specific language governing permissions and
|
13
13
|
# limitations under the License.
|
14
14
|
|
15
|
+
from contextlib import contextmanager
|
16
|
+
|
15
17
|
from ..utils.logging import get_logger
|
16
18
|
|
17
19
|
|
@@ -25,6 +27,7 @@ class CacheMixin:
|
|
25
27
|
Supported caching techniques:
|
26
28
|
- [Pyramid Attention Broadcast](https://huggingface.co/papers/2408.12588)
|
27
29
|
- [FasterCache](https://huggingface.co/papers/2410.19355)
|
30
|
+
- [FirstBlockCache](https://github.com/chengzeyi/ParaAttention/blob/7a266123671b55e7e5a2fe9af3121f07a36afc78/README.md#first-block-cache-our-dynamic-caching)
|
28
31
|
"""
|
29
32
|
|
30
33
|
_cache_config = None
|
@@ -62,8 +65,10 @@ class CacheMixin:
|
|
62
65
|
|
63
66
|
from ..hooks import (
|
64
67
|
FasterCacheConfig,
|
68
|
+
FirstBlockCacheConfig,
|
65
69
|
PyramidAttentionBroadcastConfig,
|
66
70
|
apply_faster_cache,
|
71
|
+
apply_first_block_cache,
|
67
72
|
apply_pyramid_attention_broadcast,
|
68
73
|
)
|
69
74
|
|
@@ -72,31 +77,36 @@ class CacheMixin:
|
|
72
77
|
f"Caching has already been enabled with {type(self._cache_config)}. To apply a new caching technique, please disable the existing one first."
|
73
78
|
)
|
74
79
|
|
75
|
-
if isinstance(config,
|
76
|
-
apply_pyramid_attention_broadcast(self, config)
|
77
|
-
elif isinstance(config, FasterCacheConfig):
|
80
|
+
if isinstance(config, FasterCacheConfig):
|
78
81
|
apply_faster_cache(self, config)
|
82
|
+
elif isinstance(config, FirstBlockCacheConfig):
|
83
|
+
apply_first_block_cache(self, config)
|
84
|
+
elif isinstance(config, PyramidAttentionBroadcastConfig):
|
85
|
+
apply_pyramid_attention_broadcast(self, config)
|
79
86
|
else:
|
80
87
|
raise ValueError(f"Cache config {type(config)} is not supported.")
|
81
88
|
|
82
89
|
self._cache_config = config
|
83
90
|
|
84
91
|
def disable_cache(self) -> None:
|
85
|
-
from ..hooks import FasterCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
92
|
+
from ..hooks import FasterCacheConfig, FirstBlockCacheConfig, HookRegistry, PyramidAttentionBroadcastConfig
|
86
93
|
from ..hooks.faster_cache import _FASTER_CACHE_BLOCK_HOOK, _FASTER_CACHE_DENOISER_HOOK
|
94
|
+
from ..hooks.first_block_cache import _FBC_BLOCK_HOOK, _FBC_LEADER_BLOCK_HOOK
|
87
95
|
from ..hooks.pyramid_attention_broadcast import _PYRAMID_ATTENTION_BROADCAST_HOOK
|
88
96
|
|
89
97
|
if self._cache_config is None:
|
90
98
|
logger.warning("Caching techniques have not been enabled, so there's nothing to disable.")
|
91
99
|
return
|
92
100
|
|
93
|
-
|
94
|
-
|
95
|
-
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
96
|
-
elif isinstance(self._cache_config, FasterCacheConfig):
|
97
|
-
registry = HookRegistry.check_if_exists_or_initialize(self)
|
101
|
+
registry = HookRegistry.check_if_exists_or_initialize(self)
|
102
|
+
if isinstance(self._cache_config, FasterCacheConfig):
|
98
103
|
registry.remove_hook(_FASTER_CACHE_DENOISER_HOOK, recurse=True)
|
99
104
|
registry.remove_hook(_FASTER_CACHE_BLOCK_HOOK, recurse=True)
|
105
|
+
elif isinstance(self._cache_config, FirstBlockCacheConfig):
|
106
|
+
registry.remove_hook(_FBC_LEADER_BLOCK_HOOK, recurse=True)
|
107
|
+
registry.remove_hook(_FBC_BLOCK_HOOK, recurse=True)
|
108
|
+
elif isinstance(self._cache_config, PyramidAttentionBroadcastConfig):
|
109
|
+
registry.remove_hook(_PYRAMID_ATTENTION_BROADCAST_HOOK, recurse=True)
|
100
110
|
else:
|
101
111
|
raise ValueError(f"Cache config {type(self._cache_config)} is not supported.")
|
102
112
|
|
@@ -106,3 +116,15 @@ class CacheMixin:
|
|
106
116
|
from ..hooks import HookRegistry
|
107
117
|
|
108
118
|
HookRegistry.check_if_exists_or_initialize(self).reset_stateful_hooks(recurse=recurse)
|
119
|
+
|
120
|
+
@contextmanager
|
121
|
+
def cache_context(self, name: str):
|
122
|
+
r"""Context manager that provides additional methods for cache management."""
|
123
|
+
from ..hooks import HookRegistry
|
124
|
+
|
125
|
+
registry = HookRegistry.check_if_exists_or_initialize(self)
|
126
|
+
registry._set_context(name)
|
127
|
+
|
128
|
+
yield
|
129
|
+
|
130
|
+
registry._set_context(None)
|
@@ -343,25 +343,25 @@ class FluxControlNetModel(ModelMixin, ConfigMixin, PeftAdapterMixin):
|
|
343
343
|
)
|
344
344
|
block_samples = block_samples + (hidden_states,)
|
345
345
|
|
346
|
-
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
347
|
-
|
348
346
|
single_block_samples = ()
|
349
347
|
for index_block, block in enumerate(self.single_transformer_blocks):
|
350
348
|
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
351
|
-
hidden_states = self._gradient_checkpointing_func(
|
349
|
+
encoder_hidden_states, hidden_states = self._gradient_checkpointing_func(
|
352
350
|
block,
|
353
351
|
hidden_states,
|
352
|
+
encoder_hidden_states,
|
354
353
|
temb,
|
355
354
|
image_rotary_emb,
|
356
355
|
)
|
357
356
|
|
358
357
|
else:
|
359
|
-
hidden_states = block(
|
358
|
+
encoder_hidden_states, hidden_states = block(
|
360
359
|
hidden_states=hidden_states,
|
360
|
+
encoder_hidden_states=encoder_hidden_states,
|
361
361
|
temb=temb,
|
362
362
|
image_rotary_emb=image_rotary_emb,
|
363
363
|
)
|
364
|
-
single_block_samples = single_block_samples + (hidden_states
|
364
|
+
single_block_samples = single_block_samples + (hidden_states,)
|
365
365
|
|
366
366
|
# controlnet block
|
367
367
|
controlnet_block_samples = ()
|