cache-dit 0.2.27__py3-none-any.whl → 0.2.29__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

@@ -1,11 +1,9 @@
1
1
  import logging
2
- import contextlib
3
2
  import dataclasses
4
3
  from collections import defaultdict
5
4
  from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
6
5
 
7
6
  import torch
8
- import torch.distributed as dist
9
7
 
10
8
  from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
11
9
  from cache_dit.logger import init_logger
@@ -14,7 +12,7 @@ logger = init_logger(__name__)
14
12
 
15
13
 
16
14
  @dataclasses.dataclass
17
- class _CachedContext: # Internal CachedContext Impl class
15
+ class CachedContext: # Internal CachedContext Impl class
18
16
  name: str = "default"
19
17
  # Dual Block Cache
20
18
  # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
@@ -67,15 +65,16 @@ class _CachedContext: # Internal CachedContext Impl class
67
65
  enable_encoder_taylorseer: bool = False
68
66
  # NOTE: use residual cache for taylorseer may incur precision loss
69
67
  taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
68
+ taylorseer_order: int = 2 # The order for TaylorSeer
70
69
  taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
71
70
  taylorseer: Optional[TaylorSeer] = None
72
71
  encoder_tarlorseer: Optional[TaylorSeer] = None
73
72
 
74
- # Support do_separate_cfg, such as Wan 2.1,
73
+ # Support enable_spearate_cfg, such as Wan 2.1,
75
74
  # Qwen-Image. For model that fused CFG and non-CFG into single
76
- # forward step, should set do_separate_cfg as False.
75
+ # forward step, should set enable_spearate_cfg as False.
77
76
  # For example: CogVideoX, HunyuanVideo, Mochi.
78
- do_separate_cfg: bool = False
77
+ enable_spearate_cfg: bool = False
79
78
  # Compute cfg forward first or not, default False, namely,
80
79
  # 0, 2, 4, ..., -> non-CFG step; 1, 3, 5, ... -> CFG step.
81
80
  cfg_compute_first: bool = False
@@ -103,10 +102,10 @@ class _CachedContext: # Internal CachedContext Impl class
103
102
  if logger.isEnabledFor(logging.DEBUG):
104
103
  logger.info(f"Created _CacheContext: {self.name}")
105
104
  # Some checks for settings
106
- if self.do_separate_cfg:
105
+ if self.enable_spearate_cfg:
107
106
  assert self.enable_alter_cache is False, (
108
107
  "enable_alter_cache must set as False if "
109
- "do_separate_cfg is enabled."
108
+ "enable_spearate_cfg is enabled."
110
109
  )
111
110
  if self.cfg_diff_compute_separate:
112
111
  assert self.cfg_compute_first is False, (
@@ -121,20 +120,17 @@ class _CachedContext: # Internal CachedContext Impl class
121
120
  self.max_warmup_steps if self.max_warmup_steps > 0 else 1
122
121
  )
123
122
 
124
- # Only set n_derivatives as 2 or 3, which is enough for most cases.
125
- if "n_derivatives" not in self.taylorseer_kwargs:
126
- self.taylorseer_kwargs["n_derivatives"] = max(
127
- 2, min(3, self.taylorseer_kwargs["max_warmup_steps"])
128
- )
123
+ # Overwrite the 'n_derivatives' by 'taylorseer_order', default: 2.
124
+ self.taylorseer_kwargs["n_derivatives"] = self.taylorseer_order
129
125
 
130
126
  if self.enable_taylorseer:
131
127
  self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
132
- if self.do_separate_cfg:
128
+ if self.enable_spearate_cfg:
133
129
  self.cfg_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
134
130
 
135
131
  if self.enable_encoder_taylorseer:
136
132
  self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
137
- if self.do_separate_cfg:
133
+ if self.enable_spearate_cfg:
138
134
  self.cfg_encoder_taylorseer = TaylorSeer(
139
135
  **self.taylorseer_kwargs
140
136
  )
@@ -181,7 +177,7 @@ class _CachedContext: # Internal CachedContext Impl class
181
177
  # incr step: prev 0 -> 1; prev 1 -> 2
182
178
  # current step: incr step - 1
183
179
  self.transformer_executed_steps += 1
184
- if not self.do_separate_cfg:
180
+ if not self.enable_spearate_cfg:
185
181
  self.executed_steps += 1
186
182
  else:
187
183
  # 0,1 -> 0 + 1, 2,3 -> 1 + 1, ...
@@ -223,7 +219,7 @@ class _CachedContext: # Internal CachedContext Impl class
223
219
 
224
220
  # mark_step_begin of TaylorSeer must be called after the cache is reset.
225
221
  if self.enable_taylorseer or self.enable_encoder_taylorseer:
226
- if self.do_separate_cfg:
222
+ if self.enable_spearate_cfg:
227
223
  # Assume non-CFG steps: 0, 2, 4, 6, ...
228
224
  if not self.is_separate_cfg_step():
229
225
  taylorseer, encoder_taylorseer = self.get_taylorseers()
@@ -318,7 +314,7 @@ class _CachedContext: # Internal CachedContext Impl class
318
314
 
319
315
  @torch.compiler.disable
320
316
  def is_separate_cfg_step(self):
321
- if not self.do_separate_cfg:
317
+ if not self.enable_spearate_cfg:
322
318
  return False
323
319
  if self.cfg_compute_first:
324
320
  # CFG steps: 0, 2, 4, 6, ...
@@ -329,861 +325,3 @@ class _CachedContext: # Internal CachedContext Impl class
329
325
  @torch.compiler.disable
330
326
  def is_in_warmup(self):
331
327
  return self.get_current_step() < self.max_warmup_steps
332
-
333
-
334
- # TODO: Support context manager for different cache_context
335
- _current_cache_context: _CachedContext = None
336
-
337
- _cache_context_manager: Dict[str, _CachedContext] = {}
338
-
339
-
340
- def create_cache_context(*args, **kwargs):
341
- global _cache_context_manager
342
- _context = _CachedContext(*args, **kwargs)
343
- _cache_context_manager[_context.name] = _context
344
- return _context
345
-
346
-
347
- def get_cache_context():
348
- return _current_cache_context
349
-
350
-
351
- def set_cache_context(cache_context: _CachedContext | str):
352
- global _current_cache_context, _cache_context_manager
353
- if isinstance(cache_context, _CachedContext):
354
- _current_cache_context = cache_context
355
- else:
356
- _current_cache_context = _cache_context_manager[cache_context]
357
-
358
-
359
- def reset_cache_context(cache_context: _CachedContext | str, *args, **kwargs):
360
- global _cache_context_manager
361
- if isinstance(cache_context, _CachedContext):
362
- old_context_name = cache_context.name
363
- if cache_context.name in _cache_context_manager:
364
- del _cache_context_manager[cache_context.name]
365
- # force use old_context name
366
- kwargs["name"] = old_context_name
367
- _context = _CachedContext(*args, **kwargs)
368
- _cache_context_manager[_context.name] = _context
369
- else:
370
- old_context_name = cache_context
371
- if cache_context in _cache_context_manager:
372
- del _cache_context_manager[cache_context]
373
- # force use old_context name
374
- kwargs["name"] = old_context_name
375
- _context = _CachedContext(*args, **kwargs)
376
- _cache_context_manager[_context.name] = _context
377
-
378
- return _context
379
-
380
-
381
- @contextlib.contextmanager
382
- def cache_context(cache_context: _CachedContext | str):
383
- global _current_cache_context, _cache_context_manager
384
- old_cache_context = _current_cache_context
385
- if isinstance(cache_context, _CachedContext):
386
- _current_cache_context = cache_context
387
- else:
388
- _current_cache_context = _cache_context_manager[cache_context]
389
- try:
390
- yield
391
- finally:
392
- _current_cache_context = old_cache_context
393
-
394
-
395
- @torch.compiler.disable
396
- def get_residual_diff_threshold():
397
- cache_context = get_cache_context()
398
- assert cache_context is not None, "cache_context must be set before"
399
- return cache_context.get_residual_diff_threshold()
400
-
401
-
402
- @torch.compiler.disable
403
- def get_buffer(name):
404
- cache_context = get_cache_context()
405
- assert cache_context is not None, "cache_context must be set before"
406
- return cache_context.get_buffer(name)
407
-
408
-
409
- @torch.compiler.disable
410
- def set_buffer(name, buffer):
411
- cache_context = get_cache_context()
412
- assert cache_context is not None, "cache_context must be set before"
413
- cache_context.set_buffer(name, buffer)
414
-
415
-
416
- @torch.compiler.disable
417
- def remove_buffer(name):
418
- cache_context = get_cache_context()
419
- assert cache_context is not None, "cache_context must be set before"
420
- cache_context.remove_buffer(name)
421
-
422
-
423
- @torch.compiler.disable
424
- def mark_step_begin():
425
- cache_context = get_cache_context()
426
- assert cache_context is not None, "cache_context must be set before"
427
- cache_context.mark_step_begin()
428
-
429
-
430
- @torch.compiler.disable
431
- def get_current_step():
432
- cache_context = get_cache_context()
433
- assert cache_context is not None, "cache_context must be set before"
434
- return cache_context.get_current_step()
435
-
436
-
437
- @torch.compiler.disable
438
- def get_current_step_residual_diff():
439
- cache_context = get_cache_context()
440
- assert cache_context is not None, "cache_context must be set before"
441
- step = str(get_current_step())
442
- residual_diffs = get_residual_diffs()
443
- if step in residual_diffs:
444
- return residual_diffs[step]
445
- return None
446
-
447
-
448
- @torch.compiler.disable
449
- def get_current_step_cfg_residual_diff():
450
- cache_context = get_cache_context()
451
- assert cache_context is not None, "cache_context must be set before"
452
- step = str(get_current_step())
453
- cfg_residual_diffs = get_cfg_residual_diffs()
454
- if step in cfg_residual_diffs:
455
- return cfg_residual_diffs[step]
456
- return None
457
-
458
-
459
- @torch.compiler.disable
460
- def get_current_transformer_step():
461
- cache_context = get_cache_context()
462
- assert cache_context is not None, "cache_context must be set before"
463
- return cache_context.get_current_transformer_step()
464
-
465
-
466
- @torch.compiler.disable
467
- def get_cached_steps():
468
- cache_context = get_cache_context()
469
- assert cache_context is not None, "cache_context must be set before"
470
- return cache_context.get_cached_steps()
471
-
472
-
473
- @torch.compiler.disable
474
- def get_cfg_cached_steps():
475
- cache_context = get_cache_context()
476
- assert cache_context is not None, "cache_context must be set before"
477
- return cache_context.get_cfg_cached_steps()
478
-
479
-
480
- @torch.compiler.disable
481
- def get_max_cached_steps():
482
- cache_context = get_cache_context()
483
- assert cache_context is not None, "cache_context must be set before"
484
- return cache_context.max_cached_steps
485
-
486
-
487
- @torch.compiler.disable
488
- def get_max_continuous_cached_steps():
489
- cache_context = get_cache_context()
490
- assert cache_context is not None, "cache_context must be set before"
491
- return cache_context.max_continuous_cached_steps
492
-
493
-
494
- @torch.compiler.disable
495
- def get_continuous_cached_steps():
496
- cache_context = get_cache_context()
497
- assert cache_context is not None, "cache_context must be set before"
498
- return cache_context.continuous_cached_steps
499
-
500
-
501
- @torch.compiler.disable
502
- def get_cfg_continuous_cached_steps():
503
- cache_context = get_cache_context()
504
- assert cache_context is not None, "cache_context must be set before"
505
- return cache_context.cfg_continuous_cached_steps
506
-
507
-
508
- @torch.compiler.disable
509
- def add_cached_step():
510
- cache_context = get_cache_context()
511
- assert cache_context is not None, "cache_context must be set before"
512
- cache_context.add_cached_step()
513
-
514
-
515
- @torch.compiler.disable
516
- def add_residual_diff(diff):
517
- cache_context = get_cache_context()
518
- assert cache_context is not None, "cache_context must be set before"
519
- cache_context.add_residual_diff(diff)
520
-
521
-
522
- @torch.compiler.disable
523
- def get_residual_diffs():
524
- cache_context = get_cache_context()
525
- assert cache_context is not None, "cache_context must be set before"
526
- return cache_context.get_residual_diffs()
527
-
528
-
529
- @torch.compiler.disable
530
- def get_cfg_residual_diffs():
531
- cache_context = get_cache_context()
532
- assert cache_context is not None, "cache_context must be set before"
533
- return cache_context.get_cfg_residual_diffs()
534
-
535
-
536
- @torch.compiler.disable
537
- def is_taylorseer_enabled():
538
- cache_context = get_cache_context()
539
- assert cache_context is not None, "cache_context must be set before"
540
- return cache_context.enable_taylorseer
541
-
542
-
543
- @torch.compiler.disable
544
- def is_encoder_taylorseer_enabled():
545
- cache_context = get_cache_context()
546
- assert cache_context is not None, "cache_context must be set before"
547
- return cache_context.enable_encoder_taylorseer
548
-
549
-
550
- def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
551
- cache_context = get_cache_context()
552
- assert cache_context is not None, "cache_context must be set before"
553
- return cache_context.get_taylorseers()
554
-
555
-
556
- def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
557
- cache_context = get_cache_context()
558
- assert cache_context is not None, "cache_context must be set before"
559
- return cache_context.get_cfg_taylorseers()
560
-
561
-
562
- @torch.compiler.disable
563
- def is_taylorseer_cache_residual():
564
- cache_context = get_cache_context()
565
- assert cache_context is not None, "cache_context must be set before"
566
- return cache_context.taylorseer_cache_type == "residual"
567
-
568
-
569
- @torch.compiler.disable
570
- def is_cache_residual():
571
- if is_taylorseer_enabled():
572
- # residual or hidden_states
573
- return is_taylorseer_cache_residual()
574
- return True
575
-
576
-
577
- @torch.compiler.disable
578
- def is_encoder_cache_residual():
579
- if is_encoder_taylorseer_enabled():
580
- # residual or hidden_states
581
- return is_taylorseer_cache_residual()
582
- return True
583
-
584
-
585
- @torch.compiler.disable
586
- def is_alter_cache_enabled():
587
- cache_context = get_cache_context()
588
- assert cache_context is not None, "cache_context must be set before"
589
- return cache_context.enable_alter_cache
590
-
591
-
592
- @torch.compiler.disable
593
- def is_alter_cache():
594
- cache_context = get_cache_context()
595
- assert cache_context is not None, "cache_context must be set before"
596
- return cache_context.is_alter_cache
597
-
598
-
599
- @torch.compiler.disable
600
- def is_in_warmup():
601
- cache_context = get_cache_context()
602
- assert cache_context is not None, "cache_context must be set before"
603
- return cache_context.is_in_warmup()
604
-
605
-
606
- @torch.compiler.disable
607
- def is_l1_diff_enabled():
608
- cache_context = get_cache_context()
609
- assert cache_context is not None, "cache_context must be set before"
610
- return (
611
- cache_context.l1_hidden_states_diff_threshold is not None
612
- and cache_context.l1_hidden_states_diff_threshold > 0.0
613
- )
614
-
615
-
616
- @torch.compiler.disable
617
- def get_important_condition_threshold():
618
- cache_context = get_cache_context()
619
- assert cache_context is not None, "cache_context must be set before"
620
- return cache_context.important_condition_threshold
621
-
622
-
623
- @torch.compiler.disable
624
- def non_compute_blocks_diff_threshold():
625
- cache_context = get_cache_context()
626
- assert cache_context is not None, "cache_context must be set before"
627
- return cache_context.non_compute_blocks_diff_threshold
628
-
629
-
630
- @torch.compiler.disable
631
- def Fn_compute_blocks():
632
- cache_context = get_cache_context()
633
- assert cache_context is not None, "cache_context must be set before"
634
- assert (
635
- cache_context.Fn_compute_blocks >= 1
636
- ), "Fn_compute_blocks must be >= 1"
637
- if cache_context.max_Fn_compute_blocks > 0:
638
- # NOTE: Fn_compute_blocks can be 1, which means FB Cache
639
- # but it must be less than or equal to max_Fn_compute_blocks
640
- assert (
641
- cache_context.Fn_compute_blocks
642
- <= cache_context.max_Fn_compute_blocks
643
- ), (
644
- f"Fn_compute_blocks must be <= {cache_context.max_Fn_compute_blocks}, "
645
- f"but got {cache_context.Fn_compute_blocks}"
646
- )
647
- return cache_context.Fn_compute_blocks
648
-
649
-
650
- @torch.compiler.disable
651
- def Fn_compute_blocks_ids():
652
- cache_context = get_cache_context()
653
- assert cache_context is not None, "cache_context must be set before"
654
- assert (
655
- len(cache_context.Fn_compute_blocks_ids)
656
- <= cache_context.Fn_compute_blocks
657
- ), (
658
- "The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
659
- f"{cache_context.Fn_compute_blocks}, but got "
660
- f"{len(cache_context.Fn_compute_blocks_ids)}"
661
- )
662
- return cache_context.Fn_compute_blocks_ids
663
-
664
-
665
- @torch.compiler.disable
666
- def Bn_compute_blocks():
667
- cache_context = get_cache_context()
668
- assert cache_context is not None, "cache_context must be set before"
669
- assert (
670
- cache_context.Bn_compute_blocks >= 0
671
- ), "Bn_compute_blocks must be >= 0"
672
- if cache_context.max_Bn_compute_blocks > 0:
673
- # NOTE: Bn_compute_blocks can be 0, which means FB Cache
674
- # but it must be less than or equal to max_Bn_compute_blocks
675
- assert (
676
- cache_context.Bn_compute_blocks
677
- <= cache_context.max_Bn_compute_blocks
678
- ), (
679
- f"Bn_compute_blocks must be <= {cache_context.max_Bn_compute_blocks}, "
680
- f"but got {cache_context.Bn_compute_blocks}"
681
- )
682
- return cache_context.Bn_compute_blocks
683
-
684
-
685
- @torch.compiler.disable
686
- def Bn_compute_blocks_ids():
687
- cache_context = get_cache_context()
688
- assert cache_context is not None, "cache_context must be set before"
689
- assert (
690
- len(cache_context.Bn_compute_blocks_ids)
691
- <= cache_context.Bn_compute_blocks
692
- ), (
693
- "The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
694
- f"{cache_context.Bn_compute_blocks}, but got "
695
- f"{len(cache_context.Bn_compute_blocks_ids)}"
696
- )
697
- return cache_context.Bn_compute_blocks_ids
698
-
699
-
700
- @torch.compiler.disable
701
- def do_separate_cfg():
702
- cache_context = get_cache_context()
703
- assert cache_context is not None, "cache_context must be set before"
704
- return cache_context.do_separate_cfg
705
-
706
-
707
- @torch.compiler.disable
708
- def is_separate_cfg_step():
709
- cache_context = get_cache_context()
710
- assert cache_context is not None, "cache_context must be set before"
711
- return cache_context.is_separate_cfg_step()
712
-
713
-
714
- @torch.compiler.disable
715
- def cfg_diff_compute_separate():
716
- cache_context = get_cache_context()
717
- assert cache_context is not None, "cache_context must be set before"
718
- return cache_context.cfg_diff_compute_separate
719
-
720
-
721
- def collect_cache_kwargs(default_attrs: dict, **kwargs):
722
- # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
723
- # default_attrs: specific settings for different pipelines
724
- cache_attrs = dataclasses.fields(_CachedContext)
725
- cache_attrs = [
726
- attr
727
- for attr in cache_attrs
728
- if hasattr(
729
- _CachedContext,
730
- attr.name,
731
- )
732
- ]
733
- cache_kwargs = {
734
- attr.name: kwargs.pop(
735
- attr.name,
736
- getattr(_CachedContext, attr.name),
737
- )
738
- for attr in cache_attrs
739
- }
740
-
741
- def _safe_set_sequence_field(
742
- field_name: str,
743
- default_value: Any = None,
744
- ):
745
- if field_name not in cache_kwargs:
746
- cache_kwargs[field_name] = kwargs.pop(
747
- field_name,
748
- default_value,
749
- )
750
-
751
- # Manually set sequence fields, namely, Fn_compute_blocks_ids
752
- # and Bn_compute_blocks_ids, which are lists or sets.
753
- _safe_set_sequence_field("Fn_compute_blocks_ids", [])
754
- _safe_set_sequence_field("Bn_compute_blocks_ids", [])
755
- _safe_set_sequence_field("taylorseer_kwargs", {})
756
-
757
- for attr in cache_attrs:
758
- if attr.name in default_attrs: # can be empty {}
759
- cache_kwargs[attr.name] = default_attrs[attr.name]
760
-
761
- if logger.isEnabledFor(logging.DEBUG):
762
- logger.debug(f"Collected DBCache kwargs: {cache_kwargs}")
763
-
764
- return cache_kwargs, kwargs
765
-
766
-
767
- @torch.compiler.disable
768
- def are_two_tensors_similar(
769
- t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
770
- t2: torch.Tensor, # curr residual R(t ,n) = H(t ,n) - H(t ,0)
771
- *,
772
- threshold: float,
773
- parallelized: bool = False,
774
- prefix: str = "Fn", # for debugging
775
- ):
776
- # Special case for threshold, 0.0 means the threshold is disabled, -1.0 means
777
- # the threshold is always enabled, -2.0 means the shape is not matched.
778
- if threshold <= 0.0:
779
- add_residual_diff(-0.0)
780
- return False
781
-
782
- if threshold >= 1.0:
783
- # If threshold is 1.0 or more, we consider them always similar.
784
- add_residual_diff(-1.0)
785
- return True
786
-
787
- if t1.shape != t2.shape:
788
- if logger.isEnabledFor(logging.DEBUG):
789
- logger.debug(f"{prefix}, shape error: {t1.shape} != {t2.shape}")
790
- add_residual_diff(-2.0)
791
- return False
792
-
793
- if all(
794
- (
795
- do_separate_cfg(),
796
- is_separate_cfg_step(),
797
- not cfg_diff_compute_separate(),
798
- get_current_step_residual_diff() is not None,
799
- )
800
- ):
801
- # Reuse computed diff value from non-CFG step
802
- diff = get_current_step_residual_diff()
803
- else:
804
- # Find the most significant token through t1 and t2, and
805
- # consider the diff of the significant token. The more significant,
806
- # the more important.
807
- condition_thresh = get_important_condition_threshold()
808
- if condition_thresh > 0.0:
809
- raw_diff = (t1 - t2).abs() # [B, seq_len, d]
810
- token_m_df = raw_diff.mean(dim=-1) # [B, seq_len]
811
- token_m_t1 = t1.abs().mean(dim=-1) # [B, seq_len]
812
- # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
813
- token_diff = token_m_df / token_m_t1 # [B, seq_len]
814
- condition = token_diff > condition_thresh # [B, seq_len]
815
- if condition.sum() > 0:
816
- condition = condition.unsqueeze(-1) # [B, seq_len, 1]
817
- condition = condition.expand_as(raw_diff) # [B, seq_len, d]
818
- mean_diff = raw_diff[condition].mean()
819
- mean_t1 = t1[condition].abs().mean()
820
- else:
821
- mean_diff = (t1 - t2).abs().mean()
822
- mean_t1 = t1.abs().mean()
823
- else:
824
- # Use the mean of the absolute difference of the tensors
825
- mean_diff = (t1 - t2).abs().mean()
826
- mean_t1 = t1.abs().mean()
827
-
828
- if parallelized:
829
- dist.all_reduce(mean_diff, op=dist.ReduceOp.AVG)
830
- dist.all_reduce(mean_t1, op=dist.ReduceOp.AVG)
831
-
832
- # D = (t1 - t2) / t1 = 1 - (t2 / t1), if D = 0, then t1 = t2.
833
- # Futher, if we assume that (H(t, 0) - H(t-1,0)) ~ 0, then,
834
- # H(t-1,n) ~ H(t ,n), which means the hidden states are similar.
835
- diff = (mean_diff / mean_t1).item()
836
-
837
- if logger.isEnabledFor(logging.DEBUG):
838
- logger.debug(f"{prefix}, diff: {diff:.6f}, threshold: {threshold:.6f}")
839
-
840
- add_residual_diff(diff)
841
-
842
- return diff < threshold
843
-
844
-
845
- @torch.compiler.disable
846
- def _debugging_set_buffer(prefix):
847
- if logger.isEnabledFor(logging.DEBUG):
848
- logger.debug(
849
- f"set {prefix}, "
850
- f"transformer step: {get_current_transformer_step()}, "
851
- f"executed step: {get_current_step()}"
852
- )
853
-
854
-
855
- @torch.compiler.disable
856
- def _debugging_get_buffer(prefix):
857
- if logger.isEnabledFor(logging.DEBUG):
858
- logger.debug(
859
- f"get {prefix}, "
860
- f"transformer step: {get_current_transformer_step()}, "
861
- f"executed step: {get_current_step()}"
862
- )
863
-
864
-
865
- # Fn buffers
866
- @torch.compiler.disable
867
- def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
868
- # Set hidden_states or residual for Fn blocks.
869
- # This buffer is only use for L1 diff calculation.
870
- downsample_factor = get_downsample_factor()
871
- if downsample_factor > 1:
872
- buffer = buffer[..., ::downsample_factor]
873
- buffer = buffer.contiguous()
874
- if is_separate_cfg_step():
875
- _debugging_set_buffer(f"{prefix}_buffer_cfg")
876
- set_buffer(f"{prefix}_buffer_cfg", buffer)
877
- else:
878
- _debugging_set_buffer(f"{prefix}_buffer")
879
- set_buffer(f"{prefix}_buffer", buffer)
880
-
881
-
882
- @torch.compiler.disable
883
- def get_Fn_buffer(prefix: str = "Fn"):
884
- if is_separate_cfg_step():
885
- _debugging_get_buffer(f"{prefix}_buffer_cfg")
886
- return get_buffer(f"{prefix}_buffer_cfg")
887
- _debugging_get_buffer(f"{prefix}_buffer")
888
- return get_buffer(f"{prefix}_buffer")
889
-
890
-
891
- @torch.compiler.disable
892
- def set_Fn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
893
- if is_separate_cfg_step():
894
- _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
895
- set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
896
- else:
897
- _debugging_set_buffer(f"{prefix}_encoder_buffer")
898
- set_buffer(f"{prefix}_encoder_buffer", buffer)
899
-
900
-
901
- @torch.compiler.disable
902
- def get_Fn_encoder_buffer(prefix: str = "Fn"):
903
- if is_separate_cfg_step():
904
- _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
905
- return get_buffer(f"{prefix}_encoder_buffer_cfg")
906
- _debugging_get_buffer(f"{prefix}_encoder_buffer")
907
- return get_buffer(f"{prefix}_encoder_buffer")
908
-
909
-
910
- # Bn buffers
911
- @torch.compiler.disable
912
- def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
913
- # Set hidden_states or residual for Bn blocks.
914
- # This buffer is use for hidden states approximation.
915
- if is_taylorseer_enabled():
916
- # taylorseer, encoder_taylorseer
917
- if is_separate_cfg_step():
918
- taylorseer, _ = get_cfg_taylorseers()
919
- else:
920
- taylorseer, _ = get_taylorseers()
921
-
922
- if taylorseer is not None:
923
- # Use TaylorSeer to update the buffer
924
- taylorseer.update(buffer)
925
- else:
926
- if logger.isEnabledFor(logging.DEBUG):
927
- logger.debug(
928
- "TaylorSeer is enabled but not set in the cache context. "
929
- "Falling back to default buffer retrieval."
930
- )
931
- if is_separate_cfg_step():
932
- _debugging_set_buffer(f"{prefix}_buffer_cfg")
933
- set_buffer(f"{prefix}_buffer_cfg", buffer)
934
- else:
935
- _debugging_set_buffer(f"{prefix}_buffer")
936
- set_buffer(f"{prefix}_buffer", buffer)
937
- else:
938
- if is_separate_cfg_step():
939
- _debugging_set_buffer(f"{prefix}_buffer_cfg")
940
- set_buffer(f"{prefix}_buffer_cfg", buffer)
941
- else:
942
- _debugging_set_buffer(f"{prefix}_buffer")
943
- set_buffer(f"{prefix}_buffer", buffer)
944
-
945
-
946
- @torch.compiler.disable
947
- def get_Bn_buffer(prefix: str = "Bn"):
948
- if is_taylorseer_enabled():
949
- # taylorseer, encoder_taylorseer
950
- if is_separate_cfg_step():
951
- taylorseer, _ = get_cfg_taylorseers()
952
- else:
953
- taylorseer, _ = get_taylorseers()
954
-
955
- if taylorseer is not None:
956
- return taylorseer.approximate_value()
957
- else:
958
- if logger.isEnabledFor(logging.DEBUG):
959
- logger.debug(
960
- "TaylorSeer is enabled but not set in the cache context. "
961
- "Falling back to default buffer retrieval."
962
- )
963
- # Fallback to default buffer retrieval
964
- if is_separate_cfg_step():
965
- _debugging_get_buffer(f"{prefix}_buffer_cfg")
966
- return get_buffer(f"{prefix}_buffer_cfg")
967
- _debugging_get_buffer(f"{prefix}_buffer")
968
- return get_buffer(f"{prefix}_buffer")
969
- else:
970
- if is_separate_cfg_step():
971
- _debugging_get_buffer(f"{prefix}_buffer_cfg")
972
- return get_buffer(f"{prefix}_buffer_cfg")
973
- _debugging_get_buffer(f"{prefix}_buffer")
974
- return get_buffer(f"{prefix}_buffer")
975
-
976
-
977
- @torch.compiler.disable
978
- def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
979
- # DON'T set None Buffer
980
- if buffer is None:
981
- return
982
-
983
- # This buffer is use for encoder hidden states approximation.
984
- if is_encoder_taylorseer_enabled():
985
- # taylorseer, encoder_taylorseer
986
- if is_separate_cfg_step():
987
- _, encoder_taylorseer = get_cfg_taylorseers()
988
- else:
989
- _, encoder_taylorseer = get_taylorseers()
990
-
991
- if encoder_taylorseer is not None:
992
- # Use TaylorSeer to update the buffer
993
- encoder_taylorseer.update(buffer)
994
- else:
995
- if logger.isEnabledFor(logging.DEBUG):
996
- logger.debug(
997
- "TaylorSeer is enabled but not set in the cache context. "
998
- "Falling back to default buffer retrieval."
999
- )
1000
- if is_separate_cfg_step():
1001
- _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
1002
- set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
1003
- else:
1004
- _debugging_set_buffer(f"{prefix}_encoder_buffer")
1005
- set_buffer(f"{prefix}_encoder_buffer", buffer)
1006
- else:
1007
- if is_separate_cfg_step():
1008
- _debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
1009
- set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
1010
- else:
1011
- _debugging_set_buffer(f"{prefix}_encoder_buffer")
1012
- set_buffer(f"{prefix}_encoder_buffer", buffer)
1013
-
1014
-
1015
- @torch.compiler.disable
1016
- def get_Bn_encoder_buffer(prefix: str = "Bn"):
1017
- if is_encoder_taylorseer_enabled():
1018
- if is_separate_cfg_step():
1019
- _, encoder_taylorseer = get_cfg_taylorseers()
1020
- else:
1021
- _, encoder_taylorseer = get_taylorseers()
1022
-
1023
- if encoder_taylorseer is not None:
1024
- # Use TaylorSeer to approximate the value
1025
- return encoder_taylorseer.approximate_value()
1026
- else:
1027
- if logger.isEnabledFor(logging.DEBUG):
1028
- logger.debug(
1029
- "TaylorSeer is enabled but not set in the cache context. "
1030
- "Falling back to default buffer retrieval."
1031
- )
1032
- # Fallback to default buffer retrieval
1033
- if is_separate_cfg_step():
1034
- _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
1035
- return get_buffer(f"{prefix}_encoder_buffer_cfg")
1036
- _debugging_get_buffer(f"{prefix}_encoder_buffer")
1037
- return get_buffer(f"{prefix}_encoder_buffer")
1038
- else:
1039
- if is_separate_cfg_step():
1040
- _debugging_get_buffer(f"{prefix}_encoder_buffer_cfg")
1041
- return get_buffer(f"{prefix}_encoder_buffer_cfg")
1042
- _debugging_get_buffer(f"{prefix}_encoder_buffer")
1043
- return get_buffer(f"{prefix}_encoder_buffer")
1044
-
1045
-
1046
- @torch.compiler.disable
1047
- def apply_hidden_states_residual(
1048
- hidden_states: torch.Tensor,
1049
- encoder_hidden_states: torch.Tensor = None,
1050
- prefix: str = "Bn",
1051
- encoder_prefix: str = "Bn_encoder",
1052
- ):
1053
- # Allow Bn and Fn prefix to be used for residual cache.
1054
- if "Bn" in prefix:
1055
- hidden_states_prev = get_Bn_buffer(prefix)
1056
- else:
1057
- hidden_states_prev = get_Fn_buffer(prefix)
1058
-
1059
- assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
1060
-
1061
- if is_cache_residual():
1062
- hidden_states = hidden_states_prev + hidden_states
1063
- else:
1064
- # If cache is not residual, we use the hidden states directly
1065
- hidden_states = hidden_states_prev
1066
-
1067
- hidden_states = hidden_states.contiguous()
1068
-
1069
- if encoder_hidden_states is not None:
1070
- if "Bn" in encoder_prefix:
1071
- encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
1072
- else:
1073
- encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
1074
-
1075
- assert (
1076
- encoder_hidden_states_prev is not None
1077
- ), f"{prefix}_encoder_buffer must be set before"
1078
-
1079
- if is_encoder_cache_residual():
1080
- encoder_hidden_states = (
1081
- encoder_hidden_states_prev + encoder_hidden_states
1082
- )
1083
- else:
1084
- # If encoder cache is not residual, we use the encoder hidden states directly
1085
- encoder_hidden_states = encoder_hidden_states_prev
1086
-
1087
- encoder_hidden_states = encoder_hidden_states.contiguous()
1088
-
1089
- return hidden_states, encoder_hidden_states
1090
-
1091
-
1092
- @torch.compiler.disable
1093
- def get_downsample_factor():
1094
- cache_context = get_cache_context()
1095
- assert cache_context is not None, "cache_context must be set before"
1096
- return cache_context.downsample_factor
1097
-
1098
-
1099
- @torch.compiler.disable
1100
- def get_can_use_cache(
1101
- states_tensor: torch.Tensor, # hidden_states or residual
1102
- parallelized: bool = False,
1103
- threshold: Optional[float] = None, # can manually set threshold
1104
- prefix: str = "Fn",
1105
- ):
1106
- if is_in_warmup():
1107
- return False
1108
-
1109
- # max cached steps
1110
- max_cached_steps = get_max_cached_steps()
1111
- if not is_separate_cfg_step():
1112
- cached_steps = get_cached_steps()
1113
- else:
1114
- cached_steps = get_cfg_cached_steps()
1115
-
1116
- if max_cached_steps >= 0 and (len(cached_steps) >= max_cached_steps):
1117
- if logger.isEnabledFor(logging.DEBUG):
1118
- logger.debug(
1119
- f"{prefix}, max_cached_steps reached: {max_cached_steps}, "
1120
- "can not use cache."
1121
- )
1122
- return False
1123
-
1124
- # max continuous cached steps
1125
- max_continuous_cached_steps = get_max_continuous_cached_steps()
1126
- if not is_separate_cfg_step():
1127
- continuous_cached_steps = get_continuous_cached_steps()
1128
- else:
1129
- continuous_cached_steps = get_cfg_continuous_cached_steps()
1130
-
1131
- if max_continuous_cached_steps >= 0 and (
1132
- continuous_cached_steps >= max_continuous_cached_steps
1133
- ):
1134
- if logger.isEnabledFor(logging.DEBUG):
1135
- logger.debug(
1136
- f"{prefix}, max_continuous_cached_steps "
1137
- f"reached: {max_continuous_cached_steps}, "
1138
- "can not use cache."
1139
- )
1140
- # reset continuous cached steps stats
1141
- cache_context = get_cache_context()
1142
- if not is_separate_cfg_step():
1143
- cache_context.continuous_cached_steps = 0
1144
- else:
1145
- cache_context.cfg_continuous_cached_steps = 0
1146
- return False
1147
-
1148
- if threshold is None or threshold <= 0.0:
1149
- threshold = get_residual_diff_threshold()
1150
- if threshold <= 0.0:
1151
- return False
1152
-
1153
- downsample_factor = get_downsample_factor()
1154
- if downsample_factor > 1 and "Bn" not in prefix:
1155
- states_tensor = states_tensor[..., ::downsample_factor]
1156
- states_tensor = states_tensor.contiguous()
1157
-
1158
- # Allow Bn and Fn prefix to be used for diff calculation.
1159
- if "Bn" in prefix:
1160
- prev_states_tensor = get_Bn_buffer(prefix)
1161
- else:
1162
- prev_states_tensor = get_Fn_buffer(prefix)
1163
-
1164
- if not is_alter_cache_enabled():
1165
- # Dynamic cache according to the residual diff
1166
- can_use_cache = (
1167
- prev_states_tensor is not None
1168
- and are_two_tensors_similar(
1169
- prev_states_tensor,
1170
- states_tensor,
1171
- threshold=threshold,
1172
- parallelized=parallelized,
1173
- prefix=prefix,
1174
- )
1175
- )
1176
- else:
1177
- # Only cache in the alter cache steps
1178
- can_use_cache = (
1179
- prev_states_tensor is not None
1180
- and are_two_tensors_similar(
1181
- prev_states_tensor,
1182
- states_tensor,
1183
- threshold=threshold,
1184
- parallelized=parallelized,
1185
- prefix=prefix,
1186
- )
1187
- and is_alter_cache()
1188
- )
1189
- return can_use_cache