cache-dit 0.2.25__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +4 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/cache_adapters.py +375 -26
- cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
- cache_dit/cache_factory/cache_blocks/utils.py +19 -0
- cache_dit/cache_factory/cache_context.py +5 -1
- cache_dit/cache_factory/cache_interface.py +7 -2
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
- cache_dit/quantize/quantize_ao.py +18 -4
- cache_dit/quantize/quantize_interface.py +2 -2
- cache_dit/utils.py +3 -2
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/METADATA +35 -8
- cache_dit-0.2.26.dist-info/RECORD +42 -0
- cache_dit/cache_factory/patch/__init__.py +0 -0
- cache_dit-0.2.25.dist-info/RECORD +0 -36
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,270 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import cache_context
|
|
4
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
5
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
6
|
+
patch_cached_stats,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
9
|
+
DBCachedBlocks_Pattern_Base,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class DBCachedBlocks_Pattern_3_4_5(DBCachedBlocks_Pattern_Base):
|
|
17
|
+
_supported_patterns = [
|
|
18
|
+
ForwardPattern.Pattern_3,
|
|
19
|
+
ForwardPattern.Pattern_4,
|
|
20
|
+
ForwardPattern.Pattern_5,
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
def forward(
|
|
24
|
+
self,
|
|
25
|
+
hidden_states: torch.Tensor,
|
|
26
|
+
*args,
|
|
27
|
+
**kwargs,
|
|
28
|
+
):
|
|
29
|
+
original_hidden_states = hidden_states
|
|
30
|
+
# Call first `n` blocks to process the hidden states for
|
|
31
|
+
# more stable diff calculation.
|
|
32
|
+
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
33
|
+
hidden_states, encoder_hidden_states = self.call_Fn_blocks(
|
|
34
|
+
hidden_states,
|
|
35
|
+
*args,
|
|
36
|
+
**kwargs,
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
40
|
+
del original_hidden_states
|
|
41
|
+
|
|
42
|
+
cache_context.mark_step_begin()
|
|
43
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
44
|
+
can_use_cache = cache_context.get_can_use_cache(
|
|
45
|
+
(
|
|
46
|
+
Fn_hidden_states_residual
|
|
47
|
+
if not cache_context.is_l1_diff_enabled()
|
|
48
|
+
else hidden_states
|
|
49
|
+
),
|
|
50
|
+
parallelized=self._is_parallelized(),
|
|
51
|
+
prefix=(
|
|
52
|
+
"Fn_residual"
|
|
53
|
+
if not cache_context.is_l1_diff_enabled()
|
|
54
|
+
else "Fn_hidden_states"
|
|
55
|
+
),
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
torch._dynamo.graph_break()
|
|
59
|
+
if can_use_cache:
|
|
60
|
+
cache_context.add_cached_step()
|
|
61
|
+
del Fn_hidden_states_residual
|
|
62
|
+
hidden_states, encoder_hidden_states = (
|
|
63
|
+
cache_context.apply_hidden_states_residual(
|
|
64
|
+
hidden_states,
|
|
65
|
+
# None Pattern 3, else 4, 5
|
|
66
|
+
encoder_hidden_states,
|
|
67
|
+
prefix=(
|
|
68
|
+
"Bn_residual"
|
|
69
|
+
if cache_context.is_cache_residual()
|
|
70
|
+
else "Bn_hidden_states"
|
|
71
|
+
),
|
|
72
|
+
encoder_prefix=(
|
|
73
|
+
"Bn_residual"
|
|
74
|
+
if cache_context.is_encoder_cache_residual()
|
|
75
|
+
else "Bn_hidden_states"
|
|
76
|
+
),
|
|
77
|
+
)
|
|
78
|
+
)
|
|
79
|
+
torch._dynamo.graph_break()
|
|
80
|
+
# Call last `n` blocks to further process the hidden states
|
|
81
|
+
# for higher precision.
|
|
82
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
83
|
+
hidden_states,
|
|
84
|
+
encoder_hidden_states,
|
|
85
|
+
*args,
|
|
86
|
+
**kwargs,
|
|
87
|
+
)
|
|
88
|
+
else:
|
|
89
|
+
cache_context.set_Fn_buffer(
|
|
90
|
+
Fn_hidden_states_residual, prefix="Fn_residual"
|
|
91
|
+
)
|
|
92
|
+
if cache_context.is_l1_diff_enabled():
|
|
93
|
+
# for hidden states L1 diff
|
|
94
|
+
cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
95
|
+
del Fn_hidden_states_residual
|
|
96
|
+
torch._dynamo.graph_break()
|
|
97
|
+
(
|
|
98
|
+
hidden_states,
|
|
99
|
+
encoder_hidden_states,
|
|
100
|
+
hidden_states_residual,
|
|
101
|
+
# None Pattern 3, else 4, 5
|
|
102
|
+
encoder_hidden_states_residual,
|
|
103
|
+
) = self.call_Mn_blocks( # middle
|
|
104
|
+
hidden_states,
|
|
105
|
+
# None Pattern 3, else 4, 5
|
|
106
|
+
encoder_hidden_states,
|
|
107
|
+
*args,
|
|
108
|
+
**kwargs,
|
|
109
|
+
)
|
|
110
|
+
torch._dynamo.graph_break()
|
|
111
|
+
if cache_context.is_cache_residual():
|
|
112
|
+
cache_context.set_Bn_buffer(
|
|
113
|
+
hidden_states_residual,
|
|
114
|
+
prefix="Bn_residual",
|
|
115
|
+
)
|
|
116
|
+
else:
|
|
117
|
+
# TaylorSeer
|
|
118
|
+
cache_context.set_Bn_buffer(
|
|
119
|
+
hidden_states,
|
|
120
|
+
prefix="Bn_hidden_states",
|
|
121
|
+
)
|
|
122
|
+
if cache_context.is_encoder_cache_residual():
|
|
123
|
+
cache_context.set_Bn_encoder_buffer(
|
|
124
|
+
# None Pattern 3, else 4, 5
|
|
125
|
+
encoder_hidden_states_residual,
|
|
126
|
+
prefix="Bn_residual",
|
|
127
|
+
)
|
|
128
|
+
else:
|
|
129
|
+
# TaylorSeer
|
|
130
|
+
cache_context.set_Bn_encoder_buffer(
|
|
131
|
+
# None Pattern 3, else 4, 5
|
|
132
|
+
encoder_hidden_states,
|
|
133
|
+
prefix="Bn_hidden_states",
|
|
134
|
+
)
|
|
135
|
+
torch._dynamo.graph_break()
|
|
136
|
+
# Call last `n` blocks to further process the hidden states
|
|
137
|
+
# for higher precision.
|
|
138
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
139
|
+
hidden_states,
|
|
140
|
+
# None Pattern 3, else 4, 5
|
|
141
|
+
encoder_hidden_states,
|
|
142
|
+
*args,
|
|
143
|
+
**kwargs,
|
|
144
|
+
)
|
|
145
|
+
|
|
146
|
+
patch_cached_stats(self.transformer)
|
|
147
|
+
torch._dynamo.graph_break()
|
|
148
|
+
|
|
149
|
+
return (
|
|
150
|
+
hidden_states
|
|
151
|
+
if self.forward_pattern.Return_H_Only
|
|
152
|
+
else (
|
|
153
|
+
(hidden_states, encoder_hidden_states)
|
|
154
|
+
if self.forward_pattern.Return_H_First
|
|
155
|
+
else (encoder_hidden_states, hidden_states)
|
|
156
|
+
)
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
def call_Fn_blocks(
|
|
160
|
+
self,
|
|
161
|
+
hidden_states: torch.Tensor,
|
|
162
|
+
*args,
|
|
163
|
+
**kwargs,
|
|
164
|
+
):
|
|
165
|
+
assert cache_context.Fn_compute_blocks() <= len(
|
|
166
|
+
self.transformer_blocks
|
|
167
|
+
), (
|
|
168
|
+
f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
|
|
169
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
170
|
+
)
|
|
171
|
+
encoder_hidden_states = None # Pattern 3
|
|
172
|
+
for block in self._Fn_blocks():
|
|
173
|
+
hidden_states = block(
|
|
174
|
+
hidden_states,
|
|
175
|
+
*args,
|
|
176
|
+
**kwargs,
|
|
177
|
+
)
|
|
178
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
179
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
180
|
+
if not self.forward_pattern.Return_H_First:
|
|
181
|
+
hidden_states, encoder_hidden_states = (
|
|
182
|
+
encoder_hidden_states,
|
|
183
|
+
hidden_states,
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
return hidden_states, encoder_hidden_states
|
|
187
|
+
|
|
188
|
+
def call_Mn_blocks(
|
|
189
|
+
self,
|
|
190
|
+
hidden_states: torch.Tensor,
|
|
191
|
+
# None Pattern 3, else 4, 5
|
|
192
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
193
|
+
*args,
|
|
194
|
+
**kwargs,
|
|
195
|
+
):
|
|
196
|
+
original_hidden_states = hidden_states
|
|
197
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
198
|
+
for block in self._Mn_blocks():
|
|
199
|
+
hidden_states = block(
|
|
200
|
+
hidden_states,
|
|
201
|
+
*args,
|
|
202
|
+
**kwargs,
|
|
203
|
+
)
|
|
204
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4, 5
|
|
205
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
206
|
+
if not self.forward_pattern.Return_H_First:
|
|
207
|
+
hidden_states, encoder_hidden_states = (
|
|
208
|
+
encoder_hidden_states,
|
|
209
|
+
hidden_states,
|
|
210
|
+
)
|
|
211
|
+
|
|
212
|
+
# compute hidden_states residual
|
|
213
|
+
hidden_states = hidden_states.contiguous()
|
|
214
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
215
|
+
if (
|
|
216
|
+
original_encoder_hidden_states is not None
|
|
217
|
+
and encoder_hidden_states is not None
|
|
218
|
+
): # Pattern 4, 5
|
|
219
|
+
encoder_hidden_states_residual = (
|
|
220
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
221
|
+
)
|
|
222
|
+
else:
|
|
223
|
+
encoder_hidden_states_residual = None # Pattern 3
|
|
224
|
+
|
|
225
|
+
return (
|
|
226
|
+
hidden_states,
|
|
227
|
+
encoder_hidden_states,
|
|
228
|
+
hidden_states_residual,
|
|
229
|
+
encoder_hidden_states_residual,
|
|
230
|
+
)
|
|
231
|
+
|
|
232
|
+
def call_Bn_blocks(
|
|
233
|
+
self,
|
|
234
|
+
hidden_states: torch.Tensor,
|
|
235
|
+
# None Pattern 3, else 4, 5
|
|
236
|
+
encoder_hidden_states: torch.Tensor | None,
|
|
237
|
+
*args,
|
|
238
|
+
**kwargs,
|
|
239
|
+
):
|
|
240
|
+
if cache_context.Bn_compute_blocks() == 0:
|
|
241
|
+
return hidden_states, encoder_hidden_states
|
|
242
|
+
|
|
243
|
+
assert cache_context.Bn_compute_blocks() <= len(
|
|
244
|
+
self.transformer_blocks
|
|
245
|
+
), (
|
|
246
|
+
f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
|
|
247
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
248
|
+
)
|
|
249
|
+
if len(cache_context.Bn_compute_blocks_ids()) > 0:
|
|
250
|
+
raise ValueError(
|
|
251
|
+
f"Bn_compute_blocks_ids is not support for "
|
|
252
|
+
f"patterns: {self._supported_patterns}."
|
|
253
|
+
)
|
|
254
|
+
else:
|
|
255
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
256
|
+
for block in self._Bn_blocks():
|
|
257
|
+
hidden_states = block(
|
|
258
|
+
hidden_states,
|
|
259
|
+
*args,
|
|
260
|
+
**kwargs,
|
|
261
|
+
)
|
|
262
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
263
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
264
|
+
if not self.forward_pattern.Return_H_First:
|
|
265
|
+
hidden_states, encoder_hidden_states = (
|
|
266
|
+
encoder_hidden_states,
|
|
267
|
+
hidden_states,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
return hidden_states, encoder_hidden_states
|
|
@@ -4,12 +4,15 @@ import torch.distributed as dist
|
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory import cache_context
|
|
6
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
7
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
8
|
+
patch_cached_stats,
|
|
9
|
+
)
|
|
7
10
|
from cache_dit.logger import init_logger
|
|
8
11
|
|
|
9
12
|
logger = init_logger(__name__)
|
|
10
13
|
|
|
11
14
|
|
|
12
|
-
class
|
|
15
|
+
class DBCachedBlocks_Pattern_Base(torch.nn.Module):
|
|
13
16
|
_supported_patterns = [
|
|
14
17
|
ForwardPattern.Pattern_0,
|
|
15
18
|
ForwardPattern.Pattern_1,
|
|
@@ -29,18 +32,30 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
29
32
|
self.transformer_blocks = transformer_blocks
|
|
30
33
|
self.forward_pattern = forward_pattern
|
|
31
34
|
self._check_forward_pattern()
|
|
35
|
+
logger.info(f"Match Cached Blocks: {self.__class__.__name__}")
|
|
32
36
|
|
|
33
37
|
def _check_forward_pattern(self):
|
|
34
38
|
assert (
|
|
35
39
|
self.forward_pattern.Supported
|
|
36
40
|
and self.forward_pattern in self._supported_patterns
|
|
37
|
-
), f"Pattern {self.forward_pattern} is not
|
|
41
|
+
), f"Pattern {self.forward_pattern} is not supported now!"
|
|
38
42
|
|
|
39
43
|
if self.transformer_blocks is not None:
|
|
40
44
|
for block in self.transformer_blocks:
|
|
41
45
|
forward_parameters = set(
|
|
42
46
|
inspect.signature(block.forward).parameters.keys()
|
|
43
47
|
)
|
|
48
|
+
num_outputs = str(
|
|
49
|
+
inspect.signature(block.forward).return_annotation
|
|
50
|
+
).count("torch.Tensor")
|
|
51
|
+
|
|
52
|
+
if num_outputs > 0:
|
|
53
|
+
assert len(self.forward_pattern.Out) == num_outputs, (
|
|
54
|
+
f"The number of block's outputs is {num_outputs} don't not "
|
|
55
|
+
f"match the number of the pattern: {self.forward_pattern}, "
|
|
56
|
+
f"Out: {len(self.forward_pattern.Out)}."
|
|
57
|
+
)
|
|
58
|
+
|
|
44
59
|
for required_param in self.forward_pattern.In:
|
|
45
60
|
assert (
|
|
46
61
|
required_param in forward_parameters
|
|
@@ -479,19 +494,3 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
479
494
|
)
|
|
480
495
|
|
|
481
496
|
return hidden_states, encoder_hidden_states
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
@torch.compiler.disable
|
|
485
|
-
def patch_cached_stats(
|
|
486
|
-
transformer,
|
|
487
|
-
):
|
|
488
|
-
# Patch the cached stats to the transformer, the cached stats
|
|
489
|
-
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
490
|
-
if transformer is None:
|
|
491
|
-
return
|
|
492
|
-
|
|
493
|
-
# TODO: Patch more cached stats to the transformer
|
|
494
|
-
transformer._cached_steps = cache_context.get_cached_steps()
|
|
495
|
-
transformer._residual_diffs = cache_context.get_residual_diffs()
|
|
496
|
-
transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
|
|
497
|
-
transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
|
|
@@ -0,0 +1,19 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
from cache_dit.cache_factory import cache_context
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
@torch.compiler.disable
|
|
7
|
+
def patch_cached_stats(
|
|
8
|
+
transformer,
|
|
9
|
+
):
|
|
10
|
+
# Patch the cached stats to the transformer, the cached stats
|
|
11
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
12
|
+
if transformer is None:
|
|
13
|
+
return
|
|
14
|
+
|
|
15
|
+
# TODO: Patch more cached stats to the transformer
|
|
16
|
+
transformer._cached_steps = cache_context.get_cached_steps()
|
|
17
|
+
transformer._residual_diffs = cache_context.get_residual_diffs()
|
|
18
|
+
transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
|
|
19
|
+
transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
|
|
@@ -941,7 +941,11 @@ def get_Bn_buffer(prefix: str = "Bn"):
|
|
|
941
941
|
|
|
942
942
|
|
|
943
943
|
@torch.compiler.disable
|
|
944
|
-
def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
944
|
+
def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
|
|
945
|
+
# DON'T set None Buffer
|
|
946
|
+
if buffer is None:
|
|
947
|
+
return
|
|
948
|
+
|
|
945
949
|
# This buffer is use for encoder hidden states approximation.
|
|
946
950
|
if is_encoder_taylorseer_enabled():
|
|
947
951
|
# taylorseer, encoder_taylorseer
|
|
@@ -1,3 +1,4 @@
|
|
|
1
|
+
from typing import Any, Tuple, List
|
|
1
2
|
from diffusers import DiffusionPipeline
|
|
2
3
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
3
4
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
@@ -9,9 +10,13 @@ from cache_dit.logger import init_logger
|
|
|
9
10
|
logger = init_logger(__name__)
|
|
10
11
|
|
|
11
12
|
|
|
13
|
+
def supported_pipelines() -> Tuple[int, List[str]]:
|
|
14
|
+
return UnifiedCacheAdapter.supported_pipelines()
|
|
15
|
+
|
|
16
|
+
|
|
12
17
|
def enable_cache(
|
|
13
18
|
# BlockAdapter & forward pattern
|
|
14
|
-
pipe_or_adapter: DiffusionPipeline | BlockAdapter,
|
|
19
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
|
|
15
20
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
16
21
|
# Cache context kwargs
|
|
17
22
|
Fn_compute_blocks: int = 8,
|
|
@@ -30,7 +35,7 @@ def enable_cache(
|
|
|
30
35
|
taylorseer_cache_type: str = "residual",
|
|
31
36
|
taylorseer_order: int = 2,
|
|
32
37
|
**other_cache_kwargs,
|
|
33
|
-
) -> DiffusionPipeline:
|
|
38
|
+
) -> DiffusionPipeline | Any:
|
|
34
39
|
r"""
|
|
35
40
|
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
36
41
|
that match the specific Input and Output patterns).
|
|
@@ -19,39 +19,57 @@ class ForwardPattern(Enum):
|
|
|
19
19
|
self.Supported = Supported
|
|
20
20
|
|
|
21
21
|
Pattern_0 = (
|
|
22
|
-
True,
|
|
23
|
-
False,
|
|
24
|
-
False,
|
|
25
|
-
("hidden_states", "encoder_hidden_states"),
|
|
26
|
-
("hidden_states", "encoder_hidden_states"),
|
|
27
|
-
True,
|
|
22
|
+
True, # Return_H_First
|
|
23
|
+
False, # Return_H_Only
|
|
24
|
+
False, # Forward_H_only
|
|
25
|
+
("hidden_states", "encoder_hidden_states"), # In
|
|
26
|
+
("hidden_states", "encoder_hidden_states"), # Out
|
|
27
|
+
True, # Supported
|
|
28
28
|
)
|
|
29
29
|
|
|
30
30
|
Pattern_1 = (
|
|
31
|
-
False,
|
|
32
|
-
False,
|
|
33
|
-
False,
|
|
34
|
-
("hidden_states", "encoder_hidden_states"),
|
|
35
|
-
("encoder_hidden_states", "hidden_states"),
|
|
36
|
-
True,
|
|
31
|
+
False, # Return_H_First
|
|
32
|
+
False, # Return_H_Only
|
|
33
|
+
False, # Forward_H_only
|
|
34
|
+
("hidden_states", "encoder_hidden_states"), # In
|
|
35
|
+
("encoder_hidden_states", "hidden_states"), # Out
|
|
36
|
+
True, # Supported
|
|
37
37
|
)
|
|
38
38
|
|
|
39
39
|
Pattern_2 = (
|
|
40
|
-
False,
|
|
41
|
-
True,
|
|
42
|
-
False,
|
|
43
|
-
("hidden_states", "encoder_hidden_states"),
|
|
44
|
-
("hidden_states",),
|
|
45
|
-
True,
|
|
40
|
+
False, # Return_H_First
|
|
41
|
+
True, # Return_H_Only
|
|
42
|
+
False, # Forward_H_only
|
|
43
|
+
("hidden_states", "encoder_hidden_states"), # In
|
|
44
|
+
("hidden_states",), # Out
|
|
45
|
+
True, # Supported
|
|
46
46
|
)
|
|
47
47
|
|
|
48
48
|
Pattern_3 = (
|
|
49
|
-
False,
|
|
50
|
-
True,
|
|
51
|
-
|
|
52
|
-
("hidden_states",),
|
|
53
|
-
("hidden_states",),
|
|
54
|
-
|
|
49
|
+
False, # Return_H_First
|
|
50
|
+
True, # Return_H_Only
|
|
51
|
+
True, # Forward_H_only
|
|
52
|
+
("hidden_states",), # In
|
|
53
|
+
("hidden_states",), # Out
|
|
54
|
+
True, # Supported
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
Pattern_4 = (
|
|
58
|
+
True, # Return_H_First
|
|
59
|
+
False, # Return_H_Only
|
|
60
|
+
True, # Forward_H_only
|
|
61
|
+
("hidden_states",), # In
|
|
62
|
+
("hidden_states", "encoder_hidden_states"), # Out
|
|
63
|
+
True, # Supported
|
|
64
|
+
)
|
|
65
|
+
|
|
66
|
+
Pattern_5 = (
|
|
67
|
+
False, # Return_H_First
|
|
68
|
+
False, # Return_H_Only
|
|
69
|
+
True, # Forward_H_only
|
|
70
|
+
("hidden_states",), # In
|
|
71
|
+
("encoder_hidden_states", "hidden_states"), # Out
|
|
72
|
+
True, # Supported
|
|
55
73
|
)
|
|
56
74
|
|
|
57
75
|
@staticmethod
|
|
@@ -60,4 +78,7 @@ class ForwardPattern(Enum):
|
|
|
60
78
|
ForwardPattern.Pattern_0,
|
|
61
79
|
ForwardPattern.Pattern_1,
|
|
62
80
|
ForwardPattern.Pattern_2,
|
|
81
|
+
ForwardPattern.Pattern_3,
|
|
82
|
+
ForwardPattern.Pattern_4,
|
|
83
|
+
ForwardPattern.Pattern_5,
|
|
63
84
|
]
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from abc import abstractmethod
|
|
3
|
+
|
|
4
|
+
from cache_dit.logger import init_logger
|
|
5
|
+
|
|
6
|
+
logger = init_logger(__name__)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class PatchFunctor:
|
|
10
|
+
|
|
11
|
+
@abstractmethod
|
|
12
|
+
def apply(
|
|
13
|
+
self,
|
|
14
|
+
transformer: torch.nn.Module,
|
|
15
|
+
*args,
|
|
16
|
+
**kwargs,
|
|
17
|
+
) -> torch.nn.Module:
|
|
18
|
+
raise NotImplementedError("apply method is not implemented.")
|