cache-dit 0.2.16__py3-none-any.whl → 0.2.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (38) hide show
  1. cache_dit/__init__.py +13 -0
  2. cache_dit/_version.py +16 -3
  3. cache_dit/cache_factory/.gitignore +2 -0
  4. cache_dit/cache_factory/__init__.py +61 -3
  5. cache_dit/cache_factory/cache_adapters.py +655 -0
  6. cache_dit/cache_factory/{dual_block_cache/cache_blocks.py → cache_blocks.py} +1 -1
  7. cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +1 -2
  8. cache_dit/cache_factory/patch/flux.py +16 -8
  9. cache_dit/cache_factory/utils.py +1 -1
  10. cache_dit/compile/__init__.py +1 -1
  11. cache_dit/compile/utils.py +1 -1
  12. cache_dit/utils.py +125 -0
  13. {cache_dit-0.2.16.dist-info → cache_dit-0.2.18.dist-info}/METADATA +98 -136
  14. cache_dit-0.2.18.dist-info/RECORD +30 -0
  15. cache_dit/cache_factory/adapters.py +0 -211
  16. cache_dit/cache_factory/dual_block_cache/__init__.py +0 -4
  17. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
  18. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -90
  19. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -108
  20. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -297
  21. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -90
  22. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -91
  23. cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -100
  24. cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -4
  25. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -55
  26. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -90
  27. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -104
  28. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -297
  29. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -90
  30. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/qwen_image.py +0 -94
  31. cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -100
  32. cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +0 -276
  33. cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -717
  34. cache_dit-0.2.16.dist-info/RECORD +0 -47
  35. {cache_dit-0.2.16.dist-info → cache_dit-0.2.18.dist-info}/WHEEL +0 -0
  36. {cache_dit-0.2.16.dist-info → cache_dit-0.2.18.dist-info}/entry_points.txt +0 -0
  37. {cache_dit-0.2.16.dist-info → cache_dit-0.2.18.dist-info}/licenses/LICENSE +0 -0
  38. {cache_dit-0.2.16.dist-info → cache_dit-0.2.18.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -3,3 +3,16 @@ try:
3
3
  except ImportError:
4
4
  __version__ = "unknown version"
5
5
  version_tuple = (0, 0, "unknown version")
6
+
7
+ from cache_dit.cache_factory import load_options
8
+ from cache_dit.cache_factory import enable_cache
9
+ from cache_dit.cache_factory import cache_type
10
+ from cache_dit.cache_factory import default_options
11
+ from cache_dit.cache_factory import block_range
12
+ from cache_dit.cache_factory import CacheType
13
+ from cache_dit.compile import set_compile_configs
14
+ from cache_dit.utils import summary
15
+ from cache_dit.logger import init_logger
16
+
17
+ NONE = CacheType.NONE
18
+ DBCache = CacheType.DBCache
cache_dit/_version.py CHANGED
@@ -1,7 +1,14 @@
1
1
  # file generated by setuptools-scm
2
2
  # don't change, don't track in version control
3
3
 
4
- __all__ = ["__version__", "__version_tuple__", "version", "version_tuple"]
4
+ __all__ = [
5
+ "__version__",
6
+ "__version_tuple__",
7
+ "version",
8
+ "version_tuple",
9
+ "__commit_id__",
10
+ "commit_id",
11
+ ]
5
12
 
6
13
  TYPE_CHECKING = False
7
14
  if TYPE_CHECKING:
@@ -9,13 +16,19 @@ if TYPE_CHECKING:
9
16
  from typing import Union
10
17
 
11
18
  VERSION_TUPLE = Tuple[Union[int, str], ...]
19
+ COMMIT_ID = Union[str, None]
12
20
  else:
13
21
  VERSION_TUPLE = object
22
+ COMMIT_ID = object
14
23
 
15
24
  version: str
16
25
  __version__: str
17
26
  __version_tuple__: VERSION_TUPLE
18
27
  version_tuple: VERSION_TUPLE
28
+ commit_id: COMMIT_ID
29
+ __commit_id__: COMMIT_ID
19
30
 
20
- __version__ = version = '0.2.16'
21
- __version_tuple__ = version_tuple = (0, 2, 16)
31
+ __version__ = version = '0.2.18'
32
+ __version_tuple__ = version_tuple = (0, 2, 18)
33
+
34
+ __commit_id__ = commit_id = None
@@ -0,0 +1,2 @@
1
+ __pycache__
2
+ deprecated
@@ -1,4 +1,62 @@
1
- from cache_dit.cache_factory.adapters import CacheType
2
- from cache_dit.cache_factory.adapters import apply_cache_on_pipe
3
- from cache_dit.cache_factory.adapters import apply_cache_on_transformer
1
+ import torch
2
+ from typing import Dict, List
3
+ from diffusers import DiffusionPipeline
4
+ from cache_dit.cache_factory.cache_adapters import CacheType
5
+ from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
4
6
  from cache_dit.cache_factory.utils import load_cache_options_from_yaml
7
+
8
+ from cache_dit.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ def load_options(path: str):
14
+ return load_cache_options_from_yaml(path)
15
+
16
+
17
+ def cache_type(
18
+ type_hint: "CacheType | str",
19
+ ) -> CacheType:
20
+ return CacheType.type(cache_type=type_hint)
21
+
22
+
23
+ def default_options(
24
+ cache_type: CacheType = CacheType.DBCache,
25
+ ) -> Dict:
26
+ return CacheType.default_options(cache_type)
27
+
28
+
29
+ def block_range(
30
+ start: int,
31
+ end: int,
32
+ step: int = 1,
33
+ ) -> List[int]:
34
+ return CacheType.block_range(
35
+ start,
36
+ end,
37
+ step,
38
+ )
39
+
40
+
41
+ def enable_cache(
42
+ pipe: DiffusionPipeline,
43
+ *,
44
+ transformer: torch.nn.Module = None,
45
+ blocks: torch.nn.ModuleList = None,
46
+ # transformer_blocks, blocks, etc.
47
+ blocks_name: str = None,
48
+ dummy_blocks_names: list[str] = [],
49
+ return_hidden_states_first: bool = True,
50
+ return_hidden_states_only: bool = False,
51
+ **cache_options_kwargs,
52
+ ) -> DiffusionPipeline:
53
+ return UnifiedCacheAdapter.apply(
54
+ pipe,
55
+ transformer=transformer,
56
+ blocks=blocks,
57
+ blocks_name=blocks_name,
58
+ dummy_blocks_names=dummy_blocks_names,
59
+ return_hidden_states_first=return_hidden_states_first,
60
+ return_hidden_states_only=return_hidden_states_only,
61
+ **cache_options_kwargs,
62
+ )