cache-dit 0.2.16__py3-none-any.whl → 0.2.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (37) hide show
  1. cache_dit/__init__.py +12 -0
  2. cache_dit/_version.py +16 -3
  3. cache_dit/cache_factory/.gitignore +2 -0
  4. cache_dit/cache_factory/__init__.py +52 -3
  5. cache_dit/cache_factory/cache_adapters.py +654 -0
  6. cache_dit/cache_factory/{dual_block_cache/cache_blocks.py → cache_blocks.py} +1 -1
  7. cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +1 -2
  8. cache_dit/cache_factory/patch/flux.py +16 -8
  9. cache_dit/cache_factory/utils.py +1 -1
  10. cache_dit/compile/__init__.py +1 -1
  11. cache_dit/compile/utils.py +1 -1
  12. {cache_dit-0.2.16.dist-info → cache_dit-0.2.17.dist-info}/METADATA +73 -136
  13. cache_dit-0.2.17.dist-info/RECORD +30 -0
  14. cache_dit/cache_factory/adapters.py +0 -211
  15. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -4
  16. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
  17. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -90
  18. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -108
  19. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -297
  20. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -90
  21. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -91
  22. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -100
  23. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -4
  24. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -55
  25. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -90
  26. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -104
  27. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -297
  28. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -90
  29. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/qwen_image.py +0 -94
  30. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -100
  31. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +0 -276
  32. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -717
  33. cache_dit-0.2.16.dist-info/RECORD +0 -47
  34. {cache_dit-0.2.16.dist-info → cache_dit-0.2.17.dist-info}/WHEEL +0 -0
  35. {cache_dit-0.2.16.dist-info → cache_dit-0.2.17.dist-info}/entry_points.txt +0 -0
  36. {cache_dit-0.2.16.dist-info → cache_dit-0.2.17.dist-info}/licenses/LICENSE +0 -0
  37. {cache_dit-0.2.16.dist-info → cache_dit-0.2.17.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -3,3 +3,15 @@ try:
3
3
  except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
+
7
+ from cache_dit.cache_factory import load_options
8
+ from cache_dit.cache_factory import enable_cache
9
+ from cache_dit.cache_factory import cache_type
10
+ from cache_dit.cache_factory import default_options
11
+ from cache_dit.cache_factory import block_range
12
+ from cache_dit.cache_factory import CacheType
13
+ from cache_dit.compile import set_compile_configs
14
+ from cache_dit.logger import init_logger
15
+
16
+ NONE = CacheType.NONE
17
+ DBCache = CacheType.DBCache
cache_dit/_version.py CHANGED
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '0.2.16'
21
- __version_tuple__ = version_tuple = (0, 2, 16)
31
+ __version__ = version = '0.2.17'
32
+ __version_tuple__ = version_tuple = (0, 2, 17)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,2 @@
1
+ __pycache__
2
+ deprecated
@@ -1,4 +1,53 @@
1
- from cache_dit.cache_factory.adapters import CacheType
2
- from cache_dit.cache_factory.adapters import apply_cache_on_pipe
3
- from cache_dit.cache_factory.adapters import apply_cache_on_transformer
1
+ import torch
2
+ from typing import Dict, List
3
+ from diffusers import DiffusionPipeline
4
+ from cache_dit.cache_factory.cache_adapters import CacheType
5
+ from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
4
6
  from cache_dit.cache_factory.utils import load_cache_options_from_yaml
7
+
8
+ from cache_dit.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ def load_options(path: str):
14
+ """cache_dit.load_options(cache_config.yaml)"""
15
+ return load_cache_options_from_yaml(path)
16
+
17
+
18
+ def cache_type(type_hint: "CacheType | str") -> CacheType:
19
+ return CacheType.type(cache_type=type_hint)
20
+
21
+
22
+ def default_options(cache_type: CacheType = None) -> Dict:
23
+ if cache_type is None:
24
+ return CacheType.default_options(CacheType.DBCache)
25
+ return CacheType.default_options(cache_type)
26
+
27
+
28
+ def block_range(start: int, end: int, step: int = 1) -> List[int]:
29
+ return CacheType.block_range(start, end, step)
30
+
31
+
32
+ def enable_cache(
33
+ pipe: DiffusionPipeline,
34
+ *,
35
+ transformer: torch.nn.Module = None,
36
+ blocks: torch.nn.ModuleList = None,
37
+ # transformer_blocks, blocks, etc.
38
+ blocks_name: str = None,
39
+ dummy_blocks_names: list[str] = [],
40
+ return_hidden_states_first: bool = True,
41
+ return_hidden_states_only: bool = False,
42
+ **cache_options_kwargs,
43
+ ) -> DiffusionPipeline:
44
+ return UnifiedCacheAdapter.apply(
45
+ pipe,
46
+ transformer=transformer,
47
+ blocks=blocks,
48
+ blocks_name=blocks_name,
49
+ dummy_blocks_names=dummy_blocks_names,
50
+ return_hidden_states_first=return_hidden_states_first,
51
+ return_hidden_states_only=return_hidden_states_only,
52
+ **cache_options_kwargs,
53
+ )