cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +8 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +17 -4
- cache_dit/cache_factory/block_adapters/__init__.py +555 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +262 -938
- cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
- cache_dit/cache_factory/cache_blocks/utils.py +16 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +31 -31
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +26 -26
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
- cache_dit-0.2.28.dist-info/RECORD +47 -0
- cache_dit/cache_factory/cache_context.py +0 -1155
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -1,1155 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import contextlib
|
|
3
|
-
import dataclasses
|
|
4
|
-
from collections import defaultdict
|
|
5
|
-
from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
6
|
-
|
|
7
|
-
import torch
|
|
8
|
-
import torch.distributed as dist
|
|
9
|
-
|
|
10
|
-
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
11
|
-
from cache_dit.logger import init_logger
|
|
12
|
-
|
|
13
|
-
logger = init_logger(__name__)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
@dataclasses.dataclass
|
|
17
|
-
class DBCacheContext:
|
|
18
|
-
# Dual Block Cache
|
|
19
|
-
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
20
|
-
Fn_compute_blocks: int = 1
|
|
21
|
-
Bn_compute_blocks: int = 0
|
|
22
|
-
# We have added residual cache pattern for selected compute blocks
|
|
23
|
-
Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
24
|
-
Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
25
|
-
# non compute blocks diff threshold, we don't skip the non
|
|
26
|
-
# compute blocks if the diff >= threshold
|
|
27
|
-
non_compute_blocks_diff_threshold: float = 0.08
|
|
28
|
-
max_Fn_compute_blocks: int = -1
|
|
29
|
-
max_Bn_compute_blocks: int = -1
|
|
30
|
-
# L1 hidden states or residual diff threshold for Fn
|
|
31
|
-
residual_diff_threshold: Union[torch.Tensor, float] = 0.05
|
|
32
|
-
l1_hidden_states_diff_threshold: float = None
|
|
33
|
-
important_condition_threshold: float = 0.0
|
|
34
|
-
|
|
35
|
-
# Alter Cache Settings
|
|
36
|
-
# Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
37
|
-
enable_alter_cache: bool = False
|
|
38
|
-
is_alter_cache: bool = True
|
|
39
|
-
# 1.0 means we always cache the residuals if alter_cache is enabled.
|
|
40
|
-
alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
|
|
41
|
-
|
|
42
|
-
# Buffer for storing the residuals and other tensors
|
|
43
|
-
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
44
|
-
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
45
|
-
default_factory=lambda: defaultdict(int),
|
|
46
|
-
)
|
|
47
|
-
|
|
48
|
-
# Other settings
|
|
49
|
-
downsample_factor: int = 1
|
|
50
|
-
num_inference_steps: int = -1 # for future use
|
|
51
|
-
max_warmup_steps: int = 0 # DON'T Cache in warmup steps
|
|
52
|
-
# DON'T Cache if the number of cached steps >= max_cached_steps
|
|
53
|
-
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
54
|
-
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
55
|
-
|
|
56
|
-
# Record the steps that have been cached, both cached and non-cache
|
|
57
|
-
executed_steps: int = 0 # cache + non-cache steps pippeline
|
|
58
|
-
# steps for transformer, for CFG, transformer_executed_steps will
|
|
59
|
-
# be double of executed_steps.
|
|
60
|
-
transformer_executed_steps: int = 0
|
|
61
|
-
|
|
62
|
-
# Support TaylorSeers in Dual Block Cache
|
|
63
|
-
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
64
|
-
# Url: https://arxiv.org/pdf/2503.06923
|
|
65
|
-
enable_taylorseer: bool = False
|
|
66
|
-
enable_encoder_taylorseer: bool = False
|
|
67
|
-
# NOTE: use residual cache for taylorseer may incur precision loss
|
|
68
|
-
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
69
|
-
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
70
|
-
taylorseer: Optional[TaylorSeer] = None
|
|
71
|
-
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
72
|
-
|
|
73
|
-
# Support do_separate_cfg, such as Wan 2.1,
|
|
74
|
-
# Qwen-Image. For model that fused CFG and non-CFG into single
|
|
75
|
-
# forward step, should set do_separate_cfg as False.
|
|
76
|
-
# For example: CogVideoX, HunyuanVideo, Mochi.
|
|
77
|
-
do_separate_cfg: bool = False
|
|
78
|
-
# Compute cfg forward first or not, default False, namely,
|
|
79
|
-
# 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
|
|
80
|
-
cfg_compute_first: bool = False
|
|
81
|
-
# Compute spearate diff values for CFG and non-CFG step,
|
|
82
|
-
# default True. If False, we will use the computed diff from
|
|
83
|
-
# current non-CFG transformer step for current CFG step.
|
|
84
|
-
cfg_diff_compute_separate: bool = True
|
|
85
|
-
cfg_taylorseer: Optional[TaylorSeer] = None
|
|
86
|
-
cfg_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
87
|
-
|
|
88
|
-
# CFG & non-CFG cached steps
|
|
89
|
-
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
90
|
-
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
91
|
-
default_factory=lambda: defaultdict(float),
|
|
92
|
-
)
|
|
93
|
-
continuous_cached_steps: int = 0
|
|
94
|
-
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
95
|
-
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
96
|
-
default_factory=lambda: defaultdict(float),
|
|
97
|
-
)
|
|
98
|
-
cfg_continuous_cached_steps: int = 0
|
|
99
|
-
|
|
100
|
-
@torch.compiler.disable
|
|
101
|
-
def __post_init__(self):
|
|
102
|
-
# Some checks for settings
|
|
103
|
-
if self.do_separate_cfg:
|
|
104
|
-
assert self.enable_alter_cache is False, (
|
|
105
|
-
"enable_alter_cache must set as False if "
|
|
106
|
-
"do_separate_cfg is enabled."
|
|
107
|
-
)
|
|
108
|
-
if self.cfg_diff_compute_separate:
|
|
109
|
-
assert self.cfg_compute_first is False, (
|
|
110
|
-
"cfg_compute_first must set as False if "
|
|
111
|
-
"cfg_diff_compute_separate is enabled."
|
|
112
|
-
)
|
|
113
|
-
|
|
114
|
-
if "max_warmup_steps" not in self.taylorseer_kwargs:
|
|
115
|
-
# If max_warmup_steps is not set in taylorseer_kwargs,
|
|
116
|
-
# set the same as max_warmup_steps for DBCache
|
|
117
|
-
self.taylorseer_kwargs["max_warmup_steps"] = (
|
|
118
|
-
self.max_warmup_steps if self.max_warmup_steps > 0 else 1
|
|
119
|
-
)
|
|
120
|
-
|
|
121
|
-
# Only set n_derivatives as 2 or 3, which is enough for most cases.
|
|
122
|
-
if "n_derivatives" not in self.taylorseer_kwargs:
|
|
123
|
-
self.taylorseer_kwargs["n_derivatives"] = max(
|
|
124
|
-
2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
if self.enable_taylorseer:
|
|
128
|
-
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
129
|
-
if self.do_separate_cfg:
|
|
130
|
-
self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
131
|
-
|
|
132
|
-
if self.enable_encoder_taylorseer:
|
|
133
|
-
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
134
|
-
if self.do_separate_cfg:
|
|
135
|
-
self.cfg_encoder_taylorseer = TaylorSeer(
|
|
136
|
-
**self.taylorseer_kwargs
|
|
137
|
-
)
|
|
138
|
-
|
|
139
|
-
@torch.compiler.disable
|
|
140
|
-
def get_residual_diff_threshold(self):
|
|
141
|
-
if self.enable_alter_cache:
|
|
142
|
-
residual_diff_threshold = self.alter_residual_diff_threshold
|
|
143
|
-
else:
|
|
144
|
-
residual_diff_threshold = self.residual_diff_threshold
|
|
145
|
-
if self.l1_hidden_states_diff_threshold is not None:
|
|
146
|
-
# Use the L1 hidden states diff threshold if set
|
|
147
|
-
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
148
|
-
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
149
|
-
residual_diff_threshold = residual_diff_threshold.item()
|
|
150
|
-
return residual_diff_threshold
|
|
151
|
-
|
|
152
|
-
@torch.compiler.disable
|
|
153
|
-
def get_buffer(self, name):
|
|
154
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
155
|
-
name = f"{name}_alter"
|
|
156
|
-
return self.buffers.get(name)
|
|
157
|
-
|
|
158
|
-
@torch.compiler.disable
|
|
159
|
-
def set_buffer(self, name, buffer):
|
|
160
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
161
|
-
name = f"{name}_alter"
|
|
162
|
-
self.buffers[name] = buffer
|
|
163
|
-
|
|
164
|
-
@torch.compiler.disable
|
|
165
|
-
def remove_buffer(self, name):
|
|
166
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
167
|
-
name = f"{name}_alter"
|
|
168
|
-
if name in self.buffers:
|
|
169
|
-
del self.buffers[name]
|
|
170
|
-
|
|
171
|
-
@torch.compiler.disable
|
|
172
|
-
def clear_buffers(self):
|
|
173
|
-
self.buffers.clear()
|
|
174
|
-
|
|
175
|
-
@torch.compiler.disable
|
|
176
|
-
def mark_step_begin(self):
|
|
177
|
-
# Always increase transformer executed steps
|
|
178
|
-
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
179
|
-
# current step: incr step - 1
|
|
180
|
-
self.transformer_executed_steps += 1
|
|
181
|
-
if not self.do_separate_cfg:
|
|
182
|
-
self.executed_steps += 1
|
|
183
|
-
else:
|
|
184
|
-
# 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
|
|
185
|
-
if not self.cfg_compute_first:
|
|
186
|
-
if not self.is_separate_cfg_step():
|
|
187
|
-
# transformer step: 0,2,4,...
|
|
188
|
-
self.executed_steps += 1
|
|
189
|
-
else:
|
|
190
|
-
if self.is_separate_cfg_step():
|
|
191
|
-
# transformer step: 0,2,4,...
|
|
192
|
-
self.executed_steps += 1
|
|
193
|
-
|
|
194
|
-
if not self.enable_alter_cache:
|
|
195
|
-
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
196
|
-
self.is_alter_cache = not self.is_alter_cache
|
|
197
|
-
|
|
198
|
-
# Reset the cached steps and residual diffs at the beginning
|
|
199
|
-
# of each inference.
|
|
200
|
-
if self.get_current_transformer_step() == 0:
|
|
201
|
-
self.cached_steps.clear()
|
|
202
|
-
self.residual_diffs.clear()
|
|
203
|
-
self.cfg_cached_steps.clear()
|
|
204
|
-
self.cfg_residual_diffs.clear()
|
|
205
|
-
# Reset the TaylorSeers cache at the beginning of each inference.
|
|
206
|
-
# reset_cache will set the current step to -1 for TaylorSeer,
|
|
207
|
-
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
208
|
-
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
209
|
-
if taylorseer is not None:
|
|
210
|
-
taylorseer.reset_cache()
|
|
211
|
-
if encoder_taylorseer is not None:
|
|
212
|
-
encoder_taylorseer.reset_cache()
|
|
213
|
-
cfg_taylorseer, cfg_encoder_taylorseer = (
|
|
214
|
-
self.get_cfg_taylorseers()
|
|
215
|
-
)
|
|
216
|
-
if cfg_taylorseer is not None:
|
|
217
|
-
cfg_taylorseer.reset_cache()
|
|
218
|
-
if cfg_encoder_taylorseer is not None:
|
|
219
|
-
cfg_encoder_taylorseer.reset_cache()
|
|
220
|
-
|
|
221
|
-
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
222
|
-
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
223
|
-
if self.do_separate_cfg:
|
|
224
|
-
# Assume non-CFG steps: 0, 2, 4, 6, ...
|
|
225
|
-
if not self.is_separate_cfg_step():
|
|
226
|
-
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
227
|
-
if taylorseer is not None:
|
|
228
|
-
taylorseer.mark_step_begin()
|
|
229
|
-
if encoder_taylorseer is not None:
|
|
230
|
-
encoder_taylorseer.mark_step_begin()
|
|
231
|
-
else:
|
|
232
|
-
cfg_taylorseer, cfg_encoder_taylorseer = (
|
|
233
|
-
self.get_cfg_taylorseers()
|
|
234
|
-
)
|
|
235
|
-
if cfg_taylorseer is not None:
|
|
236
|
-
cfg_taylorseer.mark_step_begin()
|
|
237
|
-
if cfg_encoder_taylorseer is not None:
|
|
238
|
-
cfg_encoder_taylorseer.mark_step_begin()
|
|
239
|
-
else:
|
|
240
|
-
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
241
|
-
if taylorseer is not None:
|
|
242
|
-
taylorseer.mark_step_begin()
|
|
243
|
-
if encoder_taylorseer is not None:
|
|
244
|
-
encoder_taylorseer.mark_step_begin()
|
|
245
|
-
|
|
246
|
-
def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
247
|
-
return self.taylorseer, self.encoder_tarlorseer
|
|
248
|
-
|
|
249
|
-
def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
250
|
-
return self.cfg_taylorseer, self.cfg_encoder_taylorseer
|
|
251
|
-
|
|
252
|
-
@torch.compiler.disable
|
|
253
|
-
def add_residual_diff(self, diff):
|
|
254
|
-
# step: executed_steps - 1, not transformer_steps - 1
|
|
255
|
-
step = str(self.get_current_step())
|
|
256
|
-
# Only add the diff if it is not already recorded for this step
|
|
257
|
-
if not self.is_separate_cfg_step():
|
|
258
|
-
if step not in self.residual_diffs:
|
|
259
|
-
self.residual_diffs[step] = diff
|
|
260
|
-
else:
|
|
261
|
-
if step not in self.cfg_residual_diffs:
|
|
262
|
-
self.cfg_residual_diffs[step] = diff
|
|
263
|
-
|
|
264
|
-
@torch.compiler.disable
|
|
265
|
-
def get_residual_diffs(self):
|
|
266
|
-
return self.residual_diffs.copy()
|
|
267
|
-
|
|
268
|
-
@torch.compiler.disable
|
|
269
|
-
def get_cfg_residual_diffs(self):
|
|
270
|
-
return self.cfg_residual_diffs.copy()
|
|
271
|
-
|
|
272
|
-
@torch.compiler.disable
|
|
273
|
-
def add_cached_step(self):
|
|
274
|
-
curr_cached_step = self.get_current_step()
|
|
275
|
-
if not self.is_separate_cfg_step():
|
|
276
|
-
if self.cached_steps:
|
|
277
|
-
prev_cached_step = self.cached_steps[-1]
|
|
278
|
-
if curr_cached_step - prev_cached_step == 1:
|
|
279
|
-
if self.continuous_cached_steps == 0:
|
|
280
|
-
self.continuous_cached_steps += 2
|
|
281
|
-
else:
|
|
282
|
-
self.continuous_cached_steps += 1
|
|
283
|
-
else:
|
|
284
|
-
self.continuous_cached_steps += 1
|
|
285
|
-
|
|
286
|
-
self.cached_steps.append(curr_cached_step)
|
|
287
|
-
else:
|
|
288
|
-
if self.cfg_cached_steps:
|
|
289
|
-
prev_cfg_cached_step = self.cfg_cached_steps[-1]
|
|
290
|
-
if curr_cached_step - prev_cfg_cached_step == 1:
|
|
291
|
-
if self.cfg_continuous_cached_steps == 0:
|
|
292
|
-
self.cfg_continuous_cached_steps += 2
|
|
293
|
-
else:
|
|
294
|
-
self.cfg_continuous_cached_steps += 1
|
|
295
|
-
else:
|
|
296
|
-
self.cfg_continuous_cached_steps += 1
|
|
297
|
-
|
|
298
|
-
self.cfg_cached_steps.append(curr_cached_step)
|
|
299
|
-
|
|
300
|
-
@torch.compiler.disable
|
|
301
|
-
def get_cached_steps(self):
|
|
302
|
-
return self.cached_steps.copy()
|
|
303
|
-
|
|
304
|
-
@torch.compiler.disable
|
|
305
|
-
def get_cfg_cached_steps(self):
|
|
306
|
-
return self.cfg_cached_steps.copy()
|
|
307
|
-
|
|
308
|
-
@torch.compiler.disable
|
|
309
|
-
def get_current_step(self):
|
|
310
|
-
return self.executed_steps - 1
|
|
311
|
-
|
|
312
|
-
@torch.compiler.disable
|
|
313
|
-
def get_current_transformer_step(self):
|
|
314
|
-
return self.transformer_executed_steps - 1
|
|
315
|
-
|
|
316
|
-
@torch.compiler.disable
|
|
317
|
-
def is_separate_cfg_step(self):
|
|
318
|
-
if not self.do_separate_cfg:
|
|
319
|
-
return False
|
|
320
|
-
if self.cfg_compute_first:
|
|
321
|
-
# CFG steps: 0, 2, 4, 6, ...
|
|
322
|
-
return self.get_current_transformer_step() % 2 == 0
|
|
323
|
-
# CFG steps: 1, 3, 5, 7, ...
|
|
324
|
-
return self.get_current_transformer_step() % 2 != 0
|
|
325
|
-
|
|
326
|
-
@torch.compiler.disable
|
|
327
|
-
def is_in_warmup(self):
|
|
328
|
-
return self.get_current_step() < self.max_warmup_steps
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
# TODO: Support context manager for different cache_context
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
def create_cache_context(*args, **kwargs):
|
|
335
|
-
return DBCacheContext(*args, **kwargs)
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
def get_current_cache_context():
|
|
339
|
-
return _current_cache_context
|
|
340
|
-
|
|
341
|
-
|
|
342
|
-
def set_current_cache_context(cache_context=None):
|
|
343
|
-
global _current_cache_context
|
|
344
|
-
_current_cache_context = cache_context
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@contextlib.contextmanager
|
|
348
|
-
def cache_context(cache_context):
|
|
349
|
-
global _current_cache_context
|
|
350
|
-
old_cache_context = _current_cache_context
|
|
351
|
-
_current_cache_context = cache_context
|
|
352
|
-
try:
|
|
353
|
-
yield
|
|
354
|
-
finally:
|
|
355
|
-
_current_cache_context = old_cache_context
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
@torch.compiler.disable
|
|
359
|
-
def get_residual_diff_threshold():
|
|
360
|
-
cache_context = get_current_cache_context()
|
|
361
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
362
|
-
return cache_context.get_residual_diff_threshold()
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
@torch.compiler.disable
|
|
366
|
-
def get_buffer(name):
|
|
367
|
-
cache_context = get_current_cache_context()
|
|
368
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
369
|
-
return cache_context.get_buffer(name)
|
|
370
|
-
|
|
371
|
-
|
|
372
|
-
@torch.compiler.disable
|
|
373
|
-
def set_buffer(name, buffer):
|
|
374
|
-
cache_context = get_current_cache_context()
|
|
375
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
376
|
-
cache_context.set_buffer(name, buffer)
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
@torch.compiler.disable
|
|
380
|
-
def remove_buffer(name):
|
|
381
|
-
cache_context = get_current_cache_context()
|
|
382
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
383
|
-
cache_context.remove_buffer(name)
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
@torch.compiler.disable
|
|
387
|
-
def mark_step_begin():
|
|
388
|
-
cache_context = get_current_cache_context()
|
|
389
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
390
|
-
cache_context.mark_step_begin()
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
@torch.compiler.disable
|
|
394
|
-
def get_current_step():
|
|
395
|
-
cache_context = get_current_cache_context()
|
|
396
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
397
|
-
return cache_context.get_current_step()
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
@torch.compiler.disable
|
|
401
|
-
def get_current_step_residual_diff():
|
|
402
|
-
cache_context = get_current_cache_context()
|
|
403
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
404
|
-
step = str(get_current_step())
|
|
405
|
-
residual_diffs = get_residual_diffs()
|
|
406
|
-
if step in residual_diffs:
|
|
407
|
-
return residual_diffs[step]
|
|
408
|
-
return None
|
|
409
|
-
|
|
410
|
-
|
|
411
|
-
@torch.compiler.disable
|
|
412
|
-
def get_current_step_cfg_residual_diff():
|
|
413
|
-
cache_context = get_current_cache_context()
|
|
414
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
415
|
-
step = str(get_current_step())
|
|
416
|
-
cfg_residual_diffs = get_cfg_residual_diffs()
|
|
417
|
-
if step in cfg_residual_diffs:
|
|
418
|
-
return cfg_residual_diffs[step]
|
|
419
|
-
return None
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
@torch.compiler.disable
|
|
423
|
-
def get_current_transformer_step():
|
|
424
|
-
cache_context = get_current_cache_context()
|
|
425
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
426
|
-
return cache_context.get_current_transformer_step()
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
@torch.compiler.disable
|
|
430
|
-
def get_cached_steps():
|
|
431
|
-
cache_context = get_current_cache_context()
|
|
432
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
433
|
-
return cache_context.get_cached_steps()
|
|
434
|
-
|
|
435
|
-
|
|
436
|
-
@torch.compiler.disable
|
|
437
|
-
def get_cfg_cached_steps():
|
|
438
|
-
cache_context = get_current_cache_context()
|
|
439
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
440
|
-
return cache_context.get_cfg_cached_steps()
|
|
441
|
-
|
|
442
|
-
|
|
443
|
-
@torch.compiler.disable
|
|
444
|
-
def get_max_cached_steps():
|
|
445
|
-
cache_context = get_current_cache_context()
|
|
446
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
447
|
-
return cache_context.max_cached_steps
|
|
448
|
-
|
|
449
|
-
|
|
450
|
-
@torch.compiler.disable
|
|
451
|
-
def get_max_continuous_cached_steps():
|
|
452
|
-
cache_context = get_current_cache_context()
|
|
453
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
454
|
-
return cache_context.max_continuous_cached_steps
|
|
455
|
-
|
|
456
|
-
|
|
457
|
-
@torch.compiler.disable
|
|
458
|
-
def get_continuous_cached_steps():
|
|
459
|
-
cache_context = get_current_cache_context()
|
|
460
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
461
|
-
return cache_context.continuous_cached_steps
|
|
462
|
-
|
|
463
|
-
|
|
464
|
-
@torch.compiler.disable
|
|
465
|
-
def get_cfg_continuous_cached_steps():
|
|
466
|
-
cache_context = get_current_cache_context()
|
|
467
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
468
|
-
return cache_context.cfg_continuous_cached_steps
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
@torch.compiler.disable
|
|
472
|
-
def add_cached_step():
|
|
473
|
-
cache_context = get_current_cache_context()
|
|
474
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
475
|
-
cache_context.add_cached_step()
|
|
476
|
-
|
|
477
|
-
|
|
478
|
-
@torch.compiler.disable
|
|
479
|
-
def add_residual_diff(diff):
|
|
480
|
-
cache_context = get_current_cache_context()
|
|
481
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
482
|
-
cache_context.add_residual_diff(diff)
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
@torch.compiler.disable
|
|
486
|
-
def get_residual_diffs():
|
|
487
|
-
cache_context = get_current_cache_context()
|
|
488
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
489
|
-
return cache_context.get_residual_diffs()
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
@torch.compiler.disable
|
|
493
|
-
def get_cfg_residual_diffs():
|
|
494
|
-
cache_context = get_current_cache_context()
|
|
495
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
496
|
-
return cache_context.get_cfg_residual_diffs()
|
|
497
|
-
|
|
498
|
-
|
|
499
|
-
@torch.compiler.disable
|
|
500
|
-
def is_taylorseer_enabled():
|
|
501
|
-
cache_context = get_current_cache_context()
|
|
502
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
503
|
-
return cache_context.enable_taylorseer
|
|
504
|
-
|
|
505
|
-
|
|
506
|
-
@torch.compiler.disable
|
|
507
|
-
def is_encoder_taylorseer_enabled():
|
|
508
|
-
cache_context = get_current_cache_context()
|
|
509
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
510
|
-
return cache_context.enable_encoder_taylorseer
|
|
511
|
-
|
|
512
|
-
|
|
513
|
-
def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
|
|
514
|
-
cache_context = get_current_cache_context()
|
|
515
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
516
|
-
return cache_context.get_taylorseers()
|
|
517
|
-
|
|
518
|
-
|
|
519
|
-
def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
|
|
520
|
-
cache_context = get_current_cache_context()
|
|
521
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
522
|
-
return cache_context.get_cfg_taylorseers()
|
|
523
|
-
|
|
524
|
-
|
|
525
|
-
@torch.compiler.disable
|
|
526
|
-
def is_taylorseer_cache_residual():
|
|
527
|
-
cache_context = get_current_cache_context()
|
|
528
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
529
|
-
return cache_context.taylorseer_cache_type == "residual"
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
@torch.compiler.disable
|
|
533
|
-
def is_cache_residual():
|
|
534
|
-
if is_taylorseer_enabled():
|
|
535
|
-
# residual or hidden_states
|
|
536
|
-
return is_taylorseer_cache_residual()
|
|
537
|
-
return True
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
@torch.compiler.disable
|
|
541
|
-
def is_encoder_cache_residual():
|
|
542
|
-
if is_encoder_taylorseer_enabled():
|
|
543
|
-
# residual or hidden_states
|
|
544
|
-
return is_taylorseer_cache_residual()
|
|
545
|
-
return True
|
|
546
|
-
|
|
547
|
-
|
|
548
|
-
@torch.compiler.disable
|
|
549
|
-
def is_alter_cache_enabled():
|
|
550
|
-
cache_context = get_current_cache_context()
|
|
551
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
552
|
-
return cache_context.enable_alter_cache
|
|
553
|
-
|
|
554
|
-
|
|
555
|
-
@torch.compiler.disable
|
|
556
|
-
def is_alter_cache():
|
|
557
|
-
cache_context = get_current_cache_context()
|
|
558
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
559
|
-
return cache_context.is_alter_cache
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
@torch.compiler.disable
|
|
563
|
-
def is_in_warmup():
|
|
564
|
-
cache_context = get_current_cache_context()
|
|
565
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
566
|
-
return cache_context.is_in_warmup()
|
|
567
|
-
|
|
568
|
-
|
|
569
|
-
@torch.compiler.disable
|
|
570
|
-
def is_l1_diff_enabled():
|
|
571
|
-
cache_context = get_current_cache_context()
|
|
572
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
573
|
-
return (
|
|
574
|
-
cache_context.l1_hidden_states_diff_threshold is not None
|
|
575
|
-
and cache_context.l1_hidden_states_diff_threshold > 0.0
|
|
576
|
-
)
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
@torch.compiler.disable
|
|
580
|
-
def get_important_condition_threshold():
|
|
581
|
-
cache_context = get_current_cache_context()
|
|
582
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
583
|
-
return cache_context.important_condition_threshold
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
@torch.compiler.disable
|
|
587
|
-
def non_compute_blocks_diff_threshold():
|
|
588
|
-
cache_context = get_current_cache_context()
|
|
589
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
590
|
-
return cache_context.non_compute_blocks_diff_threshold
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
@torch.compiler.disable
|
|
594
|
-
def Fn_compute_blocks():
|
|
595
|
-
cache_context = get_current_cache_context()
|
|
596
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
597
|
-
assert (
|
|
598
|
-
cache_context.Fn_compute_blocks >= 1
|
|
599
|
-
), "Fn_compute_blocks must be >= 1"
|
|
600
|
-
if cache_context.max_Fn_compute_blocks > 0:
|
|
601
|
-
# NOTE: Fn_compute_blocks can be 1, which means FB Cache
|
|
602
|
-
# but it must be less than or equal to max_Fn_compute_blocks
|
|
603
|
-
assert (
|
|
604
|
-
cache_context.Fn_compute_blocks
|
|
605
|
-
<= cache_context.max_Fn_compute_blocks
|
|
606
|
-
), (
|
|
607
|
-
f"Fn_compute_blocks must be <= {cache_context.max_Fn_compute_blocks}, "
|
|
608
|
-
f"but got {cache_context.Fn_compute_blocks}"
|
|
609
|
-
)
|
|
610
|
-
return cache_context.Fn_compute_blocks
|
|
611
|
-
|
|
612
|
-
|
|
613
|
-
@torch.compiler.disable
|
|
614
|
-
def Fn_compute_blocks_ids():
|
|
615
|
-
cache_context = get_current_cache_context()
|
|
616
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
617
|
-
assert (
|
|
618
|
-
len(cache_context.Fn_compute_blocks_ids)
|
|
619
|
-
<= cache_context.Fn_compute_blocks
|
|
620
|
-
), (
|
|
621
|
-
"The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
|
|
622
|
-
f"{cache_context.Fn_compute_blocks}, but got "
|
|
623
|
-
f"{len(cache_context.Fn_compute_blocks_ids)}"
|
|
624
|
-
)
|
|
625
|
-
return cache_context.Fn_compute_blocks_ids
|
|
626
|
-
|
|
627
|
-
|
|
628
|
-
@torch.compiler.disable
|
|
629
|
-
def Bn_compute_blocks():
|
|
630
|
-
cache_context = get_current_cache_context()
|
|
631
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
632
|
-
assert (
|
|
633
|
-
cache_context.Bn_compute_blocks >= 0
|
|
634
|
-
), "Bn_compute_blocks must be >= 0"
|
|
635
|
-
if cache_context.max_Bn_compute_blocks > 0:
|
|
636
|
-
# NOTE: Bn_compute_blocks can be 0, which means FB Cache
|
|
637
|
-
# but it must be less than or equal to max_Bn_compute_blocks
|
|
638
|
-
assert (
|
|
639
|
-
cache_context.Bn_compute_blocks
|
|
640
|
-
<= cache_context.max_Bn_compute_blocks
|
|
641
|
-
), (
|
|
642
|
-
f"Bn_compute_blocks must be <= {cache_context.max_Bn_compute_blocks}, "
|
|
643
|
-
f"but got {cache_context.Bn_compute_blocks}"
|
|
644
|
-
)
|
|
645
|
-
return cache_context.Bn_compute_blocks
|
|
646
|
-
|
|
647
|
-
|
|
648
|
-
@torch.compiler.disable
|
|
649
|
-
def Bn_compute_blocks_ids():
|
|
650
|
-
cache_context = get_current_cache_context()
|
|
651
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
652
|
-
assert (
|
|
653
|
-
len(cache_context.Bn_compute_blocks_ids)
|
|
654
|
-
<= cache_context.Bn_compute_blocks
|
|
655
|
-
), (
|
|
656
|
-
"The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
|
|
657
|
-
f"{cache_context.Bn_compute_blocks}, but got "
|
|
658
|
-
f"{len(cache_context.Bn_compute_blocks_ids)}"
|
|
659
|
-
)
|
|
660
|
-
return cache_context.Bn_compute_blocks_ids
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
@torch.compiler.disable
|
|
664
|
-
def do_separate_cfg():
|
|
665
|
-
cache_context = get_current_cache_context()
|
|
666
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
667
|
-
return cache_context.do_separate_cfg
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
@torch.compiler.disable
|
|
671
|
-
def is_separate_cfg_step():
|
|
672
|
-
cache_context = get_current_cache_context()
|
|
673
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
674
|
-
return cache_context.is_separate_cfg_step()
|
|
675
|
-
|
|
676
|
-
|
|
677
|
-
@torch.compiler.disable
|
|
678
|
-
def cfg_diff_compute_separate():
|
|
679
|
-
cache_context = get_current_cache_context()
|
|
680
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
681
|
-
return cache_context.cfg_diff_compute_separate
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
_current_cache_context: DBCacheContext = None
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
688
|
-
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
689
|
-
# default_attrs: specific settings for different pipelines
|
|
690
|
-
cache_attrs = dataclasses.fields(DBCacheContext)
|
|
691
|
-
cache_attrs = [
|
|
692
|
-
attr
|
|
693
|
-
for attr in cache_attrs
|
|
694
|
-
if hasattr(
|
|
695
|
-
DBCacheContext,
|
|
696
|
-
attr.name,
|
|
697
|
-
)
|
|
698
|
-
]
|
|
699
|
-
cache_kwargs = {
|
|
700
|
-
attr.name: kwargs.pop(
|
|
701
|
-
attr.name,
|
|
702
|
-
getattr(DBCacheContext, attr.name),
|
|
703
|
-
)
|
|
704
|
-
for attr in cache_attrs
|
|
705
|
-
}
|
|
706
|
-
|
|
707
|
-
def _safe_set_sequence_field(
|
|
708
|
-
field_name: str,
|
|
709
|
-
default_value: Any = None,
|
|
710
|
-
):
|
|
711
|
-
if field_name not in cache_kwargs:
|
|
712
|
-
cache_kwargs[field_name] = kwargs.pop(
|
|
713
|
-
field_name,
|
|
714
|
-
default_value,
|
|
715
|
-
)
|
|
716
|
-
|
|
717
|
-
# Manually set sequence fields, namely, Fn_compute_blocks_ids
|
|
718
|
-
# and Bn_compute_blocks_ids, which are lists or sets.
|
|
719
|
-
_safe_set_sequence_field("Fn_compute_blocks_ids", [])
|
|
720
|
-
_safe_set_sequence_field("Bn_compute_blocks_ids", [])
|
|
721
|
-
_safe_set_sequence_field("taylorseer_kwargs", {})
|
|
722
|
-
|
|
723
|
-
for attr in cache_attrs:
|
|
724
|
-
if attr.name in default_attrs: # can be empty {}
|
|
725
|
-
cache_kwargs[attr.name] = default_attrs[attr.name]
|
|
726
|
-
|
|
727
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
728
|
-
logger.debug(f"Collected DBCache kwargs: {cache_kwargs}")
|
|
729
|
-
|
|
730
|
-
return cache_kwargs, kwargs
|
|
731
|
-
|
|
732
|
-
|
|
733
|
-
@torch.compiler.disable
|
|
734
|
-
def are_two_tensors_similar(
|
|
735
|
-
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
736
|
-
t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
|
|
737
|
-
*,
|
|
738
|
-
threshold: float,
|
|
739
|
-
parallelized: bool = False,
|
|
740
|
-
prefix: str = "Fn", # for debugging
|
|
741
|
-
):
|
|
742
|
-
# Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
|
|
743
|
-
# the threshold is always enabled, -2.0 means the shape is not matched.
|
|
744
|
-
if threshold <= 0.0:
|
|
745
|
-
add_residual_diff(-0.0)
|
|
746
|
-
return False
|
|
747
|
-
|
|
748
|
-
if threshold >= 1.0:
|
|
749
|
-
# If threshold is 1.0 or more, we consider them always similar.
|
|
750
|
-
add_residual_diff(-1.0)
|
|
751
|
-
return True
|
|
752
|
-
|
|
753
|
-
if t1.shape != t2.shape:
|
|
754
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
755
|
-
logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
|
|
756
|
-
add_residual_diff(-2.0)
|
|
757
|
-
return False
|
|
758
|
-
|
|
759
|
-
if all(
|
|
760
|
-
(
|
|
761
|
-
do_separate_cfg(),
|
|
762
|
-
is_separate_cfg_step(),
|
|
763
|
-
not cfg_diff_compute_separate(),
|
|
764
|
-
get_current_step_residual_diff() is not None,
|
|
765
|
-
)
|
|
766
|
-
):
|
|
767
|
-
# Reuse computed diff value from non-CFG step
|
|
768
|
-
diff = get_current_step_residual_diff()
|
|
769
|
-
else:
|
|
770
|
-
# Find the most significant token through t1 and t2, and
|
|
771
|
-
# consider the diff of the significant token. The more significant,
|
|
772
|
-
# the more important.
|
|
773
|
-
condition_thresh = get_important_condition_threshold()
|
|
774
|
-
if condition_thresh > 0.0:
|
|
775
|
-
raw_diff = (t1 - t2).abs() # [B, seq_len, d]
|
|
776
|
-
token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
|
|
777
|
-
token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
|
|
778
|
-
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
779
|
-
token_diff = token_m_df / token_m_t1 # [B, seq_len]
|
|
780
|
-
condition = token_diff > condition_thresh # [B, seq_len]
|
|
781
|
-
if condition.sum() > 0:
|
|
782
|
-
condition = condition.unsqueeze(-1) # [B, seq_len, 1]
|
|
783
|
-
condition = condition.expand_as(raw_diff) # [B, seq_len, d]
|
|
784
|
-
mean_diff = raw_diff[condition].mean()
|
|
785
|
-
mean_t1 = t1[condition].abs().mean()
|
|
786
|
-
else:
|
|
787
|
-
mean_diff = (t1 - t2).abs().mean()
|
|
788
|
-
mean_t1 = t1.abs().mean()
|
|
789
|
-
else:
|
|
790
|
-
# Use the mean of the absolute difference of the tensors
|
|
791
|
-
mean_diff = (t1 - t2).abs().mean()
|
|
792
|
-
mean_t1 = t1.abs().mean()
|
|
793
|
-
|
|
794
|
-
if parallelized:
|
|
795
|
-
dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
|
|
796
|
-
dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
|
|
797
|
-
|
|
798
|
-
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
799
|
-
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
800
|
-
# H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
|
|
801
|
-
diff = (mean_diff / mean_t1).item()
|
|
802
|
-
|
|
803
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
804
|
-
logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
|
|
805
|
-
|
|
806
|
-
add_residual_diff(diff)
|
|
807
|
-
|
|
808
|
-
return diff < threshold
|
|
809
|
-
|
|
810
|
-
|
|
811
|
-
@torch.compiler.disable
|
|
812
|
-
def _debugging_set_buffer(prefix):
|
|
813
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
814
|
-
logger.debug(
|
|
815
|
-
f"set {prefix}, "
|
|
816
|
-
f"transformer step: {get_current_transformer_step()}, "
|
|
817
|
-
f"executed step: {get_current_step()}"
|
|
818
|
-
)
|
|
819
|
-
|
|
820
|
-
|
|
821
|
-
@torch.compiler.disable
|
|
822
|
-
def _debugging_get_buffer(prefix):
|
|
823
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
824
|
-
logger.debug(
|
|
825
|
-
f"get {prefix}, "
|
|
826
|
-
f"transformer step: {get_current_transformer_step()}, "
|
|
827
|
-
f"executed step: {get_current_step()}"
|
|
828
|
-
)
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
# Fn buffers
|
|
832
|
-
@torch.compiler.disable
|
|
833
|
-
def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
834
|
-
# Set hidden_states or residual for Fn blocks.
|
|
835
|
-
# This buffer is only use for L1 diff calculation.
|
|
836
|
-
downsample_factor = get_downsample_factor()
|
|
837
|
-
if downsample_factor > 1:
|
|
838
|
-
buffer = buffer[..., ::downsample_factor]
|
|
839
|
-
buffer = buffer.contiguous()
|
|
840
|
-
if is_separate_cfg_step():
|
|
841
|
-
_debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
842
|
-
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
843
|
-
else:
|
|
844
|
-
_debugging_set_buffer(f"{prefix}_buffer")
|
|
845
|
-
set_buffer(f"{prefix}_buffer", buffer)
|
|
846
|
-
|
|
847
|
-
|
|
848
|
-
@torch.compiler.disable
|
|
849
|
-
def get_Fn_buffer(prefix: str = "Fn"):
|
|
850
|
-
if is_separate_cfg_step():
|
|
851
|
-
_debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
852
|
-
return get_buffer(f"{prefix}_buffer_cfg")
|
|
853
|
-
_debugging_get_buffer(f"{prefix}_buffer")
|
|
854
|
-
return get_buffer(f"{prefix}_buffer")
|
|
855
|
-
|
|
856
|
-
|
|
857
|
-
@torch.compiler.disable
|
|
858
|
-
def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
859
|
-
if is_separate_cfg_step():
|
|
860
|
-
_debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
861
|
-
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
862
|
-
else:
|
|
863
|
-
_debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
864
|
-
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
865
|
-
|
|
866
|
-
|
|
867
|
-
@torch.compiler.disable
|
|
868
|
-
def get_Fn_encoder_buffer(prefix: str = "Fn"):
|
|
869
|
-
if is_separate_cfg_step():
|
|
870
|
-
_debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
871
|
-
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
872
|
-
_debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
873
|
-
return get_buffer(f"{prefix}_encoder_buffer")
|
|
874
|
-
|
|
875
|
-
|
|
876
|
-
# Bn buffers
|
|
877
|
-
@torch.compiler.disable
|
|
878
|
-
def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
879
|
-
# Set hidden_states or residual for Bn blocks.
|
|
880
|
-
# This buffer is use for hidden states approximation.
|
|
881
|
-
if is_taylorseer_enabled():
|
|
882
|
-
# taylorseer, encoder_taylorseer
|
|
883
|
-
if is_separate_cfg_step():
|
|
884
|
-
taylorseer, _ = get_cfg_taylorseers()
|
|
885
|
-
else:
|
|
886
|
-
taylorseer, _ = get_taylorseers()
|
|
887
|
-
|
|
888
|
-
if taylorseer is not None:
|
|
889
|
-
# Use TaylorSeer to update the buffer
|
|
890
|
-
taylorseer.update(buffer)
|
|
891
|
-
else:
|
|
892
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
893
|
-
logger.debug(
|
|
894
|
-
"TaylorSeer is enabled but not set in the cache context. "
|
|
895
|
-
"Falling back to default buffer retrieval."
|
|
896
|
-
)
|
|
897
|
-
if is_separate_cfg_step():
|
|
898
|
-
_debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
899
|
-
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
900
|
-
else:
|
|
901
|
-
_debugging_set_buffer(f"{prefix}_buffer")
|
|
902
|
-
set_buffer(f"{prefix}_buffer", buffer)
|
|
903
|
-
else:
|
|
904
|
-
if is_separate_cfg_step():
|
|
905
|
-
_debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
906
|
-
set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
907
|
-
else:
|
|
908
|
-
_debugging_set_buffer(f"{prefix}_buffer")
|
|
909
|
-
set_buffer(f"{prefix}_buffer", buffer)
|
|
910
|
-
|
|
911
|
-
|
|
912
|
-
@torch.compiler.disable
|
|
913
|
-
def get_Bn_buffer(prefix: str = "Bn"):
|
|
914
|
-
if is_taylorseer_enabled():
|
|
915
|
-
# taylorseer, encoder_taylorseer
|
|
916
|
-
if is_separate_cfg_step():
|
|
917
|
-
taylorseer, _ = get_cfg_taylorseers()
|
|
918
|
-
else:
|
|
919
|
-
taylorseer, _ = get_taylorseers()
|
|
920
|
-
|
|
921
|
-
if taylorseer is not None:
|
|
922
|
-
return taylorseer.approximate_value()
|
|
923
|
-
else:
|
|
924
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
925
|
-
logger.debug(
|
|
926
|
-
"TaylorSeer is enabled but not set in the cache context. "
|
|
927
|
-
"Falling back to default buffer retrieval."
|
|
928
|
-
)
|
|
929
|
-
# Fallback to default buffer retrieval
|
|
930
|
-
if is_separate_cfg_step():
|
|
931
|
-
_debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
932
|
-
return get_buffer(f"{prefix}_buffer_cfg")
|
|
933
|
-
_debugging_get_buffer(f"{prefix}_buffer")
|
|
934
|
-
return get_buffer(f"{prefix}_buffer")
|
|
935
|
-
else:
|
|
936
|
-
if is_separate_cfg_step():
|
|
937
|
-
_debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
938
|
-
return get_buffer(f"{prefix}_buffer_cfg")
|
|
939
|
-
_debugging_get_buffer(f"{prefix}_buffer")
|
|
940
|
-
return get_buffer(f"{prefix}_buffer")
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
@torch.compiler.disable
|
|
944
|
-
def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
|
|
945
|
-
# DON'T set None Buffer
|
|
946
|
-
if buffer is None:
|
|
947
|
-
return
|
|
948
|
-
|
|
949
|
-
# This buffer is use for encoder hidden states approximation.
|
|
950
|
-
if is_encoder_taylorseer_enabled():
|
|
951
|
-
# taylorseer, encoder_taylorseer
|
|
952
|
-
if is_separate_cfg_step():
|
|
953
|
-
_, encoder_taylorseer = get_cfg_taylorseers()
|
|
954
|
-
else:
|
|
955
|
-
_, encoder_taylorseer = get_taylorseers()
|
|
956
|
-
|
|
957
|
-
if encoder_taylorseer is not None:
|
|
958
|
-
# Use TaylorSeer to update the buffer
|
|
959
|
-
encoder_taylorseer.update(buffer)
|
|
960
|
-
else:
|
|
961
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
962
|
-
logger.debug(
|
|
963
|
-
"TaylorSeer is enabled but not set in the cache context. "
|
|
964
|
-
"Falling back to default buffer retrieval."
|
|
965
|
-
)
|
|
966
|
-
if is_separate_cfg_step():
|
|
967
|
-
_debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
968
|
-
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
969
|
-
else:
|
|
970
|
-
_debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
971
|
-
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
972
|
-
else:
|
|
973
|
-
if is_separate_cfg_step():
|
|
974
|
-
_debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
975
|
-
set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
976
|
-
else:
|
|
977
|
-
_debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
978
|
-
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
979
|
-
|
|
980
|
-
|
|
981
|
-
@torch.compiler.disable
|
|
982
|
-
def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
983
|
-
if is_encoder_taylorseer_enabled():
|
|
984
|
-
if is_separate_cfg_step():
|
|
985
|
-
_, encoder_taylorseer = get_cfg_taylorseers()
|
|
986
|
-
else:
|
|
987
|
-
_, encoder_taylorseer = get_taylorseers()
|
|
988
|
-
|
|
989
|
-
if encoder_taylorseer is not None:
|
|
990
|
-
# Use TaylorSeer to approximate the value
|
|
991
|
-
return encoder_taylorseer.approximate_value()
|
|
992
|
-
else:
|
|
993
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
994
|
-
logger.debug(
|
|
995
|
-
"TaylorSeer is enabled but not set in the cache context. "
|
|
996
|
-
"Falling back to default buffer retrieval."
|
|
997
|
-
)
|
|
998
|
-
# Fallback to default buffer retrieval
|
|
999
|
-
if is_separate_cfg_step():
|
|
1000
|
-
_debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
1001
|
-
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
1002
|
-
_debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
1003
|
-
return get_buffer(f"{prefix}_encoder_buffer")
|
|
1004
|
-
else:
|
|
1005
|
-
if is_separate_cfg_step():
|
|
1006
|
-
_debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
1007
|
-
return get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
1008
|
-
_debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
1009
|
-
return get_buffer(f"{prefix}_encoder_buffer")
|
|
1010
|
-
|
|
1011
|
-
|
|
1012
|
-
@torch.compiler.disable
|
|
1013
|
-
def apply_hidden_states_residual(
|
|
1014
|
-
hidden_states: torch.Tensor,
|
|
1015
|
-
encoder_hidden_states: torch.Tensor = None,
|
|
1016
|
-
prefix: str = "Bn",
|
|
1017
|
-
encoder_prefix: str = "Bn_encoder",
|
|
1018
|
-
):
|
|
1019
|
-
# Allow Bn and Fn prefix to be used for residual cache.
|
|
1020
|
-
if "Bn" in prefix:
|
|
1021
|
-
hidden_states_prev = get_Bn_buffer(prefix)
|
|
1022
|
-
else:
|
|
1023
|
-
hidden_states_prev = get_Fn_buffer(prefix)
|
|
1024
|
-
|
|
1025
|
-
assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
|
|
1026
|
-
|
|
1027
|
-
if is_cache_residual():
|
|
1028
|
-
hidden_states = hidden_states_prev + hidden_states
|
|
1029
|
-
else:
|
|
1030
|
-
# If cache is not residual, we use the hidden states directly
|
|
1031
|
-
hidden_states = hidden_states_prev
|
|
1032
|
-
|
|
1033
|
-
hidden_states = hidden_states.contiguous()
|
|
1034
|
-
|
|
1035
|
-
if encoder_hidden_states is not None:
|
|
1036
|
-
if "Bn" in encoder_prefix:
|
|
1037
|
-
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
1038
|
-
else:
|
|
1039
|
-
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
1040
|
-
|
|
1041
|
-
assert (
|
|
1042
|
-
encoder_hidden_states_prev is not None
|
|
1043
|
-
), f"{prefix}_encoder_buffer must be set before"
|
|
1044
|
-
|
|
1045
|
-
if is_encoder_cache_residual():
|
|
1046
|
-
encoder_hidden_states = (
|
|
1047
|
-
encoder_hidden_states_prev + encoder_hidden_states
|
|
1048
|
-
)
|
|
1049
|
-
else:
|
|
1050
|
-
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
1051
|
-
encoder_hidden_states = encoder_hidden_states_prev
|
|
1052
|
-
|
|
1053
|
-
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
1054
|
-
|
|
1055
|
-
return hidden_states, encoder_hidden_states
|
|
1056
|
-
|
|
1057
|
-
|
|
1058
|
-
@torch.compiler.disable
|
|
1059
|
-
def get_downsample_factor():
|
|
1060
|
-
cache_context = get_current_cache_context()
|
|
1061
|
-
assert cache_context is not None, "cache_context must be set before"
|
|
1062
|
-
return cache_context.downsample_factor
|
|
1063
|
-
|
|
1064
|
-
|
|
1065
|
-
@torch.compiler.disable
|
|
1066
|
-
def get_can_use_cache(
|
|
1067
|
-
states_tensor: torch.Tensor, # hidden_states or residual
|
|
1068
|
-
parallelized: bool = False,
|
|
1069
|
-
threshold: Optional[float] = None, # can manually set threshold
|
|
1070
|
-
prefix: str = "Fn",
|
|
1071
|
-
):
|
|
1072
|
-
if is_in_warmup():
|
|
1073
|
-
return False
|
|
1074
|
-
|
|
1075
|
-
# max cached steps
|
|
1076
|
-
max_cached_steps = get_max_cached_steps()
|
|
1077
|
-
if not is_separate_cfg_step():
|
|
1078
|
-
cached_steps = get_cached_steps()
|
|
1079
|
-
else:
|
|
1080
|
-
cached_steps = get_cfg_cached_steps()
|
|
1081
|
-
|
|
1082
|
-
if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
|
|
1083
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
1084
|
-
logger.debug(
|
|
1085
|
-
f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
|
|
1086
|
-
"can not use cache."
|
|
1087
|
-
)
|
|
1088
|
-
return False
|
|
1089
|
-
|
|
1090
|
-
# max continuous cached steps
|
|
1091
|
-
max_continuous_cached_steps = get_max_continuous_cached_steps()
|
|
1092
|
-
if not is_separate_cfg_step():
|
|
1093
|
-
continuous_cached_steps = get_continuous_cached_steps()
|
|
1094
|
-
else:
|
|
1095
|
-
continuous_cached_steps = get_cfg_continuous_cached_steps()
|
|
1096
|
-
|
|
1097
|
-
if max_continuous_cached_steps >= 0 and (
|
|
1098
|
-
continuous_cached_steps >= max_continuous_cached_steps
|
|
1099
|
-
):
|
|
1100
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
1101
|
-
logger.debug(
|
|
1102
|
-
f"{prefix}, max_continuous_cached_steps "
|
|
1103
|
-
f"reached: {max_continuous_cached_steps}, "
|
|
1104
|
-
"can not use cache."
|
|
1105
|
-
)
|
|
1106
|
-
# reset continuous cached steps stats
|
|
1107
|
-
cache_context = get_current_cache_context()
|
|
1108
|
-
if not is_separate_cfg_step():
|
|
1109
|
-
cache_context.continuous_cached_steps = 0
|
|
1110
|
-
else:
|
|
1111
|
-
cache_context.cfg_continuous_cached_steps = 0
|
|
1112
|
-
return False
|
|
1113
|
-
|
|
1114
|
-
if threshold is None or threshold <= 0.0:
|
|
1115
|
-
threshold = get_residual_diff_threshold()
|
|
1116
|
-
if threshold <= 0.0:
|
|
1117
|
-
return False
|
|
1118
|
-
|
|
1119
|
-
downsample_factor = get_downsample_factor()
|
|
1120
|
-
if downsample_factor > 1 and "Bn" not in prefix:
|
|
1121
|
-
states_tensor = states_tensor[..., ::downsample_factor]
|
|
1122
|
-
states_tensor = states_tensor.contiguous()
|
|
1123
|
-
|
|
1124
|
-
# Allow Bn and Fn prefix to be used for diff calculation.
|
|
1125
|
-
if "Bn" in prefix:
|
|
1126
|
-
prev_states_tensor = get_Bn_buffer(prefix)
|
|
1127
|
-
else:
|
|
1128
|
-
prev_states_tensor = get_Fn_buffer(prefix)
|
|
1129
|
-
|
|
1130
|
-
if not is_alter_cache_enabled():
|
|
1131
|
-
# Dynamic cache according to the residual diff
|
|
1132
|
-
can_use_cache = (
|
|
1133
|
-
prev_states_tensor is not None
|
|
1134
|
-
and are_two_tensors_similar(
|
|
1135
|
-
prev_states_tensor,
|
|
1136
|
-
states_tensor,
|
|
1137
|
-
threshold=threshold,
|
|
1138
|
-
parallelized=parallelized,
|
|
1139
|
-
prefix=prefix,
|
|
1140
|
-
)
|
|
1141
|
-
)
|
|
1142
|
-
else:
|
|
1143
|
-
# Only cache in the alter cache steps
|
|
1144
|
-
can_use_cache = (
|
|
1145
|
-
prev_states_tensor is not None
|
|
1146
|
-
and are_two_tensors_similar(
|
|
1147
|
-
prev_states_tensor,
|
|
1148
|
-
states_tensor,
|
|
1149
|
-
threshold=threshold,
|
|
1150
|
-
parallelized=parallelized,
|
|
1151
|
-
prefix=prefix,
|
|
1152
|
-
)
|
|
1153
|
-
and is_alter_cache()
|
|
1154
|
-
)
|
|
1155
|
-
return can_use_cache
|