cache-dit 0.1.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (30) hide show
  1. cache_dit/__init__.py +0 -0
  2. cache_dit/_version.py +21 -0
  3. cache_dit/cache_factory/__init__.py +166 -0
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
  10. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
  16. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  17. cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
  18. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
  19. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
  20. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
  21. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
  22. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
  23. cache_dit/cache_factory/taylorseer.py +76 -0
  24. cache_dit/cache_factory/utils.py +0 -0
  25. cache_dit/logger.py +97 -0
  26. cache_dit/primitives.py +152 -0
  27. cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
  28. cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
  29. cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
  30. cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
@@ -0,0 +1,1361 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
2
+
3
+ import logging
4
+ import contextlib
5
+ import dataclasses
6
+ from collections import defaultdict
7
+ from typing import Any, DefaultDict, Dict, List, Optional, Union
8
+
9
+ import torch
10
+
11
+ import cache_dit.primitives as DP
12
+ from cache_dit.logger import init_logger
13
+
14
+ logger = init_logger(__name__)
15
+
16
+
17
+ @dataclasses.dataclass
18
+ class DBCacheContext:
19
+ # Dual Block Cache
20
+ # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
21
+ Fn_compute_blocks: int = 1
22
+ Bn_compute_blocks: int = 0
23
+ # We have added residual cache pattern for selected compute blocks
24
+ Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
25
+ Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
26
+ # non compute blocks diff threshold, we don't skip the non
27
+ # compute blocks if the diff >= threshold
28
+ non_compute_blocks_diff_threshold: float = 0.08
29
+ max_Fn_compute_blocks: int = -1
30
+ max_Bn_compute_blocks: int = -1
31
+ # L1 hidden states or residual diff threshold for Fn
32
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.0
33
+ l1_hidden_states_diff_threshold: float = None
34
+ important_condition_threshold: float = 0.0
35
+
36
+ # Alter Cache Settings
37
+ # Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
38
+ enable_alter_cache: bool = False
39
+ is_alter_cache: bool = True
40
+ # 1.0 means we always cache the residuals if alter_cache is enabled.
41
+ alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
42
+
43
+ # Buffer for storing the residuals and other tensors
44
+ buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
45
+ incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
46
+ default_factory=lambda: defaultdict(int),
47
+ )
48
+
49
+ # Other settings
50
+ downsample_factor: int = 1
51
+ num_inference_steps: int = -1
52
+ warmup_steps: int = 0 # DON'T Cache in warmup steps
53
+ # DON'T Cache if the number of cached steps >= max_cached_steps
54
+ max_cached_steps: int = -1
55
+
56
+ # Statistics for botch alter cache and non-alter cache
57
+ # Record the steps that have been cached, both alter cache and non-alter cache
58
+ executed_steps: int = 0 # cache + non-cache steps
59
+ cached_steps: List[int] = dataclasses.field(default_factory=list)
60
+ residual_diffs: DefaultDict[str, float] = dataclasses.field(
61
+ default_factory=lambda: defaultdict(float),
62
+ )
63
+
64
+ def get_incremental_name(self, name=None):
65
+ if name is None:
66
+ name = "default"
67
+ idx = self.incremental_name_counters[name]
68
+ self.incremental_name_counters[name] += 1
69
+ return f"{name}_{idx}"
70
+
71
+ def reset_incremental_names(self):
72
+ self.incremental_name_counters.clear()
73
+
74
+ def get_residual_diff_threshold(self):
75
+ if self.enable_alter_cache:
76
+ residual_diff_threshold = self.alter_residual_diff_threshold
77
+ else:
78
+ residual_diff_threshold = self.residual_diff_threshold
79
+ if self.l1_hidden_states_diff_threshold is not None:
80
+ # Use the L1 hidden states diff threshold if set
81
+ residual_diff_threshold = self.l1_hidden_states_diff_threshold
82
+ if isinstance(residual_diff_threshold, torch.Tensor):
83
+ residual_diff_threshold = residual_diff_threshold.item()
84
+ return residual_diff_threshold
85
+
86
+ def get_buffer(self, name):
87
+ if self.enable_alter_cache and self.is_alter_cache:
88
+ name = f"{name}_alter"
89
+ return self.buffers.get(name)
90
+
91
+ def set_buffer(self, name, buffer):
92
+ if self.enable_alter_cache and self.is_alter_cache:
93
+ name = f"{name}_alter"
94
+ self.buffers[name] = buffer
95
+
96
+ def remove_buffer(self, name):
97
+ if self.enable_alter_cache and self.is_alter_cache:
98
+ name = f"{name}_alter"
99
+ if name in self.buffers:
100
+ del self.buffers[name]
101
+
102
+ def clear_buffers(self):
103
+ self.buffers.clear()
104
+
105
+ def mark_step_begin(self):
106
+ if not self.enable_alter_cache:
107
+ self.executed_steps += 1
108
+ else:
109
+ self.executed_steps += 1
110
+ # 0 F 1 T 2 F 3 T 4 F 5 T ...
111
+ self.is_alter_cache = not self.is_alter_cache
112
+
113
+ # Reset the cached steps and residual diffs at the beginning
114
+ # of each inference.
115
+ if self.get_current_step() == 0:
116
+ self.cached_steps.clear()
117
+ self.residual_diffs.clear()
118
+ self.reset_incremental_names()
119
+
120
+ def add_residual_diff(self, diff):
121
+ step = str(self.get_current_step())
122
+ if step not in self.residual_diffs:
123
+ # Only add the diff if it is not already recorded for this step
124
+ self.residual_diffs[step] = diff
125
+
126
+ def get_residual_diffs(self):
127
+ return self.residual_diffs.copy()
128
+
129
+ def add_cached_step(self):
130
+ self.cached_steps.append(self.get_current_step())
131
+
132
+ def get_cached_steps(self):
133
+ return self.cached_steps.copy()
134
+
135
+ def get_current_step(self):
136
+ return self.executed_steps - 1
137
+
138
+ def is_in_warmup(self):
139
+ return self.get_current_step() < self.warmup_steps
140
+
141
+
142
+ @torch.compiler.disable
143
+ def get_residual_diff_threshold():
144
+ cache_context = get_current_cache_context()
145
+ assert cache_context is not None, "cache_context must be set before"
146
+ return cache_context.get_residual_diff_threshold()
147
+
148
+
149
+ @torch.compiler.disable
150
+ def get_buffer(name):
151
+ cache_context = get_current_cache_context()
152
+ assert cache_context is not None, "cache_context must be set before"
153
+ return cache_context.get_buffer(name)
154
+
155
+
156
+ @torch.compiler.disable
157
+ def set_buffer(name, buffer):
158
+ cache_context = get_current_cache_context()
159
+ assert cache_context is not None, "cache_context must be set before"
160
+ cache_context.set_buffer(name, buffer)
161
+
162
+
163
+ @torch.compiler.disable
164
+ def remove_buffer(name):
165
+ cache_context = get_current_cache_context()
166
+ assert cache_context is not None, "cache_context must be set before"
167
+ cache_context.remove_buffer(name)
168
+
169
+
170
+ @torch.compiler.disable
171
+ def mark_step_begin():
172
+ cache_context = get_current_cache_context()
173
+ assert cache_context is not None, "cache_context must be set before"
174
+ cache_context.mark_step_begin()
175
+
176
+
177
+ @torch.compiler.disable
178
+ def get_current_step():
179
+ cache_context = get_current_cache_context()
180
+ assert cache_context is not None, "cache_context must be set before"
181
+ return cache_context.get_current_step()
182
+
183
+
184
+ @torch.compiler.disable
185
+ def get_cached_steps():
186
+ cache_context = get_current_cache_context()
187
+ assert cache_context is not None, "cache_context must be set before"
188
+ return cache_context.get_cached_steps()
189
+
190
+
191
+ @torch.compiler.disable
192
+ def get_max_cached_steps():
193
+ cache_context = get_current_cache_context()
194
+ assert cache_context is not None, "cache_context must be set before"
195
+ return cache_context.max_cached_steps
196
+
197
+
198
+ @torch.compiler.disable
199
+ def add_cached_step():
200
+ cache_context = get_current_cache_context()
201
+ assert cache_context is not None, "cache_context must be set before"
202
+ cache_context.add_cached_step()
203
+
204
+
205
+ @torch.compiler.disable
206
+ def add_residual_diff(diff):
207
+ cache_context = get_current_cache_context()
208
+ assert cache_context is not None, "cache_context must be set before"
209
+ cache_context.add_residual_diff(diff)
210
+
211
+
212
+ @torch.compiler.disable
213
+ def get_residual_diffs():
214
+ cache_context = get_current_cache_context()
215
+ assert cache_context is not None, "cache_context must be set before"
216
+ return cache_context.get_residual_diffs()
217
+
218
+
219
+ @torch.compiler.disable
220
+ def is_alter_cache_enabled():
221
+ cache_context = get_current_cache_context()
222
+ assert cache_context is not None, "cache_context must be set before"
223
+ return cache_context.enable_alter_cache
224
+
225
+
226
+ @torch.compiler.disable
227
+ def is_alter_cache():
228
+ cache_context = get_current_cache_context()
229
+ assert cache_context is not None, "cache_context must be set before"
230
+ return cache_context.is_alter_cache
231
+
232
+
233
+ @torch.compiler.disable
234
+ def is_in_warmup():
235
+ cache_context = get_current_cache_context()
236
+ assert cache_context is not None, "cache_context must be set before"
237
+ return cache_context.is_in_warmup()
238
+
239
+
240
+ @torch.compiler.disable
241
+ def is_l1_diff_enabled():
242
+ cache_context = get_current_cache_context()
243
+ assert cache_context is not None, "cache_context must be set before"
244
+ return (
245
+ cache_context.l1_hidden_states_diff_threshold is not None
246
+ and cache_context.l1_hidden_states_diff_threshold > 0.0
247
+ )
248
+
249
+
250
+ @torch.compiler.disable
251
+ def get_important_condition_threshold():
252
+ cache_context = get_current_cache_context()
253
+ assert cache_context is not None, "cache_context must be set before"
254
+ return cache_context.important_condition_threshold
255
+
256
+
257
+ @torch.compiler.disable
258
+ def non_compute_blocks_diff_threshold():
259
+ cache_context = get_current_cache_context()
260
+ assert cache_context is not None, "cache_context must be set before"
261
+ return cache_context.non_compute_blocks_diff_threshold
262
+
263
+
264
+ @torch.compiler.disable
265
+ def Fn_compute_blocks():
266
+ cache_context = get_current_cache_context()
267
+ assert cache_context is not None, "cache_context must be set before"
268
+ assert (
269
+ cache_context.Fn_compute_blocks >= 1
270
+ ), "Fn_compute_blocks must be >= 1"
271
+ if cache_context.max_Fn_compute_blocks > 0:
272
+ # NOTE: Fn_compute_blocks can be 1, which means FB Cache
273
+ # but it must be less than or equal to max_Fn_compute_blocks
274
+ assert (
275
+ cache_context.Fn_compute_blocks
276
+ <= cache_context.max_Fn_compute_blocks
277
+ ), (
278
+ f"Fn_compute_blocks must be <= {cache_context.max_Fn_compute_blocks}, "
279
+ f"but got {cache_context.Fn_compute_blocks}"
280
+ )
281
+ return cache_context.Fn_compute_blocks
282
+
283
+
284
+ @torch.compiler.disable
285
+ def Fn_compute_blocks_ids():
286
+ cache_context = get_current_cache_context()
287
+ assert cache_context is not None, "cache_context must be set before"
288
+ assert (
289
+ len(cache_context.Fn_compute_blocks_ids)
290
+ <= cache_context.Fn_compute_blocks
291
+ ), (
292
+ "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
293
+ f"{cache_context.Fn_compute_blocks}, but got "
294
+ f"{len(cache_context.Fn_compute_blocks_ids)}"
295
+ )
296
+ return cache_context.Fn_compute_blocks_ids
297
+
298
+
299
+ @torch.compiler.disable
300
+ def Bn_compute_blocks():
301
+ cache_context = get_current_cache_context()
302
+ assert cache_context is not None, "cache_context must be set before"
303
+ assert (
304
+ cache_context.Bn_compute_blocks >= 0
305
+ ), "Bn_compute_blocks must be >= 0"
306
+ if cache_context.max_Bn_compute_blocks > 0:
307
+ # NOTE: Bn_compute_blocks can be 0, which means FB Cache
308
+ # but it must be less than or equal to max_Bn_compute_blocks
309
+ assert (
310
+ cache_context.Bn_compute_blocks
311
+ <= cache_context.max_Bn_compute_blocks
312
+ ), (
313
+ f"Bn_compute_blocks must be <= {cache_context.max_Bn_compute_blocks}, "
314
+ f"but got {cache_context.Bn_compute_blocks}"
315
+ )
316
+ return cache_context.Bn_compute_blocks
317
+
318
+
319
+ @torch.compiler.disable
320
+ def Bn_compute_blocks_ids():
321
+ cache_context = get_current_cache_context()
322
+ assert cache_context is not None, "cache_context must be set before"
323
+ assert (
324
+ len(cache_context.Bn_compute_blocks_ids)
325
+ <= cache_context.Bn_compute_blocks
326
+ ), (
327
+ "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
328
+ f"{cache_context.Bn_compute_blocks}, but got "
329
+ f"{len(cache_context.Bn_compute_blocks_ids)}"
330
+ )
331
+ return cache_context.Bn_compute_blocks_ids
332
+
333
+
334
+ _current_cache_context: DBCacheContext = None
335
+
336
+
337
+ def create_cache_context(*args, **kwargs):
338
+ return DBCacheContext(*args, **kwargs)
339
+
340
+
341
+ def get_current_cache_context():
342
+ return _current_cache_context
343
+
344
+
345
+ def set_current_cache_context(cache_context=None):
346
+ global _current_cache_context
347
+ _current_cache_context = cache_context
348
+
349
+
350
+ def collect_cache_kwargs(default_attrs: dict, **kwargs):
351
+ # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
352
+ # default_attrs: specific settings for different pipelines
353
+ cache_attrs = dataclasses.fields(DBCacheContext)
354
+ cache_attrs = [
355
+ attr
356
+ for attr in cache_attrs
357
+ if hasattr(
358
+ DBCacheContext,
359
+ attr.name,
360
+ )
361
+ ]
362
+ cache_kwargs = {
363
+ attr.name: kwargs.pop(
364
+ attr.name,
365
+ getattr(DBCacheContext, attr.name),
366
+ )
367
+ for attr in cache_attrs
368
+ }
369
+
370
+ # Manually set sequence fields, namely, Fn_compute_blocks_ids
371
+ # and Bn_compute_blocks_ids, which are lists or sets.
372
+ cache_kwargs["Fn_compute_blocks_ids"] = kwargs.pop(
373
+ "Fn_compute_blocks_ids",
374
+ [],
375
+ )
376
+ cache_kwargs["Bn_compute_blocks_ids"] = kwargs.pop(
377
+ "Bn_compute_blocks_ids",
378
+ [],
379
+ )
380
+
381
+ assert default_attrs is not None, "default_attrs must be set before"
382
+ for attr in cache_attrs:
383
+ if attr.name in default_attrs:
384
+ cache_kwargs[attr.name] = default_attrs[attr.name]
385
+
386
+ if logger.isEnabledFor(logging.DEBUG):
387
+ logger.debug(f"Collected DBCache kwargs: {cache_kwargs}")
388
+
389
+ return cache_kwargs, kwargs
390
+
391
+
392
+ @contextlib.contextmanager
393
+ def cache_context(cache_context):
394
+ global _current_cache_context
395
+ old_cache_context = _current_cache_context
396
+ _current_cache_context = cache_context
397
+ try:
398
+ yield
399
+ finally:
400
+ _current_cache_context = old_cache_context
401
+
402
+
403
+ @torch.compiler.disable
404
+ def are_two_tensors_similar(
405
+ t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
406
+ t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
407
+ *,
408
+ threshold: float,
409
+ parallelized: bool = False,
410
+ prefix: str = "Fn", # for debugging
411
+ ):
412
+ # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
413
+ # the threshold is always enabled, -2.0 means the shape is not matched.
414
+ if threshold <= 0.0:
415
+ add_residual_diff(-0.0)
416
+ return False
417
+
418
+ if threshold >= 1.0:
419
+ # If threshold is 1.0 or more, we consider them always similar.
420
+ add_residual_diff(-1.0)
421
+ return True
422
+
423
+ if t1.shape != t2.shape:
424
+ if logger.isEnabledFor(logging.DEBUG):
425
+ logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
426
+ add_residual_diff(-2.0)
427
+ return False
428
+
429
+ # Find the most significant token through t1 and t2, and
430
+ # consider the diff of the significant token. The more significant,
431
+ # the more important.
432
+ condition_thresh = get_important_condition_threshold()
433
+ if condition_thresh > 0.0:
434
+ raw_diff = (t1 - t2).abs() # [B, seq_len, d]
435
+ token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
436
+ token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
437
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
438
+ token_diff = token_m_df / token_m_t1 # [B, seq_len]
439
+ condition = token_diff > condition_thresh # [B, seq_len]
440
+ if condition.sum() > 0:
441
+ condition = condition.unsqueeze(-1) # [B, seq_len, 1]
442
+ condition = condition.expand_as(raw_diff) # [B, seq_len, d]
443
+ mean_diff = raw_diff[condition].mean()
444
+ mean_t1 = t1[condition].abs().mean()
445
+ else:
446
+ mean_diff = (t1 - t2).abs().mean()
447
+ mean_t1 = t1.abs().mean()
448
+ else:
449
+ # Use the mean of the absolute difference of the tensors
450
+ mean_diff = (t1 - t2).abs().mean()
451
+ mean_t1 = t1.abs().mean()
452
+
453
+ if parallelized:
454
+ mean_diff = DP.all_reduce_sync(mean_diff, "avg")
455
+ mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
456
+
457
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
458
+ # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
459
+ # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
460
+ diff = (mean_diff / mean_t1).item()
461
+
462
+ if logger.isEnabledFor(logging.DEBUG):
463
+ logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
464
+
465
+ add_residual_diff(diff)
466
+
467
+ return diff < threshold
468
+
469
+
470
+ # Fn buffers
471
+ @torch.compiler.disable
472
+ def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
473
+ # Set hidden_states or residual for Fn blocks.
474
+ downsample_factor = get_downsample_factor()
475
+ if downsample_factor > 1:
476
+ buffer = buffer[..., ::downsample_factor]
477
+ buffer = buffer.contiguous()
478
+ set_buffer(f"{prefix}_buffer", buffer)
479
+
480
+
481
+ @torch.compiler.disable
482
+ def get_Fn_buffer(prefix: str = "Fn"):
483
+ return get_buffer(f"{prefix}_buffer")
484
+
485
+
486
+ @torch.compiler.disable
487
+ def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
488
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
489
+
490
+
491
+ @torch.compiler.disable
492
+ def get_Fn_encoder_buffer(prefix: str = "Fn"):
493
+ return get_buffer(f"{prefix}_encoder_buffer")
494
+
495
+
496
+ # Bn buffers
497
+ @torch.compiler.disable
498
+ def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
499
+ # Set hidden_states or residual for Bn blocks.
500
+ set_buffer(f"{prefix}_buffer", buffer)
501
+
502
+
503
+ @torch.compiler.disable
504
+ def get_Bn_buffer(prefix: str = "Bn"):
505
+ return get_buffer(f"{prefix}_buffer")
506
+
507
+
508
+ @torch.compiler.disable
509
+ def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
510
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
511
+
512
+
513
+ @torch.compiler.disable
514
+ def get_Bn_encoder_buffer(prefix: str = "Bn"):
515
+ return get_buffer(f"{prefix}_encoder_buffer")
516
+
517
+
518
+ @torch.compiler.disable
519
+ def apply_hidden_states_residual(
520
+ hidden_states: torch.Tensor,
521
+ encoder_hidden_states: torch.Tensor,
522
+ prefix: str = "Bn",
523
+ ):
524
+ # Allow Bn and Fn prefix to be used for residual cache.
525
+ if "Bn" in prefix:
526
+ hidden_states_residual = get_Bn_buffer(prefix)
527
+ else:
528
+ hidden_states_residual = get_Fn_buffer(prefix)
529
+
530
+ assert (
531
+ hidden_states_residual is not None
532
+ ), f"{prefix}_buffer must be set before"
533
+ hidden_states = hidden_states_residual + hidden_states
534
+
535
+ if "Bn" in prefix:
536
+ encoder_hidden_states_residual = get_Bn_encoder_buffer(prefix)
537
+ else:
538
+ encoder_hidden_states_residual = get_Fn_encoder_buffer(prefix)
539
+
540
+ assert (
541
+ encoder_hidden_states_residual is not None
542
+ ), f"{prefix}_encoder_buffer must be set before"
543
+ encoder_hidden_states = (
544
+ encoder_hidden_states_residual + encoder_hidden_states
545
+ )
546
+
547
+ hidden_states = hidden_states.contiguous()
548
+ encoder_hidden_states = encoder_hidden_states.contiguous()
549
+
550
+ return hidden_states, encoder_hidden_states
551
+
552
+
553
+ @torch.compiler.disable
554
+ def get_downsample_factor():
555
+ cache_context = get_current_cache_context()
556
+ assert cache_context is not None, "cache_context must be set before"
557
+ return cache_context.downsample_factor
558
+
559
+
560
+ @torch.compiler.disable
561
+ def get_can_use_cache(
562
+ states_tensor: torch.Tensor, # hidden_states or residual
563
+ parallelized: bool = False,
564
+ threshold: Optional[float] = None, # can manually set threshold
565
+ prefix: str = "Fn",
566
+ ):
567
+ if is_in_warmup():
568
+ return False
569
+ cached_steps = get_cached_steps()
570
+ max_cached_steps = get_max_cached_steps()
571
+ if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
572
+ if logger.isEnabledFor(logging.DEBUG):
573
+ logger.debug(
574
+ f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
575
+ "cannot use cache."
576
+ )
577
+ return False
578
+ if threshold is None or threshold <= 0.0:
579
+ threshold = get_residual_diff_threshold()
580
+ if threshold <= 0.0:
581
+ return False
582
+ downsample_factor = get_downsample_factor()
583
+ if downsample_factor > 1 and "Bn" not in prefix:
584
+ states_tensor = states_tensor[..., ::downsample_factor]
585
+ states_tensor = states_tensor.contiguous()
586
+
587
+ # Allow Bn and Fn prefix to be used for diff calculation.
588
+ if "Bn" in prefix:
589
+ prev_states_tensor = get_Bn_buffer(prefix)
590
+ else:
591
+ prev_states_tensor = get_Fn_buffer(prefix)
592
+
593
+ if not is_alter_cache_enabled():
594
+ # Dynamic cache according to the residual diff
595
+ can_use_cache = (
596
+ prev_states_tensor is not None
597
+ and are_two_tensors_similar(
598
+ prev_states_tensor,
599
+ states_tensor,
600
+ threshold=threshold,
601
+ parallelized=parallelized,
602
+ prefix=prefix,
603
+ )
604
+ )
605
+ else:
606
+ # Only cache in the alter cache steps
607
+ can_use_cache = (
608
+ prev_states_tensor is not None
609
+ and are_two_tensors_similar(
610
+ prev_states_tensor,
611
+ states_tensor,
612
+ threshold=threshold,
613
+ parallelized=parallelized,
614
+ prefix=prefix,
615
+ )
616
+ and is_alter_cache()
617
+ )
618
+ return can_use_cache
619
+
620
+
621
+ class DBCachedTransformerBlocks(torch.nn.Module):
622
+ def __init__(
623
+ self,
624
+ transformer_blocks,
625
+ single_transformer_blocks=None,
626
+ *,
627
+ transformer=None,
628
+ return_hidden_states_first=True,
629
+ return_hidden_states_only=False,
630
+ ):
631
+ super().__init__()
632
+
633
+ self.transformer = transformer
634
+ self.transformer_blocks = transformer_blocks
635
+ self.single_transformer_blocks = single_transformer_blocks
636
+ self.return_hidden_states_first = return_hidden_states_first
637
+ self.return_hidden_states_only = return_hidden_states_only
638
+
639
+ def forward(
640
+ self,
641
+ hidden_states: torch.Tensor,
642
+ encoder_hidden_states: torch.Tensor,
643
+ *args,
644
+ **kwargs,
645
+ ):
646
+ original_hidden_states = hidden_states
647
+ # Call first `n` blocks to process the hidden states for
648
+ # more stable diff calculation.
649
+ hidden_states, encoder_hidden_states = self.call_Fn_transformer_blocks(
650
+ hidden_states,
651
+ encoder_hidden_states,
652
+ *args,
653
+ **kwargs,
654
+ )
655
+
656
+ Fn_hidden_states_residual = hidden_states - original_hidden_states
657
+ del original_hidden_states
658
+
659
+ mark_step_begin()
660
+ # Residual L1 diff or Hidden States L1 diff
661
+ can_use_cache = get_can_use_cache(
662
+ (
663
+ Fn_hidden_states_residual
664
+ if not is_l1_diff_enabled()
665
+ else hidden_states
666
+ ),
667
+ parallelized=self._is_parallelized(),
668
+ prefix=(
669
+ "Fn_residual"
670
+ if not is_l1_diff_enabled()
671
+ else "Fn_hidden_states"
672
+ ),
673
+ )
674
+
675
+ torch._dynamo.graph_break()
676
+ if can_use_cache:
677
+ add_cached_step()
678
+ del Fn_hidden_states_residual
679
+ hidden_states, encoder_hidden_states = apply_hidden_states_residual(
680
+ hidden_states, encoder_hidden_states, prefix="Bn_residual"
681
+ )
682
+ # Call last `n` blocks to further process the hidden states
683
+ # for higher precision.
684
+ hidden_states, encoder_hidden_states = (
685
+ self.call_Bn_transformer_blocks(
686
+ hidden_states,
687
+ encoder_hidden_states,
688
+ *args,
689
+ **kwargs,
690
+ )
691
+ )
692
+ else:
693
+ set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
694
+ if is_l1_diff_enabled():
695
+ # for hidden states L1 diff
696
+ set_Fn_buffer(hidden_states, "Fn_hidden_states")
697
+ del Fn_hidden_states_residual
698
+ (
699
+ hidden_states,
700
+ encoder_hidden_states,
701
+ hidden_states_residual,
702
+ encoder_hidden_states_residual,
703
+ ) = self.call_MN2n_transformer_blocks( # middle
704
+ hidden_states,
705
+ encoder_hidden_states,
706
+ *args,
707
+ **kwargs,
708
+ )
709
+ set_Bn_buffer(hidden_states_residual, prefix="Bn_residual")
710
+ set_Bn_encoder_buffer(
711
+ encoder_hidden_states_residual, prefix="Bn_residual"
712
+ )
713
+ # Call last `n` blocks to further process the hidden states
714
+ # for higher precision.
715
+ hidden_states, encoder_hidden_states = (
716
+ self.call_Bn_transformer_blocks(
717
+ hidden_states,
718
+ encoder_hidden_states,
719
+ *args,
720
+ **kwargs,
721
+ )
722
+ )
723
+
724
+ patch_cached_stats(self.transformer)
725
+ torch._dynamo.graph_break()
726
+
727
+ return (
728
+ hidden_states
729
+ if self.return_hidden_states_only
730
+ else (
731
+ (hidden_states, encoder_hidden_states)
732
+ if self.return_hidden_states_first
733
+ else (encoder_hidden_states, hidden_states)
734
+ )
735
+ )
736
+
737
+ @torch.compiler.disable
738
+ def _is_parallelized(self):
739
+ # Compatible with distributed inference.
740
+ return all(
741
+ (
742
+ self.transformer is not None,
743
+ getattr(self.transformer, "_is_parallelized", False),
744
+ )
745
+ )
746
+
747
+ @torch.compiler.disable
748
+ def _is_in_cache_step(self):
749
+ # Check if the current step is in cache steps.
750
+ # If so, we can skip some Bn blocks and directly
751
+ # use the cached values.
752
+ return get_current_step() in get_cached_steps()
753
+
754
+ @torch.compiler.disable
755
+ def _Fn_transformer_blocks(self):
756
+ # Select first `n` blocks to process the hidden states for
757
+ # more stable diff calculation.
758
+ # Fn: [0,...,n-1]
759
+ selected_Fn_transformer_blocks = self.transformer_blocks[
760
+ : Fn_compute_blocks()
761
+ ]
762
+ # Skip the blocks if they are not in the Fn_compute_blocks_ids.
763
+ # WARN: DON'T set len(Fn_compute_blocks_ids) > 0 NOW, still have
764
+ # some precision issues. We don't know whether a step should be
765
+ # cached or not before the first Fn blocks are processed.
766
+ if len(Fn_compute_blocks_ids()) > 0:
767
+ selected_Fn_transformer_blocks = [
768
+ selected_Fn_transformer_blocks[i]
769
+ for i in Fn_compute_blocks_ids()
770
+ if i < len(selected_Fn_transformer_blocks)
771
+ ]
772
+ return selected_Fn_transformer_blocks
773
+
774
+ @torch.compiler.disable
775
+ def _MN2n_single_transformer_blocks(self): # middle
776
+ # M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
777
+ selected_MN2n_single_transformer_blocks = []
778
+ if self.single_transformer_blocks is not None:
779
+ if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
780
+ selected_MN2n_single_transformer_blocks = (
781
+ self.single_transformer_blocks
782
+ )
783
+ else:
784
+ selected_MN2n_single_transformer_blocks = (
785
+ self.single_transformer_blocks[: -Bn_compute_blocks()]
786
+ )
787
+ return selected_MN2n_single_transformer_blocks
788
+
789
+ @torch.compiler.disable
790
+ def _MN2n_transformer_blocks(self):
791
+ # M(N-2n): only transformer_blocks [n,...,N-n], middle
792
+ if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
793
+ selected_MN2n_transformer_blocks = self.transformer_blocks[
794
+ Fn_compute_blocks() :
795
+ ]
796
+ else:
797
+ selected_MN2n_transformer_blocks = self.transformer_blocks[
798
+ Fn_compute_blocks() : -Bn_compute_blocks()
799
+ ]
800
+ return selected_MN2n_transformer_blocks
801
+
802
+ @torch.compiler.disable
803
+ def _Bn_single_transformer_blocks(self):
804
+ # Bn: single_transformer_blocks [N-n+1,...,N-1]
805
+ selected_Bn_single_transformer_blocks = []
806
+ if self.single_transformer_blocks is not None:
807
+ selected_Bn_single_transformer_blocks = (
808
+ self.single_transformer_blocks[-Bn_compute_blocks() :]
809
+ )
810
+ return selected_Bn_single_transformer_blocks
811
+
812
+ @torch.compiler.disable
813
+ def _Bn_transformer_blocks(self):
814
+ # Bn: transformer_blocks [N-n+1,...,N-1]
815
+ selected_Bn_transformer_blocks = self.transformer_blocks[
816
+ -Bn_compute_blocks() :
817
+ ]
818
+ return selected_Bn_transformer_blocks
819
+
820
+ def call_Fn_transformer_blocks(
821
+ self,
822
+ hidden_states: torch.Tensor,
823
+ encoder_hidden_states: torch.Tensor,
824
+ *args,
825
+ **kwargs,
826
+ ):
827
+ assert Fn_compute_blocks() <= len(self.transformer_blocks), (
828
+ f"Fn_compute_blocks {Fn_compute_blocks()} must be less than "
829
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
830
+ )
831
+ for block in self._Fn_transformer_blocks():
832
+ hidden_states = block(
833
+ hidden_states,
834
+ encoder_hidden_states,
835
+ *args,
836
+ **kwargs,
837
+ )
838
+ if not isinstance(hidden_states, torch.Tensor):
839
+ hidden_states, encoder_hidden_states = hidden_states
840
+ if not self.return_hidden_states_first:
841
+ hidden_states, encoder_hidden_states = (
842
+ encoder_hidden_states,
843
+ hidden_states,
844
+ )
845
+
846
+ return hidden_states, encoder_hidden_states
847
+
848
+ def call_MN2n_transformer_blocks(
849
+ self,
850
+ hidden_states: torch.Tensor,
851
+ encoder_hidden_states: torch.Tensor,
852
+ *args,
853
+ **kwargs,
854
+ ):
855
+ original_hidden_states = hidden_states
856
+ original_encoder_hidden_states = encoder_hidden_states
857
+ if self.single_transformer_blocks is not None:
858
+ for block in self.transformer_blocks[Fn_compute_blocks() :]:
859
+ hidden_states = block(
860
+ hidden_states,
861
+ encoder_hidden_states,
862
+ *args,
863
+ **kwargs,
864
+ )
865
+ if not isinstance(hidden_states, torch.Tensor):
866
+ hidden_states, encoder_hidden_states = hidden_states
867
+ if not self.return_hidden_states_first:
868
+ hidden_states, encoder_hidden_states = (
869
+ encoder_hidden_states,
870
+ hidden_states,
871
+ )
872
+
873
+ hidden_states = torch.cat(
874
+ [encoder_hidden_states, hidden_states], dim=1
875
+ )
876
+ for block in self._MN2n_single_transformer_blocks():
877
+ hidden_states = block(
878
+ hidden_states,
879
+ *args,
880
+ **kwargs,
881
+ )
882
+ encoder_hidden_states, hidden_states = hidden_states.split(
883
+ [
884
+ encoder_hidden_states.shape[1],
885
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
886
+ ],
887
+ dim=1,
888
+ )
889
+ else:
890
+ for block in self._MN2n_transformer_blocks():
891
+ hidden_states = block(
892
+ hidden_states,
893
+ encoder_hidden_states,
894
+ *args,
895
+ **kwargs,
896
+ )
897
+ if not isinstance(hidden_states, torch.Tensor):
898
+ hidden_states, encoder_hidden_states = hidden_states
899
+ if not self.return_hidden_states_first:
900
+ hidden_states, encoder_hidden_states = (
901
+ encoder_hidden_states,
902
+ hidden_states,
903
+ )
904
+
905
+ # hidden_states_shape = hidden_states.shape
906
+ # encoder_hidden_states_shape = encoder_hidden_states.shape
907
+ hidden_states = (
908
+ hidden_states.reshape(-1)
909
+ .contiguous()
910
+ .reshape(original_hidden_states.shape)
911
+ )
912
+ encoder_hidden_states = (
913
+ encoder_hidden_states.reshape(-1)
914
+ .contiguous()
915
+ .reshape(original_encoder_hidden_states.shape)
916
+ )
917
+
918
+ # hidden_states = hidden_states.contiguous()
919
+ # encoder_hidden_states = encoder_hidden_states.contiguous()
920
+
921
+ hidden_states_residual = hidden_states - original_hidden_states
922
+ encoder_hidden_states_residual = (
923
+ encoder_hidden_states - original_encoder_hidden_states
924
+ )
925
+
926
+ hidden_states_residual = (
927
+ hidden_states_residual.reshape(-1)
928
+ .contiguous()
929
+ .reshape(original_hidden_states.shape)
930
+ )
931
+ encoder_hidden_states_residual = (
932
+ encoder_hidden_states_residual.reshape(-1)
933
+ .contiguous()
934
+ .reshape(original_encoder_hidden_states.shape)
935
+ )
936
+
937
+ return (
938
+ hidden_states,
939
+ encoder_hidden_states,
940
+ hidden_states_residual,
941
+ encoder_hidden_states_residual,
942
+ )
943
+
944
+ @torch.compiler.disable
945
+ def _Bn_i_single_hidden_states_residual(
946
+ self,
947
+ Bn_i_hidden_states: torch.Tensor,
948
+ Bn_i_original_hidden_states: torch.Tensor,
949
+ original_hidden_states: torch.Tensor,
950
+ original_encoder_hidden_states: torch.Tensor,
951
+ ):
952
+ # Split the Bn_i_hidden_states and Bn_i_original_hidden_states
953
+ # into encoder_hidden_states and hidden_states.
954
+ Bn_i_hidden_states, Bn_i_encoder_hidden_states = (
955
+ self._split_Bn_i_single_hidden_states(
956
+ Bn_i_hidden_states,
957
+ original_hidden_states,
958
+ original_encoder_hidden_states,
959
+ )
960
+ )
961
+ # Split the Bn_i_original_hidden_states into encoder_hidden_states
962
+ # and hidden_states.
963
+ Bn_i_original_hidden_states, Bn_i_original_encoder_hidden_states = (
964
+ self._split_Bn_i_single_hidden_states(
965
+ Bn_i_original_hidden_states,
966
+ original_hidden_states,
967
+ original_encoder_hidden_states,
968
+ )
969
+ )
970
+
971
+ # Compute the residuals for the Bn_i_hidden_states and
972
+ # Bn_i_encoder_hidden_states.
973
+ Bn_i_hidden_states_residual = (
974
+ Bn_i_hidden_states - Bn_i_original_hidden_states
975
+ )
976
+ Bn_i_encoder_hidden_states_residual = (
977
+ Bn_i_encoder_hidden_states - Bn_i_original_encoder_hidden_states
978
+ )
979
+ return (
980
+ Bn_i_hidden_states_residual,
981
+ Bn_i_encoder_hidden_states_residual,
982
+ )
983
+
984
+ @torch.compiler.disable
985
+ def _split_Bn_i_single_hidden_states(
986
+ self,
987
+ Bn_i_hidden_states: torch.Tensor,
988
+ original_hidden_states: torch.Tensor,
989
+ original_encoder_hidden_states: torch.Tensor,
990
+ ):
991
+ # Split the Bn_i_hidden_states into encoder_hidden_states and hidden_states.
992
+ Bn_i_encoder_hidden_states, Bn_i_hidden_states = (
993
+ Bn_i_hidden_states.split(
994
+ [
995
+ original_encoder_hidden_states.shape[1],
996
+ Bn_i_hidden_states.shape[1]
997
+ - original_encoder_hidden_states.shape[1],
998
+ ],
999
+ dim=1,
1000
+ )
1001
+ )
1002
+ # Reshape the Bn_i_hidden_states and Bn_i_encoder_hidden_states
1003
+ # to the original shape. This is necessary to ensure that the
1004
+ # residuals are computed correctly.
1005
+ Bn_i_hidden_states = (
1006
+ Bn_i_hidden_states.reshape(-1)
1007
+ .contiguous()
1008
+ .reshape(original_hidden_states.shape)
1009
+ )
1010
+ Bn_i_encoder_hidden_states = (
1011
+ Bn_i_encoder_hidden_states.reshape(-1)
1012
+ .contiguous()
1013
+ .reshape(original_encoder_hidden_states.shape)
1014
+ )
1015
+ return Bn_i_hidden_states, Bn_i_encoder_hidden_states
1016
+
1017
+ def _compute_and_cache_single_transformer_block(
1018
+ self,
1019
+ i: int, # Block index in the transformer blocks
1020
+ # Helper inputs for hidden states split and reshape
1021
+ original_hidden_states: torch.Tensor,
1022
+ original_encoder_hidden_states: torch.Tensor,
1023
+ # Below are the inputs to the block
1024
+ block, # The transformer block to be executed
1025
+ hidden_states: torch.Tensor,
1026
+ *args,
1027
+ **kwargs,
1028
+ ):
1029
+ # Helper function for `call_Bn_transformer_blocks`
1030
+ # Skip the blocks by reuse residual cache if they are not
1031
+ # in the Bn_compute_blocks_ids. NOTE: We should only skip
1032
+ # the specific Bn blocks in cache steps. Compute the block
1033
+ # and cache the residuals in non-cache steps.
1034
+
1035
+ # Normal steps: Compute the block and cache the residuals.
1036
+ if not self._is_in_cache_step():
1037
+ Bn_i_original_hidden_states = hidden_states
1038
+ hidden_states = block(
1039
+ hidden_states,
1040
+ *args,
1041
+ **kwargs,
1042
+ )
1043
+ # Cache residuals for the non-compute Bn blocks for
1044
+ # subsequent cache steps.
1045
+ if i not in Bn_compute_blocks_ids():
1046
+ Bn_i_hidden_states = hidden_states
1047
+ (
1048
+ Bn_i_hidden_states_residual,
1049
+ Bn_i_encoder_hidden_states_residual,
1050
+ ) = self._Bn_i_single_hidden_states_residual(
1051
+ Bn_i_hidden_states,
1052
+ Bn_i_original_hidden_states,
1053
+ original_hidden_states,
1054
+ original_encoder_hidden_states,
1055
+ )
1056
+
1057
+ # Save original_hidden_states for diff calculation.
1058
+ set_Bn_buffer(
1059
+ Bn_i_original_hidden_states,
1060
+ prefix=f"Bn_{i}_single_original",
1061
+ )
1062
+
1063
+ set_Bn_buffer(
1064
+ Bn_i_hidden_states_residual,
1065
+ prefix=f"Bn_{i}_single_residual",
1066
+ )
1067
+ set_Bn_encoder_buffer(
1068
+ Bn_i_encoder_hidden_states_residual,
1069
+ prefix=f"Bn_{i}_single_residual",
1070
+ )
1071
+ del Bn_i_hidden_states
1072
+ del Bn_i_hidden_states_residual
1073
+ del Bn_i_encoder_hidden_states_residual
1074
+
1075
+ del Bn_i_original_hidden_states
1076
+
1077
+ else:
1078
+ # Cache steps: Reuse the cached residuals.
1079
+ # Check if the block is in the Bn_compute_blocks_ids.
1080
+ if i in Bn_compute_blocks_ids():
1081
+ hidden_states = block(
1082
+ hidden_states,
1083
+ *args,
1084
+ **kwargs,
1085
+ )
1086
+ else:
1087
+ # Skip the block if it is not in the Bn_compute_blocks_ids.
1088
+ # Use the cached residuals instead.
1089
+ # Check if can use the cached residuals.
1090
+ if get_can_use_cache(
1091
+ hidden_states, # curr step
1092
+ parallelized=self._is_parallelized(),
1093
+ threshold=non_compute_blocks_diff_threshold(),
1094
+ prefix=f"Bn_{i}_single_original", # prev step
1095
+ ):
1096
+ Bn_i_original_hidden_states = hidden_states
1097
+ (
1098
+ Bn_i_original_hidden_states,
1099
+ Bn_i_original_encoder_hidden_states,
1100
+ ) = self._split_Bn_i_single_hidden_states(
1101
+ Bn_i_original_hidden_states,
1102
+ original_hidden_states,
1103
+ original_encoder_hidden_states,
1104
+ )
1105
+ hidden_states, encoder_hidden_states = (
1106
+ apply_hidden_states_residual(
1107
+ Bn_i_original_hidden_states,
1108
+ Bn_i_original_encoder_hidden_states,
1109
+ prefix=f"Bn_{i}_single_residual",
1110
+ )
1111
+ )
1112
+ hidden_states = torch.cat(
1113
+ [encoder_hidden_states, hidden_states],
1114
+ dim=1,
1115
+ )
1116
+ del Bn_i_original_hidden_states
1117
+ del Bn_i_original_encoder_hidden_states
1118
+ else:
1119
+ hidden_states = block(
1120
+ hidden_states,
1121
+ *args,
1122
+ **kwargs,
1123
+ )
1124
+ return hidden_states
1125
+
1126
+ def _compute_and_cache_transformer_block(
1127
+ self,
1128
+ i: int, # Block index in the transformer blocks
1129
+ # Below are the inputs to the block
1130
+ block, # The transformer block to be executed
1131
+ hidden_states: torch.Tensor,
1132
+ encoder_hidden_states: torch.Tensor,
1133
+ *args,
1134
+ **kwargs,
1135
+ ):
1136
+ # Helper function for `call_Bn_transformer_blocks`
1137
+ # Skip the blocks by reuse residual cache if they are not
1138
+ # in the Bn_compute_blocks_ids. NOTE: We should only skip
1139
+ # the specific Bn blocks in cache steps. Compute the block
1140
+ # and cache the residuals in non-cache steps.
1141
+
1142
+ # Normal steps: Compute the block and cache the residuals.
1143
+ if not self._is_in_cache_step():
1144
+ Bn_i_original_hidden_states = hidden_states
1145
+ Bn_i_original_encoder_hidden_states = encoder_hidden_states
1146
+ hidden_states = block(
1147
+ hidden_states,
1148
+ encoder_hidden_states,
1149
+ *args,
1150
+ **kwargs,
1151
+ )
1152
+ if not isinstance(hidden_states, torch.Tensor):
1153
+ hidden_states, encoder_hidden_states = hidden_states
1154
+ if not self.return_hidden_states_first:
1155
+ hidden_states, encoder_hidden_states = (
1156
+ encoder_hidden_states,
1157
+ hidden_states,
1158
+ )
1159
+ # Cache residuals for the non-compute Bn blocks for
1160
+ # subsequent cache steps.
1161
+ if i not in Bn_compute_blocks_ids():
1162
+ Bn_i_hidden_states_residual = (
1163
+ hidden_states - Bn_i_original_hidden_states
1164
+ )
1165
+ Bn_i_encoder_hidden_states_residual = (
1166
+ encoder_hidden_states - Bn_i_original_encoder_hidden_states
1167
+ )
1168
+
1169
+ # Save original_hidden_states for diff calculation.
1170
+ set_Bn_buffer(
1171
+ Bn_i_original_hidden_states,
1172
+ prefix=f"Bn_{i}_original",
1173
+ )
1174
+
1175
+ set_Bn_buffer(
1176
+ Bn_i_hidden_states_residual,
1177
+ prefix=f"Bn_{i}_residual",
1178
+ )
1179
+ set_Bn_encoder_buffer(
1180
+ Bn_i_encoder_hidden_states_residual,
1181
+ prefix=f"Bn_{i}_residual",
1182
+ )
1183
+ del Bn_i_hidden_states_residual
1184
+ del Bn_i_encoder_hidden_states_residual
1185
+
1186
+ del Bn_i_original_hidden_states
1187
+ del Bn_i_original_encoder_hidden_states
1188
+
1189
+ else:
1190
+ # Cache steps: Reuse the cached residuals.
1191
+ # Check if the block is in the Bn_compute_blocks_ids.
1192
+ if i in Bn_compute_blocks_ids():
1193
+ hidden_states = block(
1194
+ hidden_states,
1195
+ encoder_hidden_states,
1196
+ *args,
1197
+ **kwargs,
1198
+ )
1199
+ if not isinstance(hidden_states, torch.Tensor):
1200
+ hidden_states, encoder_hidden_states = hidden_states
1201
+ if not self.return_hidden_states_first:
1202
+ hidden_states, encoder_hidden_states = (
1203
+ encoder_hidden_states,
1204
+ hidden_states,
1205
+ )
1206
+ else:
1207
+ # Skip the block if it is not in the Bn_compute_blocks_ids.
1208
+ # Use the cached residuals instead.
1209
+ # Check if can use the cached residuals.
1210
+ if get_can_use_cache(
1211
+ hidden_states, # curr step
1212
+ parallelized=self._is_parallelized(),
1213
+ threshold=non_compute_blocks_diff_threshold(),
1214
+ prefix=f"Bn_{i}_original", # prev step
1215
+ ):
1216
+ hidden_states, encoder_hidden_states = (
1217
+ apply_hidden_states_residual(
1218
+ hidden_states,
1219
+ encoder_hidden_states,
1220
+ prefix=f"Bn_{i}_residual",
1221
+ )
1222
+ )
1223
+ else:
1224
+ hidden_states = block(
1225
+ hidden_states,
1226
+ encoder_hidden_states,
1227
+ *args,
1228
+ **kwargs,
1229
+ )
1230
+ if not isinstance(hidden_states, torch.Tensor):
1231
+ hidden_states, encoder_hidden_states = hidden_states
1232
+ if not self.return_hidden_states_first:
1233
+ hidden_states, encoder_hidden_states = (
1234
+ encoder_hidden_states,
1235
+ hidden_states,
1236
+ )
1237
+ return hidden_states, encoder_hidden_states
1238
+
1239
+ def call_Bn_transformer_blocks(
1240
+ self,
1241
+ hidden_states: torch.Tensor,
1242
+ encoder_hidden_states: torch.Tensor,
1243
+ *args,
1244
+ **kwargs,
1245
+ ):
1246
+ if Bn_compute_blocks() == 0:
1247
+ return hidden_states, encoder_hidden_states
1248
+
1249
+ original_hidden_states = hidden_states
1250
+ original_encoder_hidden_states = encoder_hidden_states
1251
+ if self.single_transformer_blocks is not None:
1252
+ assert Bn_compute_blocks() <= len(self.single_transformer_blocks), (
1253
+ f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1254
+ f"the number of single transformer blocks {len(self.single_transformer_blocks)}"
1255
+ )
1256
+
1257
+ torch._dynamo.graph_break()
1258
+ hidden_states = torch.cat(
1259
+ [encoder_hidden_states, hidden_states], dim=1
1260
+ )
1261
+ if len(Bn_compute_blocks_ids()) > 0:
1262
+ for i, block in enumerate(self._Bn_single_transformer_blocks()):
1263
+ hidden_states = (
1264
+ self._compute_and_cache_single_transformer_block(
1265
+ i,
1266
+ original_hidden_states,
1267
+ original_encoder_hidden_states,
1268
+ block,
1269
+ hidden_states,
1270
+ *args,
1271
+ **kwargs,
1272
+ )
1273
+ )
1274
+ else:
1275
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1276
+ for block in self._Bn_single_transformer_blocks():
1277
+ hidden_states = block(
1278
+ hidden_states,
1279
+ *args,
1280
+ **kwargs,
1281
+ )
1282
+ encoder_hidden_states, hidden_states = hidden_states.split(
1283
+ [
1284
+ encoder_hidden_states.shape[1],
1285
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
1286
+ ],
1287
+ dim=1,
1288
+ )
1289
+ torch._dynamo.graph_break()
1290
+ else:
1291
+ assert Bn_compute_blocks() <= len(self.transformer_blocks), (
1292
+ f"Bn_compute_blocks {Bn_compute_blocks()} must be less than "
1293
+ f"the number of transformer blocks {len(self.transformer_blocks)}"
1294
+ )
1295
+ torch._dynamo.graph_break()
1296
+ if len(Bn_compute_blocks_ids()) > 0:
1297
+ for i, block in enumerate(self._Bn_transformer_blocks()):
1298
+ hidden_states, encoder_hidden_states = (
1299
+ self._compute_and_cache_transformer_block(
1300
+ i,
1301
+ block,
1302
+ hidden_states,
1303
+ encoder_hidden_states,
1304
+ *args,
1305
+ **kwargs,
1306
+ )
1307
+ )
1308
+ else:
1309
+ # Compute all Bn blocks if no specific Bn compute blocks ids are set.
1310
+ for block in self._Bn_transformer_blocks():
1311
+ hidden_states = block(
1312
+ hidden_states,
1313
+ encoder_hidden_states,
1314
+ *args,
1315
+ **kwargs,
1316
+ )
1317
+ if not isinstance(hidden_states, torch.Tensor):
1318
+ hidden_states, encoder_hidden_states = hidden_states
1319
+ if not self.return_hidden_states_first:
1320
+ hidden_states, encoder_hidden_states = (
1321
+ encoder_hidden_states,
1322
+ hidden_states,
1323
+ )
1324
+ torch._dynamo.graph_break()
1325
+
1326
+ hidden_states = (
1327
+ hidden_states.reshape(-1)
1328
+ .contiguous()
1329
+ .reshape(original_hidden_states.shape)
1330
+ )
1331
+ encoder_hidden_states = (
1332
+ encoder_hidden_states.reshape(-1)
1333
+ .contiguous()
1334
+ .reshape(original_encoder_hidden_states.shape)
1335
+ )
1336
+ return hidden_states, encoder_hidden_states
1337
+
1338
+
1339
+ @torch.compiler.disable
1340
+ def patch_cached_stats(
1341
+ transformer,
1342
+ ):
1343
+ # Patch the cached stats to the transformer, the cached stats
1344
+ # will be reset for each calling of pipe.__call__(**kwargs).
1345
+ if transformer is None:
1346
+ return
1347
+
1348
+ cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
1349
+ if cached_transformer_blocks is None:
1350
+ return
1351
+
1352
+ if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
1353
+ cached_transformer_blocks = cached_transformer_blocks[0]
1354
+ if not isinstance(
1355
+ cached_transformer_blocks, DBCachedTransformerBlocks
1356
+ ) or not isinstance(transformer, torch.nn.Module):
1357
+ return
1358
+
1359
+ # TODO: Patch more cached stats to the transformer
1360
+ transformer._cached_steps = get_cached_steps()
1361
+ transformer._residual_diffs = get_residual_diffs()