cache-dit 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
- cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +20 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +18 -7
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/METADATA +9 -9
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -1,28 +1,33 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
5
|
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
|
|
5
7
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
8
|
CachedContextManager,
|
|
7
9
|
)
|
|
10
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
11
|
+
PrunedContextManager,
|
|
12
|
+
)
|
|
8
13
|
|
|
9
14
|
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
10
15
|
CachedBlocks_Pattern_0_1_2,
|
|
16
|
+
PrunedBlocks_Pattern_0_1_2,
|
|
11
17
|
)
|
|
12
18
|
from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
13
19
|
CachedBlocks_Pattern_3_4_5,
|
|
20
|
+
PrunedBlocks_Pattern_3_4_5,
|
|
14
21
|
)
|
|
15
|
-
from cache_dit.cache_factory.cache_blocks.pattern_utils import
|
|
16
|
-
|
|
17
|
-
remove_cached_stats,
|
|
18
|
-
)
|
|
22
|
+
from cache_dit.cache_factory.cache_blocks.pattern_utils import apply_stats
|
|
23
|
+
from cache_dit.cache_factory.cache_blocks.pattern_utils import remove_stats
|
|
19
24
|
|
|
20
25
|
from cache_dit.logger import init_logger
|
|
21
26
|
|
|
22
27
|
logger = init_logger(__name__)
|
|
23
28
|
|
|
24
29
|
|
|
25
|
-
class
|
|
30
|
+
class UnifiedBlocks:
|
|
26
31
|
def __new__(
|
|
27
32
|
cls,
|
|
28
33
|
# 0. Transformer blocks configuration
|
|
@@ -36,16 +41,13 @@ class CachedBlocks:
|
|
|
36
41
|
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
37
42
|
cache_prefix: str = None, # cache_prefix maybe un-need.
|
|
38
43
|
# Usually, blocks_name, etc.
|
|
39
|
-
cache_context: CachedContext | str = None,
|
|
40
|
-
|
|
44
|
+
cache_context: CachedContext | PrunedContext | str = None,
|
|
45
|
+
context_manager: CachedContextManager | PrunedContextManager = None,
|
|
46
|
+
cache_type: CacheType = CacheType.DBCache,
|
|
41
47
|
**kwargs,
|
|
42
48
|
):
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
assert cache_context is not None, "cache_context can't be None."
|
|
46
|
-
assert cache_manager is not None, "cache_manager can't be None."
|
|
47
|
-
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
48
|
-
return CachedBlocks_Pattern_0_1_2(
|
|
49
|
+
if cache_type == CacheType.DBCache:
|
|
50
|
+
return CachedBlocks(
|
|
49
51
|
# 0. Transformer blocks configuration
|
|
50
52
|
transformer_blocks,
|
|
51
53
|
transformer=transformer,
|
|
@@ -55,11 +57,12 @@ class CachedBlocks:
|
|
|
55
57
|
# 1. Cache context configuration
|
|
56
58
|
cache_prefix=cache_prefix,
|
|
57
59
|
cache_context=cache_context,
|
|
58
|
-
|
|
60
|
+
context_manager=context_manager,
|
|
61
|
+
cache_type=cache_type,
|
|
59
62
|
**kwargs,
|
|
60
63
|
)
|
|
61
|
-
elif
|
|
62
|
-
return
|
|
64
|
+
elif cache_type == CacheType.DBPrune:
|
|
65
|
+
return PrunedBlocks(
|
|
63
66
|
# 0. Transformer blocks configuration
|
|
64
67
|
transformer_blocks,
|
|
65
68
|
transformer=transformer,
|
|
@@ -69,8 +72,155 @@ class CachedBlocks:
|
|
|
69
72
|
# 1. Cache context configuration
|
|
70
73
|
cache_prefix=cache_prefix,
|
|
71
74
|
cache_context=cache_context,
|
|
72
|
-
|
|
75
|
+
context_manager=context_manager,
|
|
76
|
+
cache_type=cache_type,
|
|
73
77
|
**kwargs,
|
|
74
78
|
)
|
|
79
|
+
else:
|
|
80
|
+
raise ValueError(f"Cache type {cache_type} is not supported now!")
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
class CachedBlocks:
|
|
84
|
+
def __new__(
|
|
85
|
+
cls,
|
|
86
|
+
# 0. Transformer blocks configuration
|
|
87
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
88
|
+
transformer: torch.nn.Module = None,
|
|
89
|
+
forward_pattern: ForwardPattern = None,
|
|
90
|
+
check_forward_pattern: bool = True,
|
|
91
|
+
check_num_outputs: bool = True,
|
|
92
|
+
# 1. Cache context configuration
|
|
93
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
94
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
95
|
+
cache_prefix: str = None, # cache_prefix maybe un-need.
|
|
96
|
+
# Usually, blocks_name, etc.
|
|
97
|
+
cache_context: CachedContext | PrunedContext | str = None,
|
|
98
|
+
context_manager: CachedContextManager | PrunedContextManager = None,
|
|
99
|
+
cache_type: CacheType = CacheType.DBCache,
|
|
100
|
+
**kwargs,
|
|
101
|
+
):
|
|
102
|
+
assert transformer is not None, "transformer can't be None."
|
|
103
|
+
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
104
|
+
assert cache_context is not None, "cache_context can't be None."
|
|
105
|
+
assert context_manager is not None, "context_manager can't be None."
|
|
106
|
+
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
107
|
+
if cache_type == CacheType.DBCache:
|
|
108
|
+
assert isinstance(
|
|
109
|
+
context_manager, CachedContextManager
|
|
110
|
+
), "context_manager must be CachedContextManager for DBCache."
|
|
111
|
+
return CachedBlocks_Pattern_0_1_2(
|
|
112
|
+
# 0. Transformer blocks configuration
|
|
113
|
+
transformer_blocks,
|
|
114
|
+
transformer=transformer,
|
|
115
|
+
forward_pattern=forward_pattern,
|
|
116
|
+
check_forward_pattern=check_forward_pattern,
|
|
117
|
+
check_num_outputs=check_num_outputs,
|
|
118
|
+
# 1. Cache context configuration
|
|
119
|
+
cache_prefix=cache_prefix,
|
|
120
|
+
cache_context=cache_context,
|
|
121
|
+
context_manager=context_manager,
|
|
122
|
+
cache_type=cache_type,
|
|
123
|
+
**kwargs,
|
|
124
|
+
)
|
|
125
|
+
else:
|
|
126
|
+
raise ValueError(
|
|
127
|
+
f"Cache type {cache_type} is not supported now!"
|
|
128
|
+
)
|
|
129
|
+
elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
|
|
130
|
+
if cache_type == CacheType.DBCache:
|
|
131
|
+
assert isinstance(
|
|
132
|
+
context_manager, CachedContextManager
|
|
133
|
+
), "context_manager must be CachedContextManager for DBCache."
|
|
134
|
+
return CachedBlocks_Pattern_3_4_5(
|
|
135
|
+
# 0. Transformer blocks configuration
|
|
136
|
+
transformer_blocks,
|
|
137
|
+
transformer=transformer,
|
|
138
|
+
forward_pattern=forward_pattern,
|
|
139
|
+
check_forward_pattern=check_forward_pattern,
|
|
140
|
+
check_num_outputs=check_num_outputs,
|
|
141
|
+
# 1. Cache context configuration
|
|
142
|
+
cache_prefix=cache_prefix,
|
|
143
|
+
cache_context=cache_context,
|
|
144
|
+
context_manager=context_manager,
|
|
145
|
+
cache_type=cache_type,
|
|
146
|
+
**kwargs,
|
|
147
|
+
)
|
|
148
|
+
else:
|
|
149
|
+
raise ValueError(
|
|
150
|
+
f"Cache type {cache_type} is not supported now!"
|
|
151
|
+
)
|
|
152
|
+
else:
|
|
153
|
+
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
154
|
+
|
|
155
|
+
|
|
156
|
+
class PrunedBlocks:
|
|
157
|
+
def __new__(
|
|
158
|
+
cls,
|
|
159
|
+
# 0. Transformer blocks configuration
|
|
160
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
161
|
+
transformer: torch.nn.Module = None,
|
|
162
|
+
forward_pattern: ForwardPattern = None,
|
|
163
|
+
check_forward_pattern: bool = True,
|
|
164
|
+
check_num_outputs: bool = True,
|
|
165
|
+
# 1. Cache context configuration
|
|
166
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
167
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
168
|
+
cache_prefix: str = None, # cache_prefix maybe un-need.
|
|
169
|
+
# Usually, blocks_name, etc.
|
|
170
|
+
cache_context: CachedContext | PrunedContext | str = None,
|
|
171
|
+
context_manager: CachedContextManager | PrunedContextManager = None,
|
|
172
|
+
cache_type: CacheType = CacheType.DBCache,
|
|
173
|
+
**kwargs,
|
|
174
|
+
):
|
|
175
|
+
assert transformer is not None, "transformer can't be None."
|
|
176
|
+
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
177
|
+
assert cache_context is not None, "cache_context can't be None."
|
|
178
|
+
assert context_manager is not None, "context_manager can't be None."
|
|
179
|
+
if forward_pattern in PrunedBlocks_Pattern_0_1_2._supported_patterns:
|
|
180
|
+
if cache_type == CacheType.DBPrune:
|
|
181
|
+
assert isinstance(
|
|
182
|
+
context_manager, PrunedContextManager
|
|
183
|
+
), "context_manager must be PrunedContextManager for DBPrune."
|
|
184
|
+
return PrunedBlocks_Pattern_0_1_2(
|
|
185
|
+
# 0. Transformer blocks configuration
|
|
186
|
+
transformer_blocks,
|
|
187
|
+
transformer=transformer,
|
|
188
|
+
forward_pattern=forward_pattern,
|
|
189
|
+
check_forward_pattern=check_forward_pattern,
|
|
190
|
+
check_num_outputs=check_num_outputs,
|
|
191
|
+
# 1. Cache context configuration
|
|
192
|
+
cache_prefix=cache_prefix,
|
|
193
|
+
cache_context=cache_context,
|
|
194
|
+
context_manager=context_manager,
|
|
195
|
+
cache_type=cache_type,
|
|
196
|
+
**kwargs,
|
|
197
|
+
)
|
|
198
|
+
else:
|
|
199
|
+
raise ValueError(
|
|
200
|
+
f"Cache type {cache_type} is not supported now!"
|
|
201
|
+
)
|
|
202
|
+
elif forward_pattern in PrunedBlocks_Pattern_3_4_5._supported_patterns:
|
|
203
|
+
if cache_type == CacheType.DBPrune:
|
|
204
|
+
assert isinstance(
|
|
205
|
+
context_manager, PrunedContextManager
|
|
206
|
+
), "context_manager must be PrunedContextManager for DBPrune."
|
|
207
|
+
return PrunedBlocks_Pattern_3_4_5(
|
|
208
|
+
# 0. Transformer blocks configuration
|
|
209
|
+
transformer_blocks,
|
|
210
|
+
transformer=transformer,
|
|
211
|
+
forward_pattern=forward_pattern,
|
|
212
|
+
check_forward_pattern=check_forward_pattern,
|
|
213
|
+
check_num_outputs=check_num_outputs,
|
|
214
|
+
# 1. Cache context configuration
|
|
215
|
+
cache_prefix=cache_prefix,
|
|
216
|
+
cache_context=cache_context,
|
|
217
|
+
context_manager=context_manager,
|
|
218
|
+
cache_type=cache_type,
|
|
219
|
+
**kwargs,
|
|
220
|
+
)
|
|
221
|
+
else:
|
|
222
|
+
raise ValueError(
|
|
223
|
+
f"Cache type {cache_type} is not supported now!"
|
|
224
|
+
)
|
|
75
225
|
else:
|
|
76
226
|
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
from cache_dit.cache_factory import ForwardPattern
|
|
2
2
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
3
3
|
CachedBlocks_Pattern_Base,
|
|
4
|
+
PrunedBlocks_Pattern_Base,
|
|
4
5
|
)
|
|
5
6
|
from cache_dit.logger import init_logger
|
|
6
7
|
|
|
@@ -14,3 +15,12 @@ class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
|
|
|
14
15
|
ForwardPattern.Pattern_2,
|
|
15
16
|
]
|
|
16
17
|
...
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class PrunedBlocks_Pattern_0_1_2(PrunedBlocks_Pattern_Base):
|
|
21
|
+
_supported_patterns = [
|
|
22
|
+
ForwardPattern.Pattern_0,
|
|
23
|
+
ForwardPattern.Pattern_1,
|
|
24
|
+
ForwardPattern.Pattern_2,
|
|
25
|
+
]
|
|
26
|
+
...
|
|
@@ -2,11 +2,17 @@ import torch
|
|
|
2
2
|
|
|
3
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
4
4
|
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
5
|
-
|
|
5
|
+
ContextNotExistError,
|
|
6
6
|
)
|
|
7
7
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
8
8
|
CachedBlocks_Pattern_Base,
|
|
9
9
|
)
|
|
10
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import PrunedContext
|
|
11
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
12
|
+
PrunedContextManager,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
15
|
+
|
|
10
16
|
from cache_dit.logger import init_logger
|
|
11
17
|
|
|
12
18
|
logger = init_logger(__name__)
|
|
@@ -91,10 +97,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
91
97
|
):
|
|
92
98
|
# Use it's own cache context.
|
|
93
99
|
try:
|
|
94
|
-
self.
|
|
100
|
+
self.context_manager.set_context(self.cache_context)
|
|
95
101
|
self._check_cache_params()
|
|
96
|
-
except
|
|
97
|
-
logger.warning(f"
|
|
102
|
+
except ContextNotExistError as e:
|
|
103
|
+
logger.warning(f"context not exist: {e}, skip cache.")
|
|
98
104
|
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
99
105
|
hidden_states,
|
|
100
106
|
*args,
|
|
@@ -118,38 +124,38 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
118
124
|
)
|
|
119
125
|
del original_hidden_states
|
|
120
126
|
|
|
121
|
-
self.
|
|
127
|
+
self.context_manager.mark_step_begin()
|
|
122
128
|
# Residual L1 diff or Hidden States L1 diff
|
|
123
|
-
can_use_cache = self.
|
|
129
|
+
can_use_cache = self.context_manager.can_cache(
|
|
124
130
|
(
|
|
125
131
|
Fn_hidden_states_residual
|
|
126
|
-
if not self.
|
|
132
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
127
133
|
else hidden_states
|
|
128
134
|
),
|
|
129
135
|
parallelized=self._is_parallelized(),
|
|
130
136
|
prefix=(
|
|
131
137
|
f"{self.cache_prefix}_Fn_residual"
|
|
132
|
-
if not self.
|
|
138
|
+
if not self.context_manager.is_l1_diff_enabled()
|
|
133
139
|
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
134
140
|
),
|
|
135
141
|
)
|
|
136
142
|
|
|
137
143
|
torch._dynamo.graph_break()
|
|
138
144
|
if can_use_cache:
|
|
139
|
-
self.
|
|
145
|
+
self.context_manager.add_cached_step()
|
|
140
146
|
del Fn_hidden_states_residual
|
|
141
147
|
hidden_states, new_encoder_hidden_states = (
|
|
142
|
-
self.
|
|
148
|
+
self.context_manager.apply_cache(
|
|
143
149
|
hidden_states,
|
|
144
150
|
new_encoder_hidden_states, # encoder_hidden_states not use cache
|
|
145
151
|
prefix=(
|
|
146
152
|
f"{self.cache_prefix}_Bn_residual"
|
|
147
|
-
if self.
|
|
153
|
+
if self.context_manager.is_cache_residual()
|
|
148
154
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
149
155
|
),
|
|
150
156
|
encoder_prefix=(
|
|
151
157
|
f"{self.cache_prefix}_Bn_residual"
|
|
152
|
-
if self.
|
|
158
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
153
159
|
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
154
160
|
),
|
|
155
161
|
)
|
|
@@ -157,20 +163,20 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
157
163
|
torch._dynamo.graph_break()
|
|
158
164
|
# Call last `n` blocks to further process the hidden states
|
|
159
165
|
# for higher precision.
|
|
160
|
-
if self.
|
|
166
|
+
if self.context_manager.Bn_compute_blocks() > 0:
|
|
161
167
|
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
162
168
|
hidden_states,
|
|
163
169
|
*args,
|
|
164
170
|
**kwargs,
|
|
165
171
|
)
|
|
166
172
|
else:
|
|
167
|
-
self.
|
|
173
|
+
self.context_manager.set_Fn_buffer(
|
|
168
174
|
Fn_hidden_states_residual,
|
|
169
175
|
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
170
176
|
)
|
|
171
|
-
if self.
|
|
177
|
+
if self.context_manager.is_l1_diff_enabled():
|
|
172
178
|
# for hidden states L1 diff
|
|
173
|
-
self.
|
|
179
|
+
self.context_manager.set_Fn_buffer(
|
|
174
180
|
hidden_states,
|
|
175
181
|
f"{self.cache_prefix}_Fn_hidden_states",
|
|
176
182
|
)
|
|
@@ -188,13 +194,13 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
188
194
|
)
|
|
189
195
|
|
|
190
196
|
torch._dynamo.graph_break()
|
|
191
|
-
if self.
|
|
192
|
-
self.
|
|
197
|
+
if self.context_manager.is_cache_residual():
|
|
198
|
+
self.context_manager.set_Bn_buffer(
|
|
193
199
|
hidden_states_residual,
|
|
194
200
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
195
201
|
)
|
|
196
202
|
else:
|
|
197
|
-
self.
|
|
203
|
+
self.context_manager.set_Bn_buffer(
|
|
198
204
|
hidden_states,
|
|
199
205
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
200
206
|
)
|
|
@@ -203,22 +209,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
203
209
|
new_encoder_hidden_states_residual = (
|
|
204
210
|
new_encoder_hidden_states - old_encoder_hidden_states
|
|
205
211
|
)
|
|
206
|
-
if self.
|
|
212
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
207
213
|
if new_encoder_hidden_states is not None:
|
|
208
|
-
self.
|
|
214
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
209
215
|
new_encoder_hidden_states_residual,
|
|
210
216
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
211
217
|
)
|
|
212
218
|
else:
|
|
213
219
|
if new_encoder_hidden_states is not None:
|
|
214
|
-
self.
|
|
220
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
215
221
|
new_encoder_hidden_states_residual,
|
|
216
222
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
217
223
|
)
|
|
218
224
|
torch._dynamo.graph_break()
|
|
219
225
|
# Call last `n` blocks to further process the hidden states
|
|
220
226
|
# for higher precision.
|
|
221
|
-
if self.
|
|
227
|
+
if self.context_manager.Bn_compute_blocks() > 0:
|
|
222
228
|
hidden_states, new_encoder_hidden_states = self.call_Bn_blocks(
|
|
223
229
|
hidden_states,
|
|
224
230
|
*args,
|
|
@@ -289,7 +295,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
289
295
|
**kwargs,
|
|
290
296
|
):
|
|
291
297
|
new_encoder_hidden_states = None
|
|
292
|
-
if self.
|
|
298
|
+
if self.context_manager.Bn_compute_blocks() == 0:
|
|
293
299
|
return hidden_states, new_encoder_hidden_states
|
|
294
300
|
|
|
295
301
|
for block in self._Bn_blocks():
|
|
@@ -304,3 +310,229 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
304
310
|
)
|
|
305
311
|
|
|
306
312
|
return hidden_states, new_encoder_hidden_states
|
|
313
|
+
|
|
314
|
+
|
|
315
|
+
class PrunedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_3_4_5):
|
|
316
|
+
_supported_patterns = [
|
|
317
|
+
ForwardPattern.Pattern_3,
|
|
318
|
+
ForwardPattern.Pattern_4,
|
|
319
|
+
ForwardPattern.Pattern_5,
|
|
320
|
+
]
|
|
321
|
+
pruned_blocks_step: int = 0 # number of pruned blocks in current step
|
|
322
|
+
|
|
323
|
+
def __init__(
|
|
324
|
+
self,
|
|
325
|
+
# 0. Transformer blocks configuration
|
|
326
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
327
|
+
transformer: torch.nn.Module = None,
|
|
328
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
329
|
+
check_forward_pattern: bool = True,
|
|
330
|
+
check_num_outputs: bool = True,
|
|
331
|
+
# 1. Prune context configuration
|
|
332
|
+
cache_prefix: str = None, # maybe un-need.
|
|
333
|
+
cache_context: PrunedContext | str = None,
|
|
334
|
+
context_manager: PrunedContextManager = None,
|
|
335
|
+
cache_type: CacheType = CacheType.DBPrune,
|
|
336
|
+
**kwargs,
|
|
337
|
+
):
|
|
338
|
+
super().__init__(
|
|
339
|
+
# 0. Transformer blocks configuration
|
|
340
|
+
transformer_blocks,
|
|
341
|
+
transformer=transformer,
|
|
342
|
+
forward_pattern=forward_pattern,
|
|
343
|
+
check_forward_pattern=check_forward_pattern,
|
|
344
|
+
check_num_outputs=check_num_outputs,
|
|
345
|
+
# 1. Cache context configuration
|
|
346
|
+
cache_prefix=cache_prefix,
|
|
347
|
+
cache_context=cache_context,
|
|
348
|
+
context_manager=context_manager,
|
|
349
|
+
cache_type=cache_type,
|
|
350
|
+
**kwargs,
|
|
351
|
+
)
|
|
352
|
+
assert isinstance(
|
|
353
|
+
self.context_manager, PrunedContextManager
|
|
354
|
+
), "context_manager must be PrunedContextManager for PrunedBlocks."
|
|
355
|
+
self.context_manager: PrunedContextManager = (
|
|
356
|
+
self.context_manager
|
|
357
|
+
) # For type hint
|
|
358
|
+
|
|
359
|
+
@torch.compiler.disable
|
|
360
|
+
def _check_cache_type(self):
|
|
361
|
+
assert (
|
|
362
|
+
self.cache_type == CacheType.DBPrune
|
|
363
|
+
), f"Cache type {self.cache_type} is not supported for PrunedBlocks."
|
|
364
|
+
|
|
365
|
+
def forward(
|
|
366
|
+
self,
|
|
367
|
+
hidden_states: torch.Tensor,
|
|
368
|
+
*args,
|
|
369
|
+
**kwargs,
|
|
370
|
+
):
|
|
371
|
+
self.pruned_blocks_step: int = 0 # reset for each step
|
|
372
|
+
|
|
373
|
+
# Use it's own cache context.
|
|
374
|
+
try:
|
|
375
|
+
self.context_manager.set_context(self.cache_context)
|
|
376
|
+
self._check_cache_params()
|
|
377
|
+
except ContextNotExistError as e:
|
|
378
|
+
logger.warning(f"context not exist: {e}, skip prune.")
|
|
379
|
+
hidden_states, new_encoder_hidden_states = self.call_blocks(
|
|
380
|
+
hidden_states,
|
|
381
|
+
*args,
|
|
382
|
+
**kwargs,
|
|
383
|
+
)
|
|
384
|
+
return self._process_forward_outputs(
|
|
385
|
+
hidden_states, new_encoder_hidden_states
|
|
386
|
+
)
|
|
387
|
+
|
|
388
|
+
self.context_manager.mark_step_begin()
|
|
389
|
+
|
|
390
|
+
# Call all blocks with prune strategy to process the hidden states.
|
|
391
|
+
new_encoder_hidden_states = None
|
|
392
|
+
for i, block in enumerate(self.transformer_blocks):
|
|
393
|
+
hidden_states, new_encoder_hidden_states = self.compute_or_prune(
|
|
394
|
+
i,
|
|
395
|
+
block,
|
|
396
|
+
hidden_states,
|
|
397
|
+
new_encoder_hidden_states,
|
|
398
|
+
*args,
|
|
399
|
+
**kwargs,
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
self.context_manager.add_pruned_block(self.pruned_blocks_step)
|
|
403
|
+
self.context_manager.add_actual_block(self.num_blocks)
|
|
404
|
+
|
|
405
|
+
return self._process_forward_outputs(
|
|
406
|
+
hidden_states,
|
|
407
|
+
new_encoder_hidden_states,
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
@property
|
|
411
|
+
@torch.compiler.disable
|
|
412
|
+
def num_blocks(self):
|
|
413
|
+
return len(self.transformer_blocks)
|
|
414
|
+
|
|
415
|
+
@torch.compiler.disable
|
|
416
|
+
def _skip_prune(self, block_id: int) -> bool:
|
|
417
|
+
# Wrap for non compiled mode.
|
|
418
|
+
return block_id in self.context_manager.get_non_prune_blocks_ids(
|
|
419
|
+
self.num_blocks
|
|
420
|
+
)
|
|
421
|
+
|
|
422
|
+
@torch.compiler.disable
|
|
423
|
+
def _maybe_prune(
|
|
424
|
+
self,
|
|
425
|
+
block_id: int, # Block index in the transformer blocks
|
|
426
|
+
hidden_states: torch.Tensor, # hidden_states or residual
|
|
427
|
+
prefix: str = "Bn_original", # prev step name for single blocks
|
|
428
|
+
):
|
|
429
|
+
# Wrap for non compiled mode.
|
|
430
|
+
can_use_prune = False
|
|
431
|
+
if not self._skip_prune(block_id):
|
|
432
|
+
can_use_prune = self.context_manager.can_prune(
|
|
433
|
+
hidden_states, # curr step
|
|
434
|
+
parallelized=self._is_parallelized(),
|
|
435
|
+
prefix=prefix, # prev step
|
|
436
|
+
)
|
|
437
|
+
self.pruned_blocks_step += int(can_use_prune)
|
|
438
|
+
return can_use_prune
|
|
439
|
+
|
|
440
|
+
def compute_or_prune(
|
|
441
|
+
self,
|
|
442
|
+
block_id: int, # Block index in the transformer blocks
|
|
443
|
+
# Below are the inputs to the block
|
|
444
|
+
block, # The transformer block to be executed
|
|
445
|
+
hidden_states: torch.Tensor,
|
|
446
|
+
new_encoder_hidden_states: torch.Tensor | None,
|
|
447
|
+
*args,
|
|
448
|
+
**kwargs,
|
|
449
|
+
):
|
|
450
|
+
original_hidden_states = hidden_states
|
|
451
|
+
original_encoder_hidden_states = new_encoder_hidden_states
|
|
452
|
+
|
|
453
|
+
can_use_prune = self._maybe_prune(
|
|
454
|
+
block_id,
|
|
455
|
+
hidden_states,
|
|
456
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
457
|
+
)
|
|
458
|
+
|
|
459
|
+
# Prune steps: Prune current block and reuse the cached
|
|
460
|
+
# residuals for hidden states approximate.
|
|
461
|
+
torch._dynamo.graph_break()
|
|
462
|
+
if can_use_prune:
|
|
463
|
+
self.context_manager.add_pruned_step()
|
|
464
|
+
hidden_states, new_encoder_hidden_states = (
|
|
465
|
+
self.context_manager.apply_prune(
|
|
466
|
+
hidden_states,
|
|
467
|
+
new_encoder_hidden_states,
|
|
468
|
+
prefix=(
|
|
469
|
+
f"{self.cache_prefix}_{block_id}_Bn_residual"
|
|
470
|
+
if self.context_manager.is_cache_residual()
|
|
471
|
+
else f"{self.cache_prefix}_{block_id}_Bn_hidden_states"
|
|
472
|
+
),
|
|
473
|
+
encoder_prefix=(
|
|
474
|
+
f"{self.cache_prefix}_{block_id}_Bn_encoder_residual"
|
|
475
|
+
if self.context_manager.is_encoder_cache_residual()
|
|
476
|
+
else f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states"
|
|
477
|
+
),
|
|
478
|
+
)
|
|
479
|
+
)
|
|
480
|
+
torch._dynamo.graph_break()
|
|
481
|
+
else:
|
|
482
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
483
|
+
hidden_states = block(
|
|
484
|
+
hidden_states,
|
|
485
|
+
*args,
|
|
486
|
+
**kwargs,
|
|
487
|
+
)
|
|
488
|
+
hidden_states, new_encoder_hidden_states = (
|
|
489
|
+
self._process_block_outputs(
|
|
490
|
+
hidden_states, new_encoder_hidden_states
|
|
491
|
+
)
|
|
492
|
+
)
|
|
493
|
+
if not self._skip_prune(block_id):
|
|
494
|
+
hidden_states = hidden_states.contiguous()
|
|
495
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
496
|
+
|
|
497
|
+
if (
|
|
498
|
+
new_encoder_hidden_states is not None
|
|
499
|
+
and original_encoder_hidden_states is not None
|
|
500
|
+
):
|
|
501
|
+
new_encoder_hidden_states = (
|
|
502
|
+
new_encoder_hidden_states.contiguous()
|
|
503
|
+
)
|
|
504
|
+
new_encoder_hidden_states_residual = (
|
|
505
|
+
new_encoder_hidden_states
|
|
506
|
+
- original_encoder_hidden_states
|
|
507
|
+
)
|
|
508
|
+
else:
|
|
509
|
+
new_encoder_hidden_states_residual = None
|
|
510
|
+
|
|
511
|
+
self.context_manager.set_Fn_buffer(
|
|
512
|
+
original_hidden_states,
|
|
513
|
+
prefix=f"{self.cache_prefix}_{block_id}_Fn_original",
|
|
514
|
+
)
|
|
515
|
+
if self.context_manager.is_cache_residual():
|
|
516
|
+
self.context_manager.set_Bn_buffer(
|
|
517
|
+
hidden_states_residual,
|
|
518
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_residual",
|
|
519
|
+
)
|
|
520
|
+
else:
|
|
521
|
+
self.context_manager.set_Bn_buffer(
|
|
522
|
+
hidden_states,
|
|
523
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_hidden_states",
|
|
524
|
+
)
|
|
525
|
+
if new_encoder_hidden_states_residual is not None:
|
|
526
|
+
if self.context_manager.is_encoder_cache_residual():
|
|
527
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
528
|
+
new_encoder_hidden_states_residual,
|
|
529
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_residual",
|
|
530
|
+
)
|
|
531
|
+
else:
|
|
532
|
+
self.context_manager.set_Bn_encoder_buffer(
|
|
533
|
+
new_encoder_hidden_states_residual,
|
|
534
|
+
prefix=f"{self.cache_prefix}_{block_id}_Bn_encoder_hidden_states",
|
|
535
|
+
)
|
|
536
|
+
torch._dynamo.graph_break()
|
|
537
|
+
|
|
538
|
+
return hidden_states, new_encoder_hidden_states
|