cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
- cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +23 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +38 -27
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -3,11 +3,16 @@ import torch
|
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
|
|
6
7
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
8
|
CachedContextManager,
|
|
8
|
-
|
|
9
|
+
ContextNotExistError,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
12
|
+
PrunedContextManager,
|
|
9
13
|
)
|
|
10
14
|
from cache_dit.cache_factory import ForwardPattern
|
|
15
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
11
16
|
from cache_dit.logger import init_logger
|
|
12
17
|
|
|
13
18
|
logger = init_logger(__name__)
|
|
@@ -31,7 +36,8 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
31
36
|
# 1. Cache context configuration
|
|
32
37
|
cache_prefix: str = None, # maybe un-need.
|
|
33
38
|
cache_context: CachedContext | str = None,
|
|
34
|
-
|
|
39
|
+
context_manager: CachedContextManager = None,
|
|
40
|
+
cache_type: CacheType = CacheType.DBCache,
|
|
35
41
|
**kwargs,
|
|
36
42
|
):
|
|
37
43
|
super().__init__()
|
|
@@ -45,13 +51,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
45
51
|
# 1. Cache context configuration
|
|
46
52
|
self.cache_prefix = cache_prefix
|
|
47
53
|
self.cache_context = cache_context
|
|
48
|
-
self.
|
|
54
|
+
self.context_manager = context_manager
|
|
55
|
+
self.cache_type = cache_type
|
|
49
56
|
|
|
50
57
|
self._check_forward_pattern()
|
|
58
|
+
self._check_cache_type()
|
|
51
59
|
logger.info(
|
|
52
|
-
f"Match
|
|
60
|
+
f"Match Blocks: {self.__class__.__name__}, for "
|
|
53
61
|
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
54
|
-
f"
|
|
62
|
+
f"context_manager: {self.context_manager.name}."
|
|
55
63
|
)
|
|
56
64
|
|
|
57
65
|
def _check_forward_pattern(self):
|
|
@@ -94,18 +102,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
94
102
|
required_param in forward_parameters
|
|
95
103
|
), f"The input parameters must contains: {required_param}."
|
|
96
104
|
|
|
105
|
+
@torch.compiler.disable
|
|
106
|
+
def _check_cache_type(self):
|
|
107
|
+
assert (
|
|
108
|
+
self.cache_type == CacheType.DBCache
|
|
109
|
+
), f"Cache type {self.cache_type} is not supported for CachedBlocks."
|
|
110
|
+
|
|
97
111
|
@torch.compiler.disable
|
|
98
112
|
def _check_cache_params(self):
|
|
99
|
-
|
|
113
|
+
self._check_cache_type()
|
|
114
|
+
assert self.context_manager.Fn_compute_blocks() <= len(
|
|
100
115
|
self.transformer_blocks
|
|
101
116
|
), (
|
|
102
|
-
f"Fn_compute_blocks {self.
|
|
117
|
+
f"Fn_compute_blocks {self.context_manager.Fn_compute_blocks()} must be less than "
|
|
103
118
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
104
119
|
)
|
|
105
|
-
assert self.
|
|
120
|
+
assert self.context_manager.Bn_compute_blocks() <= len(
|
|
106
121
|
self.transformer_blocks
|
|
107
122
|
), (
|
|
108
|
-
f"Bn_compute_blocks {self.
|
|
123
|
+
f"Bn_compute_blocks {self.context_manager.Bn_compute_blocks()} must be less than "
|
|
109
124
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
110
125
|
)
|
|
111
126
|
|
|
@@ -135,7 +150,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
135
150
|
return hidden_states, encoder_hidden_states
|
|
136
151
|
|
|
137
152
|
@torch.compiler.disable
|
|
138
|
-
def
|
|
153
|
+
def _process_block_outputs(
|
|
139
154
|
self,
|
|
140
155
|
hidden_states: torch.Tensor | tuple,
|
|
141
156
|
encoder_hidden_states: torch.Tensor | None,
|
|
@@ -150,7 +165,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
150
165
|
return hidden_states, encoder_hidden_states
|
|
151
166
|
|
|
152
167
|
@torch.compiler.disable
|
|
153
|
-
def
|
|
168
|
+
def _process_forward_outputs(
|
|
154
169
|
self,
|
|
155
170
|
hidden_states: torch.Tensor,
|
|
156
171
|
encoder_hidden_states: torch.Tensor | None,
|
|
@@ -174,9 +189,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
174
189
|
):
|
|
175
190
|
# Use it's own cache context.
|
|
176
191
|
try:
|
|
177
|
-
self.
|
|
192
|
+
self.context_manager.set_context(self.cache_context)
|
|
178
193
|
self._check_cache_params()
|
|
179
|
-
except
|
|
194
|
+
except ContextNotExistError as e:
|
|
180
195
|
logger.warning(f"Cache context not exist: {e}, skip cache.")
|
|
181
196
|
# Call all blocks to process the hidden states.
|
|
182
197
|
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
@@ -185,7 +200,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
185
200
|
*args,
|
|
186
201
|
**kwargs,
|
|
187
202
|
)
|
|
188
|
-
return self.
|
|
203
|
+
return self._process_forward_outputs(
|
|
204
|
+
hidden_states,
|
|
205
|
+
encoder_hidden_states,
|
|
206
|
+
)
|
|
189
207
|
|
|
190
208
|
original_hidden_states = hidden_states
|
|
191
209
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -200,38 +218,38 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
200
218
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
201
219
|
del original_hidden_states
|
|
202
220
|
|
|
203
|
-
self.
|
|
221
|
+
self.context_manager.mark_step_begin()
|
|
204
222
|
# Residual L1 diff or Hidden States L1 diff
|
|
205
|
-
can_use_cache = self.
|
|
223
|
+
can_use_cache = self.context_manager.can_cache(
|
|
206
224
|
(
|
|
207
225
|
Fn_hidden_states_residual
|
|
208
|
-
if not self.
|
|
226
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
209
227
|
else hidden_states
|
|
210
228
|
),
|
|
211
229
|
parallelized=self._is_parallelized(),
|
|
212
230
|
prefix=(
|
|
213
231
|
f"{self.cache_prefix}_Fn_residual"
|
|
214
|
-
if not self.
|
|
232
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
215
233
|
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
216
234
|
),
|
|
217
235
|
)
|
|
218
236
|
|
|
219
237
|
torch._dynamo.graph_break()
|
|
220
238
|
if can_use_cache:
|
|
221
|
-
self.
|
|
239
|
+
self.context_manager.add_cached_step()
|
|
222
240
|
del Fn_hidden_states_residual
|
|
223
241
|
hidden_states, encoder_hidden_states = (
|
|
224
|
-
self.
|
|
242
|
+
self.context_manager.apply_cache(
|
|
225
243
|
hidden_states,
|
|
226
244
|
encoder_hidden_states,
|
|
227
245
|
prefix=(
|
|
228
246
|
f"{self.cache_prefix}_Bn_residual"
|
|
229
|
-
if self.
|
|
247
|
+
if self.context_manager.is_cache_residual()
|
|
230
248
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
231
249
|
),
|
|
232
250
|
encoder_prefix=(
|
|
233
251
|
f"{self.cache_prefix}_Bn_residual"
|
|
234
|
-
if self.
|
|
252
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
235
253
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
236
254
|
),
|
|
237
255
|
)
|
|
@@ -246,13 +264,13 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
246
264
|
**kwargs,
|
|
247
265
|
)
|
|
248
266
|
else:
|
|
249
|
-
self.
|
|
267
|
+
self.context_manager.set_Fn_buffer(
|
|
250
268
|
Fn_hidden_states_residual,
|
|
251
269
|
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
252
270
|
)
|
|
253
|
-
if self.
|
|
271
|
+
if self.context_manager.is_l1_diff_enabled():
|
|
254
272
|
# for hidden states L1 diff
|
|
255
|
-
self.
|
|
273
|
+
self.context_manager.set_Fn_buffer(
|
|
256
274
|
hidden_states,
|
|
257
275
|
f"{self.cache_prefix}_Fn_hidden_states",
|
|
258
276
|
)
|
|
@@ -270,24 +288,24 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
270
288
|
**kwargs,
|
|
271
289
|
)
|
|
272
290
|
torch._dynamo.graph_break()
|
|
273
|
-
if self.
|
|
274
|
-
self.
|
|
291
|
+
if self.context_manager.is_cache_residual():
|
|
292
|
+
self.context_manager.set_Bn_buffer(
|
|
275
293
|
hidden_states_residual,
|
|
276
294
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
277
295
|
)
|
|
278
296
|
else:
|
|
279
|
-
self.
|
|
297
|
+
self.context_manager.set_Bn_buffer(
|
|
280
298
|
hidden_states,
|
|
281
299
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
282
300
|
)
|
|
283
301
|
|
|
284
|
-
if self.
|
|
285
|
-
self.
|
|
302
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
303
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
286
304
|
encoder_hidden_states_residual,
|
|
287
305
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
288
306
|
)
|
|
289
307
|
else:
|
|
290
|
-
self.
|
|
308
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
291
309
|
encoder_hidden_states,
|
|
292
310
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
293
311
|
)
|
|
@@ -304,7 +322,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
304
322
|
# patch cached stats for blocks or remove it.
|
|
305
323
|
torch._dynamo.graph_break()
|
|
306
324
|
|
|
307
|
-
return self.
|
|
325
|
+
return self._process_forward_outputs(
|
|
326
|
+
hidden_states,
|
|
327
|
+
encoder_hidden_states,
|
|
328
|
+
)
|
|
308
329
|
|
|
309
330
|
@torch.compiler.disable
|
|
310
331
|
def _is_parallelized(self):
|
|
@@ -327,11 +348,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
327
348
|
# If so, we can skip some Bn blocks and directly
|
|
328
349
|
# use the cached values.
|
|
329
350
|
return (
|
|
330
|
-
self.
|
|
331
|
-
in self.
|
|
351
|
+
self.context_manager.get_current_step()
|
|
352
|
+
in self.context_manager.get_cached_steps()
|
|
332
353
|
) or (
|
|
333
|
-
self.
|
|
334
|
-
in self.
|
|
354
|
+
self.context_manager.get_current_step()
|
|
355
|
+
in self.context_manager.get_cfg_cached_steps()
|
|
335
356
|
)
|
|
336
357
|
|
|
337
358
|
@torch.compiler.disable
|
|
@@ -340,20 +361,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
340
361
|
# more stable diff calculation.
|
|
341
362
|
# Fn: [0,...,n-1]
|
|
342
363
|
selected_Fn_blocks = self.transformer_blocks[
|
|
343
|
-
: self.
|
|
364
|
+
: self.context_manager.Fn_compute_blocks()
|
|
344
365
|
]
|
|
345
366
|
return selected_Fn_blocks
|
|
346
367
|
|
|
347
368
|
@torch.compiler.disable
|
|
348
369
|
def _Mn_blocks(self): # middle blocks
|
|
349
370
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
350
|
-
if self.
|
|
371
|
+
if self.context_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
351
372
|
selected_Mn_blocks = self.transformer_blocks[
|
|
352
|
-
self.
|
|
373
|
+
self.context_manager.Fn_compute_blocks() :
|
|
353
374
|
]
|
|
354
375
|
else:
|
|
355
376
|
selected_Mn_blocks = self.transformer_blocks[
|
|
356
|
-
self.
|
|
377
|
+
self.context_manager.Fn_compute_blocks() : -self.context_manager.Bn_compute_blocks()
|
|
357
378
|
]
|
|
358
379
|
return selected_Mn_blocks
|
|
359
380
|
|
|
@@ -361,7 +382,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
361
382
|
def _Bn_blocks(self):
|
|
362
383
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
363
384
|
selected_Bn_blocks = self.transformer_blocks[
|
|
364
|
-
-self.
|
|
385
|
+
-self.context_manager.Bn_compute_blocks() :
|
|
365
386
|
]
|
|
366
387
|
return selected_Bn_blocks
|
|
367
388
|
|
|
@@ -379,7 +400,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
379
400
|
*args,
|
|
380
401
|
**kwargs,
|
|
381
402
|
)
|
|
382
|
-
hidden_states, encoder_hidden_states = self.
|
|
403
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
383
404
|
hidden_states, encoder_hidden_states
|
|
384
405
|
)
|
|
385
406
|
|
|
@@ -401,7 +422,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
401
422
|
*args,
|
|
402
423
|
**kwargs,
|
|
403
424
|
)
|
|
404
|
-
hidden_states, encoder_hidden_states = self.
|
|
425
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
405
426
|
hidden_states, encoder_hidden_states
|
|
406
427
|
)
|
|
407
428
|
|
|
@@ -435,7 +456,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
435
456
|
*args,
|
|
436
457
|
**kwargs,
|
|
437
458
|
):
|
|
438
|
-
if self.
|
|
459
|
+
if self.context_manager.Bn_compute_blocks() == 0:
|
|
439
460
|
return hidden_states, encoder_hidden_states
|
|
440
461
|
|
|
441
462
|
for block in self._Bn_blocks():
|
|
@@ -445,8 +466,228 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
445
466
|
*args,
|
|
446
467
|
**kwargs,
|
|
447
468
|
)
|
|
448
|
-
hidden_states, encoder_hidden_states = self.
|
|
469
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
470
|
+
hidden_states, encoder_hidden_states
|
|
471
|
+
)
|
|
472
|
+
|
|
473
|
+
return hidden_states, encoder_hidden_states
|
|
474
|
+
|
|
475
|
+
|
|
476
|
+
class PrunedBlocks_Pattern_Base(CachedBlocks_Pattern_Base):
|
|
477
|
+
pruned_blocks_step: int = 0 # number of pruned blocks in current step
|
|
478
|
+
|
|
479
|
+
def __init__(
|
|
480
|
+
self,
|
|
481
|
+
# 0. Transformer blocks configuration
|
|
482
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
483
|
+
transformer: torch.nn.Module = None,
|
|
484
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
485
|
+
check_forward_pattern: bool = True,
|
|
486
|
+
check_num_outputs: bool = True,
|
|
487
|
+
# 1. Prune context configuration
|
|
488
|
+
cache_prefix: str = None, # maybe un-need.
|
|
489
|
+
cache_context: PrunedContext | str = None,
|
|
490
|
+
context_manager: PrunedContextManager = None,
|
|
491
|
+
cache_type: CacheType = CacheType.DBPrune,
|
|
492
|
+
**kwargs,
|
|
493
|
+
):
|
|
494
|
+
super().__init__(
|
|
495
|
+
# 0. Transformer blocks configuration
|
|
496
|
+
transformer_blocks,
|
|
497
|
+
transformer=transformer,
|
|
498
|
+
forward_pattern=forward_pattern,
|
|
499
|
+
check_forward_pattern=check_forward_pattern,
|
|
500
|
+
check_num_outputs=check_num_outputs,
|
|
501
|
+
# 1. Cache context configuration
|
|
502
|
+
cache_prefix=cache_prefix,
|
|
503
|
+
cache_context=cache_context,
|
|
504
|
+
context_manager=context_manager,
|
|
505
|
+
cache_type=cache_type,
|
|
506
|
+
**kwargs,
|
|
507
|
+
)
|
|
508
|
+
assert isinstance(
|
|
509
|
+
self.context_manager, PrunedContextManager
|
|
510
|
+
), "context_manager must be PrunedContextManager for PrunedBlocks."
|
|
511
|
+
self.context_manager: PrunedContextManager = (
|
|
512
|
+
self.context_manager
|
|
513
|
+
) # For type hint
|
|
514
|
+
|
|
515
|
+
@torch.compiler.disable
|
|
516
|
+
def _check_cache_type(self):
|
|
517
|
+
assert (
|
|
518
|
+
self.cache_type == CacheType.DBPrune
|
|
519
|
+
), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
|
|
520
|
+
|
|
521
|
+
def forward(
|
|
522
|
+
self,
|
|
523
|
+
hidden_states: torch.Tensor,
|
|
524
|
+
encoder_hidden_states: torch.Tensor,
|
|
525
|
+
*args,
|
|
526
|
+
**kwargs,
|
|
527
|
+
):
|
|
528
|
+
self.pruned_blocks_step: int = 0 # reset for each step
|
|
529
|
+
|
|
530
|
+
# Use it's own cache context.
|
|
531
|
+
try:
|
|
532
|
+
self.context_manager.set_context(self.cache_context)
|
|
533
|
+
self._check_cache_params()
|
|
534
|
+
except ContextNotExistError as e:
|
|
535
|
+
logger.warning(f"Cache context not exist: {e}, skip prune.")
|
|
536
|
+
# Fallback to call all blocks to process the hidden states w/o prune.
|
|
537
|
+
hidden_states, encoder_hidden_states = self.call_blocks(
|
|
538
|
+
hidden_states,
|
|
539
|
+
encoder_hidden_states,
|
|
540
|
+
*args,
|
|
541
|
+
**kwargs,
|
|
542
|
+
)
|
|
543
|
+
return self._process_forward_outputs(
|
|
544
|
+
hidden_states,
|
|
545
|
+
encoder_hidden_states,
|
|
546
|
+
)
|
|
547
|
+
|
|
548
|
+
self.context_manager.mark_step_begin()
|
|
549
|
+
|
|
550
|
+
# Call all blocks with prune strategy to process the hidden states.
|
|
551
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
552
|
+
hidden_states, encoder_hidden_states = self.compute_or_prune(
|
|
553
|
+
i,
|
|
554
|
+
block,
|
|
555
|
+
hidden_states,
|
|
556
|
+
encoder_hidden_states,
|
|
557
|
+
*args,
|
|
558
|
+
**kwargs,
|
|
559
|
+
)
|
|
560
|
+
|
|
561
|
+
self.context_manager.add_pruned_block(self.pruned_blocks_step)
|
|
562
|
+
self.context_manager.add_actual_block(self.num_blocks)
|
|
563
|
+
|
|
564
|
+
return self._process_forward_outputs(
|
|
565
|
+
hidden_states,
|
|
566
|
+
encoder_hidden_states,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
@property
|
|
570
|
+
@torch.compiler.disable
|
|
571
|
+
def num_blocks(self):
|
|
572
|
+
return len(self.transformer_blocks)
|
|
573
|
+
|
|
574
|
+
@torch.compiler.disable
|
|
575
|
+
def _skip_prune(self, block_id: int) -> bool:
|
|
576
|
+
# Wrap for non compiled mode.
|
|
577
|
+
return block_id in self.context_manager.get_non_prune_blocks_ids(
|
|
578
|
+
self.num_blocks
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
@torch.compiler.disable
|
|
582
|
+
def _maybe_prune(
|
|
583
|
+
self,
|
|
584
|
+
block_id: int, # Block index in the transformer blocks
|
|
585
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
586
|
+
prefix: str = "Bn_original", # prev step name for single blocks
|
|
587
|
+
):
|
|
588
|
+
# Wrap for non compiled mode.
|
|
589
|
+
can_use_prune = False
|
|
590
|
+
if not self._skip_prune(block_id):
|
|
591
|
+
can_use_prune = self.context_manager.can_prune(
|
|
592
|
+
hidden_states, # curr step
|
|
593
|
+
parallelized=self._is_parallelized(),
|
|
594
|
+
prefix=prefix, # prev step
|
|
595
|
+
)
|
|
596
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
597
|
+
return can_use_prune
|
|
598
|
+
|
|
599
|
+
def compute_or_prune(
|
|
600
|
+
self,
|
|
601
|
+
block_id: int, # Block index in the transformer blocks
|
|
602
|
+
# Below are the inputs to the block
|
|
603
|
+
block, # The transformer block to be executed
|
|
604
|
+
hidden_states: torch.Tensor,
|
|
605
|
+
encoder_hidden_states: torch.Tensor,
|
|
606
|
+
*args,
|
|
607
|
+
**kwargs,
|
|
608
|
+
):
|
|
609
|
+
original_hidden_states = hidden_states
|
|
610
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
611
|
+
|
|
612
|
+
can_use_prune = self._maybe_prune(
|
|
613
|
+
block_id,
|
|
614
|
+
hidden_states,
|
|
615
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
616
|
+
)
|
|
617
|
+
|
|
618
|
+
# Prune steps: Prune current block and reuse the cached
|
|
619
|
+
# residuals for hidden states approximate.
|
|
620
|
+
torch._dynamo.graph_break()
|
|
621
|
+
if can_use_prune:
|
|
622
|
+
self.context_manager.add_pruned_step()
|
|
623
|
+
hidden_states, encoder_hidden_states = (
|
|
624
|
+
self.context_manager.apply_prune(
|
|
625
|
+
hidden_states,
|
|
626
|
+
encoder_hidden_states,
|
|
627
|
+
prefix=(
|
|
628
|
+
f"{self.cache_prefix}_{block_id}_Bn_residual"
|
|
629
|
+
if self.context_manager.is_cache_residual()
|
|
630
|
+
else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
|
|
631
|
+
),
|
|
632
|
+
encoder_prefix=(
|
|
633
|
+
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
|
|
634
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
635
|
+
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
|
|
636
|
+
),
|
|
637
|
+
)
|
|
638
|
+
)
|
|
639
|
+
torch._dynamo.graph_break()
|
|
640
|
+
else:
|
|
641
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
642
|
+
hidden_states = block(
|
|
643
|
+
hidden_states,
|
|
644
|
+
encoder_hidden_states,
|
|
645
|
+
*args,
|
|
646
|
+
**kwargs,
|
|
647
|
+
)
|
|
648
|
+
hidden_states, encoder_hidden_states = self._process_block_outputs(
|
|
449
649
|
hidden_states, encoder_hidden_states
|
|
450
650
|
)
|
|
651
|
+
if not self._skip_prune(block_id):
|
|
652
|
+
hidden_states = hidden_states.contiguous()
|
|
653
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
654
|
+
|
|
655
|
+
if (
|
|
656
|
+
encoder_hidden_states is not None
|
|
657
|
+
and original_encoder_hidden_states is not None
|
|
658
|
+
):
|
|
659
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
660
|
+
encoder_hidden_states_residual = (
|
|
661
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
662
|
+
)
|
|
663
|
+
else:
|
|
664
|
+
encoder_hidden_states_residual = None
|
|
665
|
+
|
|
666
|
+
self.context_manager.set_Fn_buffer(
|
|
667
|
+
original_hidden_states,
|
|
668
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
669
|
+
)
|
|
670
|
+
if self.context_manager.is_cache_residual():
|
|
671
|
+
self.context_manager.set_Bn_buffer(
|
|
672
|
+
hidden_states_residual,
|
|
673
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
|
|
674
|
+
)
|
|
675
|
+
else:
|
|
676
|
+
self.context_manager.set_Bn_buffer(
|
|
677
|
+
hidden_states,
|
|
678
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
|
|
679
|
+
)
|
|
680
|
+
if encoder_hidden_states_residual is not None:
|
|
681
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
682
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
683
|
+
encoder_hidden_states_residual,
|
|
684
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
|
|
685
|
+
)
|
|
686
|
+
else:
|
|
687
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
688
|
+
encoder_hidden_states_residual,
|
|
689
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
|
|
690
|
+
)
|
|
691
|
+
torch._dynamo.graph_break()
|
|
451
692
|
|
|
452
693
|
return hidden_states, encoder_hidden_states
|
|
@@ -3,34 +3,61 @@ import torch
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
from cache_dit.cache_factory import CachedContext
|
|
5
5
|
from cache_dit.cache_factory import CachedContextManager
|
|
6
|
+
from cache_dit.cache_factory import PrunedContextManager
|
|
6
7
|
|
|
7
8
|
|
|
8
|
-
def
|
|
9
|
+
def apply_stats(
|
|
9
10
|
module: torch.nn.Module | Any,
|
|
10
11
|
cache_context: CachedContext | str = None,
|
|
11
|
-
|
|
12
|
+
context_manager: CachedContextManager | PrunedContextManager = None,
|
|
12
13
|
):
|
|
13
14
|
# Patch the cached stats to the module, the cached stats
|
|
14
15
|
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
15
|
-
if module is None or
|
|
16
|
+
if module is None or context_manager is None:
|
|
16
17
|
return
|
|
17
18
|
|
|
18
19
|
if cache_context is not None:
|
|
19
|
-
|
|
20
|
+
context_manager.set_context(cache_context)
|
|
20
21
|
|
|
21
|
-
#
|
|
22
|
-
module._cached_steps =
|
|
23
|
-
module._residual_diffs =
|
|
24
|
-
module._cfg_cached_steps =
|
|
25
|
-
module._cfg_residual_diffs =
|
|
22
|
+
# Cache stats for Dual Block Cache
|
|
23
|
+
module._cached_steps = context_manager.get_cached_steps()
|
|
24
|
+
module._residual_diffs = context_manager.get_residual_diffs()
|
|
25
|
+
module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
|
|
26
|
+
module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
|
|
27
|
+
# Pruned stats for Dynamic Block Prune
|
|
28
|
+
if not isinstance(context_manager, PrunedContextManager):
|
|
29
|
+
return
|
|
30
|
+
module._pruned_steps = context_manager.get_pruned_steps()
|
|
31
|
+
module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
|
|
32
|
+
module._pruned_blocks = context_manager.get_pruned_blocks()
|
|
33
|
+
module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
|
|
34
|
+
module._actual_blocks = context_manager.get_actual_blocks()
|
|
35
|
+
module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
|
|
36
|
+
# Caculate pruned ratio
|
|
37
|
+
if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
|
|
38
|
+
module._pruned_ratio = sum(module._pruned_blocks) / sum(
|
|
39
|
+
module._actual_blocks
|
|
40
|
+
)
|
|
41
|
+
else:
|
|
42
|
+
module._pruned_ratio = None
|
|
43
|
+
if (
|
|
44
|
+
len(module._cfg_pruned_blocks) > 0
|
|
45
|
+
and sum(module._cfg_actual_blocks) > 0
|
|
46
|
+
):
|
|
47
|
+
module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
|
|
48
|
+
module._cfg_actual_blocks
|
|
49
|
+
)
|
|
50
|
+
else:
|
|
51
|
+
module._cfg_pruned_ratio = None
|
|
26
52
|
|
|
27
53
|
|
|
28
|
-
def
|
|
54
|
+
def remove_stats(
|
|
29
55
|
module: torch.nn.Module | Any,
|
|
30
56
|
):
|
|
31
57
|
if module is None:
|
|
32
58
|
return
|
|
33
59
|
|
|
60
|
+
# Dual Block Cache
|
|
34
61
|
if hasattr(module, "_cached_steps"):
|
|
35
62
|
del module._cached_steps
|
|
36
63
|
if hasattr(module, "_residual_diffs"):
|
|
@@ -39,3 +66,21 @@ def remove_cached_stats(
|
|
|
39
66
|
del module._cfg_cached_steps
|
|
40
67
|
if hasattr(module, "_cfg_residual_diffs"):
|
|
41
68
|
del module._cfg_residual_diffs
|
|
69
|
+
|
|
70
|
+
# Dynamic Block Prune
|
|
71
|
+
if hasattr(module, "_pruned_steps"):
|
|
72
|
+
del module._pruned_steps
|
|
73
|
+
if hasattr(module, "_cfg_pruned_steps"):
|
|
74
|
+
del module._cfg_pruned_steps
|
|
75
|
+
if hasattr(module, "_pruned_blocks"):
|
|
76
|
+
del module._pruned_blocks
|
|
77
|
+
if hasattr(module, "_cfg_pruned_blocks"):
|
|
78
|
+
del module._cfg_pruned_blocks
|
|
79
|
+
if hasattr(module, "_actual_blocks"):
|
|
80
|
+
del module._actual_blocks
|
|
81
|
+
if hasattr(module, "_cfg_actual_blocks"):
|
|
82
|
+
del module._cfg_actual_blocks
|
|
83
|
+
if hasattr(module, "_pruned_ratio"):
|
|
84
|
+
del module._pruned_ratio
|
|
85
|
+
if hasattr(module, "_cfg_pruned_ratio"):
|
|
86
|
+
del module._cfg_pruned_ratio
|
|
@@ -5,11 +5,24 @@ from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
|
5
5
|
TaylorSeerCalibratorConfig,
|
|
6
6
|
FoCaCalibratorConfig,
|
|
7
7
|
)
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.cache_config import (
|
|
9
|
+
BasicCacheConfig,
|
|
10
|
+
DBCacheConfig,
|
|
11
|
+
)
|
|
8
12
|
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
9
13
|
CachedContext,
|
|
10
|
-
BasicCacheConfig,
|
|
11
14
|
)
|
|
12
15
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
13
16
|
CachedContextManager,
|
|
14
|
-
|
|
17
|
+
ContextNotExistError,
|
|
18
|
+
)
|
|
19
|
+
from cache_dit.cache_factory.cache_contexts.prune_config import DBPruneConfig
|
|
20
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import (
|
|
21
|
+
PrunedContext,
|
|
22
|
+
)
|
|
23
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
24
|
+
PrunedContextManager,
|
|
25
|
+
)
|
|
26
|
+
from cache_dit.cache_factory.cache_contexts.context_manager import (
|
|
27
|
+
ContextManager,
|
|
15
28
|
)
|