cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +2 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -0
- cache_dit/cache_factory/block_adapters/__init__.py +105 -111
- cache_dit/cache_factory/block_adapters/block_adapters.py +314 -41
- cache_dit/cache_factory/block_adapters/block_registers.py +15 -6
- cache_dit/cache_factory/cache_adapters.py +244 -116
- cache_dit/cache_factory/cache_blocks/__init__.py +55 -4
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +36 -37
- cache_dit/cache_factory/cache_blocks/pattern_base.py +83 -76
- cache_dit/cache_factory/cache_blocks/utils.py +26 -8
- cache_dit/cache_factory/cache_contexts/__init__.py +4 -1
- cache_dit/cache_factory/cache_contexts/cache_context.py +14 -876
- cache_dit/cache_factory/cache_contexts/cache_manager.py +847 -0
- cache_dit/cache_factory/cache_interface.py +91 -24
- cache_dit/cache_factory/patch_functors/functor_chroma.py +1 -1
- cache_dit/cache_factory/patch_functors/functor_flux.py +1 -1
- cache_dit/utils.py +164 -58
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/METADATA +59 -34
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/RECORD +24 -24
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.27.dist-info → cache_dit-0.2.29.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from cache_dit.cache_factory import CachedContext
|
|
4
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
5
4
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
6
5
|
CachedBlocks_Pattern_Base,
|
|
@@ -24,7 +23,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
24
23
|
**kwargs,
|
|
25
24
|
):
|
|
26
25
|
# Use it's own cache context.
|
|
27
|
-
|
|
26
|
+
self.cache_manager.set_context(
|
|
28
27
|
self.cache_context,
|
|
29
28
|
)
|
|
30
29
|
|
|
@@ -41,40 +40,40 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
41
40
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
42
41
|
del original_hidden_states
|
|
43
42
|
|
|
44
|
-
|
|
43
|
+
self.cache_manager.mark_step_begin()
|
|
45
44
|
# Residual L1 diff or Hidden States L1 diff
|
|
46
|
-
can_use_cache =
|
|
45
|
+
can_use_cache = self.cache_manager.can_cache(
|
|
47
46
|
(
|
|
48
47
|
Fn_hidden_states_residual
|
|
49
|
-
if not
|
|
48
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
50
49
|
else hidden_states
|
|
51
50
|
),
|
|
52
51
|
parallelized=self._is_parallelized(),
|
|
53
52
|
prefix=(
|
|
54
|
-
f"{self.
|
|
55
|
-
if not
|
|
56
|
-
else f"{self.
|
|
53
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
54
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
55
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
57
56
|
),
|
|
58
57
|
)
|
|
59
58
|
|
|
60
59
|
torch._dynamo.graph_break()
|
|
61
60
|
if can_use_cache:
|
|
62
|
-
|
|
61
|
+
self.cache_manager.add_cached_step()
|
|
63
62
|
del Fn_hidden_states_residual
|
|
64
63
|
hidden_states, encoder_hidden_states = (
|
|
65
|
-
|
|
64
|
+
self.cache_manager.apply_cache(
|
|
66
65
|
hidden_states,
|
|
67
66
|
# None Pattern 3, else 4, 5
|
|
68
67
|
encoder_hidden_states,
|
|
69
68
|
prefix=(
|
|
70
|
-
f"{self.
|
|
71
|
-
if
|
|
72
|
-
else f"{self.
|
|
69
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
70
|
+
if self.cache_manager.is_cache_residual()
|
|
71
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
73
72
|
),
|
|
74
73
|
encoder_prefix=(
|
|
75
|
-
f"{self.
|
|
76
|
-
if
|
|
77
|
-
else f"{self.
|
|
74
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
75
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
76
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
78
77
|
),
|
|
79
78
|
)
|
|
80
79
|
)
|
|
@@ -88,15 +87,15 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
88
87
|
**kwargs,
|
|
89
88
|
)
|
|
90
89
|
else:
|
|
91
|
-
|
|
90
|
+
self.cache_manager.set_Fn_buffer(
|
|
92
91
|
Fn_hidden_states_residual,
|
|
93
|
-
prefix=f"{self.
|
|
92
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
94
93
|
)
|
|
95
|
-
if
|
|
94
|
+
if self.cache_manager.is_l1_diff_enabled():
|
|
96
95
|
# for hidden states L1 diff
|
|
97
|
-
|
|
96
|
+
self.cache_manager.set_Fn_buffer(
|
|
98
97
|
hidden_states,
|
|
99
|
-
f"{self.
|
|
98
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
100
99
|
)
|
|
101
100
|
del Fn_hidden_states_residual
|
|
102
101
|
torch._dynamo.graph_break()
|
|
@@ -114,29 +113,29 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
114
113
|
**kwargs,
|
|
115
114
|
)
|
|
116
115
|
torch._dynamo.graph_break()
|
|
117
|
-
if
|
|
118
|
-
|
|
116
|
+
if self.cache_manager.is_cache_residual():
|
|
117
|
+
self.cache_manager.set_Bn_buffer(
|
|
119
118
|
hidden_states_residual,
|
|
120
|
-
prefix=f"{self.
|
|
119
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
121
120
|
)
|
|
122
121
|
else:
|
|
123
122
|
# TaylorSeer
|
|
124
|
-
|
|
123
|
+
self.cache_manager.set_Bn_buffer(
|
|
125
124
|
hidden_states,
|
|
126
|
-
prefix=f"{self.
|
|
125
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
127
126
|
)
|
|
128
|
-
if
|
|
129
|
-
|
|
127
|
+
if self.cache_manager.is_encoder_cache_residual():
|
|
128
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
130
129
|
# None Pattern 3, else 4, 5
|
|
131
130
|
encoder_hidden_states_residual,
|
|
132
|
-
prefix=f"{self.
|
|
131
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
133
132
|
)
|
|
134
133
|
else:
|
|
135
134
|
# TaylorSeer
|
|
136
|
-
|
|
135
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
137
136
|
# None Pattern 3, else 4, 5
|
|
138
137
|
encoder_hidden_states,
|
|
139
|
-
prefix=f"{self.
|
|
138
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
140
139
|
)
|
|
141
140
|
torch._dynamo.graph_break()
|
|
142
141
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -167,10 +166,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
167
166
|
*args,
|
|
168
167
|
**kwargs,
|
|
169
168
|
):
|
|
170
|
-
assert
|
|
169
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
171
170
|
self.transformer_blocks
|
|
172
171
|
), (
|
|
173
|
-
f"Fn_compute_blocks {
|
|
172
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
174
173
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
175
174
|
)
|
|
176
175
|
encoder_hidden_states = None # Pattern 3
|
|
@@ -242,16 +241,16 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
242
241
|
*args,
|
|
243
242
|
**kwargs,
|
|
244
243
|
):
|
|
245
|
-
if
|
|
244
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
246
245
|
return hidden_states, encoder_hidden_states
|
|
247
246
|
|
|
248
|
-
assert
|
|
247
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
249
248
|
self.transformer_blocks
|
|
250
249
|
), (
|
|
251
|
-
f"Bn_compute_blocks {
|
|
250
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
252
251
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
253
252
|
)
|
|
254
|
-
if len(
|
|
253
|
+
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
255
254
|
raise ValueError(
|
|
256
255
|
f"Bn_compute_blocks_ids is not support for "
|
|
257
256
|
f"patterns: {self._supported_patterns}."
|
|
@@ -2,7 +2,10 @@ import inspect
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
|
-
from cache_dit.cache_factory import CachedContext
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
7
|
+
CachedContextManager,
|
|
8
|
+
)
|
|
6
9
|
from cache_dit.cache_factory import ForwardPattern
|
|
7
10
|
from cache_dit.logger import init_logger
|
|
8
11
|
|
|
@@ -18,29 +21,34 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
18
21
|
|
|
19
22
|
def __init__(
|
|
20
23
|
self,
|
|
24
|
+
# 0. Transformer blocks configuration
|
|
21
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
22
|
-
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
23
|
-
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
24
|
-
blocks_name: str,
|
|
25
|
-
# Usually, blocks_name, etc.
|
|
26
|
-
cache_context: str,
|
|
27
|
-
*,
|
|
28
26
|
transformer: torch.nn.Module = None,
|
|
29
27
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
30
28
|
check_num_outputs: bool = True,
|
|
29
|
+
# 1. Cache context configuration
|
|
30
|
+
cache_prefix: str = None, # maybe un-need.
|
|
31
|
+
cache_context: CachedContext | str = None,
|
|
32
|
+
cache_manager: CachedContextManager = None,
|
|
33
|
+
**kwargs,
|
|
31
34
|
):
|
|
32
35
|
super().__init__()
|
|
33
36
|
|
|
37
|
+
# 0. Transformer blocks configuration
|
|
34
38
|
self.transformer = transformer
|
|
35
39
|
self.transformer_blocks = transformer_blocks
|
|
36
|
-
self.blocks_name = blocks_name
|
|
37
|
-
self.cache_context = cache_context
|
|
38
40
|
self.forward_pattern = forward_pattern
|
|
39
41
|
self.check_num_outputs = check_num_outputs
|
|
42
|
+
# 1. Cache context configuration
|
|
43
|
+
self.cache_prefix = cache_prefix
|
|
44
|
+
self.cache_context = cache_context
|
|
45
|
+
self.cache_manager = cache_manager
|
|
46
|
+
|
|
40
47
|
self._check_forward_pattern()
|
|
41
48
|
logger.info(
|
|
42
49
|
f"Match Cached Blocks: {self.__class__.__name__}, for "
|
|
43
|
-
f"{self.
|
|
50
|
+
f"{self.cache_prefix}, cache_context: {self.cache_context}, "
|
|
51
|
+
f"cache_manager: {self.cache_manager.name}."
|
|
44
52
|
)
|
|
45
53
|
|
|
46
54
|
def _check_forward_pattern(self):
|
|
@@ -79,9 +87,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
79
87
|
*args,
|
|
80
88
|
**kwargs,
|
|
81
89
|
):
|
|
82
|
-
|
|
83
|
-
self.cache_context,
|
|
84
|
-
)
|
|
90
|
+
self.cache_manager.set_context(self.cache_context)
|
|
85
91
|
|
|
86
92
|
original_hidden_states = hidden_states
|
|
87
93
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -96,39 +102,39 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
96
102
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
97
103
|
del original_hidden_states
|
|
98
104
|
|
|
99
|
-
|
|
105
|
+
self.cache_manager.mark_step_begin()
|
|
100
106
|
# Residual L1 diff or Hidden States L1 diff
|
|
101
|
-
can_use_cache =
|
|
107
|
+
can_use_cache = self.cache_manager.can_cache(
|
|
102
108
|
(
|
|
103
109
|
Fn_hidden_states_residual
|
|
104
|
-
if not
|
|
110
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
105
111
|
else hidden_states
|
|
106
112
|
),
|
|
107
113
|
parallelized=self._is_parallelized(),
|
|
108
114
|
prefix=(
|
|
109
|
-
f"{self.
|
|
110
|
-
if not
|
|
111
|
-
else f"{self.
|
|
115
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
116
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
117
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
112
118
|
),
|
|
113
119
|
)
|
|
114
120
|
|
|
115
121
|
torch._dynamo.graph_break()
|
|
116
122
|
if can_use_cache:
|
|
117
|
-
|
|
123
|
+
self.cache_manager.add_cached_step()
|
|
118
124
|
del Fn_hidden_states_residual
|
|
119
125
|
hidden_states, encoder_hidden_states = (
|
|
120
|
-
|
|
126
|
+
self.cache_manager.apply_cache(
|
|
121
127
|
hidden_states,
|
|
122
128
|
encoder_hidden_states,
|
|
123
129
|
prefix=(
|
|
124
|
-
f"{self.
|
|
125
|
-
if
|
|
126
|
-
else f"{self.
|
|
130
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
131
|
+
if self.cache_manager.is_cache_residual()
|
|
132
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
127
133
|
),
|
|
128
134
|
encoder_prefix=(
|
|
129
|
-
f"{self.
|
|
130
|
-
if
|
|
131
|
-
else f"{self.
|
|
135
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
136
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
137
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
132
138
|
),
|
|
133
139
|
)
|
|
134
140
|
)
|
|
@@ -142,15 +148,15 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
142
148
|
**kwargs,
|
|
143
149
|
)
|
|
144
150
|
else:
|
|
145
|
-
|
|
151
|
+
self.cache_manager.set_Fn_buffer(
|
|
146
152
|
Fn_hidden_states_residual,
|
|
147
|
-
prefix=f"{self.
|
|
153
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
148
154
|
)
|
|
149
|
-
if
|
|
155
|
+
if self.cache_manager.is_l1_diff_enabled():
|
|
150
156
|
# for hidden states L1 diff
|
|
151
|
-
|
|
157
|
+
self.cache_manager.set_Fn_buffer(
|
|
152
158
|
hidden_states,
|
|
153
|
-
f"{self.
|
|
159
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
154
160
|
)
|
|
155
161
|
del Fn_hidden_states_residual
|
|
156
162
|
torch._dynamo.graph_break()
|
|
@@ -166,27 +172,27 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
166
172
|
**kwargs,
|
|
167
173
|
)
|
|
168
174
|
torch._dynamo.graph_break()
|
|
169
|
-
if
|
|
170
|
-
|
|
175
|
+
if self.cache_manager.is_cache_residual():
|
|
176
|
+
self.cache_manager.set_Bn_buffer(
|
|
171
177
|
hidden_states_residual,
|
|
172
|
-
prefix=f"{self.
|
|
178
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
173
179
|
)
|
|
174
180
|
else:
|
|
175
181
|
# TaylorSeer
|
|
176
|
-
|
|
182
|
+
self.cache_manager.set_Bn_buffer(
|
|
177
183
|
hidden_states,
|
|
178
|
-
prefix=f"{self.
|
|
184
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
179
185
|
)
|
|
180
|
-
if
|
|
181
|
-
|
|
186
|
+
if self.cache_manager.is_encoder_cache_residual():
|
|
187
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
182
188
|
encoder_hidden_states_residual,
|
|
183
|
-
prefix=f"{self.
|
|
189
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
184
190
|
)
|
|
185
191
|
else:
|
|
186
192
|
# TaylorSeer
|
|
187
|
-
|
|
193
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
188
194
|
encoder_hidden_states,
|
|
189
|
-
prefix=f"{self.
|
|
195
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
190
196
|
)
|
|
191
197
|
torch._dynamo.graph_break()
|
|
192
198
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -232,10 +238,11 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
232
238
|
# If so, we can skip some Bn blocks and directly
|
|
233
239
|
# use the cached values.
|
|
234
240
|
return (
|
|
235
|
-
|
|
241
|
+
self.cache_manager.get_current_step()
|
|
242
|
+
in self.cache_manager.get_cached_steps()
|
|
236
243
|
) or (
|
|
237
|
-
|
|
238
|
-
in
|
|
244
|
+
self.cache_manager.get_current_step()
|
|
245
|
+
in self.cache_manager.get_cfg_cached_steps()
|
|
239
246
|
)
|
|
240
247
|
|
|
241
248
|
@torch.compiler.disable
|
|
@@ -244,20 +251,20 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
244
251
|
# more stable diff calculation.
|
|
245
252
|
# Fn: [0,...,n-1]
|
|
246
253
|
selected_Fn_blocks = self.transformer_blocks[
|
|
247
|
-
:
|
|
254
|
+
: self.cache_manager.Fn_compute_blocks()
|
|
248
255
|
]
|
|
249
256
|
return selected_Fn_blocks
|
|
250
257
|
|
|
251
258
|
@torch.compiler.disable
|
|
252
259
|
def _Mn_blocks(self): # middle blocks
|
|
253
260
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
254
|
-
if
|
|
261
|
+
if self.cache_manager.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
255
262
|
selected_Mn_blocks = self.transformer_blocks[
|
|
256
|
-
|
|
263
|
+
self.cache_manager.Fn_compute_blocks() :
|
|
257
264
|
]
|
|
258
265
|
else:
|
|
259
266
|
selected_Mn_blocks = self.transformer_blocks[
|
|
260
|
-
|
|
267
|
+
self.cache_manager.Fn_compute_blocks() : -self.cache_manager.Bn_compute_blocks()
|
|
261
268
|
]
|
|
262
269
|
return selected_Mn_blocks
|
|
263
270
|
|
|
@@ -265,7 +272,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
265
272
|
def _Bn_blocks(self):
|
|
266
273
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
267
274
|
selected_Bn_blocks = self.transformer_blocks[
|
|
268
|
-
-
|
|
275
|
+
-self.cache_manager.Bn_compute_blocks() :
|
|
269
276
|
]
|
|
270
277
|
return selected_Bn_blocks
|
|
271
278
|
|
|
@@ -276,10 +283,10 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
276
283
|
*args,
|
|
277
284
|
**kwargs,
|
|
278
285
|
):
|
|
279
|
-
assert
|
|
286
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
280
287
|
self.transformer_blocks
|
|
281
288
|
), (
|
|
282
|
-
f"Fn_compute_blocks {
|
|
289
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
283
290
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
284
291
|
)
|
|
285
292
|
for block in self._Fn_blocks():
|
|
@@ -376,7 +383,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
376
383
|
)
|
|
377
384
|
# Cache residuals for the non-compute Bn blocks for
|
|
378
385
|
# subsequent cache steps.
|
|
379
|
-
if block_id not in
|
|
386
|
+
if block_id not in self.cache_manager.Bn_compute_blocks_ids():
|
|
380
387
|
Bn_i_hidden_states_residual = (
|
|
381
388
|
hidden_states - Bn_i_original_hidden_states
|
|
382
389
|
)
|
|
@@ -385,22 +392,22 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
385
392
|
)
|
|
386
393
|
|
|
387
394
|
# Save original_hidden_states for diff calculation.
|
|
388
|
-
|
|
395
|
+
self.cache_manager.set_Bn_buffer(
|
|
389
396
|
Bn_i_original_hidden_states,
|
|
390
|
-
prefix=f"{self.
|
|
397
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
391
398
|
)
|
|
392
|
-
|
|
399
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
393
400
|
Bn_i_original_encoder_hidden_states,
|
|
394
|
-
prefix=f"{self.
|
|
401
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
395
402
|
)
|
|
396
403
|
|
|
397
|
-
|
|
404
|
+
self.cache_manager.set_Bn_buffer(
|
|
398
405
|
Bn_i_hidden_states_residual,
|
|
399
|
-
prefix=f"{self.
|
|
406
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
400
407
|
)
|
|
401
|
-
|
|
408
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
402
409
|
Bn_i_encoder_hidden_states_residual,
|
|
403
|
-
prefix=f"{self.
|
|
410
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
404
411
|
)
|
|
405
412
|
del Bn_i_hidden_states_residual
|
|
406
413
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -411,7 +418,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
411
418
|
else:
|
|
412
419
|
# Cache steps: Reuse the cached residuals.
|
|
413
420
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
414
|
-
if block_id in
|
|
421
|
+
if block_id in self.cache_manager.Bn_compute_blocks_ids():
|
|
415
422
|
hidden_states = block(
|
|
416
423
|
hidden_states,
|
|
417
424
|
encoder_hidden_states,
|
|
@@ -429,25 +436,25 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
429
436
|
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
430
437
|
# Use the cached residuals instead.
|
|
431
438
|
# Check if can use the cached residuals.
|
|
432
|
-
if
|
|
439
|
+
if self.cache_manager.can_cache(
|
|
433
440
|
hidden_states, # curr step
|
|
434
441
|
parallelized=self._is_parallelized(),
|
|
435
|
-
threshold=
|
|
436
|
-
prefix=f"{self.
|
|
442
|
+
threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
|
|
443
|
+
prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
|
|
437
444
|
):
|
|
438
445
|
hidden_states, encoder_hidden_states = (
|
|
439
|
-
|
|
446
|
+
self.cache_manager.apply_cache(
|
|
440
447
|
hidden_states,
|
|
441
448
|
encoder_hidden_states,
|
|
442
449
|
prefix=(
|
|
443
|
-
f"{self.
|
|
444
|
-
if
|
|
445
|
-
else f"{self.
|
|
450
|
+
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
451
|
+
if self.cache_manager.is_cache_residual()
|
|
452
|
+
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
446
453
|
),
|
|
447
454
|
encoder_prefix=(
|
|
448
|
-
f"{self.
|
|
449
|
-
if
|
|
450
|
-
else f"{self.
|
|
455
|
+
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
456
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
457
|
+
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
451
458
|
),
|
|
452
459
|
)
|
|
453
460
|
)
|
|
@@ -474,16 +481,16 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
474
481
|
*args,
|
|
475
482
|
**kwargs,
|
|
476
483
|
):
|
|
477
|
-
if
|
|
484
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
478
485
|
return hidden_states, encoder_hidden_states
|
|
479
486
|
|
|
480
|
-
assert
|
|
487
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
481
488
|
self.transformer_blocks
|
|
482
489
|
), (
|
|
483
|
-
f"Bn_compute_blocks {
|
|
490
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
484
491
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
485
492
|
)
|
|
486
|
-
if len(
|
|
493
|
+
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
487
494
|
for i, block in enumerate(self._Bn_blocks()):
|
|
488
495
|
hidden_states, encoder_hidden_states = (
|
|
489
496
|
self._compute_or_cache_block(
|
|
@@ -2,22 +2,40 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from typing import Any
|
|
4
4
|
from cache_dit.cache_factory import CachedContext
|
|
5
|
+
from cache_dit.cache_factory import CachedContextManager
|
|
5
6
|
|
|
6
7
|
|
|
7
|
-
@torch.compiler.disable
|
|
8
8
|
def patch_cached_stats(
|
|
9
|
-
module: torch.nn.Module | Any,
|
|
9
|
+
module: torch.nn.Module | Any,
|
|
10
|
+
cache_context: CachedContext | str = None,
|
|
11
|
+
cache_manager: CachedContextManager = None,
|
|
10
12
|
):
|
|
11
13
|
# Patch the cached stats to the module, the cached stats
|
|
12
14
|
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
13
|
-
if module is None:
|
|
15
|
+
if module is None or cache_manager is None:
|
|
14
16
|
return
|
|
15
17
|
|
|
16
18
|
if cache_context is not None:
|
|
17
|
-
|
|
19
|
+
cache_manager.set_context(cache_context)
|
|
18
20
|
|
|
19
21
|
# TODO: Patch more cached stats to the module
|
|
20
|
-
module._cached_steps =
|
|
21
|
-
module._residual_diffs =
|
|
22
|
-
module._cfg_cached_steps =
|
|
23
|
-
module._cfg_residual_diffs =
|
|
22
|
+
module._cached_steps = cache_manager.get_cached_steps()
|
|
23
|
+
module._residual_diffs = cache_manager.get_residual_diffs()
|
|
24
|
+
module._cfg_cached_steps = cache_manager.get_cfg_cached_steps()
|
|
25
|
+
module._cfg_residual_diffs = cache_manager.get_cfg_residual_diffs()
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def remove_cached_stats(
|
|
29
|
+
module: torch.nn.Module | Any,
|
|
30
|
+
):
|
|
31
|
+
if module is None:
|
|
32
|
+
return
|
|
33
|
+
|
|
34
|
+
if hasattr(module, "_cached_steps"):
|
|
35
|
+
del module._cached_steps
|
|
36
|
+
if hasattr(module, "_residual_diffs"):
|
|
37
|
+
del module._residual_diffs
|
|
38
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
39
|
+
del module._cfg_cached_steps
|
|
40
|
+
if hasattr(module, "_cfg_residual_diffs"):
|
|
41
|
+
del module._cfg_residual_diffs
|
|
@@ -1,2 +1,5 @@
|
|
|
1
1
|
# namespace alias: for _CachedContext and many others' cache context funcs.
|
|
2
|
-
|
|
2
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
3
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
4
|
+
CachedContextManager,
|
|
5
|
+
)
|