cache-dit 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/cache_adapters.py +375 -26
- cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
- cache_dit/cache_factory/cache_blocks/utils.py +19 -0
- cache_dit/cache_factory/cache_context.py +32 -25
- cache_dit/cache_factory/cache_interface.py +8 -3
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
- cache_dit/compile/utils.py +1 -1
- cache_dit/quantize/__init__.py +1 -0
- cache_dit/quantize/quantize_ao.py +196 -0
- cache_dit/quantize/quantize_interface.py +46 -0
- cache_dit/utils.py +49 -17
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/METADATA +43 -18
- cache_dit-0.2.26.dist-info/RECORD +42 -0
- cache_dit-0.2.24.dist-info/RECORD +0 -32
- /cache_dit/{cache_factory/patch/__init__.py → quantize/quantize_svdq.py} +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import cache_context
|
|
4
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
5
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
6
|
+
patch_cached_stats,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
9
|
+
DBCachedBlocks_Pattern_Base,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
17
|
+
_supported_patterns = [
|
|
18
|
+
ForwardPattern.Pattern_3,
|
|
19
|
+
ForwardPattern.Pattern_4,
|
|
20
|
+
ForwardPattern.Pattern_5,
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
def forward(
|
|
24
|
+
self,
|
|
25
|
+
hidden_states: torch.Tensor,
|
|
26
|
+
*args,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
original_hidden_states = hidden_states
|
|
30
|
+
# Call first `n` blocks to process the hidden states for
|
|
31
|
+
# more stable diff calculation.
|
|
32
|
+
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
33
|
+
hidden_states, encoder_hidden_states = self.call_Fn_blocks(
|
|
34
|
+
hidden_states,
|
|
35
|
+
*args,
|
|
36
|
+
**kwargs,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
40
|
+
del original_hidden_states
|
|
41
|
+
|
|
42
|
+
cache_context.mark_step_begin()
|
|
43
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
44
|
+
can_use_cache = cache_context.get_can_use_cache(
|
|
45
|
+
(
|
|
46
|
+
Fn_hidden_states_residual
|
|
47
|
+
if not cache_context.is_l1_diff_enabled()
|
|
48
|
+
else hidden_states
|
|
49
|
+
),
|
|
50
|
+
parallelized=self._is_parallelized(),
|
|
51
|
+
prefix=(
|
|
52
|
+
"Fn_residual"
|
|
53
|
+
if not cache_context.is_l1_diff_enabled()
|
|
54
|
+
else "Fn_hidden_states"
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
torch._dynamo.graph_break()
|
|
59
|
+
if can_use_cache:
|
|
60
|
+
cache_context.add_cached_step()
|
|
61
|
+
del Fn_hidden_states_residual
|
|
62
|
+
hidden_states, encoder_hidden_states = (
|
|
63
|
+
cache_context.apply_hidden_states_residual(
|
|
64
|
+
hidden_states,
|
|
65
|
+
# None Pattern 3, else 4, 5
|
|
66
|
+
encoder_hidden_states,
|
|
67
|
+
prefix=(
|
|
68
|
+
"Bn_residual"
|
|
69
|
+
if cache_context.is_cache_residual()
|
|
70
|
+
else "Bn_hidden_states"
|
|
71
|
+
),
|
|
72
|
+
encoder_prefix=(
|
|
73
|
+
"Bn_residual"
|
|
74
|
+
if cache_context.is_encoder_cache_residual()
|
|
75
|
+
else "Bn_hidden_states"
|
|
76
|
+
),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
torch._dynamo.graph_break()
|
|
80
|
+
# Call last `n` blocks to further process the hidden states
|
|
81
|
+
# for higher precision.
|
|
82
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
83
|
+
hidden_states,
|
|
84
|
+
encoder_hidden_states,
|
|
85
|
+
*args,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
cache_context.set_Fn_buffer(
|
|
90
|
+
Fn_hidden_states_residual, prefix="Fn_residual"
|
|
91
|
+
)
|
|
92
|
+
if cache_context.is_l1_diff_enabled():
|
|
93
|
+
# for hidden states L1 diff
|
|
94
|
+
cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
95
|
+
del Fn_hidden_states_residual
|
|
96
|
+
torch._dynamo.graph_break()
|
|
97
|
+
(
|
|
98
|
+
hidden_states,
|
|
99
|
+
encoder_hidden_states,
|
|
100
|
+
hidden_states_residual,
|
|
101
|
+
# None Pattern 3, else 4, 5
|
|
102
|
+
encoder_hidden_states_residual,
|
|
103
|
+
) = self.call_Mn_blocks( # middle
|
|
104
|
+
hidden_states,
|
|
105
|
+
# None Pattern 3, else 4, 5
|
|
106
|
+
encoder_hidden_states,
|
|
107
|
+
*args,
|
|
108
|
+
**kwargs,
|
|
109
|
+
)
|
|
110
|
+
torch._dynamo.graph_break()
|
|
111
|
+
if cache_context.is_cache_residual():
|
|
112
|
+
cache_context.set_Bn_buffer(
|
|
113
|
+
hidden_states_residual,
|
|
114
|
+
prefix="Bn_residual",
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
# TaylorSeer
|
|
118
|
+
cache_context.set_Bn_buffer(
|
|
119
|
+
hidden_states,
|
|
120
|
+
prefix="Bn_hidden_states",
|
|
121
|
+
)
|
|
122
|
+
if cache_context.is_encoder_cache_residual():
|
|
123
|
+
cache_context.set_Bn_encoder_buffer(
|
|
124
|
+
# None Pattern 3, else 4, 5
|
|
125
|
+
encoder_hidden_states_residual,
|
|
126
|
+
prefix="Bn_residual",
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
# TaylorSeer
|
|
130
|
+
cache_context.set_Bn_encoder_buffer(
|
|
131
|
+
# None Pattern 3, else 4, 5
|
|
132
|
+
encoder_hidden_states,
|
|
133
|
+
prefix="Bn_hidden_states",
|
|
134
|
+
)
|
|
135
|
+
torch._dynamo.graph_break()
|
|
136
|
+
# Call last `n` blocks to further process the hidden states
|
|
137
|
+
# for higher precision.
|
|
138
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
139
|
+
hidden_states,
|
|
140
|
+
# None Pattern 3, else 4, 5
|
|
141
|
+
encoder_hidden_states,
|
|
142
|
+
*args,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
patch_cached_stats(self.transformer)
|
|
147
|
+
torch._dynamo.graph_break()
|
|
148
|
+
|
|
149
|
+
return (
|
|
150
|
+
hidden_states
|
|
151
|
+
if self.forward_pattern.Return_H_Only
|
|
152
|
+
else (
|
|
153
|
+
(hidden_states, encoder_hidden_states)
|
|
154
|
+
if self.forward_pattern.Return_H_First
|
|
155
|
+
else (encoder_hidden_states, hidden_states)
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def call_Fn_blocks(
|
|
160
|
+
self,
|
|
161
|
+
hidden_states: torch.Tensor,
|
|
162
|
+
*args,
|
|
163
|
+
**kwargs,
|
|
164
|
+
):
|
|
165
|
+
assert cache_context.Fn_compute_blocks() <= len(
|
|
166
|
+
self.transformer_blocks
|
|
167
|
+
), (
|
|
168
|
+
f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
|
|
169
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
170
|
+
)
|
|
171
|
+
encoder_hidden_states = None # Pattern 3
|
|
172
|
+
for block in self._Fn_blocks():
|
|
173
|
+
hidden_states = block(
|
|
174
|
+
hidden_states,
|
|
175
|
+
*args,
|
|
176
|
+
**kwargs,
|
|
177
|
+
)
|
|
178
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
179
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
180
|
+
if not self.forward_pattern.Return_H_First:
|
|
181
|
+
hidden_states, encoder_hidden_states = (
|
|
182
|
+
encoder_hidden_states,
|
|
183
|
+
hidden_states,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return hidden_states, encoder_hidden_states
|
|
187
|
+
|
|
188
|
+
def call_Mn_blocks(
|
|
189
|
+
self,
|
|
190
|
+
hidden_states: torch.Tensor,
|
|
191
|
+
# None Pattern 3, else 4, 5
|
|
192
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
193
|
+
*args,
|
|
194
|
+
**kwargs,
|
|
195
|
+
):
|
|
196
|
+
original_hidden_states = hidden_states
|
|
197
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
198
|
+
for block in self._Mn_blocks():
|
|
199
|
+
hidden_states = block(
|
|
200
|
+
hidden_states,
|
|
201
|
+
*args,
|
|
202
|
+
**kwargs,
|
|
203
|
+
)
|
|
204
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
205
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
206
|
+
if not self.forward_pattern.Return_H_First:
|
|
207
|
+
hidden_states, encoder_hidden_states = (
|
|
208
|
+
encoder_hidden_states,
|
|
209
|
+
hidden_states,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# compute hidden_states residual
|
|
213
|
+
hidden_states = hidden_states.contiguous()
|
|
214
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
215
|
+
if (
|
|
216
|
+
original_encoder_hidden_states is not None
|
|
217
|
+
and encoder_hidden_states is not None
|
|
218
|
+
): # Pattern 4, 5
|
|
219
|
+
encoder_hidden_states_residual = (
|
|
220
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
encoder_hidden_states_residual = None # Pattern 3
|
|
224
|
+
|
|
225
|
+
return (
|
|
226
|
+
hidden_states,
|
|
227
|
+
encoder_hidden_states,
|
|
228
|
+
hidden_states_residual,
|
|
229
|
+
encoder_hidden_states_residual,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def call_Bn_blocks(
|
|
233
|
+
self,
|
|
234
|
+
hidden_states: torch.Tensor,
|
|
235
|
+
# None Pattern 3, else 4, 5
|
|
236
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
237
|
+
*args,
|
|
238
|
+
**kwargs,
|
|
239
|
+
):
|
|
240
|
+
if cache_context.Bn_compute_blocks() == 0:
|
|
241
|
+
return hidden_states, encoder_hidden_states
|
|
242
|
+
|
|
243
|
+
assert cache_context.Bn_compute_blocks() <= len(
|
|
244
|
+
self.transformer_blocks
|
|
245
|
+
), (
|
|
246
|
+
f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
|
|
247
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
248
|
+
)
|
|
249
|
+
if len(cache_context.Bn_compute_blocks_ids()) > 0:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Bn_compute_blocks_ids is not support for "
|
|
252
|
+
f"patterns: {self._supported_patterns}."
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
256
|
+
for block in self._Bn_blocks():
|
|
257
|
+
hidden_states = block(
|
|
258
|
+
hidden_states,
|
|
259
|
+
*args,
|
|
260
|
+
**kwargs,
|
|
261
|
+
)
|
|
262
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
263
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
264
|
+
if not self.forward_pattern.Return_H_First:
|
|
265
|
+
hidden_states, encoder_hidden_states = (
|
|
266
|
+
encoder_hidden_states,
|
|
267
|
+
hidden_states,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return hidden_states, encoder_hidden_states
|
|
@@ -4,12 +4,15 @@ import torch.distributed as dist
|
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory import cache_context
|
|
6
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
7
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
8
|
+
patch_cached_stats,
|
|
9
|
+
)
|
|
7
10
|
from cache_dit.logger import init_logger
|
|
8
11
|
|
|
9
12
|
logger = init_logger(__name__)
|
|
10
13
|
|
|
11
14
|
|
|
12
|
-
class
|
|
15
|
+
class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
13
16
|
_supported_patterns = [
|
|
14
17
|
ForwardPattern.Pattern_0,
|
|
15
18
|
ForwardPattern.Pattern_1,
|
|
@@ -29,18 +32,30 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
29
32
|
self.transformer_blocks = transformer_blocks
|
|
30
33
|
self.forward_pattern = forward_pattern
|
|
31
34
|
self._check_forward_pattern()
|
|
35
|
+
logger.info(f"Match Cached Blocks: {self.__class__.__name__}")
|
|
32
36
|
|
|
33
37
|
def _check_forward_pattern(self):
|
|
34
38
|
assert (
|
|
35
39
|
self.forward_pattern.Supported
|
|
36
40
|
and self.forward_pattern in self._supported_patterns
|
|
37
|
-
), f"Pattern {self.forward_pattern} is not
|
|
41
|
+
), f"Pattern {self.forward_pattern} is not supported now!"
|
|
38
42
|
|
|
39
43
|
if self.transformer_blocks is not None:
|
|
40
44
|
for block in self.transformer_blocks:
|
|
41
45
|
forward_parameters = set(
|
|
42
46
|
inspect.signature(block.forward).parameters.keys()
|
|
43
47
|
)
|
|
48
|
+
num_outputs = str(
|
|
49
|
+
inspect.signature(block.forward).return_annotation
|
|
50
|
+
).count("torch.Tensor")
|
|
51
|
+
|
|
52
|
+
if num_outputs > 0:
|
|
53
|
+
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
54
|
+
f"The number of block's outputs is {num_outputs} don't not "
|
|
55
|
+
f"match the number of the pattern: {self.forward_pattern}, "
|
|
56
|
+
f"Out: {len(self.forward_pattern.Out)}."
|
|
57
|
+
)
|
|
58
|
+
|
|
44
59
|
for required_param in self.forward_pattern.In:
|
|
45
60
|
assert (
|
|
46
61
|
required_param in forward_parameters
|
|
@@ -479,19 +494,3 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
479
494
|
)
|
|
480
495
|
|
|
481
496
|
return hidden_states, encoder_hidden_states
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
@torch.compiler.disable
|
|
485
|
-
def patch_cached_stats(
|
|
486
|
-
transformer,
|
|
487
|
-
):
|
|
488
|
-
# Patch the cached stats to the transformer, the cached stats
|
|
489
|
-
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
490
|
-
if transformer is None:
|
|
491
|
-
return
|
|
492
|
-
|
|
493
|
-
# TODO: Patch more cached stats to the transformer
|
|
494
|
-
transformer._cached_steps = cache_context.get_cached_steps()
|
|
495
|
-
transformer._residual_diffs = cache_context.get_residual_diffs()
|
|
496
|
-
transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
|
|
497
|
-
transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import cache_context
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.compiler.disable
|
|
7
|
+
def patch_cached_stats(
|
|
8
|
+
transformer,
|
|
9
|
+
):
|
|
10
|
+
# Patch the cached stats to the transformer, the cached stats
|
|
11
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
12
|
+
if transformer is None:
|
|
13
|
+
return
|
|
14
|
+
|
|
15
|
+
# TODO: Patch more cached stats to the transformer
|
|
16
|
+
transformer._cached_steps = cache_context.get_cached_steps()
|
|
17
|
+
transformer._residual_diffs = cache_context.get_residual_diffs()
|
|
18
|
+
transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
|
|
19
|
+
transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
|
|
@@ -328,6 +328,33 @@ class DBCacheContext:
|
|
|
328
328
|
return self.get_current_step() < self.max_warmup_steps
|
|
329
329
|
|
|
330
330
|
|
|
331
|
+
# TODO: Support context manager for different cache_context
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def create_cache_context(*args, **kwargs):
|
|
335
|
+
return DBCacheContext(*args, **kwargs)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def get_current_cache_context():
|
|
339
|
+
return _current_cache_context
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def set_current_cache_context(cache_context=None):
|
|
343
|
+
global _current_cache_context
|
|
344
|
+
_current_cache_context = cache_context
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@contextlib.contextmanager
|
|
348
|
+
def cache_context(cache_context):
|
|
349
|
+
global _current_cache_context
|
|
350
|
+
old_cache_context = _current_cache_context
|
|
351
|
+
_current_cache_context = cache_context
|
|
352
|
+
try:
|
|
353
|
+
yield
|
|
354
|
+
finally:
|
|
355
|
+
_current_cache_context = old_cache_context
|
|
356
|
+
|
|
357
|
+
|
|
331
358
|
@torch.compiler.disable
|
|
332
359
|
def get_residual_diff_threshold():
|
|
333
360
|
cache_context = get_current_cache_context()
|
|
@@ -657,19 +684,6 @@ def cfg_diff_compute_separate():
|
|
|
657
684
|
_current_cache_context: DBCacheContext = None
|
|
658
685
|
|
|
659
686
|
|
|
660
|
-
def create_cache_context(*args, **kwargs):
|
|
661
|
-
return DBCacheContext(*args, **kwargs)
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
def get_current_cache_context():
|
|
665
|
-
return _current_cache_context
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
def set_current_cache_context(cache_context=None):
|
|
669
|
-
global _current_cache_context
|
|
670
|
-
_current_cache_context = cache_context
|
|
671
|
-
|
|
672
|
-
|
|
673
687
|
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
674
688
|
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
675
689
|
# default_attrs: specific settings for different pipelines
|
|
@@ -716,17 +730,6 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
|
716
730
|
return cache_kwargs, kwargs
|
|
717
731
|
|
|
718
732
|
|
|
719
|
-
@contextlib.contextmanager
|
|
720
|
-
def cache_context(cache_context):
|
|
721
|
-
global _current_cache_context
|
|
722
|
-
old_cache_context = _current_cache_context
|
|
723
|
-
_current_cache_context = cache_context
|
|
724
|
-
try:
|
|
725
|
-
yield
|
|
726
|
-
finally:
|
|
727
|
-
_current_cache_context = old_cache_context
|
|
728
|
-
|
|
729
|
-
|
|
730
733
|
@torch.compiler.disable
|
|
731
734
|
def are_two_tensors_similar(
|
|
732
735
|
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
@@ -938,7 +941,11 @@ def get_Bn_buffer(prefix: str = "Bn"):
|
|
|
938
941
|
|
|
939
942
|
|
|
940
943
|
@torch.compiler.disable
|
|
941
|
-
def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
944
|
+
def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
|
|
945
|
+
# DON'T set None Buffer
|
|
946
|
+
if buffer is None:
|
|
947
|
+
return
|
|
948
|
+
|
|
942
949
|
# This buffer is use for encoder hidden states approximation.
|
|
943
950
|
if is_encoder_taylorseer_enabled():
|
|
944
951
|
# taylorseer, encoder_taylorseer
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Any, Tuple, List
|
|
1
2
|
from diffusers import DiffusionPipeline
|
|
2
3
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
3
4
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
@@ -9,9 +10,13 @@ from cache_dit.logger import init_logger
|
|
|
9
10
|
logger = init_logger(__name__)
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
def supported_pipelines() -> Tuple[int, List[str]]:
|
|
14
|
+
return UnifiedCacheAdapter.supported_pipelines()
|
|
15
|
+
|
|
16
|
+
|
|
12
17
|
def enable_cache(
|
|
13
18
|
# BlockAdapter & forward pattern
|
|
14
|
-
pipe_or_adapter: DiffusionPipeline | BlockAdapter,
|
|
19
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
|
|
15
20
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
16
21
|
# Cache context kwargs
|
|
17
22
|
Fn_compute_blocks: int = 8,
|
|
@@ -23,14 +28,14 @@ def enable_cache(
|
|
|
23
28
|
# Cache CFG or not
|
|
24
29
|
do_separate_cfg: bool = False,
|
|
25
30
|
cfg_compute_first: bool = False,
|
|
26
|
-
cfg_diff_compute_separate: bool =
|
|
31
|
+
cfg_diff_compute_separate: bool = True,
|
|
27
32
|
# Hybird TaylorSeer
|
|
28
33
|
enable_taylorseer: bool = False,
|
|
29
34
|
enable_encoder_taylorseer: bool = False,
|
|
30
35
|
taylorseer_cache_type: str = "residual",
|
|
31
36
|
taylorseer_order: int = 2,
|
|
32
37
|
**other_cache_kwargs,
|
|
33
|
-
) -> DiffusionPipeline:
|
|
38
|
+
) -> DiffusionPipeline | Any:
|
|
34
39
|
r"""
|
|
35
40
|
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
36
41
|
that match the specific Input and Output patterns).
|
|
@@ -19,39 +19,57 @@ class ForwardPattern(Enum):
|
|
|
19
19
|
self.Supported = Supported
|
|
20
20
|
|
|
21
21
|
Pattern_0 = (
|
|
22
|
-
True,
|
|
23
|
-
False,
|
|
24
|
-
False,
|
|
25
|
-
("hidden_states", "encoder_hidden_states"),
|
|
26
|
-
("hidden_states", "encoder_hidden_states"),
|
|
27
|
-
True,
|
|
22
|
+
True, # Return_H_First
|
|
23
|
+
False, # Return_H_Only
|
|
24
|
+
False, # Forward_H_only
|
|
25
|
+
("hidden_states", "encoder_hidden_states"), # In
|
|
26
|
+
("hidden_states", "encoder_hidden_states"), # Out
|
|
27
|
+
True, # Supported
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
Pattern_1 = (
|
|
31
|
-
False,
|
|
32
|
-
False,
|
|
33
|
-
False,
|
|
34
|
-
("hidden_states", "encoder_hidden_states"),
|
|
35
|
-
("encoder_hidden_states", "hidden_states"),
|
|
36
|
-
True,
|
|
31
|
+
False, # Return_H_First
|
|
32
|
+
False, # Return_H_Only
|
|
33
|
+
False, # Forward_H_only
|
|
34
|
+
("hidden_states", "encoder_hidden_states"), # In
|
|
35
|
+
("encoder_hidden_states", "hidden_states"), # Out
|
|
36
|
+
True, # Supported
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
Pattern_2 = (
|
|
40
|
-
False,
|
|
41
|
-
True,
|
|
42
|
-
False,
|
|
43
|
-
("hidden_states", "encoder_hidden_states"),
|
|
44
|
-
("hidden_states",),
|
|
45
|
-
True,
|
|
40
|
+
False, # Return_H_First
|
|
41
|
+
True, # Return_H_Only
|
|
42
|
+
False, # Forward_H_only
|
|
43
|
+
("hidden_states", "encoder_hidden_states"), # In
|
|
44
|
+
("hidden_states",), # Out
|
|
45
|
+
True, # Supported
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
Pattern_3 = (
|
|
49
|
-
False,
|
|
50
|
-
True,
|
|
51
|
-
|
|
52
|
-
("hidden_states",),
|
|
53
|
-
("hidden_states",),
|
|
54
|
-
|
|
49
|
+
False, # Return_H_First
|
|
50
|
+
True, # Return_H_Only
|
|
51
|
+
True, # Forward_H_only
|
|
52
|
+
("hidden_states",), # In
|
|
53
|
+
("hidden_states",), # Out
|
|
54
|
+
True, # Supported
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
Pattern_4 = (
|
|
58
|
+
True, # Return_H_First
|
|
59
|
+
False, # Return_H_Only
|
|
60
|
+
True, # Forward_H_only
|
|
61
|
+
("hidden_states",), # In
|
|
62
|
+
("hidden_states", "encoder_hidden_states"), # Out
|
|
63
|
+
True, # Supported
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
Pattern_5 = (
|
|
67
|
+
False, # Return_H_First
|
|
68
|
+
False, # Return_H_Only
|
|
69
|
+
True, # Forward_H_only
|
|
70
|
+
("hidden_states",), # In
|
|
71
|
+
("encoder_hidden_states", "hidden_states"), # Out
|
|
72
|
+
True, # Supported
|
|
55
73
|
)
|
|
56
74
|
|
|
57
75
|
@staticmethod
|
|
@@ -60,4 +78,7 @@ class ForwardPattern(Enum):
|
|
|
60
78
|
ForwardPattern.Pattern_0,
|
|
61
79
|
ForwardPattern.Pattern_1,
|
|
62
80
|
ForwardPattern.Pattern_2,
|
|
81
|
+
ForwardPattern.Pattern_3,
|
|
82
|
+
ForwardPattern.Pattern_4,
|
|
83
|
+
ForwardPattern.Pattern_5,
|
|
63
84
|
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
|
|
4
|
+
from cache_dit.logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PatchFunctor:
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def apply(
|
|
13
|
+
self,
|
|
14
|
+
transformer: torch.nn.Module,
|
|
15
|
+
*args,
|
|
16
|
+
**kwargs,
|
|
17
|
+
) -> torch.nn.Module:
|
|
18
|
+
raise NotImplementedError("apply method is not implemented.")
|