cache-dit 0.2.21__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +2 -4
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +5 -62
- cache_dit/cache_factory/cache_adapters.py +194 -91
- cache_dit/cache_factory/cache_blocks.py +9 -3
- cache_dit/cache_factory/cache_context.py +40 -40
- cache_dit/cache_factory/cache_interface.py +149 -0
- cache_dit/cache_factory/cache_types.py +21 -52
- cache_dit/cache_factory/taylorseer.py +0 -2
- cache_dit/cache_factory/utils.py +4 -0
- cache_dit/compile/utils.py +6 -2
- cache_dit/utils.py +33 -4
- {cache_dit-0.2.21.dist-info → cache_dit-0.2.23.dist-info}/METADATA +19 -17
- cache_dit-0.2.23.dist-info/RECORD +33 -0
- cache_dit-0.2.21.dist-info/RECORD +0 -32
- {cache_dit-0.2.21.dist-info → cache_dit-0.2.23.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.21.dist-info → cache_dit-0.2.23.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.21.dist-info → cache_dit-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.21.dist-info → cache_dit-0.2.23.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -7,20 +7,18 @@ except ImportError:
|
|
|
7
7
|
from cache_dit.cache_factory import load_options
|
|
8
8
|
from cache_dit.cache_factory import enable_cache
|
|
9
9
|
from cache_dit.cache_factory import cache_type
|
|
10
|
-
from cache_dit.cache_factory import default_options
|
|
11
10
|
from cache_dit.cache_factory import block_range
|
|
12
11
|
from cache_dit.cache_factory import CacheType
|
|
12
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
13
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
|
-
from cache_dit.cache_factory import BlockAdapterParams
|
|
15
14
|
from cache_dit.compile import set_compile_configs
|
|
16
15
|
from cache_dit.utils import summary
|
|
16
|
+
from cache_dit.utils import strify
|
|
17
17
|
from cache_dit.logger import init_logger
|
|
18
18
|
|
|
19
19
|
NONE = CacheType.NONE
|
|
20
20
|
DBCache = CacheType.DBCache
|
|
21
21
|
|
|
22
|
-
BlockAdapter = BlockAdapterParams
|
|
23
|
-
|
|
24
22
|
Forward_Pattern_0 = ForwardPattern.Pattern_0
|
|
25
23
|
Forward_Pattern_1 = ForwardPattern.Pattern_1
|
|
26
24
|
Forward_Pattern_2 = ForwardPattern.Pattern_2
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.23'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 23)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -1,65 +1,8 @@
|
|
|
1
|
-
from typing import Dict, List
|
|
2
|
-
from diffusers import DiffusionPipeline
|
|
3
1
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
4
2
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
5
|
-
from cache_dit.cache_factory.
|
|
3
|
+
from cache_dit.cache_factory.cache_types import cache_type
|
|
4
|
+
from cache_dit.cache_factory.cache_types import block_range
|
|
5
|
+
from cache_dit.cache_factory.cache_adapters import BlockAdapter
|
|
6
6
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
7
|
-
from cache_dit.cache_factory.
|
|
8
|
-
|
|
9
|
-
from cache_dit.logger import init_logger
|
|
10
|
-
|
|
11
|
-
logger = init_logger(__name__)
|
|
12
|
-
|
|
13
|
-
|
|
14
|
-
def load_options(path: str):
|
|
15
|
-
return load_cache_options_from_yaml(path)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
def cache_type(
|
|
19
|
-
type_hint: "CacheType | str",
|
|
20
|
-
) -> CacheType:
|
|
21
|
-
return CacheType.type(cache_type=type_hint)
|
|
22
|
-
|
|
23
|
-
|
|
24
|
-
def default_options(
|
|
25
|
-
cache_type: CacheType = CacheType.DBCache,
|
|
26
|
-
) -> Dict:
|
|
27
|
-
return CacheType.default_options(cache_type)
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
def block_range(
|
|
31
|
-
start: int,
|
|
32
|
-
end: int,
|
|
33
|
-
step: int = 1,
|
|
34
|
-
) -> List[int]:
|
|
35
|
-
return CacheType.block_range(
|
|
36
|
-
start,
|
|
37
|
-
end,
|
|
38
|
-
step,
|
|
39
|
-
)
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
def enable_cache(
|
|
43
|
-
pipe_or_adapter: DiffusionPipeline | BlockAdapterParams,
|
|
44
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
45
|
-
**cache_options_kwargs,
|
|
46
|
-
) -> DiffusionPipeline:
|
|
47
|
-
if isinstance(pipe_or_adapter, BlockAdapterParams):
|
|
48
|
-
return UnifiedCacheAdapter.apply(
|
|
49
|
-
pipe=None,
|
|
50
|
-
adapter_params=pipe_or_adapter,
|
|
51
|
-
forward_pattern=forward_pattern,
|
|
52
|
-
**cache_options_kwargs,
|
|
53
|
-
)
|
|
54
|
-
elif isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
55
|
-
return UnifiedCacheAdapter.apply(
|
|
56
|
-
pipe=pipe_or_adapter,
|
|
57
|
-
adapter_params=None,
|
|
58
|
-
forward_pattern=forward_pattern,
|
|
59
|
-
**cache_options_kwargs,
|
|
60
|
-
)
|
|
61
|
-
else:
|
|
62
|
-
raise ValueError(
|
|
63
|
-
"Please pass DiffusionPipeline or BlockAdapterParams"
|
|
64
|
-
"(BlockAdapter) for the 1 position param: pipe_or_adapter"
|
|
65
|
-
)
|
|
7
|
+
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
8
|
+
from cache_dit.cache_factory.utils import load_options
|
|
@@ -5,7 +5,7 @@ import unittest
|
|
|
5
5
|
import functools
|
|
6
6
|
import dataclasses
|
|
7
7
|
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any, Tuple, List
|
|
9
9
|
from contextlib import ExitStack
|
|
10
10
|
from diffusers import DiffusionPipeline
|
|
11
11
|
from cache_dit.cache_factory.patch.flux import (
|
|
@@ -23,28 +23,152 @@ logger = init_logger(__name__)
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@dataclasses.dataclass
|
|
26
|
-
class
|
|
26
|
+
class BlockAdapter:
|
|
27
27
|
pipe: DiffusionPipeline = None
|
|
28
28
|
transformer: torch.nn.Module = None
|
|
29
29
|
blocks: torch.nn.ModuleList = None
|
|
30
30
|
# transformer_blocks, blocks, etc.
|
|
31
31
|
blocks_name: str = None
|
|
32
32
|
dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
|
|
33
|
+
# flags to control auto block adapter
|
|
34
|
+
auto: bool = False
|
|
35
|
+
allow_prefixes: List[str] = dataclasses.field(
|
|
36
|
+
default_factory=lambda: [
|
|
37
|
+
"transformer",
|
|
38
|
+
"single_transformer",
|
|
39
|
+
"blocks",
|
|
40
|
+
"layers",
|
|
41
|
+
]
|
|
42
|
+
)
|
|
43
|
+
allow_suffixes: List[str] = dataclasses.field(
|
|
44
|
+
default_factory=lambda: ["TransformerBlock"]
|
|
45
|
+
)
|
|
46
|
+
check_suffixes: bool = False
|
|
47
|
+
blocks_policy: str = dataclasses.field(
|
|
48
|
+
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def auto_block_adapter(adapter: "BlockAdapter") -> "BlockAdapter":
|
|
53
|
+
assert adapter.auto, (
|
|
54
|
+
"Please manually set `auto` to True, or, manually "
|
|
55
|
+
"set all the transformer blocks configuration."
|
|
56
|
+
)
|
|
57
|
+
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
58
|
+
pipe = adapter.pipe
|
|
59
|
+
|
|
60
|
+
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
61
|
+
|
|
62
|
+
transformer = pipe.transformer
|
|
63
|
+
|
|
64
|
+
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
65
|
+
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
66
|
+
transformer=transformer,
|
|
67
|
+
allow_prefixes=adapter.allow_prefixes,
|
|
68
|
+
allow_suffixes=adapter.allow_suffixes,
|
|
69
|
+
check_suffixes=adapter.check_suffixes,
|
|
70
|
+
blocks_policy=adapter.blocks_policy,
|
|
71
|
+
)
|
|
33
72
|
|
|
34
|
-
|
|
73
|
+
return BlockAdapter(
|
|
74
|
+
pipe=pipe,
|
|
75
|
+
transformer=transformer,
|
|
76
|
+
blocks=blocks,
|
|
77
|
+
blocks_name=blocks_name,
|
|
78
|
+
)
|
|
79
|
+
|
|
80
|
+
@staticmethod
|
|
81
|
+
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
35
82
|
if (
|
|
36
|
-
isinstance(
|
|
37
|
-
and
|
|
38
|
-
and
|
|
39
|
-
and
|
|
83
|
+
isinstance(adapter.pipe, DiffusionPipeline)
|
|
84
|
+
and adapter.transformer is not None
|
|
85
|
+
and adapter.blocks is not None
|
|
86
|
+
and adapter.blocks_name is not None
|
|
87
|
+
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
40
88
|
):
|
|
41
89
|
return True
|
|
42
90
|
return False
|
|
43
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def find_blocks(
|
|
94
|
+
transformer: torch.nn.Module,
|
|
95
|
+
allow_prefixes: List[str] = [
|
|
96
|
+
"transformer",
|
|
97
|
+
"single_transformer",
|
|
98
|
+
"blocks",
|
|
99
|
+
"layers",
|
|
100
|
+
],
|
|
101
|
+
allow_suffixes: List[str] = [
|
|
102
|
+
"TransformerBlock",
|
|
103
|
+
],
|
|
104
|
+
check_suffixes: bool = False,
|
|
105
|
+
**kwargs,
|
|
106
|
+
) -> Tuple[torch.nn.ModuleList, str]:
|
|
107
|
+
|
|
108
|
+
blocks_names = []
|
|
109
|
+
for attr_name in dir(transformer):
|
|
110
|
+
for prefix in allow_prefixes:
|
|
111
|
+
if attr_name.startswith(prefix):
|
|
112
|
+
blocks_names.append(attr_name)
|
|
113
|
+
|
|
114
|
+
# Type check
|
|
115
|
+
valid_names = []
|
|
116
|
+
valid_count = []
|
|
117
|
+
for blocks_name in blocks_names:
|
|
118
|
+
if blocks := getattr(transformer, blocks_name, None):
|
|
119
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
120
|
+
block = blocks[0]
|
|
121
|
+
block_cls_name = block.__class__.__name__
|
|
122
|
+
if isinstance(block, torch.nn.Module) and (
|
|
123
|
+
any(
|
|
124
|
+
(
|
|
125
|
+
block_cls_name.endswith(allow_suffix)
|
|
126
|
+
for allow_suffix in allow_suffixes
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
or (not check_suffixes)
|
|
130
|
+
):
|
|
131
|
+
valid_names.append(blocks_name)
|
|
132
|
+
valid_count.append(len(blocks))
|
|
133
|
+
|
|
134
|
+
if not valid_names:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"Auto selected transformer blocks failed, please set it manually."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
final_name = valid_names[0]
|
|
140
|
+
final_count = valid_count[0]
|
|
141
|
+
block_policy = kwargs.get("blocks_policy", "max")
|
|
142
|
+
for blocks_name, count in zip(valid_names, valid_count):
|
|
143
|
+
blocks = getattr(transformer, blocks_name)
|
|
144
|
+
logger.info(
|
|
145
|
+
f"Auto selected transformer blocks: {blocks_name}, "
|
|
146
|
+
f"class: {blocks[0].__class__.__name__}, "
|
|
147
|
+
f"num blocks: {count}"
|
|
148
|
+
)
|
|
149
|
+
if block_policy == "max":
|
|
150
|
+
if final_count < count:
|
|
151
|
+
final_count = count
|
|
152
|
+
final_name = blocks_name
|
|
153
|
+
else:
|
|
154
|
+
if final_count > count:
|
|
155
|
+
final_count = count
|
|
156
|
+
final_name = blocks_name
|
|
157
|
+
|
|
158
|
+
final_blocks = getattr(transformer, final_name)
|
|
159
|
+
|
|
160
|
+
logger.info(
|
|
161
|
+
f"Final selected transformer blocks: {final_name}, "
|
|
162
|
+
f"class: {final_blocks[0].__class__.__name__}, "
|
|
163
|
+
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return final_blocks, final_name
|
|
167
|
+
|
|
44
168
|
|
|
45
169
|
@dataclasses.dataclass
|
|
46
170
|
class UnifiedCacheParams:
|
|
47
|
-
|
|
171
|
+
block_adapter: BlockAdapter = None
|
|
48
172
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
|
|
49
173
|
|
|
50
174
|
|
|
@@ -85,7 +209,7 @@ class UnifiedCacheAdapter:
|
|
|
85
209
|
|
|
86
210
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
87
211
|
return UnifiedCacheParams(
|
|
88
|
-
|
|
212
|
+
block_adapter=BlockAdapter(
|
|
89
213
|
pipe=pipe,
|
|
90
214
|
transformer=pipe.transformer,
|
|
91
215
|
blocks=(
|
|
@@ -102,7 +226,7 @@ class UnifiedCacheAdapter:
|
|
|
102
226
|
|
|
103
227
|
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
104
228
|
return UnifiedCacheParams(
|
|
105
|
-
|
|
229
|
+
block_adapter=BlockAdapter(
|
|
106
230
|
pipe=pipe,
|
|
107
231
|
transformer=pipe.transformer,
|
|
108
232
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -116,7 +240,7 @@ class UnifiedCacheAdapter:
|
|
|
116
240
|
|
|
117
241
|
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
118
242
|
return UnifiedCacheParams(
|
|
119
|
-
|
|
243
|
+
block_adapter=BlockAdapter(
|
|
120
244
|
pipe=pipe,
|
|
121
245
|
transformer=pipe.transformer,
|
|
122
246
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -136,7 +260,7 @@ class UnifiedCacheAdapter:
|
|
|
136
260
|
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
137
261
|
)
|
|
138
262
|
return UnifiedCacheParams(
|
|
139
|
-
|
|
263
|
+
block_adapter=BlockAdapter(
|
|
140
264
|
pipe=pipe,
|
|
141
265
|
transformer=pipe.transformer,
|
|
142
266
|
blocks=pipe.transformer.blocks,
|
|
@@ -150,7 +274,7 @@ class UnifiedCacheAdapter:
|
|
|
150
274
|
|
|
151
275
|
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
152
276
|
return UnifiedCacheParams(
|
|
153
|
-
|
|
277
|
+
block_adapter=BlockAdapter(
|
|
154
278
|
pipe=pipe,
|
|
155
279
|
blocks=(
|
|
156
280
|
pipe.transformer.transformer_blocks
|
|
@@ -166,7 +290,7 @@ class UnifiedCacheAdapter:
|
|
|
166
290
|
|
|
167
291
|
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
168
292
|
return UnifiedCacheParams(
|
|
169
|
-
|
|
293
|
+
block_adapter=BlockAdapter(
|
|
170
294
|
pipe=pipe,
|
|
171
295
|
transformer=pipe.transformer,
|
|
172
296
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -180,7 +304,7 @@ class UnifiedCacheAdapter:
|
|
|
180
304
|
|
|
181
305
|
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
182
306
|
return UnifiedCacheParams(
|
|
183
|
-
|
|
307
|
+
block_adapter=BlockAdapter(
|
|
184
308
|
pipe=pipe,
|
|
185
309
|
transformer=pipe.transformer,
|
|
186
310
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -194,7 +318,7 @@ class UnifiedCacheAdapter:
|
|
|
194
318
|
|
|
195
319
|
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
196
320
|
return UnifiedCacheParams(
|
|
197
|
-
|
|
321
|
+
block_adapter=BlockAdapter(
|
|
198
322
|
pipe=pipe,
|
|
199
323
|
transformer=pipe.transformer,
|
|
200
324
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -208,7 +332,7 @@ class UnifiedCacheAdapter:
|
|
|
208
332
|
|
|
209
333
|
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
210
334
|
return UnifiedCacheParams(
|
|
211
|
-
|
|
335
|
+
block_adapter=BlockAdapter(
|
|
212
336
|
pipe=pipe,
|
|
213
337
|
transformer=pipe.transformer,
|
|
214
338
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -222,7 +346,7 @@ class UnifiedCacheAdapter:
|
|
|
222
346
|
|
|
223
347
|
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
224
348
|
return UnifiedCacheParams(
|
|
225
|
-
|
|
349
|
+
block_adapter=BlockAdapter(
|
|
226
350
|
pipe=pipe,
|
|
227
351
|
transformer=pipe.transformer,
|
|
228
352
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -236,7 +360,7 @@ class UnifiedCacheAdapter:
|
|
|
236
360
|
|
|
237
361
|
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
238
362
|
return UnifiedCacheParams(
|
|
239
|
-
|
|
363
|
+
block_adapter=BlockAdapter(
|
|
240
364
|
pipe=pipe,
|
|
241
365
|
transformer=pipe.transformer,
|
|
242
366
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -250,7 +374,7 @@ class UnifiedCacheAdapter:
|
|
|
250
374
|
|
|
251
375
|
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
252
376
|
return UnifiedCacheParams(
|
|
253
|
-
|
|
377
|
+
block_adapter=BlockAdapter(
|
|
254
378
|
pipe=pipe,
|
|
255
379
|
transformer=pipe.transformer,
|
|
256
380
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -264,7 +388,7 @@ class UnifiedCacheAdapter:
|
|
|
264
388
|
|
|
265
389
|
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
266
390
|
return UnifiedCacheParams(
|
|
267
|
-
|
|
391
|
+
block_adapter=BlockAdapter(
|
|
268
392
|
pipe=pipe,
|
|
269
393
|
transformer=pipe.transformer,
|
|
270
394
|
blocks=pipe.transformer.blocks,
|
|
@@ -278,7 +402,7 @@ class UnifiedCacheAdapter:
|
|
|
278
402
|
|
|
279
403
|
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
280
404
|
return UnifiedCacheParams(
|
|
281
|
-
|
|
405
|
+
block_adapter=BlockAdapter(
|
|
282
406
|
pipe=pipe,
|
|
283
407
|
transformer=pipe.transformer,
|
|
284
408
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -294,13 +418,13 @@ class UnifiedCacheAdapter:
|
|
|
294
418
|
def apply(
|
|
295
419
|
cls,
|
|
296
420
|
pipe: DiffusionPipeline = None,
|
|
297
|
-
|
|
421
|
+
block_adapter: BlockAdapter = None,
|
|
298
422
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
299
423
|
**cache_context_kwargs,
|
|
300
424
|
) -> DiffusionPipeline:
|
|
301
425
|
assert (
|
|
302
|
-
pipe is not None or
|
|
303
|
-
), "pipe or
|
|
426
|
+
pipe is not None or block_adapter is not None
|
|
427
|
+
), "pipe or block_adapter can not both None!"
|
|
304
428
|
|
|
305
429
|
if pipe is not None:
|
|
306
430
|
if cls.is_supported(pipe):
|
|
@@ -310,7 +434,7 @@ class UnifiedCacheAdapter:
|
|
|
310
434
|
)
|
|
311
435
|
params = cls.get_params(pipe)
|
|
312
436
|
return cls.cachify(
|
|
313
|
-
params.
|
|
437
|
+
params.block_adapter,
|
|
314
438
|
forward_pattern=params.forward_pattern,
|
|
315
439
|
**cache_context_kwargs,
|
|
316
440
|
)
|
|
@@ -324,7 +448,7 @@ class UnifiedCacheAdapter:
|
|
|
324
448
|
"Adapting cache acceleration using custom BlockAdapter!"
|
|
325
449
|
)
|
|
326
450
|
return cls.cachify(
|
|
327
|
-
|
|
451
|
+
block_adapter,
|
|
328
452
|
forward_pattern=forward_pattern,
|
|
329
453
|
**cache_context_kwargs,
|
|
330
454
|
)
|
|
@@ -332,25 +456,28 @@ class UnifiedCacheAdapter:
|
|
|
332
456
|
@classmethod
|
|
333
457
|
def cachify(
|
|
334
458
|
cls,
|
|
335
|
-
|
|
459
|
+
block_adapter: BlockAdapter,
|
|
336
460
|
*,
|
|
337
461
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
338
462
|
**cache_context_kwargs,
|
|
339
463
|
) -> DiffusionPipeline:
|
|
340
|
-
|
|
341
|
-
|
|
464
|
+
|
|
465
|
+
if block_adapter.auto:
|
|
466
|
+
block_adapter = BlockAdapter.auto_block_adapter(block_adapter)
|
|
467
|
+
|
|
468
|
+
if BlockAdapter.check_block_adapter(block_adapter):
|
|
469
|
+
assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
|
|
342
470
|
# Apply cache on pipeline: wrap cache context
|
|
343
|
-
cls.create_context(
|
|
471
|
+
cls.create_context(block_adapter.pipe, **cache_context_kwargs)
|
|
344
472
|
# Apply cache on transformer: mock cached transformer blocks
|
|
345
473
|
cls.mock_blocks(
|
|
346
|
-
|
|
474
|
+
block_adapter,
|
|
347
475
|
forward_pattern=forward_pattern,
|
|
348
476
|
)
|
|
349
|
-
|
|
350
|
-
return adapter_params.pipe
|
|
477
|
+
return block_adapter.pipe
|
|
351
478
|
|
|
352
479
|
@classmethod
|
|
353
|
-
def
|
|
480
|
+
def has_separate_cfg(
|
|
354
481
|
cls,
|
|
355
482
|
pipe_or_transformer: DiffusionPipeline | Any,
|
|
356
483
|
) -> bool:
|
|
@@ -364,20 +491,9 @@ class UnifiedCacheAdapter:
|
|
|
364
491
|
@classmethod
|
|
365
492
|
def check_context_kwargs(cls, pipe, **cache_context_kwargs):
|
|
366
493
|
# Check cache_context_kwargs
|
|
367
|
-
if not cache_context_kwargs:
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
cache_context_kwargs["do_separate_classifier_free_guidance"] = (
|
|
371
|
-
True
|
|
372
|
-
)
|
|
373
|
-
logger.warning(
|
|
374
|
-
"cache_context_kwargs is empty, use default "
|
|
375
|
-
f"cache options: {cache_context_kwargs}"
|
|
376
|
-
)
|
|
377
|
-
else:
|
|
378
|
-
# Allow empty cache_type, we only support DBCache now.
|
|
379
|
-
if cache_context_kwargs.get("cache_type", None):
|
|
380
|
-
cache_context_kwargs["cache_type"] = CacheType.DBCache
|
|
494
|
+
if not cache_context_kwargs["do_separate_cfg"]:
|
|
495
|
+
# Check cfg for some specific case if users don't set it as True
|
|
496
|
+
cache_context_kwargs["do_separate_cfg"] = cls.has_separate_cfg(pipe)
|
|
381
497
|
|
|
382
498
|
if cache_type := cache_context_kwargs.pop("cache_type", None):
|
|
383
499
|
assert (
|
|
@@ -424,22 +540,33 @@ class UnifiedCacheAdapter:
|
|
|
424
540
|
@classmethod
|
|
425
541
|
def mock_blocks(
|
|
426
542
|
cls,
|
|
427
|
-
|
|
543
|
+
block_adapter: BlockAdapter,
|
|
428
544
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
429
545
|
) -> torch.nn.Module:
|
|
430
|
-
|
|
431
|
-
|
|
546
|
+
|
|
547
|
+
if (
|
|
548
|
+
block_adapter.transformer is None
|
|
549
|
+
or block_adapter.blocks_name is None
|
|
550
|
+
or block_adapter.blocks is None
|
|
551
|
+
):
|
|
552
|
+
assert block_adapter.auto, (
|
|
553
|
+
"Please manually set `auto` to True, or, "
|
|
554
|
+
"manually set transformer blocks configuration."
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
558
|
+
return block_adapter.transformer
|
|
432
559
|
|
|
433
560
|
# Firstly, process some specificial cases (TODO: more patches)
|
|
434
|
-
if
|
|
435
|
-
|
|
436
|
-
|
|
437
|
-
blocks=
|
|
561
|
+
if block_adapter.transformer.__class__.__name__.startswith("Flux"):
|
|
562
|
+
block_adapter.transformer = maybe_patch_flux_transformer(
|
|
563
|
+
block_adapter.transformer,
|
|
564
|
+
blocks=block_adapter.blocks,
|
|
438
565
|
)
|
|
439
566
|
|
|
440
567
|
# Check block forward pattern matching
|
|
441
568
|
assert cls.match_pattern(
|
|
442
|
-
|
|
569
|
+
block_adapter.blocks,
|
|
443
570
|
forward_pattern=forward_pattern,
|
|
444
571
|
), (
|
|
445
572
|
"No block forward pattern matched, "
|
|
@@ -450,22 +577,17 @@ class UnifiedCacheAdapter:
|
|
|
450
577
|
cached_blocks = torch.nn.ModuleList(
|
|
451
578
|
[
|
|
452
579
|
DBCachedTransformerBlocks(
|
|
453
|
-
|
|
454
|
-
transformer=
|
|
580
|
+
block_adapter.blocks,
|
|
581
|
+
transformer=block_adapter.transformer,
|
|
455
582
|
forward_pattern=forward_pattern,
|
|
456
583
|
)
|
|
457
584
|
]
|
|
458
585
|
)
|
|
459
586
|
dummy_blocks = torch.nn.ModuleList()
|
|
460
587
|
|
|
461
|
-
original_forward =
|
|
588
|
+
original_forward = block_adapter.transformer.forward
|
|
462
589
|
|
|
463
|
-
assert isinstance(
|
|
464
|
-
if adapter_params.blocks_name is None:
|
|
465
|
-
adapter_params.blocks_name = cls.find_blocks_name(
|
|
466
|
-
adapter_params.transformer
|
|
467
|
-
)
|
|
468
|
-
assert adapter_params.blocks_name is not None
|
|
590
|
+
assert isinstance(block_adapter.dummy_blocks_names, list)
|
|
469
591
|
|
|
470
592
|
@functools.wraps(original_forward)
|
|
471
593
|
def new_forward(self, *args, **kwargs):
|
|
@@ -473,11 +595,11 @@ class UnifiedCacheAdapter:
|
|
|
473
595
|
stack.enter_context(
|
|
474
596
|
unittest.mock.patch.object(
|
|
475
597
|
self,
|
|
476
|
-
|
|
598
|
+
block_adapter.blocks_name,
|
|
477
599
|
cached_blocks,
|
|
478
600
|
)
|
|
479
601
|
)
|
|
480
|
-
for dummy_name in
|
|
602
|
+
for dummy_name in block_adapter.dummy_blocks_names:
|
|
481
603
|
stack.enter_context(
|
|
482
604
|
unittest.mock.patch.object(
|
|
483
605
|
self,
|
|
@@ -487,12 +609,12 @@ class UnifiedCacheAdapter:
|
|
|
487
609
|
)
|
|
488
610
|
return original_forward(*args, **kwargs)
|
|
489
611
|
|
|
490
|
-
|
|
491
|
-
|
|
612
|
+
block_adapter.transformer.forward = new_forward.__get__(
|
|
613
|
+
block_adapter.transformer
|
|
492
614
|
)
|
|
493
|
-
|
|
615
|
+
block_adapter.transformer._is_cached = True
|
|
494
616
|
|
|
495
|
-
return
|
|
617
|
+
return block_adapter.transformer
|
|
496
618
|
|
|
497
619
|
@classmethod
|
|
498
620
|
def match_pattern(
|
|
@@ -536,22 +658,3 @@ class UnifiedCacheAdapter:
|
|
|
536
658
|
)
|
|
537
659
|
|
|
538
660
|
return pattern_matched
|
|
539
|
-
|
|
540
|
-
@classmethod
|
|
541
|
-
def find_blocks_name(cls, transformer):
|
|
542
|
-
blocks_name = None
|
|
543
|
-
allow_prefixes = ["transformer", "blocks"]
|
|
544
|
-
for attr_name in dir(transformer):
|
|
545
|
-
if blocks_name is None:
|
|
546
|
-
for prefix in allow_prefixes:
|
|
547
|
-
# transformer_blocks, blocks
|
|
548
|
-
if attr_name.startswith(prefix):
|
|
549
|
-
blocks_name = attr_name
|
|
550
|
-
logger.info(f"Auto selected blocks name: {blocks_name}")
|
|
551
|
-
# only find one transformer blocks name
|
|
552
|
-
break
|
|
553
|
-
if blocks_name is None:
|
|
554
|
-
logger.warning(
|
|
555
|
-
"Auto selected blocks name failed, please set it manually."
|
|
556
|
-
)
|
|
557
|
-
return blocks_name
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import torch
|
|
3
|
+
import torch.distributed as dist
|
|
3
4
|
|
|
4
5
|
from cache_dit.cache_factory import cache_context
|
|
5
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
@@ -179,10 +180,15 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
179
180
|
@torch.compiler.disable
|
|
180
181
|
def _is_parallelized(self):
|
|
181
182
|
# Compatible with distributed inference.
|
|
182
|
-
return
|
|
183
|
+
return any(
|
|
183
184
|
(
|
|
184
|
-
|
|
185
|
-
|
|
185
|
+
all(
|
|
186
|
+
(
|
|
187
|
+
self.transformer is not None,
|
|
188
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
189
|
+
)
|
|
190
|
+
),
|
|
191
|
+
(dist.is_initialized() and dist.get_world_size() > 1),
|
|
186
192
|
)
|
|
187
193
|
)
|
|
188
194
|
|