cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -4,32 +4,50 @@ except ImportError:
|
|
|
4
4
|
__version__ = "unknown version"
|
|
5
5
|
version_tuple = (0, 0, "unknown version")
|
|
6
6
|
|
|
7
|
-
|
|
8
|
-
from cache_dit.utils import strify
|
|
7
|
+
|
|
9
8
|
from cache_dit.utils import disable_print
|
|
10
9
|
from cache_dit.logger import init_logger
|
|
11
|
-
from cache_dit.
|
|
12
|
-
from cache_dit.
|
|
13
|
-
from cache_dit.
|
|
14
|
-
from cache_dit.
|
|
15
|
-
from cache_dit.
|
|
16
|
-
from cache_dit.
|
|
17
|
-
from cache_dit.
|
|
18
|
-
from cache_dit.
|
|
19
|
-
from cache_dit.
|
|
20
|
-
from cache_dit.
|
|
21
|
-
from cache_dit.
|
|
22
|
-
from cache_dit.
|
|
23
|
-
from cache_dit.
|
|
24
|
-
from cache_dit.
|
|
25
|
-
from cache_dit.
|
|
26
|
-
from cache_dit.
|
|
10
|
+
from cache_dit.caching import load_options
|
|
11
|
+
from cache_dit.caching import enable_cache
|
|
12
|
+
from cache_dit.caching import disable_cache
|
|
13
|
+
from cache_dit.caching import cache_type
|
|
14
|
+
from cache_dit.caching import block_range
|
|
15
|
+
from cache_dit.caching import CacheType
|
|
16
|
+
from cache_dit.caching import BlockAdapter
|
|
17
|
+
from cache_dit.caching import ParamsModifier
|
|
18
|
+
from cache_dit.caching import ForwardPattern
|
|
19
|
+
from cache_dit.caching import PatchFunctor
|
|
20
|
+
from cache_dit.caching import BasicCacheConfig
|
|
21
|
+
from cache_dit.caching import DBCacheConfig
|
|
22
|
+
from cache_dit.caching import DBPruneConfig
|
|
23
|
+
from cache_dit.caching import CalibratorConfig
|
|
24
|
+
from cache_dit.caching import TaylorSeerCalibratorConfig
|
|
25
|
+
from cache_dit.caching import FoCaCalibratorConfig
|
|
26
|
+
from cache_dit.caching import supported_pipelines
|
|
27
|
+
from cache_dit.caching import get_adapter
|
|
27
28
|
from cache_dit.compile import set_compile_configs
|
|
28
|
-
from cache_dit.
|
|
29
|
+
from cache_dit.parallelism import ParallelismBackend
|
|
30
|
+
from cache_dit.parallelism import ParallelismConfig
|
|
31
|
+
from cache_dit.summary import supported_matrix
|
|
32
|
+
from cache_dit.summary import summary
|
|
33
|
+
from cache_dit.summary import strify
|
|
34
|
+
|
|
35
|
+
try:
|
|
36
|
+
from cache_dit.quantize import quantize
|
|
37
|
+
except ImportError as e: # noqa: F841
|
|
38
|
+
err_msg = str(e)
|
|
39
|
+
|
|
40
|
+
def quantize(*args, **kwargs):
|
|
41
|
+
raise ImportError(
|
|
42
|
+
"Quantization requires additional dependencies. "
|
|
43
|
+
"Please install cache-dit[quantization] or cache-dit[all] "
|
|
44
|
+
f"to use this feature. Error message: {err_msg}"
|
|
45
|
+
)
|
|
29
46
|
|
|
30
47
|
|
|
31
48
|
NONE = CacheType.NONE
|
|
32
49
|
DBCache = CacheType.DBCache
|
|
50
|
+
DBPrune = CacheType.DBPrune
|
|
33
51
|
|
|
34
52
|
Pattern_0 = ForwardPattern.Pattern_0
|
|
35
53
|
Pattern_1 = ForwardPattern.Pattern_1
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '1.0.
|
|
32
|
-
__version_tuple__ = version_tuple = (1, 0,
|
|
31
|
+
__version__ = version = '1.0.14'
|
|
32
|
+
__version_tuple__ = version_tuple = (1, 0, 14)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from cache_dit.caching.cache_types import CacheType
|
|
2
|
+
from cache_dit.caching.cache_types import cache_type
|
|
3
|
+
from cache_dit.caching.cache_types import block_range
|
|
4
|
+
|
|
5
|
+
from cache_dit.caching.forward_pattern import ForwardPattern
|
|
6
|
+
from cache_dit.caching.params_modifier import ParamsModifier
|
|
7
|
+
from cache_dit.caching.patch_functors import PatchFunctor
|
|
8
|
+
|
|
9
|
+
from cache_dit.caching.block_adapters import BlockAdapter
|
|
10
|
+
from cache_dit.caching.block_adapters import BlockAdapterRegistry
|
|
11
|
+
from cache_dit.caching.block_adapters import FakeDiffusionPipeline
|
|
12
|
+
|
|
13
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
14
|
+
from cache_dit.caching.cache_contexts import DBCacheConfig
|
|
15
|
+
from cache_dit.caching.cache_contexts import CachedContext
|
|
16
|
+
from cache_dit.caching.cache_contexts import CachedContextManager
|
|
17
|
+
from cache_dit.caching.cache_contexts import DBPruneConfig
|
|
18
|
+
from cache_dit.caching.cache_contexts import PrunedContext
|
|
19
|
+
from cache_dit.caching.cache_contexts import PrunedContextManager
|
|
20
|
+
from cache_dit.caching.cache_contexts import ContextManager
|
|
21
|
+
from cache_dit.caching.cache_contexts import CalibratorConfig
|
|
22
|
+
from cache_dit.caching.cache_contexts import TaylorSeerCalibratorConfig
|
|
23
|
+
from cache_dit.caching.cache_contexts import FoCaCalibratorConfig
|
|
24
|
+
|
|
25
|
+
from cache_dit.caching.cache_blocks import CachedBlocks
|
|
26
|
+
from cache_dit.caching.cache_blocks import PrunedBlocks
|
|
27
|
+
from cache_dit.caching.cache_blocks import UnifiedBlocks
|
|
28
|
+
|
|
29
|
+
from cache_dit.caching.cache_adapters import CachedAdapter
|
|
30
|
+
|
|
31
|
+
from cache_dit.caching.cache_interface import enable_cache
|
|
32
|
+
from cache_dit.caching.cache_interface import disable_cache
|
|
33
|
+
from cache_dit.caching.cache_interface import supported_pipelines
|
|
34
|
+
from cache_dit.caching.cache_interface import get_adapter
|
|
35
|
+
|
|
36
|
+
from cache_dit.caching.utils import load_options
|
|
@@ -1,7 +1,10 @@
|
|
|
1
|
-
from cache_dit.
|
|
2
|
-
from cache_dit.
|
|
3
|
-
from cache_dit.
|
|
4
|
-
|
|
1
|
+
from cache_dit.caching.forward_pattern import ForwardPattern
|
|
2
|
+
from cache_dit.caching.block_adapters.block_adapters import BlockAdapter
|
|
3
|
+
from cache_dit.caching.block_adapters.block_adapters import (
|
|
4
|
+
FakeDiffusionPipeline,
|
|
5
|
+
)
|
|
6
|
+
from cache_dit.caching.block_adapters.block_adapters import ParamsModifier
|
|
7
|
+
from cache_dit.caching.block_adapters.block_registers import (
|
|
5
8
|
BlockAdapterRegistry,
|
|
6
9
|
)
|
|
7
10
|
|
|
@@ -12,7 +15,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
12
15
|
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
13
16
|
|
|
14
17
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
15
|
-
|
|
18
|
+
transformer_cls_name: str = pipe.transformer.__class__.__name__
|
|
19
|
+
if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
|
|
20
|
+
"Nunchaku"
|
|
21
|
+
):
|
|
16
22
|
return BlockAdapter(
|
|
17
23
|
pipe=pipe,
|
|
18
24
|
transformer=pipe.transformer,
|
|
@@ -24,6 +30,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
24
30
|
ForwardPattern.Pattern_1,
|
|
25
31
|
ForwardPattern.Pattern_1,
|
|
26
32
|
],
|
|
33
|
+
check_forward_pattern=True,
|
|
27
34
|
**kwargs,
|
|
28
35
|
)
|
|
29
36
|
else:
|
|
@@ -38,6 +45,7 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
38
45
|
ForwardPattern.Pattern_1,
|
|
39
46
|
ForwardPattern.Pattern_3,
|
|
40
47
|
],
|
|
48
|
+
check_forward_pattern=True,
|
|
41
49
|
**kwargs,
|
|
42
50
|
)
|
|
43
51
|
|
|
@@ -52,6 +60,7 @@ def mochi_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
52
60
|
transformer=pipe.transformer,
|
|
53
61
|
blocks=pipe.transformer.transformer_blocks,
|
|
54
62
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
63
|
+
check_forward_pattern=True,
|
|
55
64
|
**kwargs,
|
|
56
65
|
)
|
|
57
66
|
|
|
@@ -66,6 +75,7 @@ def cogvideox_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
66
75
|
transformer=pipe.transformer,
|
|
67
76
|
blocks=pipe.transformer.transformer_blocks,
|
|
68
77
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
78
|
+
check_forward_pattern=True,
|
|
69
79
|
**kwargs,
|
|
70
80
|
)
|
|
71
81
|
|
|
@@ -101,6 +111,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
101
111
|
ForwardPattern.Pattern_2,
|
|
102
112
|
ForwardPattern.Pattern_2,
|
|
103
113
|
],
|
|
114
|
+
check_forward_pattern=True,
|
|
104
115
|
has_separate_cfg=True,
|
|
105
116
|
**kwargs,
|
|
106
117
|
)
|
|
@@ -111,6 +122,7 @@ def wan_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
111
122
|
transformer=pipe.transformer,
|
|
112
123
|
blocks=pipe.transformer.blocks,
|
|
113
124
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
125
|
+
check_forward_pattern=True,
|
|
114
126
|
has_separate_cfg=True,
|
|
115
127
|
**kwargs,
|
|
116
128
|
)
|
|
@@ -132,6 +144,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
132
144
|
ForwardPattern.Pattern_0,
|
|
133
145
|
ForwardPattern.Pattern_0,
|
|
134
146
|
],
|
|
147
|
+
check_forward_pattern=True,
|
|
135
148
|
# The type hint in diffusers is wrong
|
|
136
149
|
check_num_outputs=False,
|
|
137
150
|
**kwargs,
|
|
@@ -146,7 +159,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
146
159
|
|
|
147
160
|
pipe_cls_name: str = pipe.__class__.__name__
|
|
148
161
|
if pipe_cls_name.startswith("QwenImageControlNet"):
|
|
149
|
-
from cache_dit.
|
|
162
|
+
from cache_dit.caching.patch_functors import (
|
|
150
163
|
QwenImageControlNetPatchFunctor,
|
|
151
164
|
)
|
|
152
165
|
|
|
@@ -156,6 +169,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
156
169
|
blocks=pipe.transformer.transformer_blocks,
|
|
157
170
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
158
171
|
patch_functor=QwenImageControlNetPatchFunctor(),
|
|
172
|
+
check_forward_pattern=True,
|
|
159
173
|
has_separate_cfg=True,
|
|
160
174
|
)
|
|
161
175
|
else:
|
|
@@ -164,6 +178,7 @@ def qwenimage_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
164
178
|
transformer=pipe.transformer,
|
|
165
179
|
blocks=pipe.transformer.transformer_blocks,
|
|
166
180
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
181
|
+
check_forward_pattern=True,
|
|
167
182
|
has_separate_cfg=True,
|
|
168
183
|
**kwargs,
|
|
169
184
|
)
|
|
@@ -179,6 +194,7 @@ def ltxvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
179
194
|
transformer=pipe.transformer,
|
|
180
195
|
blocks=pipe.transformer.transformer_blocks,
|
|
181
196
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
197
|
+
check_forward_pattern=True,
|
|
182
198
|
**kwargs,
|
|
183
199
|
)
|
|
184
200
|
|
|
@@ -193,6 +209,7 @@ def allegro_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
193
209
|
transformer=pipe.transformer,
|
|
194
210
|
blocks=pipe.transformer.transformer_blocks,
|
|
195
211
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
212
|
+
check_forward_pattern=True,
|
|
196
213
|
**kwargs,
|
|
197
214
|
)
|
|
198
215
|
|
|
@@ -207,6 +224,7 @@ def cogview3plus_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
207
224
|
transformer=pipe.transformer,
|
|
208
225
|
blocks=pipe.transformer.transformer_blocks,
|
|
209
226
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
227
|
+
check_forward_pattern=True,
|
|
210
228
|
**kwargs,
|
|
211
229
|
)
|
|
212
230
|
|
|
@@ -221,6 +239,7 @@ def cogview4_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
221
239
|
transformer=pipe.transformer,
|
|
222
240
|
blocks=pipe.transformer.transformer_blocks,
|
|
223
241
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
242
|
+
check_forward_pattern=True,
|
|
224
243
|
has_separate_cfg=True,
|
|
225
244
|
**kwargs,
|
|
226
245
|
)
|
|
@@ -236,6 +255,7 @@ def cosmos_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
236
255
|
transformer=pipe.transformer,
|
|
237
256
|
blocks=pipe.transformer.transformer_blocks,
|
|
238
257
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
258
|
+
check_forward_pattern=True,
|
|
239
259
|
has_separate_cfg=True,
|
|
240
260
|
**kwargs,
|
|
241
261
|
)
|
|
@@ -251,6 +271,7 @@ def easyanimate_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
251
271
|
transformer=pipe.transformer,
|
|
252
272
|
blocks=pipe.transformer.transformer_blocks,
|
|
253
273
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
274
|
+
check_forward_pattern=True,
|
|
254
275
|
**kwargs,
|
|
255
276
|
)
|
|
256
277
|
|
|
@@ -268,6 +289,7 @@ def skyreelsv2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
268
289
|
# encoder_hidden_states will never change in the blocks
|
|
269
290
|
# forward loop.
|
|
270
291
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
292
|
+
check_forward_pattern=True,
|
|
271
293
|
has_separate_cfg=True,
|
|
272
294
|
**kwargs,
|
|
273
295
|
)
|
|
@@ -283,6 +305,7 @@ def sd3_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
283
305
|
transformer=pipe.transformer,
|
|
284
306
|
blocks=pipe.transformer.transformer_blocks,
|
|
285
307
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
308
|
+
check_forward_pattern=True,
|
|
286
309
|
**kwargs,
|
|
287
310
|
)
|
|
288
311
|
|
|
@@ -297,6 +320,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
297
320
|
transformer=pipe.transformer,
|
|
298
321
|
blocks=pipe.transformer.transformer_blocks,
|
|
299
322
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
323
|
+
check_forward_pattern=True,
|
|
300
324
|
**kwargs,
|
|
301
325
|
)
|
|
302
326
|
|
|
@@ -304,7 +328,7 @@ def consisid_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
304
328
|
@BlockAdapterRegistry.register("DiT")
|
|
305
329
|
def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
306
330
|
from diffusers import DiTTransformer2DModel
|
|
307
|
-
from cache_dit.
|
|
331
|
+
from cache_dit.caching.patch_functors import DiTPatchFunctor
|
|
308
332
|
|
|
309
333
|
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
310
334
|
return BlockAdapter(
|
|
@@ -313,6 +337,7 @@ def dit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
313
337
|
blocks=pipe.transformer.transformer_blocks,
|
|
314
338
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
315
339
|
patch_functor=DiTPatchFunctor(),
|
|
340
|
+
check_forward_pattern=True,
|
|
316
341
|
**kwargs,
|
|
317
342
|
)
|
|
318
343
|
|
|
@@ -327,6 +352,7 @@ def amused_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
327
352
|
transformer=pipe.transformer,
|
|
328
353
|
blocks=pipe.transformer.transformer_layers,
|
|
329
354
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
355
|
+
check_forward_pattern=True,
|
|
330
356
|
**kwargs,
|
|
331
357
|
)
|
|
332
358
|
|
|
@@ -347,6 +373,7 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
347
373
|
ForwardPattern.Pattern_0,
|
|
348
374
|
ForwardPattern.Pattern_0,
|
|
349
375
|
],
|
|
376
|
+
check_forward_pattern=True,
|
|
350
377
|
**kwargs,
|
|
351
378
|
)
|
|
352
379
|
|
|
@@ -364,6 +391,7 @@ def lumina2_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
364
391
|
transformer=pipe.transformer,
|
|
365
392
|
blocks=pipe.transformer.layers,
|
|
366
393
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
394
|
+
check_forward_pattern=True,
|
|
367
395
|
**kwargs,
|
|
368
396
|
)
|
|
369
397
|
|
|
@@ -378,6 +406,7 @@ def omnigen_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
378
406
|
transformer=pipe.transformer,
|
|
379
407
|
blocks=pipe.transformer.layers,
|
|
380
408
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
409
|
+
check_forward_pattern=True,
|
|
381
410
|
**kwargs,
|
|
382
411
|
)
|
|
383
412
|
|
|
@@ -392,6 +421,7 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
392
421
|
transformer=pipe.transformer,
|
|
393
422
|
blocks=pipe.transformer.transformer_blocks,
|
|
394
423
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
424
|
+
check_forward_pattern=True,
|
|
395
425
|
**kwargs,
|
|
396
426
|
)
|
|
397
427
|
|
|
@@ -406,6 +436,7 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
406
436
|
transformer=pipe.transformer,
|
|
407
437
|
blocks=pipe.transformer.transformer_blocks,
|
|
408
438
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
439
|
+
check_forward_pattern=True,
|
|
409
440
|
**kwargs,
|
|
410
441
|
)
|
|
411
442
|
|
|
@@ -420,6 +451,7 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
420
451
|
transformer=pipe.transformer,
|
|
421
452
|
blocks=pipe.transformer.transformer_blocks,
|
|
422
453
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
454
|
+
check_forward_pattern=True,
|
|
423
455
|
**kwargs,
|
|
424
456
|
)
|
|
425
457
|
|
|
@@ -442,6 +474,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
442
474
|
ForwardPattern.Pattern_1,
|
|
443
475
|
ForwardPattern.Pattern_1,
|
|
444
476
|
],
|
|
477
|
+
check_forward_pattern=True,
|
|
445
478
|
**kwargs,
|
|
446
479
|
)
|
|
447
480
|
else:
|
|
@@ -456,6 +489,7 @@ def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
456
489
|
ForwardPattern.Pattern_1,
|
|
457
490
|
ForwardPattern.Pattern_3,
|
|
458
491
|
],
|
|
492
|
+
check_forward_pattern=True,
|
|
459
493
|
**kwargs,
|
|
460
494
|
)
|
|
461
495
|
|
|
@@ -470,6 +504,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
470
504
|
transformer=pipe.transformer,
|
|
471
505
|
blocks=pipe.transformer.single_transformer_blocks,
|
|
472
506
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
507
|
+
check_forward_pattern=True,
|
|
473
508
|
**kwargs,
|
|
474
509
|
)
|
|
475
510
|
|
|
@@ -477,7 +512,7 @@ def auraflow_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
477
512
|
@BlockAdapterRegistry.register("Chroma")
|
|
478
513
|
def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
479
514
|
from diffusers import ChromaTransformer2DModel
|
|
480
|
-
from cache_dit.
|
|
515
|
+
from cache_dit.caching.patch_functors import ChromaPatchFunctor
|
|
481
516
|
|
|
482
517
|
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
483
518
|
return BlockAdapter(
|
|
@@ -492,6 +527,7 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
492
527
|
ForwardPattern.Pattern_3,
|
|
493
528
|
],
|
|
494
529
|
patch_functor=ChromaPatchFunctor(),
|
|
530
|
+
check_forward_pattern=True,
|
|
495
531
|
has_separate_cfg=True,
|
|
496
532
|
**kwargs,
|
|
497
533
|
)
|
|
@@ -507,6 +543,7 @@ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
507
543
|
transformer=pipe.prior,
|
|
508
544
|
blocks=pipe.prior.transformer_blocks,
|
|
509
545
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
546
|
+
check_forward_pattern=True,
|
|
510
547
|
**kwargs,
|
|
511
548
|
)
|
|
512
549
|
|
|
@@ -519,7 +556,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
519
556
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
520
557
|
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
521
558
|
from diffusers import HiDreamImageTransformer2DModel
|
|
522
|
-
from cache_dit.
|
|
559
|
+
from cache_dit.caching.patch_functors import HiDreamPatchFunctor
|
|
523
560
|
|
|
524
561
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
525
562
|
return BlockAdapter(
|
|
@@ -544,7 +581,7 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
544
581
|
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
545
582
|
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
546
583
|
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
547
|
-
from cache_dit.
|
|
584
|
+
from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
|
|
548
585
|
|
|
549
586
|
assert isinstance(
|
|
550
587
|
pipe.transformer,
|
|
@@ -556,6 +593,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
556
593
|
blocks=pipe.transformer.blocks,
|
|
557
594
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
558
595
|
patch_functor=HunyuanDiTPatchFunctor(),
|
|
596
|
+
check_forward_pattern=True,
|
|
559
597
|
**kwargs,
|
|
560
598
|
)
|
|
561
599
|
|
|
@@ -563,7 +601,7 @@ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
563
601
|
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
564
602
|
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
565
603
|
from diffusers import HunyuanDiT2DModel
|
|
566
|
-
from cache_dit.
|
|
604
|
+
from cache_dit.caching.patch_functors import HunyuanDiTPatchFunctor
|
|
567
605
|
|
|
568
606
|
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
569
607
|
return BlockAdapter(
|
|
@@ -572,5 +610,82 @@ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
572
610
|
blocks=pipe.transformer.blocks,
|
|
573
611
|
forward_pattern=ForwardPattern.Pattern_3,
|
|
574
612
|
patch_functor=HunyuanDiTPatchFunctor(),
|
|
613
|
+
check_forward_pattern=True,
|
|
575
614
|
**kwargs,
|
|
576
615
|
)
|
|
616
|
+
|
|
617
|
+
|
|
618
|
+
@BlockAdapterRegistry.register("Kandinsky5")
|
|
619
|
+
def kandinsky5_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
620
|
+
try:
|
|
621
|
+
from diffusers import Kandinsky5Transformer3DModel
|
|
622
|
+
|
|
623
|
+
assert isinstance(pipe.transformer, Kandinsky5Transformer3DModel)
|
|
624
|
+
return BlockAdapter(
|
|
625
|
+
pipe=pipe,
|
|
626
|
+
transformer=pipe.transformer,
|
|
627
|
+
blocks=pipe.transformer.visual_transformer_blocks,
|
|
628
|
+
forward_pattern=ForwardPattern.Pattern_3, # or Pattern_2
|
|
629
|
+
has_separate_cfg=True,
|
|
630
|
+
check_forward_pattern=False,
|
|
631
|
+
check_num_outputs=False,
|
|
632
|
+
**kwargs,
|
|
633
|
+
)
|
|
634
|
+
except ImportError:
|
|
635
|
+
raise ImportError(
|
|
636
|
+
"Kandinsky5Transformer3DModel is not available in the current diffusers version. "
|
|
637
|
+
"Please upgrade diffusers>=0.36.dev0 to use this adapter."
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
@BlockAdapterRegistry.register("PRX")
|
|
642
|
+
def prx_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
643
|
+
try:
|
|
644
|
+
from diffusers import PRXTransformer2DModel
|
|
645
|
+
|
|
646
|
+
assert isinstance(pipe.transformer, PRXTransformer2DModel)
|
|
647
|
+
return BlockAdapter(
|
|
648
|
+
pipe=pipe,
|
|
649
|
+
transformer=pipe.transformer,
|
|
650
|
+
blocks=pipe.transformer.blocks,
|
|
651
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
652
|
+
check_forward_pattern=True,
|
|
653
|
+
check_num_outputs=False,
|
|
654
|
+
**kwargs,
|
|
655
|
+
)
|
|
656
|
+
except ImportError:
|
|
657
|
+
raise ImportError(
|
|
658
|
+
"PRXTransformer2DModel is not available in the current diffusers version. "
|
|
659
|
+
"Please upgrade diffusers>=0.36.dev0 to use this adapter."
|
|
660
|
+
)
|
|
661
|
+
|
|
662
|
+
|
|
663
|
+
@BlockAdapterRegistry.register("HunyuanImage")
|
|
664
|
+
def hunyuan_image_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
665
|
+
try:
|
|
666
|
+
from diffusers import HunyuanImageTransformer2DModel
|
|
667
|
+
|
|
668
|
+
assert isinstance(pipe.transformer, HunyuanImageTransformer2DModel)
|
|
669
|
+
return BlockAdapter(
|
|
670
|
+
pipe=pipe,
|
|
671
|
+
transformer=pipe.transformer,
|
|
672
|
+
blocks=[
|
|
673
|
+
pipe.transformer.transformer_blocks,
|
|
674
|
+
pipe.transformer.single_transformer_blocks,
|
|
675
|
+
],
|
|
676
|
+
forward_pattern=[
|
|
677
|
+
ForwardPattern.Pattern_0,
|
|
678
|
+
ForwardPattern.Pattern_0,
|
|
679
|
+
],
|
|
680
|
+
# set `has_separate_cfg` as True to enable separate cfg caching
|
|
681
|
+
# since in hyimage-2.1 the `guider_state` contains 2 input batches.
|
|
682
|
+
# The cfg is `enabled` by default in AdaptiveProjectedMixGuidance.
|
|
683
|
+
has_separate_cfg=True,
|
|
684
|
+
check_forward_pattern=True,
|
|
685
|
+
**kwargs,
|
|
686
|
+
)
|
|
687
|
+
except ImportError:
|
|
688
|
+
raise ImportError(
|
|
689
|
+
"HunyuanImageTransformer2DModel is not available in the current diffusers version. "
|
|
690
|
+
"Please upgrade diffusers>=0.36.dev0 to use this adapter."
|
|
691
|
+
)
|
|
@@ -6,22 +6,32 @@ from collections.abc import Iterable
|
|
|
6
6
|
|
|
7
7
|
from typing import Any, Tuple, List, Optional, Union
|
|
8
8
|
|
|
9
|
-
from diffusers import DiffusionPipeline
|
|
10
|
-
from cache_dit.
|
|
11
|
-
from cache_dit.
|
|
12
|
-
from cache_dit.
|
|
9
|
+
from diffusers import DiffusionPipeline, ModelMixin
|
|
10
|
+
from cache_dit.caching.patch_functors import PatchFunctor
|
|
11
|
+
from cache_dit.caching.forward_pattern import ForwardPattern
|
|
12
|
+
from cache_dit.caching.params_modifier import ParamsModifier
|
|
13
13
|
|
|
14
14
|
from cache_dit.logger import init_logger
|
|
15
15
|
|
|
16
16
|
logger = init_logger(__name__)
|
|
17
17
|
|
|
18
18
|
|
|
19
|
+
class FakeDiffusionPipeline:
|
|
20
|
+
# A placeholder for pipelines when pipe is None.
|
|
21
|
+
def __init__(
|
|
22
|
+
self,
|
|
23
|
+
transformer: Optional[torch.nn.Module | ModelMixin] = None,
|
|
24
|
+
):
|
|
25
|
+
self.transformer = transformer # Reference only
|
|
26
|
+
|
|
27
|
+
|
|
19
28
|
@dataclasses.dataclass
|
|
20
29
|
class BlockAdapter:
|
|
21
30
|
|
|
22
31
|
# Transformer configurations.
|
|
23
32
|
pipe: Union[
|
|
24
33
|
DiffusionPipeline,
|
|
34
|
+
FakeDiffusionPipeline,
|
|
25
35
|
Any,
|
|
26
36
|
] = None
|
|
27
37
|
|
|
@@ -73,7 +83,7 @@ class BlockAdapter:
|
|
|
73
83
|
]
|
|
74
84
|
] = None
|
|
75
85
|
|
|
76
|
-
check_forward_pattern: bool =
|
|
86
|
+
check_forward_pattern: Optional[bool] = None
|
|
77
87
|
check_num_outputs: bool = False
|
|
78
88
|
|
|
79
89
|
# Pipeline Level Flags
|
|
@@ -110,12 +120,43 @@ class BlockAdapter:
|
|
|
110
120
|
def __post_init__(self):
|
|
111
121
|
if self.skip_post_init:
|
|
112
122
|
return
|
|
123
|
+
|
|
124
|
+
self.maybe_fake_pipe()
|
|
113
125
|
if any((self.pipe is not None, self.transformer is not None)):
|
|
114
126
|
self.maybe_fill_attrs()
|
|
115
127
|
self.maybe_patchify()
|
|
116
128
|
self.maybe_skip_checks()
|
|
117
129
|
|
|
130
|
+
def maybe_fake_pipe(self):
|
|
131
|
+
if self.pipe is None:
|
|
132
|
+
self.pipe = FakeDiffusionPipeline()
|
|
133
|
+
logger.warning("pipe is None, use FakeDiffusionPipeline instead.")
|
|
134
|
+
|
|
118
135
|
def maybe_skip_checks(self):
|
|
136
|
+
if self.check_forward_pattern is None:
|
|
137
|
+
if self.transformer is not None:
|
|
138
|
+
if self.nested_depth(self.transformer) == 0:
|
|
139
|
+
transformer = self.transformer
|
|
140
|
+
elif self.nested_depth(self.transformer) == 1:
|
|
141
|
+
transformer = self.transformer[0]
|
|
142
|
+
else:
|
|
143
|
+
raise ValueError(
|
|
144
|
+
"transformer nested depth can't more than 1, "
|
|
145
|
+
f"current is: {self.nested_depth(self.transformer)}"
|
|
146
|
+
)
|
|
147
|
+
if transformer.__module__.startswith("diffusers"):
|
|
148
|
+
self.check_forward_pattern = True
|
|
149
|
+
logger.info(
|
|
150
|
+
f"Found transformer from diffusers: {transformer.__module__} "
|
|
151
|
+
"enable check_forward_pattern by default."
|
|
152
|
+
)
|
|
153
|
+
else:
|
|
154
|
+
self.check_forward_pattern = False
|
|
155
|
+
logger.info(
|
|
156
|
+
f"Found transformer NOT from diffusers: {transformer.__module__} "
|
|
157
|
+
"disable check_forward_pattern by default."
|
|
158
|
+
)
|
|
159
|
+
|
|
119
160
|
if getattr(self.transformer, "_hf_hook", None) is not None:
|
|
120
161
|
logger.warning("_hf_hook is not None, force skip pattern check!")
|
|
121
162
|
self.check_forward_pattern = False
|
|
@@ -208,7 +249,10 @@ class BlockAdapter:
|
|
|
208
249
|
if self.transformer is not None:
|
|
209
250
|
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
210
251
|
else:
|
|
211
|
-
assert hasattr(self.pipe, "transformer")
|
|
252
|
+
assert hasattr(self.pipe, "transformer"), (
|
|
253
|
+
"pipe.transformer can not be None when patch_functor "
|
|
254
|
+
"is provided and transformer is None."
|
|
255
|
+
)
|
|
212
256
|
self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
|
|
213
257
|
|
|
214
258
|
@staticmethod
|
|
@@ -224,6 +268,10 @@ class BlockAdapter:
|
|
|
224
268
|
adapter.forward_pattern is not None
|
|
225
269
|
), "adapter.forward_pattern can not be None."
|
|
226
270
|
pipe = adapter.pipe
|
|
271
|
+
if isinstance(pipe, FakeDiffusionPipeline):
|
|
272
|
+
raise ValueError(
|
|
273
|
+
"Can not auto block adapter for FakeDiffusionPipeline."
|
|
274
|
+
)
|
|
227
275
|
|
|
228
276
|
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
229
277
|
|
|
@@ -489,6 +537,7 @@ class BlockAdapter:
|
|
|
489
537
|
@staticmethod
|
|
490
538
|
def normalize(
|
|
491
539
|
adapter: "BlockAdapter",
|
|
540
|
+
unique: bool = True,
|
|
492
541
|
) -> "BlockAdapter":
|
|
493
542
|
|
|
494
543
|
if getattr(adapter, "_is_normalized", False):
|
|
@@ -523,7 +572,10 @@ class BlockAdapter:
|
|
|
523
572
|
adapter.forward_pattern = _normalize_attr(adapter.forward_pattern)
|
|
524
573
|
adapter.dummy_blocks_names = _normalize_attr(adapter.dummy_blocks_names)
|
|
525
574
|
adapter.params_modifiers = _normalize_attr(adapter.params_modifiers)
|
|
526
|
-
|
|
575
|
+
# Some times, the cache_config will be None.
|
|
576
|
+
# So we do not perform unique check here.
|
|
577
|
+
if unique:
|
|
578
|
+
BlockAdapter.unique(adapter)
|
|
527
579
|
|
|
528
580
|
adapter._is_normalized = True
|
|
529
581
|
|
|
@@ -571,6 +623,10 @@ class BlockAdapter:
|
|
|
571
623
|
if not getattr(adapter, "_is_normalized", False):
|
|
572
624
|
raise RuntimeError("block_adapter must be normailzed.")
|
|
573
625
|
|
|
626
|
+
@classmethod
|
|
627
|
+
def is_normalized(cls, adapter: "BlockAdapter") -> bool:
|
|
628
|
+
return getattr(adapter, "_is_normalized", False)
|
|
629
|
+
|
|
574
630
|
@classmethod
|
|
575
631
|
def is_cached(cls, adapter: Any) -> bool:
|
|
576
632
|
if isinstance(adapter, cls):
|
|
@@ -592,6 +648,21 @@ class BlockAdapter:
|
|
|
592
648
|
else:
|
|
593
649
|
return getattr(adapter, "_is_cached", False)
|
|
594
650
|
|
|
651
|
+
@classmethod
|
|
652
|
+
def is_parallelized(cls, adapter: Any) -> bool:
|
|
653
|
+
if isinstance(adapter, cls):
|
|
654
|
+
cls.assert_normalized(adapter)
|
|
655
|
+
return getattr(adapter.transformer[0], "_is_parallelized", False)
|
|
656
|
+
elif isinstance(adapter, DiffusionPipeline):
|
|
657
|
+
return getattr(adapter.transformer, "_is_parallelized", False)
|
|
658
|
+
elif isinstance(adapter, torch.nn.Module):
|
|
659
|
+
return getattr(adapter, "_is_parallelized", False)
|
|
660
|
+
elif isinstance(adapter, list): # [TRN_0,...]
|
|
661
|
+
assert isinstance(adapter[0], torch.nn.Module)
|
|
662
|
+
return getattr(adapter[0], "_is_parallelized", False)
|
|
663
|
+
else:
|
|
664
|
+
return getattr(adapter, "_is_parallelized", False)
|
|
665
|
+
|
|
595
666
|
@classmethod
|
|
596
667
|
def nested_depth(cls, obj: Any):
|
|
597
668
|
# str: 0; List[str]: 1; List[List[str]]: 2
|