cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
  5. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
  8. cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
  9. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  10. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  11. cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
  12. cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
  13. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  14. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  15. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  16. cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
  17. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  18. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  19. cache_dit/cache_factory/cache_interface.py +23 -14
  20. cache_dit/cache_factory/cache_types.py +19 -2
  21. cache_dit/cache_factory/params_modifier.py +7 -7
  22. cache_dit/cache_factory/utils.py +38 -27
  23. cache_dit/utils.py +191 -54
  24. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
  25. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
  26. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
  27. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
@@ -3,11 +3,16 @@ import torch
3
3
  import torch.distributed as dist
4
4
 
5
5
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
+ from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
6
7
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
8
  CachedContextManager,
8
- CacheNotExistError,
9
+ ContextNotExistError,
10
+ )
11
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
12
+ PrunedContextManager,
9
13
  )
10
14
  from cache_dit.cache_factory import ForwardPattern
15
+ from cache_dit.cache_factory.cache_types import CacheType
11
16
  from cache_dit.logger import init_logger
12
17
 
13
18
  logger = init_logger(__name__)
@@ -31,7 +36,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
31
36
  # 1. Cache context configuration
32
37
  cache_prefix: str = None, # maybe un-need.
33
38
  cache_context: CachedContext | str = None,
34
- cache_manager: CachedContextManager = None,
39
+ context_manager: CachedContextManager = None,
40
+ cache_type: CacheType = CacheType.DBCache,
35
41
  **kwargs,
36
42
  ):
37
43
  super().__init__()
@@ -45,13 +51,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
45
51
  # 1. Cache context configuration
46
52
  self.cache_prefix = cache_prefix
47
53
  self.cache_context = cache_context
48
- self.cache_manager = cache_manager
54
+ self.context_manager = context_manager
55
+ self.cache_type = cache_type
49
56
 
50
57
  self._check_forward_pattern()
58
+ self._check_cache_type()
51
59
  logger.info(
52
- f"Match Cached Blocks: {self.__class__.__name__}, for "
60
+ f"Match Blocks: {self.__class__.__name__}, for "
53
61
  f"{self.cache_prefix}, cache_context: {self.cache_context}, "
54
- f"cache_manager: {self.cache_manager.name}."
62
+ f"context_manager: {self.context_manager.name}."
55
63
  )
56
64
 
57
65
  def _check_forward_pattern(self):
@@ -94,18 +102,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
94
102
  required_param in forward_parameters
95
103
  ), f"The input parameters must contains: {required_param}."
96
104
 
105
+ @torch.compiler.disable
106
+ def _check_cache_type(self):
107
+ assert (
108
+ self.cache_type == CacheType.DBCache
109
+ ), f"Cache type {self.cache_type} is not supported for CachedBlocks."
110
+
97
111
  @torch.compiler.disable
98
112
  def _check_cache_params(self):
99
- assert self.cache_manager.Fn_compute_blocks() <= len(
113
+ self._check_cache_type()
114
+ assert self.context_manager.Fn_compute_blocks() <= len(
100
115
  self.transformer_blocks
101
116
  ), (
102
- f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
117
+ f"Fn_compute_blocks {self.context_manager.Fn_compute_blocks()} must be less than "
103
118
  f"the number of transformer blocks {len(self.transformer_blocks)}"
104
119
  )
105
- assert self.cache_manager.Bn_compute_blocks() <= len(
120
+ assert self.context_manager.Bn_compute_blocks() <= len(
106
121
  self.transformer_blocks
107
122
  ), (
108
- f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
123
+ f"Bn_compute_blocks {self.context_manager.Bn_compute_blocks()} must be less than "
109
124
  f"the number of transformer blocks {len(self.transformer_blocks)}"
110
125
  )
111
126
 
@@ -135,7 +150,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
135
150
  return hidden_states, encoder_hidden_states
136
151
 
137
152
  @torch.compiler.disable
138
- def _process_outputs(
153
+ def _process_block_outputs(
139
154
  self,
140
155
  hidden_states: torch.Tensor | tuple,
141
156
  encoder_hidden_states: torch.Tensor | None,
@@ -150,7 +165,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
150
165
  return hidden_states, encoder_hidden_states
151
166
 
152
167
  @torch.compiler.disable
153
- def _forward_outputs(
168
+ def _process_forward_outputs(
154
169
  self,
155
170
  hidden_states: torch.Tensor,
156
171
  encoder_hidden_states: torch.Tensor | None,
@@ -174,9 +189,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
174
189
  ):
175
190
  # Use it's own cache context.
176
191
  try:
177
- self.cache_manager.set_context(self.cache_context)
192
+ self.context_manager.set_context(self.cache_context)
178
193
  self._check_cache_params()
179
- except CacheNotExistError as e:
194
+ except ContextNotExistError as e:
180
195
  logger.warning(f"Cache context not exist: {e}, skip cache.")
181
196
  # Call all blocks to process the hidden states.
182
197
  hidden_states, encoder_hidden_states = self.call_blocks(
@@ -185,7 +200,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
185
200
  *args,
186
201
  **kwargs,
187
202
  )
188
- return self._forward_outputs(hidden_states, encoder_hidden_states)
203
+ return self._process_forward_outputs(
204
+ hidden_states,
205
+ encoder_hidden_states,
206
+ )
189
207
 
190
208
  original_hidden_states = hidden_states
191
209
  # Call first `n` blocks to process the hidden states for
@@ -200,38 +218,38 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
200
218
  Fn_hidden_states_residual = hidden_states - original_hidden_states
201
219
  del original_hidden_states
202
220
 
203
- self.cache_manager.mark_step_begin()
221
+ self.context_manager.mark_step_begin()
204
222
  # Residual L1 diff or Hidden States L1 diff
205
- can_use_cache = self.cache_manager.can_cache(
223
+ can_use_cache = self.context_manager.can_cache(
206
224
  (
207
225
  Fn_hidden_states_residual
208
- if not self.cache_manager.is_l1_diff_enabled()
226
+ if not self.context_manager.is_l1_diff_enabled()
209
227
  else hidden_states
210
228
  ),
211
229
  parallelized=self._is_parallelized(),
212
230
  prefix=(
213
231
  f"{self.cache_prefix}_Fn_residual"
214
- if not self.cache_manager.is_l1_diff_enabled()
232
+ if not self.context_manager.is_l1_diff_enabled()
215
233
  else f"{self.cache_prefix}_Fn_hidden_states"
216
234
  ),
217
235
  )
218
236
 
219
237
  torch._dynamo.graph_break()
220
238
  if can_use_cache:
221
- self.cache_manager.add_cached_step()
239
+ self.context_manager.add_cached_step()
222
240
  del Fn_hidden_states_residual
223
241
  hidden_states, encoder_hidden_states = (
224
- self.cache_manager.apply_cache(
242
+ self.context_manager.apply_cache(
225
243
  hidden_states,
226
244
  encoder_hidden_states,
227
245
  prefix=(
228
246
  f"{self.cache_prefix}_Bn_residual"
229
- if self.cache_manager.is_cache_residual()
247
+ if self.context_manager.is_cache_residual()
230
248
  else f"{self.cache_prefix}_Bn_hidden_states"
231
249
  ),
232
250
  encoder_prefix=(
233
251
  f"{self.cache_prefix}_Bn_residual"
234
- if self.cache_manager.is_encoder_cache_residual()
252
+ if self.context_manager.is_encoder_cache_residual()
235
253
  else f"{self.cache_prefix}_Bn_hidden_states"
236
254
  ),
237
255
  )
@@ -246,13 +264,13 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
246
264
  **kwargs,
247
265
  )
248
266
  else:
249
- self.cache_manager.set_Fn_buffer(
267
+ self.context_manager.set_Fn_buffer(
250
268
  Fn_hidden_states_residual,
251
269
  prefix=f"{self.cache_prefix}_Fn_residual",
252
270
  )
253
- if self.cache_manager.is_l1_diff_enabled():
271
+ if self.context_manager.is_l1_diff_enabled():
254
272
  # for hidden states L1 diff
255
- self.cache_manager.set_Fn_buffer(
273
+ self.context_manager.set_Fn_buffer(
256
274
  hidden_states,
257
275
  f"{self.cache_prefix}_Fn_hidden_states",
258
276
  )
@@ -270,24 +288,24 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
270
288
  **kwargs,
271
289
  )
272
290
  torch._dynamo.graph_break()
273
- if self.cache_manager.is_cache_residual():
274
- self.cache_manager.set_Bn_buffer(
291
+ if self.context_manager.is_cache_residual():
292
+ self.context_manager.set_Bn_buffer(
275
293
  hidden_states_residual,
276
294
  prefix=f"{self.cache_prefix}_Bn_residual",
277
295
  )
278
296
  else:
279
- self.cache_manager.set_Bn_buffer(
297
+ self.context_manager.set_Bn_buffer(
280
298
  hidden_states,
281
299
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
282
300
  )
283
301
 
284
- if self.cache_manager.is_encoder_cache_residual():
285
- self.cache_manager.set_Bn_encoder_buffer(
302
+ if self.context_manager.is_encoder_cache_residual():
303
+ self.context_manager.set_Bn_encoder_buffer(
286
304
  encoder_hidden_states_residual,
287
305
  prefix=f"{self.cache_prefix}_Bn_residual",
288
306
  )
289
307
  else:
290
- self.cache_manager.set_Bn_encoder_buffer(
308
+ self.context_manager.set_Bn_encoder_buffer(
291
309
  encoder_hidden_states,
292
310
  prefix=f"{self.cache_prefix}_Bn_hidden_states",
293
311
  )
@@ -304,7 +322,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
304
322
  # patch cached stats for blocks or remove it.
305
323
  torch._dynamo.graph_break()
306
324
 
307
- return self._forward_outputs(hidden_states, encoder_hidden_states)
325
+ return self._process_forward_outputs(
326
+ hidden_states,
327
+ encoder_hidden_states,
328
+ )
308
329
 
309
330
  @torch.compiler.disable
310
331
  def _is_parallelized(self):
@@ -327,11 +348,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
327
348
  # If so, we can skip some Bn blocks and directly
328
349
  # use the cached values.
329
350
  return (
330
- self.cache_manager.get_current_step()
331
- in self.cache_manager.get_cached_steps()
351
+ self.context_manager.get_current_step()
352
+ in self.context_manager.get_cached_steps()
332
353
  ) or (
333
- self.cache_manager.get_current_step()
334
- in self.cache_manager.get_cfg_cached_steps()
354
+ self.context_manager.get_current_step()
355
+ in self.context_manager.get_cfg_cached_steps()
335
356
  )
336
357
 
337
358
  @torch.compiler.disable
@@ -340,20 +361,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
340
361
  # more stable diff calculation.
341
362
  # Fn: [0,...,n-1]
342
363
  selected_Fn_blocks = self.transformer_blocks[
343
- : self.cache_manager.Fn_compute_blocks()
364
+ : self.context_manager.Fn_compute_blocks()
344
365
  ]
345
366
  return selected_Fn_blocks
346
367
 
347
368
  @torch.compiler.disable
348
369
  def _Mn_blocks(self): # middle blocks
349
370
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
350
- if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
371
+ if self.context_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
351
372
  selected_Mn_blocks = self.transformer_blocks[
352
- self.cache_manager.Fn_compute_blocks() :
373
+ self.context_manager.Fn_compute_blocks() :
353
374
  ]
354
375
  else:
355
376
  selected_Mn_blocks = self.transformer_blocks[
356
- self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
377
+ self.context_manager.Fn_compute_blocks() : -self.context_manager.Bn_compute_blocks()
357
378
  ]
358
379
  return selected_Mn_blocks
359
380
 
@@ -361,7 +382,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
361
382
  def _Bn_blocks(self):
362
383
  # Bn: transformer_blocks [N-n+1,...,N-1]
363
384
  selected_Bn_blocks = self.transformer_blocks[
364
- -self.cache_manager.Bn_compute_blocks() :
385
+ -self.context_manager.Bn_compute_blocks() :
365
386
  ]
366
387
  return selected_Bn_blocks
367
388
 
@@ -379,7 +400,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
379
400
  *args,
380
401
  **kwargs,
381
402
  )
382
- hidden_states, encoder_hidden_states = self._process_outputs(
403
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
383
404
  hidden_states, encoder_hidden_states
384
405
  )
385
406
 
@@ -401,7 +422,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
401
422
  *args,
402
423
  **kwargs,
403
424
  )
404
- hidden_states, encoder_hidden_states = self._process_outputs(
425
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
405
426
  hidden_states, encoder_hidden_states
406
427
  )
407
428
 
@@ -435,7 +456,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
435
456
  *args,
436
457
  **kwargs,
437
458
  ):
438
- if self.cache_manager.Bn_compute_blocks() == 0:
459
+ if self.context_manager.Bn_compute_blocks() == 0:
439
460
  return hidden_states, encoder_hidden_states
440
461
 
441
462
  for block in self._Bn_blocks():
@@ -445,8 +466,228 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
445
466
  *args,
446
467
  **kwargs,
447
468
  )
448
- hidden_states, encoder_hidden_states = self._process_outputs(
469
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
470
+ hidden_states, encoder_hidden_states
471
+ )
472
+
473
+ return hidden_states, encoder_hidden_states
474
+
475
+
476
+ class PrunedBlocks_Pattern_Base(CachedBlocks_Pattern_Base):
477
+ pruned_blocks_step: int = 0 # number of pruned blocks in current step
478
+
479
+ def __init__(
480
+ self,
481
+ # 0. Transformer blocks configuration
482
+ transformer_blocks: torch.nn.ModuleList,
483
+ transformer: torch.nn.Module = None,
484
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
485
+ check_forward_pattern: bool = True,
486
+ check_num_outputs: bool = True,
487
+ # 1. Prune context configuration
488
+ cache_prefix: str = None, # maybe un-need.
489
+ cache_context: PrunedContext | str = None,
490
+ context_manager: PrunedContextManager = None,
491
+ cache_type: CacheType = CacheType.DBPrune,
492
+ **kwargs,
493
+ ):
494
+ super().__init__(
495
+ # 0. Transformer blocks configuration
496
+ transformer_blocks,
497
+ transformer=transformer,
498
+ forward_pattern=forward_pattern,
499
+ check_forward_pattern=check_forward_pattern,
500
+ check_num_outputs=check_num_outputs,
501
+ # 1. Cache context configuration
502
+ cache_prefix=cache_prefix,
503
+ cache_context=cache_context,
504
+ context_manager=context_manager,
505
+ cache_type=cache_type,
506
+ **kwargs,
507
+ )
508
+ assert isinstance(
509
+ self.context_manager, PrunedContextManager
510
+ ), "context_manager must be PrunedContextManager for PrunedBlocks."
511
+ self.context_manager: PrunedContextManager = (
512
+ self.context_manager
513
+ ) # For type hint
514
+
515
+ @torch.compiler.disable
516
+ def _check_cache_type(self):
517
+ assert (
518
+ self.cache_type == CacheType.DBPrune
519
+ ), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
520
+
521
+ def forward(
522
+ self,
523
+ hidden_states: torch.Tensor,
524
+ encoder_hidden_states: torch.Tensor,
525
+ *args,
526
+ **kwargs,
527
+ ):
528
+ self.pruned_blocks_step: int = 0 # reset for each step
529
+
530
+ # Use it's own cache context.
531
+ try:
532
+ self.context_manager.set_context(self.cache_context)
533
+ self._check_cache_params()
534
+ except ContextNotExistError as e:
535
+ logger.warning(f"Cache context not exist: {e}, skip prune.")
536
+ # Fallback to call all blocks to process the hidden states w/o prune.
537
+ hidden_states, encoder_hidden_states = self.call_blocks(
538
+ hidden_states,
539
+ encoder_hidden_states,
540
+ *args,
541
+ **kwargs,
542
+ )
543
+ return self._process_forward_outputs(
544
+ hidden_states,
545
+ encoder_hidden_states,
546
+ )
547
+
548
+ self.context_manager.mark_step_begin()
549
+
550
+ # Call all blocks with prune strategy to process the hidden states.
551
+ for i, block in enumerate(self.transformer_blocks):
552
+ hidden_states, encoder_hidden_states = self.compute_or_prune(
553
+ i,
554
+ block,
555
+ hidden_states,
556
+ encoder_hidden_states,
557
+ *args,
558
+ **kwargs,
559
+ )
560
+
561
+ self.context_manager.add_pruned_block(self.pruned_blocks_step)
562
+ self.context_manager.add_actual_block(self.num_blocks)
563
+
564
+ return self._process_forward_outputs(
565
+ hidden_states,
566
+ encoder_hidden_states,
567
+ )
568
+
569
+ @property
570
+ @torch.compiler.disable
571
+ def num_blocks(self):
572
+ return len(self.transformer_blocks)
573
+
574
+ @torch.compiler.disable
575
+ def _skip_prune(self, block_id: int) -> bool:
576
+ # Wrap for non compiled mode.
577
+ return block_id in self.context_manager.get_non_prune_blocks_ids(
578
+ self.num_blocks
579
+ )
580
+
581
+ @torch.compiler.disable
582
+ def _maybe_prune(
583
+ self,
584
+ block_id: int, # Block index in the transformer blocks
585
+ hidden_states: torch.Tensor, # hidden_states or residual
586
+ prefix: str = "Bn_original", # prev step name for single blocks
587
+ ):
588
+ # Wrap for non compiled mode.
589
+ can_use_prune = False
590
+ if not self._skip_prune(block_id):
591
+ can_use_prune = self.context_manager.can_prune(
592
+ hidden_states, # curr step
593
+ parallelized=self._is_parallelized(),
594
+ prefix=prefix, # prev step
595
+ )
596
+ self.pruned_blocks_step += int(can_use_prune)
597
+ return can_use_prune
598
+
599
+ def compute_or_prune(
600
+ self,
601
+ block_id: int, # Block index in the transformer blocks
602
+ # Below are the inputs to the block
603
+ block, # The transformer block to be executed
604
+ hidden_states: torch.Tensor,
605
+ encoder_hidden_states: torch.Tensor,
606
+ *args,
607
+ **kwargs,
608
+ ):
609
+ original_hidden_states = hidden_states
610
+ original_encoder_hidden_states = encoder_hidden_states
611
+
612
+ can_use_prune = self._maybe_prune(
613
+ block_id,
614
+ hidden_states,
615
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
616
+ )
617
+
618
+ # Prune steps: Prune current block and reuse the cached
619
+ # residuals for hidden states approximate.
620
+ torch._dynamo.graph_break()
621
+ if can_use_prune:
622
+ self.context_manager.add_pruned_step()
623
+ hidden_states, encoder_hidden_states = (
624
+ self.context_manager.apply_prune(
625
+ hidden_states,
626
+ encoder_hidden_states,
627
+ prefix=(
628
+ f"{self.cache_prefix}_{block_id}_Bn_residual"
629
+ if self.context_manager.is_cache_residual()
630
+ else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
631
+ ),
632
+ encoder_prefix=(
633
+ f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
634
+ if self.context_manager.is_encoder_cache_residual()
635
+ else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
636
+ ),
637
+ )
638
+ )
639
+ torch._dynamo.graph_break()
640
+ else:
641
+ # Normal steps: Compute the block and cache the residuals.
642
+ hidden_states = block(
643
+ hidden_states,
644
+ encoder_hidden_states,
645
+ *args,
646
+ **kwargs,
647
+ )
648
+ hidden_states, encoder_hidden_states = self._process_block_outputs(
449
649
  hidden_states, encoder_hidden_states
450
650
  )
651
+ if not self._skip_prune(block_id):
652
+ hidden_states = hidden_states.contiguous()
653
+ hidden_states_residual = hidden_states - original_hidden_states
654
+
655
+ if (
656
+ encoder_hidden_states is not None
657
+ and original_encoder_hidden_states is not None
658
+ ):
659
+ encoder_hidden_states = encoder_hidden_states.contiguous()
660
+ encoder_hidden_states_residual = (
661
+ encoder_hidden_states - original_encoder_hidden_states
662
+ )
663
+ else:
664
+ encoder_hidden_states_residual = None
665
+
666
+ self.context_manager.set_Fn_buffer(
667
+ original_hidden_states,
668
+ prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
669
+ )
670
+ if self.context_manager.is_cache_residual():
671
+ self.context_manager.set_Bn_buffer(
672
+ hidden_states_residual,
673
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
674
+ )
675
+ else:
676
+ self.context_manager.set_Bn_buffer(
677
+ hidden_states,
678
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
679
+ )
680
+ if encoder_hidden_states_residual is not None:
681
+ if self.context_manager.is_encoder_cache_residual():
682
+ self.context_manager.set_Bn_encoder_buffer(
683
+ encoder_hidden_states_residual,
684
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
685
+ )
686
+ else:
687
+ self.context_manager.set_Bn_encoder_buffer(
688
+ encoder_hidden_states_residual,
689
+ prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
690
+ )
691
+ torch._dynamo.graph_break()
451
692
 
452
693
  return hidden_states, encoder_hidden_states
@@ -3,34 +3,61 @@ import torch
3
3
  from typing import Any
4
4
  from cache_dit.cache_factory import CachedContext
5
5
  from cache_dit.cache_factory import CachedContextManager
6
+ from cache_dit.cache_factory import PrunedContextManager
6
7
 
7
8
 
8
- def patch_cached_stats(
9
+ def apply_stats(
9
10
  module: torch.nn.Module | Any,
10
11
  cache_context: CachedContext | str = None,
11
- cache_manager: CachedContextManager = None,
12
+ context_manager: CachedContextManager | PrunedContextManager = None,
12
13
  ):
13
14
  # Patch the cached stats to the module, the cached stats
14
15
  # will be reset for each calling of pipe.__call__(**kwargs).
15
- if module is None or cache_manager is None:
16
+ if module is None or context_manager is None:
16
17
  return
17
18
 
18
19
  if cache_context is not None:
19
- cache_manager.set_context(cache_context)
20
+ context_manager.set_context(cache_context)
20
21
 
21
- # TODO: Patch more cached stats to the module
22
- module._cached_steps = cache_manager.get_cached_steps()
23
- module._residual_diffs = cache_manager.get_residual_diffs()
24
- module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
25
- module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
22
+ # Cache stats for Dual Block Cache
23
+ module._cached_steps = context_manager.get_cached_steps()
24
+ module._residual_diffs = context_manager.get_residual_diffs()
25
+ module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
26
+ module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
27
+ # Pruned stats for Dynamic Block Prune
28
+ if not isinstance(context_manager, PrunedContextManager):
29
+ return
30
+ module._pruned_steps = context_manager.get_pruned_steps()
31
+ module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
32
+ module._pruned_blocks = context_manager.get_pruned_blocks()
33
+ module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
34
+ module._actual_blocks = context_manager.get_actual_blocks()
35
+ module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
36
+ # Caculate pruned ratio
37
+ if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
38
+ module._pruned_ratio = sum(module._pruned_blocks) / sum(
39
+ module._actual_blocks
40
+ )
41
+ else:
42
+ module._pruned_ratio = None
43
+ if (
44
+ len(module._cfg_pruned_blocks) > 0
45
+ and sum(module._cfg_actual_blocks) > 0
46
+ ):
47
+ module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
48
+ module._cfg_actual_blocks
49
+ )
50
+ else:
51
+ module._cfg_pruned_ratio = None
26
52
 
27
53
 
28
- def remove_cached_stats(
54
+ def remove_stats(
29
55
  module: torch.nn.Module | Any,
30
56
  ):
31
57
  if module is None:
32
58
  return
33
59
 
60
+ # Dual Block Cache
34
61
  if hasattr(module, "_cached_steps"):
35
62
  del module._cached_steps
36
63
  if hasattr(module, "_residual_diffs"):
@@ -39,3 +66,21 @@ def remove_cached_stats(
39
66
  del module._cfg_cached_steps
40
67
  if hasattr(module, "_cfg_residual_diffs"):
41
68
  del module._cfg_residual_diffs
69
+
70
+ # Dynamic Block Prune
71
+ if hasattr(module, "_pruned_steps"):
72
+ del module._pruned_steps
73
+ if hasattr(module, "_cfg_pruned_steps"):
74
+ del module._cfg_pruned_steps
75
+ if hasattr(module, "_pruned_blocks"):
76
+ del module._pruned_blocks
77
+ if hasattr(module, "_cfg_pruned_blocks"):
78
+ del module._cfg_pruned_blocks
79
+ if hasattr(module, "_actual_blocks"):
80
+ del module._actual_blocks
81
+ if hasattr(module, "_cfg_actual_blocks"):
82
+ del module._cfg_actual_blocks
83
+ if hasattr(module, "_pruned_ratio"):
84
+ del module._pruned_ratio
85
+ if hasattr(module, "_cfg_pruned_ratio"):
86
+ del module._cfg_pruned_ratio
@@ -5,11 +5,24 @@ from cache_dit.cache_factory.cache_contexts.calibrators import (
5
5
  TaylorSeerCalibratorConfig,
6
6
  FoCaCalibratorConfig,
7
7
  )
8
+ from cache_dit.cache_factory.cache_contexts.cache_config import (
9
+ BasicCacheConfig,
10
+ DBCacheConfig,
11
+ )
8
12
  from cache_dit.cache_factory.cache_contexts.cache_context import (
9
13
  CachedContext,
10
- BasicCacheConfig,
11
14
  )
12
15
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
16
  CachedContextManager,
14
- CacheNotExistError,
17
+ ContextNotExistError,
18
+ )
19
+ from cache_dit.cache_factory.cache_contexts.prune_config import DBPruneConfig
20
+ from cache_dit.cache_factory.cache_contexts.prune_context import (
21
+ PrunedContext,
22
+ )
23
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
24
+ PrunedContextManager,
25
+ )
26
+ from cache_dit.cache_factory.cache_contexts.context_manager import (
27
+ ContextManager,
15
28
  )