cache-dit 0.2.23__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/cache_adapters.py +137 -76
- cache_dit/cache_factory/cache_context.py +112 -39
- cache_dit/cache_factory/cache_interface.py +11 -4
- cache_dit/cache_factory/taylorseer.py +5 -4
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/compile/utils.py +1 -1
- cache_dit/quantize/__init__.py +1 -0
- cache_dit/quantize/quantize_ao.py +182 -0
- cache_dit/quantize/quantize_interface.py +46 -0
- cache_dit/quantize/quantize_svdq.py +0 -0
- cache_dit/utils.py +68 -34
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/METADATA +15 -15
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/RECORD +19 -16
- cache_dit/primitives.py +0 -152
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.23.dist-info → cache_dit-0.2.25.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import time
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Callable, Optional, List
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def quantize_ao(
|
|
11
|
+
module: torch.nn.Module,
|
|
12
|
+
quant_type: str = "fp8_w8a8_dq",
|
|
13
|
+
per_row: bool = True,
|
|
14
|
+
exclude_layers: List[str] = [
|
|
15
|
+
"embedder",
|
|
16
|
+
"embed",
|
|
17
|
+
],
|
|
18
|
+
filter_fn: Optional[Callable] = None,
|
|
19
|
+
**kwargs,
|
|
20
|
+
) -> torch.nn.Module:
|
|
21
|
+
# Apply FP8 DQ for module and skip any `embed` modules
|
|
22
|
+
# by default to avoid non-trivial precision downgrade. Please
|
|
23
|
+
# set `exclude_layers` as `[]` if you don't want this behavior.
|
|
24
|
+
assert isinstance(module, torch.nn.Module)
|
|
25
|
+
|
|
26
|
+
quant_type = quant_type.lower()
|
|
27
|
+
assert quant_type in (
|
|
28
|
+
"fp8_w8a8_dq",
|
|
29
|
+
"fp8_w8a16_wo",
|
|
30
|
+
"int8_w8a8_dq",
|
|
31
|
+
"int8_w8a16_wo",
|
|
32
|
+
"int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4_dq",
|
|
34
|
+
"int4_w4a16_wo",
|
|
35
|
+
), f"{quant_type} is not supported for torchao backend now!"
|
|
36
|
+
|
|
37
|
+
if "fp8" in quant_type:
|
|
38
|
+
assert torch.cuda.get_device_capability() >= (
|
|
39
|
+
8,
|
|
40
|
+
9,
|
|
41
|
+
), "FP8 is not supported for current device."
|
|
42
|
+
|
|
43
|
+
num_quant_linear = 0
|
|
44
|
+
num_skip_linear = 0
|
|
45
|
+
num_linear_layers = 0
|
|
46
|
+
num_layers = 0
|
|
47
|
+
|
|
48
|
+
# Ensure bfloat16 for per_row
|
|
49
|
+
def _filter_fn(m: torch.nn.Module, name: str) -> bool:
|
|
50
|
+
nonlocal num_quant_linear, num_skip_linear, num_linear_layers, num_layers
|
|
51
|
+
num_layers += 1
|
|
52
|
+
if isinstance(m, torch.nn.Linear):
|
|
53
|
+
num_linear_layers += 1
|
|
54
|
+
for exclude_name in exclude_layers:
|
|
55
|
+
if exclude_name in name:
|
|
56
|
+
logger.info(
|
|
57
|
+
f"Skip Quantization: {name} -> "
|
|
58
|
+
f"pattern<{exclude_name}>"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
num_skip_linear += 1
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
if (
|
|
65
|
+
per_row
|
|
66
|
+
and m.weight.dtype != torch.bfloat16
|
|
67
|
+
and quant_type == "fp8_w8a8_dq"
|
|
68
|
+
):
|
|
69
|
+
logger.info(
|
|
70
|
+
f"Skip Quantization: {name} -> "
|
|
71
|
+
f"pattern<dtype({m.weight.dtype})!=bfloat16>"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
num_skip_linear += 1
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
num_quant_linear += 1
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
return False
|
|
81
|
+
|
|
82
|
+
def _quantization_fn():
|
|
83
|
+
try:
|
|
84
|
+
if quant_type == "fp8_w8a8_dq":
|
|
85
|
+
from torchao.quantization import (
|
|
86
|
+
float8_dynamic_activation_float8_weight,
|
|
87
|
+
PerTensor,
|
|
88
|
+
PerRow,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
quantization_fn = float8_dynamic_activation_float8_weight(
|
|
92
|
+
granularity=(
|
|
93
|
+
((PerRow(), PerRow()))
|
|
94
|
+
if per_row
|
|
95
|
+
else ((PerTensor(), PerTensor()))
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
elif quant_type == "fp8_w8a16_wo":
|
|
100
|
+
from torchao.quantization import float8_weight_only
|
|
101
|
+
|
|
102
|
+
quantization_fn = float8_weight_only()
|
|
103
|
+
|
|
104
|
+
elif quant_type == "int8_w8a8_dq":
|
|
105
|
+
from torchao.quantization import (
|
|
106
|
+
int8_dynamic_activation_int8_weight,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
quantization_fn = int8_dynamic_activation_int8_weight()
|
|
110
|
+
|
|
111
|
+
elif quant_type == "int8_w8a16_wo":
|
|
112
|
+
from torchao.quantization import int8_weight_only
|
|
113
|
+
|
|
114
|
+
quantization_fn = int8_weight_only(
|
|
115
|
+
# group_size is None -> per_channel, else per group
|
|
116
|
+
group_size=kwargs.get("group_size", None),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
elif quant_type == "int4_w4a8_dq":
|
|
120
|
+
from torchao.quantization import (
|
|
121
|
+
int8_dynamic_activation_int4_weight,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
quantization_fn = int8_dynamic_activation_int4_weight(
|
|
125
|
+
group_size=kwargs.get("group_size", 32),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
elif quant_type == "int4_w4a4_dq":
|
|
129
|
+
from torchao.quantization import (
|
|
130
|
+
int4_dynamic_activation_int4_weight,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
quantization_fn = int4_dynamic_activation_int4_weight()
|
|
134
|
+
|
|
135
|
+
elif quant_type == "int4_w4a16_wo":
|
|
136
|
+
from torchao.quantization import int4_weight_only
|
|
137
|
+
|
|
138
|
+
quantization_fn = int4_weight_only(
|
|
139
|
+
group_size=kwargs.get("group_size", 32),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"quant_type: {quant_type} is not supported now!"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
except ImportError as e:
|
|
148
|
+
e.msg += (
|
|
149
|
+
f"{quant_type} is not supported in torchao backend now! "
|
|
150
|
+
"Please upgrade the torchao library."
|
|
151
|
+
)
|
|
152
|
+
raise e
|
|
153
|
+
|
|
154
|
+
return quantization_fn
|
|
155
|
+
|
|
156
|
+
from torchao.quantization import quantize_
|
|
157
|
+
|
|
158
|
+
quantize_(
|
|
159
|
+
module,
|
|
160
|
+
_quantization_fn(),
|
|
161
|
+
filter_fn=_filter_fn if filter_fn is None else filter_fn,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
force_empty_cache()
|
|
166
|
+
|
|
167
|
+
logger.info(
|
|
168
|
+
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
169
|
+
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
170
|
+
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
|
171
|
+
f"Total (all) Layers: {num_layers:>5}"
|
|
172
|
+
)
|
|
173
|
+
return module
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def force_empty_cache():
|
|
177
|
+
time.sleep(1)
|
|
178
|
+
gc.collect()
|
|
179
|
+
torch.cuda.empty_cache()
|
|
180
|
+
time.sleep(1)
|
|
181
|
+
gc.collect()
|
|
182
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Callable, Optional, List
|
|
3
|
+
from cache_dit.logger import init_logger
|
|
4
|
+
|
|
5
|
+
logger = init_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def quantize(
|
|
9
|
+
module: torch.nn.Module,
|
|
10
|
+
quant_type: str = "fp8_w8a8_dq",
|
|
11
|
+
backend: str = "ao",
|
|
12
|
+
# only for fp8_w8a8_dq
|
|
13
|
+
per_row: bool = True,
|
|
14
|
+
exclude_layers: List[str] = [
|
|
15
|
+
"embedder",
|
|
16
|
+
"embed",
|
|
17
|
+
],
|
|
18
|
+
filter_fn: Optional[Callable] = None,
|
|
19
|
+
**kwargs,
|
|
20
|
+
) -> torch.nn.Module:
|
|
21
|
+
assert isinstance(module, torch.nn.Module)
|
|
22
|
+
|
|
23
|
+
if backend.lower() in ("ao", "torchao"):
|
|
24
|
+
from cache_dit.quantize.quantize_ao import quantize_ao
|
|
25
|
+
|
|
26
|
+
quant_type = quant_type.lower()
|
|
27
|
+
assert quant_type in (
|
|
28
|
+
"fp8_w8a8_dq",
|
|
29
|
+
"fp8_w8a16_wo",
|
|
30
|
+
"int8_w8a8_dq",
|
|
31
|
+
"int8_w8a16_wo",
|
|
32
|
+
"int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4_dq",
|
|
34
|
+
"int4_w4a16_wo",
|
|
35
|
+
), f"{quant_type} is not supported for torchao backend now!"
|
|
36
|
+
|
|
37
|
+
return quantize_ao(
|
|
38
|
+
module,
|
|
39
|
+
quant_type=quant_type,
|
|
40
|
+
per_row=per_row,
|
|
41
|
+
exclude_layers=exclude_layers,
|
|
42
|
+
filter_fn=filter_fn,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(f"backend: {backend} is not supported now!")
|
|
File without changes
|
cache_dit/utils.py
CHANGED
|
@@ -5,7 +5,10 @@ import numpy as np
|
|
|
5
5
|
from pprint import pprint
|
|
6
6
|
from diffusers import DiffusionPipeline
|
|
7
7
|
|
|
8
|
+
from typing import Dict, Any
|
|
8
9
|
from cache_dit.logger import init_logger
|
|
10
|
+
from cache_dit.cache_factory import CacheType
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
logger = init_logger(__name__)
|
|
11
14
|
|
|
@@ -27,22 +30,26 @@ class CacheStats:
|
|
|
27
30
|
|
|
28
31
|
|
|
29
32
|
def summary(
|
|
30
|
-
|
|
31
|
-
|
|
33
|
+
pipe_or_transformer: DiffusionPipeline | torch.nn.Module,
|
|
34
|
+
details: bool = False,
|
|
35
|
+
logging: bool = True,
|
|
36
|
+
) -> CacheStats:
|
|
32
37
|
cache_stats = CacheStats()
|
|
33
|
-
|
|
38
|
+
cls_name = pipe_or_transformer.__class__.__name__
|
|
39
|
+
if isinstance(pipe_or_transformer, DiffusionPipeline):
|
|
40
|
+
transformer = pipe_or_transformer.transformer
|
|
41
|
+
else:
|
|
42
|
+
transformer = pipe_or_transformer
|
|
34
43
|
|
|
35
|
-
if hasattr(
|
|
36
|
-
cache_options =
|
|
44
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
45
|
+
cache_options = transformer._cache_context_kwargs
|
|
37
46
|
cache_stats.cache_options = cache_options
|
|
38
47
|
if logging:
|
|
39
|
-
print(f"\n🤗Cache Options: {
|
|
48
|
+
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
40
49
|
|
|
41
|
-
if hasattr(
|
|
42
|
-
cached_steps: list[int] =
|
|
43
|
-
residual_diffs: dict[str, float] = dict(
|
|
44
|
-
pipe.transformer._residual_diffs
|
|
45
|
-
)
|
|
50
|
+
if hasattr(transformer, "_cached_steps"):
|
|
51
|
+
cached_steps: list[int] = transformer._cached_steps
|
|
52
|
+
residual_diffs: dict[str, float] = dict(transformer._residual_diffs)
|
|
46
53
|
cache_stats.cached_steps = cached_steps
|
|
47
54
|
cache_stats.residual_diffs = residual_diffs
|
|
48
55
|
|
|
@@ -57,7 +64,7 @@ def summary(
|
|
|
57
64
|
qmax = np.max(diffs_values)
|
|
58
65
|
|
|
59
66
|
print(
|
|
60
|
-
f"\n⚡️Cache Steps and Residual Diffs Statistics: {
|
|
67
|
+
f"\n⚡️Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
61
68
|
)
|
|
62
69
|
|
|
63
70
|
print(
|
|
@@ -74,9 +81,7 @@ def summary(
|
|
|
74
81
|
print("")
|
|
75
82
|
|
|
76
83
|
if details:
|
|
77
|
-
print(
|
|
78
|
-
f"📚Cache Steps and Residual Diffs Details: {pipe_cls_name}\n"
|
|
79
|
-
)
|
|
84
|
+
print(f"📚Cache Steps and Residual Diffs Details: {cls_name}\n")
|
|
80
85
|
pprint(
|
|
81
86
|
f"Cache Steps: {len(cached_steps)}, {cached_steps}",
|
|
82
87
|
)
|
|
@@ -85,10 +90,10 @@ def summary(
|
|
|
85
90
|
compact=True,
|
|
86
91
|
)
|
|
87
92
|
|
|
88
|
-
if hasattr(
|
|
89
|
-
cfg_cached_steps: list[int] =
|
|
93
|
+
if hasattr(transformer, "_cfg_cached_steps"):
|
|
94
|
+
cfg_cached_steps: list[int] = transformer._cfg_cached_steps
|
|
90
95
|
cfg_residual_diffs: dict[str, float] = dict(
|
|
91
|
-
|
|
96
|
+
transformer._cfg_residual_diffs
|
|
92
97
|
)
|
|
93
98
|
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
94
99
|
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
@@ -104,7 +109,7 @@ def summary(
|
|
|
104
109
|
qmax = np.max(cfg_diffs_values)
|
|
105
110
|
|
|
106
111
|
print(
|
|
107
|
-
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {
|
|
112
|
+
f"\n⚡️CFG Cache Steps and Residual Diffs Statistics: {cls_name}\n"
|
|
108
113
|
)
|
|
109
114
|
|
|
110
115
|
print(
|
|
@@ -122,7 +127,7 @@ def summary(
|
|
|
122
127
|
|
|
123
128
|
if details:
|
|
124
129
|
print(
|
|
125
|
-
f"📚CFG Cache Steps and Residual Diffs Details: {
|
|
130
|
+
f"📚CFG Cache Steps and Residual Diffs Details: {cls_name}\n"
|
|
126
131
|
)
|
|
127
132
|
pprint(
|
|
128
133
|
f"CFG Cache Steps: {len(cfg_cached_steps)}, {cfg_cached_steps}",
|
|
@@ -135,27 +140,56 @@ def summary(
|
|
|
135
140
|
return cache_stats
|
|
136
141
|
|
|
137
142
|
|
|
138
|
-
def strify(
|
|
139
|
-
|
|
143
|
+
def strify(
|
|
144
|
+
pipe_or_stats: DiffusionPipeline | CacheStats | Dict[str, Any],
|
|
145
|
+
) -> str:
|
|
146
|
+
if isinstance(pipe_or_stats, DiffusionPipeline):
|
|
140
147
|
stats = summary(pipe_or_stats, logging=False)
|
|
141
|
-
|
|
148
|
+
cache_options = stats.cache_options
|
|
149
|
+
cached_steps = len(stats.cached_steps)
|
|
150
|
+
elif isinstance(pipe_or_stats, CacheStats):
|
|
142
151
|
stats = pipe_or_stats
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
152
|
+
cache_options = stats.cache_options
|
|
153
|
+
cached_steps = len(stats.cached_steps)
|
|
154
|
+
elif isinstance(pipe_or_stats, dict):
|
|
155
|
+
# Assume cache_context_kwargs
|
|
156
|
+
cache_options = pipe_or_stats
|
|
157
|
+
cached_steps = None
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
"Please set pipe_or_stats param as one of: "
|
|
161
|
+
"DiffusionPipeline | CacheStats | Dict[str, Any]"
|
|
162
|
+
)
|
|
146
163
|
|
|
147
164
|
if not cache_options:
|
|
148
165
|
return "NONE"
|
|
149
166
|
|
|
167
|
+
if cache_options.get("cache_type", None) != CacheType.DBCache:
|
|
168
|
+
return "NONE"
|
|
169
|
+
|
|
170
|
+
def get_taylorseer_order():
|
|
171
|
+
taylorseer_order = 0
|
|
172
|
+
if "taylorseer_kwargs" in cache_options:
|
|
173
|
+
if "n_derivatives" in cache_options["taylorseer_kwargs"]:
|
|
174
|
+
taylorseer_order = cache_options["taylorseer_kwargs"][
|
|
175
|
+
"n_derivatives"
|
|
176
|
+
]
|
|
177
|
+
elif "taylorseer_order" in cache_options:
|
|
178
|
+
taylorseer_order = cache_options["taylorseer_order"]
|
|
179
|
+
return taylorseer_order
|
|
180
|
+
|
|
150
181
|
cache_type_str = (
|
|
151
|
-
f"DBCACHE_F{cache_options
|
|
152
|
-
f"B{cache_options
|
|
153
|
-
f"W{cache_options
|
|
154
|
-
f"M{max(0, cache_options
|
|
155
|
-
f"
|
|
156
|
-
f"
|
|
157
|
-
f"
|
|
158
|
-
f"
|
|
182
|
+
f"DBCACHE_F{cache_options.get('Fn_compute_blocks', 1)}"
|
|
183
|
+
f"B{cache_options.get('Bn_compute_blocks', 0)}_"
|
|
184
|
+
f"W{cache_options.get('max_warmup_steps', 0)}"
|
|
185
|
+
f"M{max(0, cache_options.get('max_cached_steps', -1))}"
|
|
186
|
+
f"MC{max(0, cache_options.get('max_continuous_cached_steps', -1))}_"
|
|
187
|
+
f"T{int(cache_options.get('enable_taylorseer', False))}"
|
|
188
|
+
f"O{get_taylorseer_order()}_"
|
|
189
|
+
f"R{cache_options.get('residual_diff_threshold', 0.08)}"
|
|
159
190
|
)
|
|
160
191
|
|
|
192
|
+
if cached_steps:
|
|
193
|
+
cache_type_str += f"_S{cached_steps}"
|
|
194
|
+
|
|
161
195
|
return cache_type_str
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.25
|
|
4
4
|
Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -11,12 +11,13 @@ Description-Content-Type: text/markdown
|
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Requires-Dist: packaging
|
|
13
13
|
Requires-Dist: pyyaml
|
|
14
|
-
Requires-Dist: torch>=2.
|
|
15
|
-
Requires-Dist: transformers>=4.
|
|
16
|
-
Requires-Dist: diffusers>=0.
|
|
14
|
+
Requires-Dist: torch>=2.7.1
|
|
15
|
+
Requires-Dist: transformers>=4.55.2
|
|
16
|
+
Requires-Dist: diffusers>=0.35.1
|
|
17
17
|
Requires-Dist: scikit-image
|
|
18
18
|
Requires-Dist: scipy
|
|
19
19
|
Requires-Dist: lpips==0.1.4
|
|
20
|
+
Requires-Dist: torchao>=0.12.0
|
|
20
21
|
Provides-Extra: all
|
|
21
22
|
Provides-Extra: dev
|
|
22
23
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -59,12 +60,13 @@ Dynamic: requires-python
|
|
|
59
60
|
</p>
|
|
60
61
|
<p align="center">
|
|
61
62
|
🎉Now, <b>cache-dit</b> covers <b>Most</b> mainstream <b>Diffusers'</b> Pipelines</b>🎉<br>
|
|
62
|
-
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
63
|
+
🔥<b><a href="#supported">Qwen-Image</a> | <a href="#supported">FLUX.1</a> | <a href="#supported">Wan 2.1/2.2</a> | <a href="#supported"> ... </a> | <a href="#supported">CogVideoX</a></b>🔥
|
|
63
64
|
</p>
|
|
64
65
|
</div>
|
|
65
66
|
|
|
66
67
|
## 🔥News
|
|
67
68
|
|
|
69
|
+
- [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.8x⚡️** speedup with `cache-dit + compile`! Check the [example](./examples/run_wan_2.2.py).
|
|
68
70
|
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check example [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
|
|
69
71
|
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
70
72
|
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x⚡️** speedup! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
|
|
@@ -119,6 +121,7 @@ Currently, **cache-dit** library supports almost **Any** Diffusion Transformers
|
|
|
119
121
|
- [🚀FLUX.1-Kontext-dev](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
120
122
|
- [🚀CogVideoX](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
121
123
|
- [🚀CogVideoX1.5](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
124
|
+
- [🚀Wan2.2-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
122
125
|
- [🚀Wan2.1-T2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
123
126
|
- [🚀Wan2.1-FLF2V](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
124
127
|
- [🚀HunyuanVideo](https://github.com/vipshop/cache-dit/raw/main/examples)
|
|
@@ -166,7 +169,7 @@ cache_dit.enable_cache(pipe)
|
|
|
166
169
|
output = pipe(...)
|
|
167
170
|
```
|
|
168
171
|
|
|
169
|
-
### 🔥
|
|
172
|
+
### 🔥Automatic Block Adapter
|
|
170
173
|
|
|
171
174
|
But in some cases, you may have a **modified** Diffusion Pipeline or Transformer that is not located in the diffusers library or not officially supported by **cache-dit** at this time. The **BlockAdapter** can help you solve this problems. Please refer to [Qwen-Image w/ BlockAdapter](./examples/run_qwen_image_adapter.py) as an example.
|
|
172
175
|
|
|
@@ -181,7 +184,7 @@ cache_dit.enable_cache(
|
|
|
181
184
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
182
185
|
)
|
|
183
186
|
|
|
184
|
-
# Or,
|
|
187
|
+
# Or, manually setup transformer configurations.
|
|
185
188
|
cache_dit.enable_cache(
|
|
186
189
|
BlockAdapter(
|
|
187
190
|
pipe=pipe, # Qwen-Image, etc.
|
|
@@ -238,7 +241,7 @@ cache_dit.enable_cache(pipe)
|
|
|
238
241
|
# Custom options, F8B8, higher precision
|
|
239
242
|
cache_dit.enable_cache(
|
|
240
243
|
pipe,
|
|
241
|
-
|
|
244
|
+
max_warmup_steps=8, # steps do not cache
|
|
242
245
|
max_cached_steps=-1, # -1 means no limit
|
|
243
246
|
Fn_compute_blocks=8, # Fn, F8, etc.
|
|
244
247
|
Bn_compute_blocks=8, # Bn, B8, etc.
|
|
@@ -292,14 +295,11 @@ cache_dit.enable_cache(
|
|
|
292
295
|
enable_encoder_taylorseer=True,
|
|
293
296
|
# Taylorseer cache type cache be hidden_states or residual.
|
|
294
297
|
taylorseer_cache_type="residual",
|
|
295
|
-
# Higher values of
|
|
296
|
-
#
|
|
297
|
-
|
|
298
|
-
"n_derivatives": 2, # default is 2.
|
|
299
|
-
},
|
|
300
|
-
warmup_steps=3, # prefer: >= n_derivatives + 1
|
|
298
|
+
# Higher values of order will lead to longer computation time
|
|
299
|
+
taylorseer_order=2, # default is 2.
|
|
300
|
+
max_warmup_steps=3, # prefer: >= order + 1
|
|
301
301
|
residual_diff_threshold=0.12
|
|
302
|
-
)
|
|
302
|
+
)s
|
|
303
303
|
```
|
|
304
304
|
|
|
305
305
|
> [!Important]
|
|
@@ -1,22 +1,21 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=VsT0f0R0COp8v6Sx9hGNsqxiElERaDpfG11a9MfK0is,945
|
|
2
|
+
cache_dit/_version.py,sha256=t9iixyDuMWz1nP7KM0bgrLNIpwu8JK6uZApA8DoCQwM,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/
|
|
5
|
-
cache_dit/utils.py,sha256=3UgVhfmTFG28w6CV-Rfxp5u1uzLrRozocHwLCTGiQ5M,5865
|
|
4
|
+
cache_dit/utils.py,sha256=1oWDMYs6E7FRsd8cidsVOPT-meIRKeuqbGbE6CrCUec,7236
|
|
6
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
7
6
|
cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
|
|
8
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=Yugqljm9tm615srM2BGQlR_tA0QiZo3PbLPceObh4dQ,25988
|
|
9
8
|
cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
|
|
10
|
-
cache_dit/cache_factory/cache_context.py,sha256=
|
|
11
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/cache_context.py,sha256=HhA5IMSdF-i-koSB1jqf5AMC_UyDV7VinpHm4Qee9Ig,41800
|
|
10
|
+
cache_dit/cache_factory/cache_interface.py,sha256=HymagnKEDs48Ly_x3IM5jTMNJpLrIdJnppVlkr2xHaM,8433
|
|
12
11
|
cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
|
|
13
12
|
cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
|
|
14
|
-
cache_dit/cache_factory/taylorseer.py,sha256=
|
|
15
|
-
cache_dit/cache_factory/utils.py,sha256=
|
|
13
|
+
cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
14
|
+
cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeHOe4,1871
|
|
16
15
|
cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
17
16
|
cache_dit/cache_factory/patch/flux.py,sha256=iNQ-1RlOgXupZ4uPiEvJ__Ro6vKT_fOKja9JrpMrO78,8998
|
|
18
17
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
19
|
-
cache_dit/compile/utils.py,sha256=
|
|
18
|
+
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
20
19
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
20
|
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
22
21
|
cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
|
|
@@ -25,9 +24,13 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
|
|
|
25
24
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
26
25
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
27
26
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
28
|
-
cache_dit
|
|
29
|
-
cache_dit
|
|
30
|
-
cache_dit
|
|
31
|
-
cache_dit
|
|
32
|
-
cache_dit-0.2.
|
|
33
|
-
cache_dit-0.2.
|
|
27
|
+
cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
|
|
28
|
+
cache_dit/quantize/quantize_ao.py,sha256=sKz_RmVtxLOpAPnUv_iOjzY_226pfaxgB_HMNrfyqB8,5465
|
|
29
|
+
cache_dit/quantize/quantize_interface.py,sha256=NG4WP7s8CLW6KhVFb9e1aAjW30KWTCcM2aS5n8QuwsA,1241
|
|
30
|
+
cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
+
cache_dit-0.2.25.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
32
|
+
cache_dit-0.2.25.dist-info/METADATA,sha256=a5wbENMZ9BDjHbM3Ejb7Il7x4QzD8W7Lzmu4poo95Wo,19913
|
|
33
|
+
cache_dit-0.2.25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
34
|
+
cache_dit-0.2.25.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
35
|
+
cache_dit-0.2.25.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
36
|
+
cache_dit-0.2.25.dist-info/RECORD,,
|