cache-dit 0.2.27__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,833 @@
1
+ import logging
2
+ import contextlib
3
+ import dataclasses
4
+ from typing import Any, Dict, Optional, Tuple, Union, List
5
+
6
+ import torch
7
+ import torch.distributed as dist
8
+
9
+ from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
10
+ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class CachedContextManager:
17
+ # Each Pipeline should have it's own context manager instance.
18
+
19
+ def __init__(self, name: str = None):
20
+ self.name = name
21
+ self._current_context: CachedContext = None
22
+ self._cached_context_manager: Dict[str, CachedContext] = {}
23
+
24
+ def new_context(self, *args, **kwargs) -> CachedContext:
25
+ _context = CachedContext(*args, **kwargs)
26
+ self._cached_context_manager[_context.name] = _context
27
+ return _context
28
+
29
+ def set_context(self, cached_context: CachedContext | str):
30
+ if isinstance(cached_context, CachedContext):
31
+ self._current_context = cached_context
32
+ else:
33
+ self._current_context = self._cached_context_manager[cached_context]
34
+
35
+ def get_context(self, name: str = None) -> CachedContext:
36
+ if name is not None:
37
+ if name not in self._cached_context_manager:
38
+ raise ValueError("Context not exist!")
39
+ return self._cached_context_manager[name]
40
+ return self._current_context
41
+
42
+ def reset_context(
43
+ self,
44
+ cached_context: CachedContext | str,
45
+ *args,
46
+ **kwargs,
47
+ ) -> CachedContext:
48
+ if isinstance(cached_context, CachedContext):
49
+ old_context_name = cached_context.name
50
+ if cached_context.name in self._cached_context_manager:
51
+ cached_context.clear_buffers()
52
+ del self._cached_context_manager[cached_context.name]
53
+ # force use old_context name
54
+ kwargs["name"] = old_context_name
55
+ _context = self.new_context(*args, **kwargs)
56
+ else:
57
+ old_context_name = cached_context
58
+ if cached_context in self._cached_context_manager:
59
+ self._cached_context_manager[cached_context].clear_buffers()
60
+ del self._cached_context_manager[cached_context]
61
+ # force use old_context name
62
+ kwargs["name"] = old_context_name
63
+ _context = self.new_context(*args, **kwargs)
64
+ return _context
65
+
66
+ @contextlib.contextmanager
67
+ def enter_context(self, cached_context: CachedContext | str):
68
+ old_cached_context = self._current_context
69
+ if isinstance(cached_context, CachedContext):
70
+ self._current_context = cached_context
71
+ else:
72
+ self._current_context = self._cached_context_manager[cached_context]
73
+ try:
74
+ yield
75
+ finally:
76
+ self._current_context = old_cached_context
77
+
78
+ @staticmethod
79
+ def collect_cache_kwargs(
80
+ default_attrs: dict, **kwargs
81
+ ) -> Tuple[Dict, Dict]:
82
+ # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
83
+ # default_attrs: specific settings for different pipelines
84
+ cache_attrs = dataclasses.fields(CachedContext)
85
+ cache_attrs = [
86
+ attr
87
+ for attr in cache_attrs
88
+ if hasattr(
89
+ CachedContext,
90
+ attr.name,
91
+ )
92
+ ]
93
+ cache_kwargs = {
94
+ attr.name: kwargs.pop(
95
+ attr.name,
96
+ getattr(CachedContext, attr.name),
97
+ )
98
+ for attr in cache_attrs
99
+ }
100
+
101
+ def _safe_set_sequence_field(
102
+ field_name: str,
103
+ default_value: Any = None,
104
+ ):
105
+ if field_name not in cache_kwargs:
106
+ cache_kwargs[field_name] = kwargs.pop(
107
+ field_name,
108
+ default_value,
109
+ )
110
+
111
+ # Manually set sequence fields, namely, Fn_compute_blocks_ids
112
+ # and Bn_compute_blocks_ids, which are lists or sets.
113
+ _safe_set_sequence_field("Fn_compute_blocks_ids", [])
114
+ _safe_set_sequence_field("Bn_compute_blocks_ids", [])
115
+ _safe_set_sequence_field("taylorseer_kwargs", {})
116
+
117
+ for attr in cache_attrs:
118
+ if attr.name in default_attrs: # can be empty {}
119
+ cache_kwargs[attr.name] = default_attrs[attr.name]
120
+
121
+ if logger.isEnabledFor(logging.DEBUG):
122
+ logger.debug(f"Collected Cache kwargs: {cache_kwargs}")
123
+
124
+ return cache_kwargs, kwargs
125
+
126
+ @torch.compiler.disable
127
+ def get_residual_diff_threshold(self) -> float:
128
+ cached_context = self.get_context()
129
+ assert cached_context is not None, "cached_context must be set before"
130
+ return cached_context.get_residual_diff_threshold()
131
+
132
+ @torch.compiler.disable
133
+ def get_buffer(self, name) -> torch.Tensor:
134
+ cached_context = self.get_context()
135
+ assert cached_context is not None, "cached_context must be set before"
136
+ return cached_context.get_buffer(name)
137
+
138
+ @torch.compiler.disable
139
+ def set_buffer(self, name, buffer):
140
+ cached_context = self.get_context()
141
+ assert cached_context is not None, "cached_context must be set before"
142
+ cached_context.set_buffer(name, buffer)
143
+
144
+ @torch.compiler.disable
145
+ def remove_buffer(self, name):
146
+ cached_context = self.get_context()
147
+ assert cached_context is not None, "cached_context must be set before"
148
+ cached_context.remove_buffer(name)
149
+
150
+ @torch.compiler.disable
151
+ def mark_step_begin(self):
152
+ cached_context = self.get_context()
153
+ assert cached_context is not None, "cached_context must be set before"
154
+ cached_context.mark_step_begin()
155
+
156
+ @torch.compiler.disable
157
+ def get_current_step(self) -> int:
158
+ cached_context = self.get_context()
159
+ assert cached_context is not None, "cached_context must be set before"
160
+ return cached_context.get_current_step()
161
+
162
+ @torch.compiler.disable
163
+ def get_current_step_residual_diff(self) -> float:
164
+ cached_context = self.get_context()
165
+ assert cached_context is not None, "cached_context must be set before"
166
+ step = str(self.get_current_step())
167
+ residual_diffs = self.get_residual_diffs()
168
+ if step in residual_diffs:
169
+ return residual_diffs[step]
170
+ return None
171
+
172
+ @torch.compiler.disable
173
+ def get_current_step_cfg_residual_diff(self) -> float:
174
+ cached_context = self.get_context()
175
+ assert cached_context is not None, "cached_context must be set before"
176
+ step = str(self.get_current_step())
177
+ cfg_residual_diffs = self.get_cfg_residual_diffs()
178
+ if step in cfg_residual_diffs:
179
+ return cfg_residual_diffs[step]
180
+ return None
181
+
182
+ @torch.compiler.disable
183
+ def get_current_transformer_step(self) -> int:
184
+ cached_context = self.get_context()
185
+ assert cached_context is not None, "cached_context must be set before"
186
+ return cached_context.get_current_transformer_step()
187
+
188
+ @torch.compiler.disable
189
+ def get_cached_steps(self) -> List[int]:
190
+ cached_context = self.get_context()
191
+ assert cached_context is not None, "cached_context must be set before"
192
+ return cached_context.get_cached_steps()
193
+
194
+ @torch.compiler.disable
195
+ def get_cfg_cached_steps(self) -> List[int]:
196
+ cached_context = self.get_context()
197
+ assert cached_context is not None, "cached_context must be set before"
198
+ return cached_context.get_cfg_cached_steps()
199
+
200
+ @torch.compiler.disable
201
+ def get_max_cached_steps(self) -> int:
202
+ cached_context = self.get_context()
203
+ assert cached_context is not None, "cached_context must be set before"
204
+ return cached_context.max_cached_steps
205
+
206
+ @torch.compiler.disable
207
+ def get_max_continuous_cached_steps(self) -> int:
208
+ cached_context = self.get_context()
209
+ assert cached_context is not None, "cached_context must be set before"
210
+ return cached_context.max_continuous_cached_steps
211
+
212
+ @torch.compiler.disable
213
+ def get_continuous_cached_steps(self) -> int:
214
+ cached_context = self.get_context()
215
+ assert cached_context is not None, "cached_context must be set before"
216
+ return cached_context.continuous_cached_steps
217
+
218
+ @torch.compiler.disable
219
+ def get_cfg_continuous_cached_steps(self) -> int:
220
+ cached_context = self.get_context()
221
+ assert cached_context is not None, "cached_context must be set before"
222
+ return cached_context.cfg_continuous_cached_steps
223
+
224
+ @torch.compiler.disable
225
+ def add_cached_step(self):
226
+ cached_context = self.get_context()
227
+ assert cached_context is not None, "cached_context must be set before"
228
+ cached_context.add_cached_step()
229
+
230
+ @torch.compiler.disable
231
+ def add_residual_diff(self, diff):
232
+ cached_context = self.get_context()
233
+ assert cached_context is not None, "cached_context must be set before"
234
+ cached_context.add_residual_diff(diff)
235
+
236
+ @torch.compiler.disable
237
+ def get_residual_diffs(self) -> Dict[str, float]:
238
+ cached_context = self.get_context()
239
+ assert cached_context is not None, "cached_context must be set before"
240
+ return cached_context.get_residual_diffs()
241
+
242
+ @torch.compiler.disable
243
+ def get_cfg_residual_diffs(self) -> Dict[str, float]:
244
+ cached_context = self.get_context()
245
+ assert cached_context is not None, "cached_context must be set before"
246
+ return cached_context.get_cfg_residual_diffs()
247
+
248
+ @torch.compiler.disable
249
+ def is_taylorseer_enabled(self) -> bool:
250
+ cached_context = self.get_context()
251
+ assert cached_context is not None, "cached_context must be set before"
252
+ return cached_context.enable_taylorseer
253
+
254
+ @torch.compiler.disable
255
+ def is_encoder_taylorseer_enabled(self) -> bool:
256
+ cached_context = self.get_context()
257
+ assert cached_context is not None, "cached_context must be set before"
258
+ return cached_context.enable_encoder_taylorseer
259
+
260
+ def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
261
+ cached_context = self.get_context()
262
+ assert cached_context is not None, "cached_context must be set before"
263
+ return cached_context.get_taylorseers()
264
+
265
+ def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
266
+ cached_context = self.get_context()
267
+ assert cached_context is not None, "cached_context must be set before"
268
+ return cached_context.get_cfg_taylorseers()
269
+
270
+ @torch.compiler.disable
271
+ def is_taylorseer_cache_residual(self) -> bool:
272
+ cached_context = self.get_context()
273
+ assert cached_context is not None, "cached_context must be set before"
274
+ return cached_context.taylorseer_cache_type == "residual"
275
+
276
+ @torch.compiler.disable
277
+ def is_cache_residual(self) -> bool:
278
+ if self.is_taylorseer_enabled():
279
+ # residual or hidden_states
280
+ return self.is_taylorseer_cache_residual()
281
+ return True
282
+
283
+ @torch.compiler.disable
284
+ def is_encoder_cache_residual(self) -> bool:
285
+ if self.is_encoder_taylorseer_enabled():
286
+ # residual or hidden_states
287
+ return self.is_taylorseer_cache_residual()
288
+ return True
289
+
290
+ @torch.compiler.disable
291
+ def is_alter_cache_enabled(self) -> bool:
292
+ cached_context = self.get_context()
293
+ assert cached_context is not None, "cached_context must be set before"
294
+ return cached_context.enable_alter_cache
295
+
296
+ @torch.compiler.disable
297
+ def is_alter_cache(self) -> bool:
298
+ cached_context = self.get_context()
299
+ assert cached_context is not None, "cached_context must be set before"
300
+ return cached_context.is_alter_cache
301
+
302
+ @torch.compiler.disable
303
+ def is_in_warmup(self) -> bool:
304
+ cached_context = self.get_context()
305
+ assert cached_context is not None, "cached_context must be set before"
306
+ return cached_context.is_in_warmup()
307
+
308
+ @torch.compiler.disable
309
+ def is_l1_diff_enabled(self) -> bool:
310
+ cached_context = self.get_context()
311
+ assert cached_context is not None, "cached_context must be set before"
312
+ return (
313
+ cached_context.l1_hidden_states_diff_threshold is not None
314
+ and cached_context.l1_hidden_states_diff_threshold > 0.0
315
+ )
316
+
317
+ @torch.compiler.disable
318
+ def get_important_condition_threshold(self) -> float:
319
+ cached_context = self.get_context()
320
+ assert cached_context is not None, "cached_context must be set before"
321
+ return cached_context.important_condition_threshold
322
+
323
+ @torch.compiler.disable
324
+ def non_compute_blocks_diff_threshold(self) -> float:
325
+ cached_context = self.get_context()
326
+ assert cached_context is not None, "cached_context must be set before"
327
+ return cached_context.non_compute_blocks_diff_threshold
328
+
329
+ @torch.compiler.disable
330
+ def Fn_compute_blocks(self) -> int:
331
+ cached_context = self.get_context()
332
+ assert cached_context is not None, "cached_context must be set before"
333
+ assert (
334
+ cached_context.Fn_compute_blocks >= 1
335
+ ), "Fn_compute_blocks must be >= 1"
336
+ if cached_context.max_Fn_compute_blocks > 0:
337
+ # NOTE: Fn_compute_blocks can be 1, which means FB Cache
338
+ # but it must be less than or equal to max_Fn_compute_blocks
339
+ assert (
340
+ cached_context.Fn_compute_blocks
341
+ <= cached_context.max_Fn_compute_blocks
342
+ ), (
343
+ f"Fn_compute_blocks must be <= {cached_context.max_Fn_compute_blocks}, "
344
+ f"but got {cached_context.Fn_compute_blocks}"
345
+ )
346
+ return cached_context.Fn_compute_blocks
347
+
348
+ @torch.compiler.disable
349
+ def Fn_compute_blocks_ids(self) -> List[int]:
350
+ cached_context = self.get_context()
351
+ assert cached_context is not None, "cached_context must be set before"
352
+ assert (
353
+ len(cached_context.Fn_compute_blocks_ids)
354
+ <= cached_context.Fn_compute_blocks
355
+ ), (
356
+ "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
357
+ f"{cached_context.Fn_compute_blocks}, but got "
358
+ f"{len(cached_context.Fn_compute_blocks_ids)}"
359
+ )
360
+ return cached_context.Fn_compute_blocks_ids
361
+
362
+ @torch.compiler.disable
363
+ def Bn_compute_blocks(self) -> int:
364
+ cached_context = self.get_context()
365
+ assert cached_context is not None, "cached_context must be set before"
366
+ assert (
367
+ cached_context.Bn_compute_blocks >= 0
368
+ ), "Bn_compute_blocks must be >= 0"
369
+ if cached_context.max_Bn_compute_blocks > 0:
370
+ # NOTE: Bn_compute_blocks can be 0, which means FB Cache
371
+ # but it must be less than or equal to max_Bn_compute_blocks
372
+ assert (
373
+ cached_context.Bn_compute_blocks
374
+ <= cached_context.max_Bn_compute_blocks
375
+ ), (
376
+ f"Bn_compute_blocks must be <= {cached_context.max_Bn_compute_blocks}, "
377
+ f"but got {cached_context.Bn_compute_blocks}"
378
+ )
379
+ return cached_context.Bn_compute_blocks
380
+
381
+ @torch.compiler.disable
382
+ def Bn_compute_blocks_ids(self) -> List[int]:
383
+ cached_context = self.get_context()
384
+ assert cached_context is not None, "cached_context must be set before"
385
+ assert (
386
+ len(cached_context.Bn_compute_blocks_ids)
387
+ <= cached_context.Bn_compute_blocks
388
+ ), (
389
+ "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
390
+ f"{cached_context.Bn_compute_blocks}, but got "
391
+ f"{len(cached_context.Bn_compute_blocks_ids)}"
392
+ )
393
+ return cached_context.Bn_compute_blocks_ids
394
+
395
+ @torch.compiler.disable
396
+ def enable_spearate_cfg(self) -> bool:
397
+ cached_context = self.get_context()
398
+ assert cached_context is not None, "cached_context must be set before"
399
+ return cached_context.enable_spearate_cfg
400
+
401
+ @torch.compiler.disable
402
+ def is_separate_cfg_step(self) -> bool:
403
+ cached_context = self.get_context()
404
+ assert cached_context is not None, "cached_context must be set before"
405
+ return cached_context.is_separate_cfg_step()
406
+
407
+ @torch.compiler.disable
408
+ def cfg_diff_compute_separate(self) -> bool:
409
+ cached_context = self.get_context()
410
+ assert cached_context is not None, "cached_context must be set before"
411
+ return cached_context.cfg_diff_compute_separate
412
+
413
+ @torch.compiler.disable
414
+ def similarity(
415
+ self,
416
+ t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
417
+ t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
418
+ *,
419
+ threshold: float,
420
+ parallelized: bool = False,
421
+ prefix: str = "Fn", # for debugging
422
+ ) -> bool:
423
+ # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
424
+ # the threshold is always enabled, -2.0 means the shape is not matched.
425
+ if threshold <= 0.0:
426
+ self.add_residual_diff(-0.0)
427
+ return False
428
+
429
+ if threshold >= 1.0:
430
+ # If threshold is 1.0 or more, we consider them always similar.
431
+ self.add_residual_diff(-1.0)
432
+ return True
433
+
434
+ if t1.shape != t2.shape:
435
+ if logger.isEnabledFor(logging.DEBUG):
436
+ logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
437
+ self.add_residual_diff(-2.0)
438
+ return False
439
+
440
+ if all(
441
+ (
442
+ self.enable_spearate_cfg(),
443
+ self.is_separate_cfg_step(),
444
+ not self.cfg_diff_compute_separate(),
445
+ self.get_current_step_residual_diff() is not None,
446
+ )
447
+ ):
448
+ # Reuse computed diff value from non-CFG step
449
+ diff = self.get_current_step_residual_diff()
450
+ else:
451
+ # Find the most significant token through t1 and t2, and
452
+ # consider the diff of the significant token. The more significant,
453
+ # the more important.
454
+ condition_thresh = self.get_important_condition_threshold()
455
+ if condition_thresh > 0.0:
456
+ raw_diff = (t1 - t2).abs() # [B, seq_len, d]
457
+ token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
458
+ token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
459
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
460
+ token_diff = token_m_df / token_m_t1 # [B, seq_len]
461
+ condition: torch.Tensor = (
462
+ token_diff > condition_thresh
463
+ ) # [B, seq_len]
464
+ if condition.sum() > 0:
465
+ condition = condition.unsqueeze(-1) # [B, seq_len, 1]
466
+ condition = condition.expand_as(raw_diff) # [B, seq_len, d]
467
+ mean_diff = raw_diff[condition].mean()
468
+ mean_t1 = t1[condition].abs().mean()
469
+ else:
470
+ mean_diff = (t1 - t2).abs().mean()
471
+ mean_t1 = t1.abs().mean()
472
+ else:
473
+ # Use the mean of the absolute difference of the tensors
474
+ mean_diff = (t1 - t2).abs().mean()
475
+ mean_t1 = t1.abs().mean()
476
+
477
+ if parallelized:
478
+ dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
479
+ dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
480
+
481
+ # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
482
+ # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
483
+ # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
484
+ diff = (mean_diff / mean_t1).item()
485
+
486
+ if logger.isEnabledFor(logging.DEBUG):
487
+ logger.debug(
488
+ f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}"
489
+ )
490
+
491
+ self.add_residual_diff(diff)
492
+
493
+ return diff < threshold
494
+
495
+ def _debugging_set_buffer(self, prefix):
496
+ if logger.isEnabledFor(logging.DEBUG):
497
+ logger.debug(
498
+ f"set {prefix}, "
499
+ f"transformer step: {self.get_current_transformer_step()}, "
500
+ f"executed step: {self.get_current_step()}"
501
+ )
502
+
503
+ def _debugging_get_buffer(self, prefix):
504
+ if logger.isEnabledFor(logging.DEBUG):
505
+ logger.debug(
506
+ f"get {prefix}, "
507
+ f"transformer step: {self.get_current_transformer_step()}, "
508
+ f"executed step: {self.get_current_step()}"
509
+ )
510
+
511
+ # Fn buffers
512
+ @torch.compiler.disable
513
+ def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
514
+ # Set hidden_states or residual for Fn blocks.
515
+ # This buffer is only use for L1 diff calculation.
516
+ downsample_factor = self.get_downsample_factor()
517
+ if downsample_factor > 1:
518
+ buffer = buffer[..., ::downsample_factor]
519
+ buffer = buffer.contiguous()
520
+ if self.is_separate_cfg_step():
521
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
522
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
523
+ else:
524
+ self._debugging_set_buffer(f"{prefix}_buffer")
525
+ self.set_buffer(f"{prefix}_buffer", buffer)
526
+
527
+ @torch.compiler.disable
528
+ def get_Fn_buffer(self, prefix: str = "Fn") -> torch.Tensor:
529
+ if self.is_separate_cfg_step():
530
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
531
+ return self.get_buffer(f"{prefix}_buffer_cfg")
532
+ self._debugging_get_buffer(f"{prefix}_buffer")
533
+ return self.get_buffer(f"{prefix}_buffer")
534
+
535
+ @torch.compiler.disable
536
+ def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
537
+ if self.is_separate_cfg_step():
538
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
539
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
540
+ else:
541
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
542
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
543
+
544
+ @torch.compiler.disable
545
+ def get_Fn_encoder_buffer(self, prefix: str = "Fn") -> torch.Tensor:
546
+ if self.is_separate_cfg_step():
547
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
548
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
549
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
550
+ return self.get_buffer(f"{prefix}_encoder_buffer")
551
+
552
+ # Bn buffers
553
+ @torch.compiler.disable
554
+ def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
555
+ # Set hidden_states or residual for Bn blocks.
556
+ # This buffer is use for hidden states approximation.
557
+ if self.is_taylorseer_enabled():
558
+ # taylorseer, encoder_taylorseer
559
+ if self.is_separate_cfg_step():
560
+ taylorseer, _ = self.get_cfg_taylorseers()
561
+ else:
562
+ taylorseer, _ = self.get_taylorseers()
563
+
564
+ if taylorseer is not None:
565
+ # Use TaylorSeer to update the buffer
566
+ taylorseer.update(buffer)
567
+ else:
568
+ if logger.isEnabledFor(logging.DEBUG):
569
+ logger.debug(
570
+ "TaylorSeer is enabled but not set in the cache context. "
571
+ "Falling back to default buffer retrieval."
572
+ )
573
+ if self.is_separate_cfg_step():
574
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
575
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
576
+ else:
577
+ self._debugging_set_buffer(f"{prefix}_buffer")
578
+ self.set_buffer(f"{prefix}_buffer", buffer)
579
+ else:
580
+ if self.is_separate_cfg_step():
581
+ self._debugging_set_buffer(f"{prefix}_buffer_cfg")
582
+ self.set_buffer(f"{prefix}_buffer_cfg", buffer)
583
+ else:
584
+ self._debugging_set_buffer(f"{prefix}_buffer")
585
+ self.set_buffer(f"{prefix}_buffer", buffer)
586
+
587
+ @torch.compiler.disable
588
+ def get_Bn_buffer(self, prefix: str = "Bn") -> torch.Tensor:
589
+ if self.is_taylorseer_enabled():
590
+ # taylorseer, encoder_taylorseer
591
+ if self.is_separate_cfg_step():
592
+ taylorseer, _ = self.get_cfg_taylorseers()
593
+ else:
594
+ taylorseer, _ = self.get_taylorseers()
595
+
596
+ if taylorseer is not None:
597
+ return taylorseer.approximate_value()
598
+ else:
599
+ if logger.isEnabledFor(logging.DEBUG):
600
+ logger.debug(
601
+ "TaylorSeer is enabled but not set in the cache context. "
602
+ "Falling back to default buffer retrieval."
603
+ )
604
+ # Fallback to default buffer retrieval
605
+ if self.is_separate_cfg_step():
606
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
607
+ return self.get_buffer(f"{prefix}_buffer_cfg")
608
+ self._debugging_get_buffer(f"{prefix}_buffer")
609
+ return self.get_buffer(f"{prefix}_buffer")
610
+ else:
611
+ if self.is_separate_cfg_step():
612
+ self._debugging_get_buffer(f"{prefix}_buffer_cfg")
613
+ return self.get_buffer(f"{prefix}_buffer_cfg")
614
+ self._debugging_get_buffer(f"{prefix}_buffer")
615
+ return self.get_buffer(f"{prefix}_buffer")
616
+
617
+ @torch.compiler.disable
618
+ def set_Bn_encoder_buffer(
619
+ self, buffer: torch.Tensor | None, prefix: str = "Bn"
620
+ ):
621
+ # DON'T set None Buffer
622
+ if buffer is None:
623
+ return
624
+
625
+ # This buffer is use for encoder hidden states approximation.
626
+ if self.is_encoder_taylorseer_enabled():
627
+ # taylorseer, encoder_taylorseer
628
+ if self.is_separate_cfg_step():
629
+ _, encoder_taylorseer = self.get_cfg_taylorseers()
630
+ else:
631
+ _, encoder_taylorseer = self.get_taylorseers()
632
+
633
+ if encoder_taylorseer is not None:
634
+ # Use TaylorSeer to update the buffer
635
+ encoder_taylorseer.update(buffer)
636
+ else:
637
+ if logger.isEnabledFor(logging.DEBUG):
638
+ logger.debug(
639
+ "TaylorSeer is enabled but not set in the cache context. "
640
+ "Falling back to default buffer retrieval."
641
+ )
642
+ if self.is_separate_cfg_step():
643
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
644
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
645
+ else:
646
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
647
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
648
+ else:
649
+ if self.is_separate_cfg_step():
650
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
651
+ self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
652
+ else:
653
+ self._debugging_set_buffer(f"{prefix}_encoder_buffer")
654
+ self.set_buffer(f"{prefix}_encoder_buffer", buffer)
655
+
656
+ @torch.compiler.disable
657
+ def get_Bn_encoder_buffer(self, prefix: str = "Bn") -> torch.Tensor:
658
+ if self.is_encoder_taylorseer_enabled():
659
+ if self.is_separate_cfg_step():
660
+ _, encoder_taylorseer = self.get_cfg_taylorseers()
661
+ else:
662
+ _, encoder_taylorseer = self.get_taylorseers()
663
+
664
+ if encoder_taylorseer is not None:
665
+ # Use TaylorSeer to approximate the value
666
+ return encoder_taylorseer.approximate_value()
667
+ else:
668
+ if logger.isEnabledFor(logging.DEBUG):
669
+ logger.debug(
670
+ "TaylorSeer is enabled but not set in the cache context. "
671
+ "Falling back to default buffer retrieval."
672
+ )
673
+ # Fallback to default buffer retrieval
674
+ if self.is_separate_cfg_step():
675
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
676
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
677
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
678
+ return self.get_buffer(f"{prefix}_encoder_buffer")
679
+ else:
680
+ if self.is_separate_cfg_step():
681
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
682
+ return self.get_buffer(f"{prefix}_encoder_buffer_cfg")
683
+ self._debugging_get_buffer(f"{prefix}_encoder_buffer")
684
+ return self.get_buffer(f"{prefix}_encoder_buffer")
685
+
686
+ @torch.compiler.disable
687
+ def apply_cache(
688
+ self,
689
+ hidden_states: torch.Tensor,
690
+ encoder_hidden_states: torch.Tensor = None,
691
+ prefix: str = "Bn",
692
+ encoder_prefix: str = "Bn_encoder",
693
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
694
+ # Allow Bn and Fn prefix to be used for residual cache.
695
+ if "Bn" in prefix:
696
+ hidden_states_prev = self.get_Bn_buffer(prefix)
697
+ else:
698
+ hidden_states_prev = self.get_Fn_buffer(prefix)
699
+
700
+ assert (
701
+ hidden_states_prev is not None
702
+ ), f"{prefix}_buffer must be set before"
703
+
704
+ if self.is_cache_residual():
705
+ hidden_states = hidden_states_prev + hidden_states
706
+ else:
707
+ # If cache is not residual, we use the hidden states directly
708
+ hidden_states = hidden_states_prev
709
+
710
+ hidden_states = hidden_states.contiguous()
711
+
712
+ if encoder_hidden_states is not None:
713
+ if "Bn" in encoder_prefix:
714
+ encoder_hidden_states_prev = self.get_Bn_encoder_buffer(
715
+ encoder_prefix
716
+ )
717
+ else:
718
+ encoder_hidden_states_prev = self.get_Fn_encoder_buffer(
719
+ encoder_prefix
720
+ )
721
+
722
+ assert (
723
+ encoder_hidden_states_prev is not None
724
+ ), f"{prefix}_encoder_buffer must be set before"
725
+
726
+ if self.is_encoder_cache_residual():
727
+ encoder_hidden_states = (
728
+ encoder_hidden_states_prev + encoder_hidden_states
729
+ )
730
+ else:
731
+ # If encoder cache is not residual, we use the encoder hidden states directly
732
+ encoder_hidden_states = encoder_hidden_states_prev
733
+
734
+ encoder_hidden_states = encoder_hidden_states.contiguous()
735
+
736
+ return hidden_states, encoder_hidden_states
737
+
738
+ @torch.compiler.disable
739
+ def get_downsample_factor(self) -> float:
740
+ cached_context = self.get_context()
741
+ assert cached_context is not None, "cached_context must be set before"
742
+ return cached_context.downsample_factor
743
+
744
+ @torch.compiler.disable
745
+ def can_cache(
746
+ self,
747
+ states_tensor: torch.Tensor, # hidden_states or residual
748
+ parallelized: bool = False,
749
+ threshold: Optional[float] = None, # can manually set threshold
750
+ prefix: str = "Fn",
751
+ ) -> bool:
752
+
753
+ if self.is_in_warmup():
754
+ return False
755
+
756
+ # max cached steps
757
+ max_cached_steps = self.get_max_cached_steps()
758
+ if not self.is_separate_cfg_step():
759
+ cached_steps = self.get_cached_steps()
760
+ else:
761
+ cached_steps = self.get_cfg_cached_steps()
762
+
763
+ if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
764
+ if logger.isEnabledFor(logging.DEBUG):
765
+ logger.debug(
766
+ f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
767
+ "can not use cache."
768
+ )
769
+ return False
770
+
771
+ # max continuous cached steps
772
+ max_continuous_cached_steps = self.get_max_continuous_cached_steps()
773
+ if not self.is_separate_cfg_step():
774
+ continuous_cached_steps = self.get_continuous_cached_steps()
775
+ else:
776
+ continuous_cached_steps = self.get_cfg_continuous_cached_steps()
777
+
778
+ if max_continuous_cached_steps >= 0 and (
779
+ continuous_cached_steps >= max_continuous_cached_steps
780
+ ):
781
+ if logger.isEnabledFor(logging.DEBUG):
782
+ logger.debug(
783
+ f"{prefix}, max_continuous_cached_steps "
784
+ f"reached: {max_continuous_cached_steps}, "
785
+ "can not use cache."
786
+ )
787
+ # reset continuous cached steps stats
788
+ cached_context = self.get_context()
789
+ if not self.is_separate_cfg_step():
790
+ cached_context.continuous_cached_steps = 0
791
+ else:
792
+ cached_context.cfg_continuous_cached_steps = 0
793
+ return False
794
+
795
+ if threshold is None or threshold <= 0.0:
796
+ threshold = self.get_residual_diff_threshold()
797
+ if threshold <= 0.0:
798
+ return False
799
+
800
+ downsample_factor = self.get_downsample_factor()
801
+ if downsample_factor > 1 and "Bn" not in prefix:
802
+ states_tensor = states_tensor[..., ::downsample_factor]
803
+ states_tensor = states_tensor.contiguous()
804
+
805
+ # Allow Bn and Fn prefix to be used for diff calculation.
806
+ if "Bn" in prefix:
807
+ prev_states_tensor = self.get_Bn_buffer(prefix)
808
+ else:
809
+ prev_states_tensor = self.get_Fn_buffer(prefix)
810
+
811
+ if not self.is_alter_cache_enabled():
812
+ # Dynamic cache according to the residual diff
813
+ can_cache = prev_states_tensor is not None and self.similarity(
814
+ prev_states_tensor,
815
+ states_tensor,
816
+ threshold=threshold,
817
+ parallelized=parallelized,
818
+ prefix=prefix,
819
+ )
820
+ else:
821
+ # Only cache in the alter cache steps
822
+ can_cache = (
823
+ prev_states_tensor is not None
824
+ and self.similarity(
825
+ prev_states_tensor,
826
+ states_tensor,
827
+ threshold=threshold,
828
+ parallelized=parallelized,
829
+ prefix=prefix,
830
+ )
831
+ and self.is_alter_cache()
832
+ )
833
+ return can_cache