cache-dit 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/block_adapters/__init__.py +4 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +126 -80
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
- cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +118 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +22 -0
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +63 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +20 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +18 -7
- cache_dit/quantize/quantize_ao.py +58 -17
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/METADATA +11 -10
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/RECORD +32 -27
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,16 @@ import torch
|
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
|
|
6
7
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
8
|
CachedContextManager,
|
|
8
|
-
|
|
9
|
+
ContextNotExistError,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
12
|
+
PrunedContextManager,
|
|
9
13
|
)
|
|
10
14
|
from cache_dit.cache_factory import ForwardPattern
|
|
15
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
11
16
|
from cache_dit.logger import init_logger
|
|
12
17
|
|
|
13
18
|
logger = init_logger(__name__)
|
|
@@ -31,7 +36,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
31
36
|
# 1. Cache context configuration
|
|
32
37
|
cache_prefix: str = None, # maybe un-need.
|
|
33
38
|
cache_context: CachedContext | str = None,
|
|
34
|
-
|
|
39
|
+
context_manager: CachedContextManager = None,
|
|
40
|
+
cache_type: CacheType = CacheType.DBCache,
|
|
35
41
|
**kwargs,
|
|
36
42
|
):
|
|
37
43
|
super().__init__()
|
|
@@ -45,13 +51,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
45
51
|
# 1. Cache context configuration
|
|
46
52
|
self.cache_prefix = cache_prefix
|
|
47
53
|
self.cache_context = cache_context
|
|
48
|
-
self.
|
|
54
|
+
self.context_manager = context_manager
|
|
55
|
+
self.cache_type = cache_type
|
|
49
56
|
|
|
50
57
|
self._check_forward_pattern()
|
|
58
|
+
self._check_cache_type()
|
|
51
59
|
logger.info(
|
|
52
|
-
f"Match
|
|
60
|
+
f"Match Blocks: {self.__class__.__name__}, for "
|
|
53
61
|
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
54
|
-
f"
|
|
62
|
+
f"context_manager: {self.context_manager.name}."
|
|
55
63
|
)
|
|
56
64
|
|
|
57
65
|
def _check_forward_pattern(self):
|
|
@@ -94,18 +102,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
94
102
|
required_param in forward_parameters
|
|
95
103
|
), f"The input parameters must contains: {required_param}."
|
|
96
104
|
|
|
105
|
+
@torch.compiler.disable
|
|
106
|
+
def _check_cache_type(self):
|
|
107
|
+
assert (
|
|
108
|
+
self.cache_type == CacheType.DBCache
|
|
109
|
+
), f"Cache type {self.cache_type} is not supported for CachedBlocks."
|
|
110
|
+
|
|
97
111
|
@torch.compiler.disable
|
|
98
112
|
def _check_cache_params(self):
|
|
99
|
-
|
|
113
|
+
self._check_cache_type()
|
|
114
|
+
assert self.context_manager.Fn_compute_blocks() <= len(
|
|
100
115
|
self.transformer_blocks
|
|
101
116
|
), (
|
|
102
|
-
f"Fn_compute_blocks {self.
|
|
117
|
+
f"Fn_compute_blocks {self.context_manager.Fn_compute_blocks()} must be less than "
|
|
103
118
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
104
119
|
)
|
|
105
|
-
assert self.
|
|
120
|
+
assert self.context_manager.Bn_compute_blocks() <= len(
|
|
106
121
|
self.transformer_blocks
|
|
107
122
|
), (
|
|
108
|
-
f"Bn_compute_blocks {self.
|
|
123
|
+
f"Bn_compute_blocks {self.context_manager.Bn_compute_blocks()} must be less than "
|
|
109
124
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
110
125
|
)
|
|
111
126
|
|
|
@@ -174,9 +189,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
174
189
|
):
|
|
175
190
|
# Use it's own cache context.
|
|
176
191
|
try:
|
|
177
|
-
self.
|
|
192
|
+
self.context_manager.set_context(self.cache_context)
|
|
178
193
|
self._check_cache_params()
|
|
179
|
-
except
|
|
194
|
+
except ContextNotExistError as e:
|
|
180
195
|
logger.warning(f"Cache context not exist: {e}, skip cache.")
|
|
181
196
|
# Call all blocks to process the hidden states.
|
|
182
197
|
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
@@ -203,38 +218,38 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
203
218
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
204
219
|
del original_hidden_states
|
|
205
220
|
|
|
206
|
-
self.
|
|
221
|
+
self.context_manager.mark_step_begin()
|
|
207
222
|
# Residual L1 diff or Hidden States L1 diff
|
|
208
|
-
can_use_cache = self.
|
|
223
|
+
can_use_cache = self.context_manager.can_cache(
|
|
209
224
|
(
|
|
210
225
|
Fn_hidden_states_residual
|
|
211
|
-
if not self.
|
|
226
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
212
227
|
else hidden_states
|
|
213
228
|
),
|
|
214
229
|
parallelized=self._is_parallelized(),
|
|
215
230
|
prefix=(
|
|
216
231
|
f"{self.cache_prefix}_Fn_residual"
|
|
217
|
-
if not self.
|
|
232
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
218
233
|
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
219
234
|
),
|
|
220
235
|
)
|
|
221
236
|
|
|
222
237
|
torch._dynamo.graph_break()
|
|
223
238
|
if can_use_cache:
|
|
224
|
-
self.
|
|
239
|
+
self.context_manager.add_cached_step()
|
|
225
240
|
del Fn_hidden_states_residual
|
|
226
241
|
hidden_states, encoder_hidden_states = (
|
|
227
|
-
self.
|
|
242
|
+
self.context_manager.apply_cache(
|
|
228
243
|
hidden_states,
|
|
229
244
|
encoder_hidden_states,
|
|
230
245
|
prefix=(
|
|
231
246
|
f"{self.cache_prefix}_Bn_residual"
|
|
232
|
-
if self.
|
|
247
|
+
if self.context_manager.is_cache_residual()
|
|
233
248
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
234
249
|
),
|
|
235
250
|
encoder_prefix=(
|
|
236
251
|
f"{self.cache_prefix}_Bn_residual"
|
|
237
|
-
if self.
|
|
252
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
238
253
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
239
254
|
),
|
|
240
255
|
)
|
|
@@ -249,13 +264,13 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
249
264
|
**kwargs,
|
|
250
265
|
)
|
|
251
266
|
else:
|
|
252
|
-
self.
|
|
267
|
+
self.context_manager.set_Fn_buffer(
|
|
253
268
|
Fn_hidden_states_residual,
|
|
254
269
|
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
255
270
|
)
|
|
256
|
-
if self.
|
|
271
|
+
if self.context_manager.is_l1_diff_enabled():
|
|
257
272
|
# for hidden states L1 diff
|
|
258
|
-
self.
|
|
273
|
+
self.context_manager.set_Fn_buffer(
|
|
259
274
|
hidden_states,
|
|
260
275
|
f"{self.cache_prefix}_Fn_hidden_states",
|
|
261
276
|
)
|
|
@@ -273,24 +288,24 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
273
288
|
**kwargs,
|
|
274
289
|
)
|
|
275
290
|
torch._dynamo.graph_break()
|
|
276
|
-
if self.
|
|
277
|
-
self.
|
|
291
|
+
if self.context_manager.is_cache_residual():
|
|
292
|
+
self.context_manager.set_Bn_buffer(
|
|
278
293
|
hidden_states_residual,
|
|
279
294
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
280
295
|
)
|
|
281
296
|
else:
|
|
282
|
-
self.
|
|
297
|
+
self.context_manager.set_Bn_buffer(
|
|
283
298
|
hidden_states,
|
|
284
299
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
285
300
|
)
|
|
286
301
|
|
|
287
|
-
if self.
|
|
288
|
-
self.
|
|
302
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
303
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
289
304
|
encoder_hidden_states_residual,
|
|
290
305
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
291
306
|
)
|
|
292
307
|
else:
|
|
293
|
-
self.
|
|
308
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
294
309
|
encoder_hidden_states,
|
|
295
310
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
296
311
|
)
|
|
@@ -333,11 +348,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
333
348
|
# If so, we can skip some Bn blocks and directly
|
|
334
349
|
# use the cached values.
|
|
335
350
|
return (
|
|
336
|
-
self.
|
|
337
|
-
in self.
|
|
351
|
+
self.context_manager.get_current_step()
|
|
352
|
+
in self.context_manager.get_cached_steps()
|
|
338
353
|
) or (
|
|
339
|
-
self.
|
|
340
|
-
in self.
|
|
354
|
+
self.context_manager.get_current_step()
|
|
355
|
+
in self.context_manager.get_cfg_cached_steps()
|
|
341
356
|
)
|
|
342
357
|
|
|
343
358
|
@torch.compiler.disable
|
|
@@ -346,20 +361,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
346
361
|
# more stable diff calculation.
|
|
347
362
|
# Fn: [0,...,n-1]
|
|
348
363
|
selected_Fn_blocks = self.transformer_blocks[
|
|
349
|
-
: self.
|
|
364
|
+
: self.context_manager.Fn_compute_blocks()
|
|
350
365
|
]
|
|
351
366
|
return selected_Fn_blocks
|
|
352
367
|
|
|
353
368
|
@torch.compiler.disable
|
|
354
369
|
def _Mn_blocks(self): # middle blocks
|
|
355
370
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
356
|
-
if self.
|
|
371
|
+
if self.context_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
357
372
|
selected_Mn_blocks = self.transformer_blocks[
|
|
358
|
-
self.
|
|
373
|
+
self.context_manager.Fn_compute_blocks() :
|
|
359
374
|
]
|
|
360
375
|
else:
|
|
361
376
|
selected_Mn_blocks = self.transformer_blocks[
|
|
362
|
-
self.
|
|
377
|
+
self.context_manager.Fn_compute_blocks() : -self.context_manager.Bn_compute_blocks()
|
|
363
378
|
]
|
|
364
379
|
return selected_Mn_blocks
|
|
365
380
|
|
|
@@ -367,7 +382,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
367
382
|
def _Bn_blocks(self):
|
|
368
383
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
369
384
|
selected_Bn_blocks = self.transformer_blocks[
|
|
370
|
-
-self.
|
|
385
|
+
-self.context_manager.Bn_compute_blocks() :
|
|
371
386
|
]
|
|
372
387
|
return selected_Bn_blocks
|
|
373
388
|
|
|
@@ -441,7 +456,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
441
456
|
*args,
|
|
442
457
|
**kwargs,
|
|
443
458
|
):
|
|
444
|
-
if self.
|
|
459
|
+
if self.context_manager.Bn_compute_blocks() == 0:
|
|
445
460
|
return hidden_states, encoder_hidden_states
|
|
446
461
|
|
|
447
462
|
for block in self._Bn_blocks():
|
|
@@ -456,3 +471,223 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
456
471
|
)
|
|
457
472
|
|
|
458
473
|
return hidden_states, encoder_hidden_states
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class PrunedBlocks_Pattern_Base(CachedBlocks_Pattern_Base):
|
|
477
|
+
pruned_blocks_step: int = 0 # number of pruned blocks in current step
|
|
478
|
+
|
|
479
|
+
def __init__(
|
|
480
|
+
self,
|
|
481
|
+
# 0. Transformer blocks configuration
|
|
482
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
483
|
+
transformer: torch.nn.Module = None,
|
|
484
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
485
|
+
check_forward_pattern: bool = True,
|
|
486
|
+
check_num_outputs: bool = True,
|
|
487
|
+
# 1. Prune context configuration
|
|
488
|
+
cache_prefix: str = None, # maybe un-need.
|
|
489
|
+
cache_context: PrunedContext | str = None,
|
|
490
|
+
context_manager: PrunedContextManager = None,
|
|
491
|
+
cache_type: CacheType = CacheType.DBPrune,
|
|
492
|
+
**kwargs,
|
|
493
|
+
):
|
|
494
|
+
super().__init__(
|
|
495
|
+
# 0. Transformer blocks configuration
|
|
496
|
+
transformer_blocks,
|
|
497
|
+
transformer=transformer,
|
|
498
|
+
forward_pattern=forward_pattern,
|
|
499
|
+
check_forward_pattern=check_forward_pattern,
|
|
500
|
+
check_num_outputs=check_num_outputs,
|
|
501
|
+
# 1. Cache context configuration
|
|
502
|
+
cache_prefix=cache_prefix,
|
|
503
|
+
cache_context=cache_context,
|
|
504
|
+
context_manager=context_manager,
|
|
505
|
+
cache_type=cache_type,
|
|
506
|
+
**kwargs,
|
|
507
|
+
)
|
|
508
|
+
assert isinstance(
|
|
509
|
+
self.context_manager, PrunedContextManager
|
|
510
|
+
), "context_manager must be PrunedContextManager for PrunedBlocks."
|
|
511
|
+
self.context_manager: PrunedContextManager = (
|
|
512
|
+
self.context_manager
|
|
513
|
+
) # For type hint
|
|
514
|
+
|
|
515
|
+
@torch.compiler.disable
|
|
516
|
+
def _check_cache_type(self):
|
|
517
|
+
assert (
|
|
518
|
+
self.cache_type == CacheType.DBPrune
|
|
519
|
+
), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
|
|
520
|
+
|
|
521
|
+
def forward(
|
|
522
|
+
self,
|
|
523
|
+
hidden_states: torch.Tensor,
|
|
524
|
+
encoder_hidden_states: torch.Tensor,
|
|
525
|
+
*args,
|
|
526
|
+
**kwargs,
|
|
527
|
+
):
|
|
528
|
+
self.pruned_blocks_step: int = 0 # reset for each step
|
|
529
|
+
|
|
530
|
+
# Use it's own cache context.
|
|
531
|
+
try:
|
|
532
|
+
self.context_manager.set_context(self.cache_context)
|
|
533
|
+
self._check_cache_params()
|
|
534
|
+
except ContextNotExistError as e:
|
|
535
|
+
logger.warning(f"Cache context not exist: {e}, skip prune.")
|
|
536
|
+
# Fallback to call all blocks to process the hidden states w/o prune.
|
|
537
|
+
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
538
|
+
hidden_states,
|
|
539
|
+
encoder_hidden_states,
|
|
540
|
+
*args,
|
|
541
|
+
**kwargs,
|
|
542
|
+
)
|
|
543
|
+
return self._process_forward_outputs(
|
|
544
|
+
hidden_states,
|
|
545
|
+
encoder_hidden_states,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
self.context_manager.mark_step_begin()
|
|
549
|
+
|
|
550
|
+
# Call all blocks with prune strategy to process the hidden states.
|
|
551
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
552
|
+
hidden_states, encoder_hidden_states = self.compute_or_prune(
|
|
553
|
+
i,
|
|
554
|
+
block,
|
|
555
|
+
hidden_states,
|
|
556
|
+
encoder_hidden_states,
|
|
557
|
+
*args,
|
|
558
|
+
**kwargs,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
self.context_manager.add_pruned_block(self.pruned_blocks_step)
|
|
562
|
+
self.context_manager.add_actual_block(self.num_blocks)
|
|
563
|
+
|
|
564
|
+
return self._process_forward_outputs(
|
|
565
|
+
hidden_states,
|
|
566
|
+
encoder_hidden_states,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
@torch.compiler.disable
|
|
571
|
+
def num_blocks(self):
|
|
572
|
+
return len(self.transformer_blocks)
|
|
573
|
+
|
|
574
|
+
@torch.compiler.disable
|
|
575
|
+
def _skip_prune(self, block_id: int) -> bool:
|
|
576
|
+
# Wrap for non compiled mode.
|
|
577
|
+
return block_id in self.context_manager.get_non_prune_blocks_ids(
|
|
578
|
+
self.num_blocks
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
@torch.compiler.disable
|
|
582
|
+
def _maybe_prune(
|
|
583
|
+
self,
|
|
584
|
+
block_id: int, # Block index in the transformer blocks
|
|
585
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
586
|
+
prefix: str = "Bn_original", # prev step name for single blocks
|
|
587
|
+
):
|
|
588
|
+
# Wrap for non compiled mode.
|
|
589
|
+
can_use_prune = False
|
|
590
|
+
if not self._skip_prune(block_id):
|
|
591
|
+
can_use_prune = self.context_manager.can_prune(
|
|
592
|
+
hidden_states, # curr step
|
|
593
|
+
parallelized=self._is_parallelized(),
|
|
594
|
+
prefix=prefix, # prev step
|
|
595
|
+
)
|
|
596
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
597
|
+
return can_use_prune
|
|
598
|
+
|
|
599
|
+
def compute_or_prune(
|
|
600
|
+
self,
|
|
601
|
+
block_id: int, # Block index in the transformer blocks
|
|
602
|
+
# Below are the inputs to the block
|
|
603
|
+
block, # The transformer block to be executed
|
|
604
|
+
hidden_states: torch.Tensor,
|
|
605
|
+
encoder_hidden_states: torch.Tensor,
|
|
606
|
+
*args,
|
|
607
|
+
**kwargs,
|
|
608
|
+
):
|
|
609
|
+
original_hidden_states = hidden_states
|
|
610
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
611
|
+
|
|
612
|
+
can_use_prune = self._maybe_prune(
|
|
613
|
+
block_id,
|
|
614
|
+
hidden_states,
|
|
615
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Prune steps: Prune current block and reuse the cached
|
|
619
|
+
# residuals for hidden states approximate.
|
|
620
|
+
torch._dynamo.graph_break()
|
|
621
|
+
if can_use_prune:
|
|
622
|
+
self.context_manager.add_pruned_step()
|
|
623
|
+
hidden_states, encoder_hidden_states = (
|
|
624
|
+
self.context_manager.apply_prune(
|
|
625
|
+
hidden_states,
|
|
626
|
+
encoder_hidden_states,
|
|
627
|
+
prefix=(
|
|
628
|
+
f"{self.cache_prefix}_{block_id}_Bn_residual"
|
|
629
|
+
if self.context_manager.is_cache_residual()
|
|
630
|
+
else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
|
|
631
|
+
),
|
|
632
|
+
encoder_prefix=(
|
|
633
|
+
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
|
|
634
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
635
|
+
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
|
|
636
|
+
),
|
|
637
|
+
)
|
|
638
|
+
)
|
|
639
|
+
torch._dynamo.graph_break()
|
|
640
|
+
else:
|
|
641
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
642
|
+
hidden_states = block(
|
|
643
|
+
hidden_states,
|
|
644
|
+
encoder_hidden_states,
|
|
645
|
+
*args,
|
|
646
|
+
**kwargs,
|
|
647
|
+
)
|
|
648
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
649
|
+
hidden_states, encoder_hidden_states
|
|
650
|
+
)
|
|
651
|
+
if not self._skip_prune(block_id):
|
|
652
|
+
hidden_states = hidden_states.contiguous()
|
|
653
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
654
|
+
|
|
655
|
+
if (
|
|
656
|
+
encoder_hidden_states is not None
|
|
657
|
+
and original_encoder_hidden_states is not None
|
|
658
|
+
):
|
|
659
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
660
|
+
encoder_hidden_states_residual = (
|
|
661
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
encoder_hidden_states_residual = None
|
|
665
|
+
|
|
666
|
+
self.context_manager.set_Fn_buffer(
|
|
667
|
+
original_hidden_states,
|
|
668
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
669
|
+
)
|
|
670
|
+
if self.context_manager.is_cache_residual():
|
|
671
|
+
self.context_manager.set_Bn_buffer(
|
|
672
|
+
hidden_states_residual,
|
|
673
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
self.context_manager.set_Bn_buffer(
|
|
677
|
+
hidden_states,
|
|
678
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
|
|
679
|
+
)
|
|
680
|
+
if encoder_hidden_states_residual is not None:
|
|
681
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
682
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
683
|
+
encoder_hidden_states_residual,
|
|
684
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
|
|
685
|
+
)
|
|
686
|
+
else:
|
|
687
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
688
|
+
encoder_hidden_states_residual,
|
|
689
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
|
|
690
|
+
)
|
|
691
|
+
torch._dynamo.graph_break()
|
|
692
|
+
|
|
693
|
+
return hidden_states, encoder_hidden_states
|
|
@@ -3,34 +3,61 @@ import torch
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
from cache_dit.cache_factory import CachedContext
|
|
5
5
|
from cache_dit.cache_factory import CachedContextManager
|
|
6
|
+
from cache_dit.cache_factory import PrunedContextManager
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
def
|
|
9
|
+
def apply_stats(
|
|
9
10
|
module: torch.nn.Module | Any,
|
|
10
11
|
cache_context: CachedContext | str = None,
|
|
11
|
-
|
|
12
|
+
context_manager: CachedContextManager | PrunedContextManager = None,
|
|
12
13
|
):
|
|
13
14
|
# Patch the cached stats to the module, the cached stats
|
|
14
15
|
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
15
|
-
if module is None or
|
|
16
|
+
if module is None or context_manager is None:
|
|
16
17
|
return
|
|
17
18
|
|
|
18
19
|
if cache_context is not None:
|
|
19
|
-
|
|
20
|
+
context_manager.set_context(cache_context)
|
|
20
21
|
|
|
21
|
-
#
|
|
22
|
-
module._cached_steps =
|
|
23
|
-
module._residual_diffs =
|
|
24
|
-
module._cfg_cached_steps =
|
|
25
|
-
module._cfg_residual_diffs =
|
|
22
|
+
# Cache stats for Dual Block Cache
|
|
23
|
+
module._cached_steps = context_manager.get_cached_steps()
|
|
24
|
+
module._residual_diffs = context_manager.get_residual_diffs()
|
|
25
|
+
module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
|
|
26
|
+
module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
|
|
27
|
+
# Pruned stats for Dynamic Block Prune
|
|
28
|
+
if not isinstance(context_manager, PrunedContextManager):
|
|
29
|
+
return
|
|
30
|
+
module._pruned_steps = context_manager.get_pruned_steps()
|
|
31
|
+
module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
|
|
32
|
+
module._pruned_blocks = context_manager.get_pruned_blocks()
|
|
33
|
+
module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
|
|
34
|
+
module._actual_blocks = context_manager.get_actual_blocks()
|
|
35
|
+
module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
|
|
36
|
+
# Caculate pruned ratio
|
|
37
|
+
if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
|
|
38
|
+
module._pruned_ratio = sum(module._pruned_blocks) / sum(
|
|
39
|
+
module._actual_blocks
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
module._pruned_ratio = None
|
|
43
|
+
if (
|
|
44
|
+
len(module._cfg_pruned_blocks) > 0
|
|
45
|
+
and sum(module._cfg_actual_blocks) > 0
|
|
46
|
+
):
|
|
47
|
+
module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
|
|
48
|
+
module._cfg_actual_blocks
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
module._cfg_pruned_ratio = None
|
|
26
52
|
|
|
27
53
|
|
|
28
|
-
def
|
|
54
|
+
def remove_stats(
|
|
29
55
|
module: torch.nn.Module | Any,
|
|
30
56
|
):
|
|
31
57
|
if module is None:
|
|
32
58
|
return
|
|
33
59
|
|
|
60
|
+
# Dual Block Cache
|
|
34
61
|
if hasattr(module, "_cached_steps"):
|
|
35
62
|
del module._cached_steps
|
|
36
63
|
if hasattr(module, "_residual_diffs"):
|
|
@@ -39,3 +66,21 @@ def remove_cached_stats(
|
|
|
39
66
|
del module._cfg_cached_steps
|
|
40
67
|
if hasattr(module, "_cfg_residual_diffs"):
|
|
41
68
|
del module._cfg_residual_diffs
|
|
69
|
+
|
|
70
|
+
# Dynamic Block Prune
|
|
71
|
+
if hasattr(module, "_pruned_steps"):
|
|
72
|
+
del module._pruned_steps
|
|
73
|
+
if hasattr(module, "_cfg_pruned_steps"):
|
|
74
|
+
del module._cfg_pruned_steps
|
|
75
|
+
if hasattr(module, "_pruned_blocks"):
|
|
76
|
+
del module._pruned_blocks
|
|
77
|
+
if hasattr(module, "_cfg_pruned_blocks"):
|
|
78
|
+
del module._cfg_pruned_blocks
|
|
79
|
+
if hasattr(module, "_actual_blocks"):
|
|
80
|
+
del module._actual_blocks
|
|
81
|
+
if hasattr(module, "_cfg_actual_blocks"):
|
|
82
|
+
del module._cfg_actual_blocks
|
|
83
|
+
if hasattr(module, "_pruned_ratio"):
|
|
84
|
+
del module._pruned_ratio
|
|
85
|
+
if hasattr(module, "_cfg_pruned_ratio"):
|
|
86
|
+
del module._cfg_pruned_ratio
|
|
@@ -5,11 +5,24 @@ from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
|
5
5
|
TaylorSeerCalibratorConfig,
|
|
6
6
|
FoCaCalibratorConfig,
|
|
7
7
|
)
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.cache_config import (
|
|
9
|
+
BasicCacheConfig,
|
|
10
|
+
DBCacheConfig,
|
|
11
|
+
)
|
|
8
12
|
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
9
13
|
CachedContext,
|
|
10
|
-
BasicCacheConfig,
|
|
11
14
|
)
|
|
12
15
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
13
16
|
CachedContextManager,
|
|
14
|
-
|
|
17
|
+
ContextNotExistError,
|
|
18
|
+
)
|
|
19
|
+
from cache_dit.cache_factory.cache_contexts.prune_config import DBPruneConfig
|
|
20
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import (
|
|
21
|
+
PrunedContext,
|
|
22
|
+
)
|
|
23
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
24
|
+
PrunedContextManager,
|
|
25
|
+
)
|
|
26
|
+
from cache_dit.cache_factory.cache_contexts.context_manager import (
|
|
27
|
+
ContextManager,
|
|
15
28
|
)
|