cache-dit 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +7 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +15 -4
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +120 -911
- cache_dit/cache_factory/cache_blocks/__init__.py +7 -9
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +46 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +98 -79
- cache_dit/cache_factory/cache_blocks/utils.py +13 -9
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +89 -55
- cache_dit/cache_factory/cache_contexts/cache_manager.py +0 -0
- cache_dit/cache_factory/cache_interface.py +21 -18
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +19 -16
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/METADATA +42 -12
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -1,20 +1,18 @@
|
|
|
1
1
|
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
2
|
-
|
|
2
|
+
CachedBlocks_Pattern_0_1_2,
|
|
3
3
|
)
|
|
4
4
|
from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
5
|
-
|
|
5
|
+
CachedBlocks_Pattern_3_4_5,
|
|
6
6
|
)
|
|
7
7
|
|
|
8
8
|
|
|
9
|
-
class
|
|
9
|
+
class CachedBlocks:
|
|
10
10
|
def __new__(cls, *args, **kwargs):
|
|
11
11
|
forward_pattern = kwargs.get("forward_pattern", None)
|
|
12
12
|
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
13
|
-
if forward_pattern in
|
|
14
|
-
return
|
|
15
|
-
elif
|
|
16
|
-
|
|
17
|
-
):
|
|
18
|
-
return DBCachedBlocks_Pattern_3_4_5(*args, **kwargs)
|
|
13
|
+
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
14
|
+
return CachedBlocks_Pattern_0_1_2(*args, **kwargs)
|
|
15
|
+
elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
|
|
16
|
+
return CachedBlocks_Pattern_3_4_5(*args, **kwargs)
|
|
19
17
|
else:
|
|
20
18
|
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from cache_dit.cache_factory import ForwardPattern
|
|
2
2
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
3
|
-
|
|
3
|
+
CachedBlocks_Pattern_Base,
|
|
4
4
|
)
|
|
5
5
|
from cache_dit.logger import init_logger
|
|
6
6
|
|
|
7
7
|
logger = init_logger(__name__)
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class
|
|
10
|
+
class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
|
|
11
11
|
_supported_patterns = [
|
|
12
12
|
ForwardPattern.Pattern_0,
|
|
13
13
|
ForwardPattern.Pattern_1,
|
|
@@ -1,19 +1,16 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from cache_dit.cache_factory import
|
|
3
|
+
from cache_dit.cache_factory import CachedContext
|
|
4
4
|
from cache_dit.cache_factory import ForwardPattern
|
|
5
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
6
|
-
patch_cached_stats,
|
|
7
|
-
)
|
|
8
5
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
9
|
-
|
|
6
|
+
CachedBlocks_Pattern_Base,
|
|
10
7
|
)
|
|
11
8
|
from cache_dit.logger import init_logger
|
|
12
9
|
|
|
13
10
|
logger = init_logger(__name__)
|
|
14
11
|
|
|
15
12
|
|
|
16
|
-
class
|
|
13
|
+
class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
17
14
|
_supported_patterns = [
|
|
18
15
|
ForwardPattern.Pattern_3,
|
|
19
16
|
ForwardPattern.Pattern_4,
|
|
@@ -26,6 +23,11 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
26
23
|
*args,
|
|
27
24
|
**kwargs,
|
|
28
25
|
):
|
|
26
|
+
# Use it's own cache context.
|
|
27
|
+
CachedContext.set_cache_context(
|
|
28
|
+
self.cache_context,
|
|
29
|
+
)
|
|
30
|
+
|
|
29
31
|
original_hidden_states = hidden_states
|
|
30
32
|
# Call first `n` blocks to process the hidden states for
|
|
31
33
|
# more stable diff calculation.
|
|
@@ -39,40 +41,40 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
39
41
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
40
42
|
del original_hidden_states
|
|
41
43
|
|
|
42
|
-
|
|
44
|
+
CachedContext.mark_step_begin()
|
|
43
45
|
# Residual L1 diff or Hidden States L1 diff
|
|
44
|
-
can_use_cache =
|
|
46
|
+
can_use_cache = CachedContext.get_can_use_cache(
|
|
45
47
|
(
|
|
46
48
|
Fn_hidden_states_residual
|
|
47
|
-
if not
|
|
49
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
48
50
|
else hidden_states
|
|
49
51
|
),
|
|
50
52
|
parallelized=self._is_parallelized(),
|
|
51
53
|
prefix=(
|
|
52
|
-
"
|
|
53
|
-
if not
|
|
54
|
-
else "
|
|
54
|
+
f"{self.blocks_name}_Fn_residual"
|
|
55
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
56
|
+
else f"{self.blocks_name}_Fn_hidden_states"
|
|
55
57
|
),
|
|
56
58
|
)
|
|
57
59
|
|
|
58
60
|
torch._dynamo.graph_break()
|
|
59
61
|
if can_use_cache:
|
|
60
|
-
|
|
62
|
+
CachedContext.add_cached_step()
|
|
61
63
|
del Fn_hidden_states_residual
|
|
62
64
|
hidden_states, encoder_hidden_states = (
|
|
63
|
-
|
|
65
|
+
CachedContext.apply_hidden_states_residual(
|
|
64
66
|
hidden_states,
|
|
65
67
|
# None Pattern 3, else 4, 5
|
|
66
68
|
encoder_hidden_states,
|
|
67
69
|
prefix=(
|
|
68
|
-
"
|
|
69
|
-
if
|
|
70
|
-
else "
|
|
70
|
+
f"{self.blocks_name}_Bn_residual"
|
|
71
|
+
if CachedContext.is_cache_residual()
|
|
72
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
71
73
|
),
|
|
72
74
|
encoder_prefix=(
|
|
73
|
-
"
|
|
74
|
-
if
|
|
75
|
-
else "
|
|
75
|
+
f"{self.blocks_name}_Bn_residual"
|
|
76
|
+
if CachedContext.is_encoder_cache_residual()
|
|
77
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
76
78
|
),
|
|
77
79
|
)
|
|
78
80
|
)
|
|
@@ -86,12 +88,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
86
88
|
**kwargs,
|
|
87
89
|
)
|
|
88
90
|
else:
|
|
89
|
-
|
|
90
|
-
Fn_hidden_states_residual,
|
|
91
|
+
CachedContext.set_Fn_buffer(
|
|
92
|
+
Fn_hidden_states_residual,
|
|
93
|
+
prefix=f"{self.blocks_name}_Fn_residual",
|
|
91
94
|
)
|
|
92
|
-
if
|
|
95
|
+
if CachedContext.is_l1_diff_enabled():
|
|
93
96
|
# for hidden states L1 diff
|
|
94
|
-
|
|
97
|
+
CachedContext.set_Fn_buffer(
|
|
98
|
+
hidden_states,
|
|
99
|
+
f"{self.blocks_name}_Fn_hidden_states",
|
|
100
|
+
)
|
|
95
101
|
del Fn_hidden_states_residual
|
|
96
102
|
torch._dynamo.graph_break()
|
|
97
103
|
(
|
|
@@ -108,29 +114,29 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
108
114
|
**kwargs,
|
|
109
115
|
)
|
|
110
116
|
torch._dynamo.graph_break()
|
|
111
|
-
if
|
|
112
|
-
|
|
117
|
+
if CachedContext.is_cache_residual():
|
|
118
|
+
CachedContext.set_Bn_buffer(
|
|
113
119
|
hidden_states_residual,
|
|
114
|
-
prefix="
|
|
120
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
115
121
|
)
|
|
116
122
|
else:
|
|
117
123
|
# TaylorSeer
|
|
118
|
-
|
|
124
|
+
CachedContext.set_Bn_buffer(
|
|
119
125
|
hidden_states,
|
|
120
|
-
prefix="
|
|
126
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
121
127
|
)
|
|
122
|
-
if
|
|
123
|
-
|
|
128
|
+
if CachedContext.is_encoder_cache_residual():
|
|
129
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
124
130
|
# None Pattern 3, else 4, 5
|
|
125
131
|
encoder_hidden_states_residual,
|
|
126
|
-
prefix="
|
|
132
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
127
133
|
)
|
|
128
134
|
else:
|
|
129
135
|
# TaylorSeer
|
|
130
|
-
|
|
136
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
131
137
|
# None Pattern 3, else 4, 5
|
|
132
138
|
encoder_hidden_states,
|
|
133
|
-
prefix="
|
|
139
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
134
140
|
)
|
|
135
141
|
torch._dynamo.graph_break()
|
|
136
142
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -143,7 +149,6 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
143
149
|
**kwargs,
|
|
144
150
|
)
|
|
145
151
|
|
|
146
|
-
patch_cached_stats(self.transformer)
|
|
147
152
|
torch._dynamo.graph_break()
|
|
148
153
|
|
|
149
154
|
return (
|
|
@@ -162,10 +167,10 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
162
167
|
*args,
|
|
163
168
|
**kwargs,
|
|
164
169
|
):
|
|
165
|
-
assert
|
|
170
|
+
assert CachedContext.Fn_compute_blocks() <= len(
|
|
166
171
|
self.transformer_blocks
|
|
167
172
|
), (
|
|
168
|
-
f"Fn_compute_blocks {
|
|
173
|
+
f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
|
|
169
174
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
170
175
|
)
|
|
171
176
|
encoder_hidden_states = None # Pattern 3
|
|
@@ -237,16 +242,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
237
242
|
*args,
|
|
238
243
|
**kwargs,
|
|
239
244
|
):
|
|
240
|
-
if
|
|
245
|
+
if CachedContext.Bn_compute_blocks() == 0:
|
|
241
246
|
return hidden_states, encoder_hidden_states
|
|
242
247
|
|
|
243
|
-
assert
|
|
248
|
+
assert CachedContext.Bn_compute_blocks() <= len(
|
|
244
249
|
self.transformer_blocks
|
|
245
250
|
), (
|
|
246
|
-
f"Bn_compute_blocks {
|
|
251
|
+
f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
|
|
247
252
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
248
253
|
)
|
|
249
|
-
if len(
|
|
254
|
+
if len(CachedContext.Bn_compute_blocks_ids()) > 0:
|
|
250
255
|
raise ValueError(
|
|
251
256
|
f"Bn_compute_blocks_ids is not support for "
|
|
252
257
|
f"patterns: {self._supported_patterns}."
|
|
@@ -2,17 +2,14 @@ import inspect
|
|
|
2
2
|
import torch
|
|
3
3
|
import torch.distributed as dist
|
|
4
4
|
|
|
5
|
-
from cache_dit.cache_factory import
|
|
5
|
+
from cache_dit.cache_factory import CachedContext
|
|
6
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
7
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
8
|
-
patch_cached_stats,
|
|
9
|
-
)
|
|
10
7
|
from cache_dit.logger import init_logger
|
|
11
8
|
|
|
12
9
|
logger = init_logger(__name__)
|
|
13
10
|
|
|
14
11
|
|
|
15
|
-
class
|
|
12
|
+
class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
16
13
|
_supported_patterns = [
|
|
17
14
|
ForwardPattern.Pattern_0,
|
|
18
15
|
ForwardPattern.Pattern_1,
|
|
@@ -22,17 +19,29 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
22
19
|
def __init__(
|
|
23
20
|
self,
|
|
24
21
|
transformer_blocks: torch.nn.ModuleList,
|
|
22
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
23
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
24
|
+
blocks_name: str,
|
|
25
|
+
# Usually, blocks_name, etc.
|
|
26
|
+
cache_context: str,
|
|
25
27
|
*,
|
|
26
28
|
transformer: torch.nn.Module = None,
|
|
27
29
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
30
|
+
check_num_outputs: bool = True,
|
|
28
31
|
):
|
|
29
32
|
super().__init__()
|
|
30
33
|
|
|
31
34
|
self.transformer = transformer
|
|
32
35
|
self.transformer_blocks = transformer_blocks
|
|
36
|
+
self.blocks_name = blocks_name
|
|
37
|
+
self.cache_context = cache_context
|
|
33
38
|
self.forward_pattern = forward_pattern
|
|
39
|
+
self.check_num_outputs = check_num_outputs
|
|
34
40
|
self._check_forward_pattern()
|
|
35
|
-
logger.info(
|
|
41
|
+
logger.info(
|
|
42
|
+
f"Match Cached Blocks: {self.__class__.__name__}, for "
|
|
43
|
+
f"{self.blocks_name}, context: {self.cache_context}"
|
|
44
|
+
)
|
|
36
45
|
|
|
37
46
|
def _check_forward_pattern(self):
|
|
38
47
|
assert (
|
|
@@ -45,16 +54,18 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
45
54
|
forward_parameters = set(
|
|
46
55
|
inspect.signature(block.forward).parameters.keys()
|
|
47
56
|
)
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
57
|
+
|
|
58
|
+
if self.check_num_outputs:
|
|
59
|
+
num_outputs = str(
|
|
60
|
+
inspect.signature(block.forward).return_annotation
|
|
61
|
+
).count("torch.Tensor")
|
|
62
|
+
|
|
63
|
+
if num_outputs > 0:
|
|
64
|
+
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
65
|
+
f"The number of block's outputs is {num_outputs} don't not "
|
|
66
|
+
f"match the number of the pattern: {self.forward_pattern}, "
|
|
67
|
+
f"Out: {len(self.forward_pattern.Out)}."
|
|
68
|
+
)
|
|
58
69
|
|
|
59
70
|
for required_param in self.forward_pattern.In:
|
|
60
71
|
assert (
|
|
@@ -68,6 +79,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
68
79
|
*args,
|
|
69
80
|
**kwargs,
|
|
70
81
|
):
|
|
82
|
+
CachedContext.set_cache_context(
|
|
83
|
+
self.cache_context,
|
|
84
|
+
)
|
|
85
|
+
|
|
71
86
|
original_hidden_states = hidden_states
|
|
72
87
|
# Call first `n` blocks to process the hidden states for
|
|
73
88
|
# more stable diff calculation.
|
|
@@ -81,39 +96,39 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
81
96
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
82
97
|
del original_hidden_states
|
|
83
98
|
|
|
84
|
-
|
|
99
|
+
CachedContext.mark_step_begin()
|
|
85
100
|
# Residual L1 diff or Hidden States L1 diff
|
|
86
|
-
can_use_cache =
|
|
101
|
+
can_use_cache = CachedContext.get_can_use_cache(
|
|
87
102
|
(
|
|
88
103
|
Fn_hidden_states_residual
|
|
89
|
-
if not
|
|
104
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
90
105
|
else hidden_states
|
|
91
106
|
),
|
|
92
107
|
parallelized=self._is_parallelized(),
|
|
93
108
|
prefix=(
|
|
94
|
-
"
|
|
95
|
-
if not
|
|
96
|
-
else "
|
|
109
|
+
f"{self.blocks_name}_Fn_residual"
|
|
110
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
111
|
+
else f"{self.blocks_name}_Fn_hidden_states"
|
|
97
112
|
),
|
|
98
113
|
)
|
|
99
114
|
|
|
100
115
|
torch._dynamo.graph_break()
|
|
101
116
|
if can_use_cache:
|
|
102
|
-
|
|
117
|
+
CachedContext.add_cached_step()
|
|
103
118
|
del Fn_hidden_states_residual
|
|
104
119
|
hidden_states, encoder_hidden_states = (
|
|
105
|
-
|
|
120
|
+
CachedContext.apply_hidden_states_residual(
|
|
106
121
|
hidden_states,
|
|
107
122
|
encoder_hidden_states,
|
|
108
123
|
prefix=(
|
|
109
|
-
"
|
|
110
|
-
if
|
|
111
|
-
else "
|
|
124
|
+
f"{self.blocks_name}_Bn_residual"
|
|
125
|
+
if CachedContext.is_cache_residual()
|
|
126
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
112
127
|
),
|
|
113
128
|
encoder_prefix=(
|
|
114
|
-
"
|
|
115
|
-
if
|
|
116
|
-
else "
|
|
129
|
+
f"{self.blocks_name}_Bn_residual"
|
|
130
|
+
if CachedContext.is_encoder_cache_residual()
|
|
131
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
117
132
|
),
|
|
118
133
|
)
|
|
119
134
|
)
|
|
@@ -127,12 +142,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
127
142
|
**kwargs,
|
|
128
143
|
)
|
|
129
144
|
else:
|
|
130
|
-
|
|
131
|
-
Fn_hidden_states_residual,
|
|
145
|
+
CachedContext.set_Fn_buffer(
|
|
146
|
+
Fn_hidden_states_residual,
|
|
147
|
+
prefix=f"{self.blocks_name}_Fn_residual",
|
|
132
148
|
)
|
|
133
|
-
if
|
|
149
|
+
if CachedContext.is_l1_diff_enabled():
|
|
134
150
|
# for hidden states L1 diff
|
|
135
|
-
|
|
151
|
+
CachedContext.set_Fn_buffer(
|
|
152
|
+
hidden_states,
|
|
153
|
+
f"{self.blocks_name}_Fn_hidden_states",
|
|
154
|
+
)
|
|
136
155
|
del Fn_hidden_states_residual
|
|
137
156
|
torch._dynamo.graph_break()
|
|
138
157
|
(
|
|
@@ -147,27 +166,27 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
147
166
|
**kwargs,
|
|
148
167
|
)
|
|
149
168
|
torch._dynamo.graph_break()
|
|
150
|
-
if
|
|
151
|
-
|
|
169
|
+
if CachedContext.is_cache_residual():
|
|
170
|
+
CachedContext.set_Bn_buffer(
|
|
152
171
|
hidden_states_residual,
|
|
153
|
-
prefix="
|
|
172
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
154
173
|
)
|
|
155
174
|
else:
|
|
156
175
|
# TaylorSeer
|
|
157
|
-
|
|
176
|
+
CachedContext.set_Bn_buffer(
|
|
158
177
|
hidden_states,
|
|
159
|
-
prefix="
|
|
178
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
160
179
|
)
|
|
161
|
-
if
|
|
162
|
-
|
|
180
|
+
if CachedContext.is_encoder_cache_residual():
|
|
181
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
163
182
|
encoder_hidden_states_residual,
|
|
164
|
-
prefix="
|
|
183
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
165
184
|
)
|
|
166
185
|
else:
|
|
167
186
|
# TaylorSeer
|
|
168
|
-
|
|
187
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
169
188
|
encoder_hidden_states,
|
|
170
|
-
prefix="
|
|
189
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
171
190
|
)
|
|
172
191
|
torch._dynamo.graph_break()
|
|
173
192
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -179,7 +198,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
179
198
|
**kwargs,
|
|
180
199
|
)
|
|
181
200
|
|
|
182
|
-
|
|
201
|
+
# patch cached stats for blocks or remove it.
|
|
183
202
|
torch._dynamo.graph_break()
|
|
184
203
|
|
|
185
204
|
return (
|
|
@@ -213,10 +232,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
213
232
|
# If so, we can skip some Bn blocks and directly
|
|
214
233
|
# use the cached values.
|
|
215
234
|
return (
|
|
216
|
-
|
|
235
|
+
CachedContext.get_current_step() in CachedContext.get_cached_steps()
|
|
217
236
|
) or (
|
|
218
|
-
|
|
219
|
-
in
|
|
237
|
+
CachedContext.get_current_step()
|
|
238
|
+
in CachedContext.get_cfg_cached_steps()
|
|
220
239
|
)
|
|
221
240
|
|
|
222
241
|
@torch.compiler.disable
|
|
@@ -225,20 +244,20 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
225
244
|
# more stable diff calculation.
|
|
226
245
|
# Fn: [0,...,n-1]
|
|
227
246
|
selected_Fn_blocks = self.transformer_blocks[
|
|
228
|
-
:
|
|
247
|
+
: CachedContext.Fn_compute_blocks()
|
|
229
248
|
]
|
|
230
249
|
return selected_Fn_blocks
|
|
231
250
|
|
|
232
251
|
@torch.compiler.disable
|
|
233
252
|
def _Mn_blocks(self): # middle blocks
|
|
234
253
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
235
|
-
if
|
|
254
|
+
if CachedContext.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
236
255
|
selected_Mn_blocks = self.transformer_blocks[
|
|
237
|
-
|
|
256
|
+
CachedContext.Fn_compute_blocks() :
|
|
238
257
|
]
|
|
239
258
|
else:
|
|
240
259
|
selected_Mn_blocks = self.transformer_blocks[
|
|
241
|
-
|
|
260
|
+
CachedContext.Fn_compute_blocks() : -CachedContext.Bn_compute_blocks()
|
|
242
261
|
]
|
|
243
262
|
return selected_Mn_blocks
|
|
244
263
|
|
|
@@ -246,7 +265,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
246
265
|
def _Bn_blocks(self):
|
|
247
266
|
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
248
267
|
selected_Bn_blocks = self.transformer_blocks[
|
|
249
|
-
-
|
|
268
|
+
-CachedContext.Bn_compute_blocks() :
|
|
250
269
|
]
|
|
251
270
|
return selected_Bn_blocks
|
|
252
271
|
|
|
@@ -257,10 +276,10 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
257
276
|
*args,
|
|
258
277
|
**kwargs,
|
|
259
278
|
):
|
|
260
|
-
assert
|
|
279
|
+
assert CachedContext.Fn_compute_blocks() <= len(
|
|
261
280
|
self.transformer_blocks
|
|
262
281
|
), (
|
|
263
|
-
f"Fn_compute_blocks {
|
|
282
|
+
f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
|
|
264
283
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
265
284
|
)
|
|
266
285
|
for block in self._Fn_blocks():
|
|
@@ -357,7 +376,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
357
376
|
)
|
|
358
377
|
# Cache residuals for the non-compute Bn blocks for
|
|
359
378
|
# subsequent cache steps.
|
|
360
|
-
if block_id not in
|
|
379
|
+
if block_id not in CachedContext.Bn_compute_blocks_ids():
|
|
361
380
|
Bn_i_hidden_states_residual = (
|
|
362
381
|
hidden_states - Bn_i_original_hidden_states
|
|
363
382
|
)
|
|
@@ -366,22 +385,22 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
366
385
|
)
|
|
367
386
|
|
|
368
387
|
# Save original_hidden_states for diff calculation.
|
|
369
|
-
|
|
388
|
+
CachedContext.set_Bn_buffer(
|
|
370
389
|
Bn_i_original_hidden_states,
|
|
371
|
-
prefix=f"
|
|
390
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_original",
|
|
372
391
|
)
|
|
373
|
-
|
|
392
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
374
393
|
Bn_i_original_encoder_hidden_states,
|
|
375
|
-
prefix=f"
|
|
394
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_original",
|
|
376
395
|
)
|
|
377
396
|
|
|
378
|
-
|
|
397
|
+
CachedContext.set_Bn_buffer(
|
|
379
398
|
Bn_i_hidden_states_residual,
|
|
380
|
-
prefix=f"
|
|
399
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
|
|
381
400
|
)
|
|
382
|
-
|
|
401
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
383
402
|
Bn_i_encoder_hidden_states_residual,
|
|
384
|
-
prefix=f"
|
|
403
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_residual",
|
|
385
404
|
)
|
|
386
405
|
del Bn_i_hidden_states_residual
|
|
387
406
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -392,7 +411,7 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
392
411
|
else:
|
|
393
412
|
# Cache steps: Reuse the cached residuals.
|
|
394
413
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
395
|
-
if block_id in
|
|
414
|
+
if block_id in CachedContext.Bn_compute_blocks_ids():
|
|
396
415
|
hidden_states = block(
|
|
397
416
|
hidden_states,
|
|
398
417
|
encoder_hidden_states,
|
|
@@ -410,25 +429,25 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
410
429
|
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
411
430
|
# Use the cached residuals instead.
|
|
412
431
|
# Check if can use the cached residuals.
|
|
413
|
-
if
|
|
432
|
+
if CachedContext.get_can_use_cache(
|
|
414
433
|
hidden_states, # curr step
|
|
415
434
|
parallelized=self._is_parallelized(),
|
|
416
|
-
threshold=
|
|
417
|
-
prefix=f"
|
|
435
|
+
threshold=CachedContext.non_compute_blocks_diff_threshold(),
|
|
436
|
+
prefix=f"{self.blocks_name}_Bn_{block_id}_original", # prev step
|
|
418
437
|
):
|
|
419
438
|
hidden_states, encoder_hidden_states = (
|
|
420
|
-
|
|
439
|
+
CachedContext.apply_hidden_states_residual(
|
|
421
440
|
hidden_states,
|
|
422
441
|
encoder_hidden_states,
|
|
423
442
|
prefix=(
|
|
424
|
-
f"
|
|
425
|
-
if
|
|
426
|
-
else f"
|
|
443
|
+
f"{self.blocks_name}_Bn_{block_id}_residual"
|
|
444
|
+
if CachedContext.is_cache_residual()
|
|
445
|
+
else f"{self.blocks_name}_Bn_{block_id}_original"
|
|
427
446
|
),
|
|
428
447
|
encoder_prefix=(
|
|
429
|
-
f"
|
|
430
|
-
if
|
|
431
|
-
else f"
|
|
448
|
+
f"{self.blocks_name}_Bn_{block_id}_residual"
|
|
449
|
+
if CachedContext.is_encoder_cache_residual()
|
|
450
|
+
else f"{self.blocks_name}_Bn_{block_id}_original"
|
|
432
451
|
),
|
|
433
452
|
)
|
|
434
453
|
)
|
|
@@ -455,16 +474,16 @@ class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
455
474
|
*args,
|
|
456
475
|
**kwargs,
|
|
457
476
|
):
|
|
458
|
-
if
|
|
477
|
+
if CachedContext.Bn_compute_blocks() == 0:
|
|
459
478
|
return hidden_states, encoder_hidden_states
|
|
460
479
|
|
|
461
|
-
assert
|
|
480
|
+
assert CachedContext.Bn_compute_blocks() <= len(
|
|
462
481
|
self.transformer_blocks
|
|
463
482
|
), (
|
|
464
|
-
f"Bn_compute_blocks {
|
|
483
|
+
f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
|
|
465
484
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
466
485
|
)
|
|
467
|
-
if len(
|
|
486
|
+
if len(CachedContext.Bn_compute_blocks_ids()) > 0:
|
|
468
487
|
for i, block in enumerate(self._Bn_blocks()):
|
|
469
488
|
hidden_states, encoder_hidden_states = (
|
|
470
489
|
self._compute_or_cache_block(
|
|
@@ -1,19 +1,23 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from
|
|
3
|
+
from typing import Any
|
|
4
|
+
from cache_dit.cache_factory import CachedContext
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
@torch.compiler.disable
|
|
7
8
|
def patch_cached_stats(
|
|
8
|
-
|
|
9
|
+
module: torch.nn.Module | Any, cache_context: str = None
|
|
9
10
|
):
|
|
10
|
-
# Patch the cached stats to the
|
|
11
|
+
# Patch the cached stats to the module, the cached stats
|
|
11
12
|
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
12
|
-
if
|
|
13
|
+
if module is None:
|
|
13
14
|
return
|
|
14
15
|
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
16
|
+
if cache_context is not None:
|
|
17
|
+
CachedContext.set_cache_context(cache_context)
|
|
18
|
+
|
|
19
|
+
# TODO: Patch more cached stats to the module
|
|
20
|
+
module._cached_steps = CachedContext.get_cached_steps()
|
|
21
|
+
module._residual_diffs = CachedContext.get_residual_diffs()
|
|
22
|
+
module._cfg_cached_steps = CachedContext.get_cfg_cached_steps()
|
|
23
|
+
module._cfg_residual_diffs = CachedContext.get_cfg_residual_diffs()
|