cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -0,0 +1,847 @@
1
+ import logging
2
+ import contextlib
3
+ import dataclasses
4
+ from typing import Any, Dict, Optional, Tuple, Union, List
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
10
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class CachedContextManager:
17
+ # Each Pipeline should have it's own context manager instance.
18
+
19
+ def __init__(self, name: str = None):
20
+ self.name = name
21
+ self._current_context: CachedContext = None
22
+ self._cached_context_manager: Dict[str, CachedContext] = {}
23
+
24
+ def new_context(self, *args, **kwargs) -> CachedContext:
25
+ _context = CachedContext(*args, **kwargs)
26
+ self._cached_context_manager[_context.name] = _context
27
+ return _context
28
+
29
+ def set_context(self, cached_context: CachedContext | str):
30
+ if isinstance(cached_context, CachedContext):
31
+ self._current_context = cached_context
32
+ else:
33
+ self._current_context = self._cached_context_manager[cached_context]
34
+
35
+ def get_context(self, name: str = None) -> CachedContext:
36
+ if name is not None:
37
+ if name not in self._cached_context_manager:
38
+ raise ValueError("Context not exist!")
39
+ return self._cached_context_manager[name]
40
+ return self._current_context
41
+
42
+ def reset_context(
43
+ self,
44
+ cached_context: CachedContext | str,
45
+ *args,
46
+ **kwargs,
47
+ ) -> CachedContext:
48
+ if isinstance(cached_context, CachedContext):
49
+ old_context_name = cached_context.name
50
+ if cached_context.name in self._cached_context_manager:
51
+ cached_context.clear_buffers()
52
+ del self._cached_context_manager[cached_context.name]
53
+ # force use old_context name
54
+ kwargs["name"] = old_context_name
55
+ _context = self.new_context(*args, **kwargs)
56
+ else:
57
+ old_context_name = cached_context
58
+ if cached_context in self._cached_context_manager:
59
+ self._cached_context_manager[cached_context].clear_buffers()
60
+ del self._cached_context_manager[cached_context]
61
+ # force use old_context name
62
+ kwargs["name"] = old_context_name
63
+ _context = self.new_context(*args, **kwargs)
64
+ return _context
65
+
66
+ def remove_context(self, cached_context: CachedContext | str):
67
+ if isinstance(cached_context, CachedContext):
68
+ cached_context.clear_buffers()
69
+ if cached_context.name in self._cached_context_manager:
70
+ del self._cached_context_manager[cached_context.name]
71
+ else:
72
+ if cached_context in self._cached_context_manager:
73
+ self._cached_context_manager[cached_context].clear_buffers()
74
+ del self._cached_context_manager[cached_context]
75
+
76
+ def clear_contexts(self):
77
+ for cached_context in self._cached_context_manager:
78
+ self.remove_context(cached_context)
79
+
80
+ @contextlib.contextmanager
81
+ def enter_context(self, cached_context: CachedContext | str):
82
+ old_cached_context = self._current_context
83
+ if isinstance(cached_context, CachedContext):
84
+ self._current_context = cached_context
85
+ else:
86
+ self._current_context = self._cached_context_manager[cached_context]
87
+ try:
88
+ yield
89
+ finally:
90
+ self._current_context = old_cached_context
91
+
92
+ @staticmethod
93
+ def collect_cache_kwargs(
94
+ default_attrs: dict, **kwargs
95
+ ) -> Tuple[Dict, Dict]:
96
+ # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
97
+ # default_attrs: specific settings for different pipelines
98
+ cache_attrs = dataclasses.fields(CachedContext)
99
+ cache_attrs = [
100
+ attr
101
+ for attr in cache_attrs
102
+ if hasattr(
103
+ CachedContext,
104
+ attr.name,
105
+ )
106
+ ]
107
+ cache_kwargs = {
108
+ attr.name: kwargs.pop(
109
+ attr.name,
110
+ getattr(CachedContext, attr.name),
111
+ )
112
+ for attr in cache_attrs
113
+ }
114
+
115
+ def _safe_set_sequence_field(
116
+ field_name: str,
117
+ default_value: Any = None,
118
+ ):
119
+ if field_name not in cache_kwargs:
120
+ cache_kwargs[field_name] = kwargs.pop(
121
+ field_name,
122
+ default_value,
123
+ )
124
+
125
+ # Manually set sequence fields, namely, Fn_compute_blocks_ids
126
+ # and Bn_compute_blocks_ids, which are lists or sets.
127
+ _safe_set_sequence_field("Fn_compute_blocks_ids", [])
128
+ _safe_set_sequence_field("Bn_compute_blocks_ids", [])
129
+ _safe_set_sequence_field("taylorseer_kwargs", {})
130
+
131
+ for attr in cache_attrs:
132
+ if attr.name in default_attrs: # can be empty {}
133
+ cache_kwargs[attr.name] = default_attrs[attr.name]
134
+
135
+ if logger.isEnabledFor(logging.DEBUG):
136
+ logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
137
+
138
+ return cache_kwargs, kwargs
139
+
140
+ @torch.compiler.disable
141
+ def get_residual_diff_threshold(self) -> float:
142
+ cached_context = self.get_context()
143
+ assert cached_context is not None, "cached_context must be set before"
144
+ return cached_context.get_residual_diff_threshold()
145
+
146
+ @torch.compiler.disable
147
+ def get_buffer(self, name) -> torch.Tensor:
148
+ cached_context = self.get_context()
149
+ assert cached_context is not None, "cached_context must be set before"
150
+ return cached_context.get_buffer(name)
151
+
152
+ @torch.compiler.disable
153
+ def set_buffer(self, name, buffer):
154
+ cached_context = self.get_context()
155
+ assert cached_context is not None, "cached_context must be set before"
156
+ cached_context.set_buffer(name, buffer)
157
+
158
+ @torch.compiler.disable
159
+ def remove_buffer(self, name):
160
+ cached_context = self.get_context()
161
+ assert cached_context is not None, "cached_context must be set before"
162
+ cached_context.remove_buffer(name)
163
+
164
+ @torch.compiler.disable
165
+ def mark_step_begin(self):
166
+ cached_context = self.get_context()
167
+ assert cached_context is not None, "cached_context must be set before"
168
+ cached_context.mark_step_begin()
169
+
170
+ @torch.compiler.disable
171
+ def get_current_step(self) -> int:
172
+ cached_context = self.get_context()
173
+ assert cached_context is not None, "cached_context must be set before"
174
+ return cached_context.get_current_step()
175
+
176
+ @torch.compiler.disable
177
+ def get_current_step_residual_diff(self) -> float:
178
+ cached_context = self.get_context()
179
+ assert cached_context is not None, "cached_context must be set before"
180
+ step = str(self.get_current_step())
181
+ residual_diffs = self.get_residual_diffs()
182
+ if step in residual_diffs:
183
+ return residual_diffs[step]
184
+ return None
185
+
186
+ @torch.compiler.disable
187
+ def get_current_step_cfg_residual_diff(self) -> float:
188
+ cached_context = self.get_context()
189
+ assert cached_context is not None, "cached_context must be set before"
190
+ step = str(self.get_current_step())
191
+ cfg_residual_diffs = self.get_cfg_residual_diffs()
192
+ if step in cfg_residual_diffs:
193
+ return cfg_residual_diffs[step]
194
+ return None
195
+
196
+ @torch.compiler.disable
197
+ def get_current_transformer_step(self) -> int:
198
+ cached_context = self.get_context()
199
+ assert cached_context is not None, "cached_context must be set before"
200
+ return cached_context.get_current_transformer_step()
201
+
202
+ @torch.compiler.disable
203
+ def get_cached_steps(self) -> List[int]:
204
+ cached_context = self.get_context()
205
+ assert cached_context is not None, "cached_context must be set before"
206
+ return cached_context.get_cached_steps()
207
+
208
+ @torch.compiler.disable
209
+ def get_cfg_cached_steps(self) -> List[int]:
210
+ cached_context = self.get_context()
211
+ assert cached_context is not None, "cached_context must be set before"
212
+ return cached_context.get_cfg_cached_steps()
213
+
214
+ @torch.compiler.disable
215
+ def get_max_cached_steps(self) -> int:
216
+ cached_context = self.get_context()
217
+ assert cached_context is not None, "cached_context must be set before"
218
+ return cached_context.max_cached_steps
219
+
220
+ @torch.compiler.disable
221
+ def get_max_continuous_cached_steps(self) -> int:
222
+ cached_context = self.get_context()
223
+ assert cached_context is not None, "cached_context must be set before"
224
+ return cached_context.max_continuous_cached_steps
225
+
226
+ @torch.compiler.disable
227
+ def get_continuous_cached_steps(self) -> int:
228
+ cached_context = self.get_context()
229
+ assert cached_context is not None, "cached_context must be set before"
230
+ return cached_context.continuous_cached_steps
231
+
232
+ @torch.compiler.disable
233
+ def get_cfg_continuous_cached_steps(self) -> int:
234
+ cached_context = self.get_context()
235
+ assert cached_context is not None, "cached_context must be set before"
236
+ return cached_context.cfg_continuous_cached_steps
237
+
238
+ @torch.compiler.disable
239
+ def add_cached_step(self):
240
+ cached_context = self.get_context()
241
+ assert cached_context is not None, "cached_context must be set before"
242
+ cached_context.add_cached_step()
243
+
244
+ @torch.compiler.disable
245
+ def add_residual_diff(self, diff):
246
+ cached_context = self.get_context()
247
+ assert cached_context is not None, "cached_context must be set before"
248
+ cached_context.add_residual_diff(diff)
249
+
250
+ @torch.compiler.disable
251
+ def get_residual_diffs(self) -> Dict[str, float]:
252
+ cached_context = self.get_context()
253
+ assert cached_context is not None, "cached_context must be set before"
254
+ return cached_context.get_residual_diffs()
255
+
256
+ @torch.compiler.disable
257
+ def get_cfg_residual_diffs(self) -> Dict[str, float]:
258
+ cached_context = self.get_context()
259
+ assert cached_context is not None, "cached_context must be set before"
260
+ return cached_context.get_cfg_residual_diffs()
261
+
262
+ @torch.compiler.disable
263
+ def is_taylorseer_enabled(self) -> bool:
264
+ cached_context = self.get_context()
265
+ assert cached_context is not None, "cached_context must be set before"
266
+ return cached_context.enable_taylorseer
267
+
268
+ @torch.compiler.disable
269
+ def is_encoder_taylorseer_enabled(self) -> bool:
270
+ cached_context = self.get_context()
271
+ assert cached_context is not None, "cached_context must be set before"
272
+ return cached_context.enable_encoder_taylorseer
273
+
274
+ def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
275
+ cached_context = self.get_context()
276
+ assert cached_context is not None, "cached_context must be set before"
277
+ return cached_context.get_taylorseers()
278
+
279
+ def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
280
+ cached_context = self.get_context()
281
+ assert cached_context is not None, "cached_context must be set before"
282
+ return cached_context.get_cfg_taylorseers()
283
+
284
+ @torch.compiler.disable
285
+ def is_taylorseer_cache_residual(self) -> bool:
286
+ cached_context = self.get_context()
287
+ assert cached_context is not None, "cached_context must be set before"
288
+ return cached_context.taylorseer_cache_type == "residual"
289
+
290
+ @torch.compiler.disable
291
+ def is_cache_residual(self) -> bool:
292
+ if self.is_taylorseer_enabled():
293
+ # residual or hidden_states
294
+ return self.is_taylorseer_cache_residual()
295
+ return True
296
+
297
+ @torch.compiler.disable
298
+ def is_encoder_cache_residual(self) -> bool:
299
+ if self.is_encoder_taylorseer_enabled():
300
+ # residual or hidden_states
301
+ return self.is_taylorseer_cache_residual()
302
+ return True
303
+
304
+ @torch.compiler.disable
305
+ def is_alter_cache_enabled(self) -> bool:
306
+ cached_context = self.get_context()
307
+ assert cached_context is not None, "cached_context must be set before"
308
+ return cached_context.enable_alter_cache
309
+
310
+ @torch.compiler.disable
311
+ def is_alter_cache(self) -> bool:
312
+ cached_context = self.get_context()
313
+ assert cached_context is not None, "cached_context must be set before"
314
+ return cached_context.is_alter_cache
315
+
316
+ @torch.compiler.disable
317
+ def is_in_warmup(self) -> bool:
318
+ cached_context = self.get_context()
319
+ assert cached_context is not None, "cached_context must be set before"
320
+ return cached_context.is_in_warmup()
321
+
322
+ @torch.compiler.disable
323
+ def is_l1_diff_enabled(self) -> bool:
324
+ cached_context = self.get_context()
325
+ assert cached_context is not None, "cached_context must be set before"
326
+ return (
327
+ cached_context.l1_hidden_states_diff_threshold is not None
328
+ and cached_context.l1_hidden_states_diff_threshold > 0.0
329
+ )
330
+
331
+ @torch.compiler.disable
332
+ def get_important_condition_threshold(self) -> float:
333
+ cached_context = self.get_context()
334
+ assert cached_context is not None, "cached_context must be set before"
335
+ return cached_context.important_condition_threshold
336
+
337
+ @torch.compiler.disable
338
+ def non_compute_blocks_diff_threshold(self) -> float:
339
+ cached_context = self.get_context()
340
+ assert cached_context is not None, "cached_context must be set before"
341
+ return cached_context.non_compute_blocks_diff_threshold
342
+
343
+ @torch.compiler.disable
344
+ def Fn_compute_blocks(self) -> int:
345
+ cached_context = self.get_context()
346
+ assert cached_context is not None, "cached_context must be set before"
347
+ assert (
348
+ cached_context.Fn_compute_blocks >= 1
349
+ ), "Fn_compute_blocks must be >= 1"
350
+ if cached_context.max_Fn_compute_blocks > 0:
351
+ # NOTE: Fn_compute_blocks can be 1, which means FB Cache
352
+ # but it must be less than or equal to max_Fn_compute_blocks
353
+ assert (
354
+ cached_context.Fn_compute_blocks
355
+ <= cached_context.max_Fn_compute_blocks
356
+ ), (
357
+ f"Fn_compute_blocks must be <= {cached_context.max_Fn_compute_blocks}, "
358
+ f"but got {cached_context.Fn_compute_blocks}"
359
+ )
360
+ return cached_context.Fn_compute_blocks
361
+
362
+ @torch.compiler.disable
363
+ def Fn_compute_blocks_ids(self) -> List[int]:
364
+ cached_context = self.get_context()
365
+ assert cached_context is not None, "cached_context must be set before"
366
+ assert (
367
+ len(cached_context.Fn_compute_blocks_ids)
368
+ <= cached_context.Fn_compute_blocks
369
+ ), (
370
+ "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
371
+ f"{cached_context.Fn_compute_blocks}, but got "
372
+ f"{len(cached_context.Fn_compute_blocks_ids)}"
373
+ )
374
+ return cached_context.Fn_compute_blocks_ids
375
+
376
+ @torch.compiler.disable
377
+ def Bn_compute_blocks(self) -> int:
378
+ cached_context = self.get_context()
379
+ assert cached_context is not None, "cached_context must be set before"
380
+ assert (
381
+ cached_context.Bn_compute_blocks >= 0
382
+ ), "Bn_compute_blocks must be >= 0"
383
+ if cached_context.max_Bn_compute_blocks > 0:
384
+ # NOTE: Bn_compute_blocks can be 0, which means FB Cache
385
+ # but it must be less than or equal to max_Bn_compute_blocks
386
+ assert (
387
+ cached_context.Bn_compute_blocks
388
+ <= cached_context.max_Bn_compute_blocks
389
+ ), (
390
+ f"Bn_compute_blocks must be <= {cached_context.max_Bn_compute_blocks}, "
391
+ f"but got {cached_context.Bn_compute_blocks}"
392
+ )
393
+ return cached_context.Bn_compute_blocks
394
+
395
+ @torch.compiler.disable
396
+ def Bn_compute_blocks_ids(self) -> List[int]:
397
+ cached_context = self.get_context()
398
+ assert cached_context is not None, "cached_context must be set before"
399
+ assert (
400
+ len(cached_context.Bn_compute_blocks_ids)
401
+ <= cached_context.Bn_compute_blocks
402
+ ), (
403
+ "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
404
+ f"{cached_context.Bn_compute_blocks}, but got "
405
+ f"{len(cached_context.Bn_compute_blocks_ids)}"
406
+ )
407
+ return cached_context.Bn_compute_blocks_ids
408
+
409
+ @torch.compiler.disable
410
+ def enable_spearate_cfg(self) -> bool:
411
+ cached_context = self.get_context()
412
+ assert cached_context is not None, "cached_context must be set before"
413
+ return cached_context.enable_spearate_cfg
414
+
415
+ @torch.compiler.disable
416
+ def is_separate_cfg_step(self) -> bool:
417
+ cached_context = self.get_context()
418
+ assert cached_context is not None, "cached_context must be set before"
419
+ return cached_context.is_separate_cfg_step()
420
+
421
+ @torch.compiler.disable
422
+ def cfg_diff_compute_separate(self) -> bool:
423
+ cached_context = self.get_context()
424
+ assert cached_context is not None, "cached_context must be set before"
425
+ return cached_context.cfg_diff_compute_separate
426
+
427
+ @torch.compiler.disable
428
+ def similarity(
429
+ self,
430
+ t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
431
+ t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
432
+ *,
433
+ threshold: float,
434
+ parallelized: bool = False,
435
+ prefix: str = "Fn", # for debugging
436
+ ) -> bool:
437
+ # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
438
+ # the threshold is always enabled, -2.0 means the shape is not matched.
439
+ if threshold <= 0.0:
440
+ self.add_residual_diff(-0.0)
441
+ return False
442
+
443
+ if threshold >= 1.0:
444
+ # If threshold is 1.0 or more, we consider them always similar.
445
+ self.add_residual_diff(-1.0)
446
+ return True
447
+
448
+ if t1.shape != t2.shape:
449
+ if logger.isEnabledFor(logging.DEBUG):
450
+ logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
451
+ self.add_residual_diff(-2.0)
452
+ return False
453
+
454
+ if all(
455
+ (
456
+ self.enable_spearate_cfg(),
457
+ self.is_separate_cfg_step(),
458
+ not self.cfg_diff_compute_separate(),
459
+ self.get_current_step_residual_diff() is not None,
460
+ )
461
+ ):
462
+ # Reuse computed diff value from non-CFG step
463
+ diff = self.get_current_step_residual_diff()
464
+ else:
465
+ # Find the most significant token through t1 and t2, and
466
+ # consider the diff of the significant token. The more significant,
467
+ # the more important.
468
+ condition_thresh = self.get_important_condition_threshold()
469
+ if condition_thresh > 0.0:
470
+ raw_diff = (t1 - t2).abs() # [B, seq_len, d]
471
+ token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
472
+ token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
473
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
474
+ token_diff = token_m_df / token_m_t1 # [B, seq_len]
475
+ condition: torch.Tensor = (
476
+ token_diff > condition_thresh
477
+ ) # [B, seq_len]
478
+ if condition.sum() > 0:
479
+ condition = condition.unsqueeze(-1) # [B, seq_len, 1]
480
+ condition = condition.expand_as(raw_diff) # [B, seq_len, d]
481
+ mean_diff = raw_diff[condition].mean()
482
+ mean_t1 = t1[condition].abs().mean()
483
+ else:
484
+ mean_diff = (t1 - t2).abs().mean()
485
+ mean_t1 = t1.abs().mean()
486
+ else:
487
+ # Use the mean of the absolute difference of the tensors
488
+ mean_diff = (t1 - t2).abs().mean()
489
+ mean_t1 = t1.abs().mean()
490
+
491
+ if parallelized:
492
+ dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
493
+ dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
494
+
495
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
496
+ # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
497
+ # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
498
+ diff = (mean_diff / mean_t1).item()
499
+
500
+ if logger.isEnabledFor(logging.DEBUG):
501
+ logger.debug(
502
+ f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}"
503
+ )
504
+
505
+ self.add_residual_diff(diff)
506
+
507
+ return diff < threshold
508
+
509
+ def _debugging_set_buffer(self, prefix):
510
+ if logger.isEnabledFor(logging.DEBUG):
511
+ logger.debug(
512
+ f"set {prefix}, "
513
+ f"transformer step: {self.get_current_transformer_step()}, "
514
+ f"executed step: {self.get_current_step()}"
515
+ )
516
+
517
+ def _debugging_get_buffer(self, prefix):
518
+ if logger.isEnabledFor(logging.DEBUG):
519
+ logger.debug(
520
+ f"get {prefix}, "
521
+ f"transformer step: {self.get_current_transformer_step()}, "
522
+ f"executed step: {self.get_current_step()}"
523
+ )
524
+
525
+ # Fn buffers
526
+ @torch.compiler.disable
527
+ def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
528
+ # Set hidden_states or residual for Fn blocks.
529
+ # This buffer is only use for L1 diff calculation.
530
+ downsample_factor = self.get_downsample_factor()
531
+ if downsample_factor > 1:
532
+ buffer = buffer[..., ::downsample_factor]
533
+ buffer = buffer.contiguous()
534
+ if self.is_separate_cfg_step():
535
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
536
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
537
+ else:
538
+ self._debugging_set_buffer(f"{prefix}_buffer")
539
+ self.set_buffer(f"{prefix}_buffer", buffer)
540
+
541
+ @torch.compiler.disable
542
+ def get_Fn_buffer(self, prefix: str = "Fn") -> torch.Tensor:
543
+ if self.is_separate_cfg_step():
544
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
545
+ return self.get_buffer(f"{prefix}_buffer_cfg")
546
+ self._debugging_get_buffer(f"{prefix}_buffer")
547
+ return self.get_buffer(f"{prefix}_buffer")
548
+
549
+ @torch.compiler.disable
550
+ def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
551
+ if self.is_separate_cfg_step():
552
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
553
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
554
+ else:
555
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
556
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
557
+
558
+ @torch.compiler.disable
559
+ def get_Fn_encoder_buffer(self, prefix: str = "Fn") -> torch.Tensor:
560
+ if self.is_separate_cfg_step():
561
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
562
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
563
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
564
+ return self.get_buffer(f"{prefix}_encoder_buffer")
565
+
566
+ # Bn buffers
567
+ @torch.compiler.disable
568
+ def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
569
+ # Set hidden_states or residual for Bn blocks.
570
+ # This buffer is use for hidden states approximation.
571
+ if self.is_taylorseer_enabled():
572
+ # taylorseer, encoder_taylorseer
573
+ if self.is_separate_cfg_step():
574
+ taylorseer, _ = self.get_cfg_taylorseers()
575
+ else:
576
+ taylorseer, _ = self.get_taylorseers()
577
+
578
+ if taylorseer is not None:
579
+ # Use TaylorSeer to update the buffer
580
+ taylorseer.update(buffer)
581
+ else:
582
+ if logger.isEnabledFor(logging.DEBUG):
583
+ logger.debug(
584
+ "TaylorSeer is enabled but not set in the cache context. "
585
+ "Falling back to default buffer retrieval."
586
+ )
587
+ if self.is_separate_cfg_step():
588
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
589
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
590
+ else:
591
+ self._debugging_set_buffer(f"{prefix}_buffer")
592
+ self.set_buffer(f"{prefix}_buffer", buffer)
593
+ else:
594
+ if self.is_separate_cfg_step():
595
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
596
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
597
+ else:
598
+ self._debugging_set_buffer(f"{prefix}_buffer")
599
+ self.set_buffer(f"{prefix}_buffer", buffer)
600
+
601
+ @torch.compiler.disable
602
+ def get_Bn_buffer(self, prefix: str = "Bn") -> torch.Tensor:
603
+ if self.is_taylorseer_enabled():
604
+ # taylorseer, encoder_taylorseer
605
+ if self.is_separate_cfg_step():
606
+ taylorseer, _ = self.get_cfg_taylorseers()
607
+ else:
608
+ taylorseer, _ = self.get_taylorseers()
609
+
610
+ if taylorseer is not None:
611
+ return taylorseer.approximate_value()
612
+ else:
613
+ if logger.isEnabledFor(logging.DEBUG):
614
+ logger.debug(
615
+ "TaylorSeer is enabled but not set in the cache context. "
616
+ "Falling back to default buffer retrieval."
617
+ )
618
+ # Fallback to default buffer retrieval
619
+ if self.is_separate_cfg_step():
620
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
621
+ return self.get_buffer(f"{prefix}_buffer_cfg")
622
+ self._debugging_get_buffer(f"{prefix}_buffer")
623
+ return self.get_buffer(f"{prefix}_buffer")
624
+ else:
625
+ if self.is_separate_cfg_step():
626
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
627
+ return self.get_buffer(f"{prefix}_buffer_cfg")
628
+ self._debugging_get_buffer(f"{prefix}_buffer")
629
+ return self.get_buffer(f"{prefix}_buffer")
630
+
631
+ @torch.compiler.disable
632
+ def set_Bn_encoder_buffer(
633
+ self, buffer: torch.Tensor | None, prefix: str = "Bn"
634
+ ):
635
+ # DON'T set None Buffer
636
+ if buffer is None:
637
+ return
638
+
639
+ # This buffer is use for encoder hidden states approximation.
640
+ if self.is_encoder_taylorseer_enabled():
641
+ # taylorseer, encoder_taylorseer
642
+ if self.is_separate_cfg_step():
643
+ _, encoder_taylorseer = self.get_cfg_taylorseers()
644
+ else:
645
+ _, encoder_taylorseer = self.get_taylorseers()
646
+
647
+ if encoder_taylorseer is not None:
648
+ # Use TaylorSeer to update the buffer
649
+ encoder_taylorseer.update(buffer)
650
+ else:
651
+ if logger.isEnabledFor(logging.DEBUG):
652
+ logger.debug(
653
+ "TaylorSeer is enabled but not set in the cache context. "
654
+ "Falling back to default buffer retrieval."
655
+ )
656
+ if self.is_separate_cfg_step():
657
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
658
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
659
+ else:
660
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
661
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
662
+ else:
663
+ if self.is_separate_cfg_step():
664
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
665
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
666
+ else:
667
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
668
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
669
+
670
+ @torch.compiler.disable
671
+ def get_Bn_encoder_buffer(self, prefix: str = "Bn") -> torch.Tensor:
672
+ if self.is_encoder_taylorseer_enabled():
673
+ if self.is_separate_cfg_step():
674
+ _, encoder_taylorseer = self.get_cfg_taylorseers()
675
+ else:
676
+ _, encoder_taylorseer = self.get_taylorseers()
677
+
678
+ if encoder_taylorseer is not None:
679
+ # Use TaylorSeer to approximate the value
680
+ return encoder_taylorseer.approximate_value()
681
+ else:
682
+ if logger.isEnabledFor(logging.DEBUG):
683
+ logger.debug(
684
+ "TaylorSeer is enabled but not set in the cache context. "
685
+ "Falling back to default buffer retrieval."
686
+ )
687
+ # Fallback to default buffer retrieval
688
+ if self.is_separate_cfg_step():
689
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
690
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
691
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
692
+ return self.get_buffer(f"{prefix}_encoder_buffer")
693
+ else:
694
+ if self.is_separate_cfg_step():
695
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
696
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
697
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
698
+ return self.get_buffer(f"{prefix}_encoder_buffer")
699
+
700
+ @torch.compiler.disable
701
+ def apply_cache(
702
+ self,
703
+ hidden_states: torch.Tensor,
704
+ encoder_hidden_states: torch.Tensor = None,
705
+ prefix: str = "Bn",
706
+ encoder_prefix: str = "Bn_encoder",
707
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
708
+ # Allow Bn and Fn prefix to be used for residual cache.
709
+ if "Bn" in prefix:
710
+ hidden_states_prev = self.get_Bn_buffer(prefix)
711
+ else:
712
+ hidden_states_prev = self.get_Fn_buffer(prefix)
713
+
714
+ assert (
715
+ hidden_states_prev is not None
716
+ ), f"{prefix}_buffer must be set before"
717
+
718
+ if self.is_cache_residual():
719
+ hidden_states = hidden_states_prev + hidden_states
720
+ else:
721
+ # If cache is not residual, we use the hidden states directly
722
+ hidden_states = hidden_states_prev
723
+
724
+ hidden_states = hidden_states.contiguous()
725
+
726
+ if encoder_hidden_states is not None:
727
+ if "Bn" in encoder_prefix:
728
+ encoder_hidden_states_prev = self.get_Bn_encoder_buffer(
729
+ encoder_prefix
730
+ )
731
+ else:
732
+ encoder_hidden_states_prev = self.get_Fn_encoder_buffer(
733
+ encoder_prefix
734
+ )
735
+
736
+ assert (
737
+ encoder_hidden_states_prev is not None
738
+ ), f"{prefix}_encoder_buffer must be set before"
739
+
740
+ if self.is_encoder_cache_residual():
741
+ encoder_hidden_states = (
742
+ encoder_hidden_states_prev + encoder_hidden_states
743
+ )
744
+ else:
745
+ # If encoder cache is not residual, we use the encoder hidden states directly
746
+ encoder_hidden_states = encoder_hidden_states_prev
747
+
748
+ encoder_hidden_states = encoder_hidden_states.contiguous()
749
+
750
+ return hidden_states, encoder_hidden_states
751
+
752
+ @torch.compiler.disable
753
+ def get_downsample_factor(self) -> float:
754
+ cached_context = self.get_context()
755
+ assert cached_context is not None, "cached_context must be set before"
756
+ return cached_context.downsample_factor
757
+
758
+ @torch.compiler.disable
759
+ def can_cache(
760
+ self,
761
+ states_tensor: torch.Tensor, # hidden_states or residual
762
+ parallelized: bool = False,
763
+ threshold: Optional[float] = None, # can manually set threshold
764
+ prefix: str = "Fn",
765
+ ) -> bool:
766
+
767
+ if self.is_in_warmup():
768
+ return False
769
+
770
+ # max cached steps
771
+ max_cached_steps = self.get_max_cached_steps()
772
+ if not self.is_separate_cfg_step():
773
+ cached_steps = self.get_cached_steps()
774
+ else:
775
+ cached_steps = self.get_cfg_cached_steps()
776
+
777
+ if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
778
+ if logger.isEnabledFor(logging.DEBUG):
779
+ logger.debug(
780
+ f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
781
+ "can not use cache."
782
+ )
783
+ return False
784
+
785
+ # max continuous cached steps
786
+ max_continuous_cached_steps = self.get_max_continuous_cached_steps()
787
+ if not self.is_separate_cfg_step():
788
+ continuous_cached_steps = self.get_continuous_cached_steps()
789
+ else:
790
+ continuous_cached_steps = self.get_cfg_continuous_cached_steps()
791
+
792
+ if max_continuous_cached_steps >= 0 and (
793
+ continuous_cached_steps >= max_continuous_cached_steps
794
+ ):
795
+ if logger.isEnabledFor(logging.DEBUG):
796
+ logger.debug(
797
+ f"{prefix}, max_continuous_cached_steps "
798
+ f"reached: {max_continuous_cached_steps}, "
799
+ "can not use cache."
800
+ )
801
+ # reset continuous cached steps stats
802
+ cached_context = self.get_context()
803
+ if not self.is_separate_cfg_step():
804
+ cached_context.continuous_cached_steps = 0
805
+ else:
806
+ cached_context.cfg_continuous_cached_steps = 0
807
+ return False
808
+
809
+ if threshold is None or threshold <= 0.0:
810
+ threshold = self.get_residual_diff_threshold()
811
+ if threshold <= 0.0:
812
+ return False
813
+
814
+ downsample_factor = self.get_downsample_factor()
815
+ if downsample_factor > 1 and "Bn" not in prefix:
816
+ states_tensor = states_tensor[..., ::downsample_factor]
817
+ states_tensor = states_tensor.contiguous()
818
+
819
+ # Allow Bn and Fn prefix to be used for diff calculation.
820
+ if "Bn" in prefix:
821
+ prev_states_tensor = self.get_Bn_buffer(prefix)
822
+ else:
823
+ prev_states_tensor = self.get_Fn_buffer(prefix)
824
+
825
+ if not self.is_alter_cache_enabled():
826
+ # Dynamic cache according to the residual diff
827
+ can_cache = prev_states_tensor is not None and self.similarity(
828
+ prev_states_tensor,
829
+ states_tensor,
830
+ threshold=threshold,
831
+ parallelized=parallelized,
832
+ prefix=prefix,
833
+ )
834
+ else:
835
+ # Only cache in the alter cache steps
836
+ can_cache = (
837
+ prev_states_tensor is not None
838
+ and self.similarity(
839
+ prev_states_tensor,
840
+ states_tensor,
841
+ threshold=threshold,
842
+ parallelized=parallelized,
843
+ prefix=prefix,
844
+ )
845
+ and self.is_alter_cache()
846
+ )
847
+ return can_cache