cache-dit 0.2.24__py3-none-any.whl → 0.2.25__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/cache_context.py +27 -24
- cache_dit/cache_factory/cache_interface.py +1 -1
- cache_dit/compile/utils.py +1 -1
- cache_dit/quantize/__init__.py +1 -0
- cache_dit/quantize/quantize_ao.py +182 -0
- cache_dit/quantize/quantize_interface.py +46 -0
- cache_dit/quantize/quantize_svdq.py +0 -0
- cache_dit/utils.py +46 -15
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.25.dist-info}/METADATA +10 -12
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.25.dist-info}/RECORD +16 -12
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.25.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.25.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.25.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.25.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -12,6 +12,7 @@ from cache_dit.cache_factory import CacheType
|
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
14
|
from cache_dit.compile import set_compile_configs
|
|
15
|
+
from cache_dit.quantize import quantize
|
|
15
16
|
from cache_dit.utils import summary
|
|
16
17
|
from cache_dit.utils import strify
|
|
17
18
|
from cache_dit.logger import init_logger
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.25'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 25)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -328,6 +328,33 @@ class DBCacheContext:
|
|
|
328
328
|
return self.get_current_step() < self.max_warmup_steps
|
|
329
329
|
|
|
330
330
|
|
|
331
|
+
# TODO: Support context manager for different cache_context
|
|
332
|
+
|
|
333
|
+
|
|
334
|
+
def create_cache_context(*args, **kwargs):
|
|
335
|
+
return DBCacheContext(*args, **kwargs)
|
|
336
|
+
|
|
337
|
+
|
|
338
|
+
def get_current_cache_context():
|
|
339
|
+
return _current_cache_context
|
|
340
|
+
|
|
341
|
+
|
|
342
|
+
def set_current_cache_context(cache_context=None):
|
|
343
|
+
global _current_cache_context
|
|
344
|
+
_current_cache_context = cache_context
|
|
345
|
+
|
|
346
|
+
|
|
347
|
+
@contextlib.contextmanager
|
|
348
|
+
def cache_context(cache_context):
|
|
349
|
+
global _current_cache_context
|
|
350
|
+
old_cache_context = _current_cache_context
|
|
351
|
+
_current_cache_context = cache_context
|
|
352
|
+
try:
|
|
353
|
+
yield
|
|
354
|
+
finally:
|
|
355
|
+
_current_cache_context = old_cache_context
|
|
356
|
+
|
|
357
|
+
|
|
331
358
|
@torch.compiler.disable
|
|
332
359
|
def get_residual_diff_threshold():
|
|
333
360
|
cache_context = get_current_cache_context()
|
|
@@ -657,19 +684,6 @@ def cfg_diff_compute_separate():
|
|
|
657
684
|
_current_cache_context: DBCacheContext = None
|
|
658
685
|
|
|
659
686
|
|
|
660
|
-
def create_cache_context(*args, **kwargs):
|
|
661
|
-
return DBCacheContext(*args, **kwargs)
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
def get_current_cache_context():
|
|
665
|
-
return _current_cache_context
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
def set_current_cache_context(cache_context=None):
|
|
669
|
-
global _current_cache_context
|
|
670
|
-
_current_cache_context = cache_context
|
|
671
|
-
|
|
672
|
-
|
|
673
687
|
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
674
688
|
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
675
689
|
# default_attrs: specific settings for different pipelines
|
|
@@ -716,17 +730,6 @@ def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
|
716
730
|
return cache_kwargs, kwargs
|
|
717
731
|
|
|
718
732
|
|
|
719
|
-
@contextlib.contextmanager
|
|
720
|
-
def cache_context(cache_context):
|
|
721
|
-
global _current_cache_context
|
|
722
|
-
old_cache_context = _current_cache_context
|
|
723
|
-
_current_cache_context = cache_context
|
|
724
|
-
try:
|
|
725
|
-
yield
|
|
726
|
-
finally:
|
|
727
|
-
_current_cache_context = old_cache_context
|
|
728
|
-
|
|
729
|
-
|
|
730
733
|
@torch.compiler.disable
|
|
731
734
|
def are_two_tensors_similar(
|
|
732
735
|
t1: torch.Tensor, # prev residual R(t-1,n) = H(t-1,n) - H(t-1,0)
|
|
@@ -23,7 +23,7 @@ def enable_cache(
|
|
|
23
23
|
# Cache CFG or not
|
|
24
24
|
do_separate_cfg: bool = False,
|
|
25
25
|
cfg_compute_first: bool = False,
|
|
26
|
-
cfg_diff_compute_separate: bool =
|
|
26
|
+
cfg_diff_compute_separate: bool = True,
|
|
27
27
|
# Hybird TaylorSeer
|
|
28
28
|
enable_taylorseer: bool = False,
|
|
29
29
|
enable_encoder_taylorseer: bool = False,
|
cache_dit/compile/utils.py
CHANGED
|
@@ -24,7 +24,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
|
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def set_compile_configs(
|
|
27
|
-
descent_tuning: bool =
|
|
27
|
+
descent_tuning: bool = False,
|
|
28
28
|
cuda_graphs: bool = False,
|
|
29
29
|
force_disable_compile_caches: bool = False,
|
|
30
30
|
use_fast_math: bool = False,
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from cache_dit.quantize.quantize_interface import quantize
|
|
@@ -0,0 +1,182 @@
|
|
|
1
|
+
import gc
|
|
2
|
+
import time
|
|
3
|
+
import torch
|
|
4
|
+
from typing import Callable, Optional, List
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
def quantize_ao(
|
|
11
|
+
module: torch.nn.Module,
|
|
12
|
+
quant_type: str = "fp8_w8a8_dq",
|
|
13
|
+
per_row: bool = True,
|
|
14
|
+
exclude_layers: List[str] = [
|
|
15
|
+
"embedder",
|
|
16
|
+
"embed",
|
|
17
|
+
],
|
|
18
|
+
filter_fn: Optional[Callable] = None,
|
|
19
|
+
**kwargs,
|
|
20
|
+
) -> torch.nn.Module:
|
|
21
|
+
# Apply FP8 DQ for module and skip any `embed` modules
|
|
22
|
+
# by default to avoid non-trivial precision downgrade. Please
|
|
23
|
+
# set `exclude_layers` as `[]` if you don't want this behavior.
|
|
24
|
+
assert isinstance(module, torch.nn.Module)
|
|
25
|
+
|
|
26
|
+
quant_type = quant_type.lower()
|
|
27
|
+
assert quant_type in (
|
|
28
|
+
"fp8_w8a8_dq",
|
|
29
|
+
"fp8_w8a16_wo",
|
|
30
|
+
"int8_w8a8_dq",
|
|
31
|
+
"int8_w8a16_wo",
|
|
32
|
+
"int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4_dq",
|
|
34
|
+
"int4_w4a16_wo",
|
|
35
|
+
), f"{quant_type} is not supported for torchao backend now!"
|
|
36
|
+
|
|
37
|
+
if "fp8" in quant_type:
|
|
38
|
+
assert torch.cuda.get_device_capability() >= (
|
|
39
|
+
8,
|
|
40
|
+
9,
|
|
41
|
+
), "FP8 is not supported for current device."
|
|
42
|
+
|
|
43
|
+
num_quant_linear = 0
|
|
44
|
+
num_skip_linear = 0
|
|
45
|
+
num_linear_layers = 0
|
|
46
|
+
num_layers = 0
|
|
47
|
+
|
|
48
|
+
# Ensure bfloat16 for per_row
|
|
49
|
+
def _filter_fn(m: torch.nn.Module, name: str) -> bool:
|
|
50
|
+
nonlocal num_quant_linear, num_skip_linear, num_linear_layers, num_layers
|
|
51
|
+
num_layers += 1
|
|
52
|
+
if isinstance(m, torch.nn.Linear):
|
|
53
|
+
num_linear_layers += 1
|
|
54
|
+
for exclude_name in exclude_layers:
|
|
55
|
+
if exclude_name in name:
|
|
56
|
+
logger.info(
|
|
57
|
+
f"Skip Quantization: {name} -> "
|
|
58
|
+
f"pattern<{exclude_name}>"
|
|
59
|
+
)
|
|
60
|
+
|
|
61
|
+
num_skip_linear += 1
|
|
62
|
+
return False
|
|
63
|
+
|
|
64
|
+
if (
|
|
65
|
+
per_row
|
|
66
|
+
and m.weight.dtype != torch.bfloat16
|
|
67
|
+
and quant_type == "fp8_w8a8_dq"
|
|
68
|
+
):
|
|
69
|
+
logger.info(
|
|
70
|
+
f"Skip Quantization: {name} -> "
|
|
71
|
+
f"pattern<dtype({m.weight.dtype})!=bfloat16>"
|
|
72
|
+
)
|
|
73
|
+
|
|
74
|
+
num_skip_linear += 1
|
|
75
|
+
return False
|
|
76
|
+
|
|
77
|
+
num_quant_linear += 1
|
|
78
|
+
return True
|
|
79
|
+
|
|
80
|
+
return False
|
|
81
|
+
|
|
82
|
+
def _quantization_fn():
|
|
83
|
+
try:
|
|
84
|
+
if quant_type == "fp8_w8a8_dq":
|
|
85
|
+
from torchao.quantization import (
|
|
86
|
+
float8_dynamic_activation_float8_weight,
|
|
87
|
+
PerTensor,
|
|
88
|
+
PerRow,
|
|
89
|
+
)
|
|
90
|
+
|
|
91
|
+
quantization_fn = float8_dynamic_activation_float8_weight(
|
|
92
|
+
granularity=(
|
|
93
|
+
((PerRow(), PerRow()))
|
|
94
|
+
if per_row
|
|
95
|
+
else ((PerTensor(), PerTensor()))
|
|
96
|
+
)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
elif quant_type == "fp8_w8a16_wo":
|
|
100
|
+
from torchao.quantization import float8_weight_only
|
|
101
|
+
|
|
102
|
+
quantization_fn = float8_weight_only()
|
|
103
|
+
|
|
104
|
+
elif quant_type == "int8_w8a8_dq":
|
|
105
|
+
from torchao.quantization import (
|
|
106
|
+
int8_dynamic_activation_int8_weight,
|
|
107
|
+
)
|
|
108
|
+
|
|
109
|
+
quantization_fn = int8_dynamic_activation_int8_weight()
|
|
110
|
+
|
|
111
|
+
elif quant_type == "int8_w8a16_wo":
|
|
112
|
+
from torchao.quantization import int8_weight_only
|
|
113
|
+
|
|
114
|
+
quantization_fn = int8_weight_only(
|
|
115
|
+
# group_size is None -> per_channel, else per group
|
|
116
|
+
group_size=kwargs.get("group_size", None),
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
elif quant_type == "int4_w4a8_dq":
|
|
120
|
+
from torchao.quantization import (
|
|
121
|
+
int8_dynamic_activation_int4_weight,
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
quantization_fn = int8_dynamic_activation_int4_weight(
|
|
125
|
+
group_size=kwargs.get("group_size", 32),
|
|
126
|
+
)
|
|
127
|
+
|
|
128
|
+
elif quant_type == "int4_w4a4_dq":
|
|
129
|
+
from torchao.quantization import (
|
|
130
|
+
int4_dynamic_activation_int4_weight,
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
quantization_fn = int4_dynamic_activation_int4_weight()
|
|
134
|
+
|
|
135
|
+
elif quant_type == "int4_w4a16_wo":
|
|
136
|
+
from torchao.quantization import int4_weight_only
|
|
137
|
+
|
|
138
|
+
quantization_fn = int4_weight_only(
|
|
139
|
+
group_size=kwargs.get("group_size", 32),
|
|
140
|
+
)
|
|
141
|
+
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
f"quant_type: {quant_type} is not supported now!"
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
except ImportError as e:
|
|
148
|
+
e.msg += (
|
|
149
|
+
f"{quant_type} is not supported in torchao backend now! "
|
|
150
|
+
"Please upgrade the torchao library."
|
|
151
|
+
)
|
|
152
|
+
raise e
|
|
153
|
+
|
|
154
|
+
return quantization_fn
|
|
155
|
+
|
|
156
|
+
from torchao.quantization import quantize_
|
|
157
|
+
|
|
158
|
+
quantize_(
|
|
159
|
+
module,
|
|
160
|
+
_quantization_fn(),
|
|
161
|
+
filter_fn=_filter_fn if filter_fn is None else filter_fn,
|
|
162
|
+
**kwargs,
|
|
163
|
+
)
|
|
164
|
+
|
|
165
|
+
force_empty_cache()
|
|
166
|
+
|
|
167
|
+
logger.info(
|
|
168
|
+
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
169
|
+
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
170
|
+
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
|
171
|
+
f"Total (all) Layers: {num_layers:>5}"
|
|
172
|
+
)
|
|
173
|
+
return module
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def force_empty_cache():
|
|
177
|
+
time.sleep(1)
|
|
178
|
+
gc.collect()
|
|
179
|
+
torch.cuda.empty_cache()
|
|
180
|
+
time.sleep(1)
|
|
181
|
+
gc.collect()
|
|
182
|
+
torch.cuda.empty_cache()
|
|
@@ -0,0 +1,46 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Callable, Optional, List
|
|
3
|
+
from cache_dit.logger import init_logger
|
|
4
|
+
|
|
5
|
+
logger = init_logger(__name__)
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def quantize(
|
|
9
|
+
module: torch.nn.Module,
|
|
10
|
+
quant_type: str = "fp8_w8a8_dq",
|
|
11
|
+
backend: str = "ao",
|
|
12
|
+
# only for fp8_w8a8_dq
|
|
13
|
+
per_row: bool = True,
|
|
14
|
+
exclude_layers: List[str] = [
|
|
15
|
+
"embedder",
|
|
16
|
+
"embed",
|
|
17
|
+
],
|
|
18
|
+
filter_fn: Optional[Callable] = None,
|
|
19
|
+
**kwargs,
|
|
20
|
+
) -> torch.nn.Module:
|
|
21
|
+
assert isinstance(module, torch.nn.Module)
|
|
22
|
+
|
|
23
|
+
if backend.lower() in ("ao", "torchao"):
|
|
24
|
+
from cache_dit.quantize.quantize_ao import quantize_ao
|
|
25
|
+
|
|
26
|
+
quant_type = quant_type.lower()
|
|
27
|
+
assert quant_type in (
|
|
28
|
+
"fp8_w8a8_dq",
|
|
29
|
+
"fp8_w8a16_wo",
|
|
30
|
+
"int8_w8a8_dq",
|
|
31
|
+
"int8_w8a16_wo",
|
|
32
|
+
"int4_w4a8_dq",
|
|
33
|
+
"int4_w4a4_dq",
|
|
34
|
+
"int4_w4a16_wo",
|
|
35
|
+
), f"{quant_type} is not supported for torchao backend now!"
|
|
36
|
+
|
|
37
|
+
return quantize_ao(
|
|
38
|
+
module,
|
|
39
|
+
quant_type=quant_type,
|
|
40
|
+
per_row=per_row,
|
|
41
|
+
exclude_layers=exclude_layers,
|
|
42
|
+
filter_fn=filter_fn,
|
|
43
|
+
**kwargs,
|
|
44
|
+
)
|
|
45
|
+
else:
|
|
46
|
+
raise ValueError(f"backend: {backend} is not supported now!")
|
|
File without changes
|
cache_dit/utils.py
CHANGED
|
@@ -5,7 +5,10 @@ import numpy as np
|
|
|
5
5
|
from pprint import pprint
|
|
6
6
|
from diffusers import DiffusionPipeline
|
|
7
7
|
|
|
8
|
+
from typing import Dict, Any
|
|
8
9
|
from cache_dit.logger import init_logger
|
|
10
|
+
from cache_dit.cache_factory import CacheType
|
|
11
|
+
|
|
9
12
|
|
|
10
13
|
logger = init_logger(__name__)
|
|
11
14
|
|
|
@@ -137,28 +140,56 @@ def summary(
|
|
|
137
140
|
return cache_stats
|
|
138
141
|
|
|
139
142
|
|
|
140
|
-
def strify(
|
|
141
|
-
|
|
143
|
+
def strify(
|
|
144
|
+
pipe_or_stats: DiffusionPipeline | CacheStats | Dict[str, Any],
|
|
145
|
+
) -> str:
|
|
146
|
+
if isinstance(pipe_or_stats, DiffusionPipeline):
|
|
142
147
|
stats = summary(pipe_or_stats, logging=False)
|
|
143
|
-
|
|
148
|
+
cache_options = stats.cache_options
|
|
149
|
+
cached_steps = len(stats.cached_steps)
|
|
150
|
+
elif isinstance(pipe_or_stats, CacheStats):
|
|
144
151
|
stats = pipe_or_stats
|
|
145
|
-
|
|
146
|
-
|
|
147
|
-
|
|
152
|
+
cache_options = stats.cache_options
|
|
153
|
+
cached_steps = len(stats.cached_steps)
|
|
154
|
+
elif isinstance(pipe_or_stats, dict):
|
|
155
|
+
# Assume cache_context_kwargs
|
|
156
|
+
cache_options = pipe_or_stats
|
|
157
|
+
cached_steps = None
|
|
158
|
+
else:
|
|
159
|
+
raise ValueError(
|
|
160
|
+
"Please set pipe_or_stats param as one of: "
|
|
161
|
+
"DiffusionPipeline | CacheStats | Dict[str, Any]"
|
|
162
|
+
)
|
|
148
163
|
|
|
149
164
|
if not cache_options:
|
|
150
165
|
return "NONE"
|
|
151
166
|
|
|
167
|
+
if cache_options.get("cache_type", None) != CacheType.DBCache:
|
|
168
|
+
return "NONE"
|
|
169
|
+
|
|
170
|
+
def get_taylorseer_order():
|
|
171
|
+
taylorseer_order = 0
|
|
172
|
+
if "taylorseer_kwargs" in cache_options:
|
|
173
|
+
if "n_derivatives" in cache_options["taylorseer_kwargs"]:
|
|
174
|
+
taylorseer_order = cache_options["taylorseer_kwargs"][
|
|
175
|
+
"n_derivatives"
|
|
176
|
+
]
|
|
177
|
+
elif "taylorseer_order" in cache_options:
|
|
178
|
+
taylorseer_order = cache_options["taylorseer_order"]
|
|
179
|
+
return taylorseer_order
|
|
180
|
+
|
|
152
181
|
cache_type_str = (
|
|
153
|
-
f"DBCACHE_F{cache_options
|
|
154
|
-
f"B{cache_options
|
|
155
|
-
f"W{cache_options
|
|
156
|
-
f"M{max(0, cache_options
|
|
157
|
-
f"MC{max(0, cache_options
|
|
158
|
-
f"T{int(cache_options
|
|
159
|
-
f"O{
|
|
160
|
-
f"R{cache_options
|
|
161
|
-
f"S{cached_steps}" # skiped steps
|
|
182
|
+
f"DBCACHE_F{cache_options.get('Fn_compute_blocks', 1)}"
|
|
183
|
+
f"B{cache_options.get('Bn_compute_blocks', 0)}_"
|
|
184
|
+
f"W{cache_options.get('max_warmup_steps', 0)}"
|
|
185
|
+
f"M{max(0, cache_options.get('max_cached_steps', -1))}"
|
|
186
|
+
f"MC{max(0, cache_options.get('max_continuous_cached_steps', -1))}_"
|
|
187
|
+
f"T{int(cache_options.get('enable_taylorseer', False))}"
|
|
188
|
+
f"O{get_taylorseer_order()}_"
|
|
189
|
+
f"R{cache_options.get('residual_diff_threshold', 0.08)}"
|
|
162
190
|
)
|
|
163
191
|
|
|
192
|
+
if cached_steps:
|
|
193
|
+
cache_type_str += f"_S{cached_steps}"
|
|
194
|
+
|
|
164
195
|
return cache_type_str
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.25
|
|
4
4
|
Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -11,12 +11,13 @@ Description-Content-Type: text/markdown
|
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Requires-Dist: packaging
|
|
13
13
|
Requires-Dist: pyyaml
|
|
14
|
-
Requires-Dist: torch>=2.
|
|
15
|
-
Requires-Dist: transformers>=4.
|
|
16
|
-
Requires-Dist: diffusers>=0.
|
|
14
|
+
Requires-Dist: torch>=2.7.1
|
|
15
|
+
Requires-Dist: transformers>=4.55.2
|
|
16
|
+
Requires-Dist: diffusers>=0.35.1
|
|
17
17
|
Requires-Dist: scikit-image
|
|
18
18
|
Requires-Dist: scipy
|
|
19
19
|
Requires-Dist: lpips==0.1.4
|
|
20
|
+
Requires-Dist: torchao>=0.12.0
|
|
20
21
|
Provides-Extra: all
|
|
21
22
|
Provides-Extra: dev
|
|
22
23
|
Requires-Dist: pre-commit; extra == "dev"
|
|
@@ -65,7 +66,7 @@ Dynamic: requires-python
|
|
|
65
66
|
|
|
66
67
|
## 🔥News
|
|
67
68
|
|
|
68
|
-
- [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.
|
|
69
|
+
- [2025-08-26] 🎉[**Wan2.2**](https://github.com/Wan-Video) **1.8x⚡️** speedup with `cache-dit + compile`! Check the [example](./examples/run_wan_2.2.py).
|
|
69
70
|
- [2025-08-19] 🔥[**Qwen-Image-Edit**](https://github.com/QwenLM/Qwen-Image) **2x⚡️** speedup! Check example [run_qwen_image_edit.py](./examples/run_qwen_image_edit.py).
|
|
70
71
|
- [2025-08-12] 🎉First caching mechanism in [QwenLM/Qwen-Image](https://github.com/QwenLM/Qwen-Image) with **[cache-dit](https://github.com/vipshop/cache-dit)**, check the [PR](https://github.com/QwenLM/Qwen-Image/pull/61).
|
|
71
72
|
- [2025-08-11] 🔥[**Qwen-Image**](https://github.com/QwenLM/Qwen-Image) **1.8x⚡️** speedup! Please refer [run_qwen_image.py](./examples/run_qwen_image.py) as an example.
|
|
@@ -294,14 +295,11 @@ cache_dit.enable_cache(
|
|
|
294
295
|
enable_encoder_taylorseer=True,
|
|
295
296
|
# Taylorseer cache type cache be hidden_states or residual.
|
|
296
297
|
taylorseer_cache_type="residual",
|
|
297
|
-
# Higher values of
|
|
298
|
-
#
|
|
299
|
-
|
|
300
|
-
"n_derivatives": 2, # default is 2.
|
|
301
|
-
},
|
|
302
|
-
max_warmup_steps=3, # prefer: >= n_derivatives + 1
|
|
298
|
+
# Higher values of order will lead to longer computation time
|
|
299
|
+
taylorseer_order=2, # default is 2.
|
|
300
|
+
max_warmup_steps=3, # prefer: >= order + 1
|
|
303
301
|
residual_diff_threshold=0.12
|
|
304
|
-
)
|
|
302
|
+
)s
|
|
305
303
|
```
|
|
306
304
|
|
|
307
305
|
> [!Important]
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=VsT0f0R0COp8v6Sx9hGNsqxiElERaDpfG11a9MfK0is,945
|
|
2
|
+
cache_dit/_version.py,sha256=t9iixyDuMWz1nP7KM0bgrLNIpwu8JK6uZApA8DoCQwM,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/utils.py,sha256=
|
|
4
|
+
cache_dit/utils.py,sha256=1oWDMYs6E7FRsd8cidsVOPT-meIRKeuqbGbE6CrCUec,7236
|
|
5
5
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
6
6
|
cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
|
|
7
7
|
cache_dit/cache_factory/cache_adapters.py,sha256=Yugqljm9tm615srM2BGQlR_tA0QiZo3PbLPceObh4dQ,25988
|
|
8
8
|
cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
|
|
9
|
-
cache_dit/cache_factory/cache_context.py,sha256=
|
|
10
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
9
|
+
cache_dit/cache_factory/cache_context.py,sha256=HhA5IMSdF-i-koSB1jqf5AMC_UyDV7VinpHm4Qee9Ig,41800
|
|
10
|
+
cache_dit/cache_factory/cache_interface.py,sha256=HymagnKEDs48Ly_x3IM5jTMNJpLrIdJnppVlkr2xHaM,8433
|
|
11
11
|
cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
|
|
12
12
|
cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
|
|
13
13
|
cache_dit/cache_factory/taylorseer.py,sha256=etSUIZzDvqW3ScKCbccTPcFaSmxV1T-xAXdk-p3e3wk,3802
|
|
@@ -15,7 +15,7 @@ cache_dit/cache_factory/utils.py,sha256=XkVM9AXcB9zYq8-S8QKAsGz80r3tA6U3lBNGDGeH
|
|
|
15
15
|
cache_dit/cache_factory/patch/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
16
16
|
cache_dit/cache_factory/patch/flux.py,sha256=iNQ-1RlOgXupZ4uPiEvJ__Ro6vKT_fOKja9JrpMrO78,8998
|
|
17
17
|
cache_dit/compile/__init__.py,sha256=FcTVzCeyypl-mxlc59_ehHL3lBNiDAFsXuRoJ-5Cfi0,56
|
|
18
|
-
cache_dit/compile/utils.py,sha256=
|
|
18
|
+
cache_dit/compile/utils.py,sha256=nN2OIrSdwRR5zGxJinKDqb07pXpvTNTF3g_OgLkeeBU,3858
|
|
19
19
|
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
20
20
|
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
21
21
|
cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
|
|
@@ -24,9 +24,13 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
|
|
|
24
24
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
25
25
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
26
26
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
27
|
-
cache_dit
|
|
28
|
-
cache_dit
|
|
29
|
-
cache_dit
|
|
30
|
-
cache_dit
|
|
31
|
-
cache_dit-0.2.
|
|
32
|
-
cache_dit-0.2.
|
|
27
|
+
cache_dit/quantize/__init__.py,sha256=kWYoMAyZgBXu9BJlZjTQ0dRffW9GqeeY9_iTkXrb70A,59
|
|
28
|
+
cache_dit/quantize/quantize_ao.py,sha256=sKz_RmVtxLOpAPnUv_iOjzY_226pfaxgB_HMNrfyqB8,5465
|
|
29
|
+
cache_dit/quantize/quantize_interface.py,sha256=NG4WP7s8CLW6KhVFb9e1aAjW30KWTCcM2aS5n8QuwsA,1241
|
|
30
|
+
cache_dit/quantize/quantize_svdq.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
31
|
+
cache_dit-0.2.25.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
32
|
+
cache_dit-0.2.25.dist-info/METADATA,sha256=a5wbENMZ9BDjHbM3Ejb7Il7x4QzD8W7Lzmu4poo95Wo,19913
|
|
33
|
+
cache_dit-0.2.25.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
34
|
+
cache_dit-0.2.25.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
35
|
+
cache_dit-0.2.25.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
36
|
+
cache_dit-0.2.25.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|