cache-dit 0.2.22__py3-none-any.whl → 0.2.23__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -1
- cache_dit/cache_factory/cache_adapters.py +190 -76
- cache_dit/cache_factory/cache_blocks.py +9 -3
- cache_dit/cache_factory/cache_interface.py +8 -8
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.23.dist-info}/METADATA +13 -6
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.23.dist-info}/RECORD +12 -12
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.23.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.23.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.23.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.23.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -9,8 +9,8 @@ from cache_dit.cache_factory import enable_cache
|
|
|
9
9
|
from cache_dit.cache_factory import cache_type
|
|
10
10
|
from cache_dit.cache_factory import block_range
|
|
11
11
|
from cache_dit.cache_factory import CacheType
|
|
12
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
12
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
13
|
-
from cache_dit.cache_factory import BlockAdapterParams
|
|
14
14
|
from cache_dit.compile import set_compile_configs
|
|
15
15
|
from cache_dit.utils import summary
|
|
16
16
|
from cache_dit.utils import strify
|
|
@@ -19,8 +19,6 @@ from cache_dit.logger import init_logger
|
|
|
19
19
|
NONE = CacheType.NONE
|
|
20
20
|
DBCache = CacheType.DBCache
|
|
21
21
|
|
|
22
|
-
BlockAdapter = BlockAdapterParams
|
|
23
|
-
|
|
24
22
|
Forward_Pattern_0 = ForwardPattern.Pattern_0
|
|
25
23
|
Forward_Pattern_1 = ForwardPattern.Pattern_1
|
|
26
24
|
Forward_Pattern_2 = ForwardPattern.Pattern_2
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.23'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 23)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -2,7 +2,7 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
|
2
2
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
3
3
|
from cache_dit.cache_factory.cache_types import cache_type
|
|
4
4
|
from cache_dit.cache_factory.cache_types import block_range
|
|
5
|
-
from cache_dit.cache_factory.cache_adapters import
|
|
5
|
+
from cache_dit.cache_factory.cache_adapters import BlockAdapter
|
|
6
6
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
7
7
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
8
8
|
from cache_dit.cache_factory.utils import load_options
|
|
@@ -5,7 +5,7 @@ import unittest
|
|
|
5
5
|
import functools
|
|
6
6
|
import dataclasses
|
|
7
7
|
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any, Tuple, List
|
|
9
9
|
from contextlib import ExitStack
|
|
10
10
|
from diffusers import DiffusionPipeline
|
|
11
11
|
from cache_dit.cache_factory.patch.flux import (
|
|
@@ -23,28 +23,152 @@ logger = init_logger(__name__)
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@dataclasses.dataclass
|
|
26
|
-
class
|
|
26
|
+
class BlockAdapter:
|
|
27
27
|
pipe: DiffusionPipeline = None
|
|
28
28
|
transformer: torch.nn.Module = None
|
|
29
29
|
blocks: torch.nn.ModuleList = None
|
|
30
30
|
# transformer_blocks, blocks, etc.
|
|
31
31
|
blocks_name: str = None
|
|
32
32
|
dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
|
|
33
|
+
# flags to control auto block adapter
|
|
34
|
+
auto: bool = False
|
|
35
|
+
allow_prefixes: List[str] = dataclasses.field(
|
|
36
|
+
default_factory=lambda: [
|
|
37
|
+
"transformer",
|
|
38
|
+
"single_transformer",
|
|
39
|
+
"blocks",
|
|
40
|
+
"layers",
|
|
41
|
+
]
|
|
42
|
+
)
|
|
43
|
+
allow_suffixes: List[str] = dataclasses.field(
|
|
44
|
+
default_factory=lambda: ["TransformerBlock"]
|
|
45
|
+
)
|
|
46
|
+
check_suffixes: bool = False
|
|
47
|
+
blocks_policy: str = dataclasses.field(
|
|
48
|
+
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
@staticmethod
|
|
52
|
+
def auto_block_adapter(adapter: "BlockAdapter") -> "BlockAdapter":
|
|
53
|
+
assert adapter.auto, (
|
|
54
|
+
"Please manually set `auto` to True, or, manually "
|
|
55
|
+
"set all the transformer blocks configuration."
|
|
56
|
+
)
|
|
57
|
+
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
58
|
+
pipe = adapter.pipe
|
|
59
|
+
|
|
60
|
+
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
61
|
+
|
|
62
|
+
transformer = pipe.transformer
|
|
63
|
+
|
|
64
|
+
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
65
|
+
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
66
|
+
transformer=transformer,
|
|
67
|
+
allow_prefixes=adapter.allow_prefixes,
|
|
68
|
+
allow_suffixes=adapter.allow_suffixes,
|
|
69
|
+
check_suffixes=adapter.check_suffixes,
|
|
70
|
+
blocks_policy=adapter.blocks_policy,
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
return BlockAdapter(
|
|
74
|
+
pipe=pipe,
|
|
75
|
+
transformer=transformer,
|
|
76
|
+
blocks=blocks,
|
|
77
|
+
blocks_name=blocks_name,
|
|
78
|
+
)
|
|
33
79
|
|
|
34
|
-
|
|
80
|
+
@staticmethod
|
|
81
|
+
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
35
82
|
if (
|
|
36
|
-
isinstance(
|
|
37
|
-
and
|
|
38
|
-
and
|
|
39
|
-
and
|
|
83
|
+
isinstance(adapter.pipe, DiffusionPipeline)
|
|
84
|
+
and adapter.transformer is not None
|
|
85
|
+
and adapter.blocks is not None
|
|
86
|
+
and adapter.blocks_name is not None
|
|
87
|
+
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
40
88
|
):
|
|
41
89
|
return True
|
|
42
90
|
return False
|
|
43
91
|
|
|
92
|
+
@staticmethod
|
|
93
|
+
def find_blocks(
|
|
94
|
+
transformer: torch.nn.Module,
|
|
95
|
+
allow_prefixes: List[str] = [
|
|
96
|
+
"transformer",
|
|
97
|
+
"single_transformer",
|
|
98
|
+
"blocks",
|
|
99
|
+
"layers",
|
|
100
|
+
],
|
|
101
|
+
allow_suffixes: List[str] = [
|
|
102
|
+
"TransformerBlock",
|
|
103
|
+
],
|
|
104
|
+
check_suffixes: bool = False,
|
|
105
|
+
**kwargs,
|
|
106
|
+
) -> Tuple[torch.nn.ModuleList, str]:
|
|
107
|
+
|
|
108
|
+
blocks_names = []
|
|
109
|
+
for attr_name in dir(transformer):
|
|
110
|
+
for prefix in allow_prefixes:
|
|
111
|
+
if attr_name.startswith(prefix):
|
|
112
|
+
blocks_names.append(attr_name)
|
|
113
|
+
|
|
114
|
+
# Type check
|
|
115
|
+
valid_names = []
|
|
116
|
+
valid_count = []
|
|
117
|
+
for blocks_name in blocks_names:
|
|
118
|
+
if blocks := getattr(transformer, blocks_name, None):
|
|
119
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
120
|
+
block = blocks[0]
|
|
121
|
+
block_cls_name = block.__class__.__name__
|
|
122
|
+
if isinstance(block, torch.nn.Module) and (
|
|
123
|
+
any(
|
|
124
|
+
(
|
|
125
|
+
block_cls_name.endswith(allow_suffix)
|
|
126
|
+
for allow_suffix in allow_suffixes
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
or (not check_suffixes)
|
|
130
|
+
):
|
|
131
|
+
valid_names.append(blocks_name)
|
|
132
|
+
valid_count.append(len(blocks))
|
|
133
|
+
|
|
134
|
+
if not valid_names:
|
|
135
|
+
raise ValueError(
|
|
136
|
+
"Auto selected transformer blocks failed, please set it manually."
|
|
137
|
+
)
|
|
138
|
+
|
|
139
|
+
final_name = valid_names[0]
|
|
140
|
+
final_count = valid_count[0]
|
|
141
|
+
block_policy = kwargs.get("blocks_policy", "max")
|
|
142
|
+
for blocks_name, count in zip(valid_names, valid_count):
|
|
143
|
+
blocks = getattr(transformer, blocks_name)
|
|
144
|
+
logger.info(
|
|
145
|
+
f"Auto selected transformer blocks: {blocks_name}, "
|
|
146
|
+
f"class: {blocks[0].__class__.__name__}, "
|
|
147
|
+
f"num blocks: {count}"
|
|
148
|
+
)
|
|
149
|
+
if block_policy == "max":
|
|
150
|
+
if final_count < count:
|
|
151
|
+
final_count = count
|
|
152
|
+
final_name = blocks_name
|
|
153
|
+
else:
|
|
154
|
+
if final_count > count:
|
|
155
|
+
final_count = count
|
|
156
|
+
final_name = blocks_name
|
|
157
|
+
|
|
158
|
+
final_blocks = getattr(transformer, final_name)
|
|
159
|
+
|
|
160
|
+
logger.info(
|
|
161
|
+
f"Final selected transformer blocks: {final_name}, "
|
|
162
|
+
f"class: {final_blocks[0].__class__.__name__}, "
|
|
163
|
+
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
164
|
+
)
|
|
165
|
+
|
|
166
|
+
return final_blocks, final_name
|
|
167
|
+
|
|
44
168
|
|
|
45
169
|
@dataclasses.dataclass
|
|
46
170
|
class UnifiedCacheParams:
|
|
47
|
-
|
|
171
|
+
block_adapter: BlockAdapter = None
|
|
48
172
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
|
|
49
173
|
|
|
50
174
|
|
|
@@ -85,7 +209,7 @@ class UnifiedCacheAdapter:
|
|
|
85
209
|
|
|
86
210
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
87
211
|
return UnifiedCacheParams(
|
|
88
|
-
|
|
212
|
+
block_adapter=BlockAdapter(
|
|
89
213
|
pipe=pipe,
|
|
90
214
|
transformer=pipe.transformer,
|
|
91
215
|
blocks=(
|
|
@@ -102,7 +226,7 @@ class UnifiedCacheAdapter:
|
|
|
102
226
|
|
|
103
227
|
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
104
228
|
return UnifiedCacheParams(
|
|
105
|
-
|
|
229
|
+
block_adapter=BlockAdapter(
|
|
106
230
|
pipe=pipe,
|
|
107
231
|
transformer=pipe.transformer,
|
|
108
232
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -116,7 +240,7 @@ class UnifiedCacheAdapter:
|
|
|
116
240
|
|
|
117
241
|
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
118
242
|
return UnifiedCacheParams(
|
|
119
|
-
|
|
243
|
+
block_adapter=BlockAdapter(
|
|
120
244
|
pipe=pipe,
|
|
121
245
|
transformer=pipe.transformer,
|
|
122
246
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -136,7 +260,7 @@ class UnifiedCacheAdapter:
|
|
|
136
260
|
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
137
261
|
)
|
|
138
262
|
return UnifiedCacheParams(
|
|
139
|
-
|
|
263
|
+
block_adapter=BlockAdapter(
|
|
140
264
|
pipe=pipe,
|
|
141
265
|
transformer=pipe.transformer,
|
|
142
266
|
blocks=pipe.transformer.blocks,
|
|
@@ -150,7 +274,7 @@ class UnifiedCacheAdapter:
|
|
|
150
274
|
|
|
151
275
|
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
152
276
|
return UnifiedCacheParams(
|
|
153
|
-
|
|
277
|
+
block_adapter=BlockAdapter(
|
|
154
278
|
pipe=pipe,
|
|
155
279
|
blocks=(
|
|
156
280
|
pipe.transformer.transformer_blocks
|
|
@@ -166,7 +290,7 @@ class UnifiedCacheAdapter:
|
|
|
166
290
|
|
|
167
291
|
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
168
292
|
return UnifiedCacheParams(
|
|
169
|
-
|
|
293
|
+
block_adapter=BlockAdapter(
|
|
170
294
|
pipe=pipe,
|
|
171
295
|
transformer=pipe.transformer,
|
|
172
296
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -180,7 +304,7 @@ class UnifiedCacheAdapter:
|
|
|
180
304
|
|
|
181
305
|
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
182
306
|
return UnifiedCacheParams(
|
|
183
|
-
|
|
307
|
+
block_adapter=BlockAdapter(
|
|
184
308
|
pipe=pipe,
|
|
185
309
|
transformer=pipe.transformer,
|
|
186
310
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -194,7 +318,7 @@ class UnifiedCacheAdapter:
|
|
|
194
318
|
|
|
195
319
|
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
196
320
|
return UnifiedCacheParams(
|
|
197
|
-
|
|
321
|
+
block_adapter=BlockAdapter(
|
|
198
322
|
pipe=pipe,
|
|
199
323
|
transformer=pipe.transformer,
|
|
200
324
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -208,7 +332,7 @@ class UnifiedCacheAdapter:
|
|
|
208
332
|
|
|
209
333
|
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
210
334
|
return UnifiedCacheParams(
|
|
211
|
-
|
|
335
|
+
block_adapter=BlockAdapter(
|
|
212
336
|
pipe=pipe,
|
|
213
337
|
transformer=pipe.transformer,
|
|
214
338
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -222,7 +346,7 @@ class UnifiedCacheAdapter:
|
|
|
222
346
|
|
|
223
347
|
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
224
348
|
return UnifiedCacheParams(
|
|
225
|
-
|
|
349
|
+
block_adapter=BlockAdapter(
|
|
226
350
|
pipe=pipe,
|
|
227
351
|
transformer=pipe.transformer,
|
|
228
352
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -236,7 +360,7 @@ class UnifiedCacheAdapter:
|
|
|
236
360
|
|
|
237
361
|
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
238
362
|
return UnifiedCacheParams(
|
|
239
|
-
|
|
363
|
+
block_adapter=BlockAdapter(
|
|
240
364
|
pipe=pipe,
|
|
241
365
|
transformer=pipe.transformer,
|
|
242
366
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -250,7 +374,7 @@ class UnifiedCacheAdapter:
|
|
|
250
374
|
|
|
251
375
|
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
252
376
|
return UnifiedCacheParams(
|
|
253
|
-
|
|
377
|
+
block_adapter=BlockAdapter(
|
|
254
378
|
pipe=pipe,
|
|
255
379
|
transformer=pipe.transformer,
|
|
256
380
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -264,7 +388,7 @@ class UnifiedCacheAdapter:
|
|
|
264
388
|
|
|
265
389
|
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
266
390
|
return UnifiedCacheParams(
|
|
267
|
-
|
|
391
|
+
block_adapter=BlockAdapter(
|
|
268
392
|
pipe=pipe,
|
|
269
393
|
transformer=pipe.transformer,
|
|
270
394
|
blocks=pipe.transformer.blocks,
|
|
@@ -278,7 +402,7 @@ class UnifiedCacheAdapter:
|
|
|
278
402
|
|
|
279
403
|
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
280
404
|
return UnifiedCacheParams(
|
|
281
|
-
|
|
405
|
+
block_adapter=BlockAdapter(
|
|
282
406
|
pipe=pipe,
|
|
283
407
|
transformer=pipe.transformer,
|
|
284
408
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -294,13 +418,13 @@ class UnifiedCacheAdapter:
|
|
|
294
418
|
def apply(
|
|
295
419
|
cls,
|
|
296
420
|
pipe: DiffusionPipeline = None,
|
|
297
|
-
|
|
421
|
+
block_adapter: BlockAdapter = None,
|
|
298
422
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
299
423
|
**cache_context_kwargs,
|
|
300
424
|
) -> DiffusionPipeline:
|
|
301
425
|
assert (
|
|
302
|
-
pipe is not None or
|
|
303
|
-
), "pipe or
|
|
426
|
+
pipe is not None or block_adapter is not None
|
|
427
|
+
), "pipe or block_adapter can not both None!"
|
|
304
428
|
|
|
305
429
|
if pipe is not None:
|
|
306
430
|
if cls.is_supported(pipe):
|
|
@@ -310,7 +434,7 @@ class UnifiedCacheAdapter:
|
|
|
310
434
|
)
|
|
311
435
|
params = cls.get_params(pipe)
|
|
312
436
|
return cls.cachify(
|
|
313
|
-
params.
|
|
437
|
+
params.block_adapter,
|
|
314
438
|
forward_pattern=params.forward_pattern,
|
|
315
439
|
**cache_context_kwargs,
|
|
316
440
|
)
|
|
@@ -324,7 +448,7 @@ class UnifiedCacheAdapter:
|
|
|
324
448
|
"Adapting cache acceleration using custom BlockAdapter!"
|
|
325
449
|
)
|
|
326
450
|
return cls.cachify(
|
|
327
|
-
|
|
451
|
+
block_adapter,
|
|
328
452
|
forward_pattern=forward_pattern,
|
|
329
453
|
**cache_context_kwargs,
|
|
330
454
|
)
|
|
@@ -332,22 +456,25 @@ class UnifiedCacheAdapter:
|
|
|
332
456
|
@classmethod
|
|
333
457
|
def cachify(
|
|
334
458
|
cls,
|
|
335
|
-
|
|
459
|
+
block_adapter: BlockAdapter,
|
|
336
460
|
*,
|
|
337
461
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
338
462
|
**cache_context_kwargs,
|
|
339
463
|
) -> DiffusionPipeline:
|
|
340
|
-
|
|
341
|
-
|
|
464
|
+
|
|
465
|
+
if block_adapter.auto:
|
|
466
|
+
block_adapter = BlockAdapter.auto_block_adapter(block_adapter)
|
|
467
|
+
|
|
468
|
+
if BlockAdapter.check_block_adapter(block_adapter):
|
|
469
|
+
assert isinstance(block_adapter.blocks, torch.nn.ModuleList)
|
|
342
470
|
# Apply cache on pipeline: wrap cache context
|
|
343
|
-
cls.create_context(
|
|
471
|
+
cls.create_context(block_adapter.pipe, **cache_context_kwargs)
|
|
344
472
|
# Apply cache on transformer: mock cached transformer blocks
|
|
345
473
|
cls.mock_blocks(
|
|
346
|
-
|
|
474
|
+
block_adapter,
|
|
347
475
|
forward_pattern=forward_pattern,
|
|
348
476
|
)
|
|
349
|
-
|
|
350
|
-
return adapter_params.pipe
|
|
477
|
+
return block_adapter.pipe
|
|
351
478
|
|
|
352
479
|
@classmethod
|
|
353
480
|
def has_separate_cfg(
|
|
@@ -413,22 +540,33 @@ class UnifiedCacheAdapter:
|
|
|
413
540
|
@classmethod
|
|
414
541
|
def mock_blocks(
|
|
415
542
|
cls,
|
|
416
|
-
|
|
543
|
+
block_adapter: BlockAdapter,
|
|
417
544
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
418
545
|
) -> torch.nn.Module:
|
|
419
|
-
|
|
420
|
-
|
|
546
|
+
|
|
547
|
+
if (
|
|
548
|
+
block_adapter.transformer is None
|
|
549
|
+
or block_adapter.blocks_name is None
|
|
550
|
+
or block_adapter.blocks is None
|
|
551
|
+
):
|
|
552
|
+
assert block_adapter.auto, (
|
|
553
|
+
"Please manually set `auto` to True, or, "
|
|
554
|
+
"manually set transformer blocks configuration."
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
558
|
+
return block_adapter.transformer
|
|
421
559
|
|
|
422
560
|
# Firstly, process some specificial cases (TODO: more patches)
|
|
423
|
-
if
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
blocks=
|
|
561
|
+
if block_adapter.transformer.__class__.__name__.startswith("Flux"):
|
|
562
|
+
block_adapter.transformer = maybe_patch_flux_transformer(
|
|
563
|
+
block_adapter.transformer,
|
|
564
|
+
blocks=block_adapter.blocks,
|
|
427
565
|
)
|
|
428
566
|
|
|
429
567
|
# Check block forward pattern matching
|
|
430
568
|
assert cls.match_pattern(
|
|
431
|
-
|
|
569
|
+
block_adapter.blocks,
|
|
432
570
|
forward_pattern=forward_pattern,
|
|
433
571
|
), (
|
|
434
572
|
"No block forward pattern matched, "
|
|
@@ -439,22 +577,17 @@ class UnifiedCacheAdapter:
|
|
|
439
577
|
cached_blocks = torch.nn.ModuleList(
|
|
440
578
|
[
|
|
441
579
|
DBCachedTransformerBlocks(
|
|
442
|
-
|
|
443
|
-
transformer=
|
|
580
|
+
block_adapter.blocks,
|
|
581
|
+
transformer=block_adapter.transformer,
|
|
444
582
|
forward_pattern=forward_pattern,
|
|
445
583
|
)
|
|
446
584
|
]
|
|
447
585
|
)
|
|
448
586
|
dummy_blocks = torch.nn.ModuleList()
|
|
449
587
|
|
|
450
|
-
original_forward =
|
|
588
|
+
original_forward = block_adapter.transformer.forward
|
|
451
589
|
|
|
452
|
-
assert isinstance(
|
|
453
|
-
if adapter_params.blocks_name is None:
|
|
454
|
-
adapter_params.blocks_name = cls.find_blocks_name(
|
|
455
|
-
adapter_params.transformer
|
|
456
|
-
)
|
|
457
|
-
assert adapter_params.blocks_name is not None
|
|
590
|
+
assert isinstance(block_adapter.dummy_blocks_names, list)
|
|
458
591
|
|
|
459
592
|
@functools.wraps(original_forward)
|
|
460
593
|
def new_forward(self, *args, **kwargs):
|
|
@@ -462,11 +595,11 @@ class UnifiedCacheAdapter:
|
|
|
462
595
|
stack.enter_context(
|
|
463
596
|
unittest.mock.patch.object(
|
|
464
597
|
self,
|
|
465
|
-
|
|
598
|
+
block_adapter.blocks_name,
|
|
466
599
|
cached_blocks,
|
|
467
600
|
)
|
|
468
601
|
)
|
|
469
|
-
for dummy_name in
|
|
602
|
+
for dummy_name in block_adapter.dummy_blocks_names:
|
|
470
603
|
stack.enter_context(
|
|
471
604
|
unittest.mock.patch.object(
|
|
472
605
|
self,
|
|
@@ -476,12 +609,12 @@ class UnifiedCacheAdapter:
|
|
|
476
609
|
)
|
|
477
610
|
return original_forward(*args, **kwargs)
|
|
478
611
|
|
|
479
|
-
|
|
480
|
-
|
|
612
|
+
block_adapter.transformer.forward = new_forward.__get__(
|
|
613
|
+
block_adapter.transformer
|
|
481
614
|
)
|
|
482
|
-
|
|
615
|
+
block_adapter.transformer._is_cached = True
|
|
483
616
|
|
|
484
|
-
return
|
|
617
|
+
return block_adapter.transformer
|
|
485
618
|
|
|
486
619
|
@classmethod
|
|
487
620
|
def match_pattern(
|
|
@@ -525,22 +658,3 @@ class UnifiedCacheAdapter:
|
|
|
525
658
|
)
|
|
526
659
|
|
|
527
660
|
return pattern_matched
|
|
528
|
-
|
|
529
|
-
@classmethod
|
|
530
|
-
def find_blocks_name(cls, transformer):
|
|
531
|
-
blocks_name = None
|
|
532
|
-
allow_prefixes = ["transformer", "blocks"]
|
|
533
|
-
for attr_name in dir(transformer):
|
|
534
|
-
if blocks_name is None:
|
|
535
|
-
for prefix in allow_prefixes:
|
|
536
|
-
# transformer_blocks, blocks
|
|
537
|
-
if attr_name.startswith(prefix):
|
|
538
|
-
blocks_name = attr_name
|
|
539
|
-
logger.info(f"Auto selected blocks name: {blocks_name}")
|
|
540
|
-
# only find one transformer blocks name
|
|
541
|
-
break
|
|
542
|
-
if blocks_name is None:
|
|
543
|
-
logger.warning(
|
|
544
|
-
"Auto selected blocks name failed, please set it manually."
|
|
545
|
-
)
|
|
546
|
-
return blocks_name
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import torch
|
|
3
|
+
import torch.distributed as dist
|
|
3
4
|
|
|
4
5
|
from cache_dit.cache_factory import cache_context
|
|
5
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
@@ -179,10 +180,15 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
179
180
|
@torch.compiler.disable
|
|
180
181
|
def _is_parallelized(self):
|
|
181
182
|
# Compatible with distributed inference.
|
|
182
|
-
return
|
|
183
|
+
return any(
|
|
183
184
|
(
|
|
184
|
-
|
|
185
|
-
|
|
185
|
+
all(
|
|
186
|
+
(
|
|
187
|
+
self.transformer is not None,
|
|
188
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
189
|
+
)
|
|
190
|
+
),
|
|
191
|
+
(dist.is_initialized() and dist.get_world_size() > 1),
|
|
186
192
|
)
|
|
187
193
|
)
|
|
188
194
|
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from diffusers import DiffusionPipeline
|
|
2
2
|
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
3
3
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
|
-
from cache_dit.cache_factory.cache_adapters import
|
|
4
|
+
from cache_dit.cache_factory.cache_adapters import BlockAdapter
|
|
5
5
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
6
6
|
|
|
7
7
|
from cache_dit.logger import init_logger
|
|
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
|
|
|
11
11
|
|
|
12
12
|
def enable_cache(
|
|
13
13
|
# BlockAdapter & forward pattern
|
|
14
|
-
pipe_or_adapter: DiffusionPipeline |
|
|
14
|
+
pipe_or_adapter: DiffusionPipeline | BlockAdapter,
|
|
15
15
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
16
16
|
# Cache context kwargs
|
|
17
17
|
Fn_compute_blocks: int = 8,
|
|
@@ -38,7 +38,7 @@ def enable_cache(
|
|
|
38
38
|
with F8B0, 8 warmup steps, and unlimited cached steps.
|
|
39
39
|
|
|
40
40
|
Args:
|
|
41
|
-
pipe_or_adapter (`DiffusionPipeline` or `
|
|
41
|
+
pipe_or_adapter (`DiffusionPipeline` or `BlockAdapter`, *required*):
|
|
42
42
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
43
43
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
44
44
|
for the usgae of BlockAdapter.
|
|
@@ -128,22 +128,22 @@ def enable_cache(
|
|
|
128
128
|
"n_derivatives": taylorseer_order
|
|
129
129
|
}
|
|
130
130
|
|
|
131
|
-
if isinstance(pipe_or_adapter,
|
|
131
|
+
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
132
132
|
return UnifiedCacheAdapter.apply(
|
|
133
133
|
pipe=None,
|
|
134
|
-
|
|
134
|
+
block_adapter=pipe_or_adapter,
|
|
135
135
|
forward_pattern=forward_pattern,
|
|
136
136
|
**cache_context_kwargs,
|
|
137
137
|
)
|
|
138
138
|
elif isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
139
139
|
return UnifiedCacheAdapter.apply(
|
|
140
140
|
pipe=pipe_or_adapter,
|
|
141
|
-
|
|
141
|
+
block_adapter=None,
|
|
142
142
|
forward_pattern=forward_pattern,
|
|
143
143
|
**cache_context_kwargs,
|
|
144
144
|
)
|
|
145
145
|
else:
|
|
146
146
|
raise ValueError(
|
|
147
|
-
"Please pass DiffusionPipeline or
|
|
148
|
-
"
|
|
147
|
+
"Please pass DiffusionPipeline or BlockAdapter"
|
|
148
|
+
"for the 1's position param: pipe_or_adapter"
|
|
149
149
|
)
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: cache_dit
|
|
3
|
-
Version: 0.2.
|
|
3
|
+
Version: 0.2.23
|
|
4
4
|
Summary: 🤗 CacheDiT: An Unified and Training-free Cache Acceleration Toolbox for Diffusion Transformers
|
|
5
5
|
Author: DefTruth, vipshop.com, etc.
|
|
6
6
|
Maintainer: DefTruth, vipshop.com, etc
|
|
@@ -173,16 +173,23 @@ But in some cases, you may have a **modified** Diffusion Pipeline or Transformer
|
|
|
173
173
|
```python
|
|
174
174
|
from cache_dit import ForwardPattern, BlockAdapter
|
|
175
175
|
|
|
176
|
-
#
|
|
176
|
+
# Use BlockAdapter with `auto` mode.
|
|
177
|
+
cache_dit.enable_cache(
|
|
178
|
+
BlockAdapter(pipe=pipe, auto=True), # Qwen-Image, etc.
|
|
179
|
+
# Check `📚Forward Pattern Matching` documentation and hack the code of
|
|
180
|
+
# of Qwen-Image, you will find that it has satisfied `FORWARD_PATTERN_1`.
|
|
181
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
# Or, manualy setup transformer configurations.
|
|
177
185
|
cache_dit.enable_cache(
|
|
178
186
|
BlockAdapter(
|
|
179
187
|
pipe=pipe, # Qwen-Image, etc.
|
|
180
188
|
transformer=pipe.transformer,
|
|
181
189
|
blocks=pipe.transformer.transformer_blocks,
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
190
|
+
blocks_name="transformer_blocks",
|
|
191
|
+
),
|
|
192
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
186
193
|
)
|
|
187
194
|
```
|
|
188
195
|
For such situations, **BlockAdapter** can help you quickly apply various cache acceleration features to your own Diffusion Pipelines and Transformers. Please check the [📚BlockAdapter.md](./docs/BlockAdapter.md) for more details.
|
|
@@ -1,14 +1,14 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=
|
|
2
|
-
cache_dit/_version.py,sha256=
|
|
1
|
+
cache_dit/__init__.py,sha256=KwhX9NfYkWSvDFuuUVeVjcuiZiGS_22y386l8j4afMo,905
|
|
2
|
+
cache_dit/_version.py,sha256=6GZdGbiFdhndXqR5oFLOd8VGzUvRkESP-NyStAZWYUw,706
|
|
3
3
|
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
4
|
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
5
|
cache_dit/utils.py,sha256=3UgVhfmTFG28w6CV-Rfxp5u1uzLrRozocHwLCTGiQ5M,5865
|
|
6
6
|
cache_dit/cache_factory/.gitignore,sha256=5Cb-qT9wsTUoMJ7vACDF7ZcLpAXhi5v-xdcWSRit988,23
|
|
7
|
-
cache_dit/cache_factory/__init__.py,sha256=
|
|
8
|
-
cache_dit/cache_factory/cache_adapters.py,sha256=
|
|
9
|
-
cache_dit/cache_factory/cache_blocks.py,sha256=
|
|
7
|
+
cache_dit/cache_factory/__init__.py,sha256=evWenCin1kuBGa6W5BCKMrDZc1C1R2uVPSg0BjXgdXE,499
|
|
8
|
+
cache_dit/cache_factory/cache_adapters.py,sha256=QTSwjdCmHDeF80TLp6D3KhzQS_oMPna0_bESgJBrdkg,23978
|
|
9
|
+
cache_dit/cache_factory/cache_blocks.py,sha256=ZeazBsYvLIjI5M_OnLL2xP2W7zMeM0rxVfBBwIVHBRs,18661
|
|
10
10
|
cache_dit/cache_factory/cache_context.py,sha256=4thx9NYxVaYZ_Nr2quUVE8bsNmTsXhZK0F960rccOc8,39000
|
|
11
|
-
cache_dit/cache_factory/cache_interface.py,sha256=
|
|
11
|
+
cache_dit/cache_factory/cache_interface.py,sha256=PohG_2oy747O37YSsWz_DwxxTXN7ORhQatyEbg_6fQs,8045
|
|
12
12
|
cache_dit/cache_factory/cache_types.py,sha256=FIFa6ZBfvvSMMHyBBhvarvgg2Y2wbRgITcG_uGylGe0,991
|
|
13
13
|
cache_dit/cache_factory/forward_pattern.py,sha256=B2YeqV2t_zo2Ar8m7qimPBjwQgoXHGp2grPZmEAhi8s,1286
|
|
14
14
|
cache_dit/cache_factory/taylorseer.py,sha256=WeK2WlAJa4Px_pnAKokmnZXeqQYylQkPw4-EDqBIqeQ,3770
|
|
@@ -25,9 +25,9 @@ cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,1706
|
|
|
25
25
|
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
26
26
|
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
27
27
|
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
28
|
-
cache_dit-0.2.
|
|
29
|
-
cache_dit-0.2.
|
|
30
|
-
cache_dit-0.2.
|
|
31
|
-
cache_dit-0.2.
|
|
32
|
-
cache_dit-0.2.
|
|
33
|
-
cache_dit-0.2.
|
|
28
|
+
cache_dit-0.2.23.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
29
|
+
cache_dit-0.2.23.dist-info/METADATA,sha256=Dq2f8TlyTmv36otIJ2F-fRGkJlZmpW2SY6O14P2AYKo,19772
|
|
30
|
+
cache_dit-0.2.23.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
31
|
+
cache_dit-0.2.23.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
32
|
+
cache_dit-0.2.23.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
33
|
+
cache_dit-0.2.23.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|