cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -4,32 +4,50 @@ except ImportError:
|
|
|
4
4
|
__version__ = "unknown version"
|
|
5
5
|
version_tuple = (0, 0, "unknown version")
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
from cache_dit.utils import strify
|
|
7
|
+
|
|
9
8
|
from cache_dit.utils import disable_print
|
|
10
9
|
from cache_dit.logger import init_logger
|
|
11
|
-
from cache_dit.
|
|
12
|
-
from cache_dit.
|
|
13
|
-
from cache_dit.
|
|
14
|
-
from cache_dit.
|
|
15
|
-
from cache_dit.
|
|
16
|
-
from cache_dit.
|
|
17
|
-
from cache_dit.
|
|
18
|
-
from cache_dit.
|
|
19
|
-
from cache_dit.
|
|
20
|
-
from cache_dit.
|
|
21
|
-
from cache_dit.
|
|
22
|
-
from cache_dit.
|
|
23
|
-
from cache_dit.
|
|
24
|
-
from cache_dit.
|
|
25
|
-
from cache_dit.
|
|
26
|
-
from cache_dit.
|
|
10
|
+
from cache_dit.caching import load_options
|
|
11
|
+
from cache_dit.caching import enable_cache
|
|
12
|
+
from cache_dit.caching import disable_cache
|
|
13
|
+
from cache_dit.caching import cache_type
|
|
14
|
+
from cache_dit.caching import block_range
|
|
15
|
+
from cache_dit.caching import CacheType
|
|
16
|
+
from cache_dit.caching import BlockAdapter
|
|
17
|
+
from cache_dit.caching import ParamsModifier
|
|
18
|
+
from cache_dit.caching import ForwardPattern
|
|
19
|
+
from cache_dit.caching import PatchFunctor
|
|
20
|
+
from cache_dit.caching import BasicCacheConfig
|
|
21
|
+
from cache_dit.caching import DBCacheConfig
|
|
22
|
+
from cache_dit.caching import DBPruneConfig
|
|
23
|
+
from cache_dit.caching import CalibratorConfig
|
|
24
|
+
from cache_dit.caching import TaylorSeerCalibratorConfig
|
|
25
|
+
from cache_dit.caching import FoCaCalibratorConfig
|
|
26
|
+
from cache_dit.caching import supported_pipelines
|
|
27
|
+
from cache_dit.caching import get_adapter
|
|
27
28
|
from cache_dit.compile import set_compile_configs
|
|
28
|
-
from cache_dit.
|
|
29
|
+
from cache_dit.parallelism import ParallelismBackend
|
|
30
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
31
|
+
from cache_dit.summary import supported_matrix
|
|
32
|
+
from cache_dit.summary import summary
|
|
33
|
+
from cache_dit.summary import strify
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from cache_dit.quantize import quantize
|
|
37
|
+
except ImportError as e: # noqa: F841
|
|
38
|
+
err_msg = str(e)
|
|
39
|
+
|
|
40
|
+
def quantize(*args, **kwargs):
|
|
41
|
+
raise ImportError(
|
|
42
|
+
"Quantization requires additional dependencies. "
|
|
43
|
+
"Please install cache-dit[quantization] or cache-dit[all] "
|
|
44
|
+
f"to use this feature. Error message: {err_msg}"
|
|
45
|
+
)
|
|
29
46
|
|
|
30
47
|
|
|
31
48
|
NONE = CacheType.NONE
|
|
32
49
|
DBCache = CacheType.DBCache
|
|
50
|
+
DBPrune = CacheType.DBPrune
|
|
33
51
|
|
|
34
52
|
Pattern_0 = ForwardPattern.Pattern_0
|
|
35
53
|
Pattern_1 = ForwardPattern.Pattern_1
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.
|
|
32
|
-
__version_tuple__ = version_tuple = (
|
|
31
|
+
__version__ = version = '1.0.14'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 14)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from cache_dit.caching.cache_types import CacheType
|
|
2
|
+
from cache_dit.caching.cache_types import cache_type
|
|
3
|
+
from cache_dit.caching.cache_types import block_range
|
|
4
|
+
|
|
5
|
+
from cache_dit.caching.forward_pattern import ForwardPattern
|
|
6
|
+
from cache_dit.caching.params_modifier import ParamsModifier
|
|
7
|
+
from cache_dit.caching.patch_functors import PatchFunctor
|
|
8
|
+
|
|
9
|
+
from cache_dit.caching.block_adapters import BlockAdapter
|
|
10
|
+
from cache_dit.caching.block_adapters import BlockAdapterRegistry
|
|
11
|
+
from cache_dit.caching.block_adapters import FakeDiffusionPipeline
|
|
12
|
+
|
|
13
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
14
|
+
from cache_dit.caching.cache_contexts import DBCacheConfig
|
|
15
|
+
from cache_dit.caching.cache_contexts import CachedContext
|
|
16
|
+
from cache_dit.caching.cache_contexts import CachedContextManager
|
|
17
|
+
from cache_dit.caching.cache_contexts import DBPruneConfig
|
|
18
|
+
from cache_dit.caching.cache_contexts import PrunedContext
|
|
19
|
+
from cache_dit.caching.cache_contexts import PrunedContextManager
|
|
20
|
+
from cache_dit.caching.cache_contexts import ContextManager
|
|
21
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
22
|
+
from cache_dit.caching.cache_contexts import TaylorSeerCalibratorConfig
|
|
23
|
+
from cache_dit.caching.cache_contexts import FoCaCalibratorConfig
|
|
24
|
+
|
|
25
|
+
from cache_dit.caching.cache_blocks import CachedBlocks
|
|
26
|
+
from cache_dit.caching.cache_blocks import PrunedBlocks
|
|
27
|
+
from cache_dit.caching.cache_blocks import UnifiedBlocks
|
|
28
|
+
|
|
29
|
+
from cache_dit.caching.cache_adapters import CachedAdapter
|
|
30
|
+
|
|
31
|
+
from cache_dit.caching.cache_interface import enable_cache
|
|
32
|
+
from cache_dit.caching.cache_interface import disable_cache
|
|
33
|
+
from cache_dit.caching.cache_interface import supported_pipelines
|
|
34
|
+
from cache_dit.caching.cache_interface import get_adapter
|
|
35
|
+
|
|
36
|
+
from cache_dit.caching.utils import load_options
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
-
from cache_dit.
|
|
2
|
-
from cache_dit.
|
|
3
|
-
from cache_dit.
|
|
4
|
-
|
|
1
|
+
from cache_dit.caching.forward_pattern import ForwardPattern
|
|
2
|
+
from cache_dit.caching.block_adapters.block_adapters import BlockAdapter
|
|
3
|
+
from cache_dit.caching.block_adapters.block_adapters import (
|
|
4
|
+
FakeDiffusionPipeline,
|
|
5
|
+
)
|
|
6
|
+
from cache_dit.caching.block_adapters.block_adapters import ParamsModifier
|
|
7
|
+
from cache_dit.caching.block_adapters.block_registers import (
|
|
5
8
|
BlockAdapterRegistry,
|
|
6
9
|
)
|
|
7
10
|
|
|
@@ -12,7 +15,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
12
15
|
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
13
16
|
|
|
14
17
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
15
|
-
|
|
18
|
+
transformer_cls_name: str = pipe.transformer.__class__.__name__
|
|
19
|
+
if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
|
|
20
|
+
"Nunchaku"
|
|
21
|
+
):
|
|
16
22
|
return BlockAdapter(
|
|
17
23
|
pipe=pipe,
|
|
18
24
|
transformer=pipe.transformer,
|
|
@@ -24,6 +30,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
24
30
|
ForwardPattern.Pattern_1,
|
|
25
31
|
ForwardPattern.Pattern_1,
|
|
26
32
|
],
|
|
33
|
+
check_forward_pattern=True,
|
|
27
34
|
**kwargs,
|
|
28
35
|
)
|
|
29
36
|
else:
|
|
@@ -38,6 +45,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
38
45
|
ForwardPattern.Pattern_1,
|
|
39
46
|
ForwardPattern.Pattern_3,
|
|
40
47
|
],
|
|
48
|
+
check_forward_pattern=True,
|
|
41
49
|
**kwargs,
|
|
42
50
|
)
|
|
43
51
|
|
|
@@ -52,6 +60,7 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
52
60
|
transformer=pipe.transformer,
|
|
53
61
|
blocks=pipe.transformer.transformer_blocks,
|
|
54
62
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
63
|
+
check_forward_pattern=True,
|
|
55
64
|
**kwargs,
|
|
56
65
|
)
|
|
57
66
|
|
|
@@ -66,6 +75,7 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
66
75
|
transformer=pipe.transformer,
|
|
67
76
|
blocks=pipe.transformer.transformer_blocks,
|
|
68
77
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
78
|
+
check_forward_pattern=True,
|
|
69
79
|
**kwargs,
|
|
70
80
|
)
|
|
71
81
|
|
|
@@ -101,6 +111,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
101
111
|
ForwardPattern.Pattern_2,
|
|
102
112
|
ForwardPattern.Pattern_2,
|
|
103
113
|
],
|
|
114
|
+
check_forward_pattern=True,
|
|
104
115
|
has_separate_cfg=True,
|
|
105
116
|
**kwargs,
|
|
106
117
|
)
|
|
@@ -111,6 +122,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
111
122
|
transformer=pipe.transformer,
|
|
112
123
|
blocks=pipe.transformer.blocks,
|
|
113
124
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
125
|
+
check_forward_pattern=True,
|
|
114
126
|
has_separate_cfg=True,
|
|
115
127
|
**kwargs,
|
|
116
128
|
)
|
|
@@ -132,6 +144,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
132
144
|
ForwardPattern.Pattern_0,
|
|
133
145
|
ForwardPattern.Pattern_0,
|
|
134
146
|
],
|
|
147
|
+
check_forward_pattern=True,
|
|
135
148
|
# The type hint in diffusers is wrong
|
|
136
149
|
check_num_outputs=False,
|
|
137
150
|
**kwargs,
|
|
@@ -143,14 +156,32 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
143
156
|
from diffusers import QwenImageTransformer2DModel
|
|
144
157
|
|
|
145
158
|
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
159
|
+
|
|
160
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
161
|
+
if pipe_cls_name.startswith("QwenImageControlNet"):
|
|
162
|
+
from cache_dit.caching.patch_functors import (
|
|
163
|
+
QwenImageControlNetPatchFunctor,
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return BlockAdapter(
|
|
167
|
+
pipe=pipe,
|
|
168
|
+
transformer=pipe.transformer,
|
|
169
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
170
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
171
|
+
patch_functor=QwenImageControlNetPatchFunctor(),
|
|
172
|
+
check_forward_pattern=True,
|
|
173
|
+
has_separate_cfg=True,
|
|
174
|
+
)
|
|
175
|
+
else:
|
|
176
|
+
return BlockAdapter(
|
|
177
|
+
pipe=pipe,
|
|
178
|
+
transformer=pipe.transformer,
|
|
179
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
180
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
181
|
+
check_forward_pattern=True,
|
|
182
|
+
has_separate_cfg=True,
|
|
183
|
+
**kwargs,
|
|
184
|
+
)
|
|
154
185
|
|
|
155
186
|
|
|
156
187
|
@BlockAdapterRegistry.register("LTX")
|
|
@@ -163,6 +194,7 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
163
194
|
transformer=pipe.transformer,
|
|
164
195
|
blocks=pipe.transformer.transformer_blocks,
|
|
165
196
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
197
|
+
check_forward_pattern=True,
|
|
166
198
|
**kwargs,
|
|
167
199
|
)
|
|
168
200
|
|
|
@@ -177,6 +209,7 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
177
209
|
transformer=pipe.transformer,
|
|
178
210
|
blocks=pipe.transformer.transformer_blocks,
|
|
179
211
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
212
|
+
check_forward_pattern=True,
|
|
180
213
|
**kwargs,
|
|
181
214
|
)
|
|
182
215
|
|
|
@@ -191,6 +224,7 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
191
224
|
transformer=pipe.transformer,
|
|
192
225
|
blocks=pipe.transformer.transformer_blocks,
|
|
193
226
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
227
|
+
check_forward_pattern=True,
|
|
194
228
|
**kwargs,
|
|
195
229
|
)
|
|
196
230
|
|
|
@@ -205,6 +239,7 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
205
239
|
transformer=pipe.transformer,
|
|
206
240
|
blocks=pipe.transformer.transformer_blocks,
|
|
207
241
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
242
|
+
check_forward_pattern=True,
|
|
208
243
|
has_separate_cfg=True,
|
|
209
244
|
**kwargs,
|
|
210
245
|
)
|
|
@@ -220,6 +255,7 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
220
255
|
transformer=pipe.transformer,
|
|
221
256
|
blocks=pipe.transformer.transformer_blocks,
|
|
222
257
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
258
|
+
check_forward_pattern=True,
|
|
223
259
|
has_separate_cfg=True,
|
|
224
260
|
**kwargs,
|
|
225
261
|
)
|
|
@@ -235,6 +271,7 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
235
271
|
transformer=pipe.transformer,
|
|
236
272
|
blocks=pipe.transformer.transformer_blocks,
|
|
237
273
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
274
|
+
check_forward_pattern=True,
|
|
238
275
|
**kwargs,
|
|
239
276
|
)
|
|
240
277
|
|
|
@@ -252,6 +289,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
252
289
|
# encoder_hidden_states will never change in the blocks
|
|
253
290
|
# forward loop.
|
|
254
291
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
292
|
+
check_forward_pattern=True,
|
|
255
293
|
has_separate_cfg=True,
|
|
256
294
|
**kwargs,
|
|
257
295
|
)
|
|
@@ -267,6 +305,7 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
267
305
|
transformer=pipe.transformer,
|
|
268
306
|
blocks=pipe.transformer.transformer_blocks,
|
|
269
307
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
308
|
+
check_forward_pattern=True,
|
|
270
309
|
**kwargs,
|
|
271
310
|
)
|
|
272
311
|
|
|
@@ -281,6 +320,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
281
320
|
transformer=pipe.transformer,
|
|
282
321
|
blocks=pipe.transformer.transformer_blocks,
|
|
283
322
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
323
|
+
check_forward_pattern=True,
|
|
284
324
|
**kwargs,
|
|
285
325
|
)
|
|
286
326
|
|
|
@@ -288,7 +328,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
288
328
|
@BlockAdapterRegistry.register("DiT")
|
|
289
329
|
def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
290
330
|
from diffusers import DiTTransformer2DModel
|
|
291
|
-
from cache_dit.
|
|
331
|
+
from cache_dit.caching.patch_functors import DiTPatchFunctor
|
|
292
332
|
|
|
293
333
|
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
294
334
|
return BlockAdapter(
|
|
@@ -297,6 +337,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
297
337
|
blocks=pipe.transformer.transformer_blocks,
|
|
298
338
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
299
339
|
patch_functor=DiTPatchFunctor(),
|
|
340
|
+
check_forward_pattern=True,
|
|
300
341
|
**kwargs,
|
|
301
342
|
)
|
|
302
343
|
|
|
@@ -311,6 +352,7 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
311
352
|
transformer=pipe.transformer,
|
|
312
353
|
blocks=pipe.transformer.transformer_layers,
|
|
313
354
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
355
|
+
check_forward_pattern=True,
|
|
314
356
|
**kwargs,
|
|
315
357
|
)
|
|
316
358
|
|
|
@@ -331,6 +373,7 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
331
373
|
ForwardPattern.Pattern_0,
|
|
332
374
|
ForwardPattern.Pattern_0,
|
|
333
375
|
],
|
|
376
|
+
check_forward_pattern=True,
|
|
334
377
|
**kwargs,
|
|
335
378
|
)
|
|
336
379
|
|
|
@@ -348,6 +391,7 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
348
391
|
transformer=pipe.transformer,
|
|
349
392
|
blocks=pipe.transformer.layers,
|
|
350
393
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
394
|
+
check_forward_pattern=True,
|
|
351
395
|
**kwargs,
|
|
352
396
|
)
|
|
353
397
|
|
|
@@ -362,6 +406,7 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
362
406
|
transformer=pipe.transformer,
|
|
363
407
|
blocks=pipe.transformer.layers,
|
|
364
408
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
409
|
+
check_forward_pattern=True,
|
|
365
410
|
**kwargs,
|
|
366
411
|
)
|
|
367
412
|
|
|
@@ -376,6 +421,7 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
376
421
|
transformer=pipe.transformer,
|
|
377
422
|
blocks=pipe.transformer.transformer_blocks,
|
|
378
423
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
424
|
+
check_forward_pattern=True,
|
|
379
425
|
**kwargs,
|
|
380
426
|
)
|
|
381
427
|
|
|
@@ -390,6 +436,7 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
390
436
|
transformer=pipe.transformer,
|
|
391
437
|
blocks=pipe.transformer.transformer_blocks,
|
|
392
438
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
439
|
+
check_forward_pattern=True,
|
|
393
440
|
**kwargs,
|
|
394
441
|
)
|
|
395
442
|
|
|
@@ -404,6 +451,7 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
404
451
|
transformer=pipe.transformer,
|
|
405
452
|
blocks=pipe.transformer.transformer_blocks,
|
|
406
453
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
454
|
+
check_forward_pattern=True,
|
|
407
455
|
**kwargs,
|
|
408
456
|
)
|
|
409
457
|
|
|
@@ -426,6 +474,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
426
474
|
ForwardPattern.Pattern_1,
|
|
427
475
|
ForwardPattern.Pattern_1,
|
|
428
476
|
],
|
|
477
|
+
check_forward_pattern=True,
|
|
429
478
|
**kwargs,
|
|
430
479
|
)
|
|
431
480
|
else:
|
|
@@ -440,6 +489,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
440
489
|
ForwardPattern.Pattern_1,
|
|
441
490
|
ForwardPattern.Pattern_3,
|
|
442
491
|
],
|
|
492
|
+
check_forward_pattern=True,
|
|
443
493
|
**kwargs,
|
|
444
494
|
)
|
|
445
495
|
|
|
@@ -454,6 +504,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
454
504
|
transformer=pipe.transformer,
|
|
455
505
|
blocks=pipe.transformer.single_transformer_blocks,
|
|
456
506
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
507
|
+
check_forward_pattern=True,
|
|
457
508
|
**kwargs,
|
|
458
509
|
)
|
|
459
510
|
|
|
@@ -461,7 +512,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
461
512
|
@BlockAdapterRegistry.register("Chroma")
|
|
462
513
|
def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
463
514
|
from diffusers import ChromaTransformer2DModel
|
|
464
|
-
from cache_dit.
|
|
515
|
+
from cache_dit.caching.patch_functors import ChromaPatchFunctor
|
|
465
516
|
|
|
466
517
|
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
467
518
|
return BlockAdapter(
|
|
@@ -476,6 +527,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
476
527
|
ForwardPattern.Pattern_3,
|
|
477
528
|
],
|
|
478
529
|
patch_functor=ChromaPatchFunctor(),
|
|
530
|
+
check_forward_pattern=True,
|
|
479
531
|
has_separate_cfg=True,
|
|
480
532
|
**kwargs,
|
|
481
533
|
)
|
|
@@ -491,6 +543,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
491
543
|
transformer=pipe.prior,
|
|
492
544
|
blocks=pipe.prior.transformer_blocks,
|
|
493
545
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
546
|
+
check_forward_pattern=True,
|
|
494
547
|
**kwargs,
|
|
495
548
|
)
|
|
496
549
|
|
|
@@ -503,7 +556,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
503
556
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
504
557
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
505
558
|
from diffusers import HiDreamImageTransformer2DModel
|
|
506
|
-
from cache_dit.
|
|
559
|
+
from cache_dit.caching.patch_functors import HiDreamPatchFunctor
|
|
507
560
|
|
|
508
561
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
509
562
|
return BlockAdapter(
|
|
@@ -528,7 +581,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
528
581
|
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
529
582
|
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
530
583
|
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
531
|
-
from cache_dit.
|
|
584
|
+
from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
|
|
532
585
|
|
|
533
586
|
assert isinstance(
|
|
534
587
|
pipe.transformer,
|
|
@@ -540,6 +593,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
540
593
|
blocks=pipe.transformer.blocks,
|
|
541
594
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
542
595
|
patch_functor=HunyuanDiTPatchFunctor(),
|
|
596
|
+
check_forward_pattern=True,
|
|
543
597
|
**kwargs,
|
|
544
598
|
)
|
|
545
599
|
|
|
@@ -547,7 +601,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
547
601
|
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
548
602
|
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
549
603
|
from diffusers import HunyuanDiT2DModel
|
|
550
|
-
from cache_dit.
|
|
604
|
+
from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
|
|
551
605
|
|
|
552
606
|
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
553
607
|
return BlockAdapter(
|
|
@@ -556,5 +610,82 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
556
610
|
blocks=pipe.transformer.blocks,
|
|
557
611
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
558
612
|
patch_functor=HunyuanDiTPatchFunctor(),
|
|
613
|
+
check_forward_pattern=True,
|
|
559
614
|
**kwargs,
|
|
560
615
|
)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@BlockAdapterRegistry.register("Kandinsky5")
|
|
619
|
+
def kandinsky5_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
620
|
+
try:
|
|
621
|
+
from diffusers import Kandinsky5Transformer3DModel
|
|
622
|
+
|
|
623
|
+
assert isinstance(pipe.transformer, Kandinsky5Transformer3DModel)
|
|
624
|
+
return BlockAdapter(
|
|
625
|
+
pipe=pipe,
|
|
626
|
+
transformer=pipe.transformer,
|
|
627
|
+
blocks=pipe.transformer.visual_transformer_blocks,
|
|
628
|
+
forward_pattern=ForwardPattern.Pattern_3, # or Pattern_2
|
|
629
|
+
has_separate_cfg=True,
|
|
630
|
+
check_forward_pattern=False,
|
|
631
|
+
check_num_outputs=False,
|
|
632
|
+
**kwargs,
|
|
633
|
+
)
|
|
634
|
+
except ImportError:
|
|
635
|
+
raise ImportError(
|
|
636
|
+
"Kandinsky5Transformer3DModel is not available in the current diffusers version. "
|
|
637
|
+
"Please upgrade diffusers>=0.36.dev0 to use this adapter."
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
@BlockAdapterRegistry.register("PRX")
|
|
642
|
+
def prx_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
643
|
+
try:
|
|
644
|
+
from diffusers import PRXTransformer2DModel
|
|
645
|
+
|
|
646
|
+
assert isinstance(pipe.transformer, PRXTransformer2DModel)
|
|
647
|
+
return BlockAdapter(
|
|
648
|
+
pipe=pipe,
|
|
649
|
+
transformer=pipe.transformer,
|
|
650
|
+
blocks=pipe.transformer.blocks,
|
|
651
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
652
|
+
check_forward_pattern=True,
|
|
653
|
+
check_num_outputs=False,
|
|
654
|
+
**kwargs,
|
|
655
|
+
)
|
|
656
|
+
except ImportError:
|
|
657
|
+
raise ImportError(
|
|
658
|
+
"PRXTransformer2DModel is not available in the current diffusers version. "
|
|
659
|
+
"Please upgrade diffusers>=0.36.dev0 to use this adapter."
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
@BlockAdapterRegistry.register("HunyuanImage")
|
|
664
|
+
def hunyuan_image_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
665
|
+
try:
|
|
666
|
+
from diffusers import HunyuanImageTransformer2DModel
|
|
667
|
+
|
|
668
|
+
assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
|
|
669
|
+
return BlockAdapter(
|
|
670
|
+
pipe=pipe,
|
|
671
|
+
transformer=pipe.transformer,
|
|
672
|
+
blocks=[
|
|
673
|
+
pipe.transformer.transformer_blocks,
|
|
674
|
+
pipe.transformer.single_transformer_blocks,
|
|
675
|
+
],
|
|
676
|
+
forward_pattern=[
|
|
677
|
+
ForwardPattern.Pattern_0,
|
|
678
|
+
ForwardPattern.Pattern_0,
|
|
679
|
+
],
|
|
680
|
+
# set `has_separate_cfg` as True to enable separate cfg caching
|
|
681
|
+
# since in hyimage-2.1 the `guider_state` contains 2 input batches.
|
|
682
|
+
# The cfg is `enabled` by default in AdaptiveProjectedMixGuidance.
|
|
683
|
+
has_separate_cfg=True,
|
|
684
|
+
check_forward_pattern=True,
|
|
685
|
+
**kwargs,
|
|
686
|
+
)
|
|
687
|
+
except ImportError:
|
|
688
|
+
raise ImportError(
|
|
689
|
+
"HunyuanImageTransformer2DModel is not available in the current diffusers version. "
|
|
690
|
+
"Please upgrade diffusers>=0.36.dev0 to use this adapter."
|
|
691
|
+
)
|