cache-dit 0.2.1__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +282 -46
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -1
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -2
- cache_dit/cache_factory/first_block_cache/cache_context.py +3 -0
- cache_dit/cache_factory/taylorseer.py +30 -0
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.2.dist-info}/METADATA +72 -39
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.2.dist-info}/RECORD +21 -21
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.2.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.1.dist-info → cache_dit-0.2.2.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -61,19 +61,55 @@ class DBCacheContext:
|
|
|
61
61
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
62
62
|
default_factory=lambda: defaultdict(float),
|
|
63
63
|
)
|
|
64
|
-
#
|
|
65
|
-
# TaylorSeers:
|
|
64
|
+
# Support TaylorSeers in Dual Block Cache
|
|
66
65
|
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
67
66
|
# Url: https://arxiv.org/pdf/2503.06923
|
|
67
|
+
enable_taylorseer: bool = False
|
|
68
|
+
enable_encoder_taylorseer: bool = False
|
|
69
|
+
# NOTE: use residual cache for taylorseer may incur precision loss
|
|
70
|
+
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
71
|
+
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
68
72
|
taylorseer: Optional[TaylorSeer] = None
|
|
73
|
+
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
69
74
|
alter_taylorseer: Optional[TaylorSeer] = None
|
|
75
|
+
alter_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
70
76
|
|
|
77
|
+
# TODO: Support SLG in Dual Block Cache
|
|
71
78
|
# Skip Layer Guidance, SLG
|
|
72
79
|
# https://github.com/huggingface/candle/issues/2588
|
|
73
80
|
slg_layers: Optional[List[int]] = None
|
|
74
81
|
slg_start: float = 0.0
|
|
75
82
|
slg_end: float = 0.1
|
|
76
83
|
|
|
84
|
+
@torch.compiler.disable
|
|
85
|
+
def __post_init__(self):
|
|
86
|
+
|
|
87
|
+
if "warmup_steps" not in self.taylorseer_kwargs:
|
|
88
|
+
# If warmup_steps is not set in taylorseer_kwargs,
|
|
89
|
+
# set the same as warmup_steps for DBCache
|
|
90
|
+
self.taylorseer_kwargs["warmup_steps"] = (
|
|
91
|
+
self.warmup_steps if self.warmup_steps > 0 else 1
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Only set n_derivatives as 2 or 3, which is enough for most cases.
|
|
95
|
+
if "n_derivatives" not in self.taylorseer_kwargs:
|
|
96
|
+
self.taylorseer_kwargs["n_derivatives"] = max(
|
|
97
|
+
2, min(3, self.taylorseer_kwargs["warmup_steps"])
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if self.enable_taylorseer:
|
|
101
|
+
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
102
|
+
if self.enable_alter_cache:
|
|
103
|
+
self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
104
|
+
|
|
105
|
+
if self.enable_encoder_taylorseer:
|
|
106
|
+
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
107
|
+
if self.enable_alter_cache:
|
|
108
|
+
self.alter_encoder_taylorseer = TaylorSeer(
|
|
109
|
+
**self.taylorseer_kwargs
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@torch.compiler.disable
|
|
77
113
|
def get_incremental_name(self, name=None):
|
|
78
114
|
if name is None:
|
|
79
115
|
name = "default"
|
|
@@ -81,9 +117,11 @@ class DBCacheContext:
|
|
|
81
117
|
self.incremental_name_counters[name] += 1
|
|
82
118
|
return f"{name}_{idx}"
|
|
83
119
|
|
|
120
|
+
@torch.compiler.disable
|
|
84
121
|
def reset_incremental_names(self):
|
|
85
122
|
self.incremental_name_counters.clear()
|
|
86
123
|
|
|
124
|
+
@torch.compiler.disable
|
|
87
125
|
def get_residual_diff_threshold(self):
|
|
88
126
|
if self.enable_alter_cache:
|
|
89
127
|
residual_diff_threshold = self.alter_residual_diff_threshold
|
|
@@ -96,25 +134,30 @@ class DBCacheContext:
|
|
|
96
134
|
residual_diff_threshold = residual_diff_threshold.item()
|
|
97
135
|
return residual_diff_threshold
|
|
98
136
|
|
|
137
|
+
@torch.compiler.disable
|
|
99
138
|
def get_buffer(self, name):
|
|
100
139
|
if self.enable_alter_cache and self.is_alter_cache:
|
|
101
140
|
name = f"{name}_alter"
|
|
102
141
|
return self.buffers.get(name)
|
|
103
142
|
|
|
143
|
+
@torch.compiler.disable
|
|
104
144
|
def set_buffer(self, name, buffer):
|
|
105
145
|
if self.enable_alter_cache and self.is_alter_cache:
|
|
106
146
|
name = f"{name}_alter"
|
|
107
147
|
self.buffers[name] = buffer
|
|
108
148
|
|
|
149
|
+
@torch.compiler.disable
|
|
109
150
|
def remove_buffer(self, name):
|
|
110
151
|
if self.enable_alter_cache and self.is_alter_cache:
|
|
111
152
|
name = f"{name}_alter"
|
|
112
153
|
if name in self.buffers:
|
|
113
154
|
del self.buffers[name]
|
|
114
155
|
|
|
156
|
+
@torch.compiler.disable
|
|
115
157
|
def clear_buffers(self):
|
|
116
158
|
self.buffers.clear()
|
|
117
159
|
|
|
160
|
+
@torch.compiler.disable
|
|
118
161
|
def mark_step_begin(self):
|
|
119
162
|
if not self.enable_alter_cache:
|
|
120
163
|
self.executed_steps += 1
|
|
@@ -129,25 +172,53 @@ class DBCacheContext:
|
|
|
129
172
|
self.cached_steps.clear()
|
|
130
173
|
self.residual_diffs.clear()
|
|
131
174
|
self.reset_incremental_names()
|
|
175
|
+
# Reset the TaylorSeers cache at the beginning of each inference.
|
|
176
|
+
# reset_cache will set the current step to -1 for TaylorSeer,
|
|
177
|
+
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
178
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
179
|
+
if taylorseer is not None:
|
|
180
|
+
taylorseer.reset_cache()
|
|
181
|
+
if encoder_taylorseer is not None:
|
|
182
|
+
encoder_taylorseer.reset_cache()
|
|
183
|
+
|
|
184
|
+
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
185
|
+
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
186
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
187
|
+
if taylorseer is not None:
|
|
188
|
+
taylorseer.mark_step_begin()
|
|
189
|
+
if encoder_taylorseer is not None:
|
|
190
|
+
encoder_taylorseer.mark_step_begin()
|
|
132
191
|
|
|
192
|
+
@torch.compiler.disable
|
|
193
|
+
def get_taylorseers(self):
|
|
194
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
195
|
+
return self.alter_taylorseer, self.alter_encoder_taylorseer
|
|
196
|
+
return self.taylorseer, self.encoder_tarlorseer
|
|
197
|
+
|
|
198
|
+
@torch.compiler.disable
|
|
133
199
|
def add_residual_diff(self, diff):
|
|
134
200
|
step = str(self.get_current_step())
|
|
135
201
|
if step not in self.residual_diffs:
|
|
136
202
|
# Only add the diff if it is not already recorded for this step
|
|
137
203
|
self.residual_diffs[step] = diff
|
|
138
204
|
|
|
205
|
+
@torch.compiler.disable
|
|
139
206
|
def get_residual_diffs(self):
|
|
140
207
|
return self.residual_diffs.copy()
|
|
141
208
|
|
|
209
|
+
@torch.compiler.disable
|
|
142
210
|
def add_cached_step(self):
|
|
143
211
|
self.cached_steps.append(self.get_current_step())
|
|
144
212
|
|
|
213
|
+
@torch.compiler.disable
|
|
145
214
|
def get_cached_steps(self):
|
|
146
215
|
return self.cached_steps.copy()
|
|
147
216
|
|
|
217
|
+
@torch.compiler.disable
|
|
148
218
|
def get_current_step(self):
|
|
149
219
|
return self.executed_steps - 1
|
|
150
220
|
|
|
221
|
+
@torch.compiler.disable
|
|
151
222
|
def is_in_warmup(self):
|
|
152
223
|
return self.get_current_step() < self.warmup_steps
|
|
153
224
|
|
|
@@ -229,6 +300,50 @@ def get_residual_diffs():
|
|
|
229
300
|
return cache_context.get_residual_diffs()
|
|
230
301
|
|
|
231
302
|
|
|
303
|
+
@torch.compiler.disable
|
|
304
|
+
def is_taylorseer_enabled():
|
|
305
|
+
cache_context = get_current_cache_context()
|
|
306
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
307
|
+
return cache_context.enable_taylorseer
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.compiler.disable
|
|
311
|
+
def is_encoder_taylorseer_enabled():
|
|
312
|
+
cache_context = get_current_cache_context()
|
|
313
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
314
|
+
return cache_context.enable_encoder_taylorseer
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@torch.compiler.disable
|
|
318
|
+
def get_taylorseers():
|
|
319
|
+
cache_context = get_current_cache_context()
|
|
320
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
321
|
+
return cache_context.get_taylorseers()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@torch.compiler.disable
|
|
325
|
+
def is_taylorseer_cache_residual():
|
|
326
|
+
cache_context = get_current_cache_context()
|
|
327
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
328
|
+
return cache_context.taylorseer_cache_type == "residual"
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@torch.compiler.disable
|
|
332
|
+
def is_cache_residual():
|
|
333
|
+
if is_taylorseer_enabled():
|
|
334
|
+
# residual or hidden_states
|
|
335
|
+
return is_taylorseer_cache_residual()
|
|
336
|
+
return True
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@torch.compiler.disable
|
|
340
|
+
def is_encoder_cache_residual():
|
|
341
|
+
if is_encoder_taylorseer_enabled():
|
|
342
|
+
# residual or hidden_states
|
|
343
|
+
return is_taylorseer_cache_residual()
|
|
344
|
+
return True
|
|
345
|
+
|
|
346
|
+
|
|
232
347
|
@torch.compiler.disable
|
|
233
348
|
def is_alter_cache_enabled():
|
|
234
349
|
cache_context = get_current_cache_context()
|
|
@@ -380,16 +495,21 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
|
380
495
|
for attr in cache_attrs
|
|
381
496
|
}
|
|
382
497
|
|
|
498
|
+
def _safe_set_sequence_field(
|
|
499
|
+
field_name: str,
|
|
500
|
+
default_value: Any = None,
|
|
501
|
+
):
|
|
502
|
+
if field_name not in cache_kwargs:
|
|
503
|
+
cache_kwargs[field_name] = kwargs.pop(
|
|
504
|
+
field_name,
|
|
505
|
+
default_value,
|
|
506
|
+
)
|
|
507
|
+
|
|
383
508
|
# Manually set sequence fields, namely, Fn_compute_blocks_ids
|
|
384
509
|
# and Bn_compute_blocks_ids, which are lists or sets.
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
)
|
|
389
|
-
cache_kwargs["Bn_compute_blocks_ids"] = kwargs.pop(
|
|
390
|
-
"Bn_compute_blocks_ids",
|
|
391
|
-
[],
|
|
392
|
-
)
|
|
510
|
+
_safe_set_sequence_field("Fn_compute_blocks_ids", [])
|
|
511
|
+
_safe_set_sequence_field("Bn_compute_blocks_ids", [])
|
|
512
|
+
_safe_set_sequence_field("taylorseer_kwargs", {})
|
|
393
513
|
|
|
394
514
|
assert default_attrs is not None, "default_attrs must be set before"
|
|
395
515
|
for attr in cache_attrs:
|
|
@@ -484,6 +604,7 @@ def are_two_tensors_similar(
|
|
|
484
604
|
@torch.compiler.disable
|
|
485
605
|
def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
486
606
|
# Set hidden_states or residual for Fn blocks.
|
|
607
|
+
# This buffer is only use for L1 diff calculation.
|
|
487
608
|
downsample_factor = get_downsample_factor()
|
|
488
609
|
if downsample_factor > 1:
|
|
489
610
|
buffer = buffer[..., ::downsample_factor]
|
|
@@ -510,22 +631,79 @@ def get_Fn_encoder_buffer(prefix: str = "Fn"):
|
|
|
510
631
|
@torch.compiler.disable
|
|
511
632
|
def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
512
633
|
# Set hidden_states or residual for Bn blocks.
|
|
513
|
-
|
|
634
|
+
# This buffer is use for hidden states approximation.
|
|
635
|
+
if is_taylorseer_enabled():
|
|
636
|
+
# taylorseer, encoder_taylorseer
|
|
637
|
+
taylorseer, _ = get_taylorseers()
|
|
638
|
+
if taylorseer is not None:
|
|
639
|
+
# Use TaylorSeer to update the buffer
|
|
640
|
+
taylorseer.update(buffer)
|
|
641
|
+
else:
|
|
642
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
643
|
+
logger.debug(
|
|
644
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
645
|
+
"Falling back to default buffer retrieval."
|
|
646
|
+
)
|
|
647
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
648
|
+
else:
|
|
649
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
514
650
|
|
|
515
651
|
|
|
516
652
|
@torch.compiler.disable
|
|
517
653
|
def get_Bn_buffer(prefix: str = "Bn"):
|
|
518
|
-
|
|
654
|
+
if is_taylorseer_enabled():
|
|
655
|
+
taylorseer, _ = get_taylorseers()
|
|
656
|
+
if taylorseer is not None:
|
|
657
|
+
return taylorseer.approximate_value()
|
|
658
|
+
else:
|
|
659
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
660
|
+
logger.debug(
|
|
661
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
662
|
+
"Falling back to default buffer retrieval."
|
|
663
|
+
)
|
|
664
|
+
# Fallback to default buffer retrieval
|
|
665
|
+
return get_buffer(f"{prefix}_buffer")
|
|
666
|
+
else:
|
|
667
|
+
return get_buffer(f"{prefix}_buffer")
|
|
519
668
|
|
|
520
669
|
|
|
521
670
|
@torch.compiler.disable
|
|
522
671
|
def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
523
|
-
|
|
672
|
+
# This buffer is use for encoder hidden states approximation.
|
|
673
|
+
if is_encoder_taylorseer_enabled():
|
|
674
|
+
# taylorseer, encoder_taylorseer
|
|
675
|
+
_, encoder_taylorseer = get_taylorseers()
|
|
676
|
+
if encoder_taylorseer is not None:
|
|
677
|
+
# Use TaylorSeer to update the buffer
|
|
678
|
+
encoder_taylorseer.update(buffer)
|
|
679
|
+
else:
|
|
680
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
681
|
+
logger.debug(
|
|
682
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
683
|
+
"Falling back to default buffer retrieval."
|
|
684
|
+
)
|
|
685
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
686
|
+
else:
|
|
687
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
524
688
|
|
|
525
689
|
|
|
526
690
|
@torch.compiler.disable
|
|
527
691
|
def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
528
|
-
|
|
692
|
+
if is_encoder_taylorseer_enabled():
|
|
693
|
+
_, encoder_taylorseer = get_taylorseers()
|
|
694
|
+
if encoder_taylorseer is not None:
|
|
695
|
+
# Use TaylorSeer to approximate the value
|
|
696
|
+
return encoder_taylorseer.approximate_value()
|
|
697
|
+
else:
|
|
698
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
699
|
+
logger.debug(
|
|
700
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
701
|
+
"Falling back to default buffer retrieval."
|
|
702
|
+
)
|
|
703
|
+
# Fallback to default buffer retrieval
|
|
704
|
+
return get_buffer(f"{prefix}_encoder_buffer")
|
|
705
|
+
else:
|
|
706
|
+
return get_buffer(f"{prefix}_encoder_buffer")
|
|
529
707
|
|
|
530
708
|
|
|
531
709
|
@torch.compiler.disable
|
|
@@ -533,29 +711,38 @@ def apply_hidden_states_residual(
|
|
|
533
711
|
hidden_states: torch.Tensor,
|
|
534
712
|
encoder_hidden_states: torch.Tensor,
|
|
535
713
|
prefix: str = "Bn",
|
|
714
|
+
encoder_prefix: str = "Bn_encoder",
|
|
536
715
|
):
|
|
537
716
|
# Allow Bn and Fn prefix to be used for residual cache.
|
|
538
717
|
if "Bn" in prefix:
|
|
539
|
-
|
|
718
|
+
hidden_states_prev = get_Bn_buffer(prefix)
|
|
540
719
|
else:
|
|
541
|
-
|
|
720
|
+
hidden_states_prev = get_Fn_buffer(prefix)
|
|
542
721
|
|
|
543
|
-
assert
|
|
544
|
-
hidden_states_residual is not None
|
|
545
|
-
), f"{prefix}_buffer must be set before"
|
|
546
|
-
hidden_states = hidden_states_residual + hidden_states
|
|
722
|
+
assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
|
|
547
723
|
|
|
548
|
-
if
|
|
549
|
-
|
|
724
|
+
if is_cache_residual():
|
|
725
|
+
hidden_states = hidden_states_prev + hidden_states
|
|
726
|
+
else:
|
|
727
|
+
# If cache is not residual, we use the hidden states directly
|
|
728
|
+
hidden_states = hidden_states_prev
|
|
729
|
+
|
|
730
|
+
if "Bn" in encoder_prefix:
|
|
731
|
+
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
550
732
|
else:
|
|
551
|
-
|
|
733
|
+
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
552
734
|
|
|
553
735
|
assert (
|
|
554
|
-
|
|
736
|
+
encoder_hidden_states_prev is not None
|
|
555
737
|
), f"{prefix}_encoder_buffer must be set before"
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
738
|
+
|
|
739
|
+
if is_encoder_cache_residual():
|
|
740
|
+
encoder_hidden_states = (
|
|
741
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
742
|
+
)
|
|
743
|
+
else:
|
|
744
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
745
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
559
746
|
|
|
560
747
|
hidden_states = hidden_states.contiguous()
|
|
561
748
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
@@ -687,11 +874,22 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
687
874
|
|
|
688
875
|
torch._dynamo.graph_break()
|
|
689
876
|
if can_use_cache:
|
|
877
|
+
torch._dynamo.graph_break()
|
|
690
878
|
add_cached_step()
|
|
691
879
|
del Fn_hidden_states_residual
|
|
692
880
|
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
693
|
-
hidden_states,
|
|
881
|
+
hidden_states,
|
|
882
|
+
encoder_hidden_states,
|
|
883
|
+
prefix=(
|
|
884
|
+
"Bn_residual" if is_cache_residual() else "Bn_hidden_states"
|
|
885
|
+
),
|
|
886
|
+
encoder_prefix=(
|
|
887
|
+
"Bn_residual"
|
|
888
|
+
if is_encoder_cache_residual()
|
|
889
|
+
else "Bn_hidden_states"
|
|
890
|
+
),
|
|
694
891
|
)
|
|
892
|
+
torch._dynamo.graph_break()
|
|
695
893
|
# Call last `n` blocks to further process the hidden states
|
|
696
894
|
# for higher precision.
|
|
697
895
|
hidden_states, encoder_hidden_states = (
|
|
@@ -703,11 +901,13 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
703
901
|
)
|
|
704
902
|
)
|
|
705
903
|
else:
|
|
904
|
+
torch._dynamo.graph_break()
|
|
706
905
|
set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
|
|
707
906
|
if is_l1_diff_enabled():
|
|
708
907
|
# for hidden states L1 diff
|
|
709
908
|
set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
710
909
|
del Fn_hidden_states_residual
|
|
910
|
+
torch._dynamo.graph_break()
|
|
711
911
|
(
|
|
712
912
|
hidden_states,
|
|
713
913
|
encoder_hidden_states,
|
|
@@ -719,10 +919,30 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
719
919
|
*args,
|
|
720
920
|
**kwargs,
|
|
721
921
|
)
|
|
722
|
-
|
|
723
|
-
|
|
724
|
-
|
|
725
|
-
|
|
922
|
+
torch._dynamo.graph_break()
|
|
923
|
+
if is_cache_residual():
|
|
924
|
+
set_Bn_buffer(
|
|
925
|
+
hidden_states_residual,
|
|
926
|
+
prefix="Bn_residual",
|
|
927
|
+
)
|
|
928
|
+
else:
|
|
929
|
+
# TaylorSeer
|
|
930
|
+
set_Bn_buffer(
|
|
931
|
+
hidden_states,
|
|
932
|
+
prefix="Bn_hidden_states",
|
|
933
|
+
)
|
|
934
|
+
if is_encoder_cache_residual():
|
|
935
|
+
set_Bn_encoder_buffer(
|
|
936
|
+
encoder_hidden_states_residual,
|
|
937
|
+
prefix="Bn_residual",
|
|
938
|
+
)
|
|
939
|
+
else:
|
|
940
|
+
# TaylorSeer
|
|
941
|
+
set_Bn_encoder_buffer(
|
|
942
|
+
encoder_hidden_states,
|
|
943
|
+
prefix="Bn_hidden_states",
|
|
944
|
+
)
|
|
945
|
+
torch._dynamo.graph_break()
|
|
726
946
|
# Call last `n` blocks to further process the hidden states
|
|
727
947
|
# for higher precision.
|
|
728
948
|
hidden_states, encoder_hidden_states = (
|
|
@@ -772,16 +992,6 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
772
992
|
selected_Fn_transformer_blocks = self.transformer_blocks[
|
|
773
993
|
: Fn_compute_blocks()
|
|
774
994
|
]
|
|
775
|
-
# Skip the blocks if they are not in the Fn_compute_blocks_ids.
|
|
776
|
-
# WARN: DON'T set len(Fn_compute_blocks_ids) > 0 NOW, still have
|
|
777
|
-
# some precision issues. We don't know whether a step should be
|
|
778
|
-
# cached or not before the first Fn blocks are processed.
|
|
779
|
-
if len(Fn_compute_blocks_ids()) > 0:
|
|
780
|
-
selected_Fn_transformer_blocks = [
|
|
781
|
-
selected_Fn_transformer_blocks[i]
|
|
782
|
-
for i in Fn_compute_blocks_ids()
|
|
783
|
-
if i < len(selected_Fn_transformer_blocks)
|
|
784
|
-
]
|
|
785
995
|
return selected_Fn_transformer_blocks
|
|
786
996
|
|
|
787
997
|
@torch.compiler.disable
|
|
@@ -800,7 +1010,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
800
1010
|
return selected_Mn_single_transformer_blocks
|
|
801
1011
|
|
|
802
1012
|
@torch.compiler.disable
|
|
803
|
-
def _Mn_transformer_blocks(self):
|
|
1013
|
+
def _Mn_transformer_blocks(self): # middle blocks
|
|
804
1014
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
805
1015
|
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
806
1016
|
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
@@ -1074,6 +1284,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1074
1284
|
Bn_i_original_hidden_states,
|
|
1075
1285
|
prefix=f"Bn_{block_id}_single_original",
|
|
1076
1286
|
)
|
|
1287
|
+
set_Bn_encoder_buffer(
|
|
1288
|
+
Bn_i_original_hidden_states,
|
|
1289
|
+
prefix=f"Bn_{block_id}_single_original",
|
|
1290
|
+
)
|
|
1077
1291
|
|
|
1078
1292
|
set_Bn_buffer(
|
|
1079
1293
|
Bn_i_hidden_states_residual,
|
|
@@ -1121,7 +1335,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1121
1335
|
apply_hidden_states_residual(
|
|
1122
1336
|
Bn_i_original_hidden_states,
|
|
1123
1337
|
Bn_i_original_encoder_hidden_states,
|
|
1124
|
-
prefix=
|
|
1338
|
+
prefix=(
|
|
1339
|
+
f"Bn_{block_id}_single_residual"
|
|
1340
|
+
if is_cache_residual()
|
|
1341
|
+
else f"Bn_{block_id}_single_original"
|
|
1342
|
+
),
|
|
1343
|
+
encoder_prefix=(
|
|
1344
|
+
f"Bn_{block_id}_single_residual"
|
|
1345
|
+
if is_encoder_cache_residual()
|
|
1346
|
+
else f"Bn_{block_id}_single_original"
|
|
1347
|
+
),
|
|
1125
1348
|
)
|
|
1126
1349
|
)
|
|
1127
1350
|
hidden_states = torch.cat(
|
|
@@ -1142,7 +1365,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1142
1365
|
self,
|
|
1143
1366
|
# Block index in the transformer blocks
|
|
1144
1367
|
# Bn: 8, block_id should be in [0, 8)
|
|
1145
|
-
block_id: int,
|
|
1368
|
+
block_id: int,
|
|
1146
1369
|
# Below are the inputs to the block
|
|
1147
1370
|
block, # The transformer block to be executed
|
|
1148
1371
|
hidden_states: torch.Tensor,
|
|
@@ -1188,6 +1411,10 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1188
1411
|
Bn_i_original_hidden_states,
|
|
1189
1412
|
prefix=f"Bn_{block_id}_original",
|
|
1190
1413
|
)
|
|
1414
|
+
set_Bn_encoder_buffer(
|
|
1415
|
+
Bn_i_original_encoder_hidden_states,
|
|
1416
|
+
prefix=f"Bn_{block_id}_original",
|
|
1417
|
+
)
|
|
1191
1418
|
|
|
1192
1419
|
set_Bn_buffer(
|
|
1193
1420
|
Bn_i_hidden_states_residual,
|
|
@@ -1234,7 +1461,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1234
1461
|
apply_hidden_states_residual(
|
|
1235
1462
|
hidden_states,
|
|
1236
1463
|
encoder_hidden_states,
|
|
1237
|
-
prefix=
|
|
1464
|
+
prefix=(
|
|
1465
|
+
f"Bn_{block_id}_residual"
|
|
1466
|
+
if is_cache_residual()
|
|
1467
|
+
else f"Bn_{block_id}_original"
|
|
1468
|
+
),
|
|
1469
|
+
encoder_prefix=(
|
|
1470
|
+
f"Bn_{block_id}_residual"
|
|
1471
|
+
if is_encoder_cache_residual()
|
|
1472
|
+
else f"Bn_{block_id}_original"
|
|
1473
|
+
),
|
|
1238
1474
|
)
|
|
1239
1475
|
)
|
|
1240
1476
|
else:
|
|
@@ -370,6 +370,9 @@ def apply_prev_hidden_states_residual(
|
|
|
370
370
|
hidden_states = hidden_states_residual + hidden_states
|
|
371
371
|
|
|
372
372
|
hidden_states = hidden_states.contiguous()
|
|
373
|
+
# NOTE: We should also support taylorseer for
|
|
374
|
+
# encoder_hidden_states approximation. Please
|
|
375
|
+
# use DBCache instead.
|
|
373
376
|
else:
|
|
374
377
|
hidden_states_residual = get_hidden_states_residual()
|
|
375
378
|
assert (
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
|
|
2
2
|
import math
|
|
3
|
+
import torch
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class TaylorSeer:
|
|
@@ -17,6 +18,7 @@ class TaylorSeer:
|
|
|
17
18
|
self.compute_step_map = compute_step_map
|
|
18
19
|
self.reset_cache()
|
|
19
20
|
|
|
21
|
+
@torch.compiler.disable
|
|
20
22
|
def reset_cache(self):
|
|
21
23
|
self.state = {
|
|
22
24
|
"dY_prev": [None] * self.ORDER,
|
|
@@ -25,6 +27,7 @@ class TaylorSeer:
|
|
|
25
27
|
self.current_step = -1
|
|
26
28
|
self.last_non_approximated_step = -1
|
|
27
29
|
|
|
30
|
+
@torch.compiler.disable
|
|
28
31
|
def should_compute_full(self, step=None):
|
|
29
32
|
step = self.current_step if step is None else step
|
|
30
33
|
if self.compute_step_map is not None:
|
|
@@ -36,7 +39,15 @@ class TaylorSeer:
|
|
|
36
39
|
return True
|
|
37
40
|
return False
|
|
38
41
|
|
|
42
|
+
@torch.compiler.disable
|
|
39
43
|
def approximate_derivative(self, Y):
|
|
44
|
+
# n-th order Taylor expansion:
|
|
45
|
+
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
46
|
+
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
47
|
+
# reference: https://github.com/Shenyi-Z/TaylorSeer
|
|
48
|
+
# TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
|
|
49
|
+
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
50
|
+
# especially for large n_derivatives.
|
|
40
51
|
dY_current = [None] * self.ORDER
|
|
41
52
|
dY_current[0] = Y
|
|
42
53
|
window = self.current_step - self.last_non_approximated_step
|
|
@@ -49,7 +60,10 @@ class TaylorSeer:
|
|
|
49
60
|
break
|
|
50
61
|
return dY_current
|
|
51
62
|
|
|
63
|
+
@torch.compiler.disable
|
|
52
64
|
def approximate_value(self):
|
|
65
|
+
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
66
|
+
# especially for large n_derivatives.
|
|
53
67
|
elapsed = self.current_step - self.last_non_approximated_step
|
|
54
68
|
output = 0
|
|
55
69
|
for i, derivative in enumerate(self.state["dY_current"]):
|
|
@@ -59,14 +73,30 @@ class TaylorSeer:
|
|
|
59
73
|
break
|
|
60
74
|
return output
|
|
61
75
|
|
|
76
|
+
@torch.compiler.disable
|
|
62
77
|
def mark_step_begin(self):
|
|
63
78
|
self.current_step += 1
|
|
64
79
|
|
|
80
|
+
@torch.compiler.disable
|
|
65
81
|
def update(self, Y):
|
|
82
|
+
# Directly call this method will ingnore the warmup
|
|
83
|
+
# policy and force full computation.
|
|
84
|
+
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
85
|
+
# step 0: dY_prev = [None, None, None, None ]
|
|
86
|
+
# dY_current = [Y0, None, None, None ]
|
|
87
|
+
# step 1: dY_prev = [Y0, None, None, None ]
|
|
88
|
+
# dY_current = [Y1, dY1, None, None ]
|
|
89
|
+
# step 2: dY_prev = [Y1, dY1, None, None ]
|
|
90
|
+
# dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
|
|
91
|
+
# step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
|
|
92
|
+
# dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
93
|
+
# step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
94
|
+
# dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
|
|
66
95
|
self.state["dY_prev"] = self.state["dY_current"]
|
|
67
96
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
68
97
|
self.last_non_approximated_step = self.current_step
|
|
69
98
|
|
|
99
|
+
@torch.compiler.disable
|
|
70
100
|
def step(self, Y):
|
|
71
101
|
self.mark_step_begin()
|
|
72
102
|
if self.should_compute_full():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -37,31 +37,31 @@ Dynamic: requires-python
|
|
|
37
37
|
<p align="center">
|
|
38
38
|
<h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
|
|
39
39
|
</p>
|
|
40
|
-
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
|
|
40
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-v1.png >
|
|
41
41
|
<div align='center'>
|
|
42
42
|
<img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
|
|
43
43
|
<img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.2.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.2.2-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>
|
|
50
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
|
|
51
51
|
</p>
|
|
52
52
|
</div>
|
|
53
53
|
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
<
|
|
57
|
-
|
|
58
|
-
|
|
54
|
+
<div align="center">
|
|
55
|
+
<p align="center">
|
|
56
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
57
|
+
</p>
|
|
58
|
+
</div>
|
|
59
59
|
|
|
60
60
|
## 🤗 Introduction
|
|
61
61
|
|
|
62
62
|
<div align="center">
|
|
63
63
|
<p align="center">
|
|
64
|
-
<h3>🔥
|
|
64
|
+
<h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
65
65
|
</p>
|
|
66
66
|
</div>
|
|
67
67
|
|
|
@@ -77,9 +77,9 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
|
|
|
77
77
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
78
78
|
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
79
79
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
80
|
-
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.
|
|
80
|
+
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
|
|
81
81
|
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
|
|
82
|
-
|<img src=https://github.com/
|
|
82
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
|
|
83
83
|
|
|
84
84
|
<div align="center">
|
|
85
85
|
<p align="center">
|
|
@@ -91,7 +91,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
91
91
|
|
|
92
92
|
<div align="center">
|
|
93
93
|
<p align="center">
|
|
94
|
-
<h3>🔥
|
|
94
|
+
<h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
95
95
|
</p>
|
|
96
96
|
</div>
|
|
97
97
|
|
|
@@ -110,11 +110,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
110
110
|
|
|
111
111
|
<div align="center">
|
|
112
112
|
<p align="center">
|
|
113
|
-
<h3>🔥
|
|
113
|
+
<h3>🔥Context Parallelism and Torch Compile</h3>
|
|
114
114
|
</p>
|
|
115
115
|
</div>
|
|
116
116
|
|
|
117
|
-
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
117
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
118
118
|
|
|
119
119
|
<div align="center">
|
|
120
120
|
<p align="center">
|
|
@@ -128,12 +128,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
128
128
|
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
129
129
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
130
130
|
|
|
131
|
-
<div align="center">
|
|
132
|
-
<p align="center">
|
|
133
|
-
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
134
|
-
</p>
|
|
135
|
-
</div>
|
|
136
|
-
|
|
137
131
|
## ©️Citations
|
|
138
132
|
|
|
139
133
|
```BibTeX
|
|
@@ -146,6 +140,12 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
146
140
|
}
|
|
147
141
|
```
|
|
148
142
|
|
|
143
|
+
## 👋Reference
|
|
144
|
+
|
|
145
|
+
<div id="reference"></div>
|
|
146
|
+
|
|
147
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
|
|
148
|
+
|
|
149
149
|
## 📖Contents
|
|
150
150
|
|
|
151
151
|
<div id="contents"></div>
|
|
@@ -153,6 +153,7 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
153
153
|
- [⚙️Installation](#️installation)
|
|
154
154
|
- [🔥Supported Models](#supported)
|
|
155
155
|
- [⚡️Dual Block Cache](#dbcache)
|
|
156
|
+
- [🔥Hybrid TaylorSeer](#taylorseer)
|
|
156
157
|
- [🎉First Block Cache](#fbcache)
|
|
157
158
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
158
159
|
- [🎉Context Parallelism](#context-parallelism)
|
|
@@ -187,28 +188,19 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
187
188
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
188
189
|
|
|
189
190
|
|
|
190
|
-
<!--
|
|
191
|
-
<p align="center">
|
|
192
|
-
<h4> 🔥Supported Models🔥</h4>
|
|
193
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
194
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
195
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
196
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
197
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
198
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
199
|
-
</p>
|
|
200
|
-
-->
|
|
201
|
-
|
|
202
191
|
## ⚡️DBCache: Dual Block Cache
|
|
203
192
|
|
|
204
193
|
<div id="dbcache"></div>
|
|
205
194
|
|
|
206
|
-

|
|
207
196
|
|
|
208
197
|
**DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
|
|
209
198
|
|
|
210
199
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
211
200
|
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
|
201
|
+
|
|
202
|
+

|
|
203
|
+
|
|
212
204
|
- **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
213
205
|
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
214
206
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
@@ -264,11 +256,50 @@ cache_options = {
|
|
|
264
256
|
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
265
257
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
266
258
|
|
|
259
|
+
## 🔥Hybrid TaylorSeer
|
|
260
|
+
|
|
261
|
+
<div id="taylorseer"></div>
|
|
262
|
+
|
|
263
|
+
We have supported the [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
|
264
|
+
|
|
265
|
+
$$
|
|
266
|
+
\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i
|
|
267
|
+
$$
|
|
268
|
+
|
|
269
|
+
**TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
|
|
270
|
+
|
|
271
|
+
```python
|
|
272
|
+
cache_options = {
|
|
273
|
+
# TaylorSeer options
|
|
274
|
+
"enable_taylorseer": True,
|
|
275
|
+
"enable_encoder_taylorseer": True,
|
|
276
|
+
# Taylorseer cache type cache be hidden_states or residual.
|
|
277
|
+
"taylorseer_cache_type": "residual",
|
|
278
|
+
# Higher values of n_derivatives will lead to longer
|
|
279
|
+
# computation time but may improve precision significantly.
|
|
280
|
+
"taylorseer_kwargs": {
|
|
281
|
+
"n_derivatives": 2, # default is 2.
|
|
282
|
+
},
|
|
283
|
+
"warmup_steps": 3, # n_derivatives + 1
|
|
284
|
+
"residual_diff_threshold": 0.12,
|
|
285
|
+
}
|
|
286
|
+
```
|
|
287
|
+
<div align="center">
|
|
288
|
+
<p align="center">
|
|
289
|
+
<b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
|
|
290
|
+
</p>
|
|
291
|
+
</div>
|
|
292
|
+
|
|
293
|
+
|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
|
|
294
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
295
|
+
|24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
|
|
296
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
|
|
297
|
+
|
|
267
298
|
## 🎉FBCache: First Block Cache
|
|
268
299
|
|
|
269
300
|
<div id="fbcache"></div>
|
|
270
301
|
|
|
271
|
-

|
|
272
303
|
|
|
273
304
|
**DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure **DBCache** with **F1B0** settings to achieve the same functionality.
|
|
274
305
|
|
|
@@ -302,7 +333,7 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
302
333
|
|
|
303
334
|
<div id="dbprune"></div>
|
|
304
335
|
|
|
305
|
-

|
|
306
337
|
|
|
307
338
|
We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
|
|
308
339
|
|
|
@@ -389,7 +420,7 @@ from para_attn.context_parallel import init_context_parallel_mesh
|
|
|
389
420
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
390
421
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
391
422
|
|
|
392
|
-
|
|
423
|
+
# Init distributed process group
|
|
393
424
|
dist.init_process_group()
|
|
394
425
|
torch.cuda.set_device(dist.get_rank())
|
|
395
426
|
|
|
@@ -436,14 +467,16 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
|
436
467
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
437
468
|
```
|
|
438
469
|
|
|
470
|
+
Please check [bench.py](./bench/bench.py) for more details.
|
|
471
|
+
|
|
439
472
|
## 👋Contribute
|
|
440
473
|
<div id="contribute"></div>
|
|
441
474
|
|
|
442
|
-
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
|
|
475
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/raw/main/CONTRIBUTE.md).
|
|
443
476
|
|
|
444
477
|
## ©️License
|
|
445
478
|
|
|
446
479
|
<div id="license"></div>
|
|
447
480
|
|
|
448
481
|
|
|
449
|
-
We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](
|
|
482
|
+
We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
|
|
@@ -1,36 +1,36 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=OjGGK5TcHVG44Y62aAqeJH4CskkZoY9ydbHOtCDew50,511
|
|
3
3
|
cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
|
|
6
|
-
cache_dit/cache_factory/taylorseer.py,sha256=
|
|
6
|
+
cache_dit/cache_factory/taylorseer.py,sha256=G3Bu7tPqKP1zt4pzN4gJ1T4VJobyaktC3zNSR1MACqc,4005
|
|
7
7
|
cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=
|
|
10
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=
|
|
11
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=
|
|
12
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=
|
|
14
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=
|
|
15
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=ip3bY79lEGo2xDcGZuRxDu065q1eYabXLJUV8BsfhvM,59698
|
|
10
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
11
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=fibkeU-FHa30BNT-uPV2Eqcd5IRli07EKb25tMDp23c,2270
|
|
12
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=fddSpTHXU24COMGAY-Z21EmHHAEArZBv_-XLRFD6ADU,2625
|
|
13
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=wcZdBhjUB8WSfz40A268BtSe3nr_hRsIi2BNlg1FHRU,9965
|
|
14
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=Cmy0KHRDgwXqtmqfkrr7kw0CP6CmkSnuz29gDHcD6sQ,2262
|
|
15
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=oOb-NNesuyBAMwpRbFDG93TCwffLO7-TXZ4bvEtvGJc,2604
|
|
16
16
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=YRDwZ_16yjThpgVgDv6YaIB4QCE9nEkE-MOru0jOd50,35026
|
|
18
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=
|
|
19
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=
|
|
20
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=
|
|
21
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=
|
|
22
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=
|
|
23
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=
|
|
18
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
|
|
19
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
|
|
20
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
|
|
21
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=xAkd40BGsfuCKdW3Abrx35VwgZQg4CZFz13P4VY71eY,9968
|
|
22
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
|
|
23
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
|
|
24
24
|
cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=
|
|
25
|
+
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=oeOmVDho8aJa86p8LGACAWltu6Fe4chOW2OW8aPtB5c,23643
|
|
26
26
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
|
|
27
27
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
|
|
28
28
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
29
29
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sha256=OL7W4ukYlZz0IDmBR1zVV6XT3Mgciglj9Hqzv1wUAkQ,10092
|
|
30
30
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
31
31
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
|
|
32
|
-
cache_dit-0.2.
|
|
33
|
-
cache_dit-0.2.
|
|
34
|
-
cache_dit-0.2.
|
|
35
|
-
cache_dit-0.2.
|
|
36
|
-
cache_dit-0.2.
|
|
32
|
+
cache_dit-0.2.2.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
33
|
+
cache_dit-0.2.2.dist-info/METADATA,sha256=-4Sl6vCYl5bPBQVtlHGhHvEz0Nn-o3R5mYArGARsjLU,24521
|
|
34
|
+
cache_dit-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
cache_dit-0.2.2.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
36
|
+
cache_dit-0.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|