cache-dit 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +9 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +24 -21
- cache_dit/cache_factory/cache_adapters.py +251 -350
- cache_dit/cache_factory/cache_blocks.py +26 -22
- cache_dit/cache_factory/cache_context.py +1 -1
- cache_dit/cache_factory/cache_types.py +70 -0
- cache_dit/cache_factory/forward_pattern.py +63 -0
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/utils.py +16 -16
- {cache_dit-0.2.20.dist-info → cache_dit-0.2.21.dist-info}/METADATA +74 -55
- {cache_dit-0.2.20.dist-info → cache_dit-0.2.21.dist-info}/RECORD +16 -14
- {cache_dit-0.2.20.dist-info → cache_dit-0.2.21.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.20.dist-info → cache_dit-0.2.21.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.20.dist-info → cache_dit-0.2.21.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.20.dist-info → cache_dit-0.2.21.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -10,9 +10,18 @@ from cache_dit.cache_factory import cache_type
|
|
|
10
10
|
from cache_dit.cache_factory import default_options
|
|
11
11
|
from cache_dit.cache_factory import block_range
|
|
12
12
|
from cache_dit.cache_factory import CacheType
|
|
13
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
14
|
+
from cache_dit.cache_factory import BlockAdapterParams
|
|
13
15
|
from cache_dit.compile import set_compile_configs
|
|
14
16
|
from cache_dit.utils import summary
|
|
15
17
|
from cache_dit.logger import init_logger
|
|
16
18
|
|
|
17
19
|
NONE = CacheType.NONE
|
|
18
20
|
DBCache = CacheType.DBCache
|
|
21
|
+
|
|
22
|
+
BlockAdapter = BlockAdapterParams
|
|
23
|
+
|
|
24
|
+
Forward_Pattern_0 = ForwardPattern.Pattern_0
|
|
25
|
+
Forward_Pattern_1 = ForwardPattern.Pattern_1
|
|
26
|
+
Forward_Pattern_2 = ForwardPattern.Pattern_2
|
|
27
|
+
Forward_Pattern_3 = ForwardPattern.Pattern_3
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.21'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 21)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,7 +1,8 @@
|
|
|
1
|
-
import torch
|
|
2
1
|
from typing import Dict, List
|
|
3
2
|
from diffusers import DiffusionPipeline
|
|
4
|
-
from cache_dit.cache_factory.
|
|
3
|
+
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
4
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
5
|
+
from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
|
|
5
6
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
6
7
|
from cache_dit.cache_factory.utils import load_cache_options_from_yaml
|
|
7
8
|
|
|
@@ -39,24 +40,26 @@ def block_range(
|
|
|
39
40
|
|
|
40
41
|
|
|
41
42
|
def enable_cache(
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
transformer: torch.nn.Module = None,
|
|
45
|
-
blocks: torch.nn.ModuleList = None,
|
|
46
|
-
# transformer_blocks, blocks, etc.
|
|
47
|
-
blocks_name: str = None,
|
|
48
|
-
dummy_blocks_names: list[str] = [],
|
|
49
|
-
return_hidden_states_first: bool = True,
|
|
50
|
-
return_hidden_states_only: bool = False,
|
|
43
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapterParams,
|
|
44
|
+
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
51
45
|
**cache_options_kwargs,
|
|
52
46
|
) -> DiffusionPipeline:
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
47
|
+
if isinstance(pipe_or_adapter, BlockAdapterParams):
|
|
48
|
+
return UnifiedCacheAdapter.apply(
|
|
49
|
+
pipe=None,
|
|
50
|
+
adapter_params=pipe_or_adapter,
|
|
51
|
+
forward_pattern=forward_pattern,
|
|
52
|
+
**cache_options_kwargs,
|
|
53
|
+
)
|
|
54
|
+
elif isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
55
|
+
return UnifiedCacheAdapter.apply(
|
|
56
|
+
pipe=pipe_or_adapter,
|
|
57
|
+
adapter_params=None,
|
|
58
|
+
forward_pattern=forward_pattern,
|
|
59
|
+
**cache_options_kwargs,
|
|
60
|
+
)
|
|
61
|
+
else:
|
|
62
|
+
raise ValueError(
|
|
63
|
+
"Please pass DiffusionPipeline or BlockAdapterParams"
|
|
64
|
+
"(BlockAdapter) for the 1 position param: pipe_or_adapter"
|
|
65
|
+
)
|