cache-dit 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -6
- cache_dit/cache_factory/block_adapters/block_adapters.py +8 -64
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +47 -14
- cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
- cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
- cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
- cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
- cache_dit/cache_factory/cache_interface.py +128 -111
- cache_dit/cache_factory/params_modifier.py +87 -0
- cache_dit/metrics/__init__.py +3 -1
- cache_dit/utils.py +12 -21
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/METADATA +78 -64
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/RECORD +23 -28
- cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
- cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
- cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
- cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
- cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
- cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
- /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/top_level.txt +0 -0
|
@@ -1,799 +0,0 @@
|
|
|
1
|
-
import logging
|
|
2
|
-
import contextlib
|
|
3
|
-
import dataclasses
|
|
4
|
-
from typing import Any, Dict, Optional, Tuple, Union, List
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
import torch.distributed as dist
|
|
8
|
-
|
|
9
|
-
from cache_dit.cache_factory.cache_contexts.v2.calibrators import CalibratorBase
|
|
10
|
-
from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
|
|
11
|
-
CachedContextV2,
|
|
12
|
-
)
|
|
13
|
-
from cache_dit.logger import init_logger
|
|
14
|
-
|
|
15
|
-
logger = init_logger(__name__)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
class CachedContextManagerV2:
|
|
19
|
-
# Each Pipeline should have it's own context manager instance.
|
|
20
|
-
|
|
21
|
-
def __init__(self, name: str = None):
|
|
22
|
-
self.name = name
|
|
23
|
-
self._current_context: CachedContextV2 = None
|
|
24
|
-
self._cached_context_manager: Dict[str, CachedContextV2] = {}
|
|
25
|
-
|
|
26
|
-
def new_context(self, *args, **kwargs) -> CachedContextV2:
|
|
27
|
-
_context = CachedContextV2(*args, **kwargs)
|
|
28
|
-
self._cached_context_manager[_context.name] = _context
|
|
29
|
-
return _context
|
|
30
|
-
|
|
31
|
-
def set_context(self, cached_context: CachedContextV2 | str):
|
|
32
|
-
if isinstance(cached_context, CachedContextV2):
|
|
33
|
-
self._current_context = cached_context
|
|
34
|
-
else:
|
|
35
|
-
self._current_context = self._cached_context_manager[cached_context]
|
|
36
|
-
|
|
37
|
-
def get_context(self, name: str = None) -> CachedContextV2:
|
|
38
|
-
if name is not None:
|
|
39
|
-
if name not in self._cached_context_manager:
|
|
40
|
-
raise ValueError("Context not exist!")
|
|
41
|
-
return self._cached_context_manager[name]
|
|
42
|
-
return self._current_context
|
|
43
|
-
|
|
44
|
-
def reset_context(
|
|
45
|
-
self,
|
|
46
|
-
cached_context: CachedContextV2 | str,
|
|
47
|
-
*args,
|
|
48
|
-
**kwargs,
|
|
49
|
-
) -> CachedContextV2:
|
|
50
|
-
if isinstance(cached_context, CachedContextV2):
|
|
51
|
-
old_context_name = cached_context.name
|
|
52
|
-
if cached_context.name in self._cached_context_manager:
|
|
53
|
-
cached_context.clear_buffers()
|
|
54
|
-
del self._cached_context_manager[cached_context.name]
|
|
55
|
-
# force use old_context name
|
|
56
|
-
kwargs["name"] = old_context_name
|
|
57
|
-
_context = self.new_context(*args, **kwargs)
|
|
58
|
-
else:
|
|
59
|
-
old_context_name = cached_context
|
|
60
|
-
if cached_context in self._cached_context_manager:
|
|
61
|
-
self._cached_context_manager[cached_context].clear_buffers()
|
|
62
|
-
del self._cached_context_manager[cached_context]
|
|
63
|
-
# force use old_context name
|
|
64
|
-
kwargs["name"] = old_context_name
|
|
65
|
-
_context = self.new_context(*args, **kwargs)
|
|
66
|
-
return _context
|
|
67
|
-
|
|
68
|
-
def remove_context(self, cached_context: CachedContextV2 | str):
|
|
69
|
-
if isinstance(cached_context, CachedContextV2):
|
|
70
|
-
cached_context.clear_buffers()
|
|
71
|
-
if cached_context.name in self._cached_context_manager:
|
|
72
|
-
del self._cached_context_manager[cached_context.name]
|
|
73
|
-
else:
|
|
74
|
-
if cached_context in self._cached_context_manager:
|
|
75
|
-
self._cached_context_manager[cached_context].clear_buffers()
|
|
76
|
-
del self._cached_context_manager[cached_context]
|
|
77
|
-
|
|
78
|
-
def clear_contexts(self):
|
|
79
|
-
for context_name in list(self._cached_context_manager.keys()):
|
|
80
|
-
self.remove_context(context_name)
|
|
81
|
-
|
|
82
|
-
@contextlib.contextmanager
|
|
83
|
-
def enter_context(self, cached_context: CachedContextV2 | str):
|
|
84
|
-
old_cached_context = self._current_context
|
|
85
|
-
if isinstance(cached_context, CachedContextV2):
|
|
86
|
-
self._current_context = cached_context
|
|
87
|
-
else:
|
|
88
|
-
self._current_context = self._cached_context_manager[cached_context]
|
|
89
|
-
try:
|
|
90
|
-
yield
|
|
91
|
-
finally:
|
|
92
|
-
self._current_context = old_cached_context
|
|
93
|
-
|
|
94
|
-
@staticmethod
|
|
95
|
-
def collect_cache_kwargs(
|
|
96
|
-
default_attrs: dict, **kwargs
|
|
97
|
-
) -> Tuple[Dict, Dict]:
|
|
98
|
-
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
99
|
-
# default_attrs: specific settings for different pipelines
|
|
100
|
-
cache_attrs = dataclasses.fields(CachedContextV2)
|
|
101
|
-
cache_attrs = [
|
|
102
|
-
attr
|
|
103
|
-
for attr in cache_attrs
|
|
104
|
-
if hasattr(
|
|
105
|
-
CachedContextV2,
|
|
106
|
-
attr.name,
|
|
107
|
-
)
|
|
108
|
-
]
|
|
109
|
-
cache_kwargs = {
|
|
110
|
-
attr.name: kwargs.pop(
|
|
111
|
-
attr.name,
|
|
112
|
-
getattr(CachedContextV2, attr.name),
|
|
113
|
-
)
|
|
114
|
-
for attr in cache_attrs
|
|
115
|
-
}
|
|
116
|
-
|
|
117
|
-
def _safe_set_sequence_field(
|
|
118
|
-
field_name: str,
|
|
119
|
-
default_value: Any = None,
|
|
120
|
-
):
|
|
121
|
-
if field_name not in cache_kwargs:
|
|
122
|
-
cache_kwargs[field_name] = kwargs.pop(
|
|
123
|
-
field_name,
|
|
124
|
-
default_value,
|
|
125
|
-
)
|
|
126
|
-
|
|
127
|
-
# Manually set sequence fields
|
|
128
|
-
_safe_set_sequence_field("calibrator_kwargs", {})
|
|
129
|
-
|
|
130
|
-
for attr in cache_attrs:
|
|
131
|
-
if attr.name in default_attrs: # can be empty {}
|
|
132
|
-
cache_kwargs[attr.name] = default_attrs[attr.name]
|
|
133
|
-
|
|
134
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
135
|
-
logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
|
|
136
|
-
|
|
137
|
-
return cache_kwargs, kwargs
|
|
138
|
-
|
|
139
|
-
@torch.compiler.disable
|
|
140
|
-
def get_residual_diff_threshold(self) -> float:
|
|
141
|
-
cached_context = self.get_context()
|
|
142
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
143
|
-
return cached_context.get_residual_diff_threshold()
|
|
144
|
-
|
|
145
|
-
@torch.compiler.disable
|
|
146
|
-
def get_buffer(self, name) -> torch.Tensor:
|
|
147
|
-
cached_context = self.get_context()
|
|
148
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
149
|
-
return cached_context.get_buffer(name)
|
|
150
|
-
|
|
151
|
-
@torch.compiler.disable
|
|
152
|
-
def set_buffer(self, name, buffer):
|
|
153
|
-
cached_context = self.get_context()
|
|
154
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
155
|
-
cached_context.set_buffer(name, buffer)
|
|
156
|
-
|
|
157
|
-
@torch.compiler.disable
|
|
158
|
-
def remove_buffer(self, name):
|
|
159
|
-
cached_context = self.get_context()
|
|
160
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
161
|
-
cached_context.remove_buffer(name)
|
|
162
|
-
|
|
163
|
-
@torch.compiler.disable
|
|
164
|
-
def mark_step_begin(self):
|
|
165
|
-
cached_context = self.get_context()
|
|
166
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
167
|
-
cached_context.mark_step_begin()
|
|
168
|
-
|
|
169
|
-
@torch.compiler.disable
|
|
170
|
-
def get_current_step(self) -> int:
|
|
171
|
-
cached_context = self.get_context()
|
|
172
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
173
|
-
return cached_context.get_current_step()
|
|
174
|
-
|
|
175
|
-
@torch.compiler.disable
|
|
176
|
-
def get_current_step_residual_diff(self) -> float:
|
|
177
|
-
cached_context = self.get_context()
|
|
178
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
179
|
-
step = str(self.get_current_step())
|
|
180
|
-
residual_diffs = self.get_residual_diffs()
|
|
181
|
-
if step in residual_diffs:
|
|
182
|
-
return residual_diffs[step]
|
|
183
|
-
return None
|
|
184
|
-
|
|
185
|
-
@torch.compiler.disable
|
|
186
|
-
def get_current_step_cfg_residual_diff(self) -> float:
|
|
187
|
-
cached_context = self.get_context()
|
|
188
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
189
|
-
step = str(self.get_current_step())
|
|
190
|
-
cfg_residual_diffs = self.get_cfg_residual_diffs()
|
|
191
|
-
if step in cfg_residual_diffs:
|
|
192
|
-
return cfg_residual_diffs[step]
|
|
193
|
-
return None
|
|
194
|
-
|
|
195
|
-
@torch.compiler.disable
|
|
196
|
-
def get_current_transformer_step(self) -> int:
|
|
197
|
-
cached_context = self.get_context()
|
|
198
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
199
|
-
return cached_context.get_current_transformer_step()
|
|
200
|
-
|
|
201
|
-
@torch.compiler.disable
|
|
202
|
-
def get_cached_steps(self) -> List[int]:
|
|
203
|
-
cached_context = self.get_context()
|
|
204
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
205
|
-
return cached_context.get_cached_steps()
|
|
206
|
-
|
|
207
|
-
@torch.compiler.disable
|
|
208
|
-
def get_cfg_cached_steps(self) -> List[int]:
|
|
209
|
-
cached_context = self.get_context()
|
|
210
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
211
|
-
return cached_context.get_cfg_cached_steps()
|
|
212
|
-
|
|
213
|
-
@torch.compiler.disable
|
|
214
|
-
def get_max_cached_steps(self) -> int:
|
|
215
|
-
cached_context = self.get_context()
|
|
216
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
217
|
-
return cached_context.max_cached_steps
|
|
218
|
-
|
|
219
|
-
@torch.compiler.disable
|
|
220
|
-
def get_max_continuous_cached_steps(self) -> int:
|
|
221
|
-
cached_context = self.get_context()
|
|
222
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
223
|
-
return cached_context.max_continuous_cached_steps
|
|
224
|
-
|
|
225
|
-
@torch.compiler.disable
|
|
226
|
-
def get_continuous_cached_steps(self) -> int:
|
|
227
|
-
cached_context = self.get_context()
|
|
228
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
229
|
-
return cached_context.continuous_cached_steps
|
|
230
|
-
|
|
231
|
-
@torch.compiler.disable
|
|
232
|
-
def get_cfg_continuous_cached_steps(self) -> int:
|
|
233
|
-
cached_context = self.get_context()
|
|
234
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
235
|
-
return cached_context.cfg_continuous_cached_steps
|
|
236
|
-
|
|
237
|
-
@torch.compiler.disable
|
|
238
|
-
def add_cached_step(self):
|
|
239
|
-
cached_context = self.get_context()
|
|
240
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
241
|
-
cached_context.add_cached_step()
|
|
242
|
-
|
|
243
|
-
@torch.compiler.disable
|
|
244
|
-
def add_residual_diff(self, diff):
|
|
245
|
-
cached_context = self.get_context()
|
|
246
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
247
|
-
cached_context.add_residual_diff(diff)
|
|
248
|
-
|
|
249
|
-
@torch.compiler.disable
|
|
250
|
-
def get_residual_diffs(self) -> Dict[str, float]:
|
|
251
|
-
cached_context = self.get_context()
|
|
252
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
253
|
-
return cached_context.get_residual_diffs()
|
|
254
|
-
|
|
255
|
-
@torch.compiler.disable
|
|
256
|
-
def get_cfg_residual_diffs(self) -> Dict[str, float]:
|
|
257
|
-
cached_context = self.get_context()
|
|
258
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
259
|
-
return cached_context.get_cfg_residual_diffs()
|
|
260
|
-
|
|
261
|
-
@torch.compiler.disable
|
|
262
|
-
def is_calibrator_enabled(self) -> bool:
|
|
263
|
-
cached_context = self.get_context()
|
|
264
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
265
|
-
return cached_context.enable_calibrator()
|
|
266
|
-
|
|
267
|
-
@torch.compiler.disable
|
|
268
|
-
def is_encoder_calibrator_enabled(self) -> bool:
|
|
269
|
-
cached_context = self.get_context()
|
|
270
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
271
|
-
return cached_context.enable_encoder_calibrator()
|
|
272
|
-
|
|
273
|
-
def get_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
274
|
-
cached_context = self.get_context()
|
|
275
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
276
|
-
return cached_context.get_calibrators()
|
|
277
|
-
|
|
278
|
-
def get_cfg_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
279
|
-
cached_context = self.get_context()
|
|
280
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
281
|
-
return cached_context.get_cfg_calibrators()
|
|
282
|
-
|
|
283
|
-
@torch.compiler.disable
|
|
284
|
-
def is_calibrator_cache_residual(self) -> bool:
|
|
285
|
-
cached_context = self.get_context()
|
|
286
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
287
|
-
return cached_context.calibrator_cache_type() == "residual"
|
|
288
|
-
|
|
289
|
-
@torch.compiler.disable
|
|
290
|
-
def is_cache_residual(self) -> bool:
|
|
291
|
-
if self.is_calibrator_enabled():
|
|
292
|
-
# residual or hidden_states
|
|
293
|
-
return self.is_calibrator_cache_residual()
|
|
294
|
-
return True
|
|
295
|
-
|
|
296
|
-
@torch.compiler.disable
|
|
297
|
-
def is_encoder_cache_residual(self) -> bool:
|
|
298
|
-
if self.is_encoder_calibrator_enabled():
|
|
299
|
-
# residual or hidden_states
|
|
300
|
-
return self.is_calibrator_cache_residual()
|
|
301
|
-
return True
|
|
302
|
-
|
|
303
|
-
@torch.compiler.disable
|
|
304
|
-
def is_in_warmup(self) -> bool:
|
|
305
|
-
cached_context = self.get_context()
|
|
306
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
307
|
-
return cached_context.is_in_warmup()
|
|
308
|
-
|
|
309
|
-
@torch.compiler.disable
|
|
310
|
-
def is_l1_diff_enabled(self) -> bool:
|
|
311
|
-
cached_context = self.get_context()
|
|
312
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
313
|
-
return (
|
|
314
|
-
cached_context.l1_hidden_states_diff_threshold is not None
|
|
315
|
-
and cached_context.l1_hidden_states_diff_threshold > 0.0
|
|
316
|
-
)
|
|
317
|
-
|
|
318
|
-
@torch.compiler.disable
|
|
319
|
-
def get_important_condition_threshold(self) -> float:
|
|
320
|
-
cached_context = self.get_context()
|
|
321
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
322
|
-
return cached_context.important_condition_threshold
|
|
323
|
-
|
|
324
|
-
@torch.compiler.disable
|
|
325
|
-
def non_compute_blocks_diff_threshold(self) -> float:
|
|
326
|
-
cached_context = self.get_context()
|
|
327
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
328
|
-
return cached_context.non_compute_blocks_diff_threshold
|
|
329
|
-
|
|
330
|
-
@torch.compiler.disable
|
|
331
|
-
def Fn_compute_blocks(self) -> int:
|
|
332
|
-
cached_context = self.get_context()
|
|
333
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
334
|
-
assert (
|
|
335
|
-
cached_context.Fn_compute_blocks >= 1
|
|
336
|
-
), "Fn_compute_blocks must be >= 1"
|
|
337
|
-
if cached_context.max_Fn_compute_blocks > 0:
|
|
338
|
-
# NOTE: Fn_compute_blocks can be 1, which means FB Cache
|
|
339
|
-
# but it must be less than or equal to max_Fn_compute_blocks
|
|
340
|
-
assert (
|
|
341
|
-
cached_context.Fn_compute_blocks
|
|
342
|
-
<= cached_context.max_Fn_compute_blocks
|
|
343
|
-
), (
|
|
344
|
-
f"Fn_compute_blocks must be <= {cached_context.max_Fn_compute_blocks}, "
|
|
345
|
-
f"but got {cached_context.Fn_compute_blocks}"
|
|
346
|
-
)
|
|
347
|
-
return cached_context.Fn_compute_blocks
|
|
348
|
-
|
|
349
|
-
@torch.compiler.disable
|
|
350
|
-
def Bn_compute_blocks(self) -> int:
|
|
351
|
-
cached_context = self.get_context()
|
|
352
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
353
|
-
assert (
|
|
354
|
-
cached_context.Bn_compute_blocks >= 0
|
|
355
|
-
), "Bn_compute_blocks must be >= 0"
|
|
356
|
-
if cached_context.max_Bn_compute_blocks > 0:
|
|
357
|
-
# NOTE: Bn_compute_blocks can be 0, which means FB Cache
|
|
358
|
-
# but it must be less than or equal to max_Bn_compute_blocks
|
|
359
|
-
assert (
|
|
360
|
-
cached_context.Bn_compute_blocks
|
|
361
|
-
<= cached_context.max_Bn_compute_blocks
|
|
362
|
-
), (
|
|
363
|
-
f"Bn_compute_blocks must be <= {cached_context.max_Bn_compute_blocks}, "
|
|
364
|
-
f"but got {cached_context.Bn_compute_blocks}"
|
|
365
|
-
)
|
|
366
|
-
return cached_context.Bn_compute_blocks
|
|
367
|
-
|
|
368
|
-
@torch.compiler.disable
|
|
369
|
-
def enable_separate_cfg(self) -> bool:
|
|
370
|
-
cached_context = self.get_context()
|
|
371
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
372
|
-
return cached_context.enable_separate_cfg
|
|
373
|
-
|
|
374
|
-
@torch.compiler.disable
|
|
375
|
-
def is_separate_cfg_step(self) -> bool:
|
|
376
|
-
cached_context = self.get_context()
|
|
377
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
378
|
-
return cached_context.is_separate_cfg_step()
|
|
379
|
-
|
|
380
|
-
@torch.compiler.disable
|
|
381
|
-
def cfg_diff_compute_separate(self) -> bool:
|
|
382
|
-
cached_context = self.get_context()
|
|
383
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
384
|
-
return cached_context.cfg_diff_compute_separate
|
|
385
|
-
|
|
386
|
-
@torch.compiler.disable
|
|
387
|
-
def similarity(
|
|
388
|
-
self,
|
|
389
|
-
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
390
|
-
t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
|
|
391
|
-
*,
|
|
392
|
-
threshold: float,
|
|
393
|
-
parallelized: bool = False,
|
|
394
|
-
prefix: str = "Fn", # for debugging
|
|
395
|
-
) -> bool:
|
|
396
|
-
# Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
|
|
397
|
-
# the threshold is always enabled, -2.0 means the shape is not matched.
|
|
398
|
-
if threshold <= 0.0:
|
|
399
|
-
self.add_residual_diff(-0.0)
|
|
400
|
-
return False
|
|
401
|
-
|
|
402
|
-
if threshold >= 1.0:
|
|
403
|
-
# If threshold is 1.0 or more, we consider them always similar.
|
|
404
|
-
self.add_residual_diff(-1.0)
|
|
405
|
-
return True
|
|
406
|
-
|
|
407
|
-
if t1.shape != t2.shape:
|
|
408
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
409
|
-
logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
|
|
410
|
-
self.add_residual_diff(-2.0)
|
|
411
|
-
return False
|
|
412
|
-
|
|
413
|
-
if all(
|
|
414
|
-
(
|
|
415
|
-
self.enable_separate_cfg(),
|
|
416
|
-
self.is_separate_cfg_step(),
|
|
417
|
-
not self.cfg_diff_compute_separate(),
|
|
418
|
-
self.get_current_step_residual_diff() is not None,
|
|
419
|
-
)
|
|
420
|
-
):
|
|
421
|
-
# Reuse computed diff value from non-CFG step
|
|
422
|
-
diff = self.get_current_step_residual_diff()
|
|
423
|
-
else:
|
|
424
|
-
# Find the most significant token through t1 and t2, and
|
|
425
|
-
# consider the diff of the significant token. The more significant,
|
|
426
|
-
# the more important.
|
|
427
|
-
condition_thresh = self.get_important_condition_threshold()
|
|
428
|
-
if condition_thresh > 0.0:
|
|
429
|
-
raw_diff = (t1 - t2).abs() # [B, seq_len, d]
|
|
430
|
-
token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
|
|
431
|
-
token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
|
|
432
|
-
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
433
|
-
token_diff = token_m_df / token_m_t1 # [B, seq_len]
|
|
434
|
-
condition: torch.Tensor = (
|
|
435
|
-
token_diff > condition_thresh
|
|
436
|
-
) # [B, seq_len]
|
|
437
|
-
if condition.sum() > 0:
|
|
438
|
-
condition = condition.unsqueeze(-1) # [B, seq_len, 1]
|
|
439
|
-
condition = condition.expand_as(raw_diff) # [B, seq_len, d]
|
|
440
|
-
mean_diff = raw_diff[condition].mean()
|
|
441
|
-
mean_t1 = t1[condition].abs().mean()
|
|
442
|
-
else:
|
|
443
|
-
mean_diff = (t1 - t2).abs().mean()
|
|
444
|
-
mean_t1 = t1.abs().mean()
|
|
445
|
-
else:
|
|
446
|
-
# Use the mean of the absolute difference of the tensors
|
|
447
|
-
mean_diff = (t1 - t2).abs().mean()
|
|
448
|
-
mean_t1 = t1.abs().mean()
|
|
449
|
-
|
|
450
|
-
if parallelized:
|
|
451
|
-
dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
|
|
452
|
-
dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
|
|
453
|
-
|
|
454
|
-
# D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
|
|
455
|
-
# Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
|
|
456
|
-
# H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
|
|
457
|
-
diff = (mean_diff / mean_t1).item()
|
|
458
|
-
|
|
459
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
460
|
-
logger.debug(
|
|
461
|
-
f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}"
|
|
462
|
-
)
|
|
463
|
-
|
|
464
|
-
self.add_residual_diff(diff)
|
|
465
|
-
|
|
466
|
-
return diff < threshold
|
|
467
|
-
|
|
468
|
-
def _debugging_set_buffer(self, prefix):
|
|
469
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
470
|
-
logger.debug(
|
|
471
|
-
f"set {prefix}, "
|
|
472
|
-
f"transformer step: {self.get_current_transformer_step()}, "
|
|
473
|
-
f"executed step: {self.get_current_step()}"
|
|
474
|
-
)
|
|
475
|
-
|
|
476
|
-
def _debugging_get_buffer(self, prefix):
|
|
477
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
478
|
-
logger.debug(
|
|
479
|
-
f"get {prefix}, "
|
|
480
|
-
f"transformer step: {self.get_current_transformer_step()}, "
|
|
481
|
-
f"executed step: {self.get_current_step()}"
|
|
482
|
-
)
|
|
483
|
-
|
|
484
|
-
# Fn buffers
|
|
485
|
-
@torch.compiler.disable
|
|
486
|
-
def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
|
|
487
|
-
# DON'T set None Buffer
|
|
488
|
-
if buffer is None:
|
|
489
|
-
return
|
|
490
|
-
# Set hidden_states or residual for Fn blocks.
|
|
491
|
-
# This buffer is only use for L1 diff calculation.
|
|
492
|
-
downsample_factor = self.get_downsample_factor()
|
|
493
|
-
if downsample_factor > 1:
|
|
494
|
-
buffer = buffer[..., ::downsample_factor]
|
|
495
|
-
buffer = buffer.contiguous()
|
|
496
|
-
if self.is_separate_cfg_step():
|
|
497
|
-
self._debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
498
|
-
self.set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
499
|
-
else:
|
|
500
|
-
self._debugging_set_buffer(f"{prefix}_buffer")
|
|
501
|
-
self.set_buffer(f"{prefix}_buffer", buffer)
|
|
502
|
-
|
|
503
|
-
@torch.compiler.disable
|
|
504
|
-
def get_Fn_buffer(self, prefix: str = "Fn") -> torch.Tensor:
|
|
505
|
-
if self.is_separate_cfg_step():
|
|
506
|
-
self._debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
507
|
-
return self.get_buffer(f"{prefix}_buffer_cfg")
|
|
508
|
-
self._debugging_get_buffer(f"{prefix}_buffer")
|
|
509
|
-
return self.get_buffer(f"{prefix}_buffer")
|
|
510
|
-
|
|
511
|
-
@torch.compiler.disable
|
|
512
|
-
def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
|
|
513
|
-
# DON'T set None Buffer
|
|
514
|
-
if buffer is None:
|
|
515
|
-
return
|
|
516
|
-
if self.is_separate_cfg_step():
|
|
517
|
-
self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
518
|
-
self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
519
|
-
else:
|
|
520
|
-
self._debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
521
|
-
self.set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
522
|
-
|
|
523
|
-
@torch.compiler.disable
|
|
524
|
-
def get_Fn_encoder_buffer(self, prefix: str = "Fn") -> torch.Tensor:
|
|
525
|
-
if self.is_separate_cfg_step():
|
|
526
|
-
self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
527
|
-
return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
528
|
-
self._debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
529
|
-
return self.get_buffer(f"{prefix}_encoder_buffer")
|
|
530
|
-
|
|
531
|
-
# Bn buffers
|
|
532
|
-
@torch.compiler.disable
|
|
533
|
-
def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
|
|
534
|
-
# DON'T set None Buffer
|
|
535
|
-
if buffer is None:
|
|
536
|
-
return
|
|
537
|
-
# Set hidden_states or residual for Bn blocks.
|
|
538
|
-
# This buffer is use for hidden states approximation.
|
|
539
|
-
if self.is_calibrator_enabled():
|
|
540
|
-
# calibrator, encoder_calibrator
|
|
541
|
-
if self.is_separate_cfg_step():
|
|
542
|
-
calibrator, _ = self.get_cfg_calibrator()
|
|
543
|
-
else:
|
|
544
|
-
calibrator, _ = self.get_calibrator()
|
|
545
|
-
|
|
546
|
-
if calibrator is not None:
|
|
547
|
-
# Use calibrator to update the buffer
|
|
548
|
-
calibrator.update(buffer)
|
|
549
|
-
else:
|
|
550
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
551
|
-
logger.debug(
|
|
552
|
-
"calibrator is enabled but not set in the cache context. "
|
|
553
|
-
"Falling back to default buffer retrieval."
|
|
554
|
-
)
|
|
555
|
-
if self.is_separate_cfg_step():
|
|
556
|
-
self._debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
557
|
-
self.set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
558
|
-
else:
|
|
559
|
-
self._debugging_set_buffer(f"{prefix}_buffer")
|
|
560
|
-
self.set_buffer(f"{prefix}_buffer", buffer)
|
|
561
|
-
else:
|
|
562
|
-
if self.is_separate_cfg_step():
|
|
563
|
-
self._debugging_set_buffer(f"{prefix}_buffer_cfg")
|
|
564
|
-
self.set_buffer(f"{prefix}_buffer_cfg", buffer)
|
|
565
|
-
else:
|
|
566
|
-
self._debugging_set_buffer(f"{prefix}_buffer")
|
|
567
|
-
self.set_buffer(f"{prefix}_buffer", buffer)
|
|
568
|
-
|
|
569
|
-
@torch.compiler.disable
|
|
570
|
-
def get_Bn_buffer(self, prefix: str = "Bn") -> torch.Tensor:
|
|
571
|
-
if self.is_calibrator_enabled():
|
|
572
|
-
# calibrator, encoder_calibrator
|
|
573
|
-
if self.is_separate_cfg_step():
|
|
574
|
-
calibrator, _ = self.get_cfg_calibrator()
|
|
575
|
-
else:
|
|
576
|
-
calibrator, _ = self.get_calibrator()
|
|
577
|
-
|
|
578
|
-
if calibrator is not None:
|
|
579
|
-
return calibrator.approximate()
|
|
580
|
-
else:
|
|
581
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
582
|
-
logger.debug(
|
|
583
|
-
"calibrator is enabled but not set in the cache context. "
|
|
584
|
-
"Falling back to default buffer retrieval."
|
|
585
|
-
)
|
|
586
|
-
# Fallback to default buffer retrieval
|
|
587
|
-
if self.is_separate_cfg_step():
|
|
588
|
-
self._debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
589
|
-
return self.get_buffer(f"{prefix}_buffer_cfg")
|
|
590
|
-
self._debugging_get_buffer(f"{prefix}_buffer")
|
|
591
|
-
return self.get_buffer(f"{prefix}_buffer")
|
|
592
|
-
else:
|
|
593
|
-
if self.is_separate_cfg_step():
|
|
594
|
-
self._debugging_get_buffer(f"{prefix}_buffer_cfg")
|
|
595
|
-
return self.get_buffer(f"{prefix}_buffer_cfg")
|
|
596
|
-
self._debugging_get_buffer(f"{prefix}_buffer")
|
|
597
|
-
return self.get_buffer(f"{prefix}_buffer")
|
|
598
|
-
|
|
599
|
-
@torch.compiler.disable
|
|
600
|
-
def set_Bn_encoder_buffer(
|
|
601
|
-
self, buffer: torch.Tensor | None, prefix: str = "Bn"
|
|
602
|
-
):
|
|
603
|
-
# DON'T set None Buffer
|
|
604
|
-
if buffer is None:
|
|
605
|
-
return
|
|
606
|
-
|
|
607
|
-
# This buffer is use for encoder hidden states approximation.
|
|
608
|
-
if self.is_encoder_calibrator_enabled():
|
|
609
|
-
# calibrator, encoder_calibrator
|
|
610
|
-
if self.is_separate_cfg_step():
|
|
611
|
-
_, encoder_calibrator = self.get_cfg_calibrator()
|
|
612
|
-
else:
|
|
613
|
-
_, encoder_calibrator = self.get_calibrator()
|
|
614
|
-
|
|
615
|
-
if encoder_calibrator is not None:
|
|
616
|
-
# Use CalibratorBase to update the buffer
|
|
617
|
-
encoder_calibrator.update(buffer)
|
|
618
|
-
else:
|
|
619
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
620
|
-
logger.debug(
|
|
621
|
-
"CalibratorBase is enabled but not set in the cache context. "
|
|
622
|
-
"Falling back to default buffer retrieval."
|
|
623
|
-
)
|
|
624
|
-
if self.is_separate_cfg_step():
|
|
625
|
-
self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
626
|
-
self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
627
|
-
else:
|
|
628
|
-
self._debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
629
|
-
self.set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
630
|
-
else:
|
|
631
|
-
if self.is_separate_cfg_step():
|
|
632
|
-
self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
633
|
-
self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
634
|
-
else:
|
|
635
|
-
self._debugging_set_buffer(f"{prefix}_encoder_buffer")
|
|
636
|
-
self.set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
637
|
-
|
|
638
|
-
@torch.compiler.disable
|
|
639
|
-
def get_Bn_encoder_buffer(self, prefix: str = "Bn") -> torch.Tensor:
|
|
640
|
-
if self.is_encoder_calibrator_enabled():
|
|
641
|
-
if self.is_separate_cfg_step():
|
|
642
|
-
_, encoder_calibrator = self.get_cfg_calibrator()
|
|
643
|
-
else:
|
|
644
|
-
_, encoder_calibrator = self.get_calibrator()
|
|
645
|
-
|
|
646
|
-
if encoder_calibrator is not None:
|
|
647
|
-
# Use calibrator to approximate the value
|
|
648
|
-
return encoder_calibrator.approximate()
|
|
649
|
-
else:
|
|
650
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
651
|
-
logger.debug(
|
|
652
|
-
"calibrator is enabled but not set in the cache context. "
|
|
653
|
-
"Falling back to default buffer retrieval."
|
|
654
|
-
)
|
|
655
|
-
# Fallback to default buffer retrieval
|
|
656
|
-
if self.is_separate_cfg_step():
|
|
657
|
-
self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
658
|
-
return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
659
|
-
self._debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
660
|
-
return self.get_buffer(f"{prefix}_encoder_buffer")
|
|
661
|
-
else:
|
|
662
|
-
if self.is_separate_cfg_step():
|
|
663
|
-
self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
664
|
-
return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
665
|
-
self._debugging_get_buffer(f"{prefix}_encoder_buffer")
|
|
666
|
-
return self.get_buffer(f"{prefix}_encoder_buffer")
|
|
667
|
-
|
|
668
|
-
@torch.compiler.disable
|
|
669
|
-
def apply_cache(
|
|
670
|
-
self,
|
|
671
|
-
hidden_states: torch.Tensor,
|
|
672
|
-
encoder_hidden_states: torch.Tensor = None,
|
|
673
|
-
prefix: str = "Bn",
|
|
674
|
-
encoder_prefix: str = "Bn_encoder",
|
|
675
|
-
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
|
676
|
-
# Allow Bn and Fn prefix to be used for residual cache.
|
|
677
|
-
if "Bn" in prefix:
|
|
678
|
-
hidden_states_prev = self.get_Bn_buffer(prefix)
|
|
679
|
-
else:
|
|
680
|
-
hidden_states_prev = self.get_Fn_buffer(prefix)
|
|
681
|
-
|
|
682
|
-
assert (
|
|
683
|
-
hidden_states_prev is not None
|
|
684
|
-
), f"{prefix}_buffer must be set before"
|
|
685
|
-
|
|
686
|
-
if self.is_cache_residual():
|
|
687
|
-
hidden_states = hidden_states_prev + hidden_states
|
|
688
|
-
else:
|
|
689
|
-
# If cache is not residual, we use the hidden states directly
|
|
690
|
-
hidden_states = hidden_states_prev
|
|
691
|
-
|
|
692
|
-
hidden_states = hidden_states.contiguous()
|
|
693
|
-
|
|
694
|
-
if encoder_hidden_states is not None:
|
|
695
|
-
if "Bn" in encoder_prefix:
|
|
696
|
-
encoder_hidden_states_prev = self.get_Bn_encoder_buffer(
|
|
697
|
-
encoder_prefix
|
|
698
|
-
)
|
|
699
|
-
else:
|
|
700
|
-
encoder_hidden_states_prev = self.get_Fn_encoder_buffer(
|
|
701
|
-
encoder_prefix
|
|
702
|
-
)
|
|
703
|
-
|
|
704
|
-
if encoder_hidden_states_prev is not None:
|
|
705
|
-
|
|
706
|
-
if self.is_encoder_cache_residual():
|
|
707
|
-
encoder_hidden_states = (
|
|
708
|
-
encoder_hidden_states_prev + encoder_hidden_states
|
|
709
|
-
)
|
|
710
|
-
else:
|
|
711
|
-
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
712
|
-
encoder_hidden_states = encoder_hidden_states_prev
|
|
713
|
-
|
|
714
|
-
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
715
|
-
|
|
716
|
-
return hidden_states, encoder_hidden_states
|
|
717
|
-
|
|
718
|
-
@torch.compiler.disable
|
|
719
|
-
def get_downsample_factor(self) -> float:
|
|
720
|
-
cached_context = self.get_context()
|
|
721
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
722
|
-
return cached_context.downsample_factor
|
|
723
|
-
|
|
724
|
-
@torch.compiler.disable
|
|
725
|
-
def can_cache(
|
|
726
|
-
self,
|
|
727
|
-
states_tensor: torch.Tensor, # hidden_states or residual
|
|
728
|
-
parallelized: bool = False,
|
|
729
|
-
threshold: Optional[float] = None, # can manually set threshold
|
|
730
|
-
prefix: str = "Fn",
|
|
731
|
-
) -> bool:
|
|
732
|
-
|
|
733
|
-
if self.is_in_warmup():
|
|
734
|
-
return False
|
|
735
|
-
|
|
736
|
-
# max cached steps
|
|
737
|
-
max_cached_steps = self.get_max_cached_steps()
|
|
738
|
-
if not self.is_separate_cfg_step():
|
|
739
|
-
cached_steps = self.get_cached_steps()
|
|
740
|
-
else:
|
|
741
|
-
cached_steps = self.get_cfg_cached_steps()
|
|
742
|
-
|
|
743
|
-
if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
|
|
744
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
745
|
-
logger.debug(
|
|
746
|
-
f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
|
|
747
|
-
"can not use cache."
|
|
748
|
-
)
|
|
749
|
-
return False
|
|
750
|
-
|
|
751
|
-
# max continuous cached steps
|
|
752
|
-
max_continuous_cached_steps = self.get_max_continuous_cached_steps()
|
|
753
|
-
if not self.is_separate_cfg_step():
|
|
754
|
-
continuous_cached_steps = self.get_continuous_cached_steps()
|
|
755
|
-
else:
|
|
756
|
-
continuous_cached_steps = self.get_cfg_continuous_cached_steps()
|
|
757
|
-
|
|
758
|
-
if max_continuous_cached_steps >= 0 and (
|
|
759
|
-
continuous_cached_steps >= max_continuous_cached_steps
|
|
760
|
-
):
|
|
761
|
-
if logger.isEnabledFor(logging.DEBUG):
|
|
762
|
-
logger.debug(
|
|
763
|
-
f"{prefix}, max_continuous_cached_steps "
|
|
764
|
-
f"reached: {max_continuous_cached_steps}, "
|
|
765
|
-
"can not use cache."
|
|
766
|
-
)
|
|
767
|
-
# reset continuous cached steps stats
|
|
768
|
-
cached_context = self.get_context()
|
|
769
|
-
if not self.is_separate_cfg_step():
|
|
770
|
-
cached_context.continuous_cached_steps = 0
|
|
771
|
-
else:
|
|
772
|
-
cached_context.cfg_continuous_cached_steps = 0
|
|
773
|
-
return False
|
|
774
|
-
|
|
775
|
-
if threshold is None or threshold <= 0.0:
|
|
776
|
-
threshold = self.get_residual_diff_threshold()
|
|
777
|
-
if threshold <= 0.0:
|
|
778
|
-
return False
|
|
779
|
-
|
|
780
|
-
downsample_factor = self.get_downsample_factor()
|
|
781
|
-
if downsample_factor > 1 and "Bn" not in prefix:
|
|
782
|
-
states_tensor = states_tensor[..., ::downsample_factor]
|
|
783
|
-
states_tensor = states_tensor.contiguous()
|
|
784
|
-
|
|
785
|
-
# Allow Bn and Fn prefix to be used for diff calculation.
|
|
786
|
-
if "Bn" in prefix:
|
|
787
|
-
prev_states_tensor = self.get_Bn_buffer(prefix)
|
|
788
|
-
else:
|
|
789
|
-
prev_states_tensor = self.get_Fn_buffer(prefix)
|
|
790
|
-
|
|
791
|
-
# Dynamic cache according to the residual diff
|
|
792
|
-
can_cache = prev_states_tensor is not None and self.similarity(
|
|
793
|
-
prev_states_tensor,
|
|
794
|
-
states_tensor,
|
|
795
|
-
threshold=threshold,
|
|
796
|
-
parallelized=parallelized,
|
|
797
|
-
prefix=prefix,
|
|
798
|
-
)
|
|
799
|
-
return can_cache
|