cache-dit 0.2.20__py3-none-any.whl → 0.2.22__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/__init__.py CHANGED
@@ -7,12 +7,21 @@ except ImportError:
7
7
  from cache_dit.cache_factory import load_options
8
8
  from cache_dit.cache_factory import enable_cache
9
9
  from cache_dit.cache_factory import cache_type
10
- from cache_dit.cache_factory import default_options
11
10
  from cache_dit.cache_factory import block_range
12
11
  from cache_dit.cache_factory import CacheType
12
+ from cache_dit.cache_factory import ForwardPattern
13
+ from cache_dit.cache_factory import BlockAdapterParams
13
14
  from cache_dit.compile import set_compile_configs
14
15
  from cache_dit.utils import summary
16
+ from cache_dit.utils import strify
15
17
  from cache_dit.logger import init_logger
16
18
 
17
19
  NONE = CacheType.NONE
18
20
  DBCache = CacheType.DBCache
21
+
22
+ BlockAdapter = BlockAdapterParams
23
+
24
+ Forward_Pattern_0 = ForwardPattern.Pattern_0
25
+ Forward_Pattern_1 = ForwardPattern.Pattern_1
26
+ Forward_Pattern_2 = ForwardPattern.Pattern_2
27
+ Forward_Pattern_3 = ForwardPattern.Pattern_3
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.20'
32
- __version_tuple__ = version_tuple = (0, 2, 20)
31
+ __version__ = version = '0.2.22'
32
+ __version_tuple__ = version_tuple = (0, 2, 22)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,62 +1,8 @@
1
- import torch
2
- from typing import Dict, List
3
- from diffusers import DiffusionPipeline
4
- from cache_dit.cache_factory.cache_adapters import CacheType
1
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
2
+ from cache_dit.cache_factory.cache_types import CacheType
3
+ from cache_dit.cache_factory.cache_types import cache_type
4
+ from cache_dit.cache_factory.cache_types import block_range
5
+ from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
5
6
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
6
- from cache_dit.cache_factory.utils import load_cache_options_from_yaml
7
-
8
- from cache_dit.logger import init_logger
9
-
10
- logger = init_logger(__name__)
11
-
12
-
13
- def load_options(path: str):
14
- return load_cache_options_from_yaml(path)
15
-
16
-
17
- def cache_type(
18
- type_hint: "CacheType | str",
19
- ) -> CacheType:
20
- return CacheType.type(cache_type=type_hint)
21
-
22
-
23
- def default_options(
24
- cache_type: CacheType = CacheType.DBCache,
25
- ) -> Dict:
26
- return CacheType.default_options(cache_type)
27
-
28
-
29
- def block_range(
30
- start: int,
31
- end: int,
32
- step: int = 1,
33
- ) -> List[int]:
34
- return CacheType.block_range(
35
- start,
36
- end,
37
- step,
38
- )
39
-
40
-
41
- def enable_cache(
42
- pipe: DiffusionPipeline,
43
- *,
44
- transformer: torch.nn.Module = None,
45
- blocks: torch.nn.ModuleList = None,
46
- # transformer_blocks, blocks, etc.
47
- blocks_name: str = None,
48
- dummy_blocks_names: list[str] = [],
49
- return_hidden_states_first: bool = True,
50
- return_hidden_states_only: bool = False,
51
- **cache_options_kwargs,
52
- ) -> DiffusionPipeline:
53
- return UnifiedCacheAdapter.apply(
54
- pipe,
55
- transformer=transformer,
56
- blocks=blocks,
57
- blocks_name=blocks_name,
58
- dummy_blocks_names=dummy_blocks_names,
59
- return_hidden_states_first=return_hidden_states_first,
60
- return_hidden_states_only=return_hidden_states_only,
61
- **cache_options_kwargs,
62
- )
7
+ from cache_dit.cache_factory.cache_interface import enable_cache
8
+ from cache_dit.cache_factory.utils import load_options