cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/__init__.py +1 -0
  3. cache_dit/cache_factory/adapters.py +47 -5
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
  6. cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
  10. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
  11. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
  12. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
  13. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
  14. cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
  15. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
  16. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
  17. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
  18. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
  19. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
  20. cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
  21. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
  22. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
  23. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
  24. cache_dit/cache_factory/patch/flux.py +241 -0
  25. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
  26. cache_dit-0.2.16.dist-info/RECORD +47 -0
  27. cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
  28. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
  29. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
  30. cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
  31. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
  32. cache_dit-0.2.14.dist-info/RECORD +0 -49
  33. /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
  34. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.14'
21
- __version_tuple__ = version_tuple = (0, 2, 14)
20
+ __version__ = version = '0.2.16'
21
+ __version_tuple__ = version_tuple = (0, 2, 16)
@@ -1,3 +1,4 @@
1
1
  from cache_dit.cache_factory.adapters import CacheType
2
2
  from cache_dit.cache_factory.adapters import apply_cache_on_pipe
3
+ from cache_dit.cache_factory.adapters import apply_cache_on_transformer
3
4
  from cache_dit.cache_factory.utils import load_cache_options_from_yaml
@@ -4,12 +4,12 @@ from diffusers import DiffusionPipeline
4
4
 
5
5
  from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
6
  apply_db_cache_on_pipe,
7
+ apply_db_cache_on_transformer,
7
8
  )
8
- from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
- apply_fb_cache_on_pipe,
10
- )
9
+
11
10
  from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
11
  apply_db_prune_on_pipe,
12
+ apply_db_prune_on_transformer,
13
13
  )
14
14
 
15
15
  from cache_dit.logger import init_logger
@@ -93,7 +93,7 @@ class CacheType(Enum):
93
93
  }
94
94
 
95
95
  _Fn_compute_blocks = 8
96
- _Bn_compute_blocks = 8
96
+ _Bn_compute_blocks = 0
97
97
 
98
98
  _db_options = {
99
99
  "cache_type": CacheType.DBCache,
@@ -155,7 +155,9 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
155
155
  cache_type = CacheType.type(cache_type)
156
156
 
157
157
  if cache_type == CacheType.FBCache:
158
- return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
158
+ raise ValueError(
159
+ "FBCache is removed from cache-dit, please use DBCache F1B0 instead."
160
+ )
159
161
  elif cache_type == CacheType.DBCache:
160
162
  return apply_db_cache_on_pipe(pipe, *args, **kwargs)
161
163
  elif cache_type == CacheType.DBPrune:
@@ -167,3 +169,43 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
167
169
  return pipe
168
170
  else:
169
171
  raise ValueError(f"Unknown cache type: {cache_type}")
172
+
173
+
174
+ def apply_cache_on_transformer(transformer, *args, **kwargs):
175
+
176
+ if hasattr(transformer, "_is_cached") and transformer._is_cached:
177
+ return transformer
178
+
179
+ if hasattr(transformer, "_is_pruned") and transformer._is_pruned:
180
+ return transformer
181
+
182
+ cache_type = kwargs.pop("cache_type", None)
183
+ if cache_type is None:
184
+ logger.warning(
185
+ "No cache type specified, we will use DBCache by default. "
186
+ "Please specify the cache_type explicitly if you want to "
187
+ "use a different cache type."
188
+ )
189
+ # Force to use DBCache with default cache options
190
+ return apply_db_cache_on_transformer(
191
+ transformer,
192
+ **CacheType.default_options(CacheType.DBCache),
193
+ )
194
+
195
+ cache_type = CacheType.type(cache_type)
196
+
197
+ if cache_type == CacheType.FBCache:
198
+ raise ValueError(
199
+ "FBCache is removed from cache-dit, please use DBCache F1B0 instead."
200
+ )
201
+ elif cache_type == CacheType.DBCache:
202
+ return apply_db_cache_on_transformer(transformer, *args, **kwargs)
203
+ elif cache_type == CacheType.DBPrune:
204
+ return apply_db_prune_on_transformer(transformer, *args, **kwargs)
205
+ elif cache_type == CacheType.NONE:
206
+ logger.warning(
207
+ f"Cache type is {cache_type}, no caching will be applied."
208
+ )
209
+ return transformer
210
+ else:
211
+ raise ValueError(f"Unknown cache type: {cache_type}")
@@ -0,0 +1,4 @@
1
+ from cache_dit.cache_factory.dual_block_cache.cache_blocks import (
2
+ cache_context,
3
+ DBCachedTransformerBlocks,
4
+ )
@@ -0,0 +1,487 @@
1
+ import inspect
2
+ import torch
3
+
4
+ from cache_dit.cache_factory.dual_block_cache import cache_context
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ class DBCachedTransformerBlocks(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ transformer_blocks: torch.nn.ModuleList,
14
+ *,
15
+ transformer: torch.nn.Module = None,
16
+ return_hidden_states_first: bool = True,
17
+ return_hidden_states_only: bool = False,
18
+ ):
19
+ super().__init__()
20
+
21
+ self.transformer = transformer
22
+ self.transformer_blocks = transformer_blocks
23
+ self.return_hidden_states_first = return_hidden_states_first
24
+ self.return_hidden_states_only = return_hidden_states_only
25
+ self._check_forward_params()
26
+
27
+ def _check_forward_params(self):
28
+ # NOTE: DBCache only support blocks which have the pattern:
29
+ # IN/OUT: (hidden_states, encoder_hidden_states)
30
+ self.required_parameters = [
31
+ "hidden_states",
32
+ "encoder_hidden_states",
33
+ ]
34
+ if self.transformer_blocks is not None:
35
+ for block in self.transformer_blocks:
36
+ forward_parameters = set(
37
+ inspect.signature(block.forward).parameters.keys()
38
+ )
39
+ for required_param in self.required_parameters:
40
+ assert (
41
+ required_param in forward_parameters
42
+ ), f"The input parameters must contains: {required_param}."
43
+
44
+ def forward(
45
+ self,
46
+ hidden_states: torch.Tensor,
47
+ encoder_hidden_states: torch.Tensor,
48
+ *args,
49
+ **kwargs,
50
+ ):
51
+ original_hidden_states = hidden_states
52
+ # Call first `n` blocks to process the hidden states for
53
+ # more stable diff calculation.
54
+ hidden_states, encoder_hidden_states = self.call_Fn_blocks(
55
+ hidden_states,
56
+ encoder_hidden_states,
57
+ *args,
58
+ **kwargs,
59
+ )
60
+
61
+ Fn_hidden_states_residual = hidden_states - original_hidden_states
62
+ del original_hidden_states
63
+
64
+ cache_context.mark_step_begin()
65
+ # Residual L1 diff or Hidden States L1 diff
66
+ can_use_cache = cache_context.get_can_use_cache(
67
+ (
68
+ Fn_hidden_states_residual
69
+ if not cache_context.is_l1_diff_enabled()
70
+ else hidden_states
71
+ ),
72
+ parallelized=self._is_parallelized(),
73
+ prefix=(
74
+ "Fn_residual"
75
+ if not cache_context.is_l1_diff_enabled()
76
+ else "Fn_hidden_states"
77
+ ),
78
+ )
79
+
80
+ torch._dynamo.graph_break()
81
+ if can_use_cache:
82
+ cache_context.add_cached_step()
83
+ del Fn_hidden_states_residual
84
+ hidden_states, encoder_hidden_states = (
85
+ cache_context.apply_hidden_states_residual(
86
+ hidden_states,
87
+ encoder_hidden_states,
88
+ prefix=(
89
+ "Bn_residual"
90
+ if cache_context.is_cache_residual()
91
+ else "Bn_hidden_states"
92
+ ),
93
+ encoder_prefix=(
94
+ "Bn_residual"
95
+ if cache_context.is_encoder_cache_residual()
96
+ else "Bn_hidden_states"
97
+ ),
98
+ )
99
+ )
100
+ torch._dynamo.graph_break()
101
+ # Call last `n` blocks to further process the hidden states
102
+ # for higher precision.
103
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
104
+ hidden_states,
105
+ encoder_hidden_states,
106
+ *args,
107
+ **kwargs,
108
+ )
109
+ else:
110
+ cache_context.set_Fn_buffer(
111
+ Fn_hidden_states_residual, prefix="Fn_residual"
112
+ )
113
+ if cache_context.is_l1_diff_enabled():
114
+ # for hidden states L1 diff
115
+ cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
116
+ del Fn_hidden_states_residual
117
+ torch._dynamo.graph_break()
118
+ (
119
+ hidden_states,
120
+ encoder_hidden_states,
121
+ hidden_states_residual,
122
+ encoder_hidden_states_residual,
123
+ ) = self.call_Mn_blocks( # middle
124
+ hidden_states,
125
+ encoder_hidden_states,
126
+ *args,
127
+ **kwargs,
128
+ )
129
+ torch._dynamo.graph_break()
130
+ if cache_context.is_cache_residual():
131
+ cache_context.set_Bn_buffer(
132
+ hidden_states_residual,
133
+ prefix="Bn_residual",
134
+ )
135
+ else:
136
+ # TaylorSeer
137
+ cache_context.set_Bn_buffer(
138
+ hidden_states,
139
+ prefix="Bn_hidden_states",
140
+ )
141
+ if cache_context.is_encoder_cache_residual():
142
+ cache_context.set_Bn_encoder_buffer(
143
+ encoder_hidden_states_residual,
144
+ prefix="Bn_residual",
145
+ )
146
+ else:
147
+ # TaylorSeer
148
+ cache_context.set_Bn_encoder_buffer(
149
+ encoder_hidden_states,
150
+ prefix="Bn_hidden_states",
151
+ )
152
+ torch._dynamo.graph_break()
153
+ # Call last `n` blocks to further process the hidden states
154
+ # for higher precision.
155
+ hidden_states, encoder_hidden_states = self.call_Bn_blocks(
156
+ hidden_states,
157
+ encoder_hidden_states,
158
+ *args,
159
+ **kwargs,
160
+ )
161
+
162
+ patch_cached_stats(self.transformer)
163
+ torch._dynamo.graph_break()
164
+
165
+ return (
166
+ hidden_states
167
+ if self.return_hidden_states_only
168
+ else (
169
+ (hidden_states, encoder_hidden_states)
170
+ if self.return_hidden_states_first
171
+ else (encoder_hidden_states, hidden_states)
172
+ )
173
+ )
174
+
175
+ @torch.compiler.disable
176
+ def _is_parallelized(self):
177
+ # Compatible with distributed inference.
178
+ return all(
179
+ (
180
+ self.transformer is not None,
181
+ getattr(self.transformer, "_is_parallelized", False),
182
+ )
183
+ )
184
+
185
+ @torch.compiler.disable
186
+ def _is_in_cache_step(self):
187
+ # Check if the current step is in cache steps.
188
+ # If so, we can skip some Bn blocks and directly
189
+ # use the cached values.
190
+ return (
191
+ cache_context.get_current_step() in cache_context.get_cached_steps()
192
+ ) or (
193
+ cache_context.get_current_step()
194
+ in cache_context.get_cfg_cached_steps()
195
+ )
196
+
197
+ @torch.compiler.disable
198
+ def _Fn_blocks(self):
199
+ # Select first `n` blocks to process the hidden states for
200
+ # more stable diff calculation.
201
+ # Fn: [0,...,n-1]
202
+ selected_Fn_blocks = self.transformer_blocks[
203
+ : cache_context.Fn_compute_blocks()
204
+ ]
205
+ return selected_Fn_blocks
206
+
207
+ @torch.compiler.disable
208
+ def _Mn_blocks(self): # middle blocks
209
+ # M(N-2n): only transformer_blocks [n,...,N-n], middle
210
+ if cache_context.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
211
+ selected_Mn_blocks = self.transformer_blocks[
212
+ cache_context.Fn_compute_blocks() :
213
+ ]
214
+ else:
215
+ selected_Mn_blocks = self.transformer_blocks[
216
+ cache_context.Fn_compute_blocks() : -cache_context.Bn_compute_blocks()
217
+ ]
218
+ return selected_Mn_blocks
219
+
220
+ @torch.compiler.disable
221
+ def _Bn_blocks(self):
222
+ # Bn: transformer_blocks [N-n+1,...,N-1]
223
+ selected_Bn_blocks = self.transformer_blocks[
224
+ -cache_context.Bn_compute_blocks() :
225
+ ]
226
+ return selected_Bn_blocks
227
+
228
+ def call_Fn_blocks(
229
+ self,
230
+ hidden_states: torch.Tensor,
231
+ encoder_hidden_states: torch.Tensor,
232
+ *args,
233
+ **kwargs,
234
+ ):
235
+ assert cache_context.Fn_compute_blocks() <= len(
236
+ self.transformer_blocks
237
+ ), (
238
+ f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
239
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
240
+ )
241
+ for block in self._Fn_blocks():
242
+ hidden_states = block(
243
+ hidden_states,
244
+ encoder_hidden_states,
245
+ *args,
246
+ **kwargs,
247
+ )
248
+ if not isinstance(hidden_states, torch.Tensor):
249
+ hidden_states, encoder_hidden_states = hidden_states
250
+ if not self.return_hidden_states_first:
251
+ hidden_states, encoder_hidden_states = (
252
+ encoder_hidden_states,
253
+ hidden_states,
254
+ )
255
+
256
+ return hidden_states, encoder_hidden_states
257
+
258
+ def call_Mn_blocks(
259
+ self,
260
+ hidden_states: torch.Tensor,
261
+ encoder_hidden_states: torch.Tensor,
262
+ *args,
263
+ **kwargs,
264
+ ):
265
+ original_hidden_states = hidden_states
266
+ original_encoder_hidden_states = encoder_hidden_states
267
+ for block in self._Mn_blocks():
268
+ hidden_states = block(
269
+ hidden_states,
270
+ encoder_hidden_states,
271
+ *args,
272
+ **kwargs,
273
+ )
274
+ if not isinstance(hidden_states, torch.Tensor):
275
+ hidden_states, encoder_hidden_states = hidden_states
276
+ if not self.return_hidden_states_first:
277
+ hidden_states, encoder_hidden_states = (
278
+ encoder_hidden_states,
279
+ hidden_states,
280
+ )
281
+
282
+ # compute hidden_states residual
283
+ hidden_states = hidden_states.contiguous()
284
+ encoder_hidden_states = encoder_hidden_states.contiguous()
285
+
286
+ hidden_states_residual = hidden_states - original_hidden_states
287
+ encoder_hidden_states_residual = (
288
+ encoder_hidden_states - original_encoder_hidden_states
289
+ )
290
+
291
+ return (
292
+ hidden_states,
293
+ encoder_hidden_states,
294
+ hidden_states_residual,
295
+ encoder_hidden_states_residual,
296
+ )
297
+
298
+ def _compute_or_cache_block(
299
+ self,
300
+ # Block index in the transformer blocks
301
+ # Bn: 8, block_id should be in [0, 8)
302
+ block_id: int,
303
+ # Below are the inputs to the block
304
+ block, # The transformer block to be executed
305
+ hidden_states: torch.Tensor,
306
+ encoder_hidden_states: torch.Tensor,
307
+ *args,
308
+ **kwargs,
309
+ ):
310
+ # Helper function for `call_Bn_blocks`
311
+ # Skip the blocks by reuse residual cache if they are not
312
+ # in the Bn_compute_blocks_ids. NOTE: We should only skip
313
+ # the specific Bn blocks in cache steps. Compute the block
314
+ # and cache the residuals in non-cache steps.
315
+
316
+ # Normal steps: Compute the block and cache the residuals.
317
+ if not self._is_in_cache_step():
318
+ Bn_i_original_hidden_states = hidden_states
319
+ Bn_i_original_encoder_hidden_states = encoder_hidden_states
320
+ hidden_states = block(
321
+ hidden_states,
322
+ encoder_hidden_states,
323
+ *args,
324
+ **kwargs,
325
+ )
326
+ if not isinstance(hidden_states, torch.Tensor):
327
+ hidden_states, encoder_hidden_states = hidden_states
328
+ if not self.return_hidden_states_first:
329
+ hidden_states, encoder_hidden_states = (
330
+ encoder_hidden_states,
331
+ hidden_states,
332
+ )
333
+ # Cache residuals for the non-compute Bn blocks for
334
+ # subsequent cache steps.
335
+ if block_id not in cache_context.Bn_compute_blocks_ids():
336
+ Bn_i_hidden_states_residual = (
337
+ hidden_states - Bn_i_original_hidden_states
338
+ )
339
+ Bn_i_encoder_hidden_states_residual = (
340
+ encoder_hidden_states - Bn_i_original_encoder_hidden_states
341
+ )
342
+
343
+ # Save original_hidden_states for diff calculation.
344
+ cache_context.set_Bn_buffer(
345
+ Bn_i_original_hidden_states,
346
+ prefix=f"Bn_{block_id}_original",
347
+ )
348
+ cache_context.set_Bn_encoder_buffer(
349
+ Bn_i_original_encoder_hidden_states,
350
+ prefix=f"Bn_{block_id}_original",
351
+ )
352
+
353
+ cache_context.set_Bn_buffer(
354
+ Bn_i_hidden_states_residual,
355
+ prefix=f"Bn_{block_id}_residual",
356
+ )
357
+ cache_context.set_Bn_encoder_buffer(
358
+ Bn_i_encoder_hidden_states_residual,
359
+ prefix=f"Bn_{block_id}_residual",
360
+ )
361
+ del Bn_i_hidden_states_residual
362
+ del Bn_i_encoder_hidden_states_residual
363
+
364
+ del Bn_i_original_hidden_states
365
+ del Bn_i_original_encoder_hidden_states
366
+
367
+ else:
368
+ # Cache steps: Reuse the cached residuals.
369
+ # Check if the block is in the Bn_compute_blocks_ids.
370
+ if block_id in cache_context.Bn_compute_blocks_ids():
371
+ hidden_states = block(
372
+ hidden_states,
373
+ encoder_hidden_states,
374
+ *args,
375
+ **kwargs,
376
+ )
377
+ if not isinstance(hidden_states, torch.Tensor):
378
+ hidden_states, encoder_hidden_states = hidden_states
379
+ if not self.return_hidden_states_first:
380
+ hidden_states, encoder_hidden_states = (
381
+ encoder_hidden_states,
382
+ hidden_states,
383
+ )
384
+ else:
385
+ # Skip the block if it is not in the Bn_compute_blocks_ids.
386
+ # Use the cached residuals instead.
387
+ # Check if can use the cached residuals.
388
+ if cache_context.get_can_use_cache(
389
+ hidden_states, # curr step
390
+ parallelized=self._is_parallelized(),
391
+ threshold=cache_context.non_compute_blocks_diff_threshold(),
392
+ prefix=f"Bn_{block_id}_original", # prev step
393
+ ):
394
+ hidden_states, encoder_hidden_states = (
395
+ cache_context.apply_hidden_states_residual(
396
+ hidden_states,
397
+ encoder_hidden_states,
398
+ prefix=(
399
+ f"Bn_{block_id}_residual"
400
+ if cache_context.is_cache_residual()
401
+ else f"Bn_{block_id}_original"
402
+ ),
403
+ encoder_prefix=(
404
+ f"Bn_{block_id}_residual"
405
+ if cache_context.is_encoder_cache_residual()
406
+ else f"Bn_{block_id}_original"
407
+ ),
408
+ )
409
+ )
410
+ else:
411
+ hidden_states = block(
412
+ hidden_states,
413
+ encoder_hidden_states,
414
+ *args,
415
+ **kwargs,
416
+ )
417
+ if not isinstance(hidden_states, torch.Tensor):
418
+ hidden_states, encoder_hidden_states = hidden_states
419
+ if not self.return_hidden_states_first:
420
+ hidden_states, encoder_hidden_states = (
421
+ encoder_hidden_states,
422
+ hidden_states,
423
+ )
424
+ return hidden_states, encoder_hidden_states
425
+
426
+ def call_Bn_blocks(
427
+ self,
428
+ hidden_states: torch.Tensor,
429
+ encoder_hidden_states: torch.Tensor,
430
+ *args,
431
+ **kwargs,
432
+ ):
433
+ if cache_context.Bn_compute_blocks() == 0:
434
+ return hidden_states, encoder_hidden_states
435
+
436
+ assert cache_context.Bn_compute_blocks() <= len(
437
+ self.transformer_blocks
438
+ ), (
439
+ f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
440
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
441
+ )
442
+ if len(cache_context.Bn_compute_blocks_ids()) > 0:
443
+ for i, block in enumerate(self._Bn_blocks()):
444
+ hidden_states, encoder_hidden_states = (
445
+ self._compute_or_cache_block(
446
+ i,
447
+ block,
448
+ hidden_states,
449
+ encoder_hidden_states,
450
+ *args,
451
+ **kwargs,
452
+ )
453
+ )
454
+ else:
455
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
456
+ for block in self._Bn_blocks():
457
+ hidden_states = block(
458
+ hidden_states,
459
+ encoder_hidden_states,
460
+ *args,
461
+ **kwargs,
462
+ )
463
+ if not isinstance(hidden_states, torch.Tensor):
464
+ hidden_states, encoder_hidden_states = hidden_states
465
+ if not self.return_hidden_states_first:
466
+ hidden_states, encoder_hidden_states = (
467
+ encoder_hidden_states,
468
+ hidden_states,
469
+ )
470
+
471
+ return hidden_states, encoder_hidden_states
472
+
473
+
474
+ @torch.compiler.disable
475
+ def patch_cached_stats(
476
+ transformer,
477
+ ):
478
+ # Patch the cached stats to the transformer, the cached stats
479
+ # will be reset for each calling of pipe.__call__(**kwargs).
480
+ if transformer is None:
481
+ return
482
+
483
+ # TODO: Patch more cached stats to the transformer
484
+ transformer._cached_steps = cache_context.get_cached_steps()
485
+ transformer._residual_diffs = cache_context.get_residual_diffs()
486
+ transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
487
+ transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()