cache-dit 0.2.22__py3-none-any.whl → 0.2.24__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +1 -3
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -1
- cache_dit/cache_factory/cache_adapters.py +298 -123
- cache_dit/cache_factory/cache_blocks.py +9 -3
- cache_dit/cache_factory/cache_context.py +85 -15
- cache_dit/cache_factory/cache_interface.py +18 -11
- cache_dit/cache_factory/taylorseer.py +5 -4
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/utils.py +25 -22
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/METADATA +19 -10
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/RECORD +16 -17
- cache_dit/primitives.py +0 -152
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.22.dist-info → cache_dit-0.2.24.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -9,8 +9,8 @@ from cache_dit.cache_factory import enable_cache
|
|
|
9
9
|
from cache_dit.cache_factory import cache_type
|
|
10
10
|
from cache_dit.cache_factory import block_range
|
|
11
11
|
from cache_dit.cache_factory import CacheType
|
|
12
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
12
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
13
|
-
from cache_dit.cache_factory import BlockAdapterParams
|
|
14
14
|
from cache_dit.compile import set_compile_configs
|
|
15
15
|
from cache_dit.utils import summary
|
|
16
16
|
from cache_dit.utils import strify
|
|
@@ -19,8 +19,6 @@ from cache_dit.logger import init_logger
|
|
|
19
19
|
NONE = CacheType.NONE
|
|
20
20
|
DBCache = CacheType.DBCache
|
|
21
21
|
|
|
22
|
-
BlockAdapter = BlockAdapterParams
|
|
23
|
-
|
|
24
22
|
Forward_Pattern_0 = ForwardPattern.Pattern_0
|
|
25
23
|
Forward_Pattern_1 = ForwardPattern.Pattern_1
|
|
26
24
|
Forward_Pattern_2 = ForwardPattern.Pattern_2
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.24'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 24)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -2,7 +2,7 @@ from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
|
2
2
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
3
3
|
from cache_dit.cache_factory.cache_types import cache_type
|
|
4
4
|
from cache_dit.cache_factory.cache_types import block_range
|
|
5
|
-
from cache_dit.cache_factory.cache_adapters import
|
|
5
|
+
from cache_dit.cache_factory.cache_adapters import BlockAdapter
|
|
6
6
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
7
7
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
8
8
|
from cache_dit.cache_factory.utils import load_options
|
|
@@ -5,7 +5,7 @@ import unittest
|
|
|
5
5
|
import functools
|
|
6
6
|
import dataclasses
|
|
7
7
|
|
|
8
|
-
from typing import Any
|
|
8
|
+
from typing import Any, Tuple, List, Optional
|
|
9
9
|
from contextlib import ExitStack
|
|
10
10
|
from diffusers import DiffusionPipeline
|
|
11
11
|
from cache_dit.cache_factory.patch.flux import (
|
|
@@ -23,28 +23,251 @@ logger = init_logger(__name__)
|
|
|
23
23
|
|
|
24
24
|
|
|
25
25
|
@dataclasses.dataclass
|
|
26
|
-
class
|
|
26
|
+
class BlockAdapter:
|
|
27
27
|
pipe: DiffusionPipeline = None
|
|
28
28
|
transformer: torch.nn.Module = None
|
|
29
29
|
blocks: torch.nn.ModuleList = None
|
|
30
30
|
# transformer_blocks, blocks, etc.
|
|
31
31
|
blocks_name: str = None
|
|
32
32
|
dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
|
|
33
|
+
# flags to control auto block adapter
|
|
34
|
+
auto: bool = False
|
|
35
|
+
allow_prefixes: List[str] = dataclasses.field(
|
|
36
|
+
default_factory=lambda: [
|
|
37
|
+
"transformer",
|
|
38
|
+
"single_transformer",
|
|
39
|
+
"blocks",
|
|
40
|
+
"layers",
|
|
41
|
+
]
|
|
42
|
+
)
|
|
43
|
+
check_prefixes: bool = True
|
|
44
|
+
allow_suffixes: List[str] = dataclasses.field(
|
|
45
|
+
default_factory=lambda: ["TransformerBlock"]
|
|
46
|
+
)
|
|
47
|
+
check_suffixes: bool = False
|
|
48
|
+
blocks_policy: str = dataclasses.field(
|
|
49
|
+
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
def __post_init__(self):
|
|
53
|
+
self.maybe_apply_patch()
|
|
54
|
+
|
|
55
|
+
def maybe_apply_patch(self):
|
|
56
|
+
# Process some specificial cases, specific for transformers
|
|
57
|
+
# that has different forward patterns between single_transformer_blocks
|
|
58
|
+
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
59
|
+
if self.transformer.__class__.__name__.startswith("Flux"):
|
|
60
|
+
self.transformer = maybe_patch_flux_transformer(
|
|
61
|
+
self.transformer,
|
|
62
|
+
blocks=self.blocks,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
@staticmethod
|
|
66
|
+
def auto_block_adapter(
|
|
67
|
+
adapter: "BlockAdapter",
|
|
68
|
+
forward_pattern: Optional[ForwardPattern] = None,
|
|
69
|
+
) -> "BlockAdapter":
|
|
70
|
+
assert adapter.auto, (
|
|
71
|
+
"Please manually set `auto` to True, or, manually "
|
|
72
|
+
"set all the transformer blocks configuration."
|
|
73
|
+
)
|
|
74
|
+
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
75
|
+
pipe = adapter.pipe
|
|
76
|
+
|
|
77
|
+
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
78
|
+
|
|
79
|
+
transformer = pipe.transformer
|
|
80
|
+
|
|
81
|
+
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
82
|
+
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
83
|
+
transformer=transformer,
|
|
84
|
+
allow_prefixes=adapter.allow_prefixes,
|
|
85
|
+
allow_suffixes=adapter.allow_suffixes,
|
|
86
|
+
check_prefixes=adapter.check_prefixes,
|
|
87
|
+
check_suffixes=adapter.check_suffixes,
|
|
88
|
+
blocks_policy=adapter.blocks_policy,
|
|
89
|
+
forward_pattern=forward_pattern,
|
|
90
|
+
)
|
|
91
|
+
|
|
92
|
+
return BlockAdapter(
|
|
93
|
+
pipe=pipe,
|
|
94
|
+
transformer=transformer,
|
|
95
|
+
blocks=blocks,
|
|
96
|
+
blocks_name=blocks_name,
|
|
97
|
+
)
|
|
33
98
|
|
|
34
|
-
|
|
99
|
+
@staticmethod
|
|
100
|
+
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
35
101
|
if (
|
|
36
|
-
isinstance(
|
|
37
|
-
and
|
|
38
|
-
and
|
|
39
|
-
and
|
|
102
|
+
isinstance(adapter.pipe, DiffusionPipeline)
|
|
103
|
+
and adapter.transformer is not None
|
|
104
|
+
and adapter.blocks is not None
|
|
105
|
+
and adapter.blocks_name is not None
|
|
106
|
+
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
40
107
|
):
|
|
41
108
|
return True
|
|
109
|
+
|
|
110
|
+
logger.warning("Check block adapter failed!")
|
|
42
111
|
return False
|
|
43
112
|
|
|
113
|
+
@staticmethod
|
|
114
|
+
def find_blocks(
|
|
115
|
+
transformer: torch.nn.Module,
|
|
116
|
+
allow_prefixes: List[str] = [
|
|
117
|
+
"transformer",
|
|
118
|
+
"single_transformer",
|
|
119
|
+
"blocks",
|
|
120
|
+
"layers",
|
|
121
|
+
],
|
|
122
|
+
allow_suffixes: List[str] = [
|
|
123
|
+
"TransformerBlock",
|
|
124
|
+
],
|
|
125
|
+
check_prefixes: bool = True,
|
|
126
|
+
check_suffixes: bool = False,
|
|
127
|
+
**kwargs,
|
|
128
|
+
) -> Tuple[torch.nn.ModuleList, str]:
|
|
129
|
+
# Check prefixes
|
|
130
|
+
if check_prefixes:
|
|
131
|
+
blocks_names = []
|
|
132
|
+
for attr_name in dir(transformer):
|
|
133
|
+
for prefix in allow_prefixes:
|
|
134
|
+
if attr_name.startswith(prefix):
|
|
135
|
+
blocks_names.append(attr_name)
|
|
136
|
+
else:
|
|
137
|
+
blocks_names = dir(transformer)
|
|
138
|
+
|
|
139
|
+
# Check ModuleList
|
|
140
|
+
valid_names = []
|
|
141
|
+
valid_count = []
|
|
142
|
+
forward_pattern = kwargs.get("forward_pattern", None)
|
|
143
|
+
for blocks_name in blocks_names:
|
|
144
|
+
if blocks := getattr(transformer, blocks_name, None):
|
|
145
|
+
if isinstance(blocks, torch.nn.ModuleList):
|
|
146
|
+
block = blocks[0]
|
|
147
|
+
block_cls_name = block.__class__.__name__
|
|
148
|
+
# Check suffixes
|
|
149
|
+
if isinstance(block, torch.nn.Module) and (
|
|
150
|
+
any(
|
|
151
|
+
(
|
|
152
|
+
block_cls_name.endswith(allow_suffix)
|
|
153
|
+
for allow_suffix in allow_suffixes
|
|
154
|
+
)
|
|
155
|
+
)
|
|
156
|
+
or (not check_suffixes)
|
|
157
|
+
):
|
|
158
|
+
# May check forward pattern
|
|
159
|
+
if forward_pattern is not None:
|
|
160
|
+
if BlockAdapter.match_blocks_pattern(
|
|
161
|
+
blocks,
|
|
162
|
+
forward_pattern,
|
|
163
|
+
logging=False,
|
|
164
|
+
):
|
|
165
|
+
valid_names.append(blocks_name)
|
|
166
|
+
valid_count.append(len(blocks))
|
|
167
|
+
else:
|
|
168
|
+
valid_names.append(blocks_name)
|
|
169
|
+
valid_count.append(len(blocks))
|
|
170
|
+
|
|
171
|
+
if not valid_names:
|
|
172
|
+
raise ValueError(
|
|
173
|
+
"Auto selected transformer blocks failed, please set it manually."
|
|
174
|
+
)
|
|
175
|
+
|
|
176
|
+
final_name = valid_names[0]
|
|
177
|
+
final_count = valid_count[0]
|
|
178
|
+
block_policy = kwargs.get("blocks_policy", "max")
|
|
179
|
+
|
|
180
|
+
for blocks_name, count in zip(valid_names, valid_count):
|
|
181
|
+
blocks = getattr(transformer, blocks_name)
|
|
182
|
+
logger.info(
|
|
183
|
+
f"Auto selected transformer blocks: {blocks_name}, "
|
|
184
|
+
f"class: {blocks[0].__class__.__name__}, "
|
|
185
|
+
f"num blocks: {count}"
|
|
186
|
+
)
|
|
187
|
+
if block_policy == "max":
|
|
188
|
+
if final_count < count:
|
|
189
|
+
final_count = count
|
|
190
|
+
final_name = blocks_name
|
|
191
|
+
else:
|
|
192
|
+
if final_count > count:
|
|
193
|
+
final_count = count
|
|
194
|
+
final_name = blocks_name
|
|
195
|
+
|
|
196
|
+
final_blocks = getattr(transformer, final_name)
|
|
197
|
+
|
|
198
|
+
logger.info(
|
|
199
|
+
f"Final selected transformer blocks: {final_name}, "
|
|
200
|
+
f"class: {final_blocks[0].__class__.__name__}, "
|
|
201
|
+
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
202
|
+
)
|
|
203
|
+
|
|
204
|
+
return final_blocks, final_name
|
|
205
|
+
|
|
206
|
+
@staticmethod
|
|
207
|
+
def match_block_pattern(
|
|
208
|
+
block: torch.nn.Module,
|
|
209
|
+
forward_pattern: ForwardPattern,
|
|
210
|
+
) -> bool:
|
|
211
|
+
assert (
|
|
212
|
+
forward_pattern.Supported
|
|
213
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
214
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
215
|
+
|
|
216
|
+
forward_parameters = set(
|
|
217
|
+
inspect.signature(block.forward).parameters.keys()
|
|
218
|
+
)
|
|
219
|
+
num_outputs = str(
|
|
220
|
+
inspect.signature(block.forward).return_annotation
|
|
221
|
+
).count("torch.Tensor")
|
|
222
|
+
|
|
223
|
+
in_matched = True
|
|
224
|
+
out_matched = True
|
|
225
|
+
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
226
|
+
# output pattern not match
|
|
227
|
+
out_matched = False
|
|
228
|
+
|
|
229
|
+
for required_param in forward_pattern.In:
|
|
230
|
+
if required_param not in forward_parameters:
|
|
231
|
+
in_matched = False
|
|
232
|
+
|
|
233
|
+
return in_matched and out_matched
|
|
234
|
+
|
|
235
|
+
@staticmethod
|
|
236
|
+
def match_blocks_pattern(
|
|
237
|
+
transformer_blocks: torch.nn.ModuleList,
|
|
238
|
+
forward_pattern: ForwardPattern,
|
|
239
|
+
logging: bool = True,
|
|
240
|
+
) -> bool:
|
|
241
|
+
assert (
|
|
242
|
+
forward_pattern.Supported
|
|
243
|
+
and forward_pattern in ForwardPattern.supported_patterns()
|
|
244
|
+
), f"Pattern {forward_pattern} is not support now!"
|
|
245
|
+
|
|
246
|
+
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
247
|
+
|
|
248
|
+
pattern_matched_states = []
|
|
249
|
+
for block in transformer_blocks:
|
|
250
|
+
pattern_matched_states.append(
|
|
251
|
+
BlockAdapter.match_block_pattern(
|
|
252
|
+
block,
|
|
253
|
+
forward_pattern,
|
|
254
|
+
)
|
|
255
|
+
)
|
|
256
|
+
|
|
257
|
+
pattern_matched = all(pattern_matched_states) # all block match
|
|
258
|
+
if pattern_matched and logging:
|
|
259
|
+
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
260
|
+
logger.info(
|
|
261
|
+
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
262
|
+
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
return pattern_matched
|
|
266
|
+
|
|
44
267
|
|
|
45
268
|
@dataclasses.dataclass
|
|
46
269
|
class UnifiedCacheParams:
|
|
47
|
-
|
|
270
|
+
block_adapter: BlockAdapter = None
|
|
48
271
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
|
|
49
272
|
|
|
50
273
|
|
|
@@ -85,7 +308,7 @@ class UnifiedCacheAdapter:
|
|
|
85
308
|
|
|
86
309
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
87
310
|
return UnifiedCacheParams(
|
|
88
|
-
|
|
311
|
+
block_adapter=BlockAdapter(
|
|
89
312
|
pipe=pipe,
|
|
90
313
|
transformer=pipe.transformer,
|
|
91
314
|
blocks=(
|
|
@@ -102,7 +325,7 @@ class UnifiedCacheAdapter:
|
|
|
102
325
|
|
|
103
326
|
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
104
327
|
return UnifiedCacheParams(
|
|
105
|
-
|
|
328
|
+
block_adapter=BlockAdapter(
|
|
106
329
|
pipe=pipe,
|
|
107
330
|
transformer=pipe.transformer,
|
|
108
331
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -116,7 +339,7 @@ class UnifiedCacheAdapter:
|
|
|
116
339
|
|
|
117
340
|
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
118
341
|
return UnifiedCacheParams(
|
|
119
|
-
|
|
342
|
+
block_adapter=BlockAdapter(
|
|
120
343
|
pipe=pipe,
|
|
121
344
|
transformer=pipe.transformer,
|
|
122
345
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -136,7 +359,7 @@ class UnifiedCacheAdapter:
|
|
|
136
359
|
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
137
360
|
)
|
|
138
361
|
return UnifiedCacheParams(
|
|
139
|
-
|
|
362
|
+
block_adapter=BlockAdapter(
|
|
140
363
|
pipe=pipe,
|
|
141
364
|
transformer=pipe.transformer,
|
|
142
365
|
blocks=pipe.transformer.blocks,
|
|
@@ -150,7 +373,7 @@ class UnifiedCacheAdapter:
|
|
|
150
373
|
|
|
151
374
|
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
152
375
|
return UnifiedCacheParams(
|
|
153
|
-
|
|
376
|
+
block_adapter=BlockAdapter(
|
|
154
377
|
pipe=pipe,
|
|
155
378
|
blocks=(
|
|
156
379
|
pipe.transformer.transformer_blocks
|
|
@@ -166,7 +389,7 @@ class UnifiedCacheAdapter:
|
|
|
166
389
|
|
|
167
390
|
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
168
391
|
return UnifiedCacheParams(
|
|
169
|
-
|
|
392
|
+
block_adapter=BlockAdapter(
|
|
170
393
|
pipe=pipe,
|
|
171
394
|
transformer=pipe.transformer,
|
|
172
395
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -180,7 +403,7 @@ class UnifiedCacheAdapter:
|
|
|
180
403
|
|
|
181
404
|
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
182
405
|
return UnifiedCacheParams(
|
|
183
|
-
|
|
406
|
+
block_adapter=BlockAdapter(
|
|
184
407
|
pipe=pipe,
|
|
185
408
|
transformer=pipe.transformer,
|
|
186
409
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -194,7 +417,7 @@ class UnifiedCacheAdapter:
|
|
|
194
417
|
|
|
195
418
|
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
196
419
|
return UnifiedCacheParams(
|
|
197
|
-
|
|
420
|
+
block_adapter=BlockAdapter(
|
|
198
421
|
pipe=pipe,
|
|
199
422
|
transformer=pipe.transformer,
|
|
200
423
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -208,7 +431,7 @@ class UnifiedCacheAdapter:
|
|
|
208
431
|
|
|
209
432
|
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
210
433
|
return UnifiedCacheParams(
|
|
211
|
-
|
|
434
|
+
block_adapter=BlockAdapter(
|
|
212
435
|
pipe=pipe,
|
|
213
436
|
transformer=pipe.transformer,
|
|
214
437
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -222,7 +445,7 @@ class UnifiedCacheAdapter:
|
|
|
222
445
|
|
|
223
446
|
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
224
447
|
return UnifiedCacheParams(
|
|
225
|
-
|
|
448
|
+
block_adapter=BlockAdapter(
|
|
226
449
|
pipe=pipe,
|
|
227
450
|
transformer=pipe.transformer,
|
|
228
451
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -236,7 +459,7 @@ class UnifiedCacheAdapter:
|
|
|
236
459
|
|
|
237
460
|
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
238
461
|
return UnifiedCacheParams(
|
|
239
|
-
|
|
462
|
+
block_adapter=BlockAdapter(
|
|
240
463
|
pipe=pipe,
|
|
241
464
|
transformer=pipe.transformer,
|
|
242
465
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -250,7 +473,7 @@ class UnifiedCacheAdapter:
|
|
|
250
473
|
|
|
251
474
|
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
252
475
|
return UnifiedCacheParams(
|
|
253
|
-
|
|
476
|
+
block_adapter=BlockAdapter(
|
|
254
477
|
pipe=pipe,
|
|
255
478
|
transformer=pipe.transformer,
|
|
256
479
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -264,7 +487,7 @@ class UnifiedCacheAdapter:
|
|
|
264
487
|
|
|
265
488
|
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
266
489
|
return UnifiedCacheParams(
|
|
267
|
-
|
|
490
|
+
block_adapter=BlockAdapter(
|
|
268
491
|
pipe=pipe,
|
|
269
492
|
transformer=pipe.transformer,
|
|
270
493
|
blocks=pipe.transformer.blocks,
|
|
@@ -278,7 +501,7 @@ class UnifiedCacheAdapter:
|
|
|
278
501
|
|
|
279
502
|
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
280
503
|
return UnifiedCacheParams(
|
|
281
|
-
|
|
504
|
+
block_adapter=BlockAdapter(
|
|
282
505
|
pipe=pipe,
|
|
283
506
|
transformer=pipe.transformer,
|
|
284
507
|
blocks=pipe.transformer.transformer_blocks,
|
|
@@ -294,13 +517,13 @@ class UnifiedCacheAdapter:
|
|
|
294
517
|
def apply(
|
|
295
518
|
cls,
|
|
296
519
|
pipe: DiffusionPipeline = None,
|
|
297
|
-
|
|
520
|
+
block_adapter: BlockAdapter = None,
|
|
298
521
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
299
522
|
**cache_context_kwargs,
|
|
300
523
|
) -> DiffusionPipeline:
|
|
301
524
|
assert (
|
|
302
|
-
pipe is not None or
|
|
303
|
-
), "pipe or
|
|
525
|
+
pipe is not None or block_adapter is not None
|
|
526
|
+
), "pipe or block_adapter can not both None!"
|
|
304
527
|
|
|
305
528
|
if pipe is not None:
|
|
306
529
|
if cls.is_supported(pipe):
|
|
@@ -310,7 +533,7 @@ class UnifiedCacheAdapter:
|
|
|
310
533
|
)
|
|
311
534
|
params = cls.get_params(pipe)
|
|
312
535
|
return cls.cachify(
|
|
313
|
-
params.
|
|
536
|
+
params.block_adapter,
|
|
314
537
|
forward_pattern=params.forward_pattern,
|
|
315
538
|
**cache_context_kwargs,
|
|
316
539
|
)
|
|
@@ -324,7 +547,7 @@ class UnifiedCacheAdapter:
|
|
|
324
547
|
"Adapting cache acceleration using custom BlockAdapter!"
|
|
325
548
|
)
|
|
326
549
|
return cls.cachify(
|
|
327
|
-
|
|
550
|
+
block_adapter,
|
|
328
551
|
forward_pattern=forward_pattern,
|
|
329
552
|
**cache_context_kwargs,
|
|
330
553
|
)
|
|
@@ -332,22 +555,48 @@ class UnifiedCacheAdapter:
|
|
|
332
555
|
@classmethod
|
|
333
556
|
def cachify(
|
|
334
557
|
cls,
|
|
335
|
-
|
|
558
|
+
block_adapter: BlockAdapter,
|
|
336
559
|
*,
|
|
337
560
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
338
561
|
**cache_context_kwargs,
|
|
339
562
|
) -> DiffusionPipeline:
|
|
340
|
-
|
|
341
|
-
|
|
563
|
+
|
|
564
|
+
if block_adapter.auto:
|
|
565
|
+
block_adapter = BlockAdapter.auto_block_adapter(
|
|
566
|
+
block_adapter,
|
|
567
|
+
forward_pattern,
|
|
568
|
+
)
|
|
569
|
+
|
|
570
|
+
if BlockAdapter.check_block_adapter(block_adapter):
|
|
342
571
|
# Apply cache on pipeline: wrap cache context
|
|
343
|
-
cls.create_context(
|
|
572
|
+
cls.create_context(
|
|
573
|
+
block_adapter.pipe,
|
|
574
|
+
**cache_context_kwargs,
|
|
575
|
+
)
|
|
344
576
|
# Apply cache on transformer: mock cached transformer blocks
|
|
345
577
|
cls.mock_blocks(
|
|
346
|
-
|
|
578
|
+
block_adapter,
|
|
347
579
|
forward_pattern=forward_pattern,
|
|
348
580
|
)
|
|
581
|
+
cls.patch_params(
|
|
582
|
+
block_adapter,
|
|
583
|
+
forward_pattern=forward_pattern,
|
|
584
|
+
**cache_context_kwargs,
|
|
585
|
+
)
|
|
586
|
+
return block_adapter.pipe
|
|
349
587
|
|
|
350
|
-
|
|
588
|
+
@classmethod
|
|
589
|
+
def patch_params(
|
|
590
|
+
cls,
|
|
591
|
+
block_adapter: BlockAdapter,
|
|
592
|
+
forward_pattern: ForwardPattern = None,
|
|
593
|
+
**cache_context_kwargs,
|
|
594
|
+
):
|
|
595
|
+
block_adapter.transformer._forward_pattern = forward_pattern
|
|
596
|
+
block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
|
|
597
|
+
block_adapter.pipe.__class__._cache_context_kwargs = (
|
|
598
|
+
cache_context_kwargs
|
|
599
|
+
)
|
|
351
600
|
|
|
352
601
|
@classmethod
|
|
353
602
|
def has_separate_cfg(
|
|
@@ -407,28 +656,21 @@ class UnifiedCacheAdapter:
|
|
|
407
656
|
|
|
408
657
|
pipe.__class__.__call__ = new_call
|
|
409
658
|
pipe.__class__._is_cached = True
|
|
410
|
-
pipe.__class__._cache_options = cache_kwargs
|
|
411
659
|
return pipe
|
|
412
660
|
|
|
413
661
|
@classmethod
|
|
414
662
|
def mock_blocks(
|
|
415
663
|
cls,
|
|
416
|
-
|
|
664
|
+
block_adapter: BlockAdapter,
|
|
417
665
|
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
418
666
|
) -> torch.nn.Module:
|
|
419
|
-
|
|
420
|
-
|
|
421
|
-
|
|
422
|
-
# Firstly, process some specificial cases (TODO: more patches)
|
|
423
|
-
if adapter_params.transformer.__class__.__name__.startswith("Flux"):
|
|
424
|
-
adapter_params.transformer = maybe_patch_flux_transformer(
|
|
425
|
-
adapter_params.transformer,
|
|
426
|
-
blocks=adapter_params.blocks,
|
|
427
|
-
)
|
|
667
|
+
|
|
668
|
+
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
669
|
+
return block_adapter.transformer
|
|
428
670
|
|
|
429
671
|
# Check block forward pattern matching
|
|
430
|
-
assert
|
|
431
|
-
|
|
672
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
673
|
+
block_adapter.blocks,
|
|
432
674
|
forward_pattern=forward_pattern,
|
|
433
675
|
), (
|
|
434
676
|
"No block forward pattern matched, "
|
|
@@ -439,22 +681,17 @@ class UnifiedCacheAdapter:
|
|
|
439
681
|
cached_blocks = torch.nn.ModuleList(
|
|
440
682
|
[
|
|
441
683
|
DBCachedTransformerBlocks(
|
|
442
|
-
|
|
443
|
-
transformer=
|
|
684
|
+
block_adapter.blocks,
|
|
685
|
+
transformer=block_adapter.transformer,
|
|
444
686
|
forward_pattern=forward_pattern,
|
|
445
687
|
)
|
|
446
688
|
]
|
|
447
689
|
)
|
|
448
690
|
dummy_blocks = torch.nn.ModuleList()
|
|
449
691
|
|
|
450
|
-
original_forward =
|
|
692
|
+
original_forward = block_adapter.transformer.forward
|
|
451
693
|
|
|
452
|
-
assert isinstance(
|
|
453
|
-
if adapter_params.blocks_name is None:
|
|
454
|
-
adapter_params.blocks_name = cls.find_blocks_name(
|
|
455
|
-
adapter_params.transformer
|
|
456
|
-
)
|
|
457
|
-
assert adapter_params.blocks_name is not None
|
|
694
|
+
assert isinstance(block_adapter.dummy_blocks_names, list)
|
|
458
695
|
|
|
459
696
|
@functools.wraps(original_forward)
|
|
460
697
|
def new_forward(self, *args, **kwargs):
|
|
@@ -462,11 +699,11 @@ class UnifiedCacheAdapter:
|
|
|
462
699
|
stack.enter_context(
|
|
463
700
|
unittest.mock.patch.object(
|
|
464
701
|
self,
|
|
465
|
-
|
|
702
|
+
block_adapter.blocks_name,
|
|
466
703
|
cached_blocks,
|
|
467
704
|
)
|
|
468
705
|
)
|
|
469
|
-
for dummy_name in
|
|
706
|
+
for dummy_name in block_adapter.dummy_blocks_names:
|
|
470
707
|
stack.enter_context(
|
|
471
708
|
unittest.mock.patch.object(
|
|
472
709
|
self,
|
|
@@ -476,71 +713,9 @@ class UnifiedCacheAdapter:
|
|
|
476
713
|
)
|
|
477
714
|
return original_forward(*args, **kwargs)
|
|
478
715
|
|
|
479
|
-
|
|
480
|
-
|
|
716
|
+
block_adapter.transformer.forward = new_forward.__get__(
|
|
717
|
+
block_adapter.transformer
|
|
481
718
|
)
|
|
482
|
-
|
|
719
|
+
block_adapter.transformer._is_cached = True
|
|
483
720
|
|
|
484
|
-
return
|
|
485
|
-
|
|
486
|
-
@classmethod
|
|
487
|
-
def match_pattern(
|
|
488
|
-
cls,
|
|
489
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
490
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
491
|
-
) -> bool:
|
|
492
|
-
pattern_matched_states = []
|
|
493
|
-
|
|
494
|
-
assert (
|
|
495
|
-
forward_pattern.Supported
|
|
496
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
497
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
498
|
-
|
|
499
|
-
for block in transformer_blocks:
|
|
500
|
-
forward_parameters = set(
|
|
501
|
-
inspect.signature(block.forward).parameters.keys()
|
|
502
|
-
)
|
|
503
|
-
num_outputs = str(
|
|
504
|
-
inspect.signature(block.forward).return_annotation
|
|
505
|
-
).count("torch.Tensor")
|
|
506
|
-
|
|
507
|
-
in_matched = True
|
|
508
|
-
out_matched = True
|
|
509
|
-
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
510
|
-
# output pattern not match
|
|
511
|
-
out_matched = False
|
|
512
|
-
|
|
513
|
-
for required_param in forward_pattern.In:
|
|
514
|
-
if required_param not in forward_parameters:
|
|
515
|
-
in_matched = False
|
|
516
|
-
|
|
517
|
-
pattern_matched_states.append(in_matched and out_matched)
|
|
518
|
-
|
|
519
|
-
pattern_matched = all(pattern_matched_states) # all block match
|
|
520
|
-
if pattern_matched:
|
|
521
|
-
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
522
|
-
logger.info(
|
|
523
|
-
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
524
|
-
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
525
|
-
)
|
|
526
|
-
|
|
527
|
-
return pattern_matched
|
|
528
|
-
|
|
529
|
-
@classmethod
|
|
530
|
-
def find_blocks_name(cls, transformer):
|
|
531
|
-
blocks_name = None
|
|
532
|
-
allow_prefixes = ["transformer", "blocks"]
|
|
533
|
-
for attr_name in dir(transformer):
|
|
534
|
-
if blocks_name is None:
|
|
535
|
-
for prefix in allow_prefixes:
|
|
536
|
-
# transformer_blocks, blocks
|
|
537
|
-
if attr_name.startswith(prefix):
|
|
538
|
-
blocks_name = attr_name
|
|
539
|
-
logger.info(f"Auto selected blocks name: {blocks_name}")
|
|
540
|
-
# only find one transformer blocks name
|
|
541
|
-
break
|
|
542
|
-
if blocks_name is None:
|
|
543
|
-
logger.warning(
|
|
544
|
-
"Auto selected blocks name failed, please set it manually."
|
|
545
|
-
)
|
|
546
|
-
return blocks_name
|
|
721
|
+
return block_adapter.transformer
|
|
@@ -1,5 +1,6 @@
|
|
|
1
1
|
import inspect
|
|
2
2
|
import torch
|
|
3
|
+
import torch.distributed as dist
|
|
3
4
|
|
|
4
5
|
from cache_dit.cache_factory import cache_context
|
|
5
6
|
from cache_dit.cache_factory import ForwardPattern
|
|
@@ -179,10 +180,15 @@ class DBCachedTransformerBlocks(torch.nn.Module):
|
|
|
179
180
|
@torch.compiler.disable
|
|
180
181
|
def _is_parallelized(self):
|
|
181
182
|
# Compatible with distributed inference.
|
|
182
|
-
return
|
|
183
|
+
return any(
|
|
183
184
|
(
|
|
184
|
-
|
|
185
|
-
|
|
185
|
+
all(
|
|
186
|
+
(
|
|
187
|
+
self.transformer is not None,
|
|
188
|
+
getattr(self.transformer, "_is_parallelized", False),
|
|
189
|
+
)
|
|
190
|
+
),
|
|
191
|
+
(dist.is_initialized() and dist.get_world_size() > 1),
|
|
186
192
|
)
|
|
187
193
|
)
|
|
188
194
|
|