cache-dit 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/block_adapters/__init__.py +22 -5
- cache_dit/cache_factory/block_adapters/block_adapters.py +230 -25
- cache_dit/cache_factory/cache_adapters.py +209 -94
- cache_dit/cache_factory/cache_blocks/__init__.py +55 -4
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +36 -37
- cache_dit/cache_factory/cache_blocks/pattern_base.py +83 -76
- cache_dit/cache_factory/cache_blocks/utils.py +10 -8
- cache_dit/cache_factory/cache_contexts/__init__.py +4 -1
- cache_dit/cache_factory/cache_contexts/cache_context.py +14 -876
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +10 -13
- cache_dit/utils.py +7 -10
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/METADATA +30 -24
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/RECORD +21 -21
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,10 @@ import inspect
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
|
-
from cache_dit.cache_factory import CachedContext
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
|
+
CachedContextManager,
|
|
8
|
+
)
|
|
6
9
|
from cache_dit.cache_factory import ForwardPattern
|
|
7
10
|
from cache_dit.logger import init_logger
|
|
8
11
|
|
|
@@ -18,29 +21,34 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
18
21
|
|
|
19
22
|
def __init__(
|
|
20
23
|
self,
|
|
24
|
+
# 0. Transformer blocks configuration
|
|
21
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
22
|
-
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
23
|
-
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
24
|
-
blocks_name: str,
|
|
25
|
-
# Usually, blocks_name, etc.
|
|
26
|
-
cache_context: str,
|
|
27
|
-
*,
|
|
28
26
|
transformer: torch.nn.Module = None,
|
|
29
27
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
30
28
|
check_num_outputs: bool = True,
|
|
29
|
+
# 1. Cache context configuration
|
|
30
|
+
cache_prefix: str = None, # maybe un-need.
|
|
31
|
+
cache_context: CachedContext | str = None,
|
|
32
|
+
cache_manager: CachedContextManager = None,
|
|
33
|
+
**kwargs,
|
|
31
34
|
):
|
|
32
35
|
super().__init__()
|
|
33
36
|
|
|
37
|
+
# 0. Transformer blocks configuration
|
|
34
38
|
self.transformer = transformer
|
|
35
39
|
self.transformer_blocks = transformer_blocks
|
|
36
|
-
self.blocks_name = blocks_name
|
|
37
|
-
self.cache_context = cache_context
|
|
38
40
|
self.forward_pattern = forward_pattern
|
|
39
41
|
self.check_num_outputs = check_num_outputs
|
|
42
|
+
# 1. Cache context configuration
|
|
43
|
+
self.cache_prefix = cache_prefix
|
|
44
|
+
self.cache_context = cache_context
|
|
45
|
+
self.cache_manager = cache_manager
|
|
46
|
+
|
|
40
47
|
self._check_forward_pattern()
|
|
41
48
|
logger.info(
|
|
42
49
|
f"Match Cached Blocks: {self.__class__.__name__}, for "
|
|
43
|
-
f"{self.
|
|
50
|
+
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
51
|
+
f"cache_manager: {self.cache_manager.name}."
|
|
44
52
|
)
|
|
45
53
|
|
|
46
54
|
def _check_forward_pattern(self):
|
|
@@ -79,9 +87,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
79
87
|
*args,
|
|
80
88
|
**kwargs,
|
|
81
89
|
):
|
|
82
|
-
|
|
83
|
-
self.cache_context,
|
|
84
|
-
)
|
|
90
|
+
self.cache_manager.set_context(self.cache_context)
|
|
85
91
|
|
|
86
92
|
original_hidden_states = hidden_states
|
|
87
93
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -96,39 +102,39 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
96
102
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
97
103
|
del original_hidden_states
|
|
98
104
|
|
|
99
|
-
|
|
105
|
+
self.cache_manager.mark_step_begin()
|
|
100
106
|
# Residual L1 diff or Hidden States L1 diff
|
|
101
|
-
can_use_cache =
|
|
107
|
+
can_use_cache = self.cache_manager.can_cache(
|
|
102
108
|
(
|
|
103
109
|
Fn_hidden_states_residual
|
|
104
|
-
if not
|
|
110
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
105
111
|
else hidden_states
|
|
106
112
|
),
|
|
107
113
|
parallelized=self._is_parallelized(),
|
|
108
114
|
prefix=(
|
|
109
|
-
f"{self.
|
|
110
|
-
if not
|
|
111
|
-
else f"{self.
|
|
115
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
116
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
117
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
112
118
|
),
|
|
113
119
|
)
|
|
114
120
|
|
|
115
121
|
torch._dynamo.graph_break()
|
|
116
122
|
if can_use_cache:
|
|
117
|
-
|
|
123
|
+
self.cache_manager.add_cached_step()
|
|
118
124
|
del Fn_hidden_states_residual
|
|
119
125
|
hidden_states, encoder_hidden_states = (
|
|
120
|
-
|
|
126
|
+
self.cache_manager.apply_cache(
|
|
121
127
|
hidden_states,
|
|
122
128
|
encoder_hidden_states,
|
|
123
129
|
prefix=(
|
|
124
|
-
f"{self.
|
|
125
|
-
if
|
|
126
|
-
else f"{self.
|
|
130
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
131
|
+
if self.cache_manager.is_cache_residual()
|
|
132
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
127
133
|
),
|
|
128
134
|
encoder_prefix=(
|
|
129
|
-
f"{self.
|
|
130
|
-
if
|
|
131
|
-
else f"{self.
|
|
135
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
136
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
137
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
132
138
|
),
|
|
133
139
|
)
|
|
134
140
|
)
|
|
@@ -142,15 +148,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
142
148
|
**kwargs,
|
|
143
149
|
)
|
|
144
150
|
else:
|
|
145
|
-
|
|
151
|
+
self.cache_manager.set_Fn_buffer(
|
|
146
152
|
Fn_hidden_states_residual,
|
|
147
|
-
prefix=f"{self.
|
|
153
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
148
154
|
)
|
|
149
|
-
if
|
|
155
|
+
if self.cache_manager.is_l1_diff_enabled():
|
|
150
156
|
# for hidden states L1 diff
|
|
151
|
-
|
|
157
|
+
self.cache_manager.set_Fn_buffer(
|
|
152
158
|
hidden_states,
|
|
153
|
-
f"{self.
|
|
159
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
154
160
|
)
|
|
155
161
|
del Fn_hidden_states_residual
|
|
156
162
|
torch._dynamo.graph_break()
|
|
@@ -166,27 +172,27 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
166
172
|
**kwargs,
|
|
167
173
|
)
|
|
168
174
|
torch._dynamo.graph_break()
|
|
169
|
-
if
|
|
170
|
-
|
|
175
|
+
if self.cache_manager.is_cache_residual():
|
|
176
|
+
self.cache_manager.set_Bn_buffer(
|
|
171
177
|
hidden_states_residual,
|
|
172
|
-
prefix=f"{self.
|
|
178
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
173
179
|
)
|
|
174
180
|
else:
|
|
175
181
|
# TaylorSeer
|
|
176
|
-
|
|
182
|
+
self.cache_manager.set_Bn_buffer(
|
|
177
183
|
hidden_states,
|
|
178
|
-
prefix=f"{self.
|
|
184
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
179
185
|
)
|
|
180
|
-
if
|
|
181
|
-
|
|
186
|
+
if self.cache_manager.is_encoder_cache_residual():
|
|
187
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
182
188
|
encoder_hidden_states_residual,
|
|
183
|
-
prefix=f"{self.
|
|
189
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
184
190
|
)
|
|
185
191
|
else:
|
|
186
192
|
# TaylorSeer
|
|
187
|
-
|
|
193
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
188
194
|
encoder_hidden_states,
|
|
189
|
-
prefix=f"{self.
|
|
195
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
190
196
|
)
|
|
191
197
|
torch._dynamo.graph_break()
|
|
192
198
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -232,10 +238,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
232
238
|
# If so, we can skip some Bn blocks and directly
|
|
233
239
|
# use the cached values.
|
|
234
240
|
return (
|
|
235
|
-
|
|
241
|
+
self.cache_manager.get_current_step()
|
|
242
|
+
in self.cache_manager.get_cached_steps()
|
|
236
243
|
) or (
|
|
237
|
-
|
|
238
|
-
in
|
|
244
|
+
self.cache_manager.get_current_step()
|
|
245
|
+
in self.cache_manager.get_cfg_cached_steps()
|
|
239
246
|
)
|
|
240
247
|
|
|
241
248
|
@torch.compiler.disable
|
|
@@ -244,20 +251,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
244
251
|
# more stable diff calculation.
|
|
245
252
|
# Fn: [0,...,n-1]
|
|
246
253
|
selected_Fn_blocks = self.transformer_blocks[
|
|
247
|
-
:
|
|
254
|
+
: self.cache_manager.Fn_compute_blocks()
|
|
248
255
|
]
|
|
249
256
|
return selected_Fn_blocks
|
|
250
257
|
|
|
251
258
|
@torch.compiler.disable
|
|
252
259
|
def _Mn_blocks(self): # middle blocks
|
|
253
260
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
254
|
-
if
|
|
261
|
+
if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
255
262
|
selected_Mn_blocks = self.transformer_blocks[
|
|
256
|
-
|
|
263
|
+
self.cache_manager.Fn_compute_blocks() :
|
|
257
264
|
]
|
|
258
265
|
else:
|
|
259
266
|
selected_Mn_blocks = self.transformer_blocks[
|
|
260
|
-
|
|
267
|
+
self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
|
|
261
268
|
]
|
|
262
269
|
return selected_Mn_blocks
|
|
263
270
|
|
|
@@ -265,7 +272,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
265
272
|
def _Bn_blocks(self):
|
|
266
273
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
267
274
|
selected_Bn_blocks = self.transformer_blocks[
|
|
268
|
-
-
|
|
275
|
+
-self.cache_manager.Bn_compute_blocks() :
|
|
269
276
|
]
|
|
270
277
|
return selected_Bn_blocks
|
|
271
278
|
|
|
@@ -276,10 +283,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
276
283
|
*args,
|
|
277
284
|
**kwargs,
|
|
278
285
|
):
|
|
279
|
-
assert
|
|
286
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
280
287
|
self.transformer_blocks
|
|
281
288
|
), (
|
|
282
|
-
f"Fn_compute_blocks {
|
|
289
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
283
290
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
284
291
|
)
|
|
285
292
|
for block in self._Fn_blocks():
|
|
@@ -376,7 +383,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
376
383
|
)
|
|
377
384
|
# Cache residuals for the non-compute Bn blocks for
|
|
378
385
|
# subsequent cache steps.
|
|
379
|
-
if block_id not in
|
|
386
|
+
if block_id not in self.cache_manager.Bn_compute_blocks_ids():
|
|
380
387
|
Bn_i_hidden_states_residual = (
|
|
381
388
|
hidden_states - Bn_i_original_hidden_states
|
|
382
389
|
)
|
|
@@ -385,22 +392,22 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
385
392
|
)
|
|
386
393
|
|
|
387
394
|
# Save original_hidden_states for diff calculation.
|
|
388
|
-
|
|
395
|
+
self.cache_manager.set_Bn_buffer(
|
|
389
396
|
Bn_i_original_hidden_states,
|
|
390
|
-
prefix=f"{self.
|
|
397
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
391
398
|
)
|
|
392
|
-
|
|
399
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
393
400
|
Bn_i_original_encoder_hidden_states,
|
|
394
|
-
prefix=f"{self.
|
|
401
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
395
402
|
)
|
|
396
403
|
|
|
397
|
-
|
|
404
|
+
self.cache_manager.set_Bn_buffer(
|
|
398
405
|
Bn_i_hidden_states_residual,
|
|
399
|
-
prefix=f"{self.
|
|
406
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
400
407
|
)
|
|
401
|
-
|
|
408
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
402
409
|
Bn_i_encoder_hidden_states_residual,
|
|
403
|
-
prefix=f"{self.
|
|
410
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
404
411
|
)
|
|
405
412
|
del Bn_i_hidden_states_residual
|
|
406
413
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -411,7 +418,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
411
418
|
else:
|
|
412
419
|
# Cache steps: Reuse the cached residuals.
|
|
413
420
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
414
|
-
if block_id in
|
|
421
|
+
if block_id in self.cache_manager.Bn_compute_blocks_ids():
|
|
415
422
|
hidden_states = block(
|
|
416
423
|
hidden_states,
|
|
417
424
|
encoder_hidden_states,
|
|
@@ -429,25 +436,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
429
436
|
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
430
437
|
# Use the cached residuals instead.
|
|
431
438
|
# Check if can use the cached residuals.
|
|
432
|
-
if
|
|
439
|
+
if self.cache_manager.can_cache(
|
|
433
440
|
hidden_states, # curr step
|
|
434
441
|
parallelized=self._is_parallelized(),
|
|
435
|
-
threshold=
|
|
436
|
-
prefix=f"{self.
|
|
442
|
+
threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
|
|
443
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
|
|
437
444
|
):
|
|
438
445
|
hidden_states, encoder_hidden_states = (
|
|
439
|
-
|
|
446
|
+
self.cache_manager.apply_cache(
|
|
440
447
|
hidden_states,
|
|
441
448
|
encoder_hidden_states,
|
|
442
449
|
prefix=(
|
|
443
|
-
f"{self.
|
|
444
|
-
if
|
|
445
|
-
else f"{self.
|
|
450
|
+
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
451
|
+
if self.cache_manager.is_cache_residual()
|
|
452
|
+
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
446
453
|
),
|
|
447
454
|
encoder_prefix=(
|
|
448
|
-
f"{self.
|
|
449
|
-
if
|
|
450
|
-
else f"{self.
|
|
455
|
+
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
456
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
457
|
+
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
451
458
|
),
|
|
452
459
|
)
|
|
453
460
|
)
|
|
@@ -474,16 +481,16 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
474
481
|
*args,
|
|
475
482
|
**kwargs,
|
|
476
483
|
):
|
|
477
|
-
if
|
|
484
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
478
485
|
return hidden_states, encoder_hidden_states
|
|
479
486
|
|
|
480
|
-
assert
|
|
487
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
481
488
|
self.transformer_blocks
|
|
482
489
|
), (
|
|
483
|
-
f"Bn_compute_blocks {
|
|
490
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
484
491
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
485
492
|
)
|
|
486
|
-
if len(
|
|
493
|
+
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
487
494
|
for i, block in enumerate(self._Bn_blocks()):
|
|
488
495
|
hidden_states, encoder_hidden_states = (
|
|
489
496
|
self._compute_or_cache_block(
|
|
@@ -2,22 +2,24 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
from cache_dit.cache_factory import CachedContext
|
|
5
|
+
from cache_dit.cache_factory import CachedContextManager
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
@torch.compiler.disable
|
|
8
8
|
def patch_cached_stats(
|
|
9
|
-
module: torch.nn.Module | Any,
|
|
9
|
+
module: torch.nn.Module | Any,
|
|
10
|
+
cache_context: CachedContext | str = None,
|
|
11
|
+
cache_manager: CachedContextManager = None,
|
|
10
12
|
):
|
|
11
13
|
# Patch the cached stats to the module, the cached stats
|
|
12
14
|
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
13
|
-
if module is None:
|
|
15
|
+
if module is None or cache_manager is None:
|
|
14
16
|
return
|
|
15
17
|
|
|
16
18
|
if cache_context is not None:
|
|
17
|
-
|
|
19
|
+
cache_manager.set_context(cache_context)
|
|
18
20
|
|
|
19
21
|
# TODO: Patch more cached stats to the module
|
|
20
|
-
module._cached_steps =
|
|
21
|
-
module._residual_diffs =
|
|
22
|
-
module._cfg_cached_steps =
|
|
23
|
-
module._cfg_residual_diffs =
|
|
22
|
+
module._cached_steps = cache_manager.get_cached_steps()
|
|
23
|
+
module._residual_diffs = cache_manager.get_residual_diffs()
|
|
24
|
+
module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
|
|
25
|
+
module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
|
|
@@ -1,2 +1,5 @@
|
|
|
1
1
|
# namespace alias: for _CachedContext and many others' cache context funcs.
|
|
2
|
-
|
|
2
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
3
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
4
|
+
CachedContextManager,
|
|
5
|
+
)
|