cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +8 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +17 -4
- cache_dit/cache_factory/block_adapters/__init__.py +555 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +262 -938
- cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
- cache_dit/cache_factory/cache_blocks/utils.py +16 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +31 -31
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +26 -26
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
- cache_dit-0.2.28.dist-info/RECORD +47 -0
- cache_dit/cache_factory/cache_context.py +0 -1155
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -1,20 +1,69 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
|
+
CachedContextManager,
|
|
7
|
+
)
|
|
8
|
+
|
|
1
9
|
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
2
|
-
|
|
10
|
+
CachedBlocks_Pattern_0_1_2,
|
|
3
11
|
)
|
|
4
12
|
from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
5
|
-
|
|
13
|
+
CachedBlocks_Pattern_3_4_5,
|
|
6
14
|
)
|
|
7
15
|
|
|
16
|
+
from cache_dit.logger import init_logger
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
8
20
|
|
|
9
|
-
class
|
|
10
|
-
def __new__(
|
|
11
|
-
|
|
21
|
+
class CachedBlocks:
|
|
22
|
+
def __new__(
|
|
23
|
+
cls,
|
|
24
|
+
# 0. Transformer blocks configuration
|
|
25
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
26
|
+
transformer: torch.nn.Module = None,
|
|
27
|
+
forward_pattern: ForwardPattern = None,
|
|
28
|
+
check_num_outputs: bool = True,
|
|
29
|
+
# 1. Cache context configuration
|
|
30
|
+
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
31
|
+
# 'layers', 'single_stream_blocks', 'double_stream_blocks'
|
|
32
|
+
cache_prefix: str = None, # cache_prefix maybe un-need.
|
|
33
|
+
# Usually, blocks_name, etc.
|
|
34
|
+
cache_context: CachedContext | str = None,
|
|
35
|
+
cache_manager: CachedContextManager = None,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
assert transformer is not None, "transformer can't be None."
|
|
12
39
|
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
13
|
-
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
40
|
+
assert cache_context is not None, "cache_context can't be None."
|
|
41
|
+
assert cache_manager is not None, "cache_manager can't be None."
|
|
42
|
+
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
43
|
+
return CachedBlocks_Pattern_0_1_2(
|
|
44
|
+
# 0. Transformer blocks configuration
|
|
45
|
+
transformer_blocks,
|
|
46
|
+
transformer=transformer,
|
|
47
|
+
forward_pattern=forward_pattern,
|
|
48
|
+
check_num_outputs=check_num_outputs,
|
|
49
|
+
# 1. Cache context configuration
|
|
50
|
+
cache_prefix=cache_prefix,
|
|
51
|
+
cache_context=cache_context,
|
|
52
|
+
cache_manager=cache_manager,
|
|
53
|
+
**kwargs,
|
|
54
|
+
)
|
|
55
|
+
elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
|
|
56
|
+
return CachedBlocks_Pattern_3_4_5(
|
|
57
|
+
# 0. Transformer blocks configuration
|
|
58
|
+
transformer_blocks,
|
|
59
|
+
transformer=transformer,
|
|
60
|
+
forward_pattern=forward_pattern,
|
|
61
|
+
check_num_outputs=check_num_outputs,
|
|
62
|
+
# 1. Cache context configuration
|
|
63
|
+
cache_prefix=cache_prefix,
|
|
64
|
+
cache_context=cache_context,
|
|
65
|
+
cache_manager=cache_manager,
|
|
66
|
+
**kwargs,
|
|
67
|
+
)
|
|
19
68
|
else:
|
|
20
69
|
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
from cache_dit.cache_factory import ForwardPattern
|
|
2
2
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
3
|
-
|
|
3
|
+
CachedBlocks_Pattern_Base,
|
|
4
4
|
)
|
|
5
5
|
from cache_dit.logger import init_logger
|
|
6
6
|
|
|
7
7
|
logger = init_logger(__name__)
|
|
8
8
|
|
|
9
9
|
|
|
10
|
-
class
|
|
10
|
+
class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
|
|
11
11
|
_supported_patterns = [
|
|
12
12
|
ForwardPattern.Pattern_0,
|
|
13
13
|
ForwardPattern.Pattern_1,
|
|
@@ -1,19 +1,15 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from cache_dit.cache_factory import cache_context
|
|
4
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
5
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
6
|
-
patch_cached_stats,
|
|
7
|
-
)
|
|
8
4
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
9
|
-
|
|
5
|
+
CachedBlocks_Pattern_Base,
|
|
10
6
|
)
|
|
11
7
|
from cache_dit.logger import init_logger
|
|
12
8
|
|
|
13
9
|
logger = init_logger(__name__)
|
|
14
10
|
|
|
15
11
|
|
|
16
|
-
class
|
|
12
|
+
class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
17
13
|
_supported_patterns = [
|
|
18
14
|
ForwardPattern.Pattern_3,
|
|
19
15
|
ForwardPattern.Pattern_4,
|
|
@@ -26,6 +22,11 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
26
22
|
*args,
|
|
27
23
|
**kwargs,
|
|
28
24
|
):
|
|
25
|
+
# Use it's own cache context.
|
|
26
|
+
self.cache_manager.set_context(
|
|
27
|
+
self.cache_context,
|
|
28
|
+
)
|
|
29
|
+
|
|
29
30
|
original_hidden_states = hidden_states
|
|
30
31
|
# Call first `n` blocks to process the hidden states for
|
|
31
32
|
# more stable diff calculation.
|
|
@@ -39,40 +40,40 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
39
40
|
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
40
41
|
del original_hidden_states
|
|
41
42
|
|
|
42
|
-
|
|
43
|
+
self.cache_manager.mark_step_begin()
|
|
43
44
|
# Residual L1 diff or Hidden States L1 diff
|
|
44
|
-
can_use_cache =
|
|
45
|
+
can_use_cache = self.cache_manager.can_cache(
|
|
45
46
|
(
|
|
46
47
|
Fn_hidden_states_residual
|
|
47
|
-
if not
|
|
48
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
48
49
|
else hidden_states
|
|
49
50
|
),
|
|
50
51
|
parallelized=self._is_parallelized(),
|
|
51
52
|
prefix=(
|
|
52
|
-
"
|
|
53
|
-
if not
|
|
54
|
-
else "
|
|
53
|
+
f"{self.cache_prefix}_Fn_residual"
|
|
54
|
+
if not self.cache_manager.is_l1_diff_enabled()
|
|
55
|
+
else f"{self.cache_prefix}_Fn_hidden_states"
|
|
55
56
|
),
|
|
56
57
|
)
|
|
57
58
|
|
|
58
59
|
torch._dynamo.graph_break()
|
|
59
60
|
if can_use_cache:
|
|
60
|
-
|
|
61
|
+
self.cache_manager.add_cached_step()
|
|
61
62
|
del Fn_hidden_states_residual
|
|
62
63
|
hidden_states, encoder_hidden_states = (
|
|
63
|
-
|
|
64
|
+
self.cache_manager.apply_cache(
|
|
64
65
|
hidden_states,
|
|
65
66
|
# None Pattern 3, else 4, 5
|
|
66
67
|
encoder_hidden_states,
|
|
67
68
|
prefix=(
|
|
68
|
-
"
|
|
69
|
-
if
|
|
70
|
-
else "
|
|
69
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
70
|
+
if self.cache_manager.is_cache_residual()
|
|
71
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
71
72
|
),
|
|
72
73
|
encoder_prefix=(
|
|
73
|
-
"
|
|
74
|
-
if
|
|
75
|
-
else "
|
|
74
|
+
f"{self.cache_prefix}_Bn_residual"
|
|
75
|
+
if self.cache_manager.is_encoder_cache_residual()
|
|
76
|
+
else f"{self.cache_prefix}_Bn_hidden_states"
|
|
76
77
|
),
|
|
77
78
|
)
|
|
78
79
|
)
|
|
@@ -86,12 +87,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
86
87
|
**kwargs,
|
|
87
88
|
)
|
|
88
89
|
else:
|
|
89
|
-
|
|
90
|
-
Fn_hidden_states_residual,
|
|
90
|
+
self.cache_manager.set_Fn_buffer(
|
|
91
|
+
Fn_hidden_states_residual,
|
|
92
|
+
prefix=f"{self.cache_prefix}_Fn_residual",
|
|
91
93
|
)
|
|
92
|
-
if
|
|
94
|
+
if self.cache_manager.is_l1_diff_enabled():
|
|
93
95
|
# for hidden states L1 diff
|
|
94
|
-
|
|
96
|
+
self.cache_manager.set_Fn_buffer(
|
|
97
|
+
hidden_states,
|
|
98
|
+
f"{self.cache_prefix}_Fn_hidden_states",
|
|
99
|
+
)
|
|
95
100
|
del Fn_hidden_states_residual
|
|
96
101
|
torch._dynamo.graph_break()
|
|
97
102
|
(
|
|
@@ -108,29 +113,29 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
108
113
|
**kwargs,
|
|
109
114
|
)
|
|
110
115
|
torch._dynamo.graph_break()
|
|
111
|
-
if
|
|
112
|
-
|
|
116
|
+
if self.cache_manager.is_cache_residual():
|
|
117
|
+
self.cache_manager.set_Bn_buffer(
|
|
113
118
|
hidden_states_residual,
|
|
114
|
-
prefix="
|
|
119
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
115
120
|
)
|
|
116
121
|
else:
|
|
117
122
|
# TaylorSeer
|
|
118
|
-
|
|
123
|
+
self.cache_manager.set_Bn_buffer(
|
|
119
124
|
hidden_states,
|
|
120
|
-
prefix="
|
|
125
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
121
126
|
)
|
|
122
|
-
if
|
|
123
|
-
|
|
127
|
+
if self.cache_manager.is_encoder_cache_residual():
|
|
128
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
124
129
|
# None Pattern 3, else 4, 5
|
|
125
130
|
encoder_hidden_states_residual,
|
|
126
|
-
prefix="
|
|
131
|
+
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
127
132
|
)
|
|
128
133
|
else:
|
|
129
134
|
# TaylorSeer
|
|
130
|
-
|
|
135
|
+
self.cache_manager.set_Bn_encoder_buffer(
|
|
131
136
|
# None Pattern 3, else 4, 5
|
|
132
137
|
encoder_hidden_states,
|
|
133
|
-
prefix="
|
|
138
|
+
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
134
139
|
)
|
|
135
140
|
torch._dynamo.graph_break()
|
|
136
141
|
# Call last `n` blocks to further process the hidden states
|
|
@@ -143,7 +148,6 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
143
148
|
**kwargs,
|
|
144
149
|
)
|
|
145
150
|
|
|
146
|
-
patch_cached_stats(self.transformer)
|
|
147
151
|
torch._dynamo.graph_break()
|
|
148
152
|
|
|
149
153
|
return (
|
|
@@ -162,10 +166,10 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
162
166
|
*args,
|
|
163
167
|
**kwargs,
|
|
164
168
|
):
|
|
165
|
-
assert
|
|
169
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
166
170
|
self.transformer_blocks
|
|
167
171
|
), (
|
|
168
|
-
f"Fn_compute_blocks {
|
|
172
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
169
173
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
170
174
|
)
|
|
171
175
|
encoder_hidden_states = None # Pattern 3
|
|
@@ -237,16 +241,16 @@ class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
|
237
241
|
*args,
|
|
238
242
|
**kwargs,
|
|
239
243
|
):
|
|
240
|
-
if
|
|
244
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
241
245
|
return hidden_states, encoder_hidden_states
|
|
242
246
|
|
|
243
|
-
assert
|
|
247
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
244
248
|
self.transformer_blocks
|
|
245
249
|
), (
|
|
246
|
-
f"Bn_compute_blocks {
|
|
250
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
247
251
|
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
248
252
|
)
|
|
249
|
-
if len(
|
|
253
|
+
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
250
254
|
raise ValueError(
|
|
251
255
|
f"Bn_compute_blocks_ids is not support for "
|
|
252
256
|
f"patterns: {self._supported_patterns}."
|