cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -1,16 +1,13 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
|
|
2
|
-
|
|
3
1
|
import logging
|
|
4
2
|
import contextlib
|
|
5
3
|
import dataclasses
|
|
6
4
|
from collections import defaultdict
|
|
7
|
-
from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
5
|
+
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
8
6
|
|
|
9
7
|
import torch
|
|
10
8
|
|
|
11
9
|
import cache_dit.primitives as primitives
|
|
12
10
|
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
13
|
-
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
14
11
|
from cache_dit.logger import init_logger
|
|
15
12
|
|
|
16
13
|
logger = init_logger(__name__)
|
|
@@ -55,8 +52,7 @@ class DBCacheContext:
|
|
|
55
52
|
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
56
53
|
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
57
54
|
|
|
58
|
-
#
|
|
59
|
-
# Record the steps that have been cached, both alter cache and non-alter cache
|
|
55
|
+
# Record the steps that have been cached, both cached and non-cache
|
|
60
56
|
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
61
57
|
# steps for transformer, for CFG, transformer_executed_steps will
|
|
62
58
|
# be double of executed_steps.
|
|
@@ -73,10 +69,10 @@ class DBCacheContext:
|
|
|
73
69
|
taylorseer: Optional[TaylorSeer] = None
|
|
74
70
|
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
75
71
|
|
|
76
|
-
# Support do_separate_classifier_free_guidance, such as Wan 2.1
|
|
77
|
-
# For model that fused CFG and non-CFG into single
|
|
78
|
-
# should set do_separate_classifier_free_guidance
|
|
79
|
-
# example: CogVideoX, HunyuanVideo, Mochi.
|
|
72
|
+
# Support do_separate_classifier_free_guidance, such as Wan 2.1,
|
|
73
|
+
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
74
|
+
# forward step, should set do_separate_classifier_free_guidance
|
|
75
|
+
# as False. For example: CogVideoX, HunyuanVideo, Mochi.
|
|
80
76
|
do_separate_classifier_free_guidance: bool = False
|
|
81
77
|
# Compute cfg forward first or not, default False, namely,
|
|
82
78
|
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
@@ -98,13 +94,6 @@ class DBCacheContext:
|
|
|
98
94
|
default_factory=lambda: defaultdict(float),
|
|
99
95
|
)
|
|
100
96
|
|
|
101
|
-
# TODO: Support SLG in Dual Block Cache
|
|
102
|
-
# Skip Layer Guidance, SLG
|
|
103
|
-
# https://github.com/huggingface/candle/issues/2588
|
|
104
|
-
slg_layers: Optional[List[int]] = None
|
|
105
|
-
slg_start: float = 0.0
|
|
106
|
-
slg_end: float = 0.1
|
|
107
|
-
|
|
108
97
|
@torch.compiler.disable
|
|
109
98
|
def __post_init__(self):
|
|
110
99
|
# Some checks for settings
|
|
@@ -144,18 +133,6 @@ class DBCacheContext:
|
|
|
144
133
|
**self.taylorseer_kwargs
|
|
145
134
|
)
|
|
146
135
|
|
|
147
|
-
@torch.compiler.disable
|
|
148
|
-
def get_incremental_name(self, name=None):
|
|
149
|
-
if name is None:
|
|
150
|
-
name = "default"
|
|
151
|
-
idx = self.incremental_name_counters[name]
|
|
152
|
-
self.incremental_name_counters[name] += 1
|
|
153
|
-
return f"{name}_{idx}"
|
|
154
|
-
|
|
155
|
-
@torch.compiler.disable
|
|
156
|
-
def reset_incremental_names(self):
|
|
157
|
-
self.incremental_name_counters.clear()
|
|
158
|
-
|
|
159
136
|
@torch.compiler.disable
|
|
160
137
|
def get_residual_diff_threshold(self):
|
|
161
138
|
if self.enable_alter_cache:
|
|
@@ -222,7 +199,6 @@ class DBCacheContext:
|
|
|
222
199
|
self.residual_diffs.clear()
|
|
223
200
|
self.cfg_cached_steps.clear()
|
|
224
201
|
self.cfg_residual_diffs.clear()
|
|
225
|
-
self.reset_incremental_names()
|
|
226
202
|
# Reset the TaylorSeers cache at the beginning of each inference.
|
|
227
203
|
# reset_cache will set the current step to -1 for TaylorSeer,
|
|
228
204
|
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
@@ -264,12 +240,10 @@ class DBCacheContext:
|
|
|
264
240
|
if encoder_taylorseer is not None:
|
|
265
241
|
encoder_taylorseer.mark_step_begin()
|
|
266
242
|
|
|
267
|
-
|
|
268
|
-
def get_taylorseers(self):
|
|
243
|
+
def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
269
244
|
return self.taylorseer, self.encoder_tarlorseer
|
|
270
245
|
|
|
271
|
-
|
|
272
|
-
def get_cfg_taylorseers(self):
|
|
246
|
+
def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
273
247
|
return self.cfg_taylorseer, self.cfg_encoder_taylorseer
|
|
274
248
|
|
|
275
249
|
@torch.compiler.disable
|
|
@@ -464,15 +438,13 @@ def is_encoder_taylorseer_enabled():
|
|
|
464
438
|
return cache_context.enable_encoder_taylorseer
|
|
465
439
|
|
|
466
440
|
|
|
467
|
-
|
|
468
|
-
def get_taylorseers():
|
|
441
|
+
def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
|
|
469
442
|
cache_context = get_current_cache_context()
|
|
470
443
|
assert cache_context is not None, "cache_context must be set before"
|
|
471
444
|
return cache_context.get_taylorseers()
|
|
472
445
|
|
|
473
446
|
|
|
474
|
-
|
|
475
|
-
def get_cfg_taylorseers():
|
|
447
|
+
def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
|
|
476
448
|
cache_context = get_current_cache_context()
|
|
477
449
|
assert cache_context is not None, "cache_context must be set before"
|
|
478
450
|
return cache_context.get_cfg_taylorseers()
|
|
@@ -1105,825 +1077,3 @@ def get_can_use_cache(
|
|
|
1105
1077
|
and is_alter_cache()
|
|
1106
1078
|
)
|
|
1107
1079
|
return can_use_cache
|
|
1108
|
-
|
|
1109
|
-
|
|
1110
|
-
class DBCachedTransformerBlocks(torch.nn.Module):
|
|
1111
|
-
def __init__(
|
|
1112
|
-
self,
|
|
1113
|
-
transformer_blocks,
|
|
1114
|
-
single_transformer_blocks=None,
|
|
1115
|
-
*,
|
|
1116
|
-
transformer=None,
|
|
1117
|
-
return_hidden_states_first=True,
|
|
1118
|
-
return_hidden_states_only=False,
|
|
1119
|
-
):
|
|
1120
|
-
super().__init__()
|
|
1121
|
-
|
|
1122
|
-
self.transformer = transformer
|
|
1123
|
-
self.transformer_blocks = transformer_blocks
|
|
1124
|
-
self.single_transformer_blocks = single_transformer_blocks
|
|
1125
|
-
self.return_hidden_states_first = return_hidden_states_first
|
|
1126
|
-
self.return_hidden_states_only = return_hidden_states_only
|
|
1127
|
-
|
|
1128
|
-
def forward(
|
|
1129
|
-
self,
|
|
1130
|
-
hidden_states: torch.Tensor,
|
|
1131
|
-
encoder_hidden_states: torch.Tensor,
|
|
1132
|
-
*args,
|
|
1133
|
-
**kwargs,
|
|
1134
|
-
):
|
|
1135
|
-
original_hidden_states = hidden_states
|
|
1136
|
-
# Call first `n` blocks to process the hidden states for
|
|
1137
|
-
# more stable diff calculation.
|
|
1138
|
-
hidden_states, encoder_hidden_states = self.call_Fn_transformer_blocks(
|
|
1139
|
-
hidden_states,
|
|
1140
|
-
encoder_hidden_states,
|
|
1141
|
-
*args,
|
|
1142
|
-
**kwargs,
|
|
1143
|
-
)
|
|
1144
|
-
|
|
1145
|
-
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
1146
|
-
del original_hidden_states
|
|
1147
|
-
|
|
1148
|
-
mark_step_begin()
|
|
1149
|
-
# Residual L1 diff or Hidden States L1 diff
|
|
1150
|
-
can_use_cache = get_can_use_cache(
|
|
1151
|
-
(
|
|
1152
|
-
Fn_hidden_states_residual
|
|
1153
|
-
if not is_l1_diff_enabled()
|
|
1154
|
-
else hidden_states
|
|
1155
|
-
),
|
|
1156
|
-
parallelized=self._is_parallelized(),
|
|
1157
|
-
prefix=(
|
|
1158
|
-
"Fn_residual"
|
|
1159
|
-
if not is_l1_diff_enabled()
|
|
1160
|
-
else "Fn_hidden_states"
|
|
1161
|
-
),
|
|
1162
|
-
)
|
|
1163
|
-
|
|
1164
|
-
torch._dynamo.graph_break()
|
|
1165
|
-
if can_use_cache:
|
|
1166
|
-
add_cached_step()
|
|
1167
|
-
del Fn_hidden_states_residual
|
|
1168
|
-
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
1169
|
-
hidden_states,
|
|
1170
|
-
encoder_hidden_states,
|
|
1171
|
-
prefix=(
|
|
1172
|
-
"Bn_residual" if is_cache_residual() else "Bn_hidden_states"
|
|
1173
|
-
),
|
|
1174
|
-
encoder_prefix=(
|
|
1175
|
-
"Bn_residual"
|
|
1176
|
-
if is_encoder_cache_residual()
|
|
1177
|
-
else "Bn_hidden_states"
|
|
1178
|
-
),
|
|
1179
|
-
)
|
|
1180
|
-
torch._dynamo.graph_break()
|
|
1181
|
-
# Call last `n` blocks to further process the hidden states
|
|
1182
|
-
# for higher precision.
|
|
1183
|
-
hidden_states, encoder_hidden_states = (
|
|
1184
|
-
self.call_Bn_transformer_blocks(
|
|
1185
|
-
hidden_states,
|
|
1186
|
-
encoder_hidden_states,
|
|
1187
|
-
*args,
|
|
1188
|
-
**kwargs,
|
|
1189
|
-
)
|
|
1190
|
-
)
|
|
1191
|
-
else:
|
|
1192
|
-
set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
|
|
1193
|
-
if is_l1_diff_enabled():
|
|
1194
|
-
# for hidden states L1 diff
|
|
1195
|
-
set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
1196
|
-
del Fn_hidden_states_residual
|
|
1197
|
-
torch._dynamo.graph_break()
|
|
1198
|
-
(
|
|
1199
|
-
hidden_states,
|
|
1200
|
-
encoder_hidden_states,
|
|
1201
|
-
hidden_states_residual,
|
|
1202
|
-
encoder_hidden_states_residual,
|
|
1203
|
-
) = self.call_Mn_transformer_blocks( # middle
|
|
1204
|
-
hidden_states,
|
|
1205
|
-
encoder_hidden_states,
|
|
1206
|
-
*args,
|
|
1207
|
-
**kwargs,
|
|
1208
|
-
)
|
|
1209
|
-
torch._dynamo.graph_break()
|
|
1210
|
-
if is_cache_residual():
|
|
1211
|
-
set_Bn_buffer(
|
|
1212
|
-
hidden_states_residual,
|
|
1213
|
-
prefix="Bn_residual",
|
|
1214
|
-
)
|
|
1215
|
-
else:
|
|
1216
|
-
# TaylorSeer
|
|
1217
|
-
set_Bn_buffer(
|
|
1218
|
-
hidden_states,
|
|
1219
|
-
prefix="Bn_hidden_states",
|
|
1220
|
-
)
|
|
1221
|
-
if is_encoder_cache_residual():
|
|
1222
|
-
set_Bn_encoder_buffer(
|
|
1223
|
-
encoder_hidden_states_residual,
|
|
1224
|
-
prefix="Bn_residual",
|
|
1225
|
-
)
|
|
1226
|
-
else:
|
|
1227
|
-
# TaylorSeer
|
|
1228
|
-
set_Bn_encoder_buffer(
|
|
1229
|
-
encoder_hidden_states,
|
|
1230
|
-
prefix="Bn_hidden_states",
|
|
1231
|
-
)
|
|
1232
|
-
torch._dynamo.graph_break()
|
|
1233
|
-
# Call last `n` blocks to further process the hidden states
|
|
1234
|
-
# for higher precision.
|
|
1235
|
-
hidden_states, encoder_hidden_states = (
|
|
1236
|
-
self.call_Bn_transformer_blocks(
|
|
1237
|
-
hidden_states,
|
|
1238
|
-
encoder_hidden_states,
|
|
1239
|
-
*args,
|
|
1240
|
-
**kwargs,
|
|
1241
|
-
)
|
|
1242
|
-
)
|
|
1243
|
-
|
|
1244
|
-
patch_cached_stats(self.transformer)
|
|
1245
|
-
torch._dynamo.graph_break()
|
|
1246
|
-
|
|
1247
|
-
return (
|
|
1248
|
-
hidden_states
|
|
1249
|
-
if self.return_hidden_states_only
|
|
1250
|
-
else (
|
|
1251
|
-
(hidden_states, encoder_hidden_states)
|
|
1252
|
-
if self.return_hidden_states_first
|
|
1253
|
-
else (encoder_hidden_states, hidden_states)
|
|
1254
|
-
)
|
|
1255
|
-
)
|
|
1256
|
-
|
|
1257
|
-
@torch.compiler.disable
|
|
1258
|
-
def _is_parallelized(self):
|
|
1259
|
-
# Compatible with distributed inference.
|
|
1260
|
-
return all(
|
|
1261
|
-
(
|
|
1262
|
-
self.transformer is not None,
|
|
1263
|
-
getattr(self.transformer, "_is_parallelized", False),
|
|
1264
|
-
)
|
|
1265
|
-
)
|
|
1266
|
-
|
|
1267
|
-
@torch.compiler.disable
|
|
1268
|
-
def _is_in_cache_step(self):
|
|
1269
|
-
# Check if the current step is in cache steps.
|
|
1270
|
-
# If so, we can skip some Bn blocks and directly
|
|
1271
|
-
# use the cached values.
|
|
1272
|
-
return (get_current_step() in get_cached_steps()) or (
|
|
1273
|
-
get_current_step() in get_cfg_cached_steps()
|
|
1274
|
-
)
|
|
1275
|
-
|
|
1276
|
-
@torch.compiler.disable
|
|
1277
|
-
def _Fn_transformer_blocks(self):
|
|
1278
|
-
# Select first `n` blocks to process the hidden states for
|
|
1279
|
-
# more stable diff calculation.
|
|
1280
|
-
# Fn: [0,...,n-1]
|
|
1281
|
-
selected_Fn_transformer_blocks = self.transformer_blocks[
|
|
1282
|
-
: Fn_compute_blocks()
|
|
1283
|
-
]
|
|
1284
|
-
return selected_Fn_transformer_blocks
|
|
1285
|
-
|
|
1286
|
-
@torch.compiler.disable
|
|
1287
|
-
def _Mn_single_transformer_blocks(self): # middle blocks
|
|
1288
|
-
# M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
|
|
1289
|
-
selected_Mn_single_transformer_blocks = []
|
|
1290
|
-
if self.single_transformer_blocks is not None:
|
|
1291
|
-
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
1292
|
-
selected_Mn_single_transformer_blocks = (
|
|
1293
|
-
self.single_transformer_blocks
|
|
1294
|
-
)
|
|
1295
|
-
else:
|
|
1296
|
-
selected_Mn_single_transformer_blocks = (
|
|
1297
|
-
self.single_transformer_blocks[: -Bn_compute_blocks()]
|
|
1298
|
-
)
|
|
1299
|
-
return selected_Mn_single_transformer_blocks
|
|
1300
|
-
|
|
1301
|
-
@torch.compiler.disable
|
|
1302
|
-
def _Mn_transformer_blocks(self): # middle blocks
|
|
1303
|
-
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
1304
|
-
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
1305
|
-
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
1306
|
-
Fn_compute_blocks() :
|
|
1307
|
-
]
|
|
1308
|
-
else:
|
|
1309
|
-
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
1310
|
-
Fn_compute_blocks() : -Bn_compute_blocks()
|
|
1311
|
-
]
|
|
1312
|
-
return selected_Mn_transformer_blocks
|
|
1313
|
-
|
|
1314
|
-
@torch.compiler.disable
|
|
1315
|
-
def _Bn_single_transformer_blocks(self):
|
|
1316
|
-
# Bn: single_transformer_blocks [N-n+1,...,N-1]
|
|
1317
|
-
selected_Bn_single_transformer_blocks = []
|
|
1318
|
-
if self.single_transformer_blocks is not None:
|
|
1319
|
-
selected_Bn_single_transformer_blocks = (
|
|
1320
|
-
self.single_transformer_blocks[-Bn_compute_blocks() :]
|
|
1321
|
-
)
|
|
1322
|
-
return selected_Bn_single_transformer_blocks
|
|
1323
|
-
|
|
1324
|
-
@torch.compiler.disable
|
|
1325
|
-
def _Bn_transformer_blocks(self):
|
|
1326
|
-
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
1327
|
-
selected_Bn_transformer_blocks = self.transformer_blocks[
|
|
1328
|
-
-Bn_compute_blocks() :
|
|
1329
|
-
]
|
|
1330
|
-
return selected_Bn_transformer_blocks
|
|
1331
|
-
|
|
1332
|
-
def call_Fn_transformer_blocks(
|
|
1333
|
-
self,
|
|
1334
|
-
hidden_states: torch.Tensor,
|
|
1335
|
-
encoder_hidden_states: torch.Tensor,
|
|
1336
|
-
*args,
|
|
1337
|
-
**kwargs,
|
|
1338
|
-
):
|
|
1339
|
-
assert Fn_compute_blocks() <= len(self.transformer_blocks), (
|
|
1340
|
-
f"Fn_compute_blocks {Fn_compute_blocks()} must be less than "
|
|
1341
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
1342
|
-
)
|
|
1343
|
-
for block in self._Fn_transformer_blocks():
|
|
1344
|
-
hidden_states = block(
|
|
1345
|
-
hidden_states,
|
|
1346
|
-
encoder_hidden_states,
|
|
1347
|
-
*args,
|
|
1348
|
-
**kwargs,
|
|
1349
|
-
)
|
|
1350
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1351
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1352
|
-
if not self.return_hidden_states_first:
|
|
1353
|
-
hidden_states, encoder_hidden_states = (
|
|
1354
|
-
encoder_hidden_states,
|
|
1355
|
-
hidden_states,
|
|
1356
|
-
)
|
|
1357
|
-
|
|
1358
|
-
return hidden_states, encoder_hidden_states
|
|
1359
|
-
|
|
1360
|
-
def call_Mn_transformer_blocks(
|
|
1361
|
-
self,
|
|
1362
|
-
hidden_states: torch.Tensor,
|
|
1363
|
-
encoder_hidden_states: torch.Tensor,
|
|
1364
|
-
*args,
|
|
1365
|
-
**kwargs,
|
|
1366
|
-
):
|
|
1367
|
-
original_hidden_states = hidden_states
|
|
1368
|
-
original_encoder_hidden_states = encoder_hidden_states
|
|
1369
|
-
# This condition branch is mainly for FLUX series.
|
|
1370
|
-
if self.single_transformer_blocks is not None:
|
|
1371
|
-
for block in self.transformer_blocks[Fn_compute_blocks() :]:
|
|
1372
|
-
hidden_states = block(
|
|
1373
|
-
hidden_states,
|
|
1374
|
-
encoder_hidden_states,
|
|
1375
|
-
*args,
|
|
1376
|
-
**kwargs,
|
|
1377
|
-
)
|
|
1378
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1379
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1380
|
-
if not self.return_hidden_states_first:
|
|
1381
|
-
hidden_states, encoder_hidden_states = (
|
|
1382
|
-
encoder_hidden_states,
|
|
1383
|
-
hidden_states,
|
|
1384
|
-
)
|
|
1385
|
-
|
|
1386
|
-
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
1387
|
-
if is_diffusers_at_least_0_3_5():
|
|
1388
|
-
for block in self._Mn_single_transformer_blocks():
|
|
1389
|
-
encoder_hidden_states, hidden_states = block(
|
|
1390
|
-
hidden_states,
|
|
1391
|
-
encoder_hidden_states,
|
|
1392
|
-
*args,
|
|
1393
|
-
**kwargs,
|
|
1394
|
-
)
|
|
1395
|
-
else:
|
|
1396
|
-
hidden_states = torch.cat(
|
|
1397
|
-
[encoder_hidden_states, hidden_states], dim=1
|
|
1398
|
-
)
|
|
1399
|
-
for block in self._Mn_single_transformer_blocks():
|
|
1400
|
-
hidden_states = block(
|
|
1401
|
-
hidden_states,
|
|
1402
|
-
*args,
|
|
1403
|
-
**kwargs,
|
|
1404
|
-
)
|
|
1405
|
-
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
1406
|
-
[
|
|
1407
|
-
encoder_hidden_states.shape[1],
|
|
1408
|
-
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
1409
|
-
],
|
|
1410
|
-
dim=1,
|
|
1411
|
-
)
|
|
1412
|
-
else:
|
|
1413
|
-
for block in self._Mn_transformer_blocks():
|
|
1414
|
-
hidden_states = block(
|
|
1415
|
-
hidden_states,
|
|
1416
|
-
encoder_hidden_states,
|
|
1417
|
-
*args,
|
|
1418
|
-
**kwargs,
|
|
1419
|
-
)
|
|
1420
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1421
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1422
|
-
if not self.return_hidden_states_first:
|
|
1423
|
-
hidden_states, encoder_hidden_states = (
|
|
1424
|
-
encoder_hidden_states,
|
|
1425
|
-
hidden_states,
|
|
1426
|
-
)
|
|
1427
|
-
|
|
1428
|
-
# hidden_states_shape = hidden_states.shape
|
|
1429
|
-
# encoder_hidden_states_shape = encoder_hidden_states.shape
|
|
1430
|
-
hidden_states = (
|
|
1431
|
-
hidden_states.reshape(-1)
|
|
1432
|
-
.contiguous()
|
|
1433
|
-
.reshape(original_hidden_states.shape)
|
|
1434
|
-
)
|
|
1435
|
-
encoder_hidden_states = (
|
|
1436
|
-
encoder_hidden_states.reshape(-1)
|
|
1437
|
-
.contiguous()
|
|
1438
|
-
.reshape(original_encoder_hidden_states.shape)
|
|
1439
|
-
)
|
|
1440
|
-
|
|
1441
|
-
# hidden_states = hidden_states.contiguous()
|
|
1442
|
-
# encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
1443
|
-
|
|
1444
|
-
hidden_states_residual = hidden_states - original_hidden_states
|
|
1445
|
-
encoder_hidden_states_residual = (
|
|
1446
|
-
encoder_hidden_states - original_encoder_hidden_states
|
|
1447
|
-
)
|
|
1448
|
-
|
|
1449
|
-
hidden_states_residual = (
|
|
1450
|
-
hidden_states_residual.reshape(-1)
|
|
1451
|
-
.contiguous()
|
|
1452
|
-
.reshape(original_hidden_states.shape)
|
|
1453
|
-
)
|
|
1454
|
-
encoder_hidden_states_residual = (
|
|
1455
|
-
encoder_hidden_states_residual.reshape(-1)
|
|
1456
|
-
.contiguous()
|
|
1457
|
-
.reshape(original_encoder_hidden_states.shape)
|
|
1458
|
-
)
|
|
1459
|
-
|
|
1460
|
-
return (
|
|
1461
|
-
hidden_states,
|
|
1462
|
-
encoder_hidden_states,
|
|
1463
|
-
hidden_states_residual,
|
|
1464
|
-
encoder_hidden_states_residual,
|
|
1465
|
-
)
|
|
1466
|
-
|
|
1467
|
-
@torch.compiler.disable
|
|
1468
|
-
def _Bn_i_single_hidden_states_residual(
|
|
1469
|
-
self,
|
|
1470
|
-
Bn_i_hidden_states: torch.Tensor,
|
|
1471
|
-
Bn_i_original_hidden_states: torch.Tensor,
|
|
1472
|
-
original_hidden_states: torch.Tensor,
|
|
1473
|
-
original_encoder_hidden_states: torch.Tensor,
|
|
1474
|
-
):
|
|
1475
|
-
# Split the Bn_i_hidden_states and Bn_i_original_hidden_states
|
|
1476
|
-
# into encoder_hidden_states and hidden_states.
|
|
1477
|
-
Bn_i_hidden_states, Bn_i_encoder_hidden_states = (
|
|
1478
|
-
self._split_Bn_i_single_hidden_states(
|
|
1479
|
-
Bn_i_hidden_states,
|
|
1480
|
-
original_hidden_states,
|
|
1481
|
-
original_encoder_hidden_states,
|
|
1482
|
-
)
|
|
1483
|
-
)
|
|
1484
|
-
# Split the Bn_i_original_hidden_states into encoder_hidden_states
|
|
1485
|
-
# and hidden_states.
|
|
1486
|
-
Bn_i_original_hidden_states, Bn_i_original_encoder_hidden_states = (
|
|
1487
|
-
self._split_Bn_i_single_hidden_states(
|
|
1488
|
-
Bn_i_original_hidden_states,
|
|
1489
|
-
original_hidden_states,
|
|
1490
|
-
original_encoder_hidden_states,
|
|
1491
|
-
)
|
|
1492
|
-
)
|
|
1493
|
-
|
|
1494
|
-
# Compute the residuals for the Bn_i_hidden_states and
|
|
1495
|
-
# Bn_i_encoder_hidden_states.
|
|
1496
|
-
Bn_i_hidden_states_residual = (
|
|
1497
|
-
Bn_i_hidden_states - Bn_i_original_hidden_states
|
|
1498
|
-
)
|
|
1499
|
-
Bn_i_encoder_hidden_states_residual = (
|
|
1500
|
-
Bn_i_encoder_hidden_states - Bn_i_original_encoder_hidden_states
|
|
1501
|
-
)
|
|
1502
|
-
return (
|
|
1503
|
-
Bn_i_hidden_states_residual,
|
|
1504
|
-
Bn_i_encoder_hidden_states_residual,
|
|
1505
|
-
)
|
|
1506
|
-
|
|
1507
|
-
@torch.compiler.disable
|
|
1508
|
-
def _split_Bn_i_single_hidden_states(
|
|
1509
|
-
self,
|
|
1510
|
-
Bn_i_hidden_states: torch.Tensor,
|
|
1511
|
-
original_hidden_states: torch.Tensor,
|
|
1512
|
-
original_encoder_hidden_states: torch.Tensor,
|
|
1513
|
-
):
|
|
1514
|
-
# Split the Bn_i_hidden_states into encoder_hidden_states and hidden_states.
|
|
1515
|
-
Bn_i_encoder_hidden_states, Bn_i_hidden_states = (
|
|
1516
|
-
Bn_i_hidden_states.split(
|
|
1517
|
-
[
|
|
1518
|
-
original_encoder_hidden_states.shape[1],
|
|
1519
|
-
Bn_i_hidden_states.shape[1]
|
|
1520
|
-
- original_encoder_hidden_states.shape[1],
|
|
1521
|
-
],
|
|
1522
|
-
dim=1,
|
|
1523
|
-
)
|
|
1524
|
-
)
|
|
1525
|
-
# Reshape the Bn_i_hidden_states and Bn_i_encoder_hidden_states
|
|
1526
|
-
# to the original shape. This is necessary to ensure that the
|
|
1527
|
-
# residuals are computed correctly.
|
|
1528
|
-
Bn_i_hidden_states = (
|
|
1529
|
-
Bn_i_hidden_states.reshape(-1)
|
|
1530
|
-
.contiguous()
|
|
1531
|
-
.reshape(original_hidden_states.shape)
|
|
1532
|
-
)
|
|
1533
|
-
Bn_i_encoder_hidden_states = (
|
|
1534
|
-
Bn_i_encoder_hidden_states.reshape(-1)
|
|
1535
|
-
.contiguous()
|
|
1536
|
-
.reshape(original_encoder_hidden_states.shape)
|
|
1537
|
-
)
|
|
1538
|
-
return Bn_i_hidden_states, Bn_i_encoder_hidden_states
|
|
1539
|
-
|
|
1540
|
-
def _compute_and_cache_single_transformer_block(
|
|
1541
|
-
self,
|
|
1542
|
-
# Block index in the transformer blocks
|
|
1543
|
-
# Bn: 8, block_id should be in [0, 8)
|
|
1544
|
-
block_id: int,
|
|
1545
|
-
# Helper inputs for hidden states split and reshape
|
|
1546
|
-
original_hidden_states: torch.Tensor,
|
|
1547
|
-
original_encoder_hidden_states: torch.Tensor,
|
|
1548
|
-
# Below are the inputs to the block
|
|
1549
|
-
block, # The transformer block to be executed
|
|
1550
|
-
hidden_states: torch.Tensor,
|
|
1551
|
-
*args,
|
|
1552
|
-
**kwargs,
|
|
1553
|
-
):
|
|
1554
|
-
# Helper function for `call_Bn_transformer_blocks`
|
|
1555
|
-
# Skip the blocks by reuse residual cache if they are not
|
|
1556
|
-
# in the Bn_compute_blocks_ids. NOTE: We should only skip
|
|
1557
|
-
# the specific Bn blocks in cache steps. Compute the block
|
|
1558
|
-
# and cache the residuals in non-cache steps.
|
|
1559
|
-
|
|
1560
|
-
# Normal steps: Compute the block and cache the residuals.
|
|
1561
|
-
if not self._is_in_cache_step():
|
|
1562
|
-
Bn_i_original_hidden_states = hidden_states
|
|
1563
|
-
hidden_states = block(
|
|
1564
|
-
hidden_states,
|
|
1565
|
-
*args,
|
|
1566
|
-
**kwargs,
|
|
1567
|
-
)
|
|
1568
|
-
# Cache residuals for the non-compute Bn blocks for
|
|
1569
|
-
# subsequent cache steps.
|
|
1570
|
-
if block_id not in Bn_compute_blocks_ids():
|
|
1571
|
-
Bn_i_hidden_states = hidden_states
|
|
1572
|
-
(
|
|
1573
|
-
Bn_i_hidden_states_residual,
|
|
1574
|
-
Bn_i_encoder_hidden_states_residual,
|
|
1575
|
-
) = self._Bn_i_single_hidden_states_residual(
|
|
1576
|
-
Bn_i_hidden_states,
|
|
1577
|
-
Bn_i_original_hidden_states,
|
|
1578
|
-
original_hidden_states,
|
|
1579
|
-
original_encoder_hidden_states,
|
|
1580
|
-
)
|
|
1581
|
-
|
|
1582
|
-
# Save original_hidden_states for diff calculation.
|
|
1583
|
-
set_Bn_buffer(
|
|
1584
|
-
Bn_i_original_hidden_states,
|
|
1585
|
-
prefix=f"Bn_{block_id}_single_original",
|
|
1586
|
-
)
|
|
1587
|
-
set_Bn_encoder_buffer(
|
|
1588
|
-
Bn_i_original_hidden_states,
|
|
1589
|
-
prefix=f"Bn_{block_id}_single_original",
|
|
1590
|
-
)
|
|
1591
|
-
|
|
1592
|
-
set_Bn_buffer(
|
|
1593
|
-
Bn_i_hidden_states_residual,
|
|
1594
|
-
prefix=f"Bn_{block_id}_single_residual",
|
|
1595
|
-
)
|
|
1596
|
-
set_Bn_encoder_buffer(
|
|
1597
|
-
Bn_i_encoder_hidden_states_residual,
|
|
1598
|
-
prefix=f"Bn_{block_id}_single_residual",
|
|
1599
|
-
)
|
|
1600
|
-
del Bn_i_hidden_states
|
|
1601
|
-
del Bn_i_hidden_states_residual
|
|
1602
|
-
del Bn_i_encoder_hidden_states_residual
|
|
1603
|
-
|
|
1604
|
-
del Bn_i_original_hidden_states
|
|
1605
|
-
|
|
1606
|
-
else:
|
|
1607
|
-
# Cache steps: Reuse the cached residuals.
|
|
1608
|
-
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1609
|
-
if block_id in Bn_compute_blocks_ids():
|
|
1610
|
-
hidden_states = block(
|
|
1611
|
-
hidden_states,
|
|
1612
|
-
*args,
|
|
1613
|
-
**kwargs,
|
|
1614
|
-
)
|
|
1615
|
-
else:
|
|
1616
|
-
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
1617
|
-
# Use the cached residuals instead.
|
|
1618
|
-
# Check if can use the cached residuals.
|
|
1619
|
-
if get_can_use_cache(
|
|
1620
|
-
hidden_states, # curr step
|
|
1621
|
-
parallelized=self._is_parallelized(),
|
|
1622
|
-
threshold=non_compute_blocks_diff_threshold(),
|
|
1623
|
-
prefix=f"Bn_{block_id}_single_original", # prev step
|
|
1624
|
-
):
|
|
1625
|
-
Bn_i_original_hidden_states = hidden_states
|
|
1626
|
-
(
|
|
1627
|
-
Bn_i_original_hidden_states,
|
|
1628
|
-
Bn_i_original_encoder_hidden_states,
|
|
1629
|
-
) = self._split_Bn_i_single_hidden_states(
|
|
1630
|
-
Bn_i_original_hidden_states,
|
|
1631
|
-
original_hidden_states,
|
|
1632
|
-
original_encoder_hidden_states,
|
|
1633
|
-
)
|
|
1634
|
-
hidden_states, encoder_hidden_states = (
|
|
1635
|
-
apply_hidden_states_residual(
|
|
1636
|
-
Bn_i_original_hidden_states,
|
|
1637
|
-
Bn_i_original_encoder_hidden_states,
|
|
1638
|
-
prefix=(
|
|
1639
|
-
f"Bn_{block_id}_single_residual"
|
|
1640
|
-
if is_cache_residual()
|
|
1641
|
-
else f"Bn_{block_id}_single_original"
|
|
1642
|
-
),
|
|
1643
|
-
encoder_prefix=(
|
|
1644
|
-
f"Bn_{block_id}_single_residual"
|
|
1645
|
-
if is_encoder_cache_residual()
|
|
1646
|
-
else f"Bn_{block_id}_single_original"
|
|
1647
|
-
),
|
|
1648
|
-
)
|
|
1649
|
-
)
|
|
1650
|
-
hidden_states = torch.cat(
|
|
1651
|
-
[encoder_hidden_states, hidden_states],
|
|
1652
|
-
dim=1,
|
|
1653
|
-
)
|
|
1654
|
-
del Bn_i_original_hidden_states
|
|
1655
|
-
del Bn_i_original_encoder_hidden_states
|
|
1656
|
-
else:
|
|
1657
|
-
hidden_states = block(
|
|
1658
|
-
hidden_states,
|
|
1659
|
-
*args,
|
|
1660
|
-
**kwargs,
|
|
1661
|
-
)
|
|
1662
|
-
return hidden_states
|
|
1663
|
-
|
|
1664
|
-
def _compute_and_cache_transformer_block(
|
|
1665
|
-
self,
|
|
1666
|
-
# Block index in the transformer blocks
|
|
1667
|
-
# Bn: 8, block_id should be in [0, 8)
|
|
1668
|
-
block_id: int,
|
|
1669
|
-
# Below are the inputs to the block
|
|
1670
|
-
block, # The transformer block to be executed
|
|
1671
|
-
hidden_states: torch.Tensor,
|
|
1672
|
-
encoder_hidden_states: torch.Tensor,
|
|
1673
|
-
*args,
|
|
1674
|
-
**kwargs,
|
|
1675
|
-
):
|
|
1676
|
-
# Helper function for `call_Bn_transformer_blocks`
|
|
1677
|
-
# Skip the blocks by reuse residual cache if they are not
|
|
1678
|
-
# in the Bn_compute_blocks_ids. NOTE: We should only skip
|
|
1679
|
-
# the specific Bn blocks in cache steps. Compute the block
|
|
1680
|
-
# and cache the residuals in non-cache steps.
|
|
1681
|
-
|
|
1682
|
-
# Normal steps: Compute the block and cache the residuals.
|
|
1683
|
-
if not self._is_in_cache_step():
|
|
1684
|
-
Bn_i_original_hidden_states = hidden_states
|
|
1685
|
-
Bn_i_original_encoder_hidden_states = encoder_hidden_states
|
|
1686
|
-
hidden_states = block(
|
|
1687
|
-
hidden_states,
|
|
1688
|
-
encoder_hidden_states,
|
|
1689
|
-
*args,
|
|
1690
|
-
**kwargs,
|
|
1691
|
-
)
|
|
1692
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1693
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1694
|
-
if not self.return_hidden_states_first:
|
|
1695
|
-
hidden_states, encoder_hidden_states = (
|
|
1696
|
-
encoder_hidden_states,
|
|
1697
|
-
hidden_states,
|
|
1698
|
-
)
|
|
1699
|
-
# Cache residuals for the non-compute Bn blocks for
|
|
1700
|
-
# subsequent cache steps.
|
|
1701
|
-
if block_id not in Bn_compute_blocks_ids():
|
|
1702
|
-
Bn_i_hidden_states_residual = (
|
|
1703
|
-
hidden_states - Bn_i_original_hidden_states
|
|
1704
|
-
)
|
|
1705
|
-
Bn_i_encoder_hidden_states_residual = (
|
|
1706
|
-
encoder_hidden_states - Bn_i_original_encoder_hidden_states
|
|
1707
|
-
)
|
|
1708
|
-
|
|
1709
|
-
# Save original_hidden_states for diff calculation.
|
|
1710
|
-
set_Bn_buffer(
|
|
1711
|
-
Bn_i_original_hidden_states,
|
|
1712
|
-
prefix=f"Bn_{block_id}_original",
|
|
1713
|
-
)
|
|
1714
|
-
set_Bn_encoder_buffer(
|
|
1715
|
-
Bn_i_original_encoder_hidden_states,
|
|
1716
|
-
prefix=f"Bn_{block_id}_original",
|
|
1717
|
-
)
|
|
1718
|
-
|
|
1719
|
-
set_Bn_buffer(
|
|
1720
|
-
Bn_i_hidden_states_residual,
|
|
1721
|
-
prefix=f"Bn_{block_id}_residual",
|
|
1722
|
-
)
|
|
1723
|
-
set_Bn_encoder_buffer(
|
|
1724
|
-
Bn_i_encoder_hidden_states_residual,
|
|
1725
|
-
prefix=f"Bn_{block_id}_residual",
|
|
1726
|
-
)
|
|
1727
|
-
del Bn_i_hidden_states_residual
|
|
1728
|
-
del Bn_i_encoder_hidden_states_residual
|
|
1729
|
-
|
|
1730
|
-
del Bn_i_original_hidden_states
|
|
1731
|
-
del Bn_i_original_encoder_hidden_states
|
|
1732
|
-
|
|
1733
|
-
else:
|
|
1734
|
-
# Cache steps: Reuse the cached residuals.
|
|
1735
|
-
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1736
|
-
if block_id in Bn_compute_blocks_ids():
|
|
1737
|
-
hidden_states = block(
|
|
1738
|
-
hidden_states,
|
|
1739
|
-
encoder_hidden_states,
|
|
1740
|
-
*args,
|
|
1741
|
-
**kwargs,
|
|
1742
|
-
)
|
|
1743
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1744
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1745
|
-
if not self.return_hidden_states_first:
|
|
1746
|
-
hidden_states, encoder_hidden_states = (
|
|
1747
|
-
encoder_hidden_states,
|
|
1748
|
-
hidden_states,
|
|
1749
|
-
)
|
|
1750
|
-
else:
|
|
1751
|
-
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
1752
|
-
# Use the cached residuals instead.
|
|
1753
|
-
# Check if can use the cached residuals.
|
|
1754
|
-
if get_can_use_cache(
|
|
1755
|
-
hidden_states, # curr step
|
|
1756
|
-
parallelized=self._is_parallelized(),
|
|
1757
|
-
threshold=non_compute_blocks_diff_threshold(),
|
|
1758
|
-
prefix=f"Bn_{block_id}_original", # prev step
|
|
1759
|
-
):
|
|
1760
|
-
hidden_states, encoder_hidden_states = (
|
|
1761
|
-
apply_hidden_states_residual(
|
|
1762
|
-
hidden_states,
|
|
1763
|
-
encoder_hidden_states,
|
|
1764
|
-
prefix=(
|
|
1765
|
-
f"Bn_{block_id}_residual"
|
|
1766
|
-
if is_cache_residual()
|
|
1767
|
-
else f"Bn_{block_id}_original"
|
|
1768
|
-
),
|
|
1769
|
-
encoder_prefix=(
|
|
1770
|
-
f"Bn_{block_id}_residual"
|
|
1771
|
-
if is_encoder_cache_residual()
|
|
1772
|
-
else f"Bn_{block_id}_original"
|
|
1773
|
-
),
|
|
1774
|
-
)
|
|
1775
|
-
)
|
|
1776
|
-
else:
|
|
1777
|
-
hidden_states = block(
|
|
1778
|
-
hidden_states,
|
|
1779
|
-
encoder_hidden_states,
|
|
1780
|
-
*args,
|
|
1781
|
-
**kwargs,
|
|
1782
|
-
)
|
|
1783
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1784
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1785
|
-
if not self.return_hidden_states_first:
|
|
1786
|
-
hidden_states, encoder_hidden_states = (
|
|
1787
|
-
encoder_hidden_states,
|
|
1788
|
-
hidden_states,
|
|
1789
|
-
)
|
|
1790
|
-
return hidden_states, encoder_hidden_states
|
|
1791
|
-
|
|
1792
|
-
def call_Bn_transformer_blocks(
|
|
1793
|
-
self,
|
|
1794
|
-
hidden_states: torch.Tensor,
|
|
1795
|
-
encoder_hidden_states: torch.Tensor,
|
|
1796
|
-
*args,
|
|
1797
|
-
**kwargs,
|
|
1798
|
-
):
|
|
1799
|
-
if Bn_compute_blocks() == 0:
|
|
1800
|
-
return hidden_states, encoder_hidden_states
|
|
1801
|
-
|
|
1802
|
-
original_hidden_states = hidden_states
|
|
1803
|
-
original_encoder_hidden_states = encoder_hidden_states
|
|
1804
|
-
# This condition branch is mainly for FLUX series.
|
|
1805
|
-
if self.single_transformer_blocks is not None:
|
|
1806
|
-
assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
|
|
1807
|
-
f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
|
|
1808
|
-
f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
|
|
1809
|
-
)
|
|
1810
|
-
if is_diffusers_at_least_0_3_5():
|
|
1811
|
-
if len(Bn_compute_blocks_ids()) > 0:
|
|
1812
|
-
# NOTE: Reuse _compute_and_cache_transformer_block here.
|
|
1813
|
-
for i, block in enumerate(
|
|
1814
|
-
self._Bn_single_transformer_blocks()
|
|
1815
|
-
):
|
|
1816
|
-
hidden_states, encoder_hidden_states = (
|
|
1817
|
-
self._compute_and_cache_transformer_block(
|
|
1818
|
-
i,
|
|
1819
|
-
block,
|
|
1820
|
-
hidden_states,
|
|
1821
|
-
encoder_hidden_states,
|
|
1822
|
-
*args,
|
|
1823
|
-
**kwargs,
|
|
1824
|
-
)
|
|
1825
|
-
)
|
|
1826
|
-
else:
|
|
1827
|
-
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1828
|
-
for block in self._Bn_single_transformer_blocks():
|
|
1829
|
-
encoder_hidden_states, hidden_states = block(
|
|
1830
|
-
hidden_states,
|
|
1831
|
-
encoder_hidden_states,
|
|
1832
|
-
*args,
|
|
1833
|
-
**kwargs,
|
|
1834
|
-
)
|
|
1835
|
-
else:
|
|
1836
|
-
hidden_states = torch.cat(
|
|
1837
|
-
[encoder_hidden_states, hidden_states], dim=1
|
|
1838
|
-
)
|
|
1839
|
-
if len(Bn_compute_blocks_ids()) > 0:
|
|
1840
|
-
for i, block in enumerate(
|
|
1841
|
-
self._Bn_single_transformer_blocks()
|
|
1842
|
-
):
|
|
1843
|
-
hidden_states = (
|
|
1844
|
-
self._compute_and_cache_single_transformer_block(
|
|
1845
|
-
i,
|
|
1846
|
-
original_hidden_states,
|
|
1847
|
-
original_encoder_hidden_states,
|
|
1848
|
-
block,
|
|
1849
|
-
hidden_states,
|
|
1850
|
-
*args,
|
|
1851
|
-
**kwargs,
|
|
1852
|
-
)
|
|
1853
|
-
)
|
|
1854
|
-
else:
|
|
1855
|
-
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1856
|
-
for block in self._Bn_single_transformer_blocks():
|
|
1857
|
-
hidden_states = block(
|
|
1858
|
-
hidden_states,
|
|
1859
|
-
*args,
|
|
1860
|
-
**kwargs,
|
|
1861
|
-
)
|
|
1862
|
-
encoder_hidden_states, hidden_states = hidden_states.split(
|
|
1863
|
-
[
|
|
1864
|
-
encoder_hidden_states.shape[1],
|
|
1865
|
-
hidden_states.shape[1] - encoder_hidden_states.shape[1],
|
|
1866
|
-
],
|
|
1867
|
-
dim=1,
|
|
1868
|
-
)
|
|
1869
|
-
else:
|
|
1870
|
-
assert Bn_compute_blocks() <= len(self.transformer_blocks), (
|
|
1871
|
-
f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
|
|
1872
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
1873
|
-
)
|
|
1874
|
-
if len(Bn_compute_blocks_ids()) > 0:
|
|
1875
|
-
for i, block in enumerate(self._Bn_transformer_blocks()):
|
|
1876
|
-
hidden_states, encoder_hidden_states = (
|
|
1877
|
-
self._compute_and_cache_transformer_block(
|
|
1878
|
-
i,
|
|
1879
|
-
block,
|
|
1880
|
-
hidden_states,
|
|
1881
|
-
encoder_hidden_states,
|
|
1882
|
-
*args,
|
|
1883
|
-
**kwargs,
|
|
1884
|
-
)
|
|
1885
|
-
)
|
|
1886
|
-
else:
|
|
1887
|
-
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
1888
|
-
for block in self._Bn_transformer_blocks():
|
|
1889
|
-
hidden_states = block(
|
|
1890
|
-
hidden_states,
|
|
1891
|
-
encoder_hidden_states,
|
|
1892
|
-
*args,
|
|
1893
|
-
**kwargs,
|
|
1894
|
-
)
|
|
1895
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
1896
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
1897
|
-
if not self.return_hidden_states_first:
|
|
1898
|
-
hidden_states, encoder_hidden_states = (
|
|
1899
|
-
encoder_hidden_states,
|
|
1900
|
-
hidden_states,
|
|
1901
|
-
)
|
|
1902
|
-
|
|
1903
|
-
hidden_states = (
|
|
1904
|
-
hidden_states.reshape(-1)
|
|
1905
|
-
.contiguous()
|
|
1906
|
-
.reshape(original_hidden_states.shape)
|
|
1907
|
-
)
|
|
1908
|
-
encoder_hidden_states = (
|
|
1909
|
-
encoder_hidden_states.reshape(-1)
|
|
1910
|
-
.contiguous()
|
|
1911
|
-
.reshape(original_encoder_hidden_states.shape)
|
|
1912
|
-
)
|
|
1913
|
-
return hidden_states, encoder_hidden_states
|
|
1914
|
-
|
|
1915
|
-
|
|
1916
|
-
@torch.compiler.disable
|
|
1917
|
-
def patch_cached_stats(
|
|
1918
|
-
transformer,
|
|
1919
|
-
):
|
|
1920
|
-
# Patch the cached stats to the transformer, the cached stats
|
|
1921
|
-
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
1922
|
-
if transformer is None:
|
|
1923
|
-
return
|
|
1924
|
-
|
|
1925
|
-
# TODO: Patch more cached stats to the transformer
|
|
1926
|
-
transformer._cached_steps = get_cached_steps()
|
|
1927
|
-
transformer._residual_diffs = get_residual_diffs()
|
|
1928
|
-
transformer._cfg_cached_steps = get_cfg_cached_steps()
|
|
1929
|
-
transformer._cfg_residual_diffs = get_cfg_residual_diffs()
|