cache-dit 0.2.37__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (24) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +7 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +16 -6
  5. cache_dit/cache_factory/cache_adapters/__init__.py +2 -0
  6. cache_dit/cache_factory/{cache_adapters.py → cache_adapters/cache_adapter.py} +6 -6
  7. cache_dit/cache_factory/cache_adapters/v2/__init__.py +3 -0
  8. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +524 -0
  9. cache_dit/cache_factory/cache_contexts/__init__.py +7 -0
  10. cache_dit/cache_factory/cache_contexts/v2/__init__.py +13 -0
  11. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +288 -0
  12. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +799 -0
  13. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +81 -0
  14. cache_dit/cache_factory/cache_contexts/v2/calibrators/base.py +27 -0
  15. cache_dit/cache_factory/cache_contexts/v2/calibrators/foca.py +26 -0
  16. cache_dit/cache_factory/cache_contexts/v2/calibrators/taylorseer.py +105 -0
  17. cache_dit/cache_factory/cache_interface.py +39 -12
  18. cache_dit/utils.py +17 -7
  19. {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/METADATA +57 -42
  20. {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/RECORD +24 -14
  21. {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/WHEEL +0 -0
  22. {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/entry_points.txt +0 -0
  23. {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/licenses/LICENSE +0 -0
  24. {cache_dit-0.2.37.dist-info → cache_dit-0.3.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,799 @@
1
+ import logging
2
+ import contextlib
3
+ import dataclasses
4
+ from typing import Any, Dict, Optional, Tuple, Union, List
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from cache_dit.cache_factory.cache_contexts.v2.calibrators import CalibratorBase
10
+ from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
11
+ CachedContextV2,
12
+ )
13
+ from cache_dit.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ class CachedContextManagerV2:
19
+ # Each Pipeline should have it's own context manager instance.
20
+
21
+ def __init__(self, name: str = None):
22
+ self.name = name
23
+ self._current_context: CachedContextV2 = None
24
+ self._cached_context_manager: Dict[str, CachedContextV2] = {}
25
+
26
+ def new_context(self, *args, **kwargs) -> CachedContextV2:
27
+ _context = CachedContextV2(*args, **kwargs)
28
+ self._cached_context_manager[_context.name] = _context
29
+ return _context
30
+
31
+ def set_context(self, cached_context: CachedContextV2 | str):
32
+ if isinstance(cached_context, CachedContextV2):
33
+ self._current_context = cached_context
34
+ else:
35
+ self._current_context = self._cached_context_manager[cached_context]
36
+
37
+ def get_context(self, name: str = None) -> CachedContextV2:
38
+ if name is not None:
39
+ if name not in self._cached_context_manager:
40
+ raise ValueError("Context not exist!")
41
+ return self._cached_context_manager[name]
42
+ return self._current_context
43
+
44
+ def reset_context(
45
+ self,
46
+ cached_context: CachedContextV2 | str,
47
+ *args,
48
+ **kwargs,
49
+ ) -> CachedContextV2:
50
+ if isinstance(cached_context, CachedContextV2):
51
+ old_context_name = cached_context.name
52
+ if cached_context.name in self._cached_context_manager:
53
+ cached_context.clear_buffers()
54
+ del self._cached_context_manager[cached_context.name]
55
+ # force use old_context name
56
+ kwargs["name"] = old_context_name
57
+ _context = self.new_context(*args, **kwargs)
58
+ else:
59
+ old_context_name = cached_context
60
+ if cached_context in self._cached_context_manager:
61
+ self._cached_context_manager[cached_context].clear_buffers()
62
+ del self._cached_context_manager[cached_context]
63
+ # force use old_context name
64
+ kwargs["name"] = old_context_name
65
+ _context = self.new_context(*args, **kwargs)
66
+ return _context
67
+
68
+ def remove_context(self, cached_context: CachedContextV2 | str):
69
+ if isinstance(cached_context, CachedContextV2):
70
+ cached_context.clear_buffers()
71
+ if cached_context.name in self._cached_context_manager:
72
+ del self._cached_context_manager[cached_context.name]
73
+ else:
74
+ if cached_context in self._cached_context_manager:
75
+ self._cached_context_manager[cached_context].clear_buffers()
76
+ del self._cached_context_manager[cached_context]
77
+
78
+ def clear_contexts(self):
79
+ for context_name in list(self._cached_context_manager.keys()):
80
+ self.remove_context(context_name)
81
+
82
+ @contextlib.contextmanager
83
+ def enter_context(self, cached_context: CachedContextV2 | str):
84
+ old_cached_context = self._current_context
85
+ if isinstance(cached_context, CachedContextV2):
86
+ self._current_context = cached_context
87
+ else:
88
+ self._current_context = self._cached_context_manager[cached_context]
89
+ try:
90
+ yield
91
+ finally:
92
+ self._current_context = old_cached_context
93
+
94
+ @staticmethod
95
+ def collect_cache_kwargs(
96
+ default_attrs: dict, **kwargs
97
+ ) -> Tuple[Dict, Dict]:
98
+ # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
99
+ # default_attrs: specific settings for different pipelines
100
+ cache_attrs = dataclasses.fields(CachedContextV2)
101
+ cache_attrs = [
102
+ attr
103
+ for attr in cache_attrs
104
+ if hasattr(
105
+ CachedContextV2,
106
+ attr.name,
107
+ )
108
+ ]
109
+ cache_kwargs = {
110
+ attr.name: kwargs.pop(
111
+ attr.name,
112
+ getattr(CachedContextV2, attr.name),
113
+ )
114
+ for attr in cache_attrs
115
+ }
116
+
117
+ def _safe_set_sequence_field(
118
+ field_name: str,
119
+ default_value: Any = None,
120
+ ):
121
+ if field_name not in cache_kwargs:
122
+ cache_kwargs[field_name] = kwargs.pop(
123
+ field_name,
124
+ default_value,
125
+ )
126
+
127
+ # Manually set sequence fields
128
+ _safe_set_sequence_field("calibrator_kwargs", {})
129
+
130
+ for attr in cache_attrs:
131
+ if attr.name in default_attrs: # can be empty {}
132
+ cache_kwargs[attr.name] = default_attrs[attr.name]
133
+
134
+ if logger.isEnabledFor(logging.DEBUG):
135
+ logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
136
+
137
+ return cache_kwargs, kwargs
138
+
139
+ @torch.compiler.disable
140
+ def get_residual_diff_threshold(self) -> float:
141
+ cached_context = self.get_context()
142
+ assert cached_context is not None, "cached_context must be set before"
143
+ return cached_context.get_residual_diff_threshold()
144
+
145
+ @torch.compiler.disable
146
+ def get_buffer(self, name) -> torch.Tensor:
147
+ cached_context = self.get_context()
148
+ assert cached_context is not None, "cached_context must be set before"
149
+ return cached_context.get_buffer(name)
150
+
151
+ @torch.compiler.disable
152
+ def set_buffer(self, name, buffer):
153
+ cached_context = self.get_context()
154
+ assert cached_context is not None, "cached_context must be set before"
155
+ cached_context.set_buffer(name, buffer)
156
+
157
+ @torch.compiler.disable
158
+ def remove_buffer(self, name):
159
+ cached_context = self.get_context()
160
+ assert cached_context is not None, "cached_context must be set before"
161
+ cached_context.remove_buffer(name)
162
+
163
+ @torch.compiler.disable
164
+ def mark_step_begin(self):
165
+ cached_context = self.get_context()
166
+ assert cached_context is not None, "cached_context must be set before"
167
+ cached_context.mark_step_begin()
168
+
169
+ @torch.compiler.disable
170
+ def get_current_step(self) -> int:
171
+ cached_context = self.get_context()
172
+ assert cached_context is not None, "cached_context must be set before"
173
+ return cached_context.get_current_step()
174
+
175
+ @torch.compiler.disable
176
+ def get_current_step_residual_diff(self) -> float:
177
+ cached_context = self.get_context()
178
+ assert cached_context is not None, "cached_context must be set before"
179
+ step = str(self.get_current_step())
180
+ residual_diffs = self.get_residual_diffs()
181
+ if step in residual_diffs:
182
+ return residual_diffs[step]
183
+ return None
184
+
185
+ @torch.compiler.disable
186
+ def get_current_step_cfg_residual_diff(self) -> float:
187
+ cached_context = self.get_context()
188
+ assert cached_context is not None, "cached_context must be set before"
189
+ step = str(self.get_current_step())
190
+ cfg_residual_diffs = self.get_cfg_residual_diffs()
191
+ if step in cfg_residual_diffs:
192
+ return cfg_residual_diffs[step]
193
+ return None
194
+
195
+ @torch.compiler.disable
196
+ def get_current_transformer_step(self) -> int:
197
+ cached_context = self.get_context()
198
+ assert cached_context is not None, "cached_context must be set before"
199
+ return cached_context.get_current_transformer_step()
200
+
201
+ @torch.compiler.disable
202
+ def get_cached_steps(self) -> List[int]:
203
+ cached_context = self.get_context()
204
+ assert cached_context is not None, "cached_context must be set before"
205
+ return cached_context.get_cached_steps()
206
+
207
+ @torch.compiler.disable
208
+ def get_cfg_cached_steps(self) -> List[int]:
209
+ cached_context = self.get_context()
210
+ assert cached_context is not None, "cached_context must be set before"
211
+ return cached_context.get_cfg_cached_steps()
212
+
213
+ @torch.compiler.disable
214
+ def get_max_cached_steps(self) -> int:
215
+ cached_context = self.get_context()
216
+ assert cached_context is not None, "cached_context must be set before"
217
+ return cached_context.max_cached_steps
218
+
219
+ @torch.compiler.disable
220
+ def get_max_continuous_cached_steps(self) -> int:
221
+ cached_context = self.get_context()
222
+ assert cached_context is not None, "cached_context must be set before"
223
+ return cached_context.max_continuous_cached_steps
224
+
225
+ @torch.compiler.disable
226
+ def get_continuous_cached_steps(self) -> int:
227
+ cached_context = self.get_context()
228
+ assert cached_context is not None, "cached_context must be set before"
229
+ return cached_context.continuous_cached_steps
230
+
231
+ @torch.compiler.disable
232
+ def get_cfg_continuous_cached_steps(self) -> int:
233
+ cached_context = self.get_context()
234
+ assert cached_context is not None, "cached_context must be set before"
235
+ return cached_context.cfg_continuous_cached_steps
236
+
237
+ @torch.compiler.disable
238
+ def add_cached_step(self):
239
+ cached_context = self.get_context()
240
+ assert cached_context is not None, "cached_context must be set before"
241
+ cached_context.add_cached_step()
242
+
243
+ @torch.compiler.disable
244
+ def add_residual_diff(self, diff):
245
+ cached_context = self.get_context()
246
+ assert cached_context is not None, "cached_context must be set before"
247
+ cached_context.add_residual_diff(diff)
248
+
249
+ @torch.compiler.disable
250
+ def get_residual_diffs(self) -> Dict[str, float]:
251
+ cached_context = self.get_context()
252
+ assert cached_context is not None, "cached_context must be set before"
253
+ return cached_context.get_residual_diffs()
254
+
255
+ @torch.compiler.disable
256
+ def get_cfg_residual_diffs(self) -> Dict[str, float]:
257
+ cached_context = self.get_context()
258
+ assert cached_context is not None, "cached_context must be set before"
259
+ return cached_context.get_cfg_residual_diffs()
260
+
261
+ @torch.compiler.disable
262
+ def is_calibrator_enabled(self) -> bool:
263
+ cached_context = self.get_context()
264
+ assert cached_context is not None, "cached_context must be set before"
265
+ return cached_context.enable_calibrator()
266
+
267
+ @torch.compiler.disable
268
+ def is_encoder_calibrator_enabled(self) -> bool:
269
+ cached_context = self.get_context()
270
+ assert cached_context is not None, "cached_context must be set before"
271
+ return cached_context.enable_encoder_calibrator()
272
+
273
+ def get_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
274
+ cached_context = self.get_context()
275
+ assert cached_context is not None, "cached_context must be set before"
276
+ return cached_context.get_calibrators()
277
+
278
+ def get_cfg_calibrator(self) -> Tuple[CalibratorBase, CalibratorBase]:
279
+ cached_context = self.get_context()
280
+ assert cached_context is not None, "cached_context must be set before"
281
+ return cached_context.get_cfg_calibrators()
282
+
283
+ @torch.compiler.disable
284
+ def is_calibrator_cache_residual(self) -> bool:
285
+ cached_context = self.get_context()
286
+ assert cached_context is not None, "cached_context must be set before"
287
+ return cached_context.calibrator_cache_type() == "residual"
288
+
289
+ @torch.compiler.disable
290
+ def is_cache_residual(self) -> bool:
291
+ if self.is_calibrator_enabled():
292
+ # residual or hidden_states
293
+ return self.is_calibrator_cache_residual()
294
+ return True
295
+
296
+ @torch.compiler.disable
297
+ def is_encoder_cache_residual(self) -> bool:
298
+ if self.is_encoder_calibrator_enabled():
299
+ # residual or hidden_states
300
+ return self.is_calibrator_cache_residual()
301
+ return True
302
+
303
+ @torch.compiler.disable
304
+ def is_in_warmup(self) -> bool:
305
+ cached_context = self.get_context()
306
+ assert cached_context is not None, "cached_context must be set before"
307
+ return cached_context.is_in_warmup()
308
+
309
+ @torch.compiler.disable
310
+ def is_l1_diff_enabled(self) -> bool:
311
+ cached_context = self.get_context()
312
+ assert cached_context is not None, "cached_context must be set before"
313
+ return (
314
+ cached_context.l1_hidden_states_diff_threshold is not None
315
+ and cached_context.l1_hidden_states_diff_threshold > 0.0
316
+ )
317
+
318
+ @torch.compiler.disable
319
+ def get_important_condition_threshold(self) -> float:
320
+ cached_context = self.get_context()
321
+ assert cached_context is not None, "cached_context must be set before"
322
+ return cached_context.important_condition_threshold
323
+
324
+ @torch.compiler.disable
325
+ def non_compute_blocks_diff_threshold(self) -> float:
326
+ cached_context = self.get_context()
327
+ assert cached_context is not None, "cached_context must be set before"
328
+ return cached_context.non_compute_blocks_diff_threshold
329
+
330
+ @torch.compiler.disable
331
+ def Fn_compute_blocks(self) -> int:
332
+ cached_context = self.get_context()
333
+ assert cached_context is not None, "cached_context must be set before"
334
+ assert (
335
+ cached_context.Fn_compute_blocks >= 1
336
+ ), "Fn_compute_blocks must be >= 1"
337
+ if cached_context.max_Fn_compute_blocks > 0:
338
+ # NOTE: Fn_compute_blocks can be 1, which means FB Cache
339
+ # but it must be less than or equal to max_Fn_compute_blocks
340
+ assert (
341
+ cached_context.Fn_compute_blocks
342
+ <= cached_context.max_Fn_compute_blocks
343
+ ), (
344
+ f"Fn_compute_blocks must be <= {cached_context.max_Fn_compute_blocks}, "
345
+ f"but got {cached_context.Fn_compute_blocks}"
346
+ )
347
+ return cached_context.Fn_compute_blocks
348
+
349
+ @torch.compiler.disable
350
+ def Bn_compute_blocks(self) -> int:
351
+ cached_context = self.get_context()
352
+ assert cached_context is not None, "cached_context must be set before"
353
+ assert (
354
+ cached_context.Bn_compute_blocks >= 0
355
+ ), "Bn_compute_blocks must be >= 0"
356
+ if cached_context.max_Bn_compute_blocks > 0:
357
+ # NOTE: Bn_compute_blocks can be 0, which means FB Cache
358
+ # but it must be less than or equal to max_Bn_compute_blocks
359
+ assert (
360
+ cached_context.Bn_compute_blocks
361
+ <= cached_context.max_Bn_compute_blocks
362
+ ), (
363
+ f"Bn_compute_blocks must be <= {cached_context.max_Bn_compute_blocks}, "
364
+ f"but got {cached_context.Bn_compute_blocks}"
365
+ )
366
+ return cached_context.Bn_compute_blocks
367
+
368
+ @torch.compiler.disable
369
+ def enable_separate_cfg(self) -> bool:
370
+ cached_context = self.get_context()
371
+ assert cached_context is not None, "cached_context must be set before"
372
+ return cached_context.enable_separate_cfg
373
+
374
+ @torch.compiler.disable
375
+ def is_separate_cfg_step(self) -> bool:
376
+ cached_context = self.get_context()
377
+ assert cached_context is not None, "cached_context must be set before"
378
+ return cached_context.is_separate_cfg_step()
379
+
380
+ @torch.compiler.disable
381
+ def cfg_diff_compute_separate(self) -> bool:
382
+ cached_context = self.get_context()
383
+ assert cached_context is not None, "cached_context must be set before"
384
+ return cached_context.cfg_diff_compute_separate
385
+
386
+ @torch.compiler.disable
387
+ def similarity(
388
+ self,
389
+ t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
390
+ t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
391
+ *,
392
+ threshold: float,
393
+ parallelized: bool = False,
394
+ prefix: str = "Fn", # for debugging
395
+ ) -> bool:
396
+ # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
397
+ # the threshold is always enabled, -2.0 means the shape is not matched.
398
+ if threshold <= 0.0:
399
+ self.add_residual_diff(-0.0)
400
+ return False
401
+
402
+ if threshold >= 1.0:
403
+ # If threshold is 1.0 or more, we consider them always similar.
404
+ self.add_residual_diff(-1.0)
405
+ return True
406
+
407
+ if t1.shape != t2.shape:
408
+ if logger.isEnabledFor(logging.DEBUG):
409
+ logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
410
+ self.add_residual_diff(-2.0)
411
+ return False
412
+
413
+ if all(
414
+ (
415
+ self.enable_separate_cfg(),
416
+ self.is_separate_cfg_step(),
417
+ not self.cfg_diff_compute_separate(),
418
+ self.get_current_step_residual_diff() is not None,
419
+ )
420
+ ):
421
+ # Reuse computed diff value from non-CFG step
422
+ diff = self.get_current_step_residual_diff()
423
+ else:
424
+ # Find the most significant token through t1 and t2, and
425
+ # consider the diff of the significant token. The more significant,
426
+ # the more important.
427
+ condition_thresh = self.get_important_condition_threshold()
428
+ if condition_thresh > 0.0:
429
+ raw_diff = (t1 - t2).abs() # [B, seq_len, d]
430
+ token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
431
+ token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
432
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
433
+ token_diff = token_m_df / token_m_t1 # [B, seq_len]
434
+ condition: torch.Tensor = (
435
+ token_diff > condition_thresh
436
+ ) # [B, seq_len]
437
+ if condition.sum() > 0:
438
+ condition = condition.unsqueeze(-1) # [B, seq_len, 1]
439
+ condition = condition.expand_as(raw_diff) # [B, seq_len, d]
440
+ mean_diff = raw_diff[condition].mean()
441
+ mean_t1 = t1[condition].abs().mean()
442
+ else:
443
+ mean_diff = (t1 - t2).abs().mean()
444
+ mean_t1 = t1.abs().mean()
445
+ else:
446
+ # Use the mean of the absolute difference of the tensors
447
+ mean_diff = (t1 - t2).abs().mean()
448
+ mean_t1 = t1.abs().mean()
449
+
450
+ if parallelized:
451
+ dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
452
+ dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
453
+
454
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
455
+ # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
456
+ # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
457
+ diff = (mean_diff / mean_t1).item()
458
+
459
+ if logger.isEnabledFor(logging.DEBUG):
460
+ logger.debug(
461
+ f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}"
462
+ )
463
+
464
+ self.add_residual_diff(diff)
465
+
466
+ return diff < threshold
467
+
468
+ def _debugging_set_buffer(self, prefix):
469
+ if logger.isEnabledFor(logging.DEBUG):
470
+ logger.debug(
471
+ f"set {prefix}, "
472
+ f"transformer step: {self.get_current_transformer_step()}, "
473
+ f"executed step: {self.get_current_step()}"
474
+ )
475
+
476
+ def _debugging_get_buffer(self, prefix):
477
+ if logger.isEnabledFor(logging.DEBUG):
478
+ logger.debug(
479
+ f"get {prefix}, "
480
+ f"transformer step: {self.get_current_transformer_step()}, "
481
+ f"executed step: {self.get_current_step()}"
482
+ )
483
+
484
+ # Fn buffers
485
+ @torch.compiler.disable
486
+ def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
487
+ # DON'T set None Buffer
488
+ if buffer is None:
489
+ return
490
+ # Set hidden_states or residual for Fn blocks.
491
+ # This buffer is only use for L1 diff calculation.
492
+ downsample_factor = self.get_downsample_factor()
493
+ if downsample_factor > 1:
494
+ buffer = buffer[..., ::downsample_factor]
495
+ buffer = buffer.contiguous()
496
+ if self.is_separate_cfg_step():
497
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
498
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
499
+ else:
500
+ self._debugging_set_buffer(f"{prefix}_buffer")
501
+ self.set_buffer(f"{prefix}_buffer", buffer)
502
+
503
+ @torch.compiler.disable
504
+ def get_Fn_buffer(self, prefix: str = "Fn") -> torch.Tensor:
505
+ if self.is_separate_cfg_step():
506
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
507
+ return self.get_buffer(f"{prefix}_buffer_cfg")
508
+ self._debugging_get_buffer(f"{prefix}_buffer")
509
+ return self.get_buffer(f"{prefix}_buffer")
510
+
511
+ @torch.compiler.disable
512
+ def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
513
+ # DON'T set None Buffer
514
+ if buffer is None:
515
+ return
516
+ if self.is_separate_cfg_step():
517
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
518
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
519
+ else:
520
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
521
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
522
+
523
+ @torch.compiler.disable
524
+ def get_Fn_encoder_buffer(self, prefix: str = "Fn") -> torch.Tensor:
525
+ if self.is_separate_cfg_step():
526
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
527
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
528
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
529
+ return self.get_buffer(f"{prefix}_encoder_buffer")
530
+
531
+ # Bn buffers
532
+ @torch.compiler.disable
533
+ def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
534
+ # DON'T set None Buffer
535
+ if buffer is None:
536
+ return
537
+ # Set hidden_states or residual for Bn blocks.
538
+ # This buffer is use for hidden states approximation.
539
+ if self.is_calibrator_enabled():
540
+ # calibrator, encoder_calibrator
541
+ if self.is_separate_cfg_step():
542
+ calibrator, _ = self.get_cfg_calibrator()
543
+ else:
544
+ calibrator, _ = self.get_calibrator()
545
+
546
+ if calibrator is not None:
547
+ # Use calibrator to update the buffer
548
+ calibrator.update(buffer)
549
+ else:
550
+ if logger.isEnabledFor(logging.DEBUG):
551
+ logger.debug(
552
+ "calibrator is enabled but not set in the cache context. "
553
+ "Falling back to default buffer retrieval."
554
+ )
555
+ if self.is_separate_cfg_step():
556
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
557
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
558
+ else:
559
+ self._debugging_set_buffer(f"{prefix}_buffer")
560
+ self.set_buffer(f"{prefix}_buffer", buffer)
561
+ else:
562
+ if self.is_separate_cfg_step():
563
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
564
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
565
+ else:
566
+ self._debugging_set_buffer(f"{prefix}_buffer")
567
+ self.set_buffer(f"{prefix}_buffer", buffer)
568
+
569
+ @torch.compiler.disable
570
+ def get_Bn_buffer(self, prefix: str = "Bn") -> torch.Tensor:
571
+ if self.is_calibrator_enabled():
572
+ # calibrator, encoder_calibrator
573
+ if self.is_separate_cfg_step():
574
+ calibrator, _ = self.get_cfg_calibrator()
575
+ else:
576
+ calibrator, _ = self.get_calibrator()
577
+
578
+ if calibrator is not None:
579
+ return calibrator.approximate()
580
+ else:
581
+ if logger.isEnabledFor(logging.DEBUG):
582
+ logger.debug(
583
+ "calibrator is enabled but not set in the cache context. "
584
+ "Falling back to default buffer retrieval."
585
+ )
586
+ # Fallback to default buffer retrieval
587
+ if self.is_separate_cfg_step():
588
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
589
+ return self.get_buffer(f"{prefix}_buffer_cfg")
590
+ self._debugging_get_buffer(f"{prefix}_buffer")
591
+ return self.get_buffer(f"{prefix}_buffer")
592
+ else:
593
+ if self.is_separate_cfg_step():
594
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
595
+ return self.get_buffer(f"{prefix}_buffer_cfg")
596
+ self._debugging_get_buffer(f"{prefix}_buffer")
597
+ return self.get_buffer(f"{prefix}_buffer")
598
+
599
+ @torch.compiler.disable
600
+ def set_Bn_encoder_buffer(
601
+ self, buffer: torch.Tensor | None, prefix: str = "Bn"
602
+ ):
603
+ # DON'T set None Buffer
604
+ if buffer is None:
605
+ return
606
+
607
+ # This buffer is use for encoder hidden states approximation.
608
+ if self.is_encoder_calibrator_enabled():
609
+ # calibrator, encoder_calibrator
610
+ if self.is_separate_cfg_step():
611
+ _, encoder_calibrator = self.get_cfg_calibrator()
612
+ else:
613
+ _, encoder_calibrator = self.get_calibrator()
614
+
615
+ if encoder_calibrator is not None:
616
+ # Use CalibratorBase to update the buffer
617
+ encoder_calibrator.update(buffer)
618
+ else:
619
+ if logger.isEnabledFor(logging.DEBUG):
620
+ logger.debug(
621
+ "CalibratorBase is enabled but not set in the cache context. "
622
+ "Falling back to default buffer retrieval."
623
+ )
624
+ if self.is_separate_cfg_step():
625
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
626
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
627
+ else:
628
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
629
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
630
+ else:
631
+ if self.is_separate_cfg_step():
632
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
633
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
634
+ else:
635
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
636
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
637
+
638
+ @torch.compiler.disable
639
+ def get_Bn_encoder_buffer(self, prefix: str = "Bn") -> torch.Tensor:
640
+ if self.is_encoder_calibrator_enabled():
641
+ if self.is_separate_cfg_step():
642
+ _, encoder_calibrator = self.get_cfg_calibrator()
643
+ else:
644
+ _, encoder_calibrator = self.get_calibrator()
645
+
646
+ if encoder_calibrator is not None:
647
+ # Use calibrator to approximate the value
648
+ return encoder_calibrator.approximate()
649
+ else:
650
+ if logger.isEnabledFor(logging.DEBUG):
651
+ logger.debug(
652
+ "calibrator is enabled but not set in the cache context. "
653
+ "Falling back to default buffer retrieval."
654
+ )
655
+ # Fallback to default buffer retrieval
656
+ if self.is_separate_cfg_step():
657
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
658
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
659
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
660
+ return self.get_buffer(f"{prefix}_encoder_buffer")
661
+ else:
662
+ if self.is_separate_cfg_step():
663
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
664
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
665
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
666
+ return self.get_buffer(f"{prefix}_encoder_buffer")
667
+
668
+ @torch.compiler.disable
669
+ def apply_cache(
670
+ self,
671
+ hidden_states: torch.Tensor,
672
+ encoder_hidden_states: torch.Tensor = None,
673
+ prefix: str = "Bn",
674
+ encoder_prefix: str = "Bn_encoder",
675
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
676
+ # Allow Bn and Fn prefix to be used for residual cache.
677
+ if "Bn" in prefix:
678
+ hidden_states_prev = self.get_Bn_buffer(prefix)
679
+ else:
680
+ hidden_states_prev = self.get_Fn_buffer(prefix)
681
+
682
+ assert (
683
+ hidden_states_prev is not None
684
+ ), f"{prefix}_buffer must be set before"
685
+
686
+ if self.is_cache_residual():
687
+ hidden_states = hidden_states_prev + hidden_states
688
+ else:
689
+ # If cache is not residual, we use the hidden states directly
690
+ hidden_states = hidden_states_prev
691
+
692
+ hidden_states = hidden_states.contiguous()
693
+
694
+ if encoder_hidden_states is not None:
695
+ if "Bn" in encoder_prefix:
696
+ encoder_hidden_states_prev = self.get_Bn_encoder_buffer(
697
+ encoder_prefix
698
+ )
699
+ else:
700
+ encoder_hidden_states_prev = self.get_Fn_encoder_buffer(
701
+ encoder_prefix
702
+ )
703
+
704
+ if encoder_hidden_states_prev is not None:
705
+
706
+ if self.is_encoder_cache_residual():
707
+ encoder_hidden_states = (
708
+ encoder_hidden_states_prev + encoder_hidden_states
709
+ )
710
+ else:
711
+ # If encoder cache is not residual, we use the encoder hidden states directly
712
+ encoder_hidden_states = encoder_hidden_states_prev
713
+
714
+ encoder_hidden_states = encoder_hidden_states.contiguous()
715
+
716
+ return hidden_states, encoder_hidden_states
717
+
718
+ @torch.compiler.disable
719
+ def get_downsample_factor(self) -> float:
720
+ cached_context = self.get_context()
721
+ assert cached_context is not None, "cached_context must be set before"
722
+ return cached_context.downsample_factor
723
+
724
+ @torch.compiler.disable
725
+ def can_cache(
726
+ self,
727
+ states_tensor: torch.Tensor, # hidden_states or residual
728
+ parallelized: bool = False,
729
+ threshold: Optional[float] = None, # can manually set threshold
730
+ prefix: str = "Fn",
731
+ ) -> bool:
732
+
733
+ if self.is_in_warmup():
734
+ return False
735
+
736
+ # max cached steps
737
+ max_cached_steps = self.get_max_cached_steps()
738
+ if not self.is_separate_cfg_step():
739
+ cached_steps = self.get_cached_steps()
740
+ else:
741
+ cached_steps = self.get_cfg_cached_steps()
742
+
743
+ if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
744
+ if logger.isEnabledFor(logging.DEBUG):
745
+ logger.debug(
746
+ f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
747
+ "can not use cache."
748
+ )
749
+ return False
750
+
751
+ # max continuous cached steps
752
+ max_continuous_cached_steps = self.get_max_continuous_cached_steps()
753
+ if not self.is_separate_cfg_step():
754
+ continuous_cached_steps = self.get_continuous_cached_steps()
755
+ else:
756
+ continuous_cached_steps = self.get_cfg_continuous_cached_steps()
757
+
758
+ if max_continuous_cached_steps >= 0 and (
759
+ continuous_cached_steps >= max_continuous_cached_steps
760
+ ):
761
+ if logger.isEnabledFor(logging.DEBUG):
762
+ logger.debug(
763
+ f"{prefix}, max_continuous_cached_steps "
764
+ f"reached: {max_continuous_cached_steps}, "
765
+ "can not use cache."
766
+ )
767
+ # reset continuous cached steps stats
768
+ cached_context = self.get_context()
769
+ if not self.is_separate_cfg_step():
770
+ cached_context.continuous_cached_steps = 0
771
+ else:
772
+ cached_context.cfg_continuous_cached_steps = 0
773
+ return False
774
+
775
+ if threshold is None or threshold <= 0.0:
776
+ threshold = self.get_residual_diff_threshold()
777
+ if threshold <= 0.0:
778
+ return False
779
+
780
+ downsample_factor = self.get_downsample_factor()
781
+ if downsample_factor > 1 and "Bn" not in prefix:
782
+ states_tensor = states_tensor[..., ::downsample_factor]
783
+ states_tensor = states_tensor.contiguous()
784
+
785
+ # Allow Bn and Fn prefix to be used for diff calculation.
786
+ if "Bn" in prefix:
787
+ prev_states_tensor = self.get_Bn_buffer(prefix)
788
+ else:
789
+ prev_states_tensor = self.get_Fn_buffer(prefix)
790
+
791
+ # Dynamic cache according to the residual diff
792
+ can_cache = prev_states_tensor is not None and self.similarity(
793
+ prev_states_tensor,
794
+ states_tensor,
795
+ threshold=threshold,
796
+ parallelized=parallelized,
797
+ prefix=prefix,
798
+ )
799
+ return can_cache