cache-dit 0.2.0__py3-none-any.whl → 0.2.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/dual_block_cache/cache_context.py +322 -69
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -1
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -1
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +1 -3
- cache_dit/cache_factory/first_block_cache/cache_context.py +3 -0
- cache_dit/cache_factory/taylorseer.py +30 -0
- {cache_dit-0.2.0.dist-info → cache_dit-0.2.2.dist-info}/METADATA +88 -36
- {cache_dit-0.2.0.dist-info → cache_dit-0.2.2.dist-info}/RECORD +21 -21
- {cache_dit-0.2.0.dist-info → cache_dit-0.2.2.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.0.dist-info → cache_dit-0.2.2.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.0.dist-info → cache_dit-0.2.2.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -9,6 +9,7 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union
|
|
|
9
9
|
import torch
|
|
10
10
|
|
|
11
11
|
import cache_dit.primitives as DP
|
|
12
|
+
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
12
13
|
from cache_dit.logger import init_logger
|
|
13
14
|
|
|
14
15
|
logger = init_logger(__name__)
|
|
@@ -60,7 +61,55 @@ class DBCacheContext:
|
|
|
60
61
|
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
61
62
|
default_factory=lambda: defaultdict(float),
|
|
62
63
|
)
|
|
64
|
+
# Support TaylorSeers in Dual Block Cache
|
|
65
|
+
# Title: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers
|
|
66
|
+
# Url: https://arxiv.org/pdf/2503.06923
|
|
67
|
+
enable_taylorseer: bool = False
|
|
68
|
+
enable_encoder_taylorseer: bool = False
|
|
69
|
+
# NOTE: use residual cache for taylorseer may incur precision loss
|
|
70
|
+
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
71
|
+
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
72
|
+
taylorseer: Optional[TaylorSeer] = None
|
|
73
|
+
encoder_tarlorseer: Optional[TaylorSeer] = None
|
|
74
|
+
alter_taylorseer: Optional[TaylorSeer] = None
|
|
75
|
+
alter_encoder_taylorseer: Optional[TaylorSeer] = None
|
|
76
|
+
|
|
77
|
+
# TODO: Support SLG in Dual Block Cache
|
|
78
|
+
# Skip Layer Guidance, SLG
|
|
79
|
+
# https://github.com/huggingface/candle/issues/2588
|
|
80
|
+
slg_layers: Optional[List[int]] = None
|
|
81
|
+
slg_start: float = 0.0
|
|
82
|
+
slg_end: float = 0.1
|
|
63
83
|
|
|
84
|
+
@torch.compiler.disable
|
|
85
|
+
def __post_init__(self):
|
|
86
|
+
|
|
87
|
+
if "warmup_steps" not in self.taylorseer_kwargs:
|
|
88
|
+
# If warmup_steps is not set in taylorseer_kwargs,
|
|
89
|
+
# set the same as warmup_steps for DBCache
|
|
90
|
+
self.taylorseer_kwargs["warmup_steps"] = (
|
|
91
|
+
self.warmup_steps if self.warmup_steps > 0 else 1
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
# Only set n_derivatives as 2 or 3, which is enough for most cases.
|
|
95
|
+
if "n_derivatives" not in self.taylorseer_kwargs:
|
|
96
|
+
self.taylorseer_kwargs["n_derivatives"] = max(
|
|
97
|
+
2, min(3, self.taylorseer_kwargs["warmup_steps"])
|
|
98
|
+
)
|
|
99
|
+
|
|
100
|
+
if self.enable_taylorseer:
|
|
101
|
+
self.taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
102
|
+
if self.enable_alter_cache:
|
|
103
|
+
self.alter_taylorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
104
|
+
|
|
105
|
+
if self.enable_encoder_taylorseer:
|
|
106
|
+
self.encoder_tarlorseer = TaylorSeer(**self.taylorseer_kwargs)
|
|
107
|
+
if self.enable_alter_cache:
|
|
108
|
+
self.alter_encoder_taylorseer = TaylorSeer(
|
|
109
|
+
**self.taylorseer_kwargs
|
|
110
|
+
)
|
|
111
|
+
|
|
112
|
+
@torch.compiler.disable
|
|
64
113
|
def get_incremental_name(self, name=None):
|
|
65
114
|
if name is None:
|
|
66
115
|
name = "default"
|
|
@@ -68,9 +117,11 @@ class DBCacheContext:
|
|
|
68
117
|
self.incremental_name_counters[name] += 1
|
|
69
118
|
return f"{name}_{idx}"
|
|
70
119
|
|
|
120
|
+
@torch.compiler.disable
|
|
71
121
|
def reset_incremental_names(self):
|
|
72
122
|
self.incremental_name_counters.clear()
|
|
73
123
|
|
|
124
|
+
@torch.compiler.disable
|
|
74
125
|
def get_residual_diff_threshold(self):
|
|
75
126
|
if self.enable_alter_cache:
|
|
76
127
|
residual_diff_threshold = self.alter_residual_diff_threshold
|
|
@@ -83,25 +134,30 @@ class DBCacheContext:
|
|
|
83
134
|
residual_diff_threshold = residual_diff_threshold.item()
|
|
84
135
|
return residual_diff_threshold
|
|
85
136
|
|
|
137
|
+
@torch.compiler.disable
|
|
86
138
|
def get_buffer(self, name):
|
|
87
139
|
if self.enable_alter_cache and self.is_alter_cache:
|
|
88
140
|
name = f"{name}_alter"
|
|
89
141
|
return self.buffers.get(name)
|
|
90
142
|
|
|
143
|
+
@torch.compiler.disable
|
|
91
144
|
def set_buffer(self, name, buffer):
|
|
92
145
|
if self.enable_alter_cache and self.is_alter_cache:
|
|
93
146
|
name = f"{name}_alter"
|
|
94
147
|
self.buffers[name] = buffer
|
|
95
148
|
|
|
149
|
+
@torch.compiler.disable
|
|
96
150
|
def remove_buffer(self, name):
|
|
97
151
|
if self.enable_alter_cache and self.is_alter_cache:
|
|
98
152
|
name = f"{name}_alter"
|
|
99
153
|
if name in self.buffers:
|
|
100
154
|
del self.buffers[name]
|
|
101
155
|
|
|
156
|
+
@torch.compiler.disable
|
|
102
157
|
def clear_buffers(self):
|
|
103
158
|
self.buffers.clear()
|
|
104
159
|
|
|
160
|
+
@torch.compiler.disable
|
|
105
161
|
def mark_step_begin(self):
|
|
106
162
|
if not self.enable_alter_cache:
|
|
107
163
|
self.executed_steps += 1
|
|
@@ -116,25 +172,53 @@ class DBCacheContext:
|
|
|
116
172
|
self.cached_steps.clear()
|
|
117
173
|
self.residual_diffs.clear()
|
|
118
174
|
self.reset_incremental_names()
|
|
175
|
+
# Reset the TaylorSeers cache at the beginning of each inference.
|
|
176
|
+
# reset_cache will set the current step to -1 for TaylorSeer,
|
|
177
|
+
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
178
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
179
|
+
if taylorseer is not None:
|
|
180
|
+
taylorseer.reset_cache()
|
|
181
|
+
if encoder_taylorseer is not None:
|
|
182
|
+
encoder_taylorseer.reset_cache()
|
|
183
|
+
|
|
184
|
+
# mark_step_begin of TaylorSeer must be called after the cache is reset.
|
|
185
|
+
if self.enable_taylorseer or self.enable_encoder_taylorseer:
|
|
186
|
+
taylorseer, encoder_taylorseer = self.get_taylorseers()
|
|
187
|
+
if taylorseer is not None:
|
|
188
|
+
taylorseer.mark_step_begin()
|
|
189
|
+
if encoder_taylorseer is not None:
|
|
190
|
+
encoder_taylorseer.mark_step_begin()
|
|
119
191
|
|
|
192
|
+
@torch.compiler.disable
|
|
193
|
+
def get_taylorseers(self):
|
|
194
|
+
if self.enable_alter_cache and self.is_alter_cache:
|
|
195
|
+
return self.alter_taylorseer, self.alter_encoder_taylorseer
|
|
196
|
+
return self.taylorseer, self.encoder_tarlorseer
|
|
197
|
+
|
|
198
|
+
@torch.compiler.disable
|
|
120
199
|
def add_residual_diff(self, diff):
|
|
121
200
|
step = str(self.get_current_step())
|
|
122
201
|
if step not in self.residual_diffs:
|
|
123
202
|
# Only add the diff if it is not already recorded for this step
|
|
124
203
|
self.residual_diffs[step] = diff
|
|
125
204
|
|
|
205
|
+
@torch.compiler.disable
|
|
126
206
|
def get_residual_diffs(self):
|
|
127
207
|
return self.residual_diffs.copy()
|
|
128
208
|
|
|
209
|
+
@torch.compiler.disable
|
|
129
210
|
def add_cached_step(self):
|
|
130
211
|
self.cached_steps.append(self.get_current_step())
|
|
131
212
|
|
|
213
|
+
@torch.compiler.disable
|
|
132
214
|
def get_cached_steps(self):
|
|
133
215
|
return self.cached_steps.copy()
|
|
134
216
|
|
|
217
|
+
@torch.compiler.disable
|
|
135
218
|
def get_current_step(self):
|
|
136
219
|
return self.executed_steps - 1
|
|
137
220
|
|
|
221
|
+
@torch.compiler.disable
|
|
138
222
|
def is_in_warmup(self):
|
|
139
223
|
return self.get_current_step() < self.warmup_steps
|
|
140
224
|
|
|
@@ -216,6 +300,50 @@ def get_residual_diffs():
|
|
|
216
300
|
return cache_context.get_residual_diffs()
|
|
217
301
|
|
|
218
302
|
|
|
303
|
+
@torch.compiler.disable
|
|
304
|
+
def is_taylorseer_enabled():
|
|
305
|
+
cache_context = get_current_cache_context()
|
|
306
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
307
|
+
return cache_context.enable_taylorseer
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
@torch.compiler.disable
|
|
311
|
+
def is_encoder_taylorseer_enabled():
|
|
312
|
+
cache_context = get_current_cache_context()
|
|
313
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
314
|
+
return cache_context.enable_encoder_taylorseer
|
|
315
|
+
|
|
316
|
+
|
|
317
|
+
@torch.compiler.disable
|
|
318
|
+
def get_taylorseers():
|
|
319
|
+
cache_context = get_current_cache_context()
|
|
320
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
321
|
+
return cache_context.get_taylorseers()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
@torch.compiler.disable
|
|
325
|
+
def is_taylorseer_cache_residual():
|
|
326
|
+
cache_context = get_current_cache_context()
|
|
327
|
+
assert cache_context is not None, "cache_context must be set before"
|
|
328
|
+
return cache_context.taylorseer_cache_type == "residual"
|
|
329
|
+
|
|
330
|
+
|
|
331
|
+
@torch.compiler.disable
|
|
332
|
+
def is_cache_residual():
|
|
333
|
+
if is_taylorseer_enabled():
|
|
334
|
+
# residual or hidden_states
|
|
335
|
+
return is_taylorseer_cache_residual()
|
|
336
|
+
return True
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@torch.compiler.disable
|
|
340
|
+
def is_encoder_cache_residual():
|
|
341
|
+
if is_encoder_taylorseer_enabled():
|
|
342
|
+
# residual or hidden_states
|
|
343
|
+
return is_taylorseer_cache_residual()
|
|
344
|
+
return True
|
|
345
|
+
|
|
346
|
+
|
|
219
347
|
@torch.compiler.disable
|
|
220
348
|
def is_alter_cache_enabled():
|
|
221
349
|
cache_context = get_current_cache_context()
|
|
@@ -367,16 +495,21 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
|
367
495
|
for attr in cache_attrs
|
|
368
496
|
}
|
|
369
497
|
|
|
498
|
+
def _safe_set_sequence_field(
|
|
499
|
+
field_name: str,
|
|
500
|
+
default_value: Any = None,
|
|
501
|
+
):
|
|
502
|
+
if field_name not in cache_kwargs:
|
|
503
|
+
cache_kwargs[field_name] = kwargs.pop(
|
|
504
|
+
field_name,
|
|
505
|
+
default_value,
|
|
506
|
+
)
|
|
507
|
+
|
|
370
508
|
# Manually set sequence fields, namely, Fn_compute_blocks_ids
|
|
371
509
|
# and Bn_compute_blocks_ids, which are lists or sets.
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
)
|
|
376
|
-
cache_kwargs["Bn_compute_blocks_ids"] = kwargs.pop(
|
|
377
|
-
"Bn_compute_blocks_ids",
|
|
378
|
-
[],
|
|
379
|
-
)
|
|
510
|
+
_safe_set_sequence_field("Fn_compute_blocks_ids", [])
|
|
511
|
+
_safe_set_sequence_field("Bn_compute_blocks_ids", [])
|
|
512
|
+
_safe_set_sequence_field("taylorseer_kwargs", {})
|
|
380
513
|
|
|
381
514
|
assert default_attrs is not None, "default_attrs must be set before"
|
|
382
515
|
for attr in cache_attrs:
|
|
@@ -471,6 +604,7 @@ def are_two_tensors_similar(
|
|
|
471
604
|
@torch.compiler.disable
|
|
472
605
|
def set_Fn_buffer(buffer: torch.Tensor, prefix: str = "Fn"):
|
|
473
606
|
# Set hidden_states or residual for Fn blocks.
|
|
607
|
+
# This buffer is only use for L1 diff calculation.
|
|
474
608
|
downsample_factor = get_downsample_factor()
|
|
475
609
|
if downsample_factor > 1:
|
|
476
610
|
buffer = buffer[..., ::downsample_factor]
|
|
@@ -497,22 +631,79 @@ def get_Fn_encoder_buffer(prefix: str = "Fn"):
|
|
|
497
631
|
@torch.compiler.disable
|
|
498
632
|
def set_Bn_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
499
633
|
# Set hidden_states or residual for Bn blocks.
|
|
500
|
-
|
|
634
|
+
# This buffer is use for hidden states approximation.
|
|
635
|
+
if is_taylorseer_enabled():
|
|
636
|
+
# taylorseer, encoder_taylorseer
|
|
637
|
+
taylorseer, _ = get_taylorseers()
|
|
638
|
+
if taylorseer is not None:
|
|
639
|
+
# Use TaylorSeer to update the buffer
|
|
640
|
+
taylorseer.update(buffer)
|
|
641
|
+
else:
|
|
642
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
643
|
+
logger.debug(
|
|
644
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
645
|
+
"Falling back to default buffer retrieval."
|
|
646
|
+
)
|
|
647
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
648
|
+
else:
|
|
649
|
+
set_buffer(f"{prefix}_buffer", buffer)
|
|
501
650
|
|
|
502
651
|
|
|
503
652
|
@torch.compiler.disable
|
|
504
653
|
def get_Bn_buffer(prefix: str = "Bn"):
|
|
505
|
-
|
|
654
|
+
if is_taylorseer_enabled():
|
|
655
|
+
taylorseer, _ = get_taylorseers()
|
|
656
|
+
if taylorseer is not None:
|
|
657
|
+
return taylorseer.approximate_value()
|
|
658
|
+
else:
|
|
659
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
660
|
+
logger.debug(
|
|
661
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
662
|
+
"Falling back to default buffer retrieval."
|
|
663
|
+
)
|
|
664
|
+
# Fallback to default buffer retrieval
|
|
665
|
+
return get_buffer(f"{prefix}_buffer")
|
|
666
|
+
else:
|
|
667
|
+
return get_buffer(f"{prefix}_buffer")
|
|
506
668
|
|
|
507
669
|
|
|
508
670
|
@torch.compiler.disable
|
|
509
671
|
def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
|
|
510
|
-
|
|
672
|
+
# This buffer is use for encoder hidden states approximation.
|
|
673
|
+
if is_encoder_taylorseer_enabled():
|
|
674
|
+
# taylorseer, encoder_taylorseer
|
|
675
|
+
_, encoder_taylorseer = get_taylorseers()
|
|
676
|
+
if encoder_taylorseer is not None:
|
|
677
|
+
# Use TaylorSeer to update the buffer
|
|
678
|
+
encoder_taylorseer.update(buffer)
|
|
679
|
+
else:
|
|
680
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
681
|
+
logger.debug(
|
|
682
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
683
|
+
"Falling back to default buffer retrieval."
|
|
684
|
+
)
|
|
685
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
686
|
+
else:
|
|
687
|
+
set_buffer(f"{prefix}_encoder_buffer", buffer)
|
|
511
688
|
|
|
512
689
|
|
|
513
690
|
@torch.compiler.disable
|
|
514
691
|
def get_Bn_encoder_buffer(prefix: str = "Bn"):
|
|
515
|
-
|
|
692
|
+
if is_encoder_taylorseer_enabled():
|
|
693
|
+
_, encoder_taylorseer = get_taylorseers()
|
|
694
|
+
if encoder_taylorseer is not None:
|
|
695
|
+
# Use TaylorSeer to approximate the value
|
|
696
|
+
return encoder_taylorseer.approximate_value()
|
|
697
|
+
else:
|
|
698
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
699
|
+
logger.debug(
|
|
700
|
+
"TaylorSeer is enabled but not set in the cache context. "
|
|
701
|
+
"Falling back to default buffer retrieval."
|
|
702
|
+
)
|
|
703
|
+
# Fallback to default buffer retrieval
|
|
704
|
+
return get_buffer(f"{prefix}_encoder_buffer")
|
|
705
|
+
else:
|
|
706
|
+
return get_buffer(f"{prefix}_encoder_buffer")
|
|
516
707
|
|
|
517
708
|
|
|
518
709
|
@torch.compiler.disable
|
|
@@ -520,29 +711,38 @@ def apply_hidden_states_residual(
|
|
|
520
711
|
hidden_states: torch.Tensor,
|
|
521
712
|
encoder_hidden_states: torch.Tensor,
|
|
522
713
|
prefix: str = "Bn",
|
|
714
|
+
encoder_prefix: str = "Bn_encoder",
|
|
523
715
|
):
|
|
524
716
|
# Allow Bn and Fn prefix to be used for residual cache.
|
|
525
717
|
if "Bn" in prefix:
|
|
526
|
-
|
|
718
|
+
hidden_states_prev = get_Bn_buffer(prefix)
|
|
527
719
|
else:
|
|
528
|
-
|
|
720
|
+
hidden_states_prev = get_Fn_buffer(prefix)
|
|
529
721
|
|
|
530
|
-
assert
|
|
531
|
-
hidden_states_residual is not None
|
|
532
|
-
), f"{prefix}_buffer must be set before"
|
|
533
|
-
hidden_states = hidden_states_residual + hidden_states
|
|
722
|
+
assert hidden_states_prev is not None, f"{prefix}_buffer must be set before"
|
|
534
723
|
|
|
535
|
-
if
|
|
536
|
-
|
|
724
|
+
if is_cache_residual():
|
|
725
|
+
hidden_states = hidden_states_prev + hidden_states
|
|
537
726
|
else:
|
|
538
|
-
|
|
727
|
+
# If cache is not residual, we use the hidden states directly
|
|
728
|
+
hidden_states = hidden_states_prev
|
|
729
|
+
|
|
730
|
+
if "Bn" in encoder_prefix:
|
|
731
|
+
encoder_hidden_states_prev = get_Bn_encoder_buffer(encoder_prefix)
|
|
732
|
+
else:
|
|
733
|
+
encoder_hidden_states_prev = get_Fn_encoder_buffer(encoder_prefix)
|
|
539
734
|
|
|
540
735
|
assert (
|
|
541
|
-
|
|
736
|
+
encoder_hidden_states_prev is not None
|
|
542
737
|
), f"{prefix}_encoder_buffer must be set before"
|
|
543
|
-
|
|
544
|
-
|
|
545
|
-
|
|
738
|
+
|
|
739
|
+
if is_encoder_cache_residual():
|
|
740
|
+
encoder_hidden_states = (
|
|
741
|
+
encoder_hidden_states_prev + encoder_hidden_states
|
|
742
|
+
)
|
|
743
|
+
else:
|
|
744
|
+
# If encoder cache is not residual, we use the encoder hidden states directly
|
|
745
|
+
encoder_hidden_states = encoder_hidden_states_prev
|
|
546
746
|
|
|
547
747
|
hidden_states = hidden_states.contiguous()
|
|
548
748
|
encoder_hidden_states = encoder_hidden_states.contiguous()
|
|
@@ -674,11 +874,22 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
674
874
|
|
|
675
875
|
torch._dynamo.graph_break()
|
|
676
876
|
if can_use_cache:
|
|
877
|
+
torch._dynamo.graph_break()
|
|
677
878
|
add_cached_step()
|
|
678
879
|
del Fn_hidden_states_residual
|
|
679
880
|
hidden_states, encoder_hidden_states = apply_hidden_states_residual(
|
|
680
|
-
hidden_states,
|
|
881
|
+
hidden_states,
|
|
882
|
+
encoder_hidden_states,
|
|
883
|
+
prefix=(
|
|
884
|
+
"Bn_residual" if is_cache_residual() else "Bn_hidden_states"
|
|
885
|
+
),
|
|
886
|
+
encoder_prefix=(
|
|
887
|
+
"Bn_residual"
|
|
888
|
+
if is_encoder_cache_residual()
|
|
889
|
+
else "Bn_hidden_states"
|
|
890
|
+
),
|
|
681
891
|
)
|
|
892
|
+
torch._dynamo.graph_break()
|
|
682
893
|
# Call last `n` blocks to further process the hidden states
|
|
683
894
|
# for higher precision.
|
|
684
895
|
hidden_states, encoder_hidden_states = (
|
|
@@ -690,26 +901,48 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
690
901
|
)
|
|
691
902
|
)
|
|
692
903
|
else:
|
|
904
|
+
torch._dynamo.graph_break()
|
|
693
905
|
set_Fn_buffer(Fn_hidden_states_residual, prefix="Fn_residual")
|
|
694
906
|
if is_l1_diff_enabled():
|
|
695
907
|
# for hidden states L1 diff
|
|
696
908
|
set_Fn_buffer(hidden_states, "Fn_hidden_states")
|
|
697
909
|
del Fn_hidden_states_residual
|
|
910
|
+
torch._dynamo.graph_break()
|
|
698
911
|
(
|
|
699
912
|
hidden_states,
|
|
700
913
|
encoder_hidden_states,
|
|
701
914
|
hidden_states_residual,
|
|
702
915
|
encoder_hidden_states_residual,
|
|
703
|
-
) = self.
|
|
916
|
+
) = self.call_Mn_transformer_blocks( # middle
|
|
704
917
|
hidden_states,
|
|
705
918
|
encoder_hidden_states,
|
|
706
919
|
*args,
|
|
707
920
|
**kwargs,
|
|
708
921
|
)
|
|
709
|
-
|
|
710
|
-
|
|
711
|
-
|
|
712
|
-
|
|
922
|
+
torch._dynamo.graph_break()
|
|
923
|
+
if is_cache_residual():
|
|
924
|
+
set_Bn_buffer(
|
|
925
|
+
hidden_states_residual,
|
|
926
|
+
prefix="Bn_residual",
|
|
927
|
+
)
|
|
928
|
+
else:
|
|
929
|
+
# TaylorSeer
|
|
930
|
+
set_Bn_buffer(
|
|
931
|
+
hidden_states,
|
|
932
|
+
prefix="Bn_hidden_states",
|
|
933
|
+
)
|
|
934
|
+
if is_encoder_cache_residual():
|
|
935
|
+
set_Bn_encoder_buffer(
|
|
936
|
+
encoder_hidden_states_residual,
|
|
937
|
+
prefix="Bn_residual",
|
|
938
|
+
)
|
|
939
|
+
else:
|
|
940
|
+
# TaylorSeer
|
|
941
|
+
set_Bn_encoder_buffer(
|
|
942
|
+
encoder_hidden_states,
|
|
943
|
+
prefix="Bn_hidden_states",
|
|
944
|
+
)
|
|
945
|
+
torch._dynamo.graph_break()
|
|
713
946
|
# Call last `n` blocks to further process the hidden states
|
|
714
947
|
# for higher precision.
|
|
715
948
|
hidden_states, encoder_hidden_states = (
|
|
@@ -759,45 +992,35 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
759
992
|
selected_Fn_transformer_blocks = self.transformer_blocks[
|
|
760
993
|
: Fn_compute_blocks()
|
|
761
994
|
]
|
|
762
|
-
# Skip the blocks if they are not in the Fn_compute_blocks_ids.
|
|
763
|
-
# WARN: DON'T set len(Fn_compute_blocks_ids) > 0 NOW, still have
|
|
764
|
-
# some precision issues. We don't know whether a step should be
|
|
765
|
-
# cached or not before the first Fn blocks are processed.
|
|
766
|
-
if len(Fn_compute_blocks_ids()) > 0:
|
|
767
|
-
selected_Fn_transformer_blocks = [
|
|
768
|
-
selected_Fn_transformer_blocks[i]
|
|
769
|
-
for i in Fn_compute_blocks_ids()
|
|
770
|
-
if i < len(selected_Fn_transformer_blocks)
|
|
771
|
-
]
|
|
772
995
|
return selected_Fn_transformer_blocks
|
|
773
996
|
|
|
774
997
|
@torch.compiler.disable
|
|
775
|
-
def
|
|
998
|
+
def _Mn_single_transformer_blocks(self): # middle blocks
|
|
776
999
|
# M(N-2n): transformer_blocks [n,...] + single_transformer_blocks [0,...,N-n]
|
|
777
|
-
|
|
1000
|
+
selected_Mn_single_transformer_blocks = []
|
|
778
1001
|
if self.single_transformer_blocks is not None:
|
|
779
1002
|
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
780
|
-
|
|
1003
|
+
selected_Mn_single_transformer_blocks = (
|
|
781
1004
|
self.single_transformer_blocks
|
|
782
1005
|
)
|
|
783
1006
|
else:
|
|
784
|
-
|
|
1007
|
+
selected_Mn_single_transformer_blocks = (
|
|
785
1008
|
self.single_transformer_blocks[: -Bn_compute_blocks()]
|
|
786
1009
|
)
|
|
787
|
-
return
|
|
1010
|
+
return selected_Mn_single_transformer_blocks
|
|
788
1011
|
|
|
789
1012
|
@torch.compiler.disable
|
|
790
|
-
def
|
|
1013
|
+
def _Mn_transformer_blocks(self): # middle blocks
|
|
791
1014
|
# M(N-2n): only transformer_blocks [n,...,N-n], middle
|
|
792
1015
|
if Bn_compute_blocks() == 0: # WARN: x[:-0] = []
|
|
793
|
-
|
|
1016
|
+
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
794
1017
|
Fn_compute_blocks() :
|
|
795
1018
|
]
|
|
796
1019
|
else:
|
|
797
|
-
|
|
1020
|
+
selected_Mn_transformer_blocks = self.transformer_blocks[
|
|
798
1021
|
Fn_compute_blocks() : -Bn_compute_blocks()
|
|
799
1022
|
]
|
|
800
|
-
return
|
|
1023
|
+
return selected_Mn_transformer_blocks
|
|
801
1024
|
|
|
802
1025
|
@torch.compiler.disable
|
|
803
1026
|
def _Bn_single_transformer_blocks(self):
|
|
@@ -845,7 +1068,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
845
1068
|
|
|
846
1069
|
return hidden_states, encoder_hidden_states
|
|
847
1070
|
|
|
848
|
-
def
|
|
1071
|
+
def call_Mn_transformer_blocks(
|
|
849
1072
|
self,
|
|
850
1073
|
hidden_states: torch.Tensor,
|
|
851
1074
|
encoder_hidden_states: torch.Tensor,
|
|
@@ -873,7 +1096,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
873
1096
|
hidden_states = torch.cat(
|
|
874
1097
|
[encoder_hidden_states, hidden_states], dim=1
|
|
875
1098
|
)
|
|
876
|
-
for block in self.
|
|
1099
|
+
for block in self._Mn_single_transformer_blocks():
|
|
877
1100
|
hidden_states = block(
|
|
878
1101
|
hidden_states,
|
|
879
1102
|
*args,
|
|
@@ -887,7 +1110,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
887
1110
|
dim=1,
|
|
888
1111
|
)
|
|
889
1112
|
else:
|
|
890
|
-
for block in self.
|
|
1113
|
+
for block in self._Mn_transformer_blocks():
|
|
891
1114
|
hidden_states = block(
|
|
892
1115
|
hidden_states,
|
|
893
1116
|
encoder_hidden_states,
|
|
@@ -1016,7 +1239,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1016
1239
|
|
|
1017
1240
|
def _compute_and_cache_single_transformer_block(
|
|
1018
1241
|
self,
|
|
1019
|
-
|
|
1242
|
+
# Block index in the transformer blocks
|
|
1243
|
+
# Bn: 8, block_id should be in [0, 8)
|
|
1244
|
+
block_id: int,
|
|
1020
1245
|
# Helper inputs for hidden states split and reshape
|
|
1021
1246
|
original_hidden_states: torch.Tensor,
|
|
1022
1247
|
original_encoder_hidden_states: torch.Tensor,
|
|
@@ -1042,7 +1267,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1042
1267
|
)
|
|
1043
1268
|
# Cache residuals for the non-compute Bn blocks for
|
|
1044
1269
|
# subsequent cache steps.
|
|
1045
|
-
if
|
|
1270
|
+
if block_id not in Bn_compute_blocks_ids():
|
|
1046
1271
|
Bn_i_hidden_states = hidden_states
|
|
1047
1272
|
(
|
|
1048
1273
|
Bn_i_hidden_states_residual,
|
|
@@ -1057,16 +1282,20 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1057
1282
|
# Save original_hidden_states for diff calculation.
|
|
1058
1283
|
set_Bn_buffer(
|
|
1059
1284
|
Bn_i_original_hidden_states,
|
|
1060
|
-
prefix=f"Bn_{
|
|
1285
|
+
prefix=f"Bn_{block_id}_single_original",
|
|
1286
|
+
)
|
|
1287
|
+
set_Bn_encoder_buffer(
|
|
1288
|
+
Bn_i_original_hidden_states,
|
|
1289
|
+
prefix=f"Bn_{block_id}_single_original",
|
|
1061
1290
|
)
|
|
1062
1291
|
|
|
1063
1292
|
set_Bn_buffer(
|
|
1064
1293
|
Bn_i_hidden_states_residual,
|
|
1065
|
-
prefix=f"Bn_{
|
|
1294
|
+
prefix=f"Bn_{block_id}_single_residual",
|
|
1066
1295
|
)
|
|
1067
1296
|
set_Bn_encoder_buffer(
|
|
1068
1297
|
Bn_i_encoder_hidden_states_residual,
|
|
1069
|
-
prefix=f"Bn_{
|
|
1298
|
+
prefix=f"Bn_{block_id}_single_residual",
|
|
1070
1299
|
)
|
|
1071
1300
|
del Bn_i_hidden_states
|
|
1072
1301
|
del Bn_i_hidden_states_residual
|
|
@@ -1077,7 +1306,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1077
1306
|
else:
|
|
1078
1307
|
# Cache steps: Reuse the cached residuals.
|
|
1079
1308
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1080
|
-
if
|
|
1309
|
+
if block_id in Bn_compute_blocks_ids():
|
|
1081
1310
|
hidden_states = block(
|
|
1082
1311
|
hidden_states,
|
|
1083
1312
|
*args,
|
|
@@ -1091,7 +1320,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1091
1320
|
hidden_states, # curr step
|
|
1092
1321
|
parallelized=self._is_parallelized(),
|
|
1093
1322
|
threshold=non_compute_blocks_diff_threshold(),
|
|
1094
|
-
prefix=f"Bn_{
|
|
1323
|
+
prefix=f"Bn_{block_id}_single_original", # prev step
|
|
1095
1324
|
):
|
|
1096
1325
|
Bn_i_original_hidden_states = hidden_states
|
|
1097
1326
|
(
|
|
@@ -1106,7 +1335,16 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1106
1335
|
apply_hidden_states_residual(
|
|
1107
1336
|
Bn_i_original_hidden_states,
|
|
1108
1337
|
Bn_i_original_encoder_hidden_states,
|
|
1109
|
-
prefix=
|
|
1338
|
+
prefix=(
|
|
1339
|
+
f"Bn_{block_id}_single_residual"
|
|
1340
|
+
if is_cache_residual()
|
|
1341
|
+
else f"Bn_{block_id}_single_original"
|
|
1342
|
+
),
|
|
1343
|
+
encoder_prefix=(
|
|
1344
|
+
f"Bn_{block_id}_single_residual"
|
|
1345
|
+
if is_encoder_cache_residual()
|
|
1346
|
+
else f"Bn_{block_id}_single_original"
|
|
1347
|
+
),
|
|
1110
1348
|
)
|
|
1111
1349
|
)
|
|
1112
1350
|
hidden_states = torch.cat(
|
|
@@ -1125,7 +1363,9 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1125
1363
|
|
|
1126
1364
|
def _compute_and_cache_transformer_block(
|
|
1127
1365
|
self,
|
|
1128
|
-
|
|
1366
|
+
# Block index in the transformer blocks
|
|
1367
|
+
# Bn: 8, block_id should be in [0, 8)
|
|
1368
|
+
block_id: int,
|
|
1129
1369
|
# Below are the inputs to the block
|
|
1130
1370
|
block, # The transformer block to be executed
|
|
1131
1371
|
hidden_states: torch.Tensor,
|
|
@@ -1158,7 +1398,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1158
1398
|
)
|
|
1159
1399
|
# Cache residuals for the non-compute Bn blocks for
|
|
1160
1400
|
# subsequent cache steps.
|
|
1161
|
-
if
|
|
1401
|
+
if block_id not in Bn_compute_blocks_ids():
|
|
1162
1402
|
Bn_i_hidden_states_residual = (
|
|
1163
1403
|
hidden_states - Bn_i_original_hidden_states
|
|
1164
1404
|
)
|
|
@@ -1169,16 +1409,20 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1169
1409
|
# Save original_hidden_states for diff calculation.
|
|
1170
1410
|
set_Bn_buffer(
|
|
1171
1411
|
Bn_i_original_hidden_states,
|
|
1172
|
-
prefix=f"Bn_{
|
|
1412
|
+
prefix=f"Bn_{block_id}_original",
|
|
1413
|
+
)
|
|
1414
|
+
set_Bn_encoder_buffer(
|
|
1415
|
+
Bn_i_original_encoder_hidden_states,
|
|
1416
|
+
prefix=f"Bn_{block_id}_original",
|
|
1173
1417
|
)
|
|
1174
1418
|
|
|
1175
1419
|
set_Bn_buffer(
|
|
1176
1420
|
Bn_i_hidden_states_residual,
|
|
1177
|
-
prefix=f"Bn_{
|
|
1421
|
+
prefix=f"Bn_{block_id}_residual",
|
|
1178
1422
|
)
|
|
1179
1423
|
set_Bn_encoder_buffer(
|
|
1180
1424
|
Bn_i_encoder_hidden_states_residual,
|
|
1181
|
-
prefix=f"Bn_{
|
|
1425
|
+
prefix=f"Bn_{block_id}_residual",
|
|
1182
1426
|
)
|
|
1183
1427
|
del Bn_i_hidden_states_residual
|
|
1184
1428
|
del Bn_i_encoder_hidden_states_residual
|
|
@@ -1189,7 +1433,7 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1189
1433
|
else:
|
|
1190
1434
|
# Cache steps: Reuse the cached residuals.
|
|
1191
1435
|
# Check if the block is in the Bn_compute_blocks_ids.
|
|
1192
|
-
if
|
|
1436
|
+
if block_id in Bn_compute_blocks_ids():
|
|
1193
1437
|
hidden_states = block(
|
|
1194
1438
|
hidden_states,
|
|
1195
1439
|
encoder_hidden_states,
|
|
@@ -1211,13 +1455,22 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
1211
1455
|
hidden_states, # curr step
|
|
1212
1456
|
parallelized=self._is_parallelized(),
|
|
1213
1457
|
threshold=non_compute_blocks_diff_threshold(),
|
|
1214
|
-
prefix=f"Bn_{
|
|
1458
|
+
prefix=f"Bn_{block_id}_original", # prev step
|
|
1215
1459
|
):
|
|
1216
1460
|
hidden_states, encoder_hidden_states = (
|
|
1217
1461
|
apply_hidden_states_residual(
|
|
1218
1462
|
hidden_states,
|
|
1219
1463
|
encoder_hidden_states,
|
|
1220
|
-
prefix=
|
|
1464
|
+
prefix=(
|
|
1465
|
+
f"Bn_{block_id}_residual"
|
|
1466
|
+
if is_cache_residual()
|
|
1467
|
+
else f"Bn_{block_id}_original"
|
|
1468
|
+
),
|
|
1469
|
+
encoder_prefix=(
|
|
1470
|
+
f"Bn_{block_id}_residual"
|
|
1471
|
+
if is_encoder_cache_residual()
|
|
1472
|
+
else f"Bn_{block_id}_original"
|
|
1473
|
+
),
|
|
1221
1474
|
)
|
|
1222
1475
|
)
|
|
1223
1476
|
else:
|
|
@@ -1,5 +1,3 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
|
|
2
|
-
|
|
3
1
|
import functools
|
|
4
2
|
import unittest
|
|
5
3
|
|
|
@@ -56,7 +54,7 @@ def apply_cache_on_pipe(
|
|
|
56
54
|
shallow_patch: bool = False,
|
|
57
55
|
residual_diff_threshold=0.03,
|
|
58
56
|
downsample_factor=1,
|
|
59
|
-
# SLG is not supported in WAN with
|
|
57
|
+
# SLG is not supported in WAN with DBPrune yet
|
|
60
58
|
# slg_layers=None,
|
|
61
59
|
# slg_start: float = 0.0,
|
|
62
60
|
# slg_end: float = 0.1,
|
|
@@ -370,6 +370,9 @@ def apply_prev_hidden_states_residual(
|
|
|
370
370
|
hidden_states = hidden_states_residual + hidden_states
|
|
371
371
|
|
|
372
372
|
hidden_states = hidden_states.contiguous()
|
|
373
|
+
# NOTE: We should also support taylorseer for
|
|
374
|
+
# encoder_hidden_states approximation. Please
|
|
375
|
+
# use DBCache instead.
|
|
373
376
|
else:
|
|
374
377
|
hidden_states_residual = get_hidden_states_residual()
|
|
375
378
|
assert (
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/taylorseer.py
|
|
2
2
|
import math
|
|
3
|
+
import torch
|
|
3
4
|
|
|
4
5
|
|
|
5
6
|
class TaylorSeer:
|
|
@@ -17,6 +18,7 @@ class TaylorSeer:
|
|
|
17
18
|
self.compute_step_map = compute_step_map
|
|
18
19
|
self.reset_cache()
|
|
19
20
|
|
|
21
|
+
@torch.compiler.disable
|
|
20
22
|
def reset_cache(self):
|
|
21
23
|
self.state = {
|
|
22
24
|
"dY_prev": [None] * self.ORDER,
|
|
@@ -25,6 +27,7 @@ class TaylorSeer:
|
|
|
25
27
|
self.current_step = -1
|
|
26
28
|
self.last_non_approximated_step = -1
|
|
27
29
|
|
|
30
|
+
@torch.compiler.disable
|
|
28
31
|
def should_compute_full(self, step=None):
|
|
29
32
|
step = self.current_step if step is None else step
|
|
30
33
|
if self.compute_step_map is not None:
|
|
@@ -36,7 +39,15 @@ class TaylorSeer:
|
|
|
36
39
|
return True
|
|
37
40
|
return False
|
|
38
41
|
|
|
42
|
+
@torch.compiler.disable
|
|
39
43
|
def approximate_derivative(self, Y):
|
|
44
|
+
# n-th order Taylor expansion:
|
|
45
|
+
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
46
|
+
# + ... + d^nY(0)/dt^n * t^n / n!
|
|
47
|
+
# reference: https://github.com/Shenyi-Z/TaylorSeer
|
|
48
|
+
# TaylorSeer-FLUX/src/flux/taylor_utils/__init__.py
|
|
49
|
+
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
50
|
+
# especially for large n_derivatives.
|
|
40
51
|
dY_current = [None] * self.ORDER
|
|
41
52
|
dY_current[0] = Y
|
|
42
53
|
window = self.current_step - self.last_non_approximated_step
|
|
@@ -49,7 +60,10 @@ class TaylorSeer:
|
|
|
49
60
|
break
|
|
50
61
|
return dY_current
|
|
51
62
|
|
|
63
|
+
@torch.compiler.disable
|
|
52
64
|
def approximate_value(self):
|
|
65
|
+
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
66
|
+
# especially for large n_derivatives.
|
|
53
67
|
elapsed = self.current_step - self.last_non_approximated_step
|
|
54
68
|
output = 0
|
|
55
69
|
for i, derivative in enumerate(self.state["dY_current"]):
|
|
@@ -59,14 +73,30 @@ class TaylorSeer:
|
|
|
59
73
|
break
|
|
60
74
|
return output
|
|
61
75
|
|
|
76
|
+
@torch.compiler.disable
|
|
62
77
|
def mark_step_begin(self):
|
|
63
78
|
self.current_step += 1
|
|
64
79
|
|
|
80
|
+
@torch.compiler.disable
|
|
65
81
|
def update(self, Y):
|
|
82
|
+
# Directly call this method will ingnore the warmup
|
|
83
|
+
# policy and force full computation.
|
|
84
|
+
# Assume warmup steps is 3, and n_derivatives is 3.
|
|
85
|
+
# step 0: dY_prev = [None, None, None, None ]
|
|
86
|
+
# dY_current = [Y0, None, None, None ]
|
|
87
|
+
# step 1: dY_prev = [Y0, None, None, None ]
|
|
88
|
+
# dY_current = [Y1, dY1, None, None ]
|
|
89
|
+
# step 2: dY_prev = [Y1, dY1, None, None ]
|
|
90
|
+
# dY_current = [Y2, dY2/Y1, dY2/dY1, None ]
|
|
91
|
+
# step 3: dY_prev = [Y2, dY2/Y1, dY2/dY1, None ],
|
|
92
|
+
# dY_current = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
93
|
+
# step 4: dY_prev = [Y3, dY3/Y2, dY3/dY2, dY3/dY1]
|
|
94
|
+
# dY_current = [Y4, dY4/Y3, dY4/dY3, dY4/dY2]
|
|
66
95
|
self.state["dY_prev"] = self.state["dY_current"]
|
|
67
96
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
68
97
|
self.last_non_approximated_step = self.current_step
|
|
69
98
|
|
|
99
|
+
@torch.compiler.disable
|
|
70
100
|
def step(self, Y):
|
|
71
101
|
self.mark_step_begin()
|
|
72
102
|
if self.should_compute_full():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.2
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -37,40 +37,31 @@ Dynamic: requires-python
|
|
|
37
37
|
<p align="center">
|
|
38
38
|
<h2>🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration <br>Toolbox for Diffusion Transformers</h2>
|
|
39
39
|
</p>
|
|
40
|
-
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit.png >
|
|
40
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/cache-dit-v1.png >
|
|
41
41
|
<div align='center'>
|
|
42
42
|
<img src=https://img.shields.io/badge/Language-Python-brightgreen.svg >
|
|
43
43
|
<img src=https://img.shields.io/badge/PRs-welcome-9cf.svg >
|
|
44
44
|
<img src=https://img.shields.io/badge/PyPI-pass-brightgreen.svg >
|
|
45
45
|
<img src=https://static.pepy.tech/badge/cache-dit >
|
|
46
46
|
<img src=https://img.shields.io/badge/Python-3.10|3.11|3.12-9cf.svg >
|
|
47
|
-
<img src=https://img.shields.io/badge/Release-v0.2.
|
|
47
|
+
<img src=https://img.shields.io/badge/Release-v0.2.2-brightgreen.svg >
|
|
48
48
|
</div>
|
|
49
49
|
<p align="center">
|
|
50
|
-
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT <br>
|
|
51
|
-
</p>
|
|
52
|
-
<p align="center">
|
|
53
|
-
<h4> 🔥Supported Models🔥</h4>
|
|
54
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀FLUX.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
55
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Mochi</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
56
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
57
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀CogVideoX1.5</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
58
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀Wan2.1</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
59
|
-
<a href=https://github.com/vipshop/cache-dit/raw/main/examples> <b>🚀HunyuanVideo</b>: ✔️DBCache, ✔️DBPrune, ✔️FBCache🔥</a> <br>
|
|
50
|
+
DeepCache is for UNet not DiT. Most DiT cache speedups are complex and not training-free. CacheDiT offers <br>a set of training-free cache accelerators for DiT: <b>🔥<a href="#dbcache">DBCache</a>, <a href="#dbprune">DBPrune</a>, <a href="#taylorseer">TaylorSeer</a>, <a href="#fbcache">FBCache</a></b>, etc🔥
|
|
60
51
|
</p>
|
|
61
52
|
</div>
|
|
62
53
|
|
|
63
|
-
|
|
64
|
-
|
|
65
|
-
<
|
|
66
|
-
|
|
67
|
-
|
|
54
|
+
<div align="center">
|
|
55
|
+
<p align="center">
|
|
56
|
+
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
57
|
+
</p>
|
|
58
|
+
</div>
|
|
68
59
|
|
|
69
60
|
## 🤗 Introduction
|
|
70
61
|
|
|
71
62
|
<div align="center">
|
|
72
63
|
<p align="center">
|
|
73
|
-
<h3>🔥
|
|
64
|
+
<h3>🔥DBCache: Dual Block Caching for Diffusion Transformers</h3>
|
|
74
65
|
</p>
|
|
75
66
|
</div>
|
|
76
67
|
|
|
@@ -86,9 +77,9 @@ The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi
|
|
|
86
77
|
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
87
78
|
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
88
79
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
89
|
-
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.
|
|
80
|
+
|**Baseline(L20x1)**|**F1B0 (0.08)**|**F8B8 (0.12)**|**F8B12 (0.12)**|**F8B16 (0.20)**|**F8B20 (0.20)**|
|
|
90
81
|
|27.85s|6.04s|5.88s|5.77s|6.01s|6.20s|
|
|
91
|
-
|<img src=https://github.com/
|
|
82
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_NONE_R0.08.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F1B0_R0.08.png width=105px> |<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B8_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B12_R0.12.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B16_R0.2.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/TEXTURE_DBCACHE_F8B20_R0.2.png width=105px>|
|
|
92
83
|
|
|
93
84
|
<div align="center">
|
|
94
85
|
<p align="center">
|
|
@@ -100,7 +91,7 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
100
91
|
|
|
101
92
|
<div align="center">
|
|
102
93
|
<p align="center">
|
|
103
|
-
<h3>🔥
|
|
94
|
+
<h3>🔥DBPrune: Dynamic Block Prune with Residual Caching</h3>
|
|
104
95
|
</p>
|
|
105
96
|
</div>
|
|
106
97
|
|
|
@@ -119,11 +110,11 @@ These case studies demonstrate that even with relatively high thresholds (such a
|
|
|
119
110
|
|
|
120
111
|
<div align="center">
|
|
121
112
|
<p align="center">
|
|
122
|
-
<h3>🔥
|
|
113
|
+
<h3>🔥Context Parallelism and Torch Compile</h3>
|
|
123
114
|
</p>
|
|
124
115
|
</div>
|
|
125
116
|
|
|
126
|
-
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference.
|
|
117
|
+
Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand with [ParaAttention](https://github.com/chengzeyi/ParaAttention). Users can easily tap into its **Context Parallelism** features for distributed inference. CacheDiT is designed to work compatibly with **torch.compile.** You can easily use CacheDiT with torch.compile to further achieve a better performance.
|
|
127
118
|
|
|
128
119
|
<div align="center">
|
|
129
120
|
<p align="center">
|
|
@@ -137,12 +128,6 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
137
128
|
|+L20x4:7.75s|6.62s|6.03s|5.81s|5.24s|3.93s|
|
|
138
129
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_NONE_R0.08_S0_T20.43s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.03_P24.0_T16.25s.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.04_P34.6_T14.12s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.045_P38.2_T13.41s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.055_P45.1_T12.00s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBPRUNE_F1B0_R0.2_P59.5_T8.86s.png width=105px>|
|
|
139
130
|
|
|
140
|
-
<div align="center">
|
|
141
|
-
<p align="center">
|
|
142
|
-
<b>♥️ Please consider to leave a ⭐️ Star to support us ~ ♥️</b>
|
|
143
|
-
</p>
|
|
144
|
-
</div>
|
|
145
|
-
|
|
146
131
|
## ©️Citations
|
|
147
132
|
|
|
148
133
|
```BibTeX
|
|
@@ -155,12 +140,20 @@ Moreover, **CacheDiT** are **plug-and-play** solutions that works hand-in-hand w
|
|
|
155
140
|
}
|
|
156
141
|
```
|
|
157
142
|
|
|
143
|
+
## 👋Reference
|
|
144
|
+
|
|
145
|
+
<div id="reference"></div>
|
|
146
|
+
|
|
147
|
+
The **CacheDiT** codebase is adapted from [FBCache](https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache). Special thanks to their excellent work!
|
|
148
|
+
|
|
158
149
|
## 📖Contents
|
|
159
150
|
|
|
160
151
|
<div id="contents"></div>
|
|
161
152
|
|
|
162
153
|
- [⚙️Installation](#️installation)
|
|
154
|
+
- [🔥Supported Models](#supported)
|
|
163
155
|
- [⚡️Dual Block Cache](#dbcache)
|
|
156
|
+
- [🔥Hybrid TaylorSeer](#taylorseer)
|
|
164
157
|
- [🎉First Block Cache](#fbcache)
|
|
165
158
|
- [⚡️Dynamic Block Prune](#dbprune)
|
|
166
159
|
- [🎉Context Parallelism](#context-parallelism)
|
|
@@ -183,16 +176,31 @@ Or you can install the latest develop version from GitHub:
|
|
|
183
176
|
pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
184
177
|
```
|
|
185
178
|
|
|
179
|
+
## 🔥Supported Models
|
|
180
|
+
|
|
181
|
+
<div id="supported"></div>
|
|
182
|
+
|
|
183
|
+
- [🚀FLUX.1](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
184
|
+
- [🚀Mochi](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
185
|
+
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
186
|
+
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
187
|
+
- [🚀Wan2.1](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
188
|
+
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
189
|
+
|
|
190
|
+
|
|
186
191
|
## ⚡️DBCache: Dual Block Cache
|
|
187
192
|
|
|
188
193
|
<div id="dbcache"></div>
|
|
189
194
|
|
|
190
|
-

|
|
191
196
|
|
|
192
197
|
**DBCache** provides configurable parameters for custom optimization, enabling a balanced trade-off between performance and precision:
|
|
193
198
|
|
|
194
199
|
- **Fn**: Specifies that DBCache uses the **first n** Transformer blocks to fit the information at time step t, enabling the calculation of a more stable L1 diff and delivering more accurate information to subsequent blocks.
|
|
195
200
|
- **Bn**: Further fuses approximate information in the **last n** Transformer blocks to enhance prediction accuracy. These blocks act as an auto-scaler for approximate hidden states that use residual cache.
|
|
201
|
+
|
|
202
|
+

|
|
203
|
+
|
|
196
204
|
- **warmup_steps**: (default: 0) DBCache does not apply the caching strategy when the number of running steps is less than or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
197
205
|
- **max_cached_steps**: (default: -1) DBCache disables the caching strategy when the previous cached steps exceed this value to prevent precision degradation.
|
|
198
206
|
- **residual_diff_threshold**: The value of residual diff threshold, a higher value leads to faster performance at the cost of lower precision.
|
|
@@ -248,11 +256,50 @@ cache_options = {
|
|
|
248
256
|
|24.85s|15.59s|8.58s|15.41s|15.11s|17.74s|
|
|
249
257
|
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.08_S11.png width=105px> | <img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F1B0S1_R0.2_S19.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F8B8S1_R0.15_S15.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F12B12S4_R0.2_S16.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/DBCACHE_F16B16S4_R0.2_S13.png width=105px>|
|
|
250
258
|
|
|
259
|
+
## 🔥Hybrid TaylorSeer
|
|
260
|
+
|
|
261
|
+
<div id="taylorseer"></div>
|
|
262
|
+
|
|
263
|
+
We have supported the [TaylorSeers: From Reusing to Forecasting: Accelerating Diffusion Models with TaylorSeers](https://arxiv.org/pdf/2503.06923) algorithm to further improve the precision of DBCache in cases where the cached steps are large, namely, **Hybrid TaylorSeer + DBCache**. At timesteps with significant intervals, the feature similarity in diffusion models decreases substantially, significantly harming the generation quality.
|
|
264
|
+
|
|
265
|
+
$$
|
|
266
|
+
\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)=\mathcal{F}\left(x_t^l\right)+\sum_{i=1}^m \frac{\Delta^i \mathcal{F}\left(x_t^l\right)}{i!\cdot N^i}(-k)^i
|
|
267
|
+
$$
|
|
268
|
+
|
|
269
|
+
**TaylorSeer** employs a differential method to approximate the higher-order derivatives of features and predict features in future timesteps with Taylor series expansion. The TaylorSeer implemented in CacheDiT supports both hidden states and residual cache types. That is $\mathcal{F}\_{\text {pred }, m}\left(x_{t-k}^l\right)$ can be a residual cache or a hidden-state cache.
|
|
270
|
+
|
|
271
|
+
```python
|
|
272
|
+
cache_options = {
|
|
273
|
+
# TaylorSeer options
|
|
274
|
+
"enable_taylorseer": True,
|
|
275
|
+
"enable_encoder_taylorseer": True,
|
|
276
|
+
# Taylorseer cache type cache be hidden_states or residual.
|
|
277
|
+
"taylorseer_cache_type": "residual",
|
|
278
|
+
# Higher values of n_derivatives will lead to longer
|
|
279
|
+
# computation time but may improve precision significantly.
|
|
280
|
+
"taylorseer_kwargs": {
|
|
281
|
+
"n_derivatives": 2, # default is 2.
|
|
282
|
+
},
|
|
283
|
+
"warmup_steps": 3, # n_derivatives + 1
|
|
284
|
+
"residual_diff_threshold": 0.12,
|
|
285
|
+
}
|
|
286
|
+
```
|
|
287
|
+
<div align="center">
|
|
288
|
+
<p align="center">
|
|
289
|
+
<b>DBCache F1B0 + TaylorSeer</b>, L20x1, Steps: 28, <br>"A cat holding a sign that says hello world with complex background"
|
|
290
|
+
</p>
|
|
291
|
+
</div>
|
|
292
|
+
|
|
293
|
+
|Baseline(L20x1)|F1B0 (0.12)|+TaylorSeer|F1B0 (0.15)|+TaylorSeer|+compile|
|
|
294
|
+
|:---:|:---:|:---:|:---:|:---:|:---:|
|
|
295
|
+
|24.85s|12.85s|12.86s|10.27s|10.28s|8.48s|
|
|
296
|
+
|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/NONE_R0.08_S0.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.12_S14_T12.85s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.12_S14_T12.86s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T0ET0_R0.15_S17_T10.27s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C0_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T10.28s.png width=105px>|<img src=https://github.com/vipshop/cache-dit/raw/main/assets/U0_C1_DBCACHE_F1B0S1W0T1ET1_R0.15_S17_T8.48s.png width=105px>|
|
|
297
|
+
|
|
251
298
|
## 🎉FBCache: First Block Cache
|
|
252
299
|
|
|
253
300
|
<div id="fbcache"></div>
|
|
254
301
|
|
|
255
|
-

|
|
256
303
|
|
|
257
304
|
**DBCache** is a more general cache algorithm than **FBCache**. When Fn=1 and Bn=0, DBCache behaves identically to FBCache. Therefore, you can either use the original FBCache implementation directly or configure **DBCache** with **F1B0** settings to achieve the same functionality.
|
|
258
305
|
|
|
@@ -286,7 +333,7 @@ apply_cache_on_pipe(pipe, **cache_options)
|
|
|
286
333
|
|
|
287
334
|
<div id="dbprune"></div>
|
|
288
335
|
|
|
289
|
-

|
|
290
337
|
|
|
291
338
|
We have further implemented a new **Dynamic Block Prune** algorithm based on **Residual Caching** for Diffusion Transformers, which is referred to as **DBPrune**. DBPrune caches each block's hidden states and residuals, then dynamically prunes blocks during inference by computing the L1 distance between previous hidden states. When a block is pruned, its output is approximated using the cached residuals. DBPrune is currently in the experimental phase, and we kindly invite you to stay tuned for upcoming updates.
|
|
292
339
|
|
|
@@ -340,6 +387,9 @@ cache_options = {
|
|
|
340
387
|
apply_cache_on_pipe(pipe, **cache_options)
|
|
341
388
|
```
|
|
342
389
|
|
|
390
|
+
> [!Important]
|
|
391
|
+
> Please note that for GPUs with lower VRAM, DBPrune may not be suitable for use on video DiTs, as it caches the hidden states and residuals of each block, leading to higher GPU memory requirements. In such cases, please use DBCache, which only caches the hidden states and residuals of 2 blocks.
|
|
392
|
+
|
|
343
393
|
<div align="center">
|
|
344
394
|
<p align="center">
|
|
345
395
|
DBPrune, <b> L20x1 </b>, Steps: 28, "A cat holding a sign that says hello world with complex background"
|
|
@@ -370,7 +420,7 @@ from para_attn.context_parallel import init_context_parallel_mesh
|
|
|
370
420
|
from para_attn.context_parallel.diffusers_adapters import parallelize_pipe
|
|
371
421
|
from cache_dit.cache_factory import apply_cache_on_pipe, CacheType
|
|
372
422
|
|
|
373
|
-
|
|
423
|
+
# Init distributed process group
|
|
374
424
|
dist.init_process_group()
|
|
375
425
|
torch.cuda.set_device(dist.get_rank())
|
|
376
426
|
|
|
@@ -417,14 +467,16 @@ torch._dynamo.config.recompile_limit = 96 # default is 8
|
|
|
417
467
|
torch._dynamo.config.accumulated_recompile_limit = 2048 # default is 256
|
|
418
468
|
```
|
|
419
469
|
|
|
470
|
+
Please check [bench.py](./bench/bench.py) for more details.
|
|
471
|
+
|
|
420
472
|
## 👋Contribute
|
|
421
473
|
<div id="contribute"></div>
|
|
422
474
|
|
|
423
|
-
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](
|
|
475
|
+
How to contribute? Star ⭐️ this repo to support us or check [CONTRIBUTE.md](https://github.com/vipshop/cache-dit/raw/main/CONTRIBUTE.md).
|
|
424
476
|
|
|
425
477
|
## ©️License
|
|
426
478
|
|
|
427
479
|
<div id="license"></div>
|
|
428
480
|
|
|
429
481
|
|
|
430
|
-
We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](
|
|
482
|
+
We have followed the original License from [ParaAttention](https://github.com/chengzeyi/ParaAttention), please check [LICENSE](https://github.com/vipshop/cache-dit/raw/main/LICENSE) for more details.
|
|
@@ -1,36 +1,36 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=OjGGK5TcHVG44Y62aAqeJH4CskkZoY9ydbHOtCDew50,511
|
|
3
3
|
cache_dit/logger.py,sha256=dKfNe_RRk9HJwfgHGeRR1f0LbskJpKdGmISCbL9roQs,3443
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
|
|
6
|
-
cache_dit/cache_factory/taylorseer.py,sha256=
|
|
6
|
+
cache_dit/cache_factory/taylorseer.py,sha256=G3Bu7tPqKP1zt4pzN4gJ1T4VJobyaktC3zNSR1MACqc,4005
|
|
7
7
|
cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
8
8
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
|
-
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=
|
|
10
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=
|
|
11
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=
|
|
12
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=
|
|
13
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=
|
|
14
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=
|
|
15
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=ip3bY79lEGo2xDcGZuRxDu065q1eYabXLJUV8BsfhvM,59698
|
|
10
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
11
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=fibkeU-FHa30BNT-uPV2Eqcd5IRli07EKb25tMDp23c,2270
|
|
12
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=fddSpTHXU24COMGAY-Z21EmHHAEArZBv_-XLRFD6ADU,2625
|
|
13
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=wcZdBhjUB8WSfz40A268BtSe3nr_hRsIi2BNlg1FHRU,9965
|
|
14
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=Cmy0KHRDgwXqtmqfkrr7kw0CP6CmkSnuz29gDHcD6sQ,2262
|
|
15
|
+
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=oOb-NNesuyBAMwpRbFDG93TCwffLO7-TXZ4bvEtvGJc,2604
|
|
16
16
|
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
17
|
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=YRDwZ_16yjThpgVgDv6YaIB4QCE9nEkE-MOru0jOd50,35026
|
|
18
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=
|
|
19
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=
|
|
20
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=
|
|
21
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=
|
|
22
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=
|
|
23
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=
|
|
18
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
|
|
19
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
|
|
20
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
|
|
21
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=xAkd40BGsfuCKdW3Abrx35VwgZQg4CZFz13P4VY71eY,9968
|
|
22
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
|
|
23
|
+
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
|
|
24
24
|
cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
25
|
-
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=
|
|
25
|
+
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=oeOmVDho8aJa86p8LGACAWltu6Fe4chOW2OW8aPtB5c,23643
|
|
26
26
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
|
|
27
27
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
|
|
28
28
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
29
29
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sha256=OL7W4ukYlZz0IDmBR1zVV6XT3Mgciglj9Hqzv1wUAkQ,10092
|
|
30
30
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
31
31
|
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
|
|
32
|
-
cache_dit-0.2.
|
|
33
|
-
cache_dit-0.2.
|
|
34
|
-
cache_dit-0.2.
|
|
35
|
-
cache_dit-0.2.
|
|
36
|
-
cache_dit-0.2.
|
|
32
|
+
cache_dit-0.2.2.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
33
|
+
cache_dit-0.2.2.dist-info/METADATA,sha256=-4Sl6vCYl5bPBQVtlHGhHvEz0Nn-o3R5mYArGARsjLU,24521
|
|
34
|
+
cache_dit-0.2.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
35
|
+
cache_dit-0.2.2.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
36
|
+
cache_dit-0.2.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|