cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (29) hide show
  1. cache_dit/__init__.py +8 -6
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +17 -4
  4. cache_dit/cache_factory/block_adapters/__init__.py +555 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +262 -938
  8. cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
  11. cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
  12. cache_dit/cache_factory/cache_blocks/utils.py +16 -10
  13. cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
  14. cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
  15. cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
  16. cache_dit/cache_factory/cache_interface.py +31 -31
  17. cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
  18. cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
  19. cache_dit/quantize/quantize_ao.py +1 -0
  20. cache_dit/utils.py +26 -26
  21. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
  22. cache_dit-0.2.28.dist-info/RECORD +47 -0
  23. cache_dit/cache_factory/cache_context.py +0 -1155
  24. cache_dit-0.2.26.dist-info/RECORD +0 -42
  25. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  26. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
@@ -1,1155 +0,0 @@
1
- import logging
2
- import contextlib
3
- import dataclasses
4
- from collections import defaultdict
5
- from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
6
-
7
- import torch
8
- import torch.distributed as dist
9
-
10
- from cache_dit.cache_factory.taylorseer import TaylorSeer
11
- from cache_dit.logger import init_logger
12
-
13
- logger = init_logger(__name__)
14
-
15
-
16
- @dataclasses.dataclass
17
- class DBCacheContext:
18
- # Dual Block Cache
19
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
20
- Fn_compute_blocks: int = 1
21
- Bn_compute_blocks: int = 0
22
- # We have added residual cache pattern for selected compute blocks
23
- Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
24
- Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
25
- # non compute blocks diff threshold, we don't skip the non
26
- # compute blocks if the diff >= threshold
27
- non_compute_blocks_diff_threshold: float = 0.08
28
- max_Fn_compute_blocks: int = -1
29
- max_Bn_compute_blocks: int = -1
30
- # L1 hidden states or residual diff threshold for Fn
31
- residual_diff_threshold: Union[torch.Tensor, float] = 0.05
32
- l1_hidden_states_diff_threshold: float = None
33
- important_condition_threshold: float = 0.0
34
-
35
- # Alter Cache Settings
36
- # Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
37
- enable_alter_cache: bool = False
38
- is_alter_cache: bool = True
39
- # 1.0 means we always cache the residuals if alter_cache is enabled.
40
- alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
41
-
42
- # Buffer for storing the residuals and other tensors
43
- buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
44
- incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
45
- default_factory=lambda: defaultdict(int),
46
- )
47
-
48
- # Other settings
49
- downsample_factor: int = 1
50
- num_inference_steps: int = -1 # for future use
51
- max_warmup_steps: int = 0 # DON'T Cache in warmup steps
52
- # DON'T Cache if the number of cached steps >= max_cached_steps
53
- max_cached_steps: int = -1 # for both CFG and non-CFG
54
- max_continuous_cached_steps: int = -1 # the max continuous cached steps
55
-
56
- # Record the steps that have been cached, both cached and non-cache
57
- executed_steps: int = 0 # cache + non-cache steps pippeline
58
- # steps for transformer, for CFG, transformer_executed_steps will
59
- # be double of executed_steps.
60
- transformer_executed_steps: int = 0
61
-
62
- # Support TaylorSeers in Dual Block Cache
63
- # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
64
- # Url: https://arxiv.org/pdf/2503.06923
65
- enable_taylorseer: bool = False
66
- enable_encoder_taylorseer: bool = False
67
- # NOTE: use residual cache for taylorseer may incur precision loss
68
- taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
69
- taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
70
- taylorseer: Optional[TaylorSeer] = None
71
- encoder_tarlorseer: Optional[TaylorSeer] = None
72
-
73
- # Support do_separate_cfg, such as Wan 2.1,
74
- # Qwen-Image. For model that fused CFG and non-CFG into single
75
- # forward step, should set do_separate_cfg as False.
76
- # For example: CogVideoX, HunyuanVideo, Mochi.
77
- do_separate_cfg: bool = False
78
- # Compute cfg forward first or not, default False, namely,
79
- # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
80
- cfg_compute_first: bool = False
81
- # Compute spearate diff values for CFG and non-CFG step,
82
- # default True. If False, we will use the computed diff from
83
- # current non-CFG transformer step for current CFG step.
84
- cfg_diff_compute_separate: bool = True
85
- cfg_taylorseer: Optional[TaylorSeer] = None
86
- cfg_encoder_taylorseer: Optional[TaylorSeer] = None
87
-
88
- # CFG & non-CFG cached steps
89
- cached_steps: List[int] = dataclasses.field(default_factory=list)
90
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
91
- default_factory=lambda: defaultdict(float),
92
- )
93
- continuous_cached_steps: int = 0
94
- cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
95
- cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
96
- default_factory=lambda: defaultdict(float),
97
- )
98
- cfg_continuous_cached_steps: int = 0
99
-
100
- @torch.compiler.disable
101
- def __post_init__(self):
102
- # Some checks for settings
103
- if self.do_separate_cfg:
104
- assert self.enable_alter_cache is False, (
105
- "enable_alter_cache must set as False if "
106
- "do_separate_cfg is enabled."
107
- )
108
- if self.cfg_diff_compute_separate:
109
- assert self.cfg_compute_first is False, (
110
- "cfg_compute_first must set as False if "
111
- "cfg_diff_compute_separate is enabled."
112
- )
113
-
114
- if "max_warmup_steps" not in self.taylorseer_kwargs:
115
- # If max_warmup_steps is not set in taylorseer_kwargs,
116
- # set the same as max_warmup_steps for DBCache
117
- self.taylorseer_kwargs["max_warmup_steps"] = (
118
- self.max_warmup_steps if self.max_warmup_steps > 0 else 1
119
- )
120
-
121
- # Only set n_derivatives as 2 or 3, which is enough for most cases.
122
- if "n_derivatives" not in self.taylorseer_kwargs:
123
- self.taylorseer_kwargs["n_derivatives"] = max(
124
- 2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
125
- )
126
-
127
- if self.enable_taylorseer:
128
- self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
129
- if self.do_separate_cfg:
130
- self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
131
-
132
- if self.enable_encoder_taylorseer:
133
- self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
134
- if self.do_separate_cfg:
135
- self.cfg_encoder_taylorseer = TaylorSeer(
136
- **self.taylorseer_kwargs
137
- )
138
-
139
- @torch.compiler.disable
140
- def get_residual_diff_threshold(self):
141
- if self.enable_alter_cache:
142
- residual_diff_threshold = self.alter_residual_diff_threshold
143
- else:
144
- residual_diff_threshold = self.residual_diff_threshold
145
- if self.l1_hidden_states_diff_threshold is not None:
146
- # Use the L1 hidden states diff threshold if set
147
- residual_diff_threshold = self.l1_hidden_states_diff_threshold
148
- if isinstance(residual_diff_threshold, torch.Tensor):
149
- residual_diff_threshold = residual_diff_threshold.item()
150
- return residual_diff_threshold
151
-
152
- @torch.compiler.disable
153
- def get_buffer(self, name):
154
- if self.enable_alter_cache and self.is_alter_cache:
155
- name = f"{name}_alter"
156
- return self.buffers.get(name)
157
-
158
- @torch.compiler.disable
159
- def set_buffer(self, name, buffer):
160
- if self.enable_alter_cache and self.is_alter_cache:
161
- name = f"{name}_alter"
162
- self.buffers[name] = buffer
163
-
164
- @torch.compiler.disable
165
- def remove_buffer(self, name):
166
- if self.enable_alter_cache and self.is_alter_cache:
167
- name = f"{name}_alter"
168
- if name in self.buffers:
169
- del self.buffers[name]
170
-
171
- @torch.compiler.disable
172
- def clear_buffers(self):
173
- self.buffers.clear()
174
-
175
- @torch.compiler.disable
176
- def mark_step_begin(self):
177
- # Always increase transformer executed steps
178
- # incr step: prev 0 -> 1; prev 1 -> 2
179
- # current step: incr step - 1
180
- self.transformer_executed_steps += 1
181
- if not self.do_separate_cfg:
182
- self.executed_steps += 1
183
- else:
184
- # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
185
- if not self.cfg_compute_first:
186
- if not self.is_separate_cfg_step():
187
- # transformer step: 0,2,4,...
188
- self.executed_steps += 1
189
- else:
190
- if self.is_separate_cfg_step():
191
- # transformer step: 0,2,4,...
192
- self.executed_steps += 1
193
-
194
- if not self.enable_alter_cache:
195
- # 0 F 1 T 2 F 3 T 4 F 5 T ...
196
- self.is_alter_cache = not self.is_alter_cache
197
-
198
- # Reset the cached steps and residual diffs at the beginning
199
- # of each inference.
200
- if self.get_current_transformer_step() == 0:
201
- self.cached_steps.clear()
202
- self.residual_diffs.clear()
203
- self.cfg_cached_steps.clear()
204
- self.cfg_residual_diffs.clear()
205
- # Reset the TaylorSeers cache at the beginning of each inference.
206
- # reset_cache will set the current step to -1 for TaylorSeer,
207
- if self.enable_taylorseer or self.enable_encoder_taylorseer:
208
- taylorseer, encoder_taylorseer = self.get_taylorseers()
209
- if taylorseer is not None:
210
- taylorseer.reset_cache()
211
- if encoder_taylorseer is not None:
212
- encoder_taylorseer.reset_cache()
213
- cfg_taylorseer, cfg_encoder_taylorseer = (
214
- self.get_cfg_taylorseers()
215
- )
216
- if cfg_taylorseer is not None:
217
- cfg_taylorseer.reset_cache()
218
- if cfg_encoder_taylorseer is not None:
219
- cfg_encoder_taylorseer.reset_cache()
220
-
221
- # mark_step_begin of TaylorSeer must be called after the cache is reset.
222
- if self.enable_taylorseer or self.enable_encoder_taylorseer:
223
- if self.do_separate_cfg:
224
- # Assume non-CFG steps: 0, 2, 4, 6, ...
225
- if not self.is_separate_cfg_step():
226
- taylorseer, encoder_taylorseer = self.get_taylorseers()
227
- if taylorseer is not None:
228
- taylorseer.mark_step_begin()
229
- if encoder_taylorseer is not None:
230
- encoder_taylorseer.mark_step_begin()
231
- else:
232
- cfg_taylorseer, cfg_encoder_taylorseer = (
233
- self.get_cfg_taylorseers()
234
- )
235
- if cfg_taylorseer is not None:
236
- cfg_taylorseer.mark_step_begin()
237
- if cfg_encoder_taylorseer is not None:
238
- cfg_encoder_taylorseer.mark_step_begin()
239
- else:
240
- taylorseer, encoder_taylorseer = self.get_taylorseers()
241
- if taylorseer is not None:
242
- taylorseer.mark_step_begin()
243
- if encoder_taylorseer is not None:
244
- encoder_taylorseer.mark_step_begin()
245
-
246
- def get_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
247
- return self.taylorseer, self.encoder_tarlorseer
248
-
249
- def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
250
- return self.cfg_taylorseer, self.cfg_encoder_taylorseer
251
-
252
- @torch.compiler.disable
253
- def add_residual_diff(self, diff):
254
- # step: executed_steps - 1, not transformer_steps - 1
255
- step = str(self.get_current_step())
256
- # Only add the diff if it is not already recorded for this step
257
- if not self.is_separate_cfg_step():
258
- if step not in self.residual_diffs:
259
- self.residual_diffs[step] = diff
260
- else:
261
- if step not in self.cfg_residual_diffs:
262
- self.cfg_residual_diffs[step] = diff
263
-
264
- @torch.compiler.disable
265
- def get_residual_diffs(self):
266
- return self.residual_diffs.copy()
267
-
268
- @torch.compiler.disable
269
- def get_cfg_residual_diffs(self):
270
- return self.cfg_residual_diffs.copy()
271
-
272
- @torch.compiler.disable
273
- def add_cached_step(self):
274
- curr_cached_step = self.get_current_step()
275
- if not self.is_separate_cfg_step():
276
- if self.cached_steps:
277
- prev_cached_step = self.cached_steps[-1]
278
- if curr_cached_step - prev_cached_step == 1:
279
- if self.continuous_cached_steps == 0:
280
- self.continuous_cached_steps += 2
281
- else:
282
- self.continuous_cached_steps += 1
283
- else:
284
- self.continuous_cached_steps += 1
285
-
286
- self.cached_steps.append(curr_cached_step)
287
- else:
288
- if self.cfg_cached_steps:
289
- prev_cfg_cached_step = self.cfg_cached_steps[-1]
290
- if curr_cached_step - prev_cfg_cached_step == 1:
291
- if self.cfg_continuous_cached_steps == 0:
292
- self.cfg_continuous_cached_steps += 2
293
- else:
294
- self.cfg_continuous_cached_steps += 1
295
- else:
296
- self.cfg_continuous_cached_steps += 1
297
-
298
- self.cfg_cached_steps.append(curr_cached_step)
299
-
300
- @torch.compiler.disable
301
- def get_cached_steps(self):
302
- return self.cached_steps.copy()
303
-
304
- @torch.compiler.disable
305
- def get_cfg_cached_steps(self):
306
- return self.cfg_cached_steps.copy()
307
-
308
- @torch.compiler.disable
309
- def get_current_step(self):
310
- return self.executed_steps - 1
311
-
312
- @torch.compiler.disable
313
- def get_current_transformer_step(self):
314
- return self.transformer_executed_steps - 1
315
-
316
- @torch.compiler.disable
317
- def is_separate_cfg_step(self):
318
- if not self.do_separate_cfg:
319
- return False
320
- if self.cfg_compute_first:
321
- # CFG steps: 0, 2, 4, 6, ...
322
- return self.get_current_transformer_step() % 2 == 0
323
- # CFG steps: 1, 3, 5, 7, ...
324
- return self.get_current_transformer_step() % 2 != 0
325
-
326
- @torch.compiler.disable
327
- def is_in_warmup(self):
328
- return self.get_current_step() < self.max_warmup_steps
329
-
330
-
331
- # TODO: Support context manager for different cache_context
332
-
333
-
334
- def create_cache_context(*args, **kwargs):
335
- return DBCacheContext(*args, **kwargs)
336
-
337
-
338
- def get_current_cache_context():
339
- return _current_cache_context
340
-
341
-
342
- def set_current_cache_context(cache_context=None):
343
- global _current_cache_context
344
- _current_cache_context = cache_context
345
-
346
-
347
- @contextlib.contextmanager
348
- def cache_context(cache_context):
349
- global _current_cache_context
350
- old_cache_context = _current_cache_context
351
- _current_cache_context = cache_context
352
- try:
353
- yield
354
- finally:
355
- _current_cache_context = old_cache_context
356
-
357
-
358
- @torch.compiler.disable
359
- def get_residual_diff_threshold():
360
- cache_context = get_current_cache_context()
361
- assert cache_context is not None, "cache_context must be set before"
362
- return cache_context.get_residual_diff_threshold()
363
-
364
-
365
- @torch.compiler.disable
366
- def get_buffer(name):
367
- cache_context = get_current_cache_context()
368
- assert cache_context is not None, "cache_context must be set before"
369
- return cache_context.get_buffer(name)
370
-
371
-
372
- @torch.compiler.disable
373
- def set_buffer(name, buffer):
374
- cache_context = get_current_cache_context()
375
- assert cache_context is not None, "cache_context must be set before"
376
- cache_context.set_buffer(name, buffer)
377
-
378
-
379
- @torch.compiler.disable
380
- def remove_buffer(name):
381
- cache_context = get_current_cache_context()
382
- assert cache_context is not None, "cache_context must be set before"
383
- cache_context.remove_buffer(name)
384
-
385
-
386
- @torch.compiler.disable
387
- def mark_step_begin():
388
- cache_context = get_current_cache_context()
389
- assert cache_context is not None, "cache_context must be set before"
390
- cache_context.mark_step_begin()
391
-
392
-
393
- @torch.compiler.disable
394
- def get_current_step():
395
- cache_context = get_current_cache_context()
396
- assert cache_context is not None, "cache_context must be set before"
397
- return cache_context.get_current_step()
398
-
399
-
400
- @torch.compiler.disable
401
- def get_current_step_residual_diff():
402
- cache_context = get_current_cache_context()
403
- assert cache_context is not None, "cache_context must be set before"
404
- step = str(get_current_step())
405
- residual_diffs = get_residual_diffs()
406
- if step in residual_diffs:
407
- return residual_diffs[step]
408
- return None
409
-
410
-
411
- @torch.compiler.disable
412
- def get_current_step_cfg_residual_diff():
413
- cache_context = get_current_cache_context()
414
- assert cache_context is not None, "cache_context must be set before"
415
- step = str(get_current_step())
416
- cfg_residual_diffs = get_cfg_residual_diffs()
417
- if step in cfg_residual_diffs:
418
- return cfg_residual_diffs[step]
419
- return None
420
-
421
-
422
- @torch.compiler.disable
423
- def get_current_transformer_step():
424
- cache_context = get_current_cache_context()
425
- assert cache_context is not None, "cache_context must be set before"
426
- return cache_context.get_current_transformer_step()
427
-
428
-
429
- @torch.compiler.disable
430
- def get_cached_steps():
431
- cache_context = get_current_cache_context()
432
- assert cache_context is not None, "cache_context must be set before"
433
- return cache_context.get_cached_steps()
434
-
435
-
436
- @torch.compiler.disable
437
- def get_cfg_cached_steps():
438
- cache_context = get_current_cache_context()
439
- assert cache_context is not None, "cache_context must be set before"
440
- return cache_context.get_cfg_cached_steps()
441
-
442
-
443
- @torch.compiler.disable
444
- def get_max_cached_steps():
445
- cache_context = get_current_cache_context()
446
- assert cache_context is not None, "cache_context must be set before"
447
- return cache_context.max_cached_steps
448
-
449
-
450
- @torch.compiler.disable
451
- def get_max_continuous_cached_steps():
452
- cache_context = get_current_cache_context()
453
- assert cache_context is not None, "cache_context must be set before"
454
- return cache_context.max_continuous_cached_steps
455
-
456
-
457
- @torch.compiler.disable
458
- def get_continuous_cached_steps():
459
- cache_context = get_current_cache_context()
460
- assert cache_context is not None, "cache_context must be set before"
461
- return cache_context.continuous_cached_steps
462
-
463
-
464
- @torch.compiler.disable
465
- def get_cfg_continuous_cached_steps():
466
- cache_context = get_current_cache_context()
467
- assert cache_context is not None, "cache_context must be set before"
468
- return cache_context.cfg_continuous_cached_steps
469
-
470
-
471
- @torch.compiler.disable
472
- def add_cached_step():
473
- cache_context = get_current_cache_context()
474
- assert cache_context is not None, "cache_context must be set before"
475
- cache_context.add_cached_step()
476
-
477
-
478
- @torch.compiler.disable
479
- def add_residual_diff(diff):
480
- cache_context = get_current_cache_context()
481
- assert cache_context is not None, "cache_context must be set before"
482
- cache_context.add_residual_diff(diff)
483
-
484
-
485
- @torch.compiler.disable
486
- def get_residual_diffs():
487
- cache_context = get_current_cache_context()
488
- assert cache_context is not None, "cache_context must be set before"
489
- return cache_context.get_residual_diffs()
490
-
491
-
492
- @torch.compiler.disable
493
- def get_cfg_residual_diffs():
494
- cache_context = get_current_cache_context()
495
- assert cache_context is not None, "cache_context must be set before"
496
- return cache_context.get_cfg_residual_diffs()
497
-
498
-
499
- @torch.compiler.disable
500
- def is_taylorseer_enabled():
501
- cache_context = get_current_cache_context()
502
- assert cache_context is not None, "cache_context must be set before"
503
- return cache_context.enable_taylorseer
504
-
505
-
506
- @torch.compiler.disable
507
- def is_encoder_taylorseer_enabled():
508
- cache_context = get_current_cache_context()
509
- assert cache_context is not None, "cache_context must be set before"
510
- return cache_context.enable_encoder_taylorseer
511
-
512
-
513
- def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
514
- cache_context = get_current_cache_context()
515
- assert cache_context is not None, "cache_context must be set before"
516
- return cache_context.get_taylorseers()
517
-
518
-
519
- def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
520
- cache_context = get_current_cache_context()
521
- assert cache_context is not None, "cache_context must be set before"
522
- return cache_context.get_cfg_taylorseers()
523
-
524
-
525
- @torch.compiler.disable
526
- def is_taylorseer_cache_residual():
527
- cache_context = get_current_cache_context()
528
- assert cache_context is not None, "cache_context must be set before"
529
- return cache_context.taylorseer_cache_type == "residual"
530
-
531
-
532
- @torch.compiler.disable
533
- def is_cache_residual():
534
- if is_taylorseer_enabled():
535
- # residual or hidden_states
536
- return is_taylorseer_cache_residual()
537
- return True
538
-
539
-
540
- @torch.compiler.disable
541
- def is_encoder_cache_residual():
542
- if is_encoder_taylorseer_enabled():
543
- # residual or hidden_states
544
- return is_taylorseer_cache_residual()
545
- return True
546
-
547
-
548
- @torch.compiler.disable
549
- def is_alter_cache_enabled():
550
- cache_context = get_current_cache_context()
551
- assert cache_context is not None, "cache_context must be set before"
552
- return cache_context.enable_alter_cache
553
-
554
-
555
- @torch.compiler.disable
556
- def is_alter_cache():
557
- cache_context = get_current_cache_context()
558
- assert cache_context is not None, "cache_context must be set before"
559
- return cache_context.is_alter_cache
560
-
561
-
562
- @torch.compiler.disable
563
- def is_in_warmup():
564
- cache_context = get_current_cache_context()
565
- assert cache_context is not None, "cache_context must be set before"
566
- return cache_context.is_in_warmup()
567
-
568
-
569
- @torch.compiler.disable
570
- def is_l1_diff_enabled():
571
- cache_context = get_current_cache_context()
572
- assert cache_context is not None, "cache_context must be set before"
573
- return (
574
- cache_context.l1_hidden_states_diff_threshold is not None
575
- and cache_context.l1_hidden_states_diff_threshold > 0.0
576
- )
577
-
578
-
579
- @torch.compiler.disable
580
- def get_important_condition_threshold():
581
- cache_context = get_current_cache_context()
582
- assert cache_context is not None, "cache_context must be set before"
583
- return cache_context.important_condition_threshold
584
-
585
-
586
- @torch.compiler.disable
587
- def non_compute_blocks_diff_threshold():
588
- cache_context = get_current_cache_context()
589
- assert cache_context is not None, "cache_context must be set before"
590
- return cache_context.non_compute_blocks_diff_threshold
591
-
592
-
593
- @torch.compiler.disable
594
- def Fn_compute_blocks():
595
- cache_context = get_current_cache_context()
596
- assert cache_context is not None, "cache_context must be set before"
597
- assert (
598
- cache_context.Fn_compute_blocks >= 1
599
- ), "Fn_compute_blocks must be >= 1"
600
- if cache_context.max_Fn_compute_blocks > 0:
601
- # NOTE: Fn_compute_blocks can be 1, which means FB Cache
602
- # but it must be less than or equal to max_Fn_compute_blocks
603
- assert (
604
- cache_context.Fn_compute_blocks
605
- <= cache_context.max_Fn_compute_blocks
606
- ), (
607
- f"Fn_compute_blocks must be <= {cache_context.max_Fn_compute_blocks}, "
608
- f"but got {cache_context.Fn_compute_blocks}"
609
- )
610
- return cache_context.Fn_compute_blocks
611
-
612
-
613
- @torch.compiler.disable
614
- def Fn_compute_blocks_ids():
615
- cache_context = get_current_cache_context()
616
- assert cache_context is not None, "cache_context must be set before"
617
- assert (
618
- len(cache_context.Fn_compute_blocks_ids)
619
- <= cache_context.Fn_compute_blocks
620
- ), (
621
- "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
622
- f"{cache_context.Fn_compute_blocks}, but got "
623
- f"{len(cache_context.Fn_compute_blocks_ids)}"
624
- )
625
- return cache_context.Fn_compute_blocks_ids
626
-
627
-
628
- @torch.compiler.disable
629
- def Bn_compute_blocks():
630
- cache_context = get_current_cache_context()
631
- assert cache_context is not None, "cache_context must be set before"
632
- assert (
633
- cache_context.Bn_compute_blocks >= 0
634
- ), "Bn_compute_blocks must be >= 0"
635
- if cache_context.max_Bn_compute_blocks > 0:
636
- # NOTE: Bn_compute_blocks can be 0, which means FB Cache
637
- # but it must be less than or equal to max_Bn_compute_blocks
638
- assert (
639
- cache_context.Bn_compute_blocks
640
- <= cache_context.max_Bn_compute_blocks
641
- ), (
642
- f"Bn_compute_blocks must be <= {cache_context.max_Bn_compute_blocks}, "
643
- f"but got {cache_context.Bn_compute_blocks}"
644
- )
645
- return cache_context.Bn_compute_blocks
646
-
647
-
648
- @torch.compiler.disable
649
- def Bn_compute_blocks_ids():
650
- cache_context = get_current_cache_context()
651
- assert cache_context is not None, "cache_context must be set before"
652
- assert (
653
- len(cache_context.Bn_compute_blocks_ids)
654
- <= cache_context.Bn_compute_blocks
655
- ), (
656
- "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
657
- f"{cache_context.Bn_compute_blocks}, but got "
658
- f"{len(cache_context.Bn_compute_blocks_ids)}"
659
- )
660
- return cache_context.Bn_compute_blocks_ids
661
-
662
-
663
- @torch.compiler.disable
664
- def do_separate_cfg():
665
- cache_context = get_current_cache_context()
666
- assert cache_context is not None, "cache_context must be set before"
667
- return cache_context.do_separate_cfg
668
-
669
-
670
- @torch.compiler.disable
671
- def is_separate_cfg_step():
672
- cache_context = get_current_cache_context()
673
- assert cache_context is not None, "cache_context must be set before"
674
- return cache_context.is_separate_cfg_step()
675
-
676
-
677
- @torch.compiler.disable
678
- def cfg_diff_compute_separate():
679
- cache_context = get_current_cache_context()
680
- assert cache_context is not None, "cache_context must be set before"
681
- return cache_context.cfg_diff_compute_separate
682
-
683
-
684
- _current_cache_context: DBCacheContext = None
685
-
686
-
687
- def collect_cache_kwargs(default_attrs: dict, **kwargs):
688
- # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
689
- # default_attrs: specific settings for different pipelines
690
- cache_attrs = dataclasses.fields(DBCacheContext)
691
- cache_attrs = [
692
- attr
693
- for attr in cache_attrs
694
- if hasattr(
695
- DBCacheContext,
696
- attr.name,
697
- )
698
- ]
699
- cache_kwargs = {
700
- attr.name: kwargs.pop(
701
- attr.name,
702
- getattr(DBCacheContext, attr.name),
703
- )
704
- for attr in cache_attrs
705
- }
706
-
707
- def _safe_set_sequence_field(
708
- field_name: str,
709
- default_value: Any = None,
710
- ):
711
- if field_name not in cache_kwargs:
712
- cache_kwargs[field_name] = kwargs.pop(
713
- field_name,
714
- default_value,
715
- )
716
-
717
- # Manually set sequence fields, namely, Fn_compute_blocks_ids
718
- # and Bn_compute_blocks_ids, which are lists or sets.
719
- _safe_set_sequence_field("Fn_compute_blocks_ids", [])
720
- _safe_set_sequence_field("Bn_compute_blocks_ids", [])
721
- _safe_set_sequence_field("taylorseer_kwargs", {})
722
-
723
- for attr in cache_attrs:
724
- if attr.name in default_attrs: # can be empty {}
725
- cache_kwargs[attr.name] = default_attrs[attr.name]
726
-
727
- if logger.isEnabledFor(logging.DEBUG):
728
- logger.debug(f"Collected DBCache kwargs: {cache_kwargs}")
729
-
730
- return cache_kwargs, kwargs
731
-
732
-
733
- @torch.compiler.disable
734
- def are_two_tensors_similar(
735
- t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
736
- t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
737
- *,
738
- threshold: float,
739
- parallelized: bool = False,
740
- prefix: str = "Fn", # for debugging
741
- ):
742
- # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
743
- # the threshold is always enabled, -2.0 means the shape is not matched.
744
- if threshold <= 0.0:
745
- add_residual_diff(-0.0)
746
- return False
747
-
748
- if threshold >= 1.0:
749
- # If threshold is 1.0 or more, we consider them always similar.
750
- add_residual_diff(-1.0)
751
- return True
752
-
753
- if t1.shape != t2.shape:
754
- if logger.isEnabledFor(logging.DEBUG):
755
- logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
756
- add_residual_diff(-2.0)
757
- return False
758
-
759
- if all(
760
- (
761
- do_separate_cfg(),
762
- is_separate_cfg_step(),
763
- not cfg_diff_compute_separate(),
764
- get_current_step_residual_diff() is not None,
765
- )
766
- ):
767
- # Reuse computed diff value from non-CFG step
768
- diff = get_current_step_residual_diff()
769
- else:
770
- # Find the most significant token through t1 and t2, and
771
- # consider the diff of the significant token. The more significant,
772
- # the more important.
773
- condition_thresh = get_important_condition_threshold()
774
- if condition_thresh > 0.0:
775
- raw_diff = (t1 - t2).abs() # [B, seq_len, d]
776
- token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
777
- token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
778
- # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
779
- token_diff = token_m_df / token_m_t1 # [B, seq_len]
780
- condition = token_diff > condition_thresh # [B, seq_len]
781
- if condition.sum() > 0:
782
- condition = condition.unsqueeze(-1) # [B, seq_len, 1]
783
- condition = condition.expand_as(raw_diff) # [B, seq_len, d]
784
- mean_diff = raw_diff[condition].mean()
785
- mean_t1 = t1[condition].abs().mean()
786
- else:
787
- mean_diff = (t1 - t2).abs().mean()
788
- mean_t1 = t1.abs().mean()
789
- else:
790
- # Use the mean of the absolute difference of the tensors
791
- mean_diff = (t1 - t2).abs().mean()
792
- mean_t1 = t1.abs().mean()
793
-
794
- if parallelized:
795
- dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
796
- dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
797
-
798
- # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
799
- # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
800
- # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
801
- diff = (mean_diff / mean_t1).item()
802
-
803
- if logger.isEnabledFor(logging.DEBUG):
804
- logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
805
-
806
- add_residual_diff(diff)
807
-
808
- return diff < threshold
809
-
810
-
811
- @torch.compiler.disable
812
- def _debugging_set_buffer(prefix):
813
- if logger.isEnabledFor(logging.DEBUG):
814
- logger.debug(
815
- f"set {prefix}, "
816
- f"transformer step: {get_current_transformer_step()}, "
817
- f"executed step: {get_current_step()}"
818
- )
819
-
820
-
821
- @torch.compiler.disable
822
- def _debugging_get_buffer(prefix):
823
- if logger.isEnabledFor(logging.DEBUG):
824
- logger.debug(
825
- f"get {prefix}, "
826
- f"transformer step: {get_current_transformer_step()}, "
827
- f"executed step: {get_current_step()}"
828
- )
829
-
830
-
831
- # Fn buffers
832
- @torch.compiler.disable
833
- def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
834
- # Set hidden_states or residual for Fn blocks.
835
- # This buffer is only use for L1 diff calculation.
836
- downsample_factor = get_downsample_factor()
837
- if downsample_factor > 1:
838
- buffer = buffer[..., ::downsample_factor]
839
- buffer = buffer.contiguous()
840
- if is_separate_cfg_step():
841
- _debugging_set_buffer(f"{prefix}_buffer_cfg")
842
- set_buffer(f"{prefix}_buffer_cfg", buffer)
843
- else:
844
- _debugging_set_buffer(f"{prefix}_buffer")
845
- set_buffer(f"{prefix}_buffer", buffer)
846
-
847
-
848
- @torch.compiler.disable
849
- def get_Fn_buffer(prefix: str = "Fn"):
850
- if is_separate_cfg_step():
851
- _debugging_get_buffer(f"{prefix}_buffer_cfg")
852
- return get_buffer(f"{prefix}_buffer_cfg")
853
- _debugging_get_buffer(f"{prefix}_buffer")
854
- return get_buffer(f"{prefix}_buffer")
855
-
856
-
857
- @torch.compiler.disable
858
- def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
859
- if is_separate_cfg_step():
860
- _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
861
- set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
862
- else:
863
- _debugging_set_buffer(f"{prefix}_encoder_buffer")
864
- set_buffer(f"{prefix}_encoder_buffer", buffer)
865
-
866
-
867
- @torch.compiler.disable
868
- def get_Fn_encoder_buffer(prefix: str = "Fn"):
869
- if is_separate_cfg_step():
870
- _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
871
- return get_buffer(f"{prefix}_encoder_buffer_cfg")
872
- _debugging_get_buffer(f"{prefix}_encoder_buffer")
873
- return get_buffer(f"{prefix}_encoder_buffer")
874
-
875
-
876
- # Bn buffers
877
- @torch.compiler.disable
878
- def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
879
- # Set hidden_states or residual for Bn blocks.
880
- # This buffer is use for hidden states approximation.
881
- if is_taylorseer_enabled():
882
- # taylorseer, encoder_taylorseer
883
- if is_separate_cfg_step():
884
- taylorseer, _ = get_cfg_taylorseers()
885
- else:
886
- taylorseer, _ = get_taylorseers()
887
-
888
- if taylorseer is not None:
889
- # Use TaylorSeer to update the buffer
890
- taylorseer.update(buffer)
891
- else:
892
- if logger.isEnabledFor(logging.DEBUG):
893
- logger.debug(
894
- "TaylorSeer is enabled but not set in the cache context. "
895
- "Falling back to default buffer retrieval."
896
- )
897
- if is_separate_cfg_step():
898
- _debugging_set_buffer(f"{prefix}_buffer_cfg")
899
- set_buffer(f"{prefix}_buffer_cfg", buffer)
900
- else:
901
- _debugging_set_buffer(f"{prefix}_buffer")
902
- set_buffer(f"{prefix}_buffer", buffer)
903
- else:
904
- if is_separate_cfg_step():
905
- _debugging_set_buffer(f"{prefix}_buffer_cfg")
906
- set_buffer(f"{prefix}_buffer_cfg", buffer)
907
- else:
908
- _debugging_set_buffer(f"{prefix}_buffer")
909
- set_buffer(f"{prefix}_buffer", buffer)
910
-
911
-
912
- @torch.compiler.disable
913
- def get_Bn_buffer(prefix: str = "Bn"):
914
- if is_taylorseer_enabled():
915
- # taylorseer, encoder_taylorseer
916
- if is_separate_cfg_step():
917
- taylorseer, _ = get_cfg_taylorseers()
918
- else:
919
- taylorseer, _ = get_taylorseers()
920
-
921
- if taylorseer is not None:
922
- return taylorseer.approximate_value()
923
- else:
924
- if logger.isEnabledFor(logging.DEBUG):
925
- logger.debug(
926
- "TaylorSeer is enabled but not set in the cache context. "
927
- "Falling back to default buffer retrieval."
928
- )
929
- # Fallback to default buffer retrieval
930
- if is_separate_cfg_step():
931
- _debugging_get_buffer(f"{prefix}_buffer_cfg")
932
- return get_buffer(f"{prefix}_buffer_cfg")
933
- _debugging_get_buffer(f"{prefix}_buffer")
934
- return get_buffer(f"{prefix}_buffer")
935
- else:
936
- if is_separate_cfg_step():
937
- _debugging_get_buffer(f"{prefix}_buffer_cfg")
938
- return get_buffer(f"{prefix}_buffer_cfg")
939
- _debugging_get_buffer(f"{prefix}_buffer")
940
- return get_buffer(f"{prefix}_buffer")
941
-
942
-
943
- @torch.compiler.disable
944
- def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
945
- # DON'T set None Buffer
946
- if buffer is None:
947
- return
948
-
949
- # This buffer is use for encoder hidden states approximation.
950
- if is_encoder_taylorseer_enabled():
951
- # taylorseer, encoder_taylorseer
952
- if is_separate_cfg_step():
953
- _, encoder_taylorseer = get_cfg_taylorseers()
954
- else:
955
- _, encoder_taylorseer = get_taylorseers()
956
-
957
- if encoder_taylorseer is not None:
958
- # Use TaylorSeer to update the buffer
959
- encoder_taylorseer.update(buffer)
960
- else:
961
- if logger.isEnabledFor(logging.DEBUG):
962
- logger.debug(
963
- "TaylorSeer is enabled but not set in the cache context. "
964
- "Falling back to default buffer retrieval."
965
- )
966
- if is_separate_cfg_step():
967
- _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
968
- set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
969
- else:
970
- _debugging_set_buffer(f"{prefix}_encoder_buffer")
971
- set_buffer(f"{prefix}_encoder_buffer", buffer)
972
- else:
973
- if is_separate_cfg_step():
974
- _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
975
- set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
976
- else:
977
- _debugging_set_buffer(f"{prefix}_encoder_buffer")
978
- set_buffer(f"{prefix}_encoder_buffer", buffer)
979
-
980
-
981
- @torch.compiler.disable
982
- def get_Bn_encoder_buffer(prefix: str = "Bn"):
983
- if is_encoder_taylorseer_enabled():
984
- if is_separate_cfg_step():
985
- _, encoder_taylorseer = get_cfg_taylorseers()
986
- else:
987
- _, encoder_taylorseer = get_taylorseers()
988
-
989
- if encoder_taylorseer is not None:
990
- # Use TaylorSeer to approximate the value
991
- return encoder_taylorseer.approximate_value()
992
- else:
993
- if logger.isEnabledFor(logging.DEBUG):
994
- logger.debug(
995
- "TaylorSeer is enabled but not set in the cache context. "
996
- "Falling back to default buffer retrieval."
997
- )
998
- # Fallback to default buffer retrieval
999
- if is_separate_cfg_step():
1000
- _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
1001
- return get_buffer(f"{prefix}_encoder_buffer_cfg")
1002
- _debugging_get_buffer(f"{prefix}_encoder_buffer")
1003
- return get_buffer(f"{prefix}_encoder_buffer")
1004
- else:
1005
- if is_separate_cfg_step():
1006
- _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
1007
- return get_buffer(f"{prefix}_encoder_buffer_cfg")
1008
- _debugging_get_buffer(f"{prefix}_encoder_buffer")
1009
- return get_buffer(f"{prefix}_encoder_buffer")
1010
-
1011
-
1012
- @torch.compiler.disable
1013
- def apply_hidden_states_residual(
1014
- hidden_states: torch.Tensor,
1015
- encoder_hidden_states: torch.Tensor = None,
1016
- prefix: str = "Bn",
1017
- encoder_prefix: str = "Bn_encoder",
1018
- ):
1019
- # Allow Bn and Fn prefix to be used for residual cache.
1020
- if "Bn" in prefix:
1021
- hidden_states_prev = get_Bn_buffer(prefix)
1022
- else:
1023
- hidden_states_prev = get_Fn_buffer(prefix)
1024
-
1025
- assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
1026
-
1027
- if is_cache_residual():
1028
- hidden_states = hidden_states_prev + hidden_states
1029
- else:
1030
- # If cache is not residual, we use the hidden states directly
1031
- hidden_states = hidden_states_prev
1032
-
1033
- hidden_states = hidden_states.contiguous()
1034
-
1035
- if encoder_hidden_states is not None:
1036
- if "Bn" in encoder_prefix:
1037
- encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
1038
- else:
1039
- encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
1040
-
1041
- assert (
1042
- encoder_hidden_states_prev is not None
1043
- ), f"{prefix}_encoder_buffer must be set before"
1044
-
1045
- if is_encoder_cache_residual():
1046
- encoder_hidden_states = (
1047
- encoder_hidden_states_prev + encoder_hidden_states
1048
- )
1049
- else:
1050
- # If encoder cache is not residual, we use the encoder hidden states directly
1051
- encoder_hidden_states = encoder_hidden_states_prev
1052
-
1053
- encoder_hidden_states = encoder_hidden_states.contiguous()
1054
-
1055
- return hidden_states, encoder_hidden_states
1056
-
1057
-
1058
- @torch.compiler.disable
1059
- def get_downsample_factor():
1060
- cache_context = get_current_cache_context()
1061
- assert cache_context is not None, "cache_context must be set before"
1062
- return cache_context.downsample_factor
1063
-
1064
-
1065
- @torch.compiler.disable
1066
- def get_can_use_cache(
1067
- states_tensor: torch.Tensor, # hidden_states or residual
1068
- parallelized: bool = False,
1069
- threshold: Optional[float] = None, # can manually set threshold
1070
- prefix: str = "Fn",
1071
- ):
1072
- if is_in_warmup():
1073
- return False
1074
-
1075
- # max cached steps
1076
- max_cached_steps = get_max_cached_steps()
1077
- if not is_separate_cfg_step():
1078
- cached_steps = get_cached_steps()
1079
- else:
1080
- cached_steps = get_cfg_cached_steps()
1081
-
1082
- if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
1083
- if logger.isEnabledFor(logging.DEBUG):
1084
- logger.debug(
1085
- f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
1086
- "can not use cache."
1087
- )
1088
- return False
1089
-
1090
- # max continuous cached steps
1091
- max_continuous_cached_steps = get_max_continuous_cached_steps()
1092
- if not is_separate_cfg_step():
1093
- continuous_cached_steps = get_continuous_cached_steps()
1094
- else:
1095
- continuous_cached_steps = get_cfg_continuous_cached_steps()
1096
-
1097
- if max_continuous_cached_steps >= 0 and (
1098
- continuous_cached_steps >= max_continuous_cached_steps
1099
- ):
1100
- if logger.isEnabledFor(logging.DEBUG):
1101
- logger.debug(
1102
- f"{prefix}, max_continuous_cached_steps "
1103
- f"reached: {max_continuous_cached_steps}, "
1104
- "can not use cache."
1105
- )
1106
- # reset continuous cached steps stats
1107
- cache_context = get_current_cache_context()
1108
- if not is_separate_cfg_step():
1109
- cache_context.continuous_cached_steps = 0
1110
- else:
1111
- cache_context.cfg_continuous_cached_steps = 0
1112
- return False
1113
-
1114
- if threshold is None or threshold <= 0.0:
1115
- threshold = get_residual_diff_threshold()
1116
- if threshold <= 0.0:
1117
- return False
1118
-
1119
- downsample_factor = get_downsample_factor()
1120
- if downsample_factor > 1 and "Bn" not in prefix:
1121
- states_tensor = states_tensor[..., ::downsample_factor]
1122
- states_tensor = states_tensor.contiguous()
1123
-
1124
- # Allow Bn and Fn prefix to be used for diff calculation.
1125
- if "Bn" in prefix:
1126
- prev_states_tensor = get_Bn_buffer(prefix)
1127
- else:
1128
- prev_states_tensor = get_Fn_buffer(prefix)
1129
-
1130
- if not is_alter_cache_enabled():
1131
- # Dynamic cache according to the residual diff
1132
- can_use_cache = (
1133
- prev_states_tensor is not None
1134
- and are_two_tensors_similar(
1135
- prev_states_tensor,
1136
- states_tensor,
1137
- threshold=threshold,
1138
- parallelized=parallelized,
1139
- prefix=prefix,
1140
- )
1141
- )
1142
- else:
1143
- # Only cache in the alter cache steps
1144
- can_use_cache = (
1145
- prev_states_tensor is not None
1146
- and are_two_tensors_similar(
1147
- prev_states_tensor,
1148
- states_tensor,
1149
- threshold=threshold,
1150
- parallelized=parallelized,
1151
- prefix=prefix,
1152
- )
1153
- and is_alter_cache()
1154
- )
1155
- return can_use_cache