cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -4,12 +4,12 @@ from diffusers import DiffusionPipeline
|
|
|
4
4
|
|
|
5
5
|
from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
|
|
6
6
|
apply_db_cache_on_pipe,
|
|
7
|
+
apply_db_cache_on_transformer,
|
|
7
8
|
)
|
|
8
|
-
|
|
9
|
-
apply_fb_cache_on_pipe,
|
|
10
|
-
)
|
|
9
|
+
|
|
11
10
|
from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
|
|
12
11
|
apply_db_prune_on_pipe,
|
|
12
|
+
apply_db_prune_on_transformer,
|
|
13
13
|
)
|
|
14
14
|
|
|
15
15
|
from cache_dit.logger import init_logger
|
|
@@ -93,7 +93,7 @@ class CacheType(Enum):
|
|
|
93
93
|
}
|
|
94
94
|
|
|
95
95
|
_Fn_compute_blocks = 8
|
|
96
|
-
_Bn_compute_blocks =
|
|
96
|
+
_Bn_compute_blocks = 0
|
|
97
97
|
|
|
98
98
|
_db_options = {
|
|
99
99
|
"cache_type": CacheType.DBCache,
|
|
@@ -155,7 +155,9 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
|
155
155
|
cache_type = CacheType.type(cache_type)
|
|
156
156
|
|
|
157
157
|
if cache_type == CacheType.FBCache:
|
|
158
|
-
|
|
158
|
+
raise ValueError(
|
|
159
|
+
"FBCache is removed from cache-dit, please use DBCache F1B0 instead."
|
|
160
|
+
)
|
|
159
161
|
elif cache_type == CacheType.DBCache:
|
|
160
162
|
return apply_db_cache_on_pipe(pipe, *args, **kwargs)
|
|
161
163
|
elif cache_type == CacheType.DBPrune:
|
|
@@ -167,3 +169,43 @@ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
|
167
169
|
return pipe
|
|
168
170
|
else:
|
|
169
171
|
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
def apply_cache_on_transformer(transformer, *args, **kwargs):
|
|
175
|
+
|
|
176
|
+
if hasattr(transformer, "_is_cached") and transformer._is_cached:
|
|
177
|
+
return transformer
|
|
178
|
+
|
|
179
|
+
if hasattr(transformer, "_is_pruned") and transformer._is_pruned:
|
|
180
|
+
return transformer
|
|
181
|
+
|
|
182
|
+
cache_type = kwargs.pop("cache_type", None)
|
|
183
|
+
if cache_type is None:
|
|
184
|
+
logger.warning(
|
|
185
|
+
"No cache type specified, we will use DBCache by default. "
|
|
186
|
+
"Please specify the cache_type explicitly if you want to "
|
|
187
|
+
"use a different cache type."
|
|
188
|
+
)
|
|
189
|
+
# Force to use DBCache with default cache options
|
|
190
|
+
return apply_db_cache_on_transformer(
|
|
191
|
+
transformer,
|
|
192
|
+
**CacheType.default_options(CacheType.DBCache),
|
|
193
|
+
)
|
|
194
|
+
|
|
195
|
+
cache_type = CacheType.type(cache_type)
|
|
196
|
+
|
|
197
|
+
if cache_type == CacheType.FBCache:
|
|
198
|
+
raise ValueError(
|
|
199
|
+
"FBCache is removed from cache-dit, please use DBCache F1B0 instead."
|
|
200
|
+
)
|
|
201
|
+
elif cache_type == CacheType.DBCache:
|
|
202
|
+
return apply_db_cache_on_transformer(transformer, *args, **kwargs)
|
|
203
|
+
elif cache_type == CacheType.DBPrune:
|
|
204
|
+
return apply_db_prune_on_transformer(transformer, *args, **kwargs)
|
|
205
|
+
elif cache_type == CacheType.NONE:
|
|
206
|
+
logger.warning(
|
|
207
|
+
f"Cache type is {cache_type}, no caching will be applied."
|
|
208
|
+
)
|
|
209
|
+
return transformer
|
|
210
|
+
else:
|
|
211
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
@@ -0,0 +1,487 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
import torch
|
|
3
|
+
|
|
4
|
+
from cache_dit.cache_factory.dual_block_cache import cache_context
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DBCachedTransformerBlocks(torch.nn.Module):
|
|
11
|
+
def __init__(
|
|
12
|
+
self,
|
|
13
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
14
|
+
*,
|
|
15
|
+
transformer: torch.nn.Module = None,
|
|
16
|
+
return_hidden_states_first: bool = True,
|
|
17
|
+
return_hidden_states_only: bool = False,
|
|
18
|
+
):
|
|
19
|
+
super().__init__()
|
|
20
|
+
|
|
21
|
+
self.transformer = transformer
|
|
22
|
+
self.transformer_blocks = transformer_blocks
|
|
23
|
+
self.return_hidden_states_first = return_hidden_states_first
|
|
24
|
+
self.return_hidden_states_only = return_hidden_states_only
|
|
25
|
+
self._check_forward_params()
|
|
26
|
+
|
|
27
|
+
def _check_forward_params(self):
|
|
28
|
+
# NOTE: DBCache only support blocks which have the pattern:
|
|
29
|
+
# IN/OUT: (hidden_states, encoder_hidden_states)
|
|
30
|
+
self.required_parameters = [
|
|
31
|
+
"hidden_states",
|
|
32
|
+
"encoder_hidden_states",
|
|
33
|
+
]
|
|
34
|
+
if self.transformer_blocks is not None:
|
|
35
|
+
for block in self.transformer_blocks:
|
|
36
|
+
forward_parameters = set(
|
|
37
|
+
inspect.signature(block.forward).parameters.keys()
|
|
38
|
+
)
|
|
39
|
+
for required_param in self.required_parameters:
|
|
40
|
+
assert (
|
|
41
|
+
required_param in forward_parameters
|
|
42
|
+
), f"The input parameters must contains: {required_param}."
|
|
43
|
+
|
|
44
|
+
def forward(
|
|
45
|
+
self,
|
|
46
|
+
hidden_states: torch.Tensor,
|
|
47
|
+
encoder_hidden_states: torch.Tensor,
|
|
48
|
+
*args,
|
|
49
|
+
**kwargs,
|
|
50
|
+
):
|
|
51
|
+
original_hidden_states = hidden_states
|
|
52
|
+
# Call first `n` blocks to process the hidden states for
|
|
53
|
+
# more stable diff calculation.
|
|
54
|
+
hidden_states, encoder_hidden_states = self.call_Fn_blocks(
|
|
55
|
+
hidden_states,
|
|
56
|
+
encoder_hidden_states,
|
|
57
|
+
*args,
|
|
58
|
+
**kwargs,
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
Fn_hidden_states_residual = hidden_states - original_hidden_states
|
|
62
|
+
del original_hidden_states
|
|
63
|
+
|
|
64
|
+
cache_context.mark_step_begin()
|
|
65
|
+
# Residual L1 diff or Hidden States L1 diff
|
|
66
|
+
can_use_cache = cache_context.get_can_use_cache(
|
|
67
|
+
(
|
|
68
|
+
Fn_hidden_states_residual
|
|
69
|
+
if not cache_context.is_l1_diff_enabled()
|
|
70
|
+
else hidden_states
|
|
71
|
+
),
|
|
72
|
+
parallelized=self._is_parallelized(),
|
|
73
|
+
prefix=(
|
|
74
|
+
"Fn_residual"
|
|
75
|
+
if not cache_context.is_l1_diff_enabled()
|
|
76
|
+
else "Fn_hidden_states"
|
|
77
|
+
),
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
torch._dynamo.graph_break()
|
|
81
|
+
if can_use_cache:
|
|
82
|
+
cache_context.add_cached_step()
|
|
83
|
+
del Fn_hidden_states_residual
|
|
84
|
+
hidden_states, encoder_hidden_states = (
|
|
85
|
+
cache_context.apply_hidden_states_residual(
|
|
86
|
+
hidden_states,
|
|
87
|
+
encoder_hidden_states,
|
|
88
|
+
prefix=(
|
|
89
|
+
"Bn_residual"
|
|
90
|
+
if cache_context.is_cache_residual()
|
|
91
|
+
else "Bn_hidden_states"
|
|
92
|
+
),
|
|
93
|
+
encoder_prefix=(
|
|
94
|
+
"Bn_residual"
|
|
95
|
+
if cache_context.is_encoder_cache_residual()
|
|
96
|
+
else "Bn_hidden_states"
|
|
97
|
+
),
|
|
98
|
+
)
|
|
99
|
+
)
|
|
100
|
+
torch._dynamo.graph_break()
|
|
101
|
+
# Call last `n` blocks to further process the hidden states
|
|
102
|
+
# for higher precision.
|
|
103
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
104
|
+
hidden_states,
|
|
105
|
+
encoder_hidden_states,
|
|
106
|
+
*args,
|
|
107
|
+
**kwargs,
|
|
108
|
+
)
|
|
109
|
+
else:
|
|
110
|
+
cache_context.set_Fn_buffer(
|
|
111
|
+
Fn_hidden_states_residual, prefix="Fn_residual"
|
|
112
|
+
)
|
|
113
|
+
if cache_context.is_l1_diff_enabled():
|
|
114
|
+
# for hidden states L1 diff
|
|
115
|
+
cache_context.set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
116
|
+
del Fn_hidden_states_residual
|
|
117
|
+
torch._dynamo.graph_break()
|
|
118
|
+
(
|
|
119
|
+
hidden_states,
|
|
120
|
+
encoder_hidden_states,
|
|
121
|
+
hidden_states_residual,
|
|
122
|
+
encoder_hidden_states_residual,
|
|
123
|
+
) = self.call_Mn_blocks( # middle
|
|
124
|
+
hidden_states,
|
|
125
|
+
encoder_hidden_states,
|
|
126
|
+
*args,
|
|
127
|
+
**kwargs,
|
|
128
|
+
)
|
|
129
|
+
torch._dynamo.graph_break()
|
|
130
|
+
if cache_context.is_cache_residual():
|
|
131
|
+
cache_context.set_Bn_buffer(
|
|
132
|
+
hidden_states_residual,
|
|
133
|
+
prefix="Bn_residual",
|
|
134
|
+
)
|
|
135
|
+
else:
|
|
136
|
+
# TaylorSeer
|
|
137
|
+
cache_context.set_Bn_buffer(
|
|
138
|
+
hidden_states,
|
|
139
|
+
prefix="Bn_hidden_states",
|
|
140
|
+
)
|
|
141
|
+
if cache_context.is_encoder_cache_residual():
|
|
142
|
+
cache_context.set_Bn_encoder_buffer(
|
|
143
|
+
encoder_hidden_states_residual,
|
|
144
|
+
prefix="Bn_residual",
|
|
145
|
+
)
|
|
146
|
+
else:
|
|
147
|
+
# TaylorSeer
|
|
148
|
+
cache_context.set_Bn_encoder_buffer(
|
|
149
|
+
encoder_hidden_states,
|
|
150
|
+
prefix="Bn_hidden_states",
|
|
151
|
+
)
|
|
152
|
+
torch._dynamo.graph_break()
|
|
153
|
+
# Call last `n` blocks to further process the hidden states
|
|
154
|
+
# for higher precision.
|
|
155
|
+
hidden_states, encoder_hidden_states = self.call_Bn_blocks(
|
|
156
|
+
hidden_states,
|
|
157
|
+
encoder_hidden_states,
|
|
158
|
+
*args,
|
|
159
|
+
**kwargs,
|
|
160
|
+
)
|
|
161
|
+
|
|
162
|
+
patch_cached_stats(self.transformer)
|
|
163
|
+
torch._dynamo.graph_break()
|
|
164
|
+
|
|
165
|
+
return (
|
|
166
|
+
hidden_states
|
|
167
|
+
if self.return_hidden_states_only
|
|
168
|
+
else (
|
|
169
|
+
(hidden_states, encoder_hidden_states)
|
|
170
|
+
if self.return_hidden_states_first
|
|
171
|
+
else (encoder_hidden_states, hidden_states)
|
|
172
|
+
)
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
@torch.compiler.disable
|
|
176
|
+
def _is_parallelized(self):
|
|
177
|
+
# Compatible with distributed inference.
|
|
178
|
+
return all(
|
|
179
|
+
(
|
|
180
|
+
self.transformer is not None,
|
|
181
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
182
|
+
)
|
|
183
|
+
)
|
|
184
|
+
|
|
185
|
+
@torch.compiler.disable
|
|
186
|
+
def _is_in_cache_step(self):
|
|
187
|
+
# Check if the current step is in cache steps.
|
|
188
|
+
# If so, we can skip some Bn blocks and directly
|
|
189
|
+
# use the cached values.
|
|
190
|
+
return (
|
|
191
|
+
cache_context.get_current_step() in cache_context.get_cached_steps()
|
|
192
|
+
) or (
|
|
193
|
+
cache_context.get_current_step()
|
|
194
|
+
in cache_context.get_cfg_cached_steps()
|
|
195
|
+
)
|
|
196
|
+
|
|
197
|
+
@torch.compiler.disable
|
|
198
|
+
def _Fn_blocks(self):
|
|
199
|
+
# Select first `n` blocks to process the hidden states for
|
|
200
|
+
# more stable diff calculation.
|
|
201
|
+
# Fn: [0,...,n-1]
|
|
202
|
+
selected_Fn_blocks = self.transformer_blocks[
|
|
203
|
+
: cache_context.Fn_compute_blocks()
|
|
204
|
+
]
|
|
205
|
+
return selected_Fn_blocks
|
|
206
|
+
|
|
207
|
+
@torch.compiler.disable
|
|
208
|
+
def _Mn_blocks(self): # middle blocks
|
|
209
|
+
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
210
|
+
if cache_context.Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
211
|
+
selected_Mn_blocks = self.transformer_blocks[
|
|
212
|
+
cache_context.Fn_compute_blocks() :
|
|
213
|
+
]
|
|
214
|
+
else:
|
|
215
|
+
selected_Mn_blocks = self.transformer_blocks[
|
|
216
|
+
cache_context.Fn_compute_blocks() : -cache_context.Bn_compute_blocks()
|
|
217
|
+
]
|
|
218
|
+
return selected_Mn_blocks
|
|
219
|
+
|
|
220
|
+
@torch.compiler.disable
|
|
221
|
+
def _Bn_blocks(self):
|
|
222
|
+
# Bn: transformer_blocks [N-n+1,...,N-1]
|
|
223
|
+
selected_Bn_blocks = self.transformer_blocks[
|
|
224
|
+
-cache_context.Bn_compute_blocks() :
|
|
225
|
+
]
|
|
226
|
+
return selected_Bn_blocks
|
|
227
|
+
|
|
228
|
+
def call_Fn_blocks(
|
|
229
|
+
self,
|
|
230
|
+
hidden_states: torch.Tensor,
|
|
231
|
+
encoder_hidden_states: torch.Tensor,
|
|
232
|
+
*args,
|
|
233
|
+
**kwargs,
|
|
234
|
+
):
|
|
235
|
+
assert cache_context.Fn_compute_blocks() <= len(
|
|
236
|
+
self.transformer_blocks
|
|
237
|
+
), (
|
|
238
|
+
f"Fn_compute_blocks {cache_context.Fn_compute_blocks()} must be less than "
|
|
239
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
240
|
+
)
|
|
241
|
+
for block in self._Fn_blocks():
|
|
242
|
+
hidden_states = block(
|
|
243
|
+
hidden_states,
|
|
244
|
+
encoder_hidden_states,
|
|
245
|
+
*args,
|
|
246
|
+
**kwargs,
|
|
247
|
+
)
|
|
248
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
249
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
250
|
+
if not self.return_hidden_states_first:
|
|
251
|
+
hidden_states, encoder_hidden_states = (
|
|
252
|
+
encoder_hidden_states,
|
|
253
|
+
hidden_states,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return hidden_states, encoder_hidden_states
|
|
257
|
+
|
|
258
|
+
def call_Mn_blocks(
|
|
259
|
+
self,
|
|
260
|
+
hidden_states: torch.Tensor,
|
|
261
|
+
encoder_hidden_states: torch.Tensor,
|
|
262
|
+
*args,
|
|
263
|
+
**kwargs,
|
|
264
|
+
):
|
|
265
|
+
original_hidden_states = hidden_states
|
|
266
|
+
original_encoder_hidden_states = encoder_hidden_states
|
|
267
|
+
for block in self._Mn_blocks():
|
|
268
|
+
hidden_states = block(
|
|
269
|
+
hidden_states,
|
|
270
|
+
encoder_hidden_states,
|
|
271
|
+
*args,
|
|
272
|
+
**kwargs,
|
|
273
|
+
)
|
|
274
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
275
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
276
|
+
if not self.return_hidden_states_first:
|
|
277
|
+
hidden_states, encoder_hidden_states = (
|
|
278
|
+
encoder_hidden_states,
|
|
279
|
+
hidden_states,
|
|
280
|
+
)
|
|
281
|
+
|
|
282
|
+
# compute hidden_states residual
|
|
283
|
+
hidden_states = hidden_states.contiguous()
|
|
284
|
+
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
285
|
+
|
|
286
|
+
hidden_states_residual = hidden_states - original_hidden_states
|
|
287
|
+
encoder_hidden_states_residual = (
|
|
288
|
+
encoder_hidden_states - original_encoder_hidden_states
|
|
289
|
+
)
|
|
290
|
+
|
|
291
|
+
return (
|
|
292
|
+
hidden_states,
|
|
293
|
+
encoder_hidden_states,
|
|
294
|
+
hidden_states_residual,
|
|
295
|
+
encoder_hidden_states_residual,
|
|
296
|
+
)
|
|
297
|
+
|
|
298
|
+
def _compute_or_cache_block(
|
|
299
|
+
self,
|
|
300
|
+
# Block index in the transformer blocks
|
|
301
|
+
# Bn: 8, block_id should be in [0, 8)
|
|
302
|
+
block_id: int,
|
|
303
|
+
# Below are the inputs to the block
|
|
304
|
+
block, # The transformer block to be executed
|
|
305
|
+
hidden_states: torch.Tensor,
|
|
306
|
+
encoder_hidden_states: torch.Tensor,
|
|
307
|
+
*args,
|
|
308
|
+
**kwargs,
|
|
309
|
+
):
|
|
310
|
+
# Helper function for `call_Bn_blocks`
|
|
311
|
+
# Skip the blocks by reuse residual cache if they are not
|
|
312
|
+
# in the Bn_compute_blocks_ids. NOTE: We should only skip
|
|
313
|
+
# the specific Bn blocks in cache steps. Compute the block
|
|
314
|
+
# and cache the residuals in non-cache steps.
|
|
315
|
+
|
|
316
|
+
# Normal steps: Compute the block and cache the residuals.
|
|
317
|
+
if not self._is_in_cache_step():
|
|
318
|
+
Bn_i_original_hidden_states = hidden_states
|
|
319
|
+
Bn_i_original_encoder_hidden_states = encoder_hidden_states
|
|
320
|
+
hidden_states = block(
|
|
321
|
+
hidden_states,
|
|
322
|
+
encoder_hidden_states,
|
|
323
|
+
*args,
|
|
324
|
+
**kwargs,
|
|
325
|
+
)
|
|
326
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
327
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
328
|
+
if not self.return_hidden_states_first:
|
|
329
|
+
hidden_states, encoder_hidden_states = (
|
|
330
|
+
encoder_hidden_states,
|
|
331
|
+
hidden_states,
|
|
332
|
+
)
|
|
333
|
+
# Cache residuals for the non-compute Bn blocks for
|
|
334
|
+
# subsequent cache steps.
|
|
335
|
+
if block_id not in cache_context.Bn_compute_blocks_ids():
|
|
336
|
+
Bn_i_hidden_states_residual = (
|
|
337
|
+
hidden_states - Bn_i_original_hidden_states
|
|
338
|
+
)
|
|
339
|
+
Bn_i_encoder_hidden_states_residual = (
|
|
340
|
+
encoder_hidden_states - Bn_i_original_encoder_hidden_states
|
|
341
|
+
)
|
|
342
|
+
|
|
343
|
+
# Save original_hidden_states for diff calculation.
|
|
344
|
+
cache_context.set_Bn_buffer(
|
|
345
|
+
Bn_i_original_hidden_states,
|
|
346
|
+
prefix=f"Bn_{block_id}_original",
|
|
347
|
+
)
|
|
348
|
+
cache_context.set_Bn_encoder_buffer(
|
|
349
|
+
Bn_i_original_encoder_hidden_states,
|
|
350
|
+
prefix=f"Bn_{block_id}_original",
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
cache_context.set_Bn_buffer(
|
|
354
|
+
Bn_i_hidden_states_residual,
|
|
355
|
+
prefix=f"Bn_{block_id}_residual",
|
|
356
|
+
)
|
|
357
|
+
cache_context.set_Bn_encoder_buffer(
|
|
358
|
+
Bn_i_encoder_hidden_states_residual,
|
|
359
|
+
prefix=f"Bn_{block_id}_residual",
|
|
360
|
+
)
|
|
361
|
+
del Bn_i_hidden_states_residual
|
|
362
|
+
del Bn_i_encoder_hidden_states_residual
|
|
363
|
+
|
|
364
|
+
del Bn_i_original_hidden_states
|
|
365
|
+
del Bn_i_original_encoder_hidden_states
|
|
366
|
+
|
|
367
|
+
else:
|
|
368
|
+
# Cache steps: Reuse the cached residuals.
|
|
369
|
+
# Check if the block is in the Bn_compute_blocks_ids.
|
|
370
|
+
if block_id in cache_context.Bn_compute_blocks_ids():
|
|
371
|
+
hidden_states = block(
|
|
372
|
+
hidden_states,
|
|
373
|
+
encoder_hidden_states,
|
|
374
|
+
*args,
|
|
375
|
+
**kwargs,
|
|
376
|
+
)
|
|
377
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
378
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
379
|
+
if not self.return_hidden_states_first:
|
|
380
|
+
hidden_states, encoder_hidden_states = (
|
|
381
|
+
encoder_hidden_states,
|
|
382
|
+
hidden_states,
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
386
|
+
# Use the cached residuals instead.
|
|
387
|
+
# Check if can use the cached residuals.
|
|
388
|
+
if cache_context.get_can_use_cache(
|
|
389
|
+
hidden_states, # curr step
|
|
390
|
+
parallelized=self._is_parallelized(),
|
|
391
|
+
threshold=cache_context.non_compute_blocks_diff_threshold(),
|
|
392
|
+
prefix=f"Bn_{block_id}_original", # prev step
|
|
393
|
+
):
|
|
394
|
+
hidden_states, encoder_hidden_states = (
|
|
395
|
+
cache_context.apply_hidden_states_residual(
|
|
396
|
+
hidden_states,
|
|
397
|
+
encoder_hidden_states,
|
|
398
|
+
prefix=(
|
|
399
|
+
f"Bn_{block_id}_residual"
|
|
400
|
+
if cache_context.is_cache_residual()
|
|
401
|
+
else f"Bn_{block_id}_original"
|
|
402
|
+
),
|
|
403
|
+
encoder_prefix=(
|
|
404
|
+
f"Bn_{block_id}_residual"
|
|
405
|
+
if cache_context.is_encoder_cache_residual()
|
|
406
|
+
else f"Bn_{block_id}_original"
|
|
407
|
+
),
|
|
408
|
+
)
|
|
409
|
+
)
|
|
410
|
+
else:
|
|
411
|
+
hidden_states = block(
|
|
412
|
+
hidden_states,
|
|
413
|
+
encoder_hidden_states,
|
|
414
|
+
*args,
|
|
415
|
+
**kwargs,
|
|
416
|
+
)
|
|
417
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
418
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
419
|
+
if not self.return_hidden_states_first:
|
|
420
|
+
hidden_states, encoder_hidden_states = (
|
|
421
|
+
encoder_hidden_states,
|
|
422
|
+
hidden_states,
|
|
423
|
+
)
|
|
424
|
+
return hidden_states, encoder_hidden_states
|
|
425
|
+
|
|
426
|
+
def call_Bn_blocks(
|
|
427
|
+
self,
|
|
428
|
+
hidden_states: torch.Tensor,
|
|
429
|
+
encoder_hidden_states: torch.Tensor,
|
|
430
|
+
*args,
|
|
431
|
+
**kwargs,
|
|
432
|
+
):
|
|
433
|
+
if cache_context.Bn_compute_blocks() == 0:
|
|
434
|
+
return hidden_states, encoder_hidden_states
|
|
435
|
+
|
|
436
|
+
assert cache_context.Bn_compute_blocks() <= len(
|
|
437
|
+
self.transformer_blocks
|
|
438
|
+
), (
|
|
439
|
+
f"Bn_compute_blocks {cache_context.Bn_compute_blocks()} must be less than "
|
|
440
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
441
|
+
)
|
|
442
|
+
if len(cache_context.Bn_compute_blocks_ids()) > 0:
|
|
443
|
+
for i, block in enumerate(self._Bn_blocks()):
|
|
444
|
+
hidden_states, encoder_hidden_states = (
|
|
445
|
+
self._compute_or_cache_block(
|
|
446
|
+
i,
|
|
447
|
+
block,
|
|
448
|
+
hidden_states,
|
|
449
|
+
encoder_hidden_states,
|
|
450
|
+
*args,
|
|
451
|
+
**kwargs,
|
|
452
|
+
)
|
|
453
|
+
)
|
|
454
|
+
else:
|
|
455
|
+
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
456
|
+
for block in self._Bn_blocks():
|
|
457
|
+
hidden_states = block(
|
|
458
|
+
hidden_states,
|
|
459
|
+
encoder_hidden_states,
|
|
460
|
+
*args,
|
|
461
|
+
**kwargs,
|
|
462
|
+
)
|
|
463
|
+
if not isinstance(hidden_states, torch.Tensor):
|
|
464
|
+
hidden_states, encoder_hidden_states = hidden_states
|
|
465
|
+
if not self.return_hidden_states_first:
|
|
466
|
+
hidden_states, encoder_hidden_states = (
|
|
467
|
+
encoder_hidden_states,
|
|
468
|
+
hidden_states,
|
|
469
|
+
)
|
|
470
|
+
|
|
471
|
+
return hidden_states, encoder_hidden_states
|
|
472
|
+
|
|
473
|
+
|
|
474
|
+
@torch.compiler.disable
|
|
475
|
+
def patch_cached_stats(
|
|
476
|
+
transformer,
|
|
477
|
+
):
|
|
478
|
+
# Patch the cached stats to the transformer, the cached stats
|
|
479
|
+
# will be reset for each calling of pipe.__call__(**kwargs).
|
|
480
|
+
if transformer is None:
|
|
481
|
+
return
|
|
482
|
+
|
|
483
|
+
# TODO: Patch more cached stats to the transformer
|
|
484
|
+
transformer._cached_steps = cache_context.get_cached_steps()
|
|
485
|
+
transformer._residual_diffs = cache_context.get_residual_diffs()
|
|
486
|
+
transformer._cfg_cached_steps = cache_context.get_cfg_cached_steps()
|
|
487
|
+
transformer._cfg_residual_diffs = cache_context.get_cfg_residual_diffs()
|