cache-dit 0.2.20__py3-none-any.whl → 0.2.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/__init__.py CHANGED
@@ -10,9 +10,18 @@ from cache_dit.cache_factory import cache_type
10
10
  from cache_dit.cache_factory import default_options
11
11
  from cache_dit.cache_factory import block_range
12
12
  from cache_dit.cache_factory import CacheType
13
+ from cache_dit.cache_factory import ForwardPattern
14
+ from cache_dit.cache_factory import BlockAdapterParams
13
15
  from cache_dit.compile import set_compile_configs
14
16
  from cache_dit.utils import summary
15
17
  from cache_dit.logger import init_logger
16
18
 
17
19
  NONE = CacheType.NONE
18
20
  DBCache = CacheType.DBCache
21
+
22
+ BlockAdapter = BlockAdapterParams
23
+
24
+ Forward_Pattern_0 = ForwardPattern.Pattern_0
25
+ Forward_Pattern_1 = ForwardPattern.Pattern_1
26
+ Forward_Pattern_2 = ForwardPattern.Pattern_2
27
+ Forward_Pattern_3 = ForwardPattern.Pattern_3
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.20'
32
- __version_tuple__ = version_tuple = (0, 2, 20)
31
+ __version__ = version = '0.2.21'
32
+ __version_tuple__ = version_tuple = (0, 2, 21)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -1,7 +1,8 @@
1
- import torch
2
1
  from typing import Dict, List
3
2
  from diffusers import DiffusionPipeline
4
- from cache_dit.cache_factory.cache_adapters import CacheType
3
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
4
+ from cache_dit.cache_factory.cache_types import CacheType
5
+ from cache_dit.cache_factory.cache_adapters import BlockAdapterParams
5
6
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
6
7
  from cache_dit.cache_factory.utils import load_cache_options_from_yaml
7
8
 
@@ -39,24 +40,26 @@ def block_range(
39
40
 
40
41
 
41
42
  def enable_cache(
42
- pipe: DiffusionPipeline,
43
- *,
44
- transformer: torch.nn.Module = None,
45
- blocks: torch.nn.ModuleList = None,
46
- # transformer_blocks, blocks, etc.
47
- blocks_name: str = None,
48
- dummy_blocks_names: list[str] = [],
49
- return_hidden_states_first: bool = True,
50
- return_hidden_states_only: bool = False,
43
+ pipe_or_adapter: DiffusionPipeline | BlockAdapterParams,
44
+ forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
51
45
  **cache_options_kwargs,
52
46
  ) -> DiffusionPipeline:
53
- return UnifiedCacheAdapter.apply(
54
- pipe,
55
- transformer=transformer,
56
- blocks=blocks,
57
- blocks_name=blocks_name,
58
- dummy_blocks_names=dummy_blocks_names,
59
- return_hidden_states_first=return_hidden_states_first,
60
- return_hidden_states_only=return_hidden_states_only,
61
- **cache_options_kwargs,
62
- )
47
+ if isinstance(pipe_or_adapter, BlockAdapterParams):
48
+ return UnifiedCacheAdapter.apply(
49
+ pipe=None,
50
+ adapter_params=pipe_or_adapter,
51
+ forward_pattern=forward_pattern,
52
+ **cache_options_kwargs,
53
+ )
54
+ elif isinstance(pipe_or_adapter, DiffusionPipeline):
55
+ return UnifiedCacheAdapter.apply(
56
+ pipe=pipe_or_adapter,
57
+ adapter_params=None,
58
+ forward_pattern=forward_pattern,
59
+ **cache_options_kwargs,
60
+ )
61
+ else:
62
+ raise ValueError(
63
+ "Please pass DiffusionPipeline or BlockAdapterParams"
64
+ "(BlockAdapter) for the 1 position param: pipe_or_adapter"
65
+ )