cache-dit 0.1.1.dev2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (30) hide show
  1. cache_dit/__init__.py +0 -0
  2. cache_dit/_version.py +21 -0
  3. cache_dit/cache_factory/__init__.py +166 -0
  4. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
  5. cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
  6. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
  7. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
  8. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
  9. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
  10. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
  11. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
  12. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
  13. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
  14. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
  15. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
  16. cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
  17. cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
  18. cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
  19. cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
  20. cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
  21. cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
  22. cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
  23. cache_dit/cache_factory/taylorseer.py +76 -0
  24. cache_dit/cache_factory/utils.py +0 -0
  25. cache_dit/logger.py +97 -0
  26. cache_dit/primitives.py +152 -0
  27. cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
  28. cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
  29. cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
  30. cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
cache_dit/__init__.py ADDED
File without changes
cache_dit/_version.py ADDED
@@ -0,0 +1,21 @@
1
+ # file generated by setuptools-scm
2
+ # don't change, don't track in version control
3
+
4
+ __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
5
+
6
+ TYPE_CHECKING = False
7
+ if TYPE_CHECKING:
8
+ from typing import Tuple
9
+ from typing import Union
10
+
11
+ VERSION_TUPLE = Tuple[Union[int, str], ...]
12
+ else:
13
+ VERSION_TUPLE = object
14
+
15
+ version: str
16
+ __version__: str
17
+ __version_tuple__: VERSION_TUPLE
18
+ version_tuple: VERSION_TUPLE
19
+
20
+ __version__ = version = '0.1.1.dev2'
21
+ __version_tuple__ = version_tuple = (0, 1, 1, 'dev2')
@@ -0,0 +1,166 @@
1
+ from enum import Enum
2
+
3
+ from diffusers import DiffusionPipeline
4
+
5
+ from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
+ apply_db_cache_on_pipe,
7
+ )
8
+ from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
+ apply_fb_cache_on_pipe,
10
+ )
11
+ from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
+ apply_db_prune_on_pipe,
13
+ )
14
+ from cache_dit.logger import init_logger
15
+
16
+
17
+ logger = init_logger(__name__)
18
+
19
+
20
+ class CacheType(Enum):
21
+ NONE = "NONE"
22
+ FBCache = "First_Block_Cache"
23
+ DBCache = "Dual_Block_Cache"
24
+ DBPrune = "Dynamic_Block_Prune"
25
+
26
+ @staticmethod
27
+ def type(cache_type: "CacheType | str") -> "CacheType":
28
+ if isinstance(cache_type, CacheType):
29
+ return cache_type
30
+ return CacheType.cache_type(cache_type)
31
+
32
+ @staticmethod
33
+ def cache_type(cache_type: "CacheType | str") -> "CacheType":
34
+ if cache_type is None:
35
+ return CacheType.NONE
36
+
37
+ if isinstance(cache_type, CacheType):
38
+ return cache_type
39
+ if cache_type.lower() in (
40
+ "first_block_cache",
41
+ "fb_cache",
42
+ "fbcache",
43
+ "fb",
44
+ ):
45
+ return CacheType.FBCache
46
+ elif cache_type.lower() in (
47
+ "dual_block_cache",
48
+ "db_cache",
49
+ "dbcache",
50
+ "db",
51
+ ):
52
+ return CacheType.DBCache
53
+ elif cache_type.lower() in (
54
+ "dynamic_block_prune",
55
+ "db_prune",
56
+ "dbprune",
57
+ "dbp",
58
+ ):
59
+ return CacheType.DBPrune
60
+ elif cache_type.lower() in (
61
+ "none_cache",
62
+ "nonecache",
63
+ "no_cache",
64
+ "nocache",
65
+ "none",
66
+ "no",
67
+ ):
68
+ return CacheType.NONE
69
+ else:
70
+ raise ValueError(f"Unknown cache type: {cache_type}")
71
+
72
+ @staticmethod
73
+ def range(start: int, end: int, step: int = 1) -> list[int]:
74
+ if start > end or end <= 0 or step <= 1:
75
+ return []
76
+ # Always compute 0 and end - 1 blocks for DB Cache
77
+ return list(
78
+ sorted(set([0] + list(range(start, end, step)) + [end - 1]))
79
+ )
80
+
81
+ @staticmethod
82
+ def default_options(cache_type: "CacheType | str") -> dict:
83
+ _no_options = {
84
+ "cache_type": CacheType.NONE,
85
+ }
86
+
87
+ _fb_options = {
88
+ "cache_type": CacheType.FBCache,
89
+ "residual_diff_threshold": 0.08,
90
+ "warmup_steps": 8,
91
+ "max_cached_steps": 8,
92
+ }
93
+
94
+ _Fn_compute_blocks = 8
95
+ _Bn_compute_blocks = 8
96
+
97
+ _db_options = {
98
+ "cache_type": CacheType.DBCache,
99
+ "residual_diff_threshold": 0.12,
100
+ "warmup_steps": 8,
101
+ "max_cached_steps": -1, # -1 means no limit
102
+ # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
103
+ "Fn_compute_blocks": _Fn_compute_blocks,
104
+ "Bn_compute_blocks": _Bn_compute_blocks,
105
+ "max_Fn_compute_blocks": 16,
106
+ "max_Bn_compute_blocks": 16,
107
+ "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
108
+ "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
109
+ }
110
+
111
+ _dbp_options = {
112
+ "cache_type": CacheType.DBPrune,
113
+ "residual_diff_threshold": 0.08,
114
+ "Fn_compute_blocks": _Fn_compute_blocks,
115
+ "Bn_compute_blocks": _Bn_compute_blocks,
116
+ "warmup_steps": 8,
117
+ "max_pruned_steps": -1, # -1 means no limit
118
+ }
119
+
120
+ if cache_type == CacheType.FBCache:
121
+ return _fb_options
122
+ elif cache_type == CacheType.DBCache:
123
+ return _db_options
124
+ elif cache_type == CacheType.DBPrune:
125
+ return _dbp_options
126
+ elif cache_type == CacheType.NONE:
127
+ return _no_options
128
+ else:
129
+ raise ValueError(f"Unknown cache type: {cache_type}")
130
+
131
+
132
+ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
133
+ assert isinstance(pipe, DiffusionPipeline)
134
+
135
+ if hasattr(pipe, "_is_cached") and pipe._is_cached:
136
+ return pipe
137
+
138
+ if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
139
+ return pipe
140
+
141
+ cache_type = kwargs.pop("cache_type", None)
142
+ if cache_type is None:
143
+ logger.warning(
144
+ "No cache type specified, we will use DBCache by default. "
145
+ "Please specify the cache_type explicitly if you want to "
146
+ "use a different cache type."
147
+ )
148
+ # Force to use DBCache with default cache options
149
+ return apply_db_cache_on_pipe(
150
+ pipe,
151
+ **CacheType.default_options(CacheType.DBCache),
152
+ )
153
+
154
+ cache_type = CacheType.type(cache_type)
155
+
156
+ if cache_type == CacheType.FBCache:
157
+ return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
158
+ elif cache_type == CacheType.DBCache:
159
+ return apply_db_cache_on_pipe(pipe, *args, **kwargs)
160
+ elif cache_type == CacheType.DBPrune:
161
+ return apply_db_prune_on_pipe(pipe, *args, **kwargs)
162
+ elif cache_type == CacheType.NONE:
163
+ logger.warning("Cache type is NONE, no caching will be applied.")
164
+ return pipe
165
+ else:
166
+ raise ValueError(f"Unknown cache type: {cache_type}")
File without changes