cache-dit 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
  5. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
  8. cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
  9. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  10. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  11. cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
  12. cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
  13. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  14. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  15. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  16. cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
  17. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  18. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  19. cache_dit/cache_factory/cache_interface.py +20 -14
  20. cache_dit/cache_factory/cache_types.py +19 -2
  21. cache_dit/cache_factory/params_modifier.py +7 -7
  22. cache_dit/cache_factory/utils.py +18 -7
  23. cache_dit/utils.py +191 -54
  24. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/METADATA +9 -9
  25. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
  26. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
  27. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,16 @@ import torch
3
3
  import torch.distributed as dist
4
4
 
5
5
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
6
7
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
8
  CachedContextManager,
8
- CacheNotExistError,
9
+ ContextNotExistError,
10
+ )
11
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
12
+ PrunedContextManager,
9
13
  )
10
14
  from cache_dit.cache_factory import ForwardPattern
15
+ from cache_dit.cache_factory.cache_types import CacheType
11
16
  from cache_dit.logger import init_logger
12
17
 
13
18
  logger = init_logger(__name__)
@@ -31,7 +36,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
31
36
  # 1. Cache context configuration
32
37
  cache_prefix: str = None, # maybe un-need.
33
38
  cache_context: CachedContext | str = None,
34
- cache_manager: CachedContextManager = None,
39
+ context_manager: CachedContextManager = None,
40
+ cache_type: CacheType = CacheType.DBCache,
35
41
  **kwargs,
36
42
  ):
37
43
  super().__init__()
@@ -45,13 +51,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
45
51
  # 1. Cache context configuration
46
52
  self.cache_prefix = cache_prefix
47
53
  self.cache_context = cache_context
48
- self.cache_manager = cache_manager
54
+ self.context_manager = context_manager
55
+ self.cache_type = cache_type
49
56
 
50
57
  self._check_forward_pattern()
58
+ self._check_cache_type()
51
59
  logger.info(
52
- f"Match Cached Blocks: {self.__class__.__name__}, for "
60
+ f"Match Blocks: {self.__class__.__name__}, for "
53
61
  f"{self.cache_prefix}, cache_context: {self.cache_context}, "
54
- f"cache_manager: {self.cache_manager.name}."
62
+ f"context_manager: {self.context_manager.name}."
55
63
  )
56
64
 
57
65
  def _check_forward_pattern(self):
@@ -94,18 +102,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
94
102
  required_param in forward_parameters
95
103
  ), f"The input parameters must contains: {required_param}."
96
104
 
105
+ @torch.compiler.disable
106
+ def _check_cache_type(self):
107
+ assert (
108
+ self.cache_type == CacheType.DBCache
109
+ ), f"Cache type {self.cache_type} is not supported for CachedBlocks."
110
+
97
111
  @torch.compiler.disable
98
112
  def _check_cache_params(self):
99
- assert self.cache_manager.Fn_compute_blocks() <= len(
113
+ self._check_cache_type()
114
+ assert self.context_manager.Fn_compute_blocks() <= len(
100
115
  self.transformer_blocks
101
116
  ), (
102
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
117
+ f"Fn_compute_blocks {self.context_manager.Fn_compute_blocks()} must be less than "
103
118
  f"the number of transformer blocks {len(self.transformer_blocks)}"
104
119
  )
105
- assert self.cache_manager.Bn_compute_blocks() <= len(
120
+ assert self.context_manager.Bn_compute_blocks() <= len(
106
121
  self.transformer_blocks
107
122
  ), (
108
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
123
+ f"Bn_compute_blocks {self.context_manager.Bn_compute_blocks()} must be less than "
109
124
  f"the number of transformer blocks {len(self.transformer_blocks)}"
110
125
  )
111
126
 
@@ -174,9 +189,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
174
189
  ):
175
190
  # Use it's own cache context.
176
191
  try:
177
- self.cache_manager.set_context(self.cache_context)
192
+ self.context_manager.set_context(self.cache_context)
178
193
  self._check_cache_params()
179
- except CacheNotExistError as e:
194
+ except ContextNotExistError as e:
180
195
  logger.warning(f"Cache context not exist: {e}, skip cache.")
181
196
  # Call all blocks to process the hidden states.
182
197
  hidden_states, encoder_hidden_states = self.call_blocks(
@@ -203,38 +218,38 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
203
218
  Fn_hidden_states_residual = hidden_states - original_hidden_states
204
219
  del original_hidden_states
205
220
 
206
- self.cache_manager.mark_step_begin()
221
+ self.context_manager.mark_step_begin()
207
222
  # Residual L1 diff or Hidden States L1 diff
208
- can_use_cache = self.cache_manager.can_cache(
223
+ can_use_cache = self.context_manager.can_cache(
209
224
  (
210
225
  Fn_hidden_states_residual
211
- if not self.cache_manager.is_l1_diff_enabled()
226
+ if not self.context_manager.is_l1_diff_enabled()
212
227
  else hidden_states
213
228
  ),
214
229
  parallelized=self._is_parallelized(),
215
230
  prefix=(
216
231
  f"{self.cache_prefix}_Fn_residual"
217
- if not self.cache_manager.is_l1_diff_enabled()
232
+ if not self.context_manager.is_l1_diff_enabled()
218
233
  else f"{self.cache_prefix}_Fn_hidden_states"
219
234
  ),
220
235
  )
221
236
 
222
237
  torch._dynamo.graph_break()
223
238
  if can_use_cache:
224
- self.cache_manager.add_cached_step()
239
+ self.context_manager.add_cached_step()
225
240
  del Fn_hidden_states_residual
226
241
  hidden_states, encoder_hidden_states = (
227
- self.cache_manager.apply_cache(
242
+ self.context_manager.apply_cache(
228
243
  hidden_states,
229
244
  encoder_hidden_states,
230
245
  prefix=(
231
246
  f"{self.cache_prefix}_Bn_residual"
232
- if self.cache_manager.is_cache_residual()
247
+ if self.context_manager.is_cache_residual()
233
248
  else f"{self.cache_prefix}_Bn_hidden_states"
234
249
  ),
235
250
  encoder_prefix=(
236
251
  f"{self.cache_prefix}_Bn_residual"
237
- if self.cache_manager.is_encoder_cache_residual()
252
+ if self.context_manager.is_encoder_cache_residual()
238
253
  else f"{self.cache_prefix}_Bn_hidden_states"
239
254
  ),
240
255
  )
@@ -249,13 +264,13 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
249
264
  **kwargs,
250
265
  )
251
266
  else:
252
- self.cache_manager.set_Fn_buffer(
267
+ self.context_manager.set_Fn_buffer(
253
268
  Fn_hidden_states_residual,
254
269
  prefix=f"{self.cache_prefix}_Fn_residual",
255
270
  )
256
- if self.cache_manager.is_l1_diff_enabled():
271
+ if self.context_manager.is_l1_diff_enabled():
257
272
  # for hidden states L1 diff
258
- self.cache_manager.set_Fn_buffer(
273
+ self.context_manager.set_Fn_buffer(
259
274
  hidden_states,
260
275
  f"{self.cache_prefix}_Fn_hidden_states",
261
276
  )
@@ -273,24 +288,24 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
273
288
  **kwargs,
274
289
  )
275
290
  torch._dynamo.graph_break()
276
- if self.cache_manager.is_cache_residual():
277
- self.cache_manager.set_Bn_buffer(
291
+ if self.context_manager.is_cache_residual():
292
+ self.context_manager.set_Bn_buffer(
278
293
  hidden_states_residual,
279
294
  prefix=f"{self.cache_prefix}_Bn_residual",
280
295
  )
281
296
  else:
282
- self.cache_manager.set_Bn_buffer(
297
+ self.context_manager.set_Bn_buffer(
283
298
  hidden_states,
284
299
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
285
300
  )
286
301
 
287
- if self.cache_manager.is_encoder_cache_residual():
288
- self.cache_manager.set_Bn_encoder_buffer(
302
+ if self.context_manager.is_encoder_cache_residual():
303
+ self.context_manager.set_Bn_encoder_buffer(
289
304
  encoder_hidden_states_residual,
290
305
  prefix=f"{self.cache_prefix}_Bn_residual",
291
306
  )
292
307
  else:
293
- self.cache_manager.set_Bn_encoder_buffer(
308
+ self.context_manager.set_Bn_encoder_buffer(
294
309
  encoder_hidden_states,
295
310
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
296
311
  )
@@ -333,11 +348,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
333
348
  # If so, we can skip some Bn blocks and directly
334
349
  # use the cached values.
335
350
  return (
336
- self.cache_manager.get_current_step()
337
- in self.cache_manager.get_cached_steps()
351
+ self.context_manager.get_current_step()
352
+ in self.context_manager.get_cached_steps()
338
353
  ) or (
339
- self.cache_manager.get_current_step()
340
- in self.cache_manager.get_cfg_cached_steps()
354
+ self.context_manager.get_current_step()
355
+ in self.context_manager.get_cfg_cached_steps()
341
356
  )
342
357
 
343
358
  @torch.compiler.disable
@@ -346,20 +361,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
346
361
  # more stable diff calculation.
347
362
  # Fn: [0,...,n-1]
348
363
  selected_Fn_blocks = self.transformer_blocks[
349
- : self.cache_manager.Fn_compute_blocks()
364
+ : self.context_manager.Fn_compute_blocks()
350
365
  ]
351
366
  return selected_Fn_blocks
352
367
 
353
368
  @torch.compiler.disable
354
369
  def _Mn_blocks(self): # middle blocks
355
370
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
356
- if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
371
+ if self.context_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
357
372
  selected_Mn_blocks = self.transformer_blocks[
358
- self.cache_manager.Fn_compute_blocks() :
373
+ self.context_manager.Fn_compute_blocks() :
359
374
  ]
360
375
  else:
361
376
  selected_Mn_blocks = self.transformer_blocks[
362
- self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
377
+ self.context_manager.Fn_compute_blocks() : -self.context_manager.Bn_compute_blocks()
363
378
  ]
364
379
  return selected_Mn_blocks
365
380
 
@@ -367,7 +382,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
367
382
  def _Bn_blocks(self):
368
383
  # Bn: transformer_blocks [N-n+1,...,N-1]
369
384
  selected_Bn_blocks = self.transformer_blocks[
370
- -self.cache_manager.Bn_compute_blocks() :
385
+ -self.context_manager.Bn_compute_blocks() :
371
386
  ]
372
387
  return selected_Bn_blocks
373
388
 
@@ -441,7 +456,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
441
456
  *args,
442
457
  **kwargs,
443
458
  ):
444
- if self.cache_manager.Bn_compute_blocks() == 0:
459
+ if self.context_manager.Bn_compute_blocks() == 0:
445
460
  return hidden_states, encoder_hidden_states
446
461
 
447
462
  for block in self._Bn_blocks():
@@ -456,3 +471,223 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
456
471
  )
457
472
 
458
473
  return hidden_states, encoder_hidden_states
474
+
475
+
476
+ class PrunedBlocks_Pattern_Base(CachedBlocks_Pattern_Base):
477
+ pruned_blocks_step: int = 0 # number of pruned blocks in current step
478
+
479
+ def __init__(
480
+ self,
481
+ # 0. Transformer blocks configuration
482
+ transformer_blocks: torch.nn.ModuleList,
483
+ transformer: torch.nn.Module = None,
484
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
485
+ check_forward_pattern: bool = True,
486
+ check_num_outputs: bool = True,
487
+ # 1. Prune context configuration
488
+ cache_prefix: str = None, # maybe un-need.
489
+ cache_context: PrunedContext | str = None,
490
+ context_manager: PrunedContextManager = None,
491
+ cache_type: CacheType = CacheType.DBPrune,
492
+ **kwargs,
493
+ ):
494
+ super().__init__(
495
+ # 0. Transformer blocks configuration
496
+ transformer_blocks,
497
+ transformer=transformer,
498
+ forward_pattern=forward_pattern,
499
+ check_forward_pattern=check_forward_pattern,
500
+ check_num_outputs=check_num_outputs,
501
+ # 1. Cache context configuration
502
+ cache_prefix=cache_prefix,
503
+ cache_context=cache_context,
504
+ context_manager=context_manager,
505
+ cache_type=cache_type,
506
+ **kwargs,
507
+ )
508
+ assert isinstance(
509
+ self.context_manager, PrunedContextManager
510
+ ), "context_manager must be PrunedContextManager for PrunedBlocks."
511
+ self.context_manager: PrunedContextManager = (
512
+ self.context_manager
513
+ ) # For type hint
514
+
515
+ @torch.compiler.disable
516
+ def _check_cache_type(self):
517
+ assert (
518
+ self.cache_type == CacheType.DBPrune
519
+ ), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
520
+
521
+ def forward(
522
+ self,
523
+ hidden_states: torch.Tensor,
524
+ encoder_hidden_states: torch.Tensor,
525
+ *args,
526
+ **kwargs,
527
+ ):
528
+ self.pruned_blocks_step: int = 0 # reset for each step
529
+
530
+ # Use it's own cache context.
531
+ try:
532
+ self.context_manager.set_context(self.cache_context)
533
+ self._check_cache_params()
534
+ except ContextNotExistError as e:
535
+ logger.warning(f"Cache context not exist: {e}, skip prune.")
536
+ # Fallback to call all blocks to process the hidden states w/o prune.
537
+ hidden_states, encoder_hidden_states = self.call_blocks(
538
+ hidden_states,
539
+ encoder_hidden_states,
540
+ *args,
541
+ **kwargs,
542
+ )
543
+ return self._process_forward_outputs(
544
+ hidden_states,
545
+ encoder_hidden_states,
546
+ )
547
+
548
+ self.context_manager.mark_step_begin()
549
+
550
+ # Call all blocks with prune strategy to process the hidden states.
551
+ for i, block in enumerate(self.transformer_blocks):
552
+ hidden_states, encoder_hidden_states = self.compute_or_prune(
553
+ i,
554
+ block,
555
+ hidden_states,
556
+ encoder_hidden_states,
557
+ *args,
558
+ **kwargs,
559
+ )
560
+
561
+ self.context_manager.add_pruned_block(self.pruned_blocks_step)
562
+ self.context_manager.add_actual_block(self.num_blocks)
563
+
564
+ return self._process_forward_outputs(
565
+ hidden_states,
566
+ encoder_hidden_states,
567
+ )
568
+
569
+ @property
570
+ @torch.compiler.disable
571
+ def num_blocks(self):
572
+ return len(self.transformer_blocks)
573
+
574
+ @torch.compiler.disable
575
+ def _skip_prune(self, block_id: int) -> bool:
576
+ # Wrap for non compiled mode.
577
+ return block_id in self.context_manager.get_non_prune_blocks_ids(
578
+ self.num_blocks
579
+ )
580
+
581
+ @torch.compiler.disable
582
+ def _maybe_prune(
583
+ self,
584
+ block_id: int, # Block index in the transformer blocks
585
+ hidden_states: torch.Tensor, # hidden_states or residual
586
+ prefix: str = "Bn_original", # prev step name for single blocks
587
+ ):
588
+ # Wrap for non compiled mode.
589
+ can_use_prune = False
590
+ if not self._skip_prune(block_id):
591
+ can_use_prune = self.context_manager.can_prune(
592
+ hidden_states, # curr step
593
+ parallelized=self._is_parallelized(),
594
+ prefix=prefix, # prev step
595
+ )
596
+ self.pruned_blocks_step += int(can_use_prune)
597
+ return can_use_prune
598
+
599
+ def compute_or_prune(
600
+ self,
601
+ block_id: int, # Block index in the transformer blocks
602
+ # Below are the inputs to the block
603
+ block, # The transformer block to be executed
604
+ hidden_states: torch.Tensor,
605
+ encoder_hidden_states: torch.Tensor,
606
+ *args,
607
+ **kwargs,
608
+ ):
609
+ original_hidden_states = hidden_states
610
+ original_encoder_hidden_states = encoder_hidden_states
611
+
612
+ can_use_prune = self._maybe_prune(
613
+ block_id,
614
+ hidden_states,
615
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
616
+ )
617
+
618
+ # Prune steps: Prune current block and reuse the cached
619
+ # residuals for hidden states approximate.
620
+ torch._dynamo.graph_break()
621
+ if can_use_prune:
622
+ self.context_manager.add_pruned_step()
623
+ hidden_states, encoder_hidden_states = (
624
+ self.context_manager.apply_prune(
625
+ hidden_states,
626
+ encoder_hidden_states,
627
+ prefix=(
628
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
629
+ if self.context_manager.is_cache_residual()
630
+ else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
631
+ ),
632
+ encoder_prefix=(
633
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
634
+ if self.context_manager.is_encoder_cache_residual()
635
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
636
+ ),
637
+ )
638
+ )
639
+ torch._dynamo.graph_break()
640
+ else:
641
+ # Normal steps: Compute the block and cache the residuals.
642
+ hidden_states = block(
643
+ hidden_states,
644
+ encoder_hidden_states,
645
+ *args,
646
+ **kwargs,
647
+ )
648
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
649
+ hidden_states, encoder_hidden_states
650
+ )
651
+ if not self._skip_prune(block_id):
652
+ hidden_states = hidden_states.contiguous()
653
+ hidden_states_residual = hidden_states - original_hidden_states
654
+
655
+ if (
656
+ encoder_hidden_states is not None
657
+ and original_encoder_hidden_states is not None
658
+ ):
659
+ encoder_hidden_states = encoder_hidden_states.contiguous()
660
+ encoder_hidden_states_residual = (
661
+ encoder_hidden_states - original_encoder_hidden_states
662
+ )
663
+ else:
664
+ encoder_hidden_states_residual = None
665
+
666
+ self.context_manager.set_Fn_buffer(
667
+ original_hidden_states,
668
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
669
+ )
670
+ if self.context_manager.is_cache_residual():
671
+ self.context_manager.set_Bn_buffer(
672
+ hidden_states_residual,
673
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
674
+ )
675
+ else:
676
+ self.context_manager.set_Bn_buffer(
677
+ hidden_states,
678
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
679
+ )
680
+ if encoder_hidden_states_residual is not None:
681
+ if self.context_manager.is_encoder_cache_residual():
682
+ self.context_manager.set_Bn_encoder_buffer(
683
+ encoder_hidden_states_residual,
684
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
685
+ )
686
+ else:
687
+ self.context_manager.set_Bn_encoder_buffer(
688
+ encoder_hidden_states_residual,
689
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
690
+ )
691
+ torch._dynamo.graph_break()
692
+
693
+ return hidden_states, encoder_hidden_states
@@ -3,34 +3,61 @@ import torch
3
3
  from typing import Any
4
4
  from cache_dit.cache_factory import CachedContext
5
5
  from cache_dit.cache_factory import CachedContextManager
6
+ from cache_dit.cache_factory import PrunedContextManager
6
7
 
7
8
 
8
- def patch_cached_stats(
9
+ def apply_stats(
9
10
  module: torch.nn.Module | Any,
10
11
  cache_context: CachedContext | str = None,
11
- cache_manager: CachedContextManager = None,
12
+ context_manager: CachedContextManager | PrunedContextManager = None,
12
13
  ):
13
14
  # Patch the cached stats to the module, the cached stats
14
15
  # will be reset for each calling of pipe.__call__(**kwargs).
15
- if module is None or cache_manager is None:
16
+ if module is None or context_manager is None:
16
17
  return
17
18
 
18
19
  if cache_context is not None:
19
- cache_manager.set_context(cache_context)
20
+ context_manager.set_context(cache_context)
20
21
 
21
- # TODO: Patch more cached stats to the module
22
- module._cached_steps = cache_manager.get_cached_steps()
23
- module._residual_diffs = cache_manager.get_residual_diffs()
24
- module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
- module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
22
+ # Cache stats for Dual Block Cache
23
+ module._cached_steps = context_manager.get_cached_steps()
24
+ module._residual_diffs = context_manager.get_residual_diffs()
25
+ module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
26
+ module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
27
+ # Pruned stats for Dynamic Block Prune
28
+ if not isinstance(context_manager, PrunedContextManager):
29
+ return
30
+ module._pruned_steps = context_manager.get_pruned_steps()
31
+ module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
32
+ module._pruned_blocks = context_manager.get_pruned_blocks()
33
+ module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
34
+ module._actual_blocks = context_manager.get_actual_blocks()
35
+ module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
36
+ # Caculate pruned ratio
37
+ if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
38
+ module._pruned_ratio = sum(module._pruned_blocks) / sum(
39
+ module._actual_blocks
40
+ )
41
+ else:
42
+ module._pruned_ratio = None
43
+ if (
44
+ len(module._cfg_pruned_blocks) > 0
45
+ and sum(module._cfg_actual_blocks) > 0
46
+ ):
47
+ module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
48
+ module._cfg_actual_blocks
49
+ )
50
+ else:
51
+ module._cfg_pruned_ratio = None
26
52
 
27
53
 
28
- def remove_cached_stats(
54
+ def remove_stats(
29
55
  module: torch.nn.Module | Any,
30
56
  ):
31
57
  if module is None:
32
58
  return
33
59
 
60
+ # Dual Block Cache
34
61
  if hasattr(module, "_cached_steps"):
35
62
  del module._cached_steps
36
63
  if hasattr(module, "_residual_diffs"):
@@ -39,3 +66,21 @@ def remove_cached_stats(
39
66
  del module._cfg_cached_steps
40
67
  if hasattr(module, "_cfg_residual_diffs"):
41
68
  del module._cfg_residual_diffs
69
+
70
+ # Dynamic Block Prune
71
+ if hasattr(module, "_pruned_steps"):
72
+ del module._pruned_steps
73
+ if hasattr(module, "_cfg_pruned_steps"):
74
+ del module._cfg_pruned_steps
75
+ if hasattr(module, "_pruned_blocks"):
76
+ del module._pruned_blocks
77
+ if hasattr(module, "_cfg_pruned_blocks"):
78
+ del module._cfg_pruned_blocks
79
+ if hasattr(module, "_actual_blocks"):
80
+ del module._actual_blocks
81
+ if hasattr(module, "_cfg_actual_blocks"):
82
+ del module._cfg_actual_blocks
83
+ if hasattr(module, "_pruned_ratio"):
84
+ del module._pruned_ratio
85
+ if hasattr(module, "_cfg_pruned_ratio"):
86
+ del module._cfg_pruned_ratio
@@ -5,11 +5,24 @@ from cache_dit.cache_factory.cache_contexts.calibrators import (
5
5
  TaylorSeerCalibratorConfig,
6
6
  FoCaCalibratorConfig,
7
7
  )
8
+ from cache_dit.cache_factory.cache_contexts.cache_config import (
9
+ BasicCacheConfig,
10
+ DBCacheConfig,
11
+ )
8
12
  from cache_dit.cache_factory.cache_contexts.cache_context import (
9
13
  CachedContext,
10
- BasicCacheConfig,
11
14
  )
12
15
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
16
  CachedContextManager,
14
- CacheNotExistError,
17
+ ContextNotExistError,
18
+ )
19
+ from cache_dit.cache_factory.cache_contexts.prune_config import DBPruneConfig
20
+ from cache_dit.cache_factory.cache_contexts.prune_context import (
21
+ PrunedContext,
22
+ )
23
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
24
+ PrunedContextManager,
25
+ )
26
+ from cache_dit.cache_factory.cache_contexts.context_manager import (
27
+ ContextManager,
15
28
  )
@@ -0,0 +1,102 @@
1
+ import torch
2
+ import dataclasses
3
+ from typing import Optional, Union
4
+ from cache_dit.cache_factory.cache_types import CacheType
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class BasicCacheConfig:
12
+ # Default: Dual Block Cache with Flexible FnBn configuration.
13
+ cache_type: CacheType = CacheType.DBCache # DBCache, DBPrune, NONE
14
+
15
+ # Fn_compute_blocks: (`int`, *required*, defaults to 8):
16
+ # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
17
+ # at time step t, enabling the calculation of a more stable L1 diff and delivering more
18
+ # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
19
+ # for more details of DBCache.
20
+ Fn_compute_blocks: int = 8
21
+ # Bn_compute_blocks: (`int`, *required*, defaults to 0):
22
+ # Further fuses approximate information in the **last n** Transformer blocks to enhance
23
+ # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
24
+ # that use residual cache.
25
+ Bn_compute_blocks: int = 0
26
+ # residual_diff_threshold (`float`, *required*, defaults to 0.08):
27
+ # the value of residual diff threshold, a higher value leads to faster performance at the
28
+ # cost of lower precision.
29
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.08
30
+ # max_warmup_steps (`int`, *required*, defaults to 8):
31
+ # DBCache does not apply the caching strategy when the number of running steps is less than
32
+ # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
33
+ max_warmup_steps: int = 8 # DON'T Cache in warmup steps
34
+ # warmup_interval (`int`, *required*, defaults to 1):
35
+ # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
36
+ # in warmup steps will be computed, others will use dynamic cache.
37
+ warmup_interval: int = 1 # skip interval in warmup steps
38
+ # max_cached_steps (`int`, *required*, defaults to -1):
39
+ # DBCache disables the caching strategy when the previous cached steps exceed this value to
40
+ # prevent precision degradation.
41
+ max_cached_steps: int = -1 # for both CFG and non-CFG
42
+ # max_continuous_cached_steps (`int`, *required*, defaults to -1):
43
+ # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
44
+ # prevent precision degradation.
45
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
46
+ # enable_separate_cfg (`bool`, *required*, defaults to None):
47
+ # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
48
+ # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
49
+ # CogVideoX, HunyuanVideo, Mochi, etc.
50
+ enable_separate_cfg: Optional[bool] = None
51
+ # cfg_compute_first (`bool`, *required*, defaults to False):
52
+ # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
53
+ # 1, 3, 5, ... -> CFG step.
54
+ cfg_compute_first: bool = False
55
+ # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
56
+ # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
57
+ # use the computed diff from current non-CFG transformer step for current CFG step.
58
+ cfg_diff_compute_separate: bool = True
59
+
60
+ def update(self, **kwargs) -> "BasicCacheConfig":
61
+ for key, value in kwargs.items():
62
+ if hasattr(self, key):
63
+ setattr(self, key, value)
64
+ return self
65
+
66
+ def strify(self) -> str:
67
+ return (
68
+ f"{self.cache_type}_"
69
+ f"F{self.Fn_compute_blocks}"
70
+ f"B{self.Bn_compute_blocks}_"
71
+ f"W{self.max_warmup_steps}"
72
+ f"I{self.warmup_interval}"
73
+ f"M{max(0, self.max_cached_steps)}"
74
+ f"MC{max(0, self.max_continuous_cached_steps)}_"
75
+ f"R{self.residual_diff_threshold}"
76
+ )
77
+
78
+
79
+ @dataclasses.dataclass
80
+ class ExtraCacheConfig:
81
+ # Some other not very important settings for Dual Block Cache.
82
+ # NOTE: These flags maybe deprecated in the future and users
83
+ # should never use these extra configurations in their cases.
84
+
85
+ # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
86
+ # The hidden states diff threshold for DBCache if use hidden_states as
87
+ # cache (not residual).
88
+ l1_hidden_states_diff_threshold: float = None
89
+ # important_condition_threshold (`float`, *optional*, defaults to 0.0):
90
+ # Only select the most important tokens while calculating the l1 diff.
91
+ important_condition_threshold: float = 0.0
92
+ # downsample_factor (`int`, *optional*, defaults to 1):
93
+ # Downsample factor for Fn buffer, in order the save GPU memory.
94
+ downsample_factor: int = 1
95
+ # num_inference_steps (`int`, *optional*, defaults to -1):
96
+ # num_inference_steps for DiffusionPipeline, for future use.
97
+ num_inference_steps: int = -1
98
+
99
+
100
+ @dataclasses.dataclass
101
+ class DBCacheConfig(BasicCacheConfig):
102
+ pass # Just an alias for BasicCacheConfig