cache-dit 0.2.32__py3-none-any.whl → 0.2.33__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/block_adapters.py +1 -1
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +23 -62
- cache_dit/cache_factory/cache_blocks/pattern_base.py +23 -168
- cache_dit/cache_factory/cache_contexts/cache_context.py +7 -53
- cache_dit/cache_factory/cache_contexts/cache_manager.py +18 -66
- cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -8
- cache_dit/quantize/quantize_ao.py +3 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.33.dist-info}/METADATA +36 -31
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.33.dist-info}/RECORD +14 -15
- cache_dit/quantize/quantize_svdq.py +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.33.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.33.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.33.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.32.dist-info → cache_dit-0.2.33.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.33'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 33)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -579,7 +579,7 @@ class BlockAdapter:
|
|
|
579
579
|
assert isinstance(adapter[0], torch.nn.Module)
|
|
580
580
|
return getattr(adapter[0], "_is_cached", False)
|
|
581
581
|
else:
|
|
582
|
-
|
|
582
|
+
return getattr(adapter, "_is_cached", False)
|
|
583
583
|
|
|
584
584
|
@classmethod
|
|
585
585
|
def nested_depth(cls, obj: Any):
|
|
@@ -1,6 +1,5 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
from typing import Dict, Any
|
|
4
3
|
from cache_dit.cache_factory import ForwardPattern
|
|
5
4
|
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
6
5
|
CachedBlocks_Pattern_Base,
|
|
@@ -24,14 +23,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
24
23
|
**kwargs,
|
|
25
24
|
):
|
|
26
25
|
# Use it's own cache context.
|
|
27
|
-
self.cache_manager.set_context(
|
|
28
|
-
|
|
29
|
-
)
|
|
26
|
+
self.cache_manager.set_context(self.cache_context)
|
|
27
|
+
self._check_cache_params()
|
|
30
28
|
|
|
31
29
|
original_hidden_states = hidden_states
|
|
32
30
|
# Call first `n` blocks to process the hidden states for
|
|
33
31
|
# more stable diff calculation.
|
|
34
|
-
# encoder_hidden_states: None Pattern 3, else 4, 5
|
|
35
32
|
hidden_states, new_encoder_hidden_states = self.call_Fn_blocks(
|
|
36
33
|
hidden_states,
|
|
37
34
|
*args,
|
|
@@ -109,10 +106,7 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
109
106
|
*args,
|
|
110
107
|
**kwargs,
|
|
111
108
|
)
|
|
112
|
-
|
|
113
|
-
new_encoder_hidden_states_residual = (
|
|
114
|
-
new_encoder_hidden_states - old_encoder_hidden_states
|
|
115
|
-
)
|
|
109
|
+
|
|
116
110
|
torch._dynamo.graph_break()
|
|
117
111
|
if self.cache_manager.is_cache_residual():
|
|
118
112
|
self.cache_manager.set_Bn_buffer(
|
|
@@ -125,6 +119,10 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
125
119
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
126
120
|
)
|
|
127
121
|
|
|
122
|
+
if new_encoder_hidden_states is not None:
|
|
123
|
+
new_encoder_hidden_states_residual = (
|
|
124
|
+
new_encoder_hidden_states - old_encoder_hidden_states
|
|
125
|
+
)
|
|
128
126
|
if self.cache_manager.is_encoder_cache_residual():
|
|
129
127
|
if new_encoder_hidden_states is not None:
|
|
130
128
|
self.cache_manager.set_Bn_encoder_buffer(
|
|
@@ -159,27 +157,12 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
159
157
|
)
|
|
160
158
|
)
|
|
161
159
|
|
|
162
|
-
@torch.compiler.disable
|
|
163
|
-
def maybe_update_kwargs(
|
|
164
|
-
self, encoder_hidden_states, kwargs: Dict[str, Any]
|
|
165
|
-
) -> Dict[str, Any]:
|
|
166
|
-
# if "encoder_hidden_states" in kwargs:
|
|
167
|
-
# kwargs["encoder_hidden_states"] = encoder_hidden_states
|
|
168
|
-
# return kwargs
|
|
169
|
-
return kwargs
|
|
170
|
-
|
|
171
160
|
def call_Fn_blocks(
|
|
172
161
|
self,
|
|
173
162
|
hidden_states: torch.Tensor,
|
|
174
163
|
*args,
|
|
175
164
|
**kwargs,
|
|
176
165
|
):
|
|
177
|
-
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
178
|
-
self.transformer_blocks
|
|
179
|
-
), (
|
|
180
|
-
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
181
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
182
|
-
)
|
|
183
166
|
new_encoder_hidden_states = None
|
|
184
167
|
for block in self._Fn_blocks():
|
|
185
168
|
hidden_states = block(
|
|
@@ -194,10 +177,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
194
177
|
new_encoder_hidden_states,
|
|
195
178
|
hidden_states,
|
|
196
179
|
)
|
|
197
|
-
kwargs = self.maybe_update_kwargs(
|
|
198
|
-
new_encoder_hidden_states,
|
|
199
|
-
kwargs,
|
|
200
|
-
)
|
|
201
180
|
|
|
202
181
|
return hidden_states, new_encoder_hidden_states
|
|
203
182
|
|
|
@@ -222,11 +201,6 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
222
201
|
new_encoder_hidden_states,
|
|
223
202
|
hidden_states,
|
|
224
203
|
)
|
|
225
|
-
kwargs = self.maybe_update_kwargs(
|
|
226
|
-
new_encoder_hidden_states,
|
|
227
|
-
kwargs,
|
|
228
|
-
)
|
|
229
|
-
|
|
230
204
|
# compute hidden_states residual
|
|
231
205
|
hidden_states = hidden_states.contiguous()
|
|
232
206
|
hidden_states_residual = hidden_states - original_hidden_states
|
|
@@ -243,35 +217,22 @@ class CachedBlocks_Pattern_3_4_5(CachedBlocks_Pattern_Base):
|
|
|
243
217
|
*args,
|
|
244
218
|
**kwargs,
|
|
245
219
|
):
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
f"patterns: {self._supported_patterns}."
|
|
220
|
+
new_encoder_hidden_states = None
|
|
221
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
222
|
+
return hidden_states, new_encoder_hidden_states
|
|
223
|
+
|
|
224
|
+
for block in self._Bn_blocks():
|
|
225
|
+
hidden_states = block(
|
|
226
|
+
hidden_states,
|
|
227
|
+
*args,
|
|
228
|
+
**kwargs,
|
|
256
229
|
)
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
)
|
|
265
|
-
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
266
|
-
hidden_states, new_encoder_hidden_states = hidden_states
|
|
267
|
-
if not self.forward_pattern.Return_H_First:
|
|
268
|
-
hidden_states, new_encoder_hidden_states = (
|
|
269
|
-
new_encoder_hidden_states,
|
|
270
|
-
hidden_states,
|
|
271
|
-
)
|
|
272
|
-
kwargs = self.maybe_update_kwargs(
|
|
273
|
-
new_encoder_hidden_states,
|
|
274
|
-
kwargs,
|
|
275
|
-
)
|
|
230
|
+
if not isinstance(hidden_states, torch.Tensor): # Pattern 4,5
|
|
231
|
+
hidden_states, new_encoder_hidden_states = hidden_states
|
|
232
|
+
if not self.forward_pattern.Return_H_First:
|
|
233
|
+
hidden_states, new_encoder_hidden_states = (
|
|
234
|
+
new_encoder_hidden_states,
|
|
235
|
+
hidden_states,
|
|
236
|
+
)
|
|
276
237
|
|
|
277
238
|
return hidden_states, new_encoder_hidden_states
|
|
@@ -93,6 +93,21 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
93
93
|
required_param in forward_parameters
|
|
94
94
|
), f"The input parameters must contains: {required_param}."
|
|
95
95
|
|
|
96
|
+
@torch.compiler.disable
|
|
97
|
+
def _check_cache_params(self):
|
|
98
|
+
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
99
|
+
self.transformer_blocks
|
|
100
|
+
), (
|
|
101
|
+
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
102
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
103
|
+
)
|
|
104
|
+
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
105
|
+
self.transformer_blocks
|
|
106
|
+
), (
|
|
107
|
+
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
108
|
+
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
109
|
+
)
|
|
110
|
+
|
|
96
111
|
def forward(
|
|
97
112
|
self,
|
|
98
113
|
hidden_states: torch.Tensor,
|
|
@@ -100,7 +115,9 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
100
115
|
*args,
|
|
101
116
|
**kwargs,
|
|
102
117
|
):
|
|
118
|
+
# Use it's own cache context.
|
|
103
119
|
self.cache_manager.set_context(self.cache_context)
|
|
120
|
+
self._check_cache_params()
|
|
104
121
|
|
|
105
122
|
original_hidden_states = hidden_states
|
|
106
123
|
# Call first `n` blocks to process the hidden states for
|
|
@@ -191,18 +208,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
191
208
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
192
209
|
)
|
|
193
210
|
else:
|
|
194
|
-
# TaylorSeer
|
|
195
211
|
self.cache_manager.set_Bn_buffer(
|
|
196
212
|
hidden_states,
|
|
197
213
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
198
214
|
)
|
|
215
|
+
|
|
199
216
|
if self.cache_manager.is_encoder_cache_residual():
|
|
200
217
|
self.cache_manager.set_Bn_encoder_buffer(
|
|
201
218
|
encoder_hidden_states_residual,
|
|
202
219
|
prefix=f"{self.cache_prefix}_Bn_residual",
|
|
203
220
|
)
|
|
204
221
|
else:
|
|
205
|
-
# TaylorSeer
|
|
206
222
|
self.cache_manager.set_Bn_encoder_buffer(
|
|
207
223
|
encoder_hidden_states,
|
|
208
224
|
prefix=f"{self.cache_prefix}_Bn_hidden_states",
|
|
@@ -296,12 +312,6 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
296
312
|
*args,
|
|
297
313
|
**kwargs,
|
|
298
314
|
):
|
|
299
|
-
assert self.cache_manager.Fn_compute_blocks() <= len(
|
|
300
|
-
self.transformer_blocks
|
|
301
|
-
), (
|
|
302
|
-
f"Fn_compute_blocks {self.cache_manager.Fn_compute_blocks()} must be less than "
|
|
303
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
304
|
-
)
|
|
305
315
|
for block in self._Fn_blocks():
|
|
306
316
|
hidden_states = block(
|
|
307
317
|
hidden_states,
|
|
@@ -366,28 +376,17 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
366
376
|
encoder_hidden_states_residual,
|
|
367
377
|
)
|
|
368
378
|
|
|
369
|
-
def
|
|
379
|
+
def call_Bn_blocks(
|
|
370
380
|
self,
|
|
371
|
-
# Block index in the transformer blocks
|
|
372
|
-
# Bn: 8, block_id should be in [0, 8)
|
|
373
|
-
block_id: int,
|
|
374
|
-
# Below are the inputs to the block
|
|
375
|
-
block, # The transformer block to be executed
|
|
376
381
|
hidden_states: torch.Tensor,
|
|
377
382
|
encoder_hidden_states: torch.Tensor,
|
|
378
383
|
*args,
|
|
379
384
|
**kwargs,
|
|
380
385
|
):
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
# and cache the residuals in non-cache steps.
|
|
386
|
-
|
|
387
|
-
# Normal steps: Compute the block and cache the residuals.
|
|
388
|
-
if not self._is_in_cache_step():
|
|
389
|
-
Bn_i_original_hidden_states = hidden_states
|
|
390
|
-
Bn_i_original_encoder_hidden_states = encoder_hidden_states
|
|
386
|
+
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
387
|
+
return hidden_states, encoder_hidden_states
|
|
388
|
+
|
|
389
|
+
for block in self._Bn_blocks():
|
|
391
390
|
hidden_states = block(
|
|
392
391
|
hidden_states,
|
|
393
392
|
encoder_hidden_states,
|
|
@@ -401,149 +400,5 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
|
|
|
401
400
|
encoder_hidden_states,
|
|
402
401
|
hidden_states,
|
|
403
402
|
)
|
|
404
|
-
# Cache residuals for the non-compute Bn blocks for
|
|
405
|
-
# subsequent cache steps.
|
|
406
|
-
if block_id not in self.cache_manager.Bn_compute_blocks_ids():
|
|
407
|
-
Bn_i_hidden_states_residual = (
|
|
408
|
-
hidden_states - Bn_i_original_hidden_states
|
|
409
|
-
)
|
|
410
|
-
if (
|
|
411
|
-
encoder_hidden_states is not None
|
|
412
|
-
and Bn_i_original_encoder_hidden_states is not None
|
|
413
|
-
):
|
|
414
|
-
Bn_i_encoder_hidden_states_residual = (
|
|
415
|
-
encoder_hidden_states
|
|
416
|
-
- Bn_i_original_encoder_hidden_states
|
|
417
|
-
)
|
|
418
|
-
else:
|
|
419
|
-
Bn_i_encoder_hidden_states_residual = None
|
|
420
|
-
|
|
421
|
-
# Save original_hidden_states for diff calculation.
|
|
422
|
-
self.cache_manager.set_Bn_buffer(
|
|
423
|
-
Bn_i_original_hidden_states,
|
|
424
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
425
|
-
)
|
|
426
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
427
|
-
Bn_i_original_encoder_hidden_states,
|
|
428
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_original",
|
|
429
|
-
)
|
|
430
|
-
|
|
431
|
-
self.cache_manager.set_Bn_buffer(
|
|
432
|
-
Bn_i_hidden_states_residual,
|
|
433
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
434
|
-
)
|
|
435
|
-
self.cache_manager.set_Bn_encoder_buffer(
|
|
436
|
-
Bn_i_encoder_hidden_states_residual,
|
|
437
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_residual",
|
|
438
|
-
)
|
|
439
|
-
del Bn_i_hidden_states_residual
|
|
440
|
-
del Bn_i_encoder_hidden_states_residual
|
|
441
|
-
|
|
442
|
-
del Bn_i_original_hidden_states
|
|
443
|
-
del Bn_i_original_encoder_hidden_states
|
|
444
|
-
|
|
445
|
-
else:
|
|
446
|
-
# Cache steps: Reuse the cached residuals.
|
|
447
|
-
# Check if the block is in the Bn_compute_blocks_ids.
|
|
448
|
-
if block_id in self.cache_manager.Bn_compute_blocks_ids():
|
|
449
|
-
hidden_states = block(
|
|
450
|
-
hidden_states,
|
|
451
|
-
encoder_hidden_states,
|
|
452
|
-
*args,
|
|
453
|
-
**kwargs,
|
|
454
|
-
)
|
|
455
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
456
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
457
|
-
if not self.forward_pattern.Return_H_First:
|
|
458
|
-
hidden_states, encoder_hidden_states = (
|
|
459
|
-
encoder_hidden_states,
|
|
460
|
-
hidden_states,
|
|
461
|
-
)
|
|
462
|
-
else:
|
|
463
|
-
# Skip the block if it is not in the Bn_compute_blocks_ids.
|
|
464
|
-
# Use the cached residuals instead.
|
|
465
|
-
# Check if can use the cached residuals.
|
|
466
|
-
if self.cache_manager.can_cache(
|
|
467
|
-
hidden_states, # curr step
|
|
468
|
-
parallelized=self._is_parallelized(),
|
|
469
|
-
threshold=self.cache_manager.non_compute_blocks_diff_threshold(),
|
|
470
|
-
prefix=f"{self.cache_prefix}_Bn_{block_id}_original", # prev step
|
|
471
|
-
):
|
|
472
|
-
hidden_states, encoder_hidden_states = (
|
|
473
|
-
self.cache_manager.apply_cache(
|
|
474
|
-
hidden_states,
|
|
475
|
-
encoder_hidden_states,
|
|
476
|
-
prefix=(
|
|
477
|
-
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
478
|
-
if self.cache_manager.is_cache_residual()
|
|
479
|
-
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
480
|
-
),
|
|
481
|
-
encoder_prefix=(
|
|
482
|
-
f"{self.cache_prefix}_Bn_{block_id}_residual"
|
|
483
|
-
if self.cache_manager.is_encoder_cache_residual()
|
|
484
|
-
else f"{self.cache_prefix}_Bn_{block_id}_original"
|
|
485
|
-
),
|
|
486
|
-
)
|
|
487
|
-
)
|
|
488
|
-
else:
|
|
489
|
-
hidden_states = block(
|
|
490
|
-
hidden_states,
|
|
491
|
-
encoder_hidden_states,
|
|
492
|
-
*args,
|
|
493
|
-
**kwargs,
|
|
494
|
-
)
|
|
495
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
496
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
497
|
-
if not self.forward_pattern.Return_H_First:
|
|
498
|
-
hidden_states, encoder_hidden_states = (
|
|
499
|
-
encoder_hidden_states,
|
|
500
|
-
hidden_states,
|
|
501
|
-
)
|
|
502
|
-
return hidden_states, encoder_hidden_states
|
|
503
|
-
|
|
504
|
-
def call_Bn_blocks(
|
|
505
|
-
self,
|
|
506
|
-
hidden_states: torch.Tensor,
|
|
507
|
-
encoder_hidden_states: torch.Tensor,
|
|
508
|
-
*args,
|
|
509
|
-
**kwargs,
|
|
510
|
-
):
|
|
511
|
-
if self.cache_manager.Bn_compute_blocks() == 0:
|
|
512
|
-
return hidden_states, encoder_hidden_states
|
|
513
|
-
|
|
514
|
-
assert self.cache_manager.Bn_compute_blocks() <= len(
|
|
515
|
-
self.transformer_blocks
|
|
516
|
-
), (
|
|
517
|
-
f"Bn_compute_blocks {self.cache_manager.Bn_compute_blocks()} must be less than "
|
|
518
|
-
f"the number of transformer blocks {len(self.transformer_blocks)}"
|
|
519
|
-
)
|
|
520
|
-
if len(self.cache_manager.Bn_compute_blocks_ids()) > 0:
|
|
521
|
-
for i, block in enumerate(self._Bn_blocks()):
|
|
522
|
-
hidden_states, encoder_hidden_states = (
|
|
523
|
-
self._compute_or_cache_block(
|
|
524
|
-
i,
|
|
525
|
-
block,
|
|
526
|
-
hidden_states,
|
|
527
|
-
encoder_hidden_states,
|
|
528
|
-
*args,
|
|
529
|
-
**kwargs,
|
|
530
|
-
)
|
|
531
|
-
)
|
|
532
|
-
else:
|
|
533
|
-
# Compute all Bn blocks if no specific Bn compute blocks ids are set.
|
|
534
|
-
for block in self._Bn_blocks():
|
|
535
|
-
hidden_states = block(
|
|
536
|
-
hidden_states,
|
|
537
|
-
encoder_hidden_states,
|
|
538
|
-
*args,
|
|
539
|
-
**kwargs,
|
|
540
|
-
)
|
|
541
|
-
if not isinstance(hidden_states, torch.Tensor):
|
|
542
|
-
hidden_states, encoder_hidden_states = hidden_states
|
|
543
|
-
if not self.forward_pattern.Return_H_First:
|
|
544
|
-
hidden_states, encoder_hidden_states = (
|
|
545
|
-
encoder_hidden_states,
|
|
546
|
-
hidden_states,
|
|
547
|
-
)
|
|
548
403
|
|
|
549
404
|
return hidden_states, encoder_hidden_states
|
|
@@ -14,13 +14,9 @@ logger = init_logger(__name__)
|
|
|
14
14
|
@dataclasses.dataclass
|
|
15
15
|
class CachedContext: # Internal CachedContext Impl class
|
|
16
16
|
name: str = "default"
|
|
17
|
-
# Dual Block Cache
|
|
18
|
-
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
17
|
+
# Dual Block Cache with flexible FnBn configuration.
|
|
19
18
|
Fn_compute_blocks: int = 1
|
|
20
19
|
Bn_compute_blocks: int = 0
|
|
21
|
-
# We have added residual cache pattern for selected compute blocks
|
|
22
|
-
Fn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
23
|
-
Bn_compute_blocks_ids: List[int] = dataclasses.field(default_factory=list)
|
|
24
20
|
# non compute blocks diff threshold, we don't skip the non
|
|
25
21
|
# compute blocks if the diff >= threshold
|
|
26
22
|
non_compute_blocks_diff_threshold: float = 0.08
|
|
@@ -31,13 +27,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
31
27
|
l1_hidden_states_diff_threshold: float = None
|
|
32
28
|
important_condition_threshold: float = 0.0
|
|
33
29
|
|
|
34
|
-
# Alter Cache Settings
|
|
35
|
-
# Pattern: 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
36
|
-
enable_alter_cache: bool = False
|
|
37
|
-
is_alter_cache: bool = True
|
|
38
|
-
# 1.0 means we always cache the residuals if alter_cache is enabled.
|
|
39
|
-
alter_residual_diff_threshold: Optional[Union[torch.Tensor, float]] = 1.0
|
|
40
|
-
|
|
41
30
|
# Buffer for storing the residuals and other tensors
|
|
42
31
|
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
43
32
|
incremental_name_counters: DefaultDict[str, int] = dataclasses.field(
|
|
@@ -63,7 +52,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
63
52
|
# Url: https://arxiv.org/pdf/2503.06923
|
|
64
53
|
enable_taylorseer: bool = False
|
|
65
54
|
enable_encoder_taylorseer: bool = False
|
|
66
|
-
# NOTE: use residual cache for taylorseer may incur precision loss
|
|
67
55
|
taylorseer_cache_type: str = "hidden_states" # residual or hidden_states
|
|
68
56
|
taylorseer_order: int = 2 # The order for TaylorSeer
|
|
69
57
|
taylorseer_kwargs: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
@@ -97,16 +85,11 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
97
85
|
)
|
|
98
86
|
cfg_continuous_cached_steps: int = 0
|
|
99
87
|
|
|
100
|
-
@torch.compiler.disable
|
|
101
88
|
def __post_init__(self):
|
|
102
89
|
if logger.isEnabledFor(logging.DEBUG):
|
|
103
90
|
logger.info(f"Created _CacheContext: {self.name}")
|
|
104
91
|
# Some checks for settings
|
|
105
92
|
if self.enable_spearate_cfg:
|
|
106
|
-
assert self.enable_alter_cache is False, (
|
|
107
|
-
"enable_alter_cache must set as False if "
|
|
108
|
-
"enable_spearate_cfg is enabled."
|
|
109
|
-
)
|
|
110
93
|
if self.cfg_diff_compute_separate:
|
|
111
94
|
assert self.cfg_compute_first is False, (
|
|
112
95
|
"cfg_compute_first must set as False if "
|
|
@@ -135,47 +118,32 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
135
118
|
**self.taylorseer_kwargs
|
|
136
119
|
)
|
|
137
120
|
|
|
138
|
-
@torch.compiler.disable
|
|
139
121
|
def get_residual_diff_threshold(self):
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
residual_diff_threshold = self.
|
|
144
|
-
if self.l1_hidden_states_diff_threshold is not None:
|
|
145
|
-
# Use the L1 hidden states diff threshold if set
|
|
146
|
-
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
122
|
+
residual_diff_threshold = self.residual_diff_threshold
|
|
123
|
+
if self.l1_hidden_states_diff_threshold is not None:
|
|
124
|
+
# Use the L1 hidden states diff threshold if set
|
|
125
|
+
residual_diff_threshold = self.l1_hidden_states_diff_threshold
|
|
147
126
|
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
148
127
|
residual_diff_threshold = residual_diff_threshold.item()
|
|
149
128
|
return residual_diff_threshold
|
|
150
129
|
|
|
151
|
-
@torch.compiler.disable
|
|
152
130
|
def get_buffer(self, name):
|
|
153
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
154
|
-
name = f"{name}_alter"
|
|
155
131
|
return self.buffers.get(name)
|
|
156
132
|
|
|
157
|
-
@torch.compiler.disable
|
|
158
133
|
def set_buffer(self, name, buffer):
|
|
159
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
160
|
-
name = f"{name}_alter"
|
|
161
134
|
self.buffers[name] = buffer
|
|
162
135
|
|
|
163
|
-
@torch.compiler.disable
|
|
164
136
|
def remove_buffer(self, name):
|
|
165
|
-
if self.enable_alter_cache and self.is_alter_cache:
|
|
166
|
-
name = f"{name}_alter"
|
|
167
137
|
if name in self.buffers:
|
|
168
138
|
del self.buffers[name]
|
|
169
139
|
|
|
170
|
-
@torch.compiler.disable
|
|
171
140
|
def clear_buffers(self):
|
|
172
141
|
self.buffers.clear()
|
|
173
142
|
|
|
174
|
-
@torch.compiler.disable
|
|
175
143
|
def mark_step_begin(self):
|
|
176
144
|
# Always increase transformer executed steps
|
|
177
|
-
# incr
|
|
178
|
-
# current
|
|
145
|
+
# incr step: prev 0 -> 1; prev 1 -> 2
|
|
146
|
+
# current step: incr step - 1
|
|
179
147
|
self.transformer_executed_steps += 1
|
|
180
148
|
if not self.enable_spearate_cfg:
|
|
181
149
|
self.executed_steps += 1
|
|
@@ -190,10 +158,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
190
158
|
# transformer step: 0,2,4,...
|
|
191
159
|
self.executed_steps += 1
|
|
192
160
|
|
|
193
|
-
if not self.enable_alter_cache:
|
|
194
|
-
# 0 F 1 T 2 F 3 T 4 F 5 T ...
|
|
195
|
-
self.is_alter_cache = not self.is_alter_cache
|
|
196
|
-
|
|
197
161
|
# Reset the cached steps and residual diffs at the beginning
|
|
198
162
|
# of each inference.
|
|
199
163
|
if self.get_current_transformer_step() == 0:
|
|
@@ -248,7 +212,6 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
248
212
|
def get_cfg_taylorseers(self) -> Tuple[TaylorSeer, TaylorSeer]:
|
|
249
213
|
return self.cfg_taylorseer, self.cfg_encoder_taylorseer
|
|
250
214
|
|
|
251
|
-
@torch.compiler.disable
|
|
252
215
|
def add_residual_diff(self, diff):
|
|
253
216
|
# step: executed_steps - 1, not transformer_steps - 1
|
|
254
217
|
step = str(self.get_current_step())
|
|
@@ -260,15 +223,12 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
260
223
|
if step not in self.cfg_residual_diffs:
|
|
261
224
|
self.cfg_residual_diffs[step] = diff
|
|
262
225
|
|
|
263
|
-
@torch.compiler.disable
|
|
264
226
|
def get_residual_diffs(self):
|
|
265
227
|
return self.residual_diffs.copy()
|
|
266
228
|
|
|
267
|
-
@torch.compiler.disable
|
|
268
229
|
def get_cfg_residual_diffs(self):
|
|
269
230
|
return self.cfg_residual_diffs.copy()
|
|
270
231
|
|
|
271
|
-
@torch.compiler.disable
|
|
272
232
|
def add_cached_step(self):
|
|
273
233
|
curr_cached_step = self.get_current_step()
|
|
274
234
|
if not self.is_separate_cfg_step():
|
|
@@ -296,23 +256,18 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
296
256
|
|
|
297
257
|
self.cfg_cached_steps.append(curr_cached_step)
|
|
298
258
|
|
|
299
|
-
@torch.compiler.disable
|
|
300
259
|
def get_cached_steps(self):
|
|
301
260
|
return self.cached_steps.copy()
|
|
302
261
|
|
|
303
|
-
@torch.compiler.disable
|
|
304
262
|
def get_cfg_cached_steps(self):
|
|
305
263
|
return self.cfg_cached_steps.copy()
|
|
306
264
|
|
|
307
|
-
@torch.compiler.disable
|
|
308
265
|
def get_current_step(self):
|
|
309
266
|
return self.executed_steps - 1
|
|
310
267
|
|
|
311
|
-
@torch.compiler.disable
|
|
312
268
|
def get_current_transformer_step(self):
|
|
313
269
|
return self.transformer_executed_steps - 1
|
|
314
270
|
|
|
315
|
-
@torch.compiler.disable
|
|
316
271
|
def is_separate_cfg_step(self):
|
|
317
272
|
if not self.enable_spearate_cfg:
|
|
318
273
|
return False
|
|
@@ -322,6 +277,5 @@ class CachedContext: # Internal CachedContext Impl class
|
|
|
322
277
|
# CFG steps: 1, 3, 5, 7, ...
|
|
323
278
|
return self.get_current_transformer_step() % 2 != 0
|
|
324
279
|
|
|
325
|
-
@torch.compiler.disable
|
|
326
280
|
def is_in_warmup(self):
|
|
327
281
|
return self.get_current_step() < self.max_warmup_steps
|
|
@@ -122,10 +122,7 @@ class CachedContextManager:
|
|
|
122
122
|
default_value,
|
|
123
123
|
)
|
|
124
124
|
|
|
125
|
-
# Manually set sequence fields
|
|
126
|
-
# and Bn_compute_blocks_ids, which are lists or sets.
|
|
127
|
-
_safe_set_sequence_field("Fn_compute_blocks_ids", [])
|
|
128
|
-
_safe_set_sequence_field("Bn_compute_blocks_ids", [])
|
|
125
|
+
# Manually set sequence fields
|
|
129
126
|
_safe_set_sequence_field("taylorseer_kwargs", {})
|
|
130
127
|
|
|
131
128
|
for attr in cache_attrs:
|
|
@@ -301,18 +298,6 @@ class CachedContextManager:
|
|
|
301
298
|
return self.is_taylorseer_cache_residual()
|
|
302
299
|
return True
|
|
303
300
|
|
|
304
|
-
@torch.compiler.disable
|
|
305
|
-
def is_alter_cache_enabled(self) -> bool:
|
|
306
|
-
cached_context = self.get_context()
|
|
307
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
308
|
-
return cached_context.enable_alter_cache
|
|
309
|
-
|
|
310
|
-
@torch.compiler.disable
|
|
311
|
-
def is_alter_cache(self) -> bool:
|
|
312
|
-
cached_context = self.get_context()
|
|
313
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
314
|
-
return cached_context.is_alter_cache
|
|
315
|
-
|
|
316
301
|
@torch.compiler.disable
|
|
317
302
|
def is_in_warmup(self) -> bool:
|
|
318
303
|
cached_context = self.get_context()
|
|
@@ -359,20 +344,6 @@ class CachedContextManager:
|
|
|
359
344
|
)
|
|
360
345
|
return cached_context.Fn_compute_blocks
|
|
361
346
|
|
|
362
|
-
@torch.compiler.disable
|
|
363
|
-
def Fn_compute_blocks_ids(self) -> List[int]:
|
|
364
|
-
cached_context = self.get_context()
|
|
365
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
366
|
-
assert (
|
|
367
|
-
len(cached_context.Fn_compute_blocks_ids)
|
|
368
|
-
<= cached_context.Fn_compute_blocks
|
|
369
|
-
), (
|
|
370
|
-
"The num of Fn_compute_blocks_ids must be <= Fn_compute_blocks "
|
|
371
|
-
f"{cached_context.Fn_compute_blocks}, but got "
|
|
372
|
-
f"{len(cached_context.Fn_compute_blocks_ids)}"
|
|
373
|
-
)
|
|
374
|
-
return cached_context.Fn_compute_blocks_ids
|
|
375
|
-
|
|
376
347
|
@torch.compiler.disable
|
|
377
348
|
def Bn_compute_blocks(self) -> int:
|
|
378
349
|
cached_context = self.get_context()
|
|
@@ -392,20 +363,6 @@ class CachedContextManager:
|
|
|
392
363
|
)
|
|
393
364
|
return cached_context.Bn_compute_blocks
|
|
394
365
|
|
|
395
|
-
@torch.compiler.disable
|
|
396
|
-
def Bn_compute_blocks_ids(self) -> List[int]:
|
|
397
|
-
cached_context = self.get_context()
|
|
398
|
-
assert cached_context is not None, "cached_context must be set before"
|
|
399
|
-
assert (
|
|
400
|
-
len(cached_context.Bn_compute_blocks_ids)
|
|
401
|
-
<= cached_context.Bn_compute_blocks
|
|
402
|
-
), (
|
|
403
|
-
"The num of Bn_compute_blocks_ids must be <= Bn_compute_blocks "
|
|
404
|
-
f"{cached_context.Bn_compute_blocks}, but got "
|
|
405
|
-
f"{len(cached_context.Bn_compute_blocks_ids)}"
|
|
406
|
-
)
|
|
407
|
-
return cached_context.Bn_compute_blocks_ids
|
|
408
|
-
|
|
409
366
|
@torch.compiler.disable
|
|
410
367
|
def enable_spearate_cfg(self) -> bool:
|
|
411
368
|
cached_context = self.get_context()
|
|
@@ -525,6 +482,9 @@ class CachedContextManager:
|
|
|
525
482
|
# Fn buffers
|
|
526
483
|
@torch.compiler.disable
|
|
527
484
|
def set_Fn_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
|
|
485
|
+
# DON'T set None Buffer
|
|
486
|
+
if buffer is None:
|
|
487
|
+
return
|
|
528
488
|
# Set hidden_states or residual for Fn blocks.
|
|
529
489
|
# This buffer is only use for L1 diff calculation.
|
|
530
490
|
downsample_factor = self.get_downsample_factor()
|
|
@@ -548,6 +508,9 @@ class CachedContextManager:
|
|
|
548
508
|
|
|
549
509
|
@torch.compiler.disable
|
|
550
510
|
def set_Fn_encoder_buffer(self, buffer: torch.Tensor, prefix: str = "Fn"):
|
|
511
|
+
# DON'T set None Buffer
|
|
512
|
+
if buffer is None:
|
|
513
|
+
return
|
|
551
514
|
if self.is_separate_cfg_step():
|
|
552
515
|
self._debugging_set_buffer(f"{prefix}_encoder_buffer_cfg")
|
|
553
516
|
self.set_buffer(f"{prefix}_encoder_buffer_cfg", buffer)
|
|
@@ -566,6 +529,9 @@ class CachedContextManager:
|
|
|
566
529
|
# Bn buffers
|
|
567
530
|
@torch.compiler.disable
|
|
568
531
|
def set_Bn_buffer(self, buffer: torch.Tensor, prefix: str = "Bn"):
|
|
532
|
+
# DON'T set None Buffer
|
|
533
|
+
if buffer is None:
|
|
534
|
+
return
|
|
569
535
|
# Set hidden_states or residual for Bn blocks.
|
|
570
536
|
# This buffer is use for hidden states approximation.
|
|
571
537
|
if self.is_taylorseer_enabled():
|
|
@@ -820,26 +786,12 @@ class CachedContextManager:
|
|
|
820
786
|
else:
|
|
821
787
|
prev_states_tensor = self.get_Fn_buffer(prefix)
|
|
822
788
|
|
|
823
|
-
|
|
824
|
-
|
|
825
|
-
|
|
826
|
-
|
|
827
|
-
|
|
828
|
-
|
|
829
|
-
|
|
830
|
-
|
|
831
|
-
)
|
|
832
|
-
else:
|
|
833
|
-
# Only cache in the alter cache steps
|
|
834
|
-
can_cache = (
|
|
835
|
-
prev_states_tensor is not None
|
|
836
|
-
and self.similarity(
|
|
837
|
-
prev_states_tensor,
|
|
838
|
-
states_tensor,
|
|
839
|
-
threshold=threshold,
|
|
840
|
-
parallelized=parallelized,
|
|
841
|
-
prefix=prefix,
|
|
842
|
-
)
|
|
843
|
-
and self.is_alter_cache()
|
|
844
|
-
)
|
|
789
|
+
# Dynamic cache according to the residual diff
|
|
790
|
+
can_cache = prev_states_tensor is not None and self.similarity(
|
|
791
|
+
prev_states_tensor,
|
|
792
|
+
states_tensor,
|
|
793
|
+
threshold=threshold,
|
|
794
|
+
parallelized=parallelized,
|
|
795
|
+
prefix=prefix,
|
|
796
|
+
)
|
|
845
797
|
return can_cache
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import math
|
|
2
|
-
import torch
|
|
3
2
|
|
|
4
3
|
|
|
5
4
|
class TaylorSeer:
|
|
@@ -17,7 +16,6 @@ class TaylorSeer:
|
|
|
17
16
|
self.compute_step_map = compute_step_map
|
|
18
17
|
self.reset_cache()
|
|
19
18
|
|
|
20
|
-
@torch.compiler.disable
|
|
21
19
|
def reset_cache(self):
|
|
22
20
|
self.state = {
|
|
23
21
|
"dY_prev": [None] * self.ORDER,
|
|
@@ -26,7 +24,6 @@ class TaylorSeer:
|
|
|
26
24
|
self.current_step = -1
|
|
27
25
|
self.last_non_approximated_step = -1
|
|
28
26
|
|
|
29
|
-
@torch.compiler.disable
|
|
30
27
|
def should_compute_full(self, step=None):
|
|
31
28
|
step = self.current_step if step is None else step
|
|
32
29
|
if self.compute_step_map is not None:
|
|
@@ -39,7 +36,6 @@ class TaylorSeer:
|
|
|
39
36
|
return True
|
|
40
37
|
return False
|
|
41
38
|
|
|
42
|
-
@torch.compiler.disable
|
|
43
39
|
def approximate_derivative(self, Y):
|
|
44
40
|
# n-th order Taylor expansion:
|
|
45
41
|
# Y(t) = Y(0) + dY(0)/dt * t + d^2Y(0)/dt^2 * t^2 / 2!
|
|
@@ -58,7 +54,6 @@ class TaylorSeer:
|
|
|
58
54
|
break
|
|
59
55
|
return dY_current
|
|
60
56
|
|
|
61
|
-
@torch.compiler.disable
|
|
62
57
|
def approximate_value(self):
|
|
63
58
|
# TODO: Custom Triton/CUDA kernel for better performance,
|
|
64
59
|
# especially for large n_derivatives.
|
|
@@ -71,11 +66,9 @@ class TaylorSeer:
|
|
|
71
66
|
break
|
|
72
67
|
return output
|
|
73
68
|
|
|
74
|
-
@torch.compiler.disable
|
|
75
69
|
def mark_step_begin(self):
|
|
76
70
|
self.current_step += 1
|
|
77
71
|
|
|
78
|
-
@torch.compiler.disable
|
|
79
72
|
def update(self, Y):
|
|
80
73
|
# Directly call this method will ingnore the warmup
|
|
81
74
|
# policy and force full computation.
|
|
@@ -94,7 +87,6 @@ class TaylorSeer:
|
|
|
94
87
|
self.state["dY_current"] = self.approximate_derivative(Y)
|
|
95
88
|
self.last_non_approximated_step = self.current_step
|
|
96
89
|
|
|
97
|
-
@torch.compiler.disable
|
|
98
90
|
def step(self, Y):
|
|
99
91
|
self.mark_step_begin()
|
|
100
92
|
if self.should_compute_full():
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.33
|
|
4
4
|
Summary: 🤗 A Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -61,39 +61,57 @@ Dynamic: requires-python
|
|
|
61
61
|
<p align="center">
|
|
62
62
|
🎉Now, <b>cache-dit</b> covers <b>most</b> mainstream Diffusers' <b>DiT</b> Pipelines🎉<br>
|
|
63
63
|
🔥<a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Qwen-Image-Lightning</a> | <a href="#supported"> Wan 2.1/2.2 </a>🔥<br>
|
|
64
|
-
🔥<a href="#supported">
|
|
65
|
-
🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">
|
|
64
|
+
🔥<a href="#supported">HunyuanImage-2.1</a> | <a href="#supported">HunyuanVideo</a> | <a href="#supported">HunyuanDiT</a> | <a href="#supported">HiDream</a> | <a href="#supported">Mochi</a>🔥<br>
|
|
65
|
+
🔥<a href="#supported">CogView3Plus</a> | <a href="#supported">CogView4</a> | <a href="#supported">Chroma</a> | <a href="#supported"> LTXVideo </a> | <a href="#supported">CogVideoX 1/1.5</a>🔥<br>
|
|
66
66
|
🔥<a href="#supported">Cosmos</a> | <a href="#supported">SkyReelsV2</a> | <a href="#supported">VisualCloze</a> | <a href="#supported"> OmniGen </a> | <a href="#supported">Lumina 1/2</a>🔥<br>
|
|
67
|
-
🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">
|
|
67
|
+
🔥<a href="#supported">Allegro</a> | <a href="#supported">EasyAnimate</a> | <a href="#supported">SD 3/3.5</a> | <a href="#supported"> ... </a> | <a href="#supported">PixArt</a>🔥
|
|
68
68
|
</p>
|
|
69
69
|
</div>
|
|
70
70
|
<div align='center'>
|
|
71
|
-
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=
|
|
72
|
-
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=
|
|
73
|
-
<img src
|
|
74
|
-
<
|
|
71
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C0_Q0_NONE.gif width=124px>
|
|
72
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/gifs/wan2.2.C1_Q0_DBCACHE_F1B0_W2M8MC2_T1O2_R0.08.gif width=124px>
|
|
73
|
+
<img src=./assets/gifs/hunyuan_video.C0_L0_Q0_NONE.gif width=126px>
|
|
74
|
+
<img src=./assets/gifs/hunyuan_video.C0_L0_Q0_DBCACHE_F1B0_W8M0MC2_T0O2_R0.12_S27.gif width=126px>
|
|
75
|
+
<p><b>🔥Wan2.2 MoE</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.0x↑🎉 | <b>HunyuanVideo</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
|
|
75
76
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C0_Q0_NONE.png width=160px>
|
|
76
77
|
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image.C1_Q0_DBCACHE_F8B0_W8M0MC0_T1O4_R0.12_S23.png width=160px>
|
|
77
|
-
<img src
|
|
78
|
-
<
|
|
79
|
-
<
|
|
80
|
-
<img src
|
|
81
|
-
<
|
|
78
|
+
<img src=./assets/flux.C0_Q0_NONE_T23.69s.png width=90px>
|
|
79
|
+
<img src=./assets/flux.C0_Q0_DBCACHE_F1B0_W4M0MC0_T1O2_R0.15_S16_T11.39s.png width=90px>
|
|
80
|
+
<p><b>🔥Qwen-Image</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>FLUX.1-dev</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.1x↑🎉</p>
|
|
81
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_NONE.png width=160px>
|
|
82
|
+
<img src=https://github.com/vipshop/cache-dit/raw/main/assets/qwen-image-lightning.4steps.C0_L1_Q0_DBCACHE_F16B16_W2M1MC1_T0O2_R0.9_S1.png width=160px>
|
|
83
|
+
<img src=./assets/sd_3_5.C0_L0_Q0_NONE.png width=90px>
|
|
84
|
+
<img src=./assets/sd_3_5.C0_L0_Q0_DBCACHE_F1B0_W8M0MC3_T0O2_R0.12_S30.png width=90px>
|
|
85
|
+
<p><b>🔥Qwen-Image-Lightning</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.14x↑🎉 | <b>SD 3.5</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:2.5x↑🎉</p>
|
|
86
|
+
<img src=./assets/hidream.C0_L0_Q0_NONE.png width=100px>
|
|
87
|
+
<img src=./assets/hidream.C0_L0_Q0_DBCACHE_F1B0_W8M0MC0_T0O2_R0.08_S24.png width=100px>
|
|
88
|
+
<img src=./assets/cogview4.C0_L0_Q0_NONE.png width=100px>
|
|
89
|
+
<img src=./assets/cogview4.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S15.png width=100px>
|
|
90
|
+
<img src=./assets/cogview4.C0_L0_Q0_DBCACHE_F1B0_W4M0MC4_T0O2_R0.2_S22.png width=100px>
|
|
91
|
+
<p><b>🔥HiDream-I1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.9x↑🎉 | <b>CogView4</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.4x↑🎉 | 1.7x↑🎉</p>
|
|
92
|
+
<img src=./assets/gifs/mochi.C0_L0_Q0_NONE.gif width=160px>
|
|
93
|
+
<img src=./assets/gifs/mochi.C0_L0_Q0_DBCACHE_F8B0_W8M0MC0_T0O2_R0.08_S34.gif width=160px>
|
|
94
|
+
<img src=./assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_NONE.png width=91px>
|
|
95
|
+
<img src=./assets/hunyuan-image-2.1.C0_L0_Q1_fp8_w8a16_wo_DBCACHE_F8B0_W8M0MC2_T1O2_R0.12_S25.png width=91px>
|
|
96
|
+
<p><b>🔥Mochi-1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.8x↑🎉 | <b>HunyuanImage-2.1</b> | <a href="https://github.com/vipshop/cache-dit">+cache-dit</a>:1.7x↑🎉
|
|
82
97
|
<br>♥️ Please consider to leave a <b>⭐️ Star</b> to support us ~ ♥️</p>
|
|
83
98
|
</div>
|
|
84
99
|
|
|
85
100
|
## 🔥News
|
|
86
101
|
|
|
102
|
+
- [2025-09-10] 🎉Day 1 support [**HunyuanImage-2.1**](https://github.com/Tencent-Hunyuan/HunyuanImage-2.1) with **1.7x↑🎉** speedup! Check this [example](./examples/pipeline/run_hunyuan_image_2.1.py).
|
|
87
103
|
- [2025-09-08] 🔥[**Qwen-Image-Lightning**](./examples/pipeline/run_qwen_image_lightning.py) **7.1/3.5 steps🎉** inference with **[DBCache: F16B16](https://github.com/vipshop/cache-dit)**.
|
|
88
104
|
- [2025-09-03] 🎉[**Wan2.2-MoE**](https://github.com/Wan-Video) **2.4x↑🎉** speedup! Please refer to [run_wan_2.2.py](./examples/pipeline/run_wan_2.2.py) as an example.
|
|
89
105
|
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x↑🎉** speedup! Check the example: [run_qwen_image_edit.py](./examples/pipeline/run_qwen_image_edit.py).
|
|
90
|
-
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
91
106
|
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x↑🎉** speedup! Please refer to [run_qwen_image.py](./examples/pipeline/run_qwen_image.py) as an example.
|
|
92
107
|
- [2025-07-13] 🎉[**FLUX.1-dev**](https://github.com/xlite-dev/flux-faster) **3.3x↑🎉** speedup! NVIDIA L20 with **[cache-dit](https://github.com/vipshop/cache-dit)** + **compile + FP8 DQ**.
|
|
93
108
|
|
|
94
109
|
<details>
|
|
95
110
|
<summary> Previous News </summary>
|
|
96
111
|
|
|
112
|
+
- [2025-09-08] 🎉First caching mechanism in [Qwen-Image-Lightning](https://github.com/ModelTC/Qwen-Image-Lightning) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/ModelTC/Qwen-Image-Lightning/pull/35).
|
|
113
|
+
- [2025-09-08] 🎉First caching mechanism in [Wan2.2](https://github.com/Wan-Video/Wan2.2) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/Wan-Video/Wan2.2/pull/127) for more details.
|
|
114
|
+
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check this [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
97
115
|
- [2025-09-01] 📚[**Hybird Forward Pattern**](#unified) is supported! Please check [FLUX.1-dev](./examples/run_flux_adapter.py) as an example.
|
|
98
116
|
- [2025-08-10] 🔥[**FLUX.1-Kontext-dev**](https://huggingface.co/black-forest-labs/FLUX.1-Kontext-dev) is supported! Please refer [run_flux_kontext.py](./examples/pipeline/run_flux_kontext.py) as an example.
|
|
99
117
|
- [2025-07-18] 🎉First caching mechanism in [🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/huggingface/flux-fast/pull/13).
|
|
@@ -139,6 +157,7 @@ pip3 install git+https://github.com/vipshop/cache-dit.git
|
|
|
139
157
|
|
|
140
158
|
Currently, **cache-dit** library supports almost **Any** Diffusion Transformers (with **Transformer Blocks** that match the specific Input and Output **patterns**). Please check [🎉Unified Cache APIs](#unified) for more details. Here are just some of the tested models listed:
|
|
141
159
|
|
|
160
|
+
- [🚀HunyuanImage-2.1](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
142
161
|
- [🚀Qwen-Image-Lightning](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
143
162
|
- [🚀Qwen-Image-Edit](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
144
163
|
- [🚀Qwen-Image](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -147,6 +166,8 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
147
166
|
- [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
148
167
|
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
149
168
|
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
169
|
+
- [🚀CogView3-Plus](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
170
|
+
- [🚀CogView4](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
150
171
|
- [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
151
172
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
152
173
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -294,23 +315,7 @@ cache_dit.enable_cache(
|
|
|
294
315
|
Bn_compute_blocks=8, # Bn, B8, etc.
|
|
295
316
|
residual_diff_threshold=0.12,
|
|
296
317
|
)
|
|
297
|
-
```
|
|
298
|
-
Moreover, users configuring higher **Bn** values (e.g., **F8B16**) while aiming to maintain good performance can specify **Bn_compute_blocks_ids** to work with Bn. DBCache will only compute the specified blocks, with the remaining estimated using the previous step's residual cache.
|
|
299
|
-
|
|
300
|
-
```python
|
|
301
|
-
# Custom options, F8B16, higher precision with good performance.
|
|
302
|
-
cache_dit.enable_cache(
|
|
303
|
-
pipe,
|
|
304
|
-
Fn_compute_blocks=8, # Fn, F8, etc.
|
|
305
|
-
Bn_compute_blocks=16, # Bn, B16, etc.
|
|
306
|
-
# 0, 2, 4, ..., 14, 15, etc. [0,16)
|
|
307
|
-
Bn_compute_blocks_ids=cache_dit.block_range(0, 16, 2),
|
|
308
|
-
# If the L1 difference is below this threshold, skip Bn blocks
|
|
309
|
-
# not in `Bn_compute_blocks_ids`(1, 3,..., etc), Otherwise,
|
|
310
|
-
# compute these blocks.
|
|
311
|
-
non_compute_blocks_diff_threshold=0.08,
|
|
312
|
-
)
|
|
313
|
-
```
|
|
318
|
+
```
|
|
314
319
|
|
|
315
320
|
<div align="center">
|
|
316
321
|
<p align="center">
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=kX9V-FegZG4c8LMwI4PTmMqH794MEW0pzDArdhC0cJw,1241
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=gTEHTWtuqv38KTvjBsXd5hC019b6d7AyfC8gLMY7KAo,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/utils.py,sha256=WK7eqgH6gCYNHXNLmWyxBDU0XSHTPg7CfOcyXlGXBqE,10510
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
@@ -10,17 +10,17 @@ cache_dit/cache_factory/cache_types.py,sha256=ooukxQRG55uTLmaZ0SKw6gIeY6SQHhMxkb
|
|
|
10
10
|
cache_dit/cache_factory/forward_pattern.py,sha256=FumlCuZ-TSmSYH0hGBHctSJ-oGLCftdZjLygqhsmdR4,2258
|
|
11
11
|
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
12
12
|
cache_dit/cache_factory/block_adapters/__init__.py,sha256=OZM5vJwmQIkoIwVmMxKXiHqKvs31NyAva1Z91C_ko3w,17547
|
|
13
|
-
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=
|
|
13
|
+
cache_dit/cache_factory/block_adapters/block_adapters.py,sha256=IqHV10aK2qA8kEVDi7EEoUSBt0GzwCUM4GpLNf8Jgww,21656
|
|
14
14
|
cache_dit/cache_factory/block_adapters/block_registers.py,sha256=ZeN2wGPmuf2u3puSsBx8x-rl3wRo8-cWcuWNcrssVfA,2553
|
|
15
15
|
cache_dit/cache_factory/cache_blocks/__init__.py,sha256=08Ox7kD05lkRKCOsVTdEZeKAWBheqpxfrAT1Nz7eclI,2916
|
|
16
16
|
cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py,sha256=ElMps6_7uI74tSF9GDR_dEI0bZEhdzcepM29xFWnYo8,428
|
|
17
|
-
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=
|
|
18
|
-
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=
|
|
17
|
+
cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py,sha256=Bv56qETXhsREvCrNvnZpSqDIIHsi6Ze3FJW4Yk2x3uI,8597
|
|
18
|
+
cache_dit/cache_factory/cache_blocks/pattern_base.py,sha256=d4H9kEB0AgnVMT8aF0Y54SUMUQUxw5HQ8gRkoCuTQ_A,14577
|
|
19
19
|
cache_dit/cache_factory/cache_blocks/utils.py,sha256=dGOC1tMMOvcbvEgx44eTESKn_jsv-0RZ3tRHPa3wmQ4,1315
|
|
20
20
|
cache_dit/cache_factory/cache_contexts/__init__.py,sha256=rqnJ5__zqnpVHK5A1OqWILpNh5Ss-0ZDTGgtxZMKGGo,250
|
|
21
|
-
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=
|
|
22
|
-
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=
|
|
23
|
-
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=
|
|
21
|
+
cache_dit/cache_factory/cache_contexts/cache_context.py,sha256=zqixcxV_LjnyoYDZ6q3HAC-hqYyVV6g0MWKBI2hA1nQ,11855
|
|
22
|
+
cache_dit/cache_factory/cache_contexts/cache_manager.py,sha256=Mcj1upIpXT_CwO4AdY4ZNJSWoOXn3Lx2mBZRi_QuLbU,32710
|
|
23
|
+
cache_dit/cache_factory/cache_contexts/taylorseer.py,sha256=hgLmgIkQgwbFTjxqtLUCJ3mgDGEcJK09B7RK8sBdPiI,3593
|
|
24
24
|
cache_dit/cache_factory/patch_functors/__init__.py,sha256=06zdddrjvSCgBzJ0a8niRHd3ucF2qsbzlbL00d4aCvk,451
|
|
25
25
|
cache_dit/cache_factory/patch_functors/functor_base.py,sha256=Ahk0fTfrHgNdEl-9JSkACvfyyv9G-Ei5OSz7XBIlX5o,357
|
|
26
26
|
cache_dit/cache_factory/patch_functors/functor_chroma.py,sha256=2iLxlsc-1dDHRveqCXaC07E9CeMNOuBNkvpJ1atpK7E,10048
|
|
@@ -38,12 +38,11 @@ cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR
|
|
|
38
38
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
39
39
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
40
40
|
cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
|
|
41
|
-
cache_dit/quantize/quantize_ao.py,sha256=
|
|
41
|
+
cache_dit/quantize/quantize_ao.py,sha256=Fx1KW4l3gdEkdrcAYtPoDW7WKBJWrs3glOHiEwW_TgE,6160
|
|
42
42
|
cache_dit/quantize/quantize_interface.py,sha256=2s_R7xPSKuJeFpEGeLwRxnq_CqJcBG3a3lzyW5wh-UM,1241
|
|
43
|
-
cache_dit/
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
47
|
-
cache_dit-0.2.
|
|
48
|
-
cache_dit-0.2.
|
|
49
|
-
cache_dit-0.2.32.dist-info/RECORD,,
|
|
43
|
+
cache_dit-0.2.33.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
44
|
+
cache_dit-0.2.33.dist-info/METADATA,sha256=GQBvDzKLXL3tABguCRqLNc-Z39h0AcMK_J37demDTu8,25977
|
|
45
|
+
cache_dit-0.2.33.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
46
|
+
cache_dit-0.2.33.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
47
|
+
cache_dit-0.2.33.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
48
|
+
cache_dit-0.2.33.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|