cache-dit 0.2.1__py3-none-any.whl → 0.2.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (27) hide show
  1. cache_dit/_version.py +2 -2
  2. cache_dit/cache_factory/dual_block_cache/cache_context.py +282 -57
  3. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -2
  4. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -2
  5. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -2
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -1
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -1
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +1 -3
  9. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -2
  10. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -2
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -2
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -1
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -2
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -2
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +23 -23
  16. cache_dit/cache_factory/first_block_cache/cache_context.py +3 -11
  17. cache_dit/cache_factory/taylorseer.py +29 -0
  18. cache_dit/compile/__init__.py +1 -0
  19. cache_dit/compile/utils.py +94 -0
  20. cache_dit/custom_ops/__init__.py +0 -0
  21. cache_dit/custom_ops/triton_taylorseer.py +0 -0
  22. cache_dit/logger.py +28 -0
  23. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/METADATA +76 -39
  24. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/RECORD +27 -23
  25. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/WHEEL +0 -0
  26. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/licenses/LICENSE +0 -0
  27. {cache_dit-0.2.1.dist-info → cache_dit-0.2.3.dist-info}/top_level.txt +0 -0
cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.1'
21
- __version_tuple__ = version_tuple = (0, 2, 1)
20
+ __version__ = version = '0.2.3'
21
+ __version_tuple__ = version_tuple = (0, 2, 3)
@@ -61,19 +61,55 @@ class DBCacheContext:
61
61
  residual_diffs: DefaultDict[str, float] = dataclasses.field(
62
62
  default_factory=lambda: defaultdict(float),
63
63
  )
64
- # TODO: Support TaylorSeers and SLG in Dual Block Cache
65
- # TaylorSeers:
64
+ # Support TaylorSeers in Dual Block Cache
66
65
  # Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
67
66
  # Url: https://arxiv.org/pdf/2503.06923
67
+ enable_taylorseer: bool = False
68
+ enable_encoder_taylorseer: bool = False
69
+ # NOTE: use residual cache for taylorseer may incur precision loss
70
+ taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
71
+ taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
68
72
  taylorseer: Optional[TaylorSeer] = None
73
+ encoder_tarlorseer: Optional[TaylorSeer] = None
69
74
  alter_taylorseer: Optional[TaylorSeer] = None
75
+ alter_encoder_taylorseer: Optional[TaylorSeer] = None
70
76
 
77
+ # TODO: Support SLG in Dual Block Cache
71
78
  # Skip Layer Guidance, SLG
72
79
  # https://github.com/huggingface/candle/issues/2588
73
80
  slg_layers: Optional[List[int]] = None
74
81
  slg_start: float = 0.0
75
82
  slg_end: float = 0.1
76
83
 
84
+ @torch.compiler.disable
85
+ def __post_init__(self):
86
+
87
+ if "warmup_steps" not in self.taylorseer_kwargs:
88
+ # If warmup_steps is not set in taylorseer_kwargs,
89
+ # set the same as warmup_steps for DBCache
90
+ self.taylorseer_kwargs["warmup_steps"] = (
91
+ self.warmup_steps if self.warmup_steps > 0 else 1
92
+ )
93
+
94
+ # Only set n_derivatives as 2 or 3, which is enough for most cases.
95
+ if "n_derivatives" not in self.taylorseer_kwargs:
96
+ self.taylorseer_kwargs["n_derivatives"] = max(
97
+ 2, min(3, self.taylorseer_kwargs["warmup_steps"])
98
+ )
99
+
100
+ if self.enable_taylorseer:
101
+ self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
102
+ if self.enable_alter_cache:
103
+ self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
104
+
105
+ if self.enable_encoder_taylorseer:
106
+ self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
107
+ if self.enable_alter_cache:
108
+ self.alter_encoder_taylorseer = TaylorSeer(
109
+ **self.taylorseer_kwargs
110
+ )
111
+
112
+ @torch.compiler.disable
77
113
  def get_incremental_name(self, name=None):
78
114
  if name is None:
79
115
  name = "default"
@@ -81,9 +117,11 @@ class DBCacheContext:
81
117
  self.incremental_name_counters[name] += 1
82
118
  return f"{name}_{idx}"
83
119
 
120
+ @torch.compiler.disable
84
121
  def reset_incremental_names(self):
85
122
  self.incremental_name_counters.clear()
86
123
 
124
+ @torch.compiler.disable
87
125
  def get_residual_diff_threshold(self):
88
126
  if self.enable_alter_cache:
89
127
  residual_diff_threshold = self.alter_residual_diff_threshold
@@ -96,25 +134,30 @@ class DBCacheContext:
96
134
  residual_diff_threshold = residual_diff_threshold.item()
97
135
  return residual_diff_threshold
98
136
 
137
+ @torch.compiler.disable
99
138
  def get_buffer(self, name):
100
139
  if self.enable_alter_cache and self.is_alter_cache:
101
140
  name = f"{name}_alter"
102
141
  return self.buffers.get(name)
103
142
 
143
+ @torch.compiler.disable
104
144
  def set_buffer(self, name, buffer):
105
145
  if self.enable_alter_cache and self.is_alter_cache:
106
146
  name = f"{name}_alter"
107
147
  self.buffers[name] = buffer
108
148
 
149
+ @torch.compiler.disable
109
150
  def remove_buffer(self, name):
110
151
  if self.enable_alter_cache and self.is_alter_cache:
111
152
  name = f"{name}_alter"
112
153
  if name in self.buffers:
113
154
  del self.buffers[name]
114
155
 
156
+ @torch.compiler.disable
115
157
  def clear_buffers(self):
116
158
  self.buffers.clear()
117
159
 
160
+ @torch.compiler.disable
118
161
  def mark_step_begin(self):
119
162
  if not self.enable_alter_cache:
120
163
  self.executed_steps += 1
@@ -129,25 +172,53 @@ class DBCacheContext:
129
172
  self.cached_steps.clear()
130
173
  self.residual_diffs.clear()
131
174
  self.reset_incremental_names()
175
+ # Reset the TaylorSeers cache at the beginning of each inference.
176
+ # reset_cache will set the current step to -1 for TaylorSeer,
177
+ if self.enable_taylorseer or self.enable_encoder_taylorseer:
178
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
179
+ if taylorseer is not None:
180
+ taylorseer.reset_cache()
181
+ if encoder_taylorseer is not None:
182
+ encoder_taylorseer.reset_cache()
183
+
184
+ # mark_step_begin of TaylorSeer must be called after the cache is reset.
185
+ if self.enable_taylorseer or self.enable_encoder_taylorseer:
186
+ taylorseer, encoder_taylorseer = self.get_taylorseers()
187
+ if taylorseer is not None:
188
+ taylorseer.mark_step_begin()
189
+ if encoder_taylorseer is not None:
190
+ encoder_taylorseer.mark_step_begin()
132
191
 
192
+ @torch.compiler.disable
193
+ def get_taylorseers(self):
194
+ if self.enable_alter_cache and self.is_alter_cache:
195
+ return self.alter_taylorseer, self.alter_encoder_taylorseer
196
+ return self.taylorseer, self.encoder_tarlorseer
197
+
198
+ @torch.compiler.disable
133
199
  def add_residual_diff(self, diff):
134
200
  step = str(self.get_current_step())
135
201
  if step not in self.residual_diffs:
136
202
  # Only add the diff if it is not already recorded for this step
137
203
  self.residual_diffs[step] = diff
138
204
 
205
+ @torch.compiler.disable
139
206
  def get_residual_diffs(self):
140
207
  return self.residual_diffs.copy()
141
208
 
209
+ @torch.compiler.disable
142
210
  def add_cached_step(self):
143
211
  self.cached_steps.append(self.get_current_step())
144
212
 
213
+ @torch.compiler.disable
145
214
  def get_cached_steps(self):
146
215
  return self.cached_steps.copy()
147
216
 
217
+ @torch.compiler.disable
148
218
  def get_current_step(self):
149
219
  return self.executed_steps - 1
150
220
 
221
+ @torch.compiler.disable
151
222
  def is_in_warmup(self):
152
223
  return self.get_current_step() < self.warmup_steps
153
224
 
@@ -229,6 +300,50 @@ def get_residual_diffs():
229
300
  return cache_context.get_residual_diffs()
230
301
 
231
302
 
303
+ @torch.compiler.disable
304
+ def is_taylorseer_enabled():
305
+ cache_context = get_current_cache_context()
306
+ assert cache_context is not None, "cache_context must be set before"
307
+ return cache_context.enable_taylorseer
308
+
309
+
310
+ @torch.compiler.disable
311
+ def is_encoder_taylorseer_enabled():
312
+ cache_context = get_current_cache_context()
313
+ assert cache_context is not None, "cache_context must be set before"
314
+ return cache_context.enable_encoder_taylorseer
315
+
316
+
317
+ @torch.compiler.disable
318
+ def get_taylorseers():
319
+ cache_context = get_current_cache_context()
320
+ assert cache_context is not None, "cache_context must be set before"
321
+ return cache_context.get_taylorseers()
322
+
323
+
324
+ @torch.compiler.disable
325
+ def is_taylorseer_cache_residual():
326
+ cache_context = get_current_cache_context()
327
+ assert cache_context is not None, "cache_context must be set before"
328
+ return cache_context.taylorseer_cache_type == "residual"
329
+
330
+
331
+ @torch.compiler.disable
332
+ def is_cache_residual():
333
+ if is_taylorseer_enabled():
334
+ # residual or hidden_states
335
+ return is_taylorseer_cache_residual()
336
+ return True
337
+
338
+
339
+ @torch.compiler.disable
340
+ def is_encoder_cache_residual():
341
+ if is_encoder_taylorseer_enabled():
342
+ # residual or hidden_states
343
+ return is_taylorseer_cache_residual()
344
+ return True
345
+
346
+
232
347
  @torch.compiler.disable
233
348
  def is_alter_cache_enabled():
234
349
  cache_context = get_current_cache_context()
@@ -380,16 +495,21 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
380
495
  for attr in cache_attrs
381
496
  }
382
497
 
498
+ def _safe_set_sequence_field(
499
+ field_name: str,
500
+ default_value: Any = None,
501
+ ):
502
+ if field_name not in cache_kwargs:
503
+ cache_kwargs[field_name] = kwargs.pop(
504
+ field_name,
505
+ default_value,
506
+ )
507
+
383
508
  # Manually set sequence fields, namely, Fn_compute_blocks_ids
384
509
  # and Bn_compute_blocks_ids, which are lists or sets.
385
- cache_kwargs["Fn_compute_blocks_ids"] = kwargs.pop(
386
- "Fn_compute_blocks_ids",
387
- [],
388
- )
389
- cache_kwargs["Bn_compute_blocks_ids"] = kwargs.pop(
390
- "Bn_compute_blocks_ids",
391
- [],
392
- )
510
+ _safe_set_sequence_field("Fn_compute_blocks_ids", [])
511
+ _safe_set_sequence_field("Bn_compute_blocks_ids", [])
512
+ _safe_set_sequence_field("taylorseer_kwargs", {})
393
513
 
394
514
  assert default_attrs is not None, "default_attrs must be set before"
395
515
  for attr in cache_attrs:
@@ -484,6 +604,7 @@ def are_two_tensors_similar(
484
604
  @torch.compiler.disable
485
605
  def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
486
606
  # Set hidden_states or residual for Fn blocks.
607
+ # This buffer is only use for L1 diff calculation.
487
608
  downsample_factor = get_downsample_factor()
488
609
  if downsample_factor > 1:
489
610
  buffer = buffer[..., ::downsample_factor]
@@ -510,22 +631,79 @@ def get_Fn_encoder_buffer(prefix: str = "Fn"):
510
631
  @torch.compiler.disable
511
632
  def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
512
633
  # Set hidden_states or residual for Bn blocks.
513
- set_buffer(f"{prefix}_buffer", buffer)
634
+ # This buffer is use for hidden states approximation.
635
+ if is_taylorseer_enabled():
636
+ # taylorseer, encoder_taylorseer
637
+ taylorseer, _ = get_taylorseers()
638
+ if taylorseer is not None:
639
+ # Use TaylorSeer to update the buffer
640
+ taylorseer.update(buffer)
641
+ else:
642
+ if logger.isEnabledFor(logging.DEBUG):
643
+ logger.debug(
644
+ "TaylorSeer is enabled but not set in the cache context. "
645
+ "Falling back to default buffer retrieval."
646
+ )
647
+ set_buffer(f"{prefix}_buffer", buffer)
648
+ else:
649
+ set_buffer(f"{prefix}_buffer", buffer)
514
650
 
515
651
 
516
652
  @torch.compiler.disable
517
653
  def get_Bn_buffer(prefix: str = "Bn"):
518
- return get_buffer(f"{prefix}_buffer")
654
+ if is_taylorseer_enabled():
655
+ taylorseer, _ = get_taylorseers()
656
+ if taylorseer is not None:
657
+ return taylorseer.approximate_value()
658
+ else:
659
+ if logger.isEnabledFor(logging.DEBUG):
660
+ logger.debug(
661
+ "TaylorSeer is enabled but not set in the cache context. "
662
+ "Falling back to default buffer retrieval."
663
+ )
664
+ # Fallback to default buffer retrieval
665
+ return get_buffer(f"{prefix}_buffer")
666
+ else:
667
+ return get_buffer(f"{prefix}_buffer")
519
668
 
520
669
 
521
670
  @torch.compiler.disable
522
671
  def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
523
- set_buffer(f"{prefix}_encoder_buffer", buffer)
672
+ # This buffer is use for encoder hidden states approximation.
673
+ if is_encoder_taylorseer_enabled():
674
+ # taylorseer, encoder_taylorseer
675
+ _, encoder_taylorseer = get_taylorseers()
676
+ if encoder_taylorseer is not None:
677
+ # Use TaylorSeer to update the buffer
678
+ encoder_taylorseer.update(buffer)
679
+ else:
680
+ if logger.isEnabledFor(logging.DEBUG):
681
+ logger.debug(
682
+ "TaylorSeer is enabled but not set in the cache context. "
683
+ "Falling back to default buffer retrieval."
684
+ )
685
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
686
+ else:
687
+ set_buffer(f"{prefix}_encoder_buffer", buffer)
524
688
 
525
689
 
526
690
  @torch.compiler.disable
527
691
  def get_Bn_encoder_buffer(prefix: str = "Bn"):
528
- return get_buffer(f"{prefix}_encoder_buffer")
692
+ if is_encoder_taylorseer_enabled():
693
+ _, encoder_taylorseer = get_taylorseers()
694
+ if encoder_taylorseer is not None:
695
+ # Use TaylorSeer to approximate the value
696
+ return encoder_taylorseer.approximate_value()
697
+ else:
698
+ if logger.isEnabledFor(logging.DEBUG):
699
+ logger.debug(
700
+ "TaylorSeer is enabled but not set in the cache context. "
701
+ "Falling back to default buffer retrieval."
702
+ )
703
+ # Fallback to default buffer retrieval
704
+ return get_buffer(f"{prefix}_encoder_buffer")
705
+ else:
706
+ return get_buffer(f"{prefix}_encoder_buffer")
529
707
 
530
708
 
531
709
  @torch.compiler.disable
@@ -533,29 +711,38 @@ def apply_hidden_states_residual(
533
711
  hidden_states: torch.Tensor,
534
712
  encoder_hidden_states: torch.Tensor,
535
713
  prefix: str = "Bn",
714
+ encoder_prefix: str = "Bn_encoder",
536
715
  ):
537
716
  # Allow Bn and Fn prefix to be used for residual cache.
538
717
  if "Bn" in prefix:
539
- hidden_states_residual = get_Bn_buffer(prefix)
718
+ hidden_states_prev = get_Bn_buffer(prefix)
540
719
  else:
541
- hidden_states_residual = get_Fn_buffer(prefix)
720
+ hidden_states_prev = get_Fn_buffer(prefix)
542
721
 
543
- assert (
544
- hidden_states_residual is not None
545
- ), f"{prefix}_buffer must be set before"
546
- hidden_states = hidden_states_residual + hidden_states
722
+ assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
547
723
 
548
- if "Bn" in prefix:
549
- encoder_hidden_states_residual = get_Bn_encoder_buffer(prefix)
724
+ if is_cache_residual():
725
+ hidden_states = hidden_states_prev + hidden_states
726
+ else:
727
+ # If cache is not residual, we use the hidden states directly
728
+ hidden_states = hidden_states_prev
729
+
730
+ if "Bn" in encoder_prefix:
731
+ encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
550
732
  else:
551
- encoder_hidden_states_residual = get_Fn_encoder_buffer(prefix)
733
+ encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
552
734
 
553
735
  assert (
554
- encoder_hidden_states_residual is not None
736
+ encoder_hidden_states_prev is not None
555
737
  ), f"{prefix}_encoder_buffer must be set before"
556
- encoder_hidden_states = (
557
- encoder_hidden_states_residual + encoder_hidden_states
558
- )
738
+
739
+ if is_encoder_cache_residual():
740
+ encoder_hidden_states = (
741
+ encoder_hidden_states_prev + encoder_hidden_states
742
+ )
743
+ else:
744
+ # If encoder cache is not residual, we use the encoder hidden states directly
745
+ encoder_hidden_states = encoder_hidden_states_prev
559
746
 
560
747
  hidden_states = hidden_states.contiguous()
561
748
  encoder_hidden_states = encoder_hidden_states.contiguous()
@@ -687,11 +874,22 @@ class DBCachedTransformerBlocks(torch.nn.Module):
687
874
 
688
875
  torch._dynamo.graph_break()
689
876
  if can_use_cache:
877
+ torch._dynamo.graph_break()
690
878
  add_cached_step()
691
879
  del Fn_hidden_states_residual
692
880
  hidden_states, encoder_hidden_states = apply_hidden_states_residual(
693
- hidden_states, encoder_hidden_states, prefix="Bn_residual"
881
+ hidden_states,
882
+ encoder_hidden_states,
883
+ prefix=(
884
+ "Bn_residual" if is_cache_residual() else "Bn_hidden_states"
885
+ ),
886
+ encoder_prefix=(
887
+ "Bn_residual"
888
+ if is_encoder_cache_residual()
889
+ else "Bn_hidden_states"
890
+ ),
694
891
  )
892
+ torch._dynamo.graph_break()
695
893
  # Call last `n` blocks to further process the hidden states
696
894
  # for higher precision.
697
895
  hidden_states, encoder_hidden_states = (
@@ -703,11 +901,13 @@ class DBCachedTransformerBlocks(torch.nn.Module):
703
901
  )
704
902
  )
705
903
  else:
904
+ torch._dynamo.graph_break()
706
905
  set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
707
906
  if is_l1_diff_enabled():
708
907
  # for hidden states L1 diff
709
908
  set_Fn_buffer(hidden_states, "Fn_hidden_states")
710
909
  del Fn_hidden_states_residual
910
+ torch._dynamo.graph_break()
711
911
  (
712
912
  hidden_states,
713
913
  encoder_hidden_states,
@@ -719,10 +919,30 @@ class DBCachedTransformerBlocks(torch.nn.Module):
719
919
  *args,
720
920
  **kwargs,
721
921
  )
722
- set_Bn_buffer(hidden_states_residual, prefix="Bn_residual")
723
- set_Bn_encoder_buffer(
724
- encoder_hidden_states_residual, prefix="Bn_residual"
725
- )
922
+ torch._dynamo.graph_break()
923
+ if is_cache_residual():
924
+ set_Bn_buffer(
925
+ hidden_states_residual,
926
+ prefix="Bn_residual",
927
+ )
928
+ else:
929
+ # TaylorSeer
930
+ set_Bn_buffer(
931
+ hidden_states,
932
+ prefix="Bn_hidden_states",
933
+ )
934
+ if is_encoder_cache_residual():
935
+ set_Bn_encoder_buffer(
936
+ encoder_hidden_states_residual,
937
+ prefix="Bn_residual",
938
+ )
939
+ else:
940
+ # TaylorSeer
941
+ set_Bn_encoder_buffer(
942
+ encoder_hidden_states,
943
+ prefix="Bn_hidden_states",
944
+ )
945
+ torch._dynamo.graph_break()
726
946
  # Call last `n` blocks to further process the hidden states
727
947
  # for higher precision.
728
948
  hidden_states, encoder_hidden_states = (
@@ -772,16 +992,6 @@ class DBCachedTransformerBlocks(torch.nn.Module):
772
992
  selected_Fn_transformer_blocks = self.transformer_blocks[
773
993
  : Fn_compute_blocks()
774
994
  ]
775
- # Skip the blocks if they are not in the Fn_compute_blocks_ids.
776
- # WARN: DON'T set len(Fn_compute_blocks_ids) > 0 NOW, still have
777
- # some precision issues. We don't know whether a step should be
778
- # cached or not before the first Fn blocks are processed.
779
- if len(Fn_compute_blocks_ids()) > 0:
780
- selected_Fn_transformer_blocks = [
781
- selected_Fn_transformer_blocks[i]
782
- for i in Fn_compute_blocks_ids()
783
- if i < len(selected_Fn_transformer_blocks)
784
- ]
785
995
  return selected_Fn_transformer_blocks
786
996
 
787
997
  @torch.compiler.disable
@@ -800,7 +1010,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
800
1010
  return selected_Mn_single_transformer_blocks
801
1011
 
802
1012
  @torch.compiler.disable
803
- def _Mn_transformer_blocks(self): # middle blocks
1013
+ def _Mn_transformer_blocks(self): # middle blocks
804
1014
  # M(N-2n): only transformer_blocks [n,...,N-n], middle
805
1015
  if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
806
1016
  selected_Mn_transformer_blocks = self.transformer_blocks[
@@ -1074,6 +1284,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1074
1284
  Bn_i_original_hidden_states,
1075
1285
  prefix=f"Bn_{block_id}_single_original",
1076
1286
  )
1287
+ set_Bn_encoder_buffer(
1288
+ Bn_i_original_hidden_states,
1289
+ prefix=f"Bn_{block_id}_single_original",
1290
+ )
1077
1291
 
1078
1292
  set_Bn_buffer(
1079
1293
  Bn_i_hidden_states_residual,
@@ -1121,7 +1335,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1121
1335
  apply_hidden_states_residual(
1122
1336
  Bn_i_original_hidden_states,
1123
1337
  Bn_i_original_encoder_hidden_states,
1124
- prefix=f"Bn_{block_id}_single_residual",
1338
+ prefix=(
1339
+ f"Bn_{block_id}_single_residual"
1340
+ if is_cache_residual()
1341
+ else f"Bn_{block_id}_single_original"
1342
+ ),
1343
+ encoder_prefix=(
1344
+ f"Bn_{block_id}_single_residual"
1345
+ if is_encoder_cache_residual()
1346
+ else f"Bn_{block_id}_single_original"
1347
+ ),
1125
1348
  )
1126
1349
  )
1127
1350
  hidden_states = torch.cat(
@@ -1142,7 +1365,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1142
1365
  self,
1143
1366
  # Block index in the transformer blocks
1144
1367
  # Bn: 8, block_id should be in [0, 8)
1145
- block_id: int,
1368
+ block_id: int,
1146
1369
  # Below are the inputs to the block
1147
1370
  block, # The transformer block to be executed
1148
1371
  hidden_states: torch.Tensor,
@@ -1188,6 +1411,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1188
1411
  Bn_i_original_hidden_states,
1189
1412
  prefix=f"Bn_{block_id}_original",
1190
1413
  )
1414
+ set_Bn_encoder_buffer(
1415
+ Bn_i_original_encoder_hidden_states,
1416
+ prefix=f"Bn_{block_id}_original",
1417
+ )
1191
1418
 
1192
1419
  set_Bn_buffer(
1193
1420
  Bn_i_hidden_states_residual,
@@ -1234,7 +1461,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
1234
1461
  apply_hidden_states_residual(
1235
1462
  hidden_states,
1236
1463
  encoder_hidden_states,
1237
- prefix=f"Bn_{block_id}_residual",
1464
+ prefix=(
1465
+ f"Bn_{block_id}_residual"
1466
+ if is_cache_residual()
1467
+ else f"Bn_{block_id}_original"
1468
+ ),
1469
+ encoder_prefix=(
1470
+ f"Bn_{block_id}_residual"
1471
+ if is_encoder_cache_residual()
1472
+ else f"Bn_{block_id}_original"
1473
+ ),
1238
1474
  )
1239
1475
  )
1240
1476
  else:
@@ -1362,17 +1598,6 @@ def patch_cached_stats(
1362
1598
  if transformer is None:
1363
1599
  return
1364
1600
 
1365
- cached_transformer_blocks = getattr(transformer, "transformer_blocks", None)
1366
- if cached_transformer_blocks is None:
1367
- return
1368
-
1369
- if isinstance(cached_transformer_blocks, torch.nn.ModuleList):
1370
- cached_transformer_blocks = cached_transformer_blocks[0]
1371
- if not isinstance(
1372
- cached_transformer_blocks, DBCachedTransformerBlocks
1373
- ) or not isinstance(transformer, torch.nn.Module):
1374
- return
1375
-
1376
1601
  # TODO: Patch more cached stats to the transformer
1377
1602
  transformer._cached_steps = get_cached_steps()
1378
1603
  transformer._residual_diffs = get_residual_diffs()
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/__init__.py
2
-
3
1
  import importlib
4
2
 
5
3
  from diffusers import DiffusionPipeline
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/cogvideox.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/flux.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
@@ -1,4 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
2
1
  import functools
3
2
  import unittest
4
3
  from typing import Any, Dict, Optional, Union
@@ -1,4 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/mochi.py
2
1
  import functools
3
2
  import unittest
4
3
 
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
@@ -94,6 +92,6 @@ def apply_db_cache_on_pipe(
94
92
  pipe.__class__._is_cached = True
95
93
 
96
94
  if not shallow_patch:
97
- apply_db_cache_on_transformer(pipe.transformer, **kwargs)
95
+ apply_db_cache_on_transformer(pipe.transformer)
98
96
 
99
97
  return pipe
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/__init__.py
2
-
3
1
  import importlib
4
2
 
5
3
  from diffusers import DiffusionPipeline
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/cogvideox.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/flux.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
@@ -1,4 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
2
1
  import functools
3
2
  import unittest
4
3
  from typing import Any, Dict, Optional, Union
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/mochi.py
2
-
3
1
  import functools
4
2
  import unittest
5
3
 
@@ -1,5 +1,3 @@
1
- # Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
2
-
3
1
  import functools
4
2
  import unittest
5
3