cache-dit 0.2.9__py3-none-any.whl → 0.2.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +3 -168
- cache_dit/cache_factory/adapters.py +169 -0
- cache_dit/cache_factory/utils.py +53 -0
- cache_dit/metrics/metrics.py +226 -56
- {cache_dit-0.2.9.dist-info → cache_dit-0.2.10.dist-info}/METADATA +3 -2
- {cache_dit-0.2.9.dist-info → cache_dit-0.2.10.dist-info}/RECORD +11 -10
- {cache_dit-0.2.9.dist-info → cache_dit-0.2.10.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.9.dist-info → cache_dit-0.2.10.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.9.dist-info → cache_dit-0.2.10.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.9.dist-info → cache_dit-0.2.10.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -1,168 +1,3 @@
|
|
|
1
|
-
from
|
|
2
|
-
|
|
3
|
-
from
|
|
4
|
-
|
|
5
|
-
from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
|
|
6
|
-
apply_db_cache_on_pipe,
|
|
7
|
-
)
|
|
8
|
-
from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
|
|
9
|
-
apply_fb_cache_on_pipe,
|
|
10
|
-
)
|
|
11
|
-
from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
|
|
12
|
-
apply_db_prune_on_pipe,
|
|
13
|
-
)
|
|
14
|
-
from cache_dit.logger import init_logger
|
|
15
|
-
|
|
16
|
-
|
|
17
|
-
logger = init_logger(__name__)
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
class CacheType(Enum):
|
|
21
|
-
NONE = "NONE"
|
|
22
|
-
FBCache = "First_Block_Cache"
|
|
23
|
-
DBCache = "Dual_Block_Cache"
|
|
24
|
-
DBPrune = "Dynamic_Block_Prune"
|
|
25
|
-
|
|
26
|
-
@staticmethod
|
|
27
|
-
def type(cache_type: "CacheType | str") -> "CacheType":
|
|
28
|
-
if isinstance(cache_type, CacheType):
|
|
29
|
-
return cache_type
|
|
30
|
-
return CacheType.cache_type(cache_type)
|
|
31
|
-
|
|
32
|
-
@staticmethod
|
|
33
|
-
def cache_type(cache_type: "CacheType | str") -> "CacheType":
|
|
34
|
-
if cache_type is None:
|
|
35
|
-
return CacheType.NONE
|
|
36
|
-
|
|
37
|
-
if isinstance(cache_type, CacheType):
|
|
38
|
-
return cache_type
|
|
39
|
-
if cache_type.lower() in (
|
|
40
|
-
"first_block_cache",
|
|
41
|
-
"fb_cache",
|
|
42
|
-
"fbcache",
|
|
43
|
-
"fb",
|
|
44
|
-
):
|
|
45
|
-
return CacheType.FBCache
|
|
46
|
-
elif cache_type.lower() in (
|
|
47
|
-
"dual_block_cache",
|
|
48
|
-
"db_cache",
|
|
49
|
-
"dbcache",
|
|
50
|
-
"db",
|
|
51
|
-
):
|
|
52
|
-
return CacheType.DBCache
|
|
53
|
-
elif cache_type.lower() in (
|
|
54
|
-
"dynamic_block_prune",
|
|
55
|
-
"db_prune",
|
|
56
|
-
"dbprune",
|
|
57
|
-
"dbp",
|
|
58
|
-
):
|
|
59
|
-
return CacheType.DBPrune
|
|
60
|
-
elif cache_type.lower() in (
|
|
61
|
-
"none_cache",
|
|
62
|
-
"nonecache",
|
|
63
|
-
"no_cache",
|
|
64
|
-
"nocache",
|
|
65
|
-
"none",
|
|
66
|
-
"no",
|
|
67
|
-
):
|
|
68
|
-
return CacheType.NONE
|
|
69
|
-
else:
|
|
70
|
-
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
71
|
-
|
|
72
|
-
@staticmethod
|
|
73
|
-
def range(start: int, end: int, step: int = 1) -> list[int]:
|
|
74
|
-
if start > end or end <= 0 or step <= 1:
|
|
75
|
-
return []
|
|
76
|
-
# Always compute 0 and end - 1 blocks for DB Cache
|
|
77
|
-
return list(
|
|
78
|
-
sorted(set([0] + list(range(start, end, step)) + [end - 1]))
|
|
79
|
-
)
|
|
80
|
-
|
|
81
|
-
@staticmethod
|
|
82
|
-
def default_options(cache_type: "CacheType | str") -> dict:
|
|
83
|
-
_no_options = {
|
|
84
|
-
"cache_type": CacheType.NONE,
|
|
85
|
-
}
|
|
86
|
-
|
|
87
|
-
_fb_options = {
|
|
88
|
-
"cache_type": CacheType.FBCache,
|
|
89
|
-
"residual_diff_threshold": 0.08,
|
|
90
|
-
"warmup_steps": 8,
|
|
91
|
-
"max_cached_steps": 8,
|
|
92
|
-
}
|
|
93
|
-
|
|
94
|
-
_Fn_compute_blocks = 8
|
|
95
|
-
_Bn_compute_blocks = 8
|
|
96
|
-
|
|
97
|
-
_db_options = {
|
|
98
|
-
"cache_type": CacheType.DBCache,
|
|
99
|
-
"residual_diff_threshold": 0.12,
|
|
100
|
-
"warmup_steps": 8,
|
|
101
|
-
"max_cached_steps": -1, # -1 means no limit
|
|
102
|
-
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
103
|
-
"Fn_compute_blocks": _Fn_compute_blocks,
|
|
104
|
-
"Bn_compute_blocks": _Bn_compute_blocks,
|
|
105
|
-
"max_Fn_compute_blocks": 16,
|
|
106
|
-
"max_Bn_compute_blocks": 16,
|
|
107
|
-
"Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
|
|
108
|
-
"Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
|
|
109
|
-
}
|
|
110
|
-
|
|
111
|
-
_dbp_options = {
|
|
112
|
-
"cache_type": CacheType.DBPrune,
|
|
113
|
-
"residual_diff_threshold": 0.08,
|
|
114
|
-
"Fn_compute_blocks": _Fn_compute_blocks,
|
|
115
|
-
"Bn_compute_blocks": _Bn_compute_blocks,
|
|
116
|
-
"warmup_steps": 8,
|
|
117
|
-
"max_pruned_steps": -1, # -1 means no limit
|
|
118
|
-
}
|
|
119
|
-
|
|
120
|
-
if cache_type == CacheType.FBCache:
|
|
121
|
-
return _fb_options
|
|
122
|
-
elif cache_type == CacheType.DBCache:
|
|
123
|
-
return _db_options
|
|
124
|
-
elif cache_type == CacheType.DBPrune:
|
|
125
|
-
return _dbp_options
|
|
126
|
-
elif cache_type == CacheType.NONE:
|
|
127
|
-
return _no_options
|
|
128
|
-
else:
|
|
129
|
-
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
133
|
-
assert isinstance(pipe, DiffusionPipeline)
|
|
134
|
-
|
|
135
|
-
if hasattr(pipe, "_is_cached") and pipe._is_cached:
|
|
136
|
-
return pipe
|
|
137
|
-
|
|
138
|
-
if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
|
|
139
|
-
return pipe
|
|
140
|
-
|
|
141
|
-
cache_type = kwargs.pop("cache_type", None)
|
|
142
|
-
if cache_type is None:
|
|
143
|
-
logger.warning(
|
|
144
|
-
"No cache type specified, we will use DBCache by default. "
|
|
145
|
-
"Please specify the cache_type explicitly if you want to "
|
|
146
|
-
"use a different cache type."
|
|
147
|
-
)
|
|
148
|
-
# Force to use DBCache with default cache options
|
|
149
|
-
return apply_db_cache_on_pipe(
|
|
150
|
-
pipe,
|
|
151
|
-
**CacheType.default_options(CacheType.DBCache),
|
|
152
|
-
)
|
|
153
|
-
|
|
154
|
-
cache_type = CacheType.type(cache_type)
|
|
155
|
-
|
|
156
|
-
if cache_type == CacheType.FBCache:
|
|
157
|
-
return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
|
|
158
|
-
elif cache_type == CacheType.DBCache:
|
|
159
|
-
return apply_db_cache_on_pipe(pipe, *args, **kwargs)
|
|
160
|
-
elif cache_type == CacheType.DBPrune:
|
|
161
|
-
return apply_db_prune_on_pipe(pipe, *args, **kwargs)
|
|
162
|
-
elif cache_type == CacheType.NONE:
|
|
163
|
-
logger.warning(
|
|
164
|
-
f"Cache type is {cache_type}, no caching will be applied."
|
|
165
|
-
)
|
|
166
|
-
return pipe
|
|
167
|
-
else:
|
|
168
|
-
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
1
|
+
from cache_dit.cache_factory.adapters import CacheType
|
|
2
|
+
from cache_dit.cache_factory.adapters import apply_cache_on_pipe
|
|
3
|
+
from cache_dit.cache_factory.utils import load_cache_options_from_yaml
|
|
@@ -0,0 +1,169 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
from diffusers import DiffusionPipeline
|
|
4
|
+
|
|
5
|
+
from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
|
|
6
|
+
apply_db_cache_on_pipe,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
|
|
9
|
+
apply_fb_cache_on_pipe,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
|
|
12
|
+
apply_db_prune_on_pipe,
|
|
13
|
+
)
|
|
14
|
+
|
|
15
|
+
from cache_dit.logger import init_logger
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
logger = init_logger(__name__)
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
class CacheType(Enum):
|
|
22
|
+
NONE = "NONE"
|
|
23
|
+
FBCache = "First_Block_Cache"
|
|
24
|
+
DBCache = "Dual_Block_Cache"
|
|
25
|
+
DBPrune = "Dynamic_Block_Prune"
|
|
26
|
+
|
|
27
|
+
@staticmethod
|
|
28
|
+
def type(cache_type: "CacheType | str") -> "CacheType":
|
|
29
|
+
if isinstance(cache_type, CacheType):
|
|
30
|
+
return cache_type
|
|
31
|
+
return CacheType.cache_type(cache_type)
|
|
32
|
+
|
|
33
|
+
@staticmethod
|
|
34
|
+
def cache_type(cache_type: "CacheType | str") -> "CacheType":
|
|
35
|
+
if cache_type is None:
|
|
36
|
+
return CacheType.NONE
|
|
37
|
+
|
|
38
|
+
if isinstance(cache_type, CacheType):
|
|
39
|
+
return cache_type
|
|
40
|
+
if cache_type.lower() in (
|
|
41
|
+
"first_block_cache",
|
|
42
|
+
"fb_cache",
|
|
43
|
+
"fbcache",
|
|
44
|
+
"fb",
|
|
45
|
+
):
|
|
46
|
+
return CacheType.FBCache
|
|
47
|
+
elif cache_type.lower() in (
|
|
48
|
+
"dual_block_cache",
|
|
49
|
+
"db_cache",
|
|
50
|
+
"dbcache",
|
|
51
|
+
"db",
|
|
52
|
+
):
|
|
53
|
+
return CacheType.DBCache
|
|
54
|
+
elif cache_type.lower() in (
|
|
55
|
+
"dynamic_block_prune",
|
|
56
|
+
"db_prune",
|
|
57
|
+
"dbprune",
|
|
58
|
+
"dbp",
|
|
59
|
+
):
|
|
60
|
+
return CacheType.DBPrune
|
|
61
|
+
elif cache_type.lower() in (
|
|
62
|
+
"none_cache",
|
|
63
|
+
"nonecache",
|
|
64
|
+
"no_cache",
|
|
65
|
+
"nocache",
|
|
66
|
+
"none",
|
|
67
|
+
"no",
|
|
68
|
+
):
|
|
69
|
+
return CacheType.NONE
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
72
|
+
|
|
73
|
+
@staticmethod
|
|
74
|
+
def range(start: int, end: int, step: int = 1) -> list[int]:
|
|
75
|
+
if start > end or end <= 0 or step <= 1:
|
|
76
|
+
return []
|
|
77
|
+
# Always compute 0 and end - 1 blocks for DB Cache
|
|
78
|
+
return list(
|
|
79
|
+
sorted(set([0] + list(range(start, end, step)) + [end - 1]))
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
@staticmethod
|
|
83
|
+
def default_options(cache_type: "CacheType | str") -> dict:
|
|
84
|
+
_no_options = {
|
|
85
|
+
"cache_type": CacheType.NONE,
|
|
86
|
+
}
|
|
87
|
+
|
|
88
|
+
_fb_options = {
|
|
89
|
+
"cache_type": CacheType.FBCache,
|
|
90
|
+
"residual_diff_threshold": 0.08,
|
|
91
|
+
"warmup_steps": 8,
|
|
92
|
+
"max_cached_steps": 8,
|
|
93
|
+
}
|
|
94
|
+
|
|
95
|
+
_Fn_compute_blocks = 8
|
|
96
|
+
_Bn_compute_blocks = 8
|
|
97
|
+
|
|
98
|
+
_db_options = {
|
|
99
|
+
"cache_type": CacheType.DBCache,
|
|
100
|
+
"residual_diff_threshold": 0.12,
|
|
101
|
+
"warmup_steps": 8,
|
|
102
|
+
"max_cached_steps": -1, # -1 means no limit
|
|
103
|
+
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
104
|
+
"Fn_compute_blocks": _Fn_compute_blocks,
|
|
105
|
+
"Bn_compute_blocks": _Bn_compute_blocks,
|
|
106
|
+
"max_Fn_compute_blocks": 16,
|
|
107
|
+
"max_Bn_compute_blocks": 16,
|
|
108
|
+
"Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
|
|
109
|
+
"Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
|
|
110
|
+
}
|
|
111
|
+
|
|
112
|
+
_dbp_options = {
|
|
113
|
+
"cache_type": CacheType.DBPrune,
|
|
114
|
+
"residual_diff_threshold": 0.08,
|
|
115
|
+
"Fn_compute_blocks": _Fn_compute_blocks,
|
|
116
|
+
"Bn_compute_blocks": _Bn_compute_blocks,
|
|
117
|
+
"warmup_steps": 8,
|
|
118
|
+
"max_pruned_steps": -1, # -1 means no limit
|
|
119
|
+
}
|
|
120
|
+
|
|
121
|
+
if cache_type == CacheType.FBCache:
|
|
122
|
+
return _fb_options
|
|
123
|
+
elif cache_type == CacheType.DBCache:
|
|
124
|
+
return _db_options
|
|
125
|
+
elif cache_type == CacheType.DBPrune:
|
|
126
|
+
return _dbp_options
|
|
127
|
+
elif cache_type == CacheType.NONE:
|
|
128
|
+
return _no_options
|
|
129
|
+
else:
|
|
130
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
134
|
+
assert isinstance(pipe, DiffusionPipeline)
|
|
135
|
+
|
|
136
|
+
if hasattr(pipe, "_is_cached") and pipe._is_cached:
|
|
137
|
+
return pipe
|
|
138
|
+
|
|
139
|
+
if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
|
|
140
|
+
return pipe
|
|
141
|
+
|
|
142
|
+
cache_type = kwargs.pop("cache_type", None)
|
|
143
|
+
if cache_type is None:
|
|
144
|
+
logger.warning(
|
|
145
|
+
"No cache type specified, we will use DBCache by default. "
|
|
146
|
+
"Please specify the cache_type explicitly if you want to "
|
|
147
|
+
"use a different cache type."
|
|
148
|
+
)
|
|
149
|
+
# Force to use DBCache with default cache options
|
|
150
|
+
return apply_db_cache_on_pipe(
|
|
151
|
+
pipe,
|
|
152
|
+
**CacheType.default_options(CacheType.DBCache),
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
cache_type = CacheType.type(cache_type)
|
|
156
|
+
|
|
157
|
+
if cache_type == CacheType.FBCache:
|
|
158
|
+
return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
|
|
159
|
+
elif cache_type == CacheType.DBCache:
|
|
160
|
+
return apply_db_cache_on_pipe(pipe, *args, **kwargs)
|
|
161
|
+
elif cache_type == CacheType.DBPrune:
|
|
162
|
+
return apply_db_prune_on_pipe(pipe, *args, **kwargs)
|
|
163
|
+
elif cache_type == CacheType.NONE:
|
|
164
|
+
logger.warning(
|
|
165
|
+
f"Cache type is {cache_type}, no caching will be applied."
|
|
166
|
+
)
|
|
167
|
+
return pipe
|
|
168
|
+
else:
|
|
169
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
cache_dit/cache_factory/utils.py
CHANGED
|
@@ -0,0 +1,53 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
from cache_dit.cache_factory.adapters import CacheType
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def load_cache_options_from_yaml(yaml_file_path):
|
|
6
|
+
try:
|
|
7
|
+
with open(yaml_file_path, "r") as f:
|
|
8
|
+
config = yaml.safe_load(f)
|
|
9
|
+
|
|
10
|
+
required_keys = [
|
|
11
|
+
"cache_type",
|
|
12
|
+
"warmup_steps",
|
|
13
|
+
"max_cached_steps",
|
|
14
|
+
"Fn_compute_blocks",
|
|
15
|
+
"Bn_compute_blocks",
|
|
16
|
+
"residual_diff_threshold",
|
|
17
|
+
]
|
|
18
|
+
for key in required_keys:
|
|
19
|
+
if key not in config:
|
|
20
|
+
raise ValueError(
|
|
21
|
+
f"Configuration file missing required item: {key}"
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
# Convert cache_type to CacheType enum
|
|
25
|
+
if isinstance(config["cache_type"], str):
|
|
26
|
+
try:
|
|
27
|
+
config["cache_type"] = CacheType[config["cache_type"]]
|
|
28
|
+
except KeyError:
|
|
29
|
+
valid_types = [ct.name for ct in CacheType]
|
|
30
|
+
raise ValueError(
|
|
31
|
+
f"Invalid cache_type value: {config['cache_type']}, "
|
|
32
|
+
f"valid values are: {valid_types}"
|
|
33
|
+
)
|
|
34
|
+
elif not isinstance(config["cache_type"], CacheType):
|
|
35
|
+
raise ValueError(
|
|
36
|
+
f"cache_type must be a string or CacheType enum, "
|
|
37
|
+
f"got: {type(config['cache_type'])}"
|
|
38
|
+
)
|
|
39
|
+
|
|
40
|
+
# Handle default value for taylorseer_kwargs
|
|
41
|
+
if "taylorseer_kwargs" not in config and config.get(
|
|
42
|
+
"enable_taylorseer", False
|
|
43
|
+
):
|
|
44
|
+
config["taylorseer_kwargs"] = {"n_derivatives": 2}
|
|
45
|
+
|
|
46
|
+
return config
|
|
47
|
+
|
|
48
|
+
except FileNotFoundError:
|
|
49
|
+
raise FileNotFoundError(
|
|
50
|
+
f"Configuration file not found: {yaml_file_path}"
|
|
51
|
+
)
|
|
52
|
+
except yaml.YAMLError as e:
|
|
53
|
+
raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
|
cache_dit/metrics/metrics.py
CHANGED
|
@@ -348,8 +348,9 @@ def get_args():
|
|
|
348
348
|
"all",
|
|
349
349
|
]
|
|
350
350
|
parser.add_argument(
|
|
351
|
-
"
|
|
351
|
+
"metrics",
|
|
352
352
|
type=str,
|
|
353
|
+
nargs="+",
|
|
353
354
|
default="psnr",
|
|
354
355
|
choices=METRICS_CHOICES,
|
|
355
356
|
help=f"Metric choices: {METRICS_CHOICES}",
|
|
@@ -389,6 +390,47 @@ def get_args():
|
|
|
389
390
|
default=False,
|
|
390
391
|
help="Show metrics progress verbose",
|
|
391
392
|
)
|
|
393
|
+
|
|
394
|
+
# Image 1 vs N pattern
|
|
395
|
+
parser.add_argument(
|
|
396
|
+
"--img-source-dir",
|
|
397
|
+
"-d",
|
|
398
|
+
type=str,
|
|
399
|
+
default=None,
|
|
400
|
+
help="Path to dir that contains dirs of images",
|
|
401
|
+
)
|
|
402
|
+
parser.add_argument(
|
|
403
|
+
"--ref-img-dir",
|
|
404
|
+
"-r",
|
|
405
|
+
type=str,
|
|
406
|
+
default=None,
|
|
407
|
+
help="Path to ref dir that contains ground truth images",
|
|
408
|
+
)
|
|
409
|
+
|
|
410
|
+
# Video 1 vs N pattern
|
|
411
|
+
parser.add_argument(
|
|
412
|
+
"--video-source-dir",
|
|
413
|
+
"-vd",
|
|
414
|
+
type=str,
|
|
415
|
+
default=None,
|
|
416
|
+
help="Path to dir that contains many videos",
|
|
417
|
+
)
|
|
418
|
+
parser.add_argument(
|
|
419
|
+
"--ref-video",
|
|
420
|
+
"-rv",
|
|
421
|
+
type=str,
|
|
422
|
+
default=None,
|
|
423
|
+
help="Path to ground truth video",
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
# FID batch size
|
|
427
|
+
parser.add_argument(
|
|
428
|
+
"--fid-batch-size",
|
|
429
|
+
"-b",
|
|
430
|
+
type=int,
|
|
431
|
+
default=1,
|
|
432
|
+
help="Batch size for FID compute",
|
|
433
|
+
)
|
|
392
434
|
return parser.parse_args()
|
|
393
435
|
|
|
394
436
|
|
|
@@ -401,67 +443,195 @@ def entrypoint():
|
|
|
401
443
|
set_metrics_verbose(True)
|
|
402
444
|
DISABLE_VERBOSE = not get_metrics_verbose()
|
|
403
445
|
|
|
404
|
-
if args.
|
|
405
|
-
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
409
|
-
|
|
410
|
-
|
|
446
|
+
if "all" in args.metrics or "fid" in args.metrics:
|
|
447
|
+
FID = FrechetInceptionDistance(
|
|
448
|
+
disable_tqdm=DISABLE_VERBOSE,
|
|
449
|
+
batch_size=args.fid_batch_size,
|
|
450
|
+
)
|
|
451
|
+
|
|
452
|
+
# run one metric
|
|
453
|
+
def _run_metric(
|
|
454
|
+
mertric: str,
|
|
455
|
+
img_true: str = None,
|
|
456
|
+
img_test: str = None,
|
|
457
|
+
video_true: str = None,
|
|
458
|
+
video_test: str = None,
|
|
459
|
+
) -> None:
|
|
460
|
+
nonlocal FID
|
|
461
|
+
mertric = mertric.lower()
|
|
462
|
+
if img_true is not None and img_test is not None:
|
|
463
|
+
if any(
|
|
464
|
+
(
|
|
465
|
+
not os.path.exists(img_true),
|
|
466
|
+
not os.path.exists(img_test),
|
|
467
|
+
)
|
|
468
|
+
):
|
|
469
|
+
return
|
|
470
|
+
# img_true and img_test can be files or dirs
|
|
471
|
+
img_true_info = os.path.basename(img_true)
|
|
472
|
+
img_test_info = os.path.basename(img_test)
|
|
473
|
+
if mertric == "psnr" or mertric == "all":
|
|
474
|
+
img_psnr, n = compute_psnr(img_true, img_test)
|
|
475
|
+
logger.info(
|
|
476
|
+
f"{img_true_info} vs {img_test_info}, "
|
|
477
|
+
f"Num: {n}, PSNR: {img_psnr}"
|
|
478
|
+
)
|
|
479
|
+
if mertric == "ssim" or mertric == "all":
|
|
480
|
+
img_ssim, n = compute_ssim(img_true, img_test)
|
|
481
|
+
logger.info(
|
|
482
|
+
f"{img_true_info} vs {img_test_info}, "
|
|
483
|
+
f"Num: {n}, SSIM: {img_ssim}"
|
|
484
|
+
)
|
|
485
|
+
if mertric == "mse" or mertric == "all":
|
|
486
|
+
img_mse, n = compute_mse(img_true, img_test)
|
|
487
|
+
logger.info(
|
|
488
|
+
f"{img_true_info} vs {img_test_info}, "
|
|
489
|
+
f"Num: {n}, MSE: {img_mse}"
|
|
490
|
+
)
|
|
491
|
+
if mertric == "fid" or mertric == "all":
|
|
492
|
+
img_fid, n = FID.compute_fid(img_true, img_test)
|
|
493
|
+
logger.info(
|
|
494
|
+
f"{img_true_info} vs {img_test_info}, "
|
|
495
|
+
f"Num: {n}, FID: {img_fid}"
|
|
496
|
+
)
|
|
497
|
+
if video_true is not None and video_test is not None:
|
|
498
|
+
if any(
|
|
499
|
+
(
|
|
500
|
+
not os.path.exists(video_true),
|
|
501
|
+
not os.path.exists(video_test),
|
|
502
|
+
)
|
|
503
|
+
):
|
|
504
|
+
return
|
|
505
|
+
# video_true and video_test can be files or dirs
|
|
506
|
+
video_true_info = os.path.basename(video_true)
|
|
507
|
+
video_test_info = os.path.basename(video_test)
|
|
508
|
+
if mertric == "psnr" or mertric == "all":
|
|
509
|
+
video_psnr, n = compute_video_psnr(video_true, video_test)
|
|
510
|
+
logger.info(
|
|
511
|
+
f"{video_true_info} vs {video_test_info}, "
|
|
512
|
+
f"Frames: {n}, PSNR: {video_psnr}"
|
|
513
|
+
)
|
|
514
|
+
if mertric == "ssim" or mertric == "all":
|
|
515
|
+
video_ssim, n = compute_video_ssim(video_true, video_test)
|
|
516
|
+
logger.info(
|
|
517
|
+
f"{video_true_info} vs {video_test_info}, "
|
|
518
|
+
f"Frames: {n}, SSIM: {video_ssim}"
|
|
519
|
+
)
|
|
520
|
+
if mertric == "mse" or mertric == "all":
|
|
521
|
+
video_mse, n = compute_video_mse(video_true, video_test)
|
|
522
|
+
logger.info(
|
|
523
|
+
f"{video_true_info} vs {video_test_info}, "
|
|
524
|
+
f"Frames: {n}, MSE: {video_mse}"
|
|
525
|
+
)
|
|
526
|
+
if mertric == "fid" or mertric == "all":
|
|
527
|
+
video_fid, n = FID.compute_video_fid(video_true, video_test)
|
|
528
|
+
logger.info(
|
|
529
|
+
f"{video_true_info} vs {video_test_info}, "
|
|
530
|
+
f"Frames: {n}, FID: {video_fid}"
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
# run selected metrics
|
|
534
|
+
if not DISABLE_VERBOSE:
|
|
535
|
+
logger.info(f"Selected metrics: {args.metrics}")
|
|
536
|
+
|
|
537
|
+
def _is_image_1vsN_pattern() -> bool:
|
|
538
|
+
return args.img_source_dir is not None and args.ref_img_dir is not None
|
|
539
|
+
|
|
540
|
+
def _is_video_1vsN_pattern() -> bool:
|
|
541
|
+
return args.video_source_dir is not None and args.ref_video is not None
|
|
542
|
+
|
|
543
|
+
assert not all((_is_image_1vsN_pattern(), _is_video_1vsN_pattern()))
|
|
544
|
+
|
|
545
|
+
if _is_image_1vsN_pattern():
|
|
546
|
+
# Glob Image dirs
|
|
547
|
+
if not os.path.exists(args.img_source_dir):
|
|
548
|
+
logger.error(f"{args.img_source_dir} not exist!")
|
|
411
549
|
return
|
|
412
|
-
|
|
413
|
-
|
|
414
|
-
img_psnr, n = compute_psnr(args.img_true, args.img_test)
|
|
415
|
-
logger.info(
|
|
416
|
-
f"{args.img_true} vs {args.img_test}, Num: {n}, PSNR: {img_psnr}"
|
|
417
|
-
)
|
|
418
|
-
if args.metric == "ssim" or args.metric == "all":
|
|
419
|
-
img_ssim, n = compute_ssim(args.img_true, args.img_test)
|
|
420
|
-
logger.info(
|
|
421
|
-
f"{args.img_true} vs {args.img_test}, Num: {n}, SSIM: {img_ssim}"
|
|
422
|
-
)
|
|
423
|
-
if args.metric == "mse" or args.metric == "all":
|
|
424
|
-
img_mse, n = compute_mse(args.img_true, args.img_test)
|
|
425
|
-
logger.info(
|
|
426
|
-
f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
|
|
427
|
-
)
|
|
428
|
-
if args.metric == "fid" or args.metric == "all":
|
|
429
|
-
FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
|
|
430
|
-
img_fid, n = FID.compute_fid(args.img_true, args.img_test)
|
|
431
|
-
logger.info(
|
|
432
|
-
f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
|
|
433
|
-
)
|
|
434
|
-
if args.video_true is not None and args.video_test is not None:
|
|
435
|
-
if any(
|
|
436
|
-
(
|
|
437
|
-
not os.path.exists(args.video_true),
|
|
438
|
-
not os.path.exists(args.video_test),
|
|
439
|
-
)
|
|
440
|
-
):
|
|
550
|
+
if not os.path.exists(args.ref_img_dir):
|
|
551
|
+
logger.error(f"{args.ref_img_dir} not exist!")
|
|
441
552
|
return
|
|
442
|
-
|
|
443
|
-
|
|
444
|
-
|
|
445
|
-
|
|
446
|
-
|
|
447
|
-
|
|
448
|
-
|
|
449
|
-
|
|
553
|
+
|
|
554
|
+
directories = []
|
|
555
|
+
for item in os.listdir(args.img_source_dir):
|
|
556
|
+
item_path = os.path.join(args.img_source_dir, item)
|
|
557
|
+
if os.path.isdir(item_path):
|
|
558
|
+
if os.path.basename(item_path) == os.path.basename(
|
|
559
|
+
args.ref_img_dir
|
|
560
|
+
):
|
|
561
|
+
continue
|
|
562
|
+
directories.append(item_path)
|
|
563
|
+
|
|
564
|
+
if len(directories) == 0:
|
|
565
|
+
return
|
|
566
|
+
|
|
567
|
+
directories = sorted(directories)
|
|
568
|
+
if not DISABLE_VERBOSE:
|
|
450
569
|
logger.info(
|
|
451
|
-
f"{args.
|
|
570
|
+
f"Compare {args.ref_img_dir} vs {directories}, "
|
|
571
|
+
f"Num compares: {len(directories)}"
|
|
452
572
|
)
|
|
453
|
-
|
|
454
|
-
|
|
573
|
+
|
|
574
|
+
for metric in args.metrics:
|
|
575
|
+
for img_test_dir in directories:
|
|
576
|
+
_run_metric(
|
|
577
|
+
mertric=metric,
|
|
578
|
+
img_true=args.ref_img_dir,
|
|
579
|
+
img_test=img_test_dir,
|
|
580
|
+
)
|
|
581
|
+
|
|
582
|
+
elif _is_video_1vsN_pattern():
|
|
583
|
+
# Glob videos
|
|
584
|
+
if not os.path.exists(args.video_source_dir):
|
|
585
|
+
logger.error(f"{args.video_source_dir} not exist!")
|
|
586
|
+
return
|
|
587
|
+
if not os.path.exists(args.ref_video):
|
|
588
|
+
logger.error(f"{args.ref_video} not exist!")
|
|
589
|
+
return
|
|
590
|
+
|
|
591
|
+
video_source_dir: pathlib.Path = pathlib.Path(args.video_source_dir)
|
|
592
|
+
video_source_files = sorted(
|
|
593
|
+
[
|
|
594
|
+
file
|
|
595
|
+
for ext in _VIDEO_EXTENSIONS
|
|
596
|
+
for file in video_source_dir.rglob("*.{}".format(ext))
|
|
597
|
+
]
|
|
598
|
+
)
|
|
599
|
+
video_source_files = [file.as_posix() for file in video_source_files]
|
|
600
|
+
|
|
601
|
+
video_source_selected = []
|
|
602
|
+
for video_source_file in video_source_files:
|
|
603
|
+
if os.path.basename(video_source_file) == os.path.basename(
|
|
604
|
+
args.ref_video
|
|
605
|
+
):
|
|
606
|
+
continue
|
|
607
|
+
video_source_selected.append(video_source_file)
|
|
608
|
+
|
|
609
|
+
if len(video_source_selected) == 0:
|
|
610
|
+
return
|
|
611
|
+
|
|
612
|
+
video_source_selected = sorted(video_source_selected)
|
|
613
|
+
if not DISABLE_VERBOSE:
|
|
455
614
|
logger.info(
|
|
456
|
-
f"{args.
|
|
457
|
-
|
|
458
|
-
if args.metric == "fid" or args.metric == "all":
|
|
459
|
-
FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
|
|
460
|
-
video_fid, n = FID.compute_video_fid(
|
|
461
|
-
args.video_true, args.video_test
|
|
615
|
+
f"Compare {args.ref_video} vs {video_source_selected}, "
|
|
616
|
+
f"Num compares: {len(video_source_selected)}"
|
|
462
617
|
)
|
|
463
|
-
|
|
464
|
-
|
|
618
|
+
|
|
619
|
+
for metric in args.metrics:
|
|
620
|
+
for video_test in video_source_selected:
|
|
621
|
+
_run_metric(
|
|
622
|
+
mertric=metric,
|
|
623
|
+
video_true=args.ref_video,
|
|
624
|
+
video_test=video_test,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
else:
|
|
628
|
+
for metric in args.metrics:
|
|
629
|
+
_run_metric(
|
|
630
|
+
mertric=metric,
|
|
631
|
+
img_true=args.img_true,
|
|
632
|
+
img_test=args.img_test,
|
|
633
|
+
video_true=args.video_true,
|
|
634
|
+
video_test=args.video_test,
|
|
465
635
|
)
|
|
466
636
|
|
|
467
637
|
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.10
|
|
4
4
|
Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -10,6 +10,7 @@ Requires-Python: >=3.10
|
|
|
10
10
|
Description-Content-Type: text/markdown
|
|
11
11
|
License-File: LICENSE
|
|
12
12
|
Requires-Dist: packaging
|
|
13
|
+
Requires-Dist: pyyaml
|
|
13
14
|
Requires-Dist: torch>=2.5.1
|
|
14
15
|
Requires-Dist: transformers>=4.51.3
|
|
15
16
|
Requires-Dist: diffusers>=0.33.1
|
|
@@ -63,7 +64,7 @@ Dynamic: requires-python
|
|
|
63
64
|
|
|
64
65
|
## 🔥News🔥
|
|
65
66
|
|
|
66
|
-
- [2025-07-13]
|
|
67
|
+
- [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of [huggingface/flux-fast](https://github.com/huggingface/flux-fast) that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20 while still maintaining **high precision**.
|
|
67
68
|
|
|
68
69
|
|
|
69
70
|
## 🤗 Introduction
|
|
@@ -1,10 +1,11 @@
|
|
|
1
1
|
cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
2
|
+
cache_dit/_version.py,sha256=fMOyoyXAggjNTgl2YJ-8HW1bnjjDPiNACUsDoNufScI,513
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
|
-
cache_dit/cache_factory/__init__.py,sha256=
|
|
5
|
+
cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
|
|
6
|
+
cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
|
|
6
7
|
cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
|
|
7
|
-
cache_dit/cache_factory/utils.py,sha256=
|
|
8
|
+
cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
|
|
8
9
|
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
9
10
|
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=itVEb6gT2eZuncAHUmP51ZS0r6v6cGtRvnPjyeXqKH8,71156
|
|
10
11
|
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
@@ -37,10 +38,10 @@ cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE
|
|
|
37
38
|
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
38
39
|
cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
|
|
39
40
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
40
|
-
cache_dit/metrics/metrics.py,sha256=
|
|
41
|
-
cache_dit-0.2.
|
|
42
|
-
cache_dit-0.2.
|
|
43
|
-
cache_dit-0.2.
|
|
44
|
-
cache_dit-0.2.
|
|
45
|
-
cache_dit-0.2.
|
|
46
|
-
cache_dit-0.2.
|
|
41
|
+
cache_dit/metrics/metrics.py,sha256=O9a8qV6deQDWEoez7UZ_aqDLQ9rJXAJUMHGnJM7RUMs,19927
|
|
42
|
+
cache_dit-0.2.10.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
43
|
+
cache_dit-0.2.10.dist-info/METADATA,sha256=f5PF-lhexcLdB2HmBWETBEyg011eZOc99tlVI1lozYA,28002
|
|
44
|
+
cache_dit-0.2.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
45
|
+
cache_dit-0.2.10.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
46
|
+
cache_dit-0.2.10.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
47
|
+
cache_dit-0.2.10.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|