cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +9 -4
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +16 -3
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +121 -563
- cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
- cache_dit/cache_factory/cache_blocks/utils.py +23 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
- cache_dit/cache_factory/cache_interface.py +24 -16
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
- cache_dit/quantize/quantize_ao.py +19 -4
- cache_dit/quantize/quantize_interface.py +2 -2
- cache_dit/utils.py +19 -15
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.25.dist-info/RECORD +0 -36
- /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
2
|
+
CachedBlocks_Pattern_0_1_2,
|
|
3
|
+
)
|
|
4
|
+
from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
5
|
+
CachedBlocks_Pattern_3_4_5,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class CachedBlocks:
|
|
10
|
+
def __new__(cls, *args, **kwargs):
|
|
11
|
+
forward_pattern = kwargs.get("forward_pattern", None)
|
|
12
|
+
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
13
|
+
if forward_pattern in CachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
14
|
+
return CachedBlocks_Pattern_0_1_2(*args, **kwargs)
|
|
15
|
+
elif forward_pattern in CachedBlocks_Pattern_3_4_5._supported_patterns:
|
|
16
|
+
return CachedBlocks_Pattern_3_4_5(*args, **kwargs)
|
|
17
|
+
else:
|
|
18
|
+
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
2
|
+
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
3
|
+
CachedBlocks_Pattern_Base,
|
|
4
|
+
)
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class CachedBlocks_Pattern_0_1_2(CachedBlocks_Pattern_Base):
|
|
11
|
+
_supported_patterns = [
|
|
12
|
+
ForwardPattern.Pattern_0,
|
|
13
|
+
ForwardPattern.Pattern_1,
|
|
14
|
+
ForwardPattern.Pattern_2,
|
|
15
|
+
]
|
|
16
|
+
...
|
|
@@ -0,0 +1,275 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import CachedContext
|
|
4
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
5
|
+
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
6
|
+
CachedBlocks_Pattern_Base,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
14
|
+
_supported_patterns = [
|
|
15
|
+
ForwardPattern.Pattern_3,
|
|
16
|
+
ForwardPattern.Pattern_4,
|
|
17
|
+
ForwardPattern.Pattern_5,
|
|
18
|
+
]
|
|
19
|
+
|
|
20
|
+
def forward(
|
|
21
|
+
self,
|
|
22
|
+
hidden_states: torch.Tensor,
|
|
23
|
+
*args,
|
|
24
|
+
**kwargs,
|
|
25
|
+
):
|
|
26
|
+
# Use it's own cache context.
|
|
27
|
+
CachedContext.set_cache_context(
|
|
28
|
+
self.cache_context,
|
|
29
|
+
)
|
|
30
|
+
|
|
31
|
+
original_hidden_states = hidden_states
|
|
32
|
+
# Call first `n` blocks to process the hidden states for
|
|
33
|
+
# more stable diff calculation.
|
|
34
|
+
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
35
|
+
hidden_states, encoder_hidden_states = self.call_Fn_blocks(
|
|
36
|
+
hidden_states,
|
|
37
|
+
*args,
|
|
38
|
+
**kwargs,
|
|
39
|
+
)
|
|
40
|
+
|
|
41
|
+
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
42
|
+
del original_hidden_states
|
|
43
|
+
|
|
44
|
+
CachedContext.mark_step_begin()
|
|
45
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
46
|
+
can_use_cache = CachedContext.get_can_use_cache(
|
|
47
|
+
(
|
|
48
|
+
Fn_hidden_states_residual
|
|
49
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
50
|
+
else hidden_states
|
|
51
|
+
),
|
|
52
|
+
parallelized=self._is_parallelized(),
|
|
53
|
+
prefix=(
|
|
54
|
+
f"{self.blocks_name}_Fn_residual"
|
|
55
|
+
if not CachedContext.is_l1_diff_enabled()
|
|
56
|
+
else f"{self.blocks_name}_Fn_hidden_states"
|
|
57
|
+
),
|
|
58
|
+
)
|
|
59
|
+
|
|
60
|
+
torch._dynamo.graph_break()
|
|
61
|
+
if can_use_cache:
|
|
62
|
+
CachedContext.add_cached_step()
|
|
63
|
+
del Fn_hidden_states_residual
|
|
64
|
+
hidden_states, encoder_hidden_states = (
|
|
65
|
+
CachedContext.apply_hidden_states_residual(
|
|
66
|
+
hidden_states,
|
|
67
|
+
# None Pattern 3, else 4, 5
|
|
68
|
+
encoder_hidden_states,
|
|
69
|
+
prefix=(
|
|
70
|
+
f"{self.blocks_name}_Bn_residual"
|
|
71
|
+
if CachedContext.is_cache_residual()
|
|
72
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
73
|
+
),
|
|
74
|
+
encoder_prefix=(
|
|
75
|
+
f"{self.blocks_name}_Bn_residual"
|
|
76
|
+
if CachedContext.is_encoder_cache_residual()
|
|
77
|
+
else f"{self.blocks_name}_Bn_hidden_states"
|
|
78
|
+
),
|
|
79
|
+
)
|
|
80
|
+
)
|
|
81
|
+
torch._dynamo.graph_break()
|
|
82
|
+
# Call last `n` blocks to further process the hidden states
|
|
83
|
+
# for higher precision.
|
|
84
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
85
|
+
hidden_states,
|
|
86
|
+
encoder_hidden_states,
|
|
87
|
+
*args,
|
|
88
|
+
**kwargs,
|
|
89
|
+
)
|
|
90
|
+
else:
|
|
91
|
+
CachedContext.set_Fn_buffer(
|
|
92
|
+
Fn_hidden_states_residual,
|
|
93
|
+
prefix=f"{self.blocks_name}_Fn_residual",
|
|
94
|
+
)
|
|
95
|
+
if CachedContext.is_l1_diff_enabled():
|
|
96
|
+
# for hidden states L1 diff
|
|
97
|
+
CachedContext.set_Fn_buffer(
|
|
98
|
+
hidden_states,
|
|
99
|
+
f"{self.blocks_name}_Fn_hidden_states",
|
|
100
|
+
)
|
|
101
|
+
del Fn_hidden_states_residual
|
|
102
|
+
torch._dynamo.graph_break()
|
|
103
|
+
(
|
|
104
|
+
hidden_states,
|
|
105
|
+
encoder_hidden_states,
|
|
106
|
+
hidden_states_residual,
|
|
107
|
+
# None Pattern 3, else 4, 5
|
|
108
|
+
encoder_hidden_states_residual,
|
|
109
|
+
) = self.call_Mn_blocks( # middle
|
|
110
|
+
hidden_states,
|
|
111
|
+
# None Pattern 3, else 4, 5
|
|
112
|
+
encoder_hidden_states,
|
|
113
|
+
*args,
|
|
114
|
+
**kwargs,
|
|
115
|
+
)
|
|
116
|
+
torch._dynamo.graph_break()
|
|
117
|
+
if CachedContext.is_cache_residual():
|
|
118
|
+
CachedContext.set_Bn_buffer(
|
|
119
|
+
hidden_states_residual,
|
|
120
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
# TaylorSeer
|
|
124
|
+
CachedContext.set_Bn_buffer(
|
|
125
|
+
hidden_states,
|
|
126
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
127
|
+
)
|
|
128
|
+
if CachedContext.is_encoder_cache_residual():
|
|
129
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
130
|
+
# None Pattern 3, else 4, 5
|
|
131
|
+
encoder_hidden_states_residual,
|
|
132
|
+
prefix=f"{self.blocks_name}_Bn_residual",
|
|
133
|
+
)
|
|
134
|
+
else:
|
|
135
|
+
# TaylorSeer
|
|
136
|
+
CachedContext.set_Bn_encoder_buffer(
|
|
137
|
+
# None Pattern 3, else 4, 5
|
|
138
|
+
encoder_hidden_states,
|
|
139
|
+
prefix=f"{self.blocks_name}_Bn_hidden_states",
|
|
140
|
+
)
|
|
141
|
+
torch._dynamo.graph_break()
|
|
142
|
+
# Call last `n` blocks to further process the hidden states
|
|
143
|
+
# for higher precision.
|
|
144
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
145
|
+
hidden_states,
|
|
146
|
+
# None Pattern 3, else 4, 5
|
|
147
|
+
encoder_hidden_states,
|
|
148
|
+
*args,
|
|
149
|
+
**kwargs,
|
|
150
|
+
)
|
|
151
|
+
|
|
152
|
+
torch._dynamo.graph_break()
|
|
153
|
+
|
|
154
|
+
return (
|
|
155
|
+
hidden_states
|
|
156
|
+
if self.forward_pattern.Return_H_Only
|
|
157
|
+
else (
|
|
158
|
+
(hidden_states, encoder_hidden_states)
|
|
159
|
+
if self.forward_pattern.Return_H_First
|
|
160
|
+
else (encoder_hidden_states, hidden_states)
|
|
161
|
+
)
|
|
162
|
+
)
|
|
163
|
+
|
|
164
|
+
def call_Fn_blocks(
|
|
165
|
+
self,
|
|
166
|
+
hidden_states: torch.Tensor,
|
|
167
|
+
*args,
|
|
168
|
+
**kwargs,
|
|
169
|
+
):
|
|
170
|
+
assert CachedContext.Fn_compute_blocks() <= len(
|
|
171
|
+
self.transformer_blocks
|
|
172
|
+
), (
|
|
173
|
+
f"Fn_compute_blocks {CachedContext.Fn_compute_blocks()} must be less than "
|
|
174
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
175
|
+
)
|
|
176
|
+
encoder_hidden_states = None # Pattern 3
|
|
177
|
+
for block in self._Fn_blocks():
|
|
178
|
+
hidden_states = block(
|
|
179
|
+
hidden_states,
|
|
180
|
+
*args,
|
|
181
|
+
**kwargs,
|
|
182
|
+
)
|
|
183
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
184
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
185
|
+
if not self.forward_pattern.Return_H_First:
|
|
186
|
+
hidden_states, encoder_hidden_states = (
|
|
187
|
+
encoder_hidden_states,
|
|
188
|
+
hidden_states,
|
|
189
|
+
)
|
|
190
|
+
|
|
191
|
+
return hidden_states, encoder_hidden_states
|
|
192
|
+
|
|
193
|
+
def call_Mn_blocks(
|
|
194
|
+
self,
|
|
195
|
+
hidden_states: torch.Tensor,
|
|
196
|
+
# None Pattern 3, else 4, 5
|
|
197
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
198
|
+
*args,
|
|
199
|
+
**kwargs,
|
|
200
|
+
):
|
|
201
|
+
original_hidden_states = hidden_states
|
|
202
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
203
|
+
for block in self._Mn_blocks():
|
|
204
|
+
hidden_states = block(
|
|
205
|
+
hidden_states,
|
|
206
|
+
*args,
|
|
207
|
+
**kwargs,
|
|
208
|
+
)
|
|
209
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
210
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
211
|
+
if not self.forward_pattern.Return_H_First:
|
|
212
|
+
hidden_states, encoder_hidden_states = (
|
|
213
|
+
encoder_hidden_states,
|
|
214
|
+
hidden_states,
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
# compute hidden_states residual
|
|
218
|
+
hidden_states = hidden_states.contiguous()
|
|
219
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
220
|
+
if (
|
|
221
|
+
original_encoder_hidden_states is not None
|
|
222
|
+
and encoder_hidden_states is not None
|
|
223
|
+
): # Pattern 4, 5
|
|
224
|
+
encoder_hidden_states_residual = (
|
|
225
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
226
|
+
)
|
|
227
|
+
else:
|
|
228
|
+
encoder_hidden_states_residual = None # Pattern 3
|
|
229
|
+
|
|
230
|
+
return (
|
|
231
|
+
hidden_states,
|
|
232
|
+
encoder_hidden_states,
|
|
233
|
+
hidden_states_residual,
|
|
234
|
+
encoder_hidden_states_residual,
|
|
235
|
+
)
|
|
236
|
+
|
|
237
|
+
def call_Bn_blocks(
|
|
238
|
+
self,
|
|
239
|
+
hidden_states: torch.Tensor,
|
|
240
|
+
# None Pattern 3, else 4, 5
|
|
241
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
242
|
+
*args,
|
|
243
|
+
**kwargs,
|
|
244
|
+
):
|
|
245
|
+
if CachedContext.Bn_compute_blocks() == 0:
|
|
246
|
+
return hidden_states, encoder_hidden_states
|
|
247
|
+
|
|
248
|
+
assert CachedContext.Bn_compute_blocks() <= len(
|
|
249
|
+
self.transformer_blocks
|
|
250
|
+
), (
|
|
251
|
+
f"Bn_compute_blocks {CachedContext.Bn_compute_blocks()} must be less than "
|
|
252
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
253
|
+
)
|
|
254
|
+
if len(CachedContext.Bn_compute_blocks_ids()) > 0:
|
|
255
|
+
raise ValueError(
|
|
256
|
+
f"Bn_compute_blocks_ids is not support for "
|
|
257
|
+
f"patterns: {self._supported_patterns}."
|
|
258
|
+
)
|
|
259
|
+
else:
|
|
260
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
261
|
+
for block in self._Bn_blocks():
|
|
262
|
+
hidden_states = block(
|
|
263
|
+
hidden_states,
|
|
264
|
+
*args,
|
|
265
|
+
**kwargs,
|
|
266
|
+
)
|
|
267
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
268
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
269
|
+
if not self.forward_pattern.Return_H_First:
|
|
270
|
+
hidden_states, encoder_hidden_states = (
|
|
271
|
+
encoder_hidden_states,
|
|
272
|
+
hidden_states,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
return hidden_states, encoder_hidden_states
|