cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
@@ -1,16 +1,13 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
2
-
3
1
  import logging
4
2
  import contextlib
5
3
  import dataclasses
6
4
  from collections import defaultdict
7
- from typing import Any, DefaultDict, Dict, List, Optional, Union
5
+ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
8
6
 
9
7
  import torch
10
8
 
11
9
  import cache_dit.primitives as primitives
12
10
  from cache_dit.cache_factory.taylorseer import TaylorSeer
13
- from cache_dit.utils import is_diffusers_at_least_0_3_5
14
11
  from cache_dit.logger import init_logger
15
12
 
16
13
  logger = init_logger(__name__)
@@ -55,8 +52,7 @@ class DBCacheContext:
55
52
  # DON'T Cache if the number of cached steps >= max_cached_steps
56
53
  max_cached_steps: int = -1 # for both CFG and non-CFG
57
54
 
58
- # Statistics for botch alter cache and non-alter cache
59
- # Record the steps that have been cached, both alter cache and non-alter cache
55
+ # Record the steps that have been cached, both cached and non-cache
60
56
  executed_steps: int = 0 # cache + non-cache steps pippeline
61
57
  # steps for transformer, for CFG, transformer_executed_steps will
62
58
  # be double of executed_steps.
@@ -73,10 +69,10 @@ class DBCacheContext:
73
69
  taylorseer: Optional[TaylorSeer] = None
74
70
  encoder_tarlorseer: Optional[TaylorSeer] = None
75
71
 
76
- # Support do_separate_classifier_free_guidance, such as Wan 2.1
77
- # For model that fused CFG and non-CFG into single forward step,
78
- # should set do_separate_classifier_free_guidance as False. For
79
- # example: CogVideoX, HunyuanVideo, Mochi.
72
+ # Support do_separate_classifier_free_guidance, such as Wan 2.1,
73
+ # Qwen-Image. For model that fused CFG and non-CFG into single
74
+ # forward step, should set do_separate_classifier_free_guidance
75
+ # as False. For example: CogVideoX, HunyuanVideo, Mochi.
80
76
  do_separate_classifier_free_guidance: bool = False
81
77
  # Compute cfg forward first or not, default False, namely,
82
78
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
@@ -98,13 +94,6 @@ class DBCacheContext:
98
94
  default_factory=lambda: defaultdict(float),
99
95
  )
100
96
 
101
- # TODO: Support SLG in Dual Block Cache
102
- # Skip Layer Guidance, SLG
103
- # https://github.com/huggingface/candle/issues/2588
104
- slg_layers: Optional[List[int]] = None
105
- slg_start: float = 0.0
106
- slg_end: float = 0.1
107
-
108
97
  @torch.compiler.disable
109
98
  def __post_init__(self):
110
99
  # Some checks for settings
@@ -144,18 +133,6 @@ class DBCacheContext:
144
133
  **self.taylorseer_kwargs
145
134
  )
146
135
 
147
- @torch.compiler.disable
148
- def get_incremental_name(self, name=None):
149
- if name is None:
150
- name = "default"
151
- idx = self.incremental_name_counters[name]
152
- self.incremental_name_counters[name] += 1
153
- return f"{name}_{idx}"
154
-
155
- @torch.compiler.disable
156
- def reset_incremental_names(self):
157
- self.incremental_name_counters.clear()
158
-
159
136
  @torch.compiler.disable
160
137
  def get_residual_diff_threshold(self):
161
138
  if self.enable_alter_cache:
@@ -222,7 +199,6 @@ class DBCacheContext:
222
199
  self.residual_diffs.clear()
223
200
  self.cfg_cached_steps.clear()
224
201
  self.cfg_residual_diffs.clear()
225
- self.reset_incremental_names()
226
202
  # Reset the TaylorSeers cache at the beginning of each inference.
227
203
  # reset_cache will set the current step to -1 for TaylorSeer,
228
204
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
@@ -264,12 +240,10 @@ class DBCacheContext:
264
240
  if encoder_taylorseer is not None:
265
241
  encoder_taylorseer.mark_step_begin()
266
242
 
267
- @torch.compiler.disable
268
- def get_taylorseers(self):
243
+ def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
269
244
  return self.taylorseer, self.encoder_tarlorseer
270
245
 
271
- @torch.compiler.disable
272
- def get_cfg_taylorseers(self):
246
+ def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
273
247
  return self.cfg_taylorseer, self.cfg_encoder_taylorseer
274
248
 
275
249
  @torch.compiler.disable
@@ -464,15 +438,13 @@ def is_encoder_taylorseer_enabled():
464
438
  return cache_context.enable_encoder_taylorseer
465
439
 
466
440
 
467
- @torch.compiler.disable
468
- def get_taylorseers():
441
+ def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
469
442
  cache_context = get_current_cache_context()
470
443
  assert cache_context is not None, "cache_context must be set before"
471
444
  return cache_context.get_taylorseers()
472
445
 
473
446
 
474
- @torch.compiler.disable
475
- def get_cfg_taylorseers():
447
+ def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
476
448
  cache_context = get_current_cache_context()
477
449
  assert cache_context is not None, "cache_context must be set before"
478
450
  return cache_context.get_cfg_taylorseers()
@@ -1105,825 +1077,3 @@ def get_can_use_cache(
1105
1077
  and is_alter_cache()
1106
1078
  )
1107
1079
  return can_use_cache
1108
-
1109
-
1110
- class DBCachedTransformerBlocks(torch.nn.Module):
1111
- def __init__(
1112
- self,
1113
- transformer_blocks,
1114
- single_transformer_blocks=None,
1115
- *,
1116
- transformer=None,
1117
- return_hidden_states_first=True,
1118
- return_hidden_states_only=False,
1119
- ):
1120
- super().__init__()
1121
-
1122
- self.transformer = transformer
1123
- self.transformer_blocks = transformer_blocks
1124
- self.single_transformer_blocks = single_transformer_blocks
1125
- self.return_hidden_states_first = return_hidden_states_first
1126
- self.return_hidden_states_only = return_hidden_states_only
1127
-
1128
- def forward(
1129
- self,
1130
- hidden_states: torch.Tensor,
1131
- encoder_hidden_states: torch.Tensor,
1132
- *args,
1133
- **kwargs,
1134
- ):
1135
- original_hidden_states = hidden_states
1136
- # Call first `n` blocks to process the hidden states for
1137
- # more stable diff calculation.
1138
- hidden_states, encoder_hidden_states = self.call_Fn_transformer_blocks(
1139
- hidden_states,
1140
- encoder_hidden_states,
1141
- *args,
1142
- **kwargs,
1143
- )
1144
-
1145
- Fn_hidden_states_residual = hidden_states - original_hidden_states
1146
- del original_hidden_states
1147
-
1148
- mark_step_begin()
1149
- # Residual L1 diff or Hidden States L1 diff
1150
- can_use_cache = get_can_use_cache(
1151
- (
1152
- Fn_hidden_states_residual
1153
- if not is_l1_diff_enabled()
1154
- else hidden_states
1155
- ),
1156
- parallelized=self._is_parallelized(),
1157
- prefix=(
1158
- "Fn_residual"
1159
- if not is_l1_diff_enabled()
1160
- else "Fn_hidden_states"
1161
- ),
1162
- )
1163
-
1164
- torch._dynamo.graph_break()
1165
- if can_use_cache:
1166
- add_cached_step()
1167
- del Fn_hidden_states_residual
1168
- hidden_states, encoder_hidden_states = apply_hidden_states_residual(
1169
- hidden_states,
1170
- encoder_hidden_states,
1171
- prefix=(
1172
- "Bn_residual" if is_cache_residual() else "Bn_hidden_states"
1173
- ),
1174
- encoder_prefix=(
1175
- "Bn_residual"
1176
- if is_encoder_cache_residual()
1177
- else "Bn_hidden_states"
1178
- ),
1179
- )
1180
- torch._dynamo.graph_break()
1181
- # Call last `n` blocks to further process the hidden states
1182
- # for higher precision.
1183
- hidden_states, encoder_hidden_states = (
1184
- self.call_Bn_transformer_blocks(
1185
- hidden_states,
1186
- encoder_hidden_states,
1187
- *args,
1188
- **kwargs,
1189
- )
1190
- )
1191
- else:
1192
- set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
1193
- if is_l1_diff_enabled():
1194
- # for hidden states L1 diff
1195
- set_Fn_buffer(hidden_states, "Fn_hidden_states")
1196
- del Fn_hidden_states_residual
1197
- torch._dynamo.graph_break()
1198
- (
1199
- hidden_states,
1200
- encoder_hidden_states,
1201
- hidden_states_residual,
1202
- encoder_hidden_states_residual,
1203
- ) = self.call_Mn_transformer_blocks( # middle
1204
- hidden_states,
1205
- encoder_hidden_states,
1206
- *args,
1207
- **kwargs,
1208
- )
1209
- torch._dynamo.graph_break()
1210
- if is_cache_residual():
1211
- set_Bn_buffer(
1212
- hidden_states_residual,
1213
- prefix="Bn_residual",
1214
- )
1215
- else:
1216
- # TaylorSeer
1217
- set_Bn_buffer(
1218
- hidden_states,
1219
- prefix="Bn_hidden_states",
1220
- )
1221
- if is_encoder_cache_residual():
1222
- set_Bn_encoder_buffer(
1223
- encoder_hidden_states_residual,
1224
- prefix="Bn_residual",
1225
- )
1226
- else:
1227
- # TaylorSeer
1228
- set_Bn_encoder_buffer(
1229
- encoder_hidden_states,
1230
- prefix="Bn_hidden_states",
1231
- )
1232
- torch._dynamo.graph_break()
1233
- # Call last `n` blocks to further process the hidden states
1234
- # for higher precision.
1235
- hidden_states, encoder_hidden_states = (
1236
- self.call_Bn_transformer_blocks(
1237
- hidden_states,
1238
- encoder_hidden_states,
1239
- *args,
1240
- **kwargs,
1241
- )
1242
- )
1243
-
1244
- patch_cached_stats(self.transformer)
1245
- torch._dynamo.graph_break()
1246
-
1247
- return (
1248
- hidden_states
1249
- if self.return_hidden_states_only
1250
- else (
1251
- (hidden_states, encoder_hidden_states)
1252
- if self.return_hidden_states_first
1253
- else (encoder_hidden_states, hidden_states)
1254
- )
1255
- )
1256
-
1257
- @torch.compiler.disable
1258
- def _is_parallelized(self):
1259
- # Compatible with distributed inference.
1260
- return all(
1261
- (
1262
- self.transformer is not None,
1263
- getattr(self.transformer, "_is_parallelized", False),
1264
- )
1265
- )
1266
-
1267
- @torch.compiler.disable
1268
- def _is_in_cache_step(self):
1269
- # Check if the current step is in cache steps.
1270
- # If so, we can skip some Bn blocks and directly
1271
- # use the cached values.
1272
- return (get_current_step() in get_cached_steps()) or (
1273
- get_current_step() in get_cfg_cached_steps()
1274
- )
1275
-
1276
- @torch.compiler.disable
1277
- def _Fn_transformer_blocks(self):
1278
- # Select first `n` blocks to process the hidden states for
1279
- # more stable diff calculation.
1280
- # Fn: [0,...,n-1]
1281
- selected_Fn_transformer_blocks = self.transformer_blocks[
1282
- : Fn_compute_blocks()
1283
- ]
1284
- return selected_Fn_transformer_blocks
1285
-
1286
- @torch.compiler.disable
1287
- def _Mn_single_transformer_blocks(self): # middle blocks
1288
- # M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
1289
- selected_Mn_single_transformer_blocks = []
1290
- if self.single_transformer_blocks is not None:
1291
- if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
1292
- selected_Mn_single_transformer_blocks = (
1293
- self.single_transformer_blocks
1294
- )
1295
- else:
1296
- selected_Mn_single_transformer_blocks = (
1297
- self.single_transformer_blocks[: -Bn_compute_blocks()]
1298
- )
1299
- return selected_Mn_single_transformer_blocks
1300
-
1301
- @torch.compiler.disable
1302
- def _Mn_transformer_blocks(self): # middle blocks
1303
- # M(N-2n): only transformer_blocks [n,...,N-n], middle
1304
- if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
1305
- selected_Mn_transformer_blocks = self.transformer_blocks[
1306
- Fn_compute_blocks() :
1307
- ]
1308
- else:
1309
- selected_Mn_transformer_blocks = self.transformer_blocks[
1310
- Fn_compute_blocks() : -Bn_compute_blocks()
1311
- ]
1312
- return selected_Mn_transformer_blocks
1313
-
1314
- @torch.compiler.disable
1315
- def _Bn_single_transformer_blocks(self):
1316
- # Bn: single_transformer_blocks [N-n+1,...,N-1]
1317
- selected_Bn_single_transformer_blocks = []
1318
- if self.single_transformer_blocks is not None:
1319
- selected_Bn_single_transformer_blocks = (
1320
- self.single_transformer_blocks[-Bn_compute_blocks() :]
1321
- )
1322
- return selected_Bn_single_transformer_blocks
1323
-
1324
- @torch.compiler.disable
1325
- def _Bn_transformer_blocks(self):
1326
- # Bn: transformer_blocks [N-n+1,...,N-1]
1327
- selected_Bn_transformer_blocks = self.transformer_blocks[
1328
- -Bn_compute_blocks() :
1329
- ]
1330
- return selected_Bn_transformer_blocks
1331
-
1332
- def call_Fn_transformer_blocks(
1333
- self,
1334
- hidden_states: torch.Tensor,
1335
- encoder_hidden_states: torch.Tensor,
1336
- *args,
1337
- **kwargs,
1338
- ):
1339
- assert Fn_compute_blocks() <= len(self.transformer_blocks), (
1340
- f"Fn_compute_blocks {Fn_compute_blocks()} must be less than "
1341
- f"the number of transformer blocks {len(self.transformer_blocks)}"
1342
- )
1343
- for block in self._Fn_transformer_blocks():
1344
- hidden_states = block(
1345
- hidden_states,
1346
- encoder_hidden_states,
1347
- *args,
1348
- **kwargs,
1349
- )
1350
- if not isinstance(hidden_states, torch.Tensor):
1351
- hidden_states, encoder_hidden_states = hidden_states
1352
- if not self.return_hidden_states_first:
1353
- hidden_states, encoder_hidden_states = (
1354
- encoder_hidden_states,
1355
- hidden_states,
1356
- )
1357
-
1358
- return hidden_states, encoder_hidden_states
1359
-
1360
- def call_Mn_transformer_blocks(
1361
- self,
1362
- hidden_states: torch.Tensor,
1363
- encoder_hidden_states: torch.Tensor,
1364
- *args,
1365
- **kwargs,
1366
- ):
1367
- original_hidden_states = hidden_states
1368
- original_encoder_hidden_states = encoder_hidden_states
1369
- # This condition branch is mainly for FLUX series.
1370
- if self.single_transformer_blocks is not None:
1371
- for block in self.transformer_blocks[Fn_compute_blocks() :]:
1372
- hidden_states = block(
1373
- hidden_states,
1374
- encoder_hidden_states,
1375
- *args,
1376
- **kwargs,
1377
- )
1378
- if not isinstance(hidden_states, torch.Tensor):
1379
- hidden_states, encoder_hidden_states = hidden_states
1380
- if not self.return_hidden_states_first:
1381
- hidden_states, encoder_hidden_states = (
1382
- encoder_hidden_states,
1383
- hidden_states,
1384
- )
1385
-
1386
- # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
1387
- if is_diffusers_at_least_0_3_5():
1388
- for block in self._Mn_single_transformer_blocks():
1389
- encoder_hidden_states, hidden_states = block(
1390
- hidden_states,
1391
- encoder_hidden_states,
1392
- *args,
1393
- **kwargs,
1394
- )
1395
- else:
1396
- hidden_states = torch.cat(
1397
- [encoder_hidden_states, hidden_states], dim=1
1398
- )
1399
- for block in self._Mn_single_transformer_blocks():
1400
- hidden_states = block(
1401
- hidden_states,
1402
- *args,
1403
- **kwargs,
1404
- )
1405
- encoder_hidden_states, hidden_states = hidden_states.split(
1406
- [
1407
- encoder_hidden_states.shape[1],
1408
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
1409
- ],
1410
- dim=1,
1411
- )
1412
- else:
1413
- for block in self._Mn_transformer_blocks():
1414
- hidden_states = block(
1415
- hidden_states,
1416
- encoder_hidden_states,
1417
- *args,
1418
- **kwargs,
1419
- )
1420
- if not isinstance(hidden_states, torch.Tensor):
1421
- hidden_states, encoder_hidden_states = hidden_states
1422
- if not self.return_hidden_states_first:
1423
- hidden_states, encoder_hidden_states = (
1424
- encoder_hidden_states,
1425
- hidden_states,
1426
- )
1427
-
1428
- # hidden_states_shape = hidden_states.shape
1429
- # encoder_hidden_states_shape = encoder_hidden_states.shape
1430
- hidden_states = (
1431
- hidden_states.reshape(-1)
1432
- .contiguous()
1433
- .reshape(original_hidden_states.shape)
1434
- )
1435
- encoder_hidden_states = (
1436
- encoder_hidden_states.reshape(-1)
1437
- .contiguous()
1438
- .reshape(original_encoder_hidden_states.shape)
1439
- )
1440
-
1441
- # hidden_states = hidden_states.contiguous()
1442
- # encoder_hidden_states = encoder_hidden_states.contiguous()
1443
-
1444
- hidden_states_residual = hidden_states - original_hidden_states
1445
- encoder_hidden_states_residual = (
1446
- encoder_hidden_states - original_encoder_hidden_states
1447
- )
1448
-
1449
- hidden_states_residual = (
1450
- hidden_states_residual.reshape(-1)
1451
- .contiguous()
1452
- .reshape(original_hidden_states.shape)
1453
- )
1454
- encoder_hidden_states_residual = (
1455
- encoder_hidden_states_residual.reshape(-1)
1456
- .contiguous()
1457
- .reshape(original_encoder_hidden_states.shape)
1458
- )
1459
-
1460
- return (
1461
- hidden_states,
1462
- encoder_hidden_states,
1463
- hidden_states_residual,
1464
- encoder_hidden_states_residual,
1465
- )
1466
-
1467
- @torch.compiler.disable
1468
- def _Bn_i_single_hidden_states_residual(
1469
- self,
1470
- Bn_i_hidden_states: torch.Tensor,
1471
- Bn_i_original_hidden_states: torch.Tensor,
1472
- original_hidden_states: torch.Tensor,
1473
- original_encoder_hidden_states: torch.Tensor,
1474
- ):
1475
- # Split the Bn_i_hidden_states and Bn_i_original_hidden_states
1476
- # into encoder_hidden_states and hidden_states.
1477
- Bn_i_hidden_states, Bn_i_encoder_hidden_states = (
1478
- self._split_Bn_i_single_hidden_states(
1479
- Bn_i_hidden_states,
1480
- original_hidden_states,
1481
- original_encoder_hidden_states,
1482
- )
1483
- )
1484
- # Split the Bn_i_original_hidden_states into encoder_hidden_states
1485
- # and hidden_states.
1486
- Bn_i_original_hidden_states, Bn_i_original_encoder_hidden_states = (
1487
- self._split_Bn_i_single_hidden_states(
1488
- Bn_i_original_hidden_states,
1489
- original_hidden_states,
1490
- original_encoder_hidden_states,
1491
- )
1492
- )
1493
-
1494
- # Compute the residuals for the Bn_i_hidden_states and
1495
- # Bn_i_encoder_hidden_states.
1496
- Bn_i_hidden_states_residual = (
1497
- Bn_i_hidden_states - Bn_i_original_hidden_states
1498
- )
1499
- Bn_i_encoder_hidden_states_residual = (
1500
- Bn_i_encoder_hidden_states - Bn_i_original_encoder_hidden_states
1501
- )
1502
- return (
1503
- Bn_i_hidden_states_residual,
1504
- Bn_i_encoder_hidden_states_residual,
1505
- )
1506
-
1507
- @torch.compiler.disable
1508
- def _split_Bn_i_single_hidden_states(
1509
- self,
1510
- Bn_i_hidden_states: torch.Tensor,
1511
- original_hidden_states: torch.Tensor,
1512
- original_encoder_hidden_states: torch.Tensor,
1513
- ):
1514
- # Split the Bn_i_hidden_states into encoder_hidden_states and hidden_states.
1515
- Bn_i_encoder_hidden_states, Bn_i_hidden_states = (
1516
- Bn_i_hidden_states.split(
1517
- [
1518
- original_encoder_hidden_states.shape[1],
1519
- Bn_i_hidden_states.shape[1]
1520
- - original_encoder_hidden_states.shape[1],
1521
- ],
1522
- dim=1,
1523
- )
1524
- )
1525
- # Reshape the Bn_i_hidden_states and Bn_i_encoder_hidden_states
1526
- # to the original shape. This is necessary to ensure that the
1527
- # residuals are computed correctly.
1528
- Bn_i_hidden_states = (
1529
- Bn_i_hidden_states.reshape(-1)
1530
- .contiguous()
1531
- .reshape(original_hidden_states.shape)
1532
- )
1533
- Bn_i_encoder_hidden_states = (
1534
- Bn_i_encoder_hidden_states.reshape(-1)
1535
- .contiguous()
1536
- .reshape(original_encoder_hidden_states.shape)
1537
- )
1538
- return Bn_i_hidden_states, Bn_i_encoder_hidden_states
1539
-
1540
- def _compute_and_cache_single_transformer_block(
1541
- self,
1542
- # Block index in the transformer blocks
1543
- # Bn: 8, block_id should be in [0, 8)
1544
- block_id: int,
1545
- # Helper inputs for hidden states split and reshape
1546
- original_hidden_states: torch.Tensor,
1547
- original_encoder_hidden_states: torch.Tensor,
1548
- # Below are the inputs to the block
1549
- block, # The transformer block to be executed
1550
- hidden_states: torch.Tensor,
1551
- *args,
1552
- **kwargs,
1553
- ):
1554
- # Helper function for `call_Bn_transformer_blocks`
1555
- # Skip the blocks by reuse residual cache if they are not
1556
- # in the Bn_compute_blocks_ids. NOTE: We should only skip
1557
- # the specific Bn blocks in cache steps. Compute the block
1558
- # and cache the residuals in non-cache steps.
1559
-
1560
- # Normal steps: Compute the block and cache the residuals.
1561
- if not self._is_in_cache_step():
1562
- Bn_i_original_hidden_states = hidden_states
1563
- hidden_states = block(
1564
- hidden_states,
1565
- *args,
1566
- **kwargs,
1567
- )
1568
- # Cache residuals for the non-compute Bn blocks for
1569
- # subsequent cache steps.
1570
- if block_id not in Bn_compute_blocks_ids():
1571
- Bn_i_hidden_states = hidden_states
1572
- (
1573
- Bn_i_hidden_states_residual,
1574
- Bn_i_encoder_hidden_states_residual,
1575
- ) = self._Bn_i_single_hidden_states_residual(
1576
- Bn_i_hidden_states,
1577
- Bn_i_original_hidden_states,
1578
- original_hidden_states,
1579
- original_encoder_hidden_states,
1580
- )
1581
-
1582
- # Save original_hidden_states for diff calculation.
1583
- set_Bn_buffer(
1584
- Bn_i_original_hidden_states,
1585
- prefix=f"Bn_{block_id}_single_original",
1586
- )
1587
- set_Bn_encoder_buffer(
1588
- Bn_i_original_hidden_states,
1589
- prefix=f"Bn_{block_id}_single_original",
1590
- )
1591
-
1592
- set_Bn_buffer(
1593
- Bn_i_hidden_states_residual,
1594
- prefix=f"Bn_{block_id}_single_residual",
1595
- )
1596
- set_Bn_encoder_buffer(
1597
- Bn_i_encoder_hidden_states_residual,
1598
- prefix=f"Bn_{block_id}_single_residual",
1599
- )
1600
- del Bn_i_hidden_states
1601
- del Bn_i_hidden_states_residual
1602
- del Bn_i_encoder_hidden_states_residual
1603
-
1604
- del Bn_i_original_hidden_states
1605
-
1606
- else:
1607
- # Cache steps: Reuse the cached residuals.
1608
- # Check if the block is in the Bn_compute_blocks_ids.
1609
- if block_id in Bn_compute_blocks_ids():
1610
- hidden_states = block(
1611
- hidden_states,
1612
- *args,
1613
- **kwargs,
1614
- )
1615
- else:
1616
- # Skip the block if it is not in the Bn_compute_blocks_ids.
1617
- # Use the cached residuals instead.
1618
- # Check if can use the cached residuals.
1619
- if get_can_use_cache(
1620
- hidden_states, # curr step
1621
- parallelized=self._is_parallelized(),
1622
- threshold=non_compute_blocks_diff_threshold(),
1623
- prefix=f"Bn_{block_id}_single_original", # prev step
1624
- ):
1625
- Bn_i_original_hidden_states = hidden_states
1626
- (
1627
- Bn_i_original_hidden_states,
1628
- Bn_i_original_encoder_hidden_states,
1629
- ) = self._split_Bn_i_single_hidden_states(
1630
- Bn_i_original_hidden_states,
1631
- original_hidden_states,
1632
- original_encoder_hidden_states,
1633
- )
1634
- hidden_states, encoder_hidden_states = (
1635
- apply_hidden_states_residual(
1636
- Bn_i_original_hidden_states,
1637
- Bn_i_original_encoder_hidden_states,
1638
- prefix=(
1639
- f"Bn_{block_id}_single_residual"
1640
- if is_cache_residual()
1641
- else f"Bn_{block_id}_single_original"
1642
- ),
1643
- encoder_prefix=(
1644
- f"Bn_{block_id}_single_residual"
1645
- if is_encoder_cache_residual()
1646
- else f"Bn_{block_id}_single_original"
1647
- ),
1648
- )
1649
- )
1650
- hidden_states = torch.cat(
1651
- [encoder_hidden_states, hidden_states],
1652
- dim=1,
1653
- )
1654
- del Bn_i_original_hidden_states
1655
- del Bn_i_original_encoder_hidden_states
1656
- else:
1657
- hidden_states = block(
1658
- hidden_states,
1659
- *args,
1660
- **kwargs,
1661
- )
1662
- return hidden_states
1663
-
1664
- def _compute_and_cache_transformer_block(
1665
- self,
1666
- # Block index in the transformer blocks
1667
- # Bn: 8, block_id should be in [0, 8)
1668
- block_id: int,
1669
- # Below are the inputs to the block
1670
- block, # The transformer block to be executed
1671
- hidden_states: torch.Tensor,
1672
- encoder_hidden_states: torch.Tensor,
1673
- *args,
1674
- **kwargs,
1675
- ):
1676
- # Helper function for `call_Bn_transformer_blocks`
1677
- # Skip the blocks by reuse residual cache if they are not
1678
- # in the Bn_compute_blocks_ids. NOTE: We should only skip
1679
- # the specific Bn blocks in cache steps. Compute the block
1680
- # and cache the residuals in non-cache steps.
1681
-
1682
- # Normal steps: Compute the block and cache the residuals.
1683
- if not self._is_in_cache_step():
1684
- Bn_i_original_hidden_states = hidden_states
1685
- Bn_i_original_encoder_hidden_states = encoder_hidden_states
1686
- hidden_states = block(
1687
- hidden_states,
1688
- encoder_hidden_states,
1689
- *args,
1690
- **kwargs,
1691
- )
1692
- if not isinstance(hidden_states, torch.Tensor):
1693
- hidden_states, encoder_hidden_states = hidden_states
1694
- if not self.return_hidden_states_first:
1695
- hidden_states, encoder_hidden_states = (
1696
- encoder_hidden_states,
1697
- hidden_states,
1698
- )
1699
- # Cache residuals for the non-compute Bn blocks for
1700
- # subsequent cache steps.
1701
- if block_id not in Bn_compute_blocks_ids():
1702
- Bn_i_hidden_states_residual = (
1703
- hidden_states - Bn_i_original_hidden_states
1704
- )
1705
- Bn_i_encoder_hidden_states_residual = (
1706
- encoder_hidden_states - Bn_i_original_encoder_hidden_states
1707
- )
1708
-
1709
- # Save original_hidden_states for diff calculation.
1710
- set_Bn_buffer(
1711
- Bn_i_original_hidden_states,
1712
- prefix=f"Bn_{block_id}_original",
1713
- )
1714
- set_Bn_encoder_buffer(
1715
- Bn_i_original_encoder_hidden_states,
1716
- prefix=f"Bn_{block_id}_original",
1717
- )
1718
-
1719
- set_Bn_buffer(
1720
- Bn_i_hidden_states_residual,
1721
- prefix=f"Bn_{block_id}_residual",
1722
- )
1723
- set_Bn_encoder_buffer(
1724
- Bn_i_encoder_hidden_states_residual,
1725
- prefix=f"Bn_{block_id}_residual",
1726
- )
1727
- del Bn_i_hidden_states_residual
1728
- del Bn_i_encoder_hidden_states_residual
1729
-
1730
- del Bn_i_original_hidden_states
1731
- del Bn_i_original_encoder_hidden_states
1732
-
1733
- else:
1734
- # Cache steps: Reuse the cached residuals.
1735
- # Check if the block is in the Bn_compute_blocks_ids.
1736
- if block_id in Bn_compute_blocks_ids():
1737
- hidden_states = block(
1738
- hidden_states,
1739
- encoder_hidden_states,
1740
- *args,
1741
- **kwargs,
1742
- )
1743
- if not isinstance(hidden_states, torch.Tensor):
1744
- hidden_states, encoder_hidden_states = hidden_states
1745
- if not self.return_hidden_states_first:
1746
- hidden_states, encoder_hidden_states = (
1747
- encoder_hidden_states,
1748
- hidden_states,
1749
- )
1750
- else:
1751
- # Skip the block if it is not in the Bn_compute_blocks_ids.
1752
- # Use the cached residuals instead.
1753
- # Check if can use the cached residuals.
1754
- if get_can_use_cache(
1755
- hidden_states, # curr step
1756
- parallelized=self._is_parallelized(),
1757
- threshold=non_compute_blocks_diff_threshold(),
1758
- prefix=f"Bn_{block_id}_original", # prev step
1759
- ):
1760
- hidden_states, encoder_hidden_states = (
1761
- apply_hidden_states_residual(
1762
- hidden_states,
1763
- encoder_hidden_states,
1764
- prefix=(
1765
- f"Bn_{block_id}_residual"
1766
- if is_cache_residual()
1767
- else f"Bn_{block_id}_original"
1768
- ),
1769
- encoder_prefix=(
1770
- f"Bn_{block_id}_residual"
1771
- if is_encoder_cache_residual()
1772
- else f"Bn_{block_id}_original"
1773
- ),
1774
- )
1775
- )
1776
- else:
1777
- hidden_states = block(
1778
- hidden_states,
1779
- encoder_hidden_states,
1780
- *args,
1781
- **kwargs,
1782
- )
1783
- if not isinstance(hidden_states, torch.Tensor):
1784
- hidden_states, encoder_hidden_states = hidden_states
1785
- if not self.return_hidden_states_first:
1786
- hidden_states, encoder_hidden_states = (
1787
- encoder_hidden_states,
1788
- hidden_states,
1789
- )
1790
- return hidden_states, encoder_hidden_states
1791
-
1792
- def call_Bn_transformer_blocks(
1793
- self,
1794
- hidden_states: torch.Tensor,
1795
- encoder_hidden_states: torch.Tensor,
1796
- *args,
1797
- **kwargs,
1798
- ):
1799
- if Bn_compute_blocks() == 0:
1800
- return hidden_states, encoder_hidden_states
1801
-
1802
- original_hidden_states = hidden_states
1803
- original_encoder_hidden_states = encoder_hidden_states
1804
- # This condition branch is mainly for FLUX series.
1805
- if self.single_transformer_blocks is not None:
1806
- assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
1807
- f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1808
- f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
1809
- )
1810
- if is_diffusers_at_least_0_3_5():
1811
- if len(Bn_compute_blocks_ids()) > 0:
1812
- # NOTE: Reuse _compute_and_cache_transformer_block here.
1813
- for i, block in enumerate(
1814
- self._Bn_single_transformer_blocks()
1815
- ):
1816
- hidden_states, encoder_hidden_states = (
1817
- self._compute_and_cache_transformer_block(
1818
- i,
1819
- block,
1820
- hidden_states,
1821
- encoder_hidden_states,
1822
- *args,
1823
- **kwargs,
1824
- )
1825
- )
1826
- else:
1827
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1828
- for block in self._Bn_single_transformer_blocks():
1829
- encoder_hidden_states, hidden_states = block(
1830
- hidden_states,
1831
- encoder_hidden_states,
1832
- *args,
1833
- **kwargs,
1834
- )
1835
- else:
1836
- hidden_states = torch.cat(
1837
- [encoder_hidden_states, hidden_states], dim=1
1838
- )
1839
- if len(Bn_compute_blocks_ids()) > 0:
1840
- for i, block in enumerate(
1841
- self._Bn_single_transformer_blocks()
1842
- ):
1843
- hidden_states = (
1844
- self._compute_and_cache_single_transformer_block(
1845
- i,
1846
- original_hidden_states,
1847
- original_encoder_hidden_states,
1848
- block,
1849
- hidden_states,
1850
- *args,
1851
- **kwargs,
1852
- )
1853
- )
1854
- else:
1855
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1856
- for block in self._Bn_single_transformer_blocks():
1857
- hidden_states = block(
1858
- hidden_states,
1859
- *args,
1860
- **kwargs,
1861
- )
1862
- encoder_hidden_states, hidden_states = hidden_states.split(
1863
- [
1864
- encoder_hidden_states.shape[1],
1865
- hidden_states.shape[1] - encoder_hidden_states.shape[1],
1866
- ],
1867
- dim=1,
1868
- )
1869
- else:
1870
- assert Bn_compute_blocks() <= len(self.transformer_blocks), (
1871
- f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1872
- f"the number of transformer blocks {len(self.transformer_blocks)}"
1873
- )
1874
- if len(Bn_compute_blocks_ids()) > 0:
1875
- for i, block in enumerate(self._Bn_transformer_blocks()):
1876
- hidden_states, encoder_hidden_states = (
1877
- self._compute_and_cache_transformer_block(
1878
- i,
1879
- block,
1880
- hidden_states,
1881
- encoder_hidden_states,
1882
- *args,
1883
- **kwargs,
1884
- )
1885
- )
1886
- else:
1887
- # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1888
- for block in self._Bn_transformer_blocks():
1889
- hidden_states = block(
1890
- hidden_states,
1891
- encoder_hidden_states,
1892
- *args,
1893
- **kwargs,
1894
- )
1895
- if not isinstance(hidden_states, torch.Tensor):
1896
- hidden_states, encoder_hidden_states = hidden_states
1897
- if not self.return_hidden_states_first:
1898
- hidden_states, encoder_hidden_states = (
1899
- encoder_hidden_states,
1900
- hidden_states,
1901
- )
1902
-
1903
- hidden_states = (
1904
- hidden_states.reshape(-1)
1905
- .contiguous()
1906
- .reshape(original_hidden_states.shape)
1907
- )
1908
- encoder_hidden_states = (
1909
- encoder_hidden_states.reshape(-1)
1910
- .contiguous()
1911
- .reshape(original_encoder_hidden_states.shape)
1912
- )
1913
- return hidden_states, encoder_hidden_states
1914
-
1915
-
1916
- @torch.compiler.disable
1917
- def patch_cached_stats(
1918
- transformer,
1919
- ):
1920
- # Patch the cached stats to the transformer, the cached stats
1921
- # will be reset for each calling of pipe.__call__(**kwargs).
1922
- if transformer is None:
1923
- return
1924
-
1925
- # TODO: Patch more cached stats to the transformer
1926
- transformer._cached_steps = get_cached_steps()
1927
- transformer._residual_diffs = get_residual_diffs()
1928
- transformer._cfg_cached_steps = get_cfg_cached_steps()
1929
- transformer._cfg_residual_diffs = get_cfg_residual_diffs()