cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,118 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Any, Tuple, List, Dict, Callable, Union
|
|
3
|
+
|
|
4
|
+
from diffusers import DiffusionPipeline
|
|
5
|
+
from cache_dit.caching.block_adapters.block_adapters import (
|
|
6
|
+
BlockAdapter,
|
|
7
|
+
FakeDiffusionPipeline,
|
|
8
|
+
)
|
|
9
|
+
|
|
10
|
+
from cache_dit.logger import init_logger
|
|
11
|
+
|
|
12
|
+
logger = init_logger(__name__)
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
class BlockAdapterRegistry:
|
|
16
|
+
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
17
|
+
_predefined_adapters_has_separate_cfg: List[str] = [
|
|
18
|
+
"QwenImage",
|
|
19
|
+
"Wan",
|
|
20
|
+
"CogView4",
|
|
21
|
+
"Cosmos",
|
|
22
|
+
"SkyReelsV2",
|
|
23
|
+
"Chroma",
|
|
24
|
+
"Lumina2",
|
|
25
|
+
"Kandinsky5",
|
|
26
|
+
]
|
|
27
|
+
|
|
28
|
+
@classmethod
|
|
29
|
+
def register(cls, name: str, supported: bool = True):
|
|
30
|
+
def decorator(
|
|
31
|
+
func: Callable[..., BlockAdapter]
|
|
32
|
+
) -> Callable[..., BlockAdapter]:
|
|
33
|
+
if supported:
|
|
34
|
+
cls._adapters[name] = func
|
|
35
|
+
return func
|
|
36
|
+
|
|
37
|
+
return decorator
|
|
38
|
+
|
|
39
|
+
@classmethod
|
|
40
|
+
def get_adapter(
|
|
41
|
+
cls,
|
|
42
|
+
pipe_or_module: DiffusionPipeline | torch.nn.Module | str | Any,
|
|
43
|
+
**kwargs,
|
|
44
|
+
) -> BlockAdapter | None:
|
|
45
|
+
if not isinstance(pipe_or_module, str):
|
|
46
|
+
cls_name: str = pipe_or_module.__class__.__name__
|
|
47
|
+
else:
|
|
48
|
+
cls_name = pipe_or_module
|
|
49
|
+
|
|
50
|
+
for name in cls._adapters:
|
|
51
|
+
if cls_name.startswith(name):
|
|
52
|
+
if not isinstance(pipe_or_module, DiffusionPipeline):
|
|
53
|
+
assert isinstance(pipe_or_module, torch.nn.Module)
|
|
54
|
+
# NOTE: Make pre-registered adapters support Transformer-only case.
|
|
55
|
+
# WARN: This branch is not officially supported and only for testing
|
|
56
|
+
# purpose. We construct a fake diffusion pipeline that contains the
|
|
57
|
+
# given transformer module. Currently, only works for DiT models which
|
|
58
|
+
# only have one transformer module. Case like multiple transformers
|
|
59
|
+
# is not supported, e.g, Wan2.2. Please use BlockAdapter directly for
|
|
60
|
+
# such cases.
|
|
61
|
+
return cls._adapters[name](
|
|
62
|
+
FakeDiffusionPipeline(pipe_or_module), **kwargs
|
|
63
|
+
)
|
|
64
|
+
else:
|
|
65
|
+
return cls._adapters[name](pipe_or_module, **kwargs)
|
|
66
|
+
|
|
67
|
+
return None
|
|
68
|
+
|
|
69
|
+
@classmethod
|
|
70
|
+
def has_separate_cfg(
|
|
71
|
+
cls,
|
|
72
|
+
pipe_or_adapter: Union[
|
|
73
|
+
DiffusionPipeline,
|
|
74
|
+
FakeDiffusionPipeline,
|
|
75
|
+
BlockAdapter,
|
|
76
|
+
Any,
|
|
77
|
+
],
|
|
78
|
+
) -> bool:
|
|
79
|
+
|
|
80
|
+
# Prefer custom setting from block adapter.
|
|
81
|
+
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
82
|
+
return pipe_or_adapter.has_separate_cfg
|
|
83
|
+
|
|
84
|
+
has_separate_cfg = False
|
|
85
|
+
if isinstance(pipe_or_adapter, FakeDiffusionPipeline):
|
|
86
|
+
return False
|
|
87
|
+
|
|
88
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
89
|
+
adapter = cls.get_adapter(
|
|
90
|
+
pipe_or_adapter,
|
|
91
|
+
skip_post_init=True, # check cfg setting only
|
|
92
|
+
)
|
|
93
|
+
if adapter is not None:
|
|
94
|
+
has_separate_cfg = adapter.has_separate_cfg
|
|
95
|
+
|
|
96
|
+
if has_separate_cfg:
|
|
97
|
+
return True
|
|
98
|
+
|
|
99
|
+
pipe_cls_name = pipe_or_adapter.__class__.__name__
|
|
100
|
+
for name in cls._predefined_adapters_has_separate_cfg:
|
|
101
|
+
if pipe_cls_name.startswith(name):
|
|
102
|
+
return True
|
|
103
|
+
|
|
104
|
+
return False
|
|
105
|
+
|
|
106
|
+
@classmethod
|
|
107
|
+
def is_supported(cls, pipe_or_module) -> bool:
|
|
108
|
+
cls_name: str = pipe_or_module.__class__.__name__
|
|
109
|
+
|
|
110
|
+
for name in cls._adapters:
|
|
111
|
+
if cls_name.startswith(name):
|
|
112
|
+
return True
|
|
113
|
+
return False
|
|
114
|
+
|
|
115
|
+
@classmethod
|
|
116
|
+
def supported_pipelines(cls, **kwargs) -> Tuple[int, List[str]]:
|
|
117
|
+
val_pipelines = cls._adapters.keys()
|
|
118
|
+
return len(val_pipelines), [p for p in val_pipelines]
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
from cache_dit.caching.cache_adapters.cache_adapter import CachedAdapter
|