cache-dit 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.0.dist-info/METADATA +350 -0
- cache_dit-0.1.0.dist-info/RECORD +31 -0
- cache_dit-0.1.0.dist-info/WHEEL +5 -0
- cache_dit-0.1.0.dist-info/licenses/LICENSE +53 -0
- cache_dit-0.1.0.dist-info/top_level.txt +1 -0
cache_dit/__init__.py
ADDED
|
File without changes
|
cache_dit/_version.py
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
# file generated by setuptools-scm
|
|
2
|
+
# don't change, don't track in version control
|
|
3
|
+
|
|
4
|
+
__all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
|
|
5
|
+
|
|
6
|
+
TYPE_CHECKING = False
|
|
7
|
+
if TYPE_CHECKING:
|
|
8
|
+
from typing import Tuple
|
|
9
|
+
from typing import Union
|
|
10
|
+
|
|
11
|
+
VERSION_TUPLE = Tuple[Union[int, str], ...]
|
|
12
|
+
else:
|
|
13
|
+
VERSION_TUPLE = object
|
|
14
|
+
|
|
15
|
+
version: str
|
|
16
|
+
__version__: str
|
|
17
|
+
__version_tuple__: VERSION_TUPLE
|
|
18
|
+
version_tuple: VERSION_TUPLE
|
|
19
|
+
|
|
20
|
+
__version__ = version = '0.1.0'
|
|
21
|
+
__version_tuple__ = version_tuple = (0, 1, 0)
|
|
@@ -0,0 +1,166 @@
|
|
|
1
|
+
from enum import Enum
|
|
2
|
+
|
|
3
|
+
from diffusers import DiffusionPipeline
|
|
4
|
+
|
|
5
|
+
from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
|
|
6
|
+
apply_db_cache_on_pipe,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
|
|
9
|
+
apply_fb_cache_on_pipe,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
|
|
12
|
+
apply_db_prune_on_pipe,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.logger import init_logger
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
logger = init_logger(__name__)
|
|
18
|
+
|
|
19
|
+
|
|
20
|
+
class CacheType(Enum):
|
|
21
|
+
NONE = "NONE"
|
|
22
|
+
FBCache = "First_Block_Cache"
|
|
23
|
+
DBCache = "Dual_Block_Cache"
|
|
24
|
+
DBPrune = "Dynamic_Block_Prune"
|
|
25
|
+
|
|
26
|
+
@staticmethod
|
|
27
|
+
def type(cache_type: "CacheType | str") -> "CacheType":
|
|
28
|
+
if isinstance(cache_type, CacheType):
|
|
29
|
+
return cache_type
|
|
30
|
+
return CacheType.cache_type(cache_type)
|
|
31
|
+
|
|
32
|
+
@staticmethod
|
|
33
|
+
def cache_type(cache_type: "CacheType | str") -> "CacheType":
|
|
34
|
+
if cache_type is None:
|
|
35
|
+
return CacheType.NONE
|
|
36
|
+
|
|
37
|
+
if isinstance(cache_type, CacheType):
|
|
38
|
+
return cache_type
|
|
39
|
+
if cache_type.lower() in (
|
|
40
|
+
"first_block_cache",
|
|
41
|
+
"fb_cache",
|
|
42
|
+
"fbcache",
|
|
43
|
+
"fb",
|
|
44
|
+
):
|
|
45
|
+
return CacheType.FBCache
|
|
46
|
+
elif cache_type.lower() in (
|
|
47
|
+
"dual_block_cache",
|
|
48
|
+
"db_cache",
|
|
49
|
+
"dbcache",
|
|
50
|
+
"db",
|
|
51
|
+
):
|
|
52
|
+
return CacheType.DBCache
|
|
53
|
+
elif cache_type.lower() in (
|
|
54
|
+
"dynamic_block_prune",
|
|
55
|
+
"db_prune",
|
|
56
|
+
"dbprune",
|
|
57
|
+
"dbp",
|
|
58
|
+
):
|
|
59
|
+
return CacheType.DBPrune
|
|
60
|
+
elif cache_type.lower() in (
|
|
61
|
+
"none_cache",
|
|
62
|
+
"nonecache",
|
|
63
|
+
"no_cache",
|
|
64
|
+
"nocache",
|
|
65
|
+
"none",
|
|
66
|
+
"no",
|
|
67
|
+
):
|
|
68
|
+
return CacheType.NONE
|
|
69
|
+
else:
|
|
70
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
71
|
+
|
|
72
|
+
@staticmethod
|
|
73
|
+
def range(start: int, end: int, step: int = 1) -> list[int]:
|
|
74
|
+
if start > end or end <= 0 or step <= 1:
|
|
75
|
+
return []
|
|
76
|
+
# Always compute 0 and end - 1 blocks for DB Cache
|
|
77
|
+
return list(
|
|
78
|
+
sorted(set([0] + list(range(start, end, step)) + [end - 1]))
|
|
79
|
+
)
|
|
80
|
+
|
|
81
|
+
@staticmethod
|
|
82
|
+
def default_options(cache_type: "CacheType | str") -> dict:
|
|
83
|
+
_no_options = {
|
|
84
|
+
"cache_type": CacheType.NONE,
|
|
85
|
+
}
|
|
86
|
+
|
|
87
|
+
_fb_options = {
|
|
88
|
+
"cache_type": CacheType.FBCache,
|
|
89
|
+
"residual_diff_threshold": 0.08,
|
|
90
|
+
"warmup_steps": 8,
|
|
91
|
+
"max_cached_steps": 8,
|
|
92
|
+
}
|
|
93
|
+
|
|
94
|
+
_Fn_compute_blocks = 8
|
|
95
|
+
_Bn_compute_blocks = 8
|
|
96
|
+
|
|
97
|
+
_db_options = {
|
|
98
|
+
"cache_type": CacheType.DBCache,
|
|
99
|
+
"residual_diff_threshold": 0.12,
|
|
100
|
+
"warmup_steps": 8,
|
|
101
|
+
"max_cached_steps": -1, # -1 means no limit
|
|
102
|
+
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
103
|
+
"Fn_compute_blocks": _Fn_compute_blocks,
|
|
104
|
+
"Bn_compute_blocks": _Bn_compute_blocks,
|
|
105
|
+
"max_Fn_compute_blocks": 16,
|
|
106
|
+
"max_Bn_compute_blocks": 16,
|
|
107
|
+
"Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
|
|
108
|
+
"Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
|
|
109
|
+
}
|
|
110
|
+
|
|
111
|
+
_dbp_options = {
|
|
112
|
+
"cache_type": CacheType.DBPrune,
|
|
113
|
+
"residual_diff_threshold": 0.08,
|
|
114
|
+
"Fn_compute_blocks": _Fn_compute_blocks,
|
|
115
|
+
"Bn_compute_blocks": _Bn_compute_blocks,
|
|
116
|
+
"warmup_steps": 8,
|
|
117
|
+
"max_pruned_steps": -1, # -1 means no limit
|
|
118
|
+
}
|
|
119
|
+
|
|
120
|
+
if cache_type == CacheType.FBCache:
|
|
121
|
+
return _fb_options
|
|
122
|
+
elif cache_type == CacheType.DBCache:
|
|
123
|
+
return _db_options
|
|
124
|
+
elif cache_type == CacheType.DBPrune:
|
|
125
|
+
return _dbp_options
|
|
126
|
+
elif cache_type == CacheType.NONE:
|
|
127
|
+
return _no_options
|
|
128
|
+
else:
|
|
129
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
130
|
+
|
|
131
|
+
|
|
132
|
+
def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
133
|
+
assert isinstance(pipe, DiffusionPipeline)
|
|
134
|
+
|
|
135
|
+
if hasattr(pipe, "_is_cached") and pipe._is_cached:
|
|
136
|
+
return pipe
|
|
137
|
+
|
|
138
|
+
if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
|
|
139
|
+
return pipe
|
|
140
|
+
|
|
141
|
+
cache_type = kwargs.pop("cache_type", None)
|
|
142
|
+
if cache_type is None:
|
|
143
|
+
logger.warning(
|
|
144
|
+
"No cache type specified, we will use DBCache by default. "
|
|
145
|
+
"Please specify the cache_type explicitly if you want to "
|
|
146
|
+
"use a different cache type."
|
|
147
|
+
)
|
|
148
|
+
# Force to use DBCache with default cache options
|
|
149
|
+
return apply_db_cache_on_pipe(
|
|
150
|
+
pipe,
|
|
151
|
+
**CacheType.default_options(CacheType.DBCache),
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
cache_type = CacheType.type(cache_type)
|
|
155
|
+
|
|
156
|
+
if cache_type == CacheType.FBCache:
|
|
157
|
+
return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
|
|
158
|
+
elif cache_type == CacheType.DBCache:
|
|
159
|
+
return apply_db_cache_on_pipe(pipe, *args, **kwargs)
|
|
160
|
+
elif cache_type == CacheType.DBPrune:
|
|
161
|
+
return apply_db_prune_on_pipe(pipe, *args, **kwargs)
|
|
162
|
+
elif cache_type == CacheType.NONE:
|
|
163
|
+
logger.warning("Cache type is NONE, no caching will be applied.")
|
|
164
|
+
return pipe
|
|
165
|
+
else:
|
|
166
|
+
raise ValueError(f"Unknown cache type: {cache_type}")
|
|
File without changes
|