cache-dit 0.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (31) hide show
  1. cache_dit/__init__.py +0 -0
  2. cache_dit/_version.py +21 -0
  3. cache_dit/cache_factory/__init__.py +166 -0
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
  10. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
  16. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  17. cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
  18. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
  19. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
  20. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
  21. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
  22. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
  23. cache_dit/cache_factory/taylorseer.py +76 -0
  24. cache_dit/cache_factory/utils.py +0 -0
  25. cache_dit/logger.py +97 -0
  26. cache_dit/primitives.py +152 -0
  27. cache_dit-0.1.0.dist-info/METADATA +350 -0
  28. cache_dit-0.1.0.dist-info/RECORD +31 -0
  29. cache_dit-0.1.0.dist-info/WHEEL +5 -0
  30. cache_dit-0.1.0.dist-info/licenses/LICENSE +53 -0
  31. cache_dit-0.1.0.dist-info/top_level.txt +1 -0
@@ -0,0 +1,979 @@
1
+ # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/context.py
2
+ import logging
3
+ import contextlib
4
+ import dataclasses
5
+ from collections import defaultdict
6
+ from typing import Any, Dict, List, Optional, Union
7
+
8
+ import torch
9
+
10
+ import cache_dit.primitives as DP
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ @dataclasses.dataclass
17
+ class DBPPruneContext:
18
+ # Dyanmic Block Prune
19
+ # Aleast compute first `Fn` and last `Bn` blocks
20
+ # FnBn designs are inspired by the Dual Block Cache
21
+ Fn_compute_blocks: int = 8
22
+ Bn_compute_blocks: int = 8
23
+ # Non prune blocks IDs, e.g., [0, 1, 2, 3, 4, 5, 6, 7]
24
+ non_prune_blocks_ids: List[int] = dataclasses.field(default_factory=list)
25
+ # L1 hidden states or residual diff threshold for Fn
26
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.0
27
+ l1_hidden_states_diff_threshold: float = None
28
+ important_condition_threshold: float = 0.0
29
+ # Compute the dynamic prune threshold based on the mean of the
30
+ # residual diffs of the previous computed or pruned blocks.
31
+ # But, also limit mean_diff to be at least 2x the residual_diff_threshold
32
+ # to avoid too aggressive pruning.
33
+ enable_dynamic_prune_threshold: bool = False
34
+ max_dynamic_prune_threshold: float = None
35
+ dynamic_prune_threshold_relax_ratio: float = 1.25
36
+ # Residual cache update interval, in steps.
37
+ residual_cache_update_interval: int = 1
38
+
39
+ # Buffer for storing the residuals and other tensors
40
+ buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
41
+
42
+ # Other settings
43
+ downsample_factor: int = 1
44
+ num_inference_steps: int = -1
45
+ warmup_steps: int = 0 # DON'T pruned in warmup steps
46
+ # DON'T prune if the number of pruned steps >= max_pruned_steps
47
+ max_pruned_steps: int = -1
48
+
49
+ # Statistics
50
+ executed_steps: int = 0
51
+ pruned_blocks: List[int] = dataclasses.field(default_factory=list)
52
+ actual_blocks: List[int] = dataclasses.field(default_factory=list)
53
+ # Residual diffs for each step, [step: list[float]]
54
+ residual_diffs: Dict[str, List[float]] = dataclasses.field(
55
+ default_factory=lambda: defaultdict(list),
56
+ )
57
+
58
+ def get_residual_diff_threshold(self):
59
+ residual_diff_threshold = self.residual_diff_threshold
60
+ if self.l1_hidden_states_diff_threshold is not None:
61
+ # Use the L1 hidden states diff threshold if set
62
+ residual_diff_threshold = self.l1_hidden_states_diff_threshold
63
+ if isinstance(residual_diff_threshold, torch.Tensor):
64
+ residual_diff_threshold = residual_diff_threshold.item()
65
+ if self.enable_dynamic_prune_threshold:
66
+ # Compute the dynamic prune threshold based on the mean of the
67
+ # residual diffs of the previous computed or pruned blocks.
68
+ step = self.get_current_step()
69
+ if step >= 0 and step in self.residual_diffs:
70
+ # TODO: Should we only use the last 5 diffs
71
+ diffs = self.residual_diffs[step][:]
72
+ diffs = [d for d in diffs if d > 0.0]
73
+ if diffs:
74
+ mean_diff = sum(diffs) / len(diffs)
75
+ relaxed_diff = (
76
+ mean_diff * self.dynamic_prune_threshold_relax_ratio
77
+ )
78
+ if self.max_dynamic_prune_threshold is None:
79
+ max_dynamic_prune_threshold = (
80
+ 2 * residual_diff_threshold
81
+ )
82
+ else:
83
+ max_dynamic_prune_threshold = (
84
+ self.max_dynamic_prune_threshold
85
+ )
86
+ if relaxed_diff < max_dynamic_prune_threshold:
87
+ # If the mean diff is less than twice the threshold,
88
+ # we can use it as the dynamic prune threshold.
89
+ residual_diff_threshold = (
90
+ relaxed_diff
91
+ if relaxed_diff > residual_diff_threshold
92
+ else residual_diff_threshold
93
+ )
94
+ if logger.isEnabledFor(logging.DEBUG):
95
+ logger.debug(
96
+ f"Dynamic prune threshold for step {step}: "
97
+ f"{residual_diff_threshold:.6f}"
98
+ )
99
+ return residual_diff_threshold
100
+
101
+ def get_buffer(self, name):
102
+ return self.buffers.get(name)
103
+
104
+ def set_buffer(self, name, buffer):
105
+ self.buffers[name] = buffer
106
+
107
+ def remove_buffer(self, name):
108
+ if name in self.buffers:
109
+ del self.buffers[name]
110
+
111
+ def clear_buffers(self):
112
+ self.buffers.clear()
113
+
114
+ def mark_step_begin(self):
115
+ self.executed_steps += 1
116
+ if self.get_current_step() == 0:
117
+ self.pruned_blocks.clear()
118
+ self.actual_blocks.clear()
119
+ self.residual_diffs.clear()
120
+
121
+ def add_pruned_block(self, num_blocks):
122
+ self.pruned_blocks.append(num_blocks)
123
+
124
+ def add_actual_block(self, num_blocks):
125
+ self.actual_blocks.append(num_blocks)
126
+
127
+ def add_residual_diff(self, diff):
128
+ if isinstance(diff, torch.Tensor):
129
+ diff = diff.item()
130
+ step = self.get_current_step()
131
+ self.residual_diffs[step].append(diff)
132
+ max_num_block_diffs = 1000
133
+ # Avoid memory leak, keep only the last 1000 diffs
134
+ if len(self.residual_diffs[step]) > max_num_block_diffs:
135
+ self.residual_diffs[step] = self.residual_diffs[step][
136
+ -max_num_block_diffs:
137
+ ]
138
+ if logger.isEnabledFor(logging.DEBUG):
139
+ logger.debug(
140
+ f"Step {step}, block: {len(self.residual_diffs[step])}, "
141
+ f"residual diff: {diff:.6f}"
142
+ )
143
+
144
+ def get_current_step(self):
145
+ return self.executed_steps - 1
146
+
147
+ def is_in_warmup(self):
148
+ return self.get_current_step() < self.warmup_steps
149
+
150
+
151
+ @torch.compiler.disable
152
+ def get_residual_diff_threshold():
153
+ prune_context = get_current_prune_context()
154
+ assert prune_context is not None, "prune_context must be set before"
155
+ return prune_context.get_residual_diff_threshold()
156
+
157
+
158
+ @torch.compiler.disable
159
+ def get_buffer(name):
160
+ prune_context = get_current_prune_context()
161
+ assert prune_context is not None, "prune_context must be set before"
162
+ return prune_context.get_buffer(name)
163
+
164
+
165
+ @torch.compiler.disable
166
+ def set_buffer(name, buffer):
167
+ prune_context = get_current_prune_context()
168
+ assert prune_context is not None, "prune_context must be set before"
169
+ prune_context.set_buffer(name, buffer)
170
+
171
+
172
+ @torch.compiler.disable
173
+ def remove_buffer(name):
174
+ prune_context = get_current_prune_context()
175
+ assert prune_context is not None, "prune_context must be set before"
176
+ prune_context.remove_buffer(name)
177
+
178
+
179
+ @torch.compiler.disable
180
+ def mark_step_begin():
181
+ prune_context = get_current_prune_context()
182
+ assert prune_context is not None, "prune_context must be set before"
183
+ prune_context.mark_step_begin()
184
+
185
+
186
+ @torch.compiler.disable
187
+ def get_current_step():
188
+ prune_context = get_current_prune_context()
189
+ assert prune_context is not None, "prune_context must be set before"
190
+ return prune_context.get_current_step()
191
+
192
+
193
+ @torch.compiler.disable
194
+ def get_max_pruned_steps():
195
+ prune_context = get_current_prune_context()
196
+ assert prune_context is not None, "prune_context must be set before"
197
+ return prune_context.max_pruned_steps
198
+
199
+
200
+ @torch.compiler.disable
201
+ def add_pruned_block(num_blocks):
202
+ assert (
203
+ isinstance(num_blocks, int) and num_blocks >= 0
204
+ ), "num_blocks must be a non-negative integer"
205
+ prune_context = get_current_prune_context()
206
+ assert prune_context is not None, "prune_context must be set before"
207
+ prune_context.add_pruned_block(num_blocks)
208
+
209
+
210
+ @torch.compiler.disable
211
+ def get_pruned_blocks():
212
+ prune_context = get_current_prune_context()
213
+ assert prune_context is not None, "prune_context must be set before"
214
+ return prune_context.pruned_blocks.copy()
215
+
216
+
217
+ @torch.compiler.disable
218
+ def add_actual_block(num_blocks):
219
+ assert (
220
+ isinstance(num_blocks, int) and num_blocks >= 0
221
+ ), "num_blocks must be a non-negative integer"
222
+ prune_context = get_current_prune_context()
223
+ assert prune_context is not None, "prune_context must be set before"
224
+ prune_context.add_actual_block(num_blocks)
225
+
226
+
227
+ @torch.compiler.disable
228
+ def get_actual_blocks():
229
+ prune_context = get_current_prune_context()
230
+ assert prune_context is not None, "prune_context must be set before"
231
+ return prune_context.actual_blocks.copy()
232
+
233
+
234
+ @torch.compiler.disable
235
+ def get_pruned_steps():
236
+ prune_context = get_current_prune_context()
237
+ assert prune_context is not None, "prune_context must be set before"
238
+ pruned_blocks = get_pruned_blocks()
239
+ pruned_blocks = [x for x in pruned_blocks if x > 0]
240
+ return len(pruned_blocks)
241
+
242
+
243
+ @torch.compiler.disable
244
+ def is_in_warmup():
245
+ prune_context = get_current_prune_context()
246
+ assert prune_context is not None, "prune_context must be set before"
247
+ return prune_context.is_in_warmup()
248
+
249
+
250
+ @torch.compiler.disable
251
+ def is_l1_diff_enabled():
252
+ prune_context = get_current_prune_context()
253
+ assert prune_context is not None, "prune_context must be set before"
254
+ return (
255
+ prune_context.l1_hidden_states_diff_threshold is not None
256
+ and prune_context.l1_hidden_states_diff_threshold > 0.0
257
+ )
258
+
259
+
260
+ @torch.compiler.disable
261
+ def add_residual_diff(diff):
262
+ prune_context = get_current_prune_context()
263
+ assert prune_context is not None, "prune_context must be set before"
264
+ prune_context.add_residual_diff(diff)
265
+
266
+
267
+ @torch.compiler.disable
268
+ def get_residual_diffs():
269
+ prune_context = get_current_prune_context()
270
+ assert prune_context is not None, "prune_context must be set before"
271
+ # Return a copy of the residual diffs to avoid modification
272
+ return prune_context.residual_diffs.copy()
273
+
274
+
275
+ @torch.compiler.disable
276
+ def get_important_condition_threshold():
277
+ prune_context = get_current_prune_context()
278
+ assert prune_context is not None, "prune_context must be set before"
279
+ return prune_context.important_condition_threshold
280
+
281
+
282
+ @torch.compiler.disable
283
+ def residual_cache_update_interval():
284
+ prune_context = get_current_prune_context()
285
+ assert prune_context is not None, "prune_context must be set before"
286
+ return prune_context.residual_cache_update_interval
287
+
288
+
289
+ @torch.compiler.disable
290
+ def Fn_compute_blocks():
291
+ prune_context = get_current_prune_context()
292
+ assert prune_context is not None, "prune_context must be set before"
293
+ assert (
294
+ prune_context.Fn_compute_blocks >= 0
295
+ ), "Fn_compute_blocks must be >= 0"
296
+ return prune_context.Fn_compute_blocks
297
+
298
+
299
+ @torch.compiler.disable
300
+ def Bn_compute_blocks():
301
+ prune_context = get_current_prune_context()
302
+ assert prune_context is not None, "prune_context must be set before"
303
+ assert (
304
+ prune_context.Bn_compute_blocks >= 0
305
+ ), "Bn_compute_blocks must be >= 0"
306
+ return prune_context.Bn_compute_blocks
307
+
308
+
309
+ @torch.compiler.disable
310
+ def get_non_prune_blocks_ids():
311
+ prune_context = get_current_prune_context()
312
+ assert prune_context is not None, "prune_context must be set before"
313
+ return prune_context.non_prune_blocks_ids
314
+
315
+
316
+ _current_prune_context: DBPPruneContext = None
317
+
318
+
319
+ def create_prune_context(*args, **kwargs):
320
+ return DBPPruneContext(*args, **kwargs)
321
+
322
+
323
+ def get_current_prune_context():
324
+ return _current_prune_context
325
+
326
+
327
+ def set_current_prune_context(prune_context=None):
328
+ global _current_prune_context
329
+ _current_prune_context = prune_context
330
+
331
+
332
+ def collect_prune_kwargs(default_attrs: dict, **kwargs):
333
+ # NOTE: This API will split kwargs into prune_kwargs and other_kwargs
334
+ # default_attrs: specific settings for different pipelines
335
+ prune_attrs = dataclasses.fields(DBPPruneContext)
336
+ prune_attrs = [
337
+ attr
338
+ for attr in prune_attrs
339
+ if hasattr(
340
+ DBPPruneContext,
341
+ attr.name,
342
+ )
343
+ ]
344
+ prune_kwargs = {
345
+ attr.name: kwargs.pop(
346
+ attr.name,
347
+ getattr(DBPPruneContext, attr.name),
348
+ )
349
+ for attr in prune_attrs
350
+ }
351
+ # Manually set sequence fields, such as non_prune_blocks_ids
352
+ prune_kwargs["non_prune_blocks_ids"] = kwargs.pop(
353
+ "non_prune_blocks_ids",
354
+ [],
355
+ )
356
+
357
+ assert default_attrs is not None, "default_attrs must be set before"
358
+ for attr in prune_attrs:
359
+ if attr.name in default_attrs:
360
+ prune_kwargs[attr.name] = default_attrs[attr.name]
361
+
362
+ if logger.isEnabledFor(logging.DEBUG):
363
+ logger.debug(f"Collected DBPrune kwargs: {prune_kwargs}")
364
+
365
+ return prune_kwargs, kwargs
366
+
367
+
368
+ @contextlib.contextmanager
369
+ def prune_context(prune_context):
370
+ global _current_prune_context
371
+ old_prune_context = _current_prune_context
372
+ _current_prune_context = prune_context
373
+ try:
374
+ yield
375
+ finally:
376
+ _current_prune_context = old_prune_context
377
+
378
+
379
+ @torch.compiler.disable
380
+ def are_two_tensors_similar(
381
+ t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
382
+ t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
383
+ *,
384
+ threshold: float,
385
+ parallelized: bool = False,
386
+ name: str = "Bn", # for debugging
387
+ ):
388
+ # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
389
+ # the threshold is always enabled, -2.0 means the shape is not matched.
390
+ if threshold <= 0.0:
391
+ add_residual_diff(-0.0)
392
+ return False
393
+
394
+ if threshold >= 1.0:
395
+ # If threshold is 1.0 or more, we consider them always similar.
396
+ add_residual_diff(-1.0)
397
+ return True
398
+
399
+ if t1.shape != t2.shape:
400
+ if logger.isEnabledFor(logging.DEBUG):
401
+ logger.debug(f"{name}, shape error: {t1.shape} != {t2.shape}")
402
+ add_residual_diff(-2.0)
403
+ return False
404
+
405
+ # Find the most significant token through t1 and t2, and
406
+ # consider the diff of the significant token. The more significant,
407
+ # the more important.
408
+ condition_thresh = get_important_condition_threshold()
409
+ if condition_thresh > 0.0:
410
+ raw_diff = (t1 - t2).abs() # [B, seq_len, d]
411
+ token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
412
+ token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
413
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
414
+ token_diff = token_m_df / token_m_t1 # [B, seq_len]
415
+ condition = token_diff > condition_thresh # [B, seq_len]
416
+ if condition.sum() > 0:
417
+ condition = condition.unsqueeze(-1) # [B, seq_len, 1]
418
+ condition = condition.expand_as(raw_diff) # [B, seq_len, d]
419
+ mean_diff = raw_diff[condition].mean()
420
+ mean_t1 = t1[condition].abs().mean()
421
+ else:
422
+ mean_diff = (t1 - t2).abs().mean()
423
+ mean_t1 = t1.abs().mean()
424
+ else:
425
+ # Use the mean of the absolute difference of the tensors
426
+ mean_diff = (t1 - t2).abs().mean()
427
+ mean_t1 = t1.abs().mean()
428
+
429
+ if parallelized:
430
+ mean_diff = DP.all_reduce_sync(mean_diff, "avg")
431
+ mean_t1 = DP.all_reduce_sync(mean_t1, "avg")
432
+
433
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
434
+ # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
435
+ # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
436
+ diff = (mean_diff / mean_t1).item()
437
+
438
+ if logger.isEnabledFor(logging.DEBUG):
439
+ logger.debug(f"{name}, diff: {diff:.6f}, threshold: {threshold:.6f}")
440
+
441
+ add_residual_diff(diff)
442
+
443
+ return diff < threshold
444
+
445
+
446
+ @torch.compiler.disable
447
+ def apply_hidden_states_residual(
448
+ hidden_states: torch.Tensor,
449
+ encoder_hidden_states: torch.Tensor,
450
+ name: str = "Bn",
451
+ encoder_name: str = "Bn_encoder",
452
+ ):
453
+ hidden_states_residual = get_buffer(f"{name}")
454
+
455
+ assert hidden_states_residual is not None, f"{name} must be set before"
456
+ hidden_states = hidden_states_residual + hidden_states
457
+
458
+ encoder_hidden_states_residual = get_buffer(f"{encoder_name}")
459
+ assert (
460
+ encoder_hidden_states_residual is not None
461
+ ), f"{encoder_name} must be set before"
462
+ encoder_hidden_states = (
463
+ encoder_hidden_states_residual + encoder_hidden_states
464
+ )
465
+
466
+ hidden_states = hidden_states.contiguous()
467
+ encoder_hidden_states = encoder_hidden_states.contiguous()
468
+
469
+ return hidden_states, encoder_hidden_states
470
+
471
+
472
+ @torch.compiler.disable
473
+ def get_downsample_factor():
474
+ prune_context = get_current_prune_context()
475
+ assert prune_context is not None, "prune_context must be set before"
476
+ return prune_context.downsample_factor
477
+
478
+
479
+ @torch.compiler.disable
480
+ def get_can_use_prune(
481
+ states_tensor: torch.Tensor, # hidden_states or residual
482
+ parallelized: bool = False,
483
+ threshold: Optional[float] = None, # can manually set threshold
484
+ name: str = "Bn",
485
+ ):
486
+ if is_in_warmup():
487
+ return False
488
+
489
+ pruned_steps = get_pruned_steps()
490
+ max_pruned_steps = get_max_pruned_steps()
491
+ if max_pruned_steps >= 0 and (pruned_steps >= max_pruned_steps):
492
+ if logger.isEnabledFor(logging.DEBUG):
493
+ logger.debug(
494
+ f"{name}, max_pruned_steps reached: {max_pruned_steps}, "
495
+ "cannot use prune."
496
+ )
497
+ return False
498
+
499
+ if threshold is None or threshold <= 0.0:
500
+ threshold = get_residual_diff_threshold()
501
+ if threshold <= 0.0:
502
+ return False
503
+
504
+ downsample_factor = get_downsample_factor()
505
+ prev_states_tensor = get_buffer(f"{name}")
506
+
507
+ if downsample_factor > 1:
508
+ states_tensor = states_tensor[..., ::downsample_factor]
509
+ states_tensor = states_tensor.contiguous()
510
+ if prev_states_tensor is not None:
511
+ prev_states_tensor = prev_states_tensor[..., ::downsample_factor]
512
+ prev_states_tensor = prev_states_tensor.contiguous()
513
+
514
+ return prev_states_tensor is not None and are_two_tensors_similar(
515
+ prev_states_tensor,
516
+ states_tensor,
517
+ threshold=threshold,
518
+ parallelized=parallelized,
519
+ name=name,
520
+ )
521
+
522
+
523
+ class DBPrunedTransformerBlocks(torch.nn.Module):
524
+ def __init__(
525
+ self,
526
+ transformer_blocks,
527
+ single_transformer_blocks=None,
528
+ *,
529
+ transformer=None,
530
+ return_hidden_states_first=True,
531
+ return_hidden_states_only=False,
532
+ ):
533
+ super().__init__()
534
+
535
+ self.transformer = transformer
536
+ self.transformer_blocks = transformer_blocks
537
+ self.single_transformer_blocks = single_transformer_blocks
538
+ self.return_hidden_states_first = return_hidden_states_first
539
+ self.return_hidden_states_only = return_hidden_states_only
540
+ self.pruned_blocks_step: int = 0
541
+
542
+ def forward(
543
+ self,
544
+ hidden_states: torch.Tensor,
545
+ encoder_hidden_states: torch.Tensor,
546
+ *args,
547
+ **kwargs,
548
+ ):
549
+ mark_step_begin()
550
+ self.pruned_blocks_step = 0
551
+ original_hidden_states = hidden_states
552
+
553
+ torch._dynamo.graph_break()
554
+ hidden_states, encoder_hidden_states = self.call_transformer_blocks(
555
+ hidden_states,
556
+ encoder_hidden_states,
557
+ *args,
558
+ **kwargs,
559
+ )
560
+
561
+ del original_hidden_states
562
+ torch._dynamo.graph_break()
563
+
564
+ add_pruned_block(self.pruned_blocks_step)
565
+ add_actual_block(self._num_transformer_blocks)
566
+ patch_pruned_stats(self.transformer)
567
+
568
+ return (
569
+ hidden_states
570
+ if self.return_hidden_states_only
571
+ else (
572
+ (hidden_states, encoder_hidden_states)
573
+ if self.return_hidden_states_first
574
+ else (encoder_hidden_states, hidden_states)
575
+ )
576
+ )
577
+
578
+ @property
579
+ @torch.compiler.disable
580
+ def _num_transformer_blocks(self):
581
+ # Total number of transformer blocks, including single transformer blocks.
582
+ num_blocks = len(self.transformer_blocks)
583
+ if self.single_transformer_blocks is not None:
584
+ num_blocks += len(self.single_transformer_blocks)
585
+ return num_blocks
586
+
587
+ @torch.compiler.disable
588
+ def _is_parallelized(self):
589
+ # Compatible with distributed inference.
590
+ return all(
591
+ (
592
+ self.transformer is not None,
593
+ getattr(self.transformer, "_is_parallelized", False),
594
+ )
595
+ )
596
+
597
+ @torch.compiler.disable
598
+ def _non_prune_blocks_ids(self):
599
+ # Never prune the first `Fn` and last `Bn` blocks.
600
+ num_blocks = self._num_transformer_blocks
601
+ Fn_compute_blocks_ = (
602
+ Fn_compute_blocks()
603
+ if Fn_compute_blocks() < num_blocks
604
+ else num_blocks
605
+ )
606
+ Fn_compute_blocks_ids = list(range(Fn_compute_blocks_))
607
+ Bn_compute_blocks_ = (
608
+ Bn_compute_blocks()
609
+ if Bn_compute_blocks() < num_blocks
610
+ else num_blocks
611
+ )
612
+ Bn_compute_blocks_ids = list(
613
+ range(
614
+ num_blocks - Bn_compute_blocks_,
615
+ num_blocks,
616
+ )
617
+ )
618
+ non_prune_blocks_ids = list(
619
+ set(
620
+ Fn_compute_blocks_ids
621
+ + Bn_compute_blocks_ids
622
+ + get_non_prune_blocks_ids()
623
+ )
624
+ )
625
+ non_prune_blocks_ids = [
626
+ d for d in non_prune_blocks_ids if d < num_blocks
627
+ ]
628
+ return sorted(non_prune_blocks_ids)
629
+
630
+ @torch.compiler.disable
631
+ def _compute_single_hidden_states_residual(
632
+ self,
633
+ single_hidden_states: torch.Tensor,
634
+ single_original_hidden_states: torch.Tensor,
635
+ # global original single hidden states
636
+ original_single_hidden_states: torch.Tensor,
637
+ original_single_encoder_hidden_states: torch.Tensor,
638
+ ):
639
+ single_hidden_states, single_encoder_hidden_states = (
640
+ self._split_single_hidden_states(
641
+ single_hidden_states,
642
+ original_single_hidden_states,
643
+ original_single_encoder_hidden_states,
644
+ )
645
+ )
646
+
647
+ single_original_hidden_states, single_original_encoder_hidden_states = (
648
+ self._split_single_hidden_states(
649
+ single_original_hidden_states,
650
+ original_single_hidden_states,
651
+ original_single_encoder_hidden_states,
652
+ )
653
+ )
654
+
655
+ single_hidden_states_residual = (
656
+ single_hidden_states - single_original_hidden_states
657
+ )
658
+ single_encoder_hidden_states_residual = (
659
+ single_encoder_hidden_states - single_original_encoder_hidden_states
660
+ )
661
+ return (
662
+ single_hidden_states_residual,
663
+ single_encoder_hidden_states_residual,
664
+ )
665
+
666
+ @torch.compiler.disable
667
+ def _split_single_hidden_states(
668
+ self,
669
+ single_hidden_states: torch.Tensor,
670
+ # global original single hidden states
671
+ original_single_hidden_states: torch.Tensor,
672
+ original_single_encoder_hidden_states: torch.Tensor,
673
+ ):
674
+ single_encoder_hidden_states, single_hidden_states = (
675
+ single_hidden_states.split(
676
+ [
677
+ original_single_encoder_hidden_states.shape[1],
678
+ single_hidden_states.shape[1]
679
+ - original_single_encoder_hidden_states.shape[1],
680
+ ],
681
+ dim=1,
682
+ )
683
+ )
684
+ # Reshape the single_hidden_states and single_encoder_hidden_states
685
+ # to the original shape. This is necessary to ensure that the
686
+ # residuals are computed correctly.
687
+ single_hidden_states = (
688
+ single_hidden_states.reshape(-1)
689
+ .contiguous()
690
+ .reshape(original_single_hidden_states.shape)
691
+ )
692
+ single_encoder_hidden_states = (
693
+ single_encoder_hidden_states.reshape(-1)
694
+ .contiguous()
695
+ .reshape(original_single_encoder_hidden_states.shape)
696
+ )
697
+ return single_hidden_states, single_encoder_hidden_states
698
+
699
+ @torch.compiler.disable
700
+ def _should_update_residuals(self):
701
+ # Wrap for non compiled mode.
702
+ # Check if the current step is a multiple of
703
+ # the residual cache update interval.
704
+ return get_current_step() % residual_cache_update_interval() == 0
705
+
706
+ @torch.compiler.disable
707
+ def _get_can_use_prune(
708
+ self,
709
+ block_id: int, # Block index in the transformer blocks
710
+ hidden_states: torch.Tensor, # hidden_states or residual
711
+ name: str = "Bn_original", # prev step name for single blocks
712
+ ):
713
+ # Wrap for non compiled mode.
714
+ can_use_prune = False
715
+ if block_id not in self._non_prune_blocks_ids():
716
+ can_use_prune = get_can_use_prune(
717
+ hidden_states, # curr step
718
+ parallelized=self._is_parallelized(),
719
+ name=name, # prev step
720
+ )
721
+ self.pruned_blocks_step += int(can_use_prune)
722
+ return can_use_prune
723
+
724
+ def _compute_or_prune_single_transformer_block(
725
+ self,
726
+ block_id: int, # Block index in the transformer blocks
727
+ # Helper inputs for hidden states split and reshape
728
+ # Global original single hidden states
729
+ original_single_hidden_states: torch.Tensor,
730
+ original_single_encoder_hidden_states: torch.Tensor,
731
+ # Below are the inputs to the block
732
+ block, # The transformer block to be executed
733
+ hidden_states: torch.Tensor,
734
+ *args,
735
+ **kwargs,
736
+ ):
737
+ # Helper function for `call_transformer_blocks`
738
+ # block_id: global block index in the transformer blocks +
739
+ # single_transformer_blocks
740
+ can_use_prune = self._get_can_use_prune(
741
+ block_id,
742
+ hidden_states, # hidden_states or residual
743
+ name=f"{block_id}_single_original", # prev step
744
+ )
745
+
746
+ # Prune steps: Prune current block and reuse the cached
747
+ # residuals for hidden states approximate.
748
+ if can_use_prune:
749
+ single_original_hidden_states = hidden_states
750
+ (
751
+ single_original_hidden_states,
752
+ single_original_encoder_hidden_states,
753
+ ) = self._split_single_hidden_states(
754
+ single_original_hidden_states,
755
+ original_single_hidden_states,
756
+ original_single_encoder_hidden_states,
757
+ )
758
+ hidden_states, encoder_hidden_states = apply_hidden_states_residual(
759
+ single_original_hidden_states,
760
+ single_original_encoder_hidden_states,
761
+ name=f"{block_id}_single_residual",
762
+ encoder_name=f"{block_id}_single_encoder_residual",
763
+ )
764
+ hidden_states = torch.cat(
765
+ [encoder_hidden_states, hidden_states],
766
+ dim=1,
767
+ )
768
+ del single_original_hidden_states
769
+ del single_original_encoder_hidden_states
770
+
771
+ else:
772
+ # Normal steps: Compute the block and cache the residuals.
773
+ single_original_hidden_states = hidden_states
774
+ hidden_states = block(
775
+ hidden_states,
776
+ *args,
777
+ **kwargs,
778
+ )
779
+
780
+ # Save original_hidden_states for diff calculation.
781
+ # May not be necessary to update the hidden
782
+ # states and residuals each step?
783
+ if self._should_update_residuals():
784
+ # Cache residuals for the non-compute Bn blocks for
785
+ # subsequent prune steps.
786
+ single_hidden_states = hidden_states
787
+ (
788
+ single_hidden_states_residual,
789
+ single_encoder_hidden_states_residual,
790
+ ) = self._compute_single_hidden_states_residual(
791
+ single_hidden_states,
792
+ single_original_hidden_states,
793
+ original_single_hidden_states,
794
+ original_single_encoder_hidden_states,
795
+ )
796
+
797
+ set_buffer(
798
+ f"{block_id}_single_original",
799
+ single_original_hidden_states,
800
+ )
801
+
802
+ set_buffer(
803
+ f"{block_id}_single_residual",
804
+ single_hidden_states_residual,
805
+ )
806
+ set_buffer(
807
+ f"{block_id}_single_encoder_residual",
808
+ single_encoder_hidden_states_residual,
809
+ )
810
+
811
+ del single_hidden_states
812
+ del single_hidden_states_residual
813
+ del single_encoder_hidden_states_residual
814
+
815
+ del single_original_hidden_states
816
+
817
+ return hidden_states
818
+
819
+ def _compute_or_prune_transformer_block(
820
+ self,
821
+ block_id: int, # Block index in the transformer blocks
822
+ # Below are the inputs to the block
823
+ block, # The transformer block to be executed
824
+ hidden_states: torch.Tensor,
825
+ encoder_hidden_states: torch.Tensor,
826
+ *args,
827
+ **kwargs,
828
+ ):
829
+ # Helper function for `call_transformer_blocks`
830
+ original_hidden_states = hidden_states
831
+ original_encoder_hidden_states = encoder_hidden_states
832
+
833
+ # block_id: global block index in the transformer blocks +
834
+ # single_transformer_blocks
835
+ can_use_prune = self._get_can_use_prune(
836
+ block_id,
837
+ hidden_states, # hidden_states or residual
838
+ name=f"{block_id}_original", # prev step
839
+ )
840
+
841
+ # Prune steps: Prune current block and reuse the cached
842
+ # residuals for hidden states approximate.
843
+ if can_use_prune:
844
+ hidden_states, encoder_hidden_states = apply_hidden_states_residual(
845
+ hidden_states,
846
+ encoder_hidden_states,
847
+ name=f"{block_id}_residual",
848
+ encoder_name=f"{block_id}_encoder_residual",
849
+ )
850
+ else:
851
+ # Normal steps: Compute the block and cache the residuals.
852
+ hidden_states = block(
853
+ hidden_states,
854
+ encoder_hidden_states,
855
+ *args,
856
+ **kwargs,
857
+ )
858
+ if not isinstance(hidden_states, torch.Tensor):
859
+ hidden_states, encoder_hidden_states = hidden_states
860
+ if not self.return_hidden_states_first:
861
+ hidden_states, encoder_hidden_states = (
862
+ encoder_hidden_states,
863
+ hidden_states,
864
+ )
865
+
866
+ # Save original_hidden_states for diff calculation.
867
+ # May not be necessary to update the hidden
868
+ # states and residuals each step?
869
+ if self._should_update_residuals():
870
+ # Cache residuals for the non-compute Bn blocks for
871
+ # subsequent prune steps.
872
+ hidden_states_residual = hidden_states - original_hidden_states
873
+ encoder_hidden_states_residual = (
874
+ encoder_hidden_states - original_encoder_hidden_states
875
+ )
876
+ set_buffer(
877
+ f"{block_id}_original",
878
+ original_hidden_states,
879
+ )
880
+
881
+ set_buffer(
882
+ f"{block_id}_residual",
883
+ hidden_states_residual,
884
+ )
885
+ set_buffer(
886
+ f"{block_id}_encoder_residual",
887
+ encoder_hidden_states_residual,
888
+ )
889
+ del hidden_states_residual
890
+ del encoder_hidden_states_residual
891
+
892
+ del original_hidden_states
893
+ del original_encoder_hidden_states
894
+
895
+ return hidden_states, encoder_hidden_states
896
+
897
+ def call_transformer_blocks(
898
+ self,
899
+ hidden_states: torch.Tensor,
900
+ encoder_hidden_states: torch.Tensor,
901
+ *args,
902
+ **kwargs,
903
+ ):
904
+ original_hidden_states = hidden_states
905
+ original_encoder_hidden_states = encoder_hidden_states
906
+
907
+ for i, block in enumerate(self.transformer_blocks):
908
+ hidden_states, encoder_hidden_states = (
909
+ self._compute_or_prune_transformer_block(
910
+ i,
911
+ block,
912
+ hidden_states,
913
+ encoder_hidden_states,
914
+ *args,
915
+ **kwargs,
916
+ )
917
+ )
918
+
919
+ if self.single_transformer_blocks is not None:
920
+ hidden_states = torch.cat(
921
+ [encoder_hidden_states, hidden_states], dim=1
922
+ )
923
+ for j, block in enumerate(self.single_transformer_blocks):
924
+ hidden_states = self._compute_or_prune_single_transformer_block(
925
+ j + len(self.transformer_blocks),
926
+ original_hidden_states,
927
+ original_encoder_hidden_states,
928
+ block,
929
+ hidden_states,
930
+ *args,
931
+ **kwargs,
932
+ )
933
+
934
+ encoder_hidden_states, hidden_states = hidden_states.split(
935
+ [
936
+ encoder_hidden_states.shape[1],
937
+ hidden_states.shape[1] - encoder_hidden_states.shape[1],
938
+ ],
939
+ dim=1,
940
+ )
941
+
942
+ hidden_states = (
943
+ hidden_states.reshape(-1)
944
+ .contiguous()
945
+ .reshape(original_hidden_states.shape)
946
+ )
947
+ encoder_hidden_states = (
948
+ encoder_hidden_states.reshape(-1)
949
+ .contiguous()
950
+ .reshape(original_encoder_hidden_states.shape)
951
+ )
952
+ return hidden_states, encoder_hidden_states
953
+
954
+
955
+ @torch.compiler.disable
956
+ def patch_pruned_stats(
957
+ transformer,
958
+ ):
959
+ # Patch the pruned stats to the transformer, the pruned stats
960
+ # will be reset for each calling of pipe.__call__(**kwargs).
961
+ if transformer is None:
962
+ return
963
+
964
+ pruned_transformer_blocks = getattr(transformer, "transformer_blocks", None)
965
+ if pruned_transformer_blocks is None:
966
+ return
967
+
968
+ if isinstance(pruned_transformer_blocks, torch.nn.ModuleList):
969
+ pruned_transformer_blocks = pruned_transformer_blocks[0]
970
+ if not isinstance(
971
+ pruned_transformer_blocks, DBPrunedTransformerBlocks
972
+ ) or not isinstance(transformer, torch.nn.Module):
973
+ return
974
+
975
+ # TODO: Patch more pruned stats to the transformer
976
+ transformer._pruned_blocks = get_pruned_blocks()
977
+ transformer._pruned_steps = get_pruned_steps()
978
+ transformer._residual_diffs = get_residual_diffs()
979
+ transformer._actual_blocks = get_actual_blocks()