cache-dit 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +7 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +15 -4
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +120 -911
- cache_dit/cache_factory/cache_blocks/__init__.py +7 -9
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +46 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +98 -79
- cache_dit/cache_factory/cache_blocks/utils.py +13 -9
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +89 -55
- cache_dit/cache_factory/cache_contexts/cache_manager.py +0 -0
- cache_dit/cache_factory/cache_interface.py +21 -18
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +19 -16
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/METADATA +42 -12
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -1,865 +1,35 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
import inspect
|
|
4
3
|
import unittest
|
|
5
4
|
import functools
|
|
6
|
-
import dataclasses
|
|
7
5
|
|
|
8
|
-
from typing import
|
|
6
|
+
from typing import Dict
|
|
9
7
|
from contextlib import ExitStack
|
|
10
8
|
from diffusers import DiffusionPipeline
|
|
11
9
|
from cache_dit.cache_factory import CacheType
|
|
12
|
-
from cache_dit.cache_factory import
|
|
10
|
+
from cache_dit.cache_factory import CachedContext
|
|
13
11
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
|
-
from cache_dit.cache_factory
|
|
15
|
-
from cache_dit.cache_factory
|
|
16
|
-
|
|
17
|
-
)
|
|
12
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
14
|
+
from cache_dit.cache_factory import CachedBlocks
|
|
18
15
|
|
|
19
16
|
from cache_dit.logger import init_logger
|
|
20
17
|
|
|
21
18
|
logger = init_logger(__name__)
|
|
22
19
|
|
|
23
20
|
|
|
24
|
-
|
|
25
|
-
class
|
|
26
|
-
pipe: DiffusionPipeline | Any = None
|
|
27
|
-
transformer: torch.nn.Module = None
|
|
28
|
-
blocks: torch.nn.ModuleList = None
|
|
29
|
-
# transformer_blocks, blocks, etc.
|
|
30
|
-
blocks_name: str = None
|
|
31
|
-
dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
|
|
32
|
-
# patch functor: Flux, etc.
|
|
33
|
-
patch_functor: Optional[PatchFunctor] = None
|
|
34
|
-
# flags to control auto block adapter
|
|
35
|
-
auto: bool = False
|
|
36
|
-
allow_prefixes: List[str] = dataclasses.field(
|
|
37
|
-
default_factory=lambda: [
|
|
38
|
-
"transformer",
|
|
39
|
-
"single_transformer",
|
|
40
|
-
"blocks",
|
|
41
|
-
"layers",
|
|
42
|
-
"single_stream_blocks",
|
|
43
|
-
"double_stream_blocks",
|
|
44
|
-
]
|
|
45
|
-
)
|
|
46
|
-
check_prefixes: bool = True
|
|
47
|
-
allow_suffixes: List[str] = dataclasses.field(
|
|
48
|
-
default_factory=lambda: ["TransformerBlock"]
|
|
49
|
-
)
|
|
50
|
-
check_suffixes: bool = False
|
|
51
|
-
blocks_policy: str = dataclasses.field(
|
|
52
|
-
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
def __post_init__(self):
|
|
56
|
-
assert any((self.pipe is not None, self.transformer is not None))
|
|
57
|
-
self.patchify()
|
|
58
|
-
|
|
59
|
-
def patchify(self, *args, **kwargs):
|
|
60
|
-
# Process some specificial cases, specific for transformers
|
|
61
|
-
# that has different forward patterns between single_transformer_blocks
|
|
62
|
-
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
63
|
-
if self.patch_functor is not None:
|
|
64
|
-
if self.transformer is not None:
|
|
65
|
-
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
66
|
-
else:
|
|
67
|
-
assert hasattr(self.pipe, "transformer")
|
|
68
|
-
self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
|
|
69
|
-
|
|
70
|
-
@staticmethod
|
|
71
|
-
def auto_block_adapter(
|
|
72
|
-
adapter: "BlockAdapter",
|
|
73
|
-
forward_pattern: Optional[ForwardPattern] = None,
|
|
74
|
-
) -> "BlockAdapter":
|
|
75
|
-
assert adapter.auto, (
|
|
76
|
-
"Please manually set `auto` to True, or, manually "
|
|
77
|
-
"set all the transformer blocks configuration."
|
|
78
|
-
)
|
|
79
|
-
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
80
|
-
pipe = adapter.pipe
|
|
81
|
-
|
|
82
|
-
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
83
|
-
|
|
84
|
-
transformer = pipe.transformer
|
|
85
|
-
|
|
86
|
-
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
87
|
-
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
88
|
-
transformer=transformer,
|
|
89
|
-
allow_prefixes=adapter.allow_prefixes,
|
|
90
|
-
allow_suffixes=adapter.allow_suffixes,
|
|
91
|
-
check_prefixes=adapter.check_prefixes,
|
|
92
|
-
check_suffixes=adapter.check_suffixes,
|
|
93
|
-
blocks_policy=adapter.blocks_policy,
|
|
94
|
-
forward_pattern=forward_pattern,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
return BlockAdapter(
|
|
98
|
-
pipe=pipe,
|
|
99
|
-
transformer=transformer,
|
|
100
|
-
blocks=blocks,
|
|
101
|
-
blocks_name=blocks_name,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
@staticmethod
|
|
105
|
-
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
106
|
-
if (
|
|
107
|
-
# NOTE: pipe may not need to be DiffusionPipeline?
|
|
108
|
-
# isinstance(adapter.pipe, DiffusionPipeline)
|
|
109
|
-
adapter.pipe is not None
|
|
110
|
-
and adapter.transformer is not None
|
|
111
|
-
and adapter.blocks is not None
|
|
112
|
-
and adapter.blocks_name is not None
|
|
113
|
-
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
114
|
-
):
|
|
115
|
-
return True
|
|
116
|
-
|
|
117
|
-
logger.warning("Check block adapter failed!")
|
|
118
|
-
return False
|
|
119
|
-
|
|
120
|
-
@staticmethod
|
|
121
|
-
def find_blocks(
|
|
122
|
-
transformer: torch.nn.Module,
|
|
123
|
-
allow_prefixes: List[str] = [
|
|
124
|
-
"transformer",
|
|
125
|
-
"single_transformer",
|
|
126
|
-
"blocks",
|
|
127
|
-
"layers",
|
|
128
|
-
],
|
|
129
|
-
allow_suffixes: List[str] = [
|
|
130
|
-
"TransformerBlock",
|
|
131
|
-
],
|
|
132
|
-
check_prefixes: bool = True,
|
|
133
|
-
check_suffixes: bool = False,
|
|
134
|
-
**kwargs,
|
|
135
|
-
) -> Tuple[torch.nn.ModuleList, str]:
|
|
136
|
-
# Check prefixes
|
|
137
|
-
if check_prefixes:
|
|
138
|
-
blocks_names = []
|
|
139
|
-
for attr_name in dir(transformer):
|
|
140
|
-
for prefix in allow_prefixes:
|
|
141
|
-
if attr_name.startswith(prefix):
|
|
142
|
-
blocks_names.append(attr_name)
|
|
143
|
-
else:
|
|
144
|
-
blocks_names = dir(transformer)
|
|
145
|
-
|
|
146
|
-
# Check ModuleList
|
|
147
|
-
valid_names = []
|
|
148
|
-
valid_count = []
|
|
149
|
-
forward_pattern = kwargs.get("forward_pattern", None)
|
|
150
|
-
for blocks_name in blocks_names:
|
|
151
|
-
if blocks := getattr(transformer, blocks_name, None):
|
|
152
|
-
if isinstance(blocks, torch.nn.ModuleList):
|
|
153
|
-
block = blocks[0]
|
|
154
|
-
block_cls_name = block.__class__.__name__
|
|
155
|
-
# Check suffixes
|
|
156
|
-
if isinstance(block, torch.nn.Module) and (
|
|
157
|
-
any(
|
|
158
|
-
(
|
|
159
|
-
block_cls_name.endswith(allow_suffix)
|
|
160
|
-
for allow_suffix in allow_suffixes
|
|
161
|
-
)
|
|
162
|
-
)
|
|
163
|
-
or (not check_suffixes)
|
|
164
|
-
):
|
|
165
|
-
# May check forward pattern
|
|
166
|
-
if forward_pattern is not None:
|
|
167
|
-
if BlockAdapter.match_blocks_pattern(
|
|
168
|
-
blocks,
|
|
169
|
-
forward_pattern,
|
|
170
|
-
logging=False,
|
|
171
|
-
):
|
|
172
|
-
valid_names.append(blocks_name)
|
|
173
|
-
valid_count.append(len(blocks))
|
|
174
|
-
else:
|
|
175
|
-
valid_names.append(blocks_name)
|
|
176
|
-
valid_count.append(len(blocks))
|
|
177
|
-
|
|
178
|
-
if not valid_names:
|
|
179
|
-
raise ValueError(
|
|
180
|
-
"Auto selected transformer blocks failed, please set it manually."
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
final_name = valid_names[0]
|
|
184
|
-
final_count = valid_count[0]
|
|
185
|
-
block_policy = kwargs.get("blocks_policy", "max")
|
|
186
|
-
|
|
187
|
-
for blocks_name, count in zip(valid_names, valid_count):
|
|
188
|
-
blocks = getattr(transformer, blocks_name)
|
|
189
|
-
logger.info(
|
|
190
|
-
f"Auto selected transformer blocks: {blocks_name}, "
|
|
191
|
-
f"class: {blocks[0].__class__.__name__}, "
|
|
192
|
-
f"num blocks: {count}"
|
|
193
|
-
)
|
|
194
|
-
if block_policy == "max":
|
|
195
|
-
if final_count < count:
|
|
196
|
-
final_count = count
|
|
197
|
-
final_name = blocks_name
|
|
198
|
-
else:
|
|
199
|
-
if final_count > count:
|
|
200
|
-
final_count = count
|
|
201
|
-
final_name = blocks_name
|
|
202
|
-
|
|
203
|
-
final_blocks = getattr(transformer, final_name)
|
|
204
|
-
|
|
205
|
-
logger.info(
|
|
206
|
-
f"Final selected transformer blocks: {final_name}, "
|
|
207
|
-
f"class: {final_blocks[0].__class__.__name__}, "
|
|
208
|
-
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
return final_blocks, final_name
|
|
212
|
-
|
|
213
|
-
@staticmethod
|
|
214
|
-
def match_block_pattern(
|
|
215
|
-
block: torch.nn.Module,
|
|
216
|
-
forward_pattern: ForwardPattern,
|
|
217
|
-
) -> bool:
|
|
218
|
-
assert (
|
|
219
|
-
forward_pattern.Supported
|
|
220
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
221
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
222
|
-
|
|
223
|
-
forward_parameters = set(
|
|
224
|
-
inspect.signature(block.forward).parameters.keys()
|
|
225
|
-
)
|
|
226
|
-
num_outputs = str(
|
|
227
|
-
inspect.signature(block.forward).return_annotation
|
|
228
|
-
).count("torch.Tensor")
|
|
229
|
-
|
|
230
|
-
in_matched = True
|
|
231
|
-
out_matched = True
|
|
232
|
-
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
233
|
-
# output pattern not match
|
|
234
|
-
out_matched = False
|
|
235
|
-
|
|
236
|
-
for required_param in forward_pattern.In:
|
|
237
|
-
if required_param not in forward_parameters:
|
|
238
|
-
in_matched = False
|
|
239
|
-
|
|
240
|
-
return in_matched and out_matched
|
|
241
|
-
|
|
242
|
-
@staticmethod
|
|
243
|
-
def match_blocks_pattern(
|
|
244
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
245
|
-
forward_pattern: ForwardPattern,
|
|
246
|
-
logging: bool = True,
|
|
247
|
-
) -> bool:
|
|
248
|
-
assert (
|
|
249
|
-
forward_pattern.Supported
|
|
250
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
251
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
252
|
-
|
|
253
|
-
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
254
|
-
|
|
255
|
-
pattern_matched_states = []
|
|
256
|
-
for block in transformer_blocks:
|
|
257
|
-
pattern_matched_states.append(
|
|
258
|
-
BlockAdapter.match_block_pattern(
|
|
259
|
-
block,
|
|
260
|
-
forward_pattern,
|
|
261
|
-
)
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
pattern_matched = all(pattern_matched_states) # all block match
|
|
265
|
-
if pattern_matched and logging:
|
|
266
|
-
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
267
|
-
logger.info(
|
|
268
|
-
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
269
|
-
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
return pattern_matched
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
@dataclasses.dataclass
|
|
276
|
-
class UnifiedCacheParams:
|
|
277
|
-
block_adapter: BlockAdapter = None
|
|
278
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
class UnifiedCacheAdapter:
|
|
282
|
-
_supported_pipelines = [
|
|
283
|
-
"Flux",
|
|
284
|
-
"Mochi",
|
|
285
|
-
"CogVideoX",
|
|
286
|
-
"Wan",
|
|
287
|
-
"HunyuanVideo",
|
|
288
|
-
"QwenImage",
|
|
289
|
-
"LTXVideo",
|
|
290
|
-
"Allegro",
|
|
291
|
-
"CogView3Plus",
|
|
292
|
-
"CogView4",
|
|
293
|
-
"Cosmos",
|
|
294
|
-
"EasyAnimate",
|
|
295
|
-
"SkyReelsV2",
|
|
296
|
-
"SD3",
|
|
297
|
-
"ConsisID",
|
|
298
|
-
"DiT",
|
|
299
|
-
"Amused",
|
|
300
|
-
"Bria",
|
|
301
|
-
"HunyuanDiT",
|
|
302
|
-
"HunyuanDiTPAG",
|
|
303
|
-
"Lumina",
|
|
304
|
-
"Lumina2",
|
|
305
|
-
"OmniGen",
|
|
306
|
-
"PixArt",
|
|
307
|
-
"Sana",
|
|
308
|
-
"ShapE",
|
|
309
|
-
"StableAudio",
|
|
310
|
-
"VisualCloze",
|
|
311
|
-
"AuraFlow",
|
|
312
|
-
"Chroma",
|
|
313
|
-
"HiDream",
|
|
314
|
-
]
|
|
21
|
+
# Unified Cached Adapter
|
|
22
|
+
class CachedAdapter:
|
|
315
23
|
|
|
316
24
|
def __call__(self, *args, **kwargs):
|
|
317
25
|
return self.apply(*args, **kwargs)
|
|
318
26
|
|
|
319
|
-
@classmethod
|
|
320
|
-
def supported_pipelines(cls) -> Tuple[int, List[str]]:
|
|
321
|
-
return len(cls._supported_pipelines), [
|
|
322
|
-
p + "*" for p in cls._supported_pipelines
|
|
323
|
-
]
|
|
324
|
-
|
|
325
|
-
@classmethod
|
|
326
|
-
def is_supported(cls, pipe: DiffusionPipeline) -> bool:
|
|
327
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
328
|
-
for prefix in cls._supported_pipelines:
|
|
329
|
-
if pipe_cls_name.startswith(prefix):
|
|
330
|
-
return True
|
|
331
|
-
return False
|
|
332
|
-
|
|
333
|
-
@classmethod
|
|
334
|
-
def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
|
|
335
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
336
|
-
|
|
337
|
-
if pipe_cls_name.startswith("Flux"):
|
|
338
|
-
from diffusers import FluxTransformer2DModel
|
|
339
|
-
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
340
|
-
|
|
341
|
-
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
342
|
-
return UnifiedCacheParams(
|
|
343
|
-
block_adapter=BlockAdapter(
|
|
344
|
-
pipe=pipe,
|
|
345
|
-
transformer=pipe.transformer,
|
|
346
|
-
blocks=(
|
|
347
|
-
pipe.transformer.transformer_blocks
|
|
348
|
-
+ pipe.transformer.single_transformer_blocks
|
|
349
|
-
),
|
|
350
|
-
blocks_name="transformer_blocks",
|
|
351
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
352
|
-
patch_functor=FluxPatchFunctor(),
|
|
353
|
-
),
|
|
354
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
elif pipe_cls_name.startswith("Mochi"):
|
|
358
|
-
from diffusers import MochiTransformer3DModel
|
|
359
|
-
|
|
360
|
-
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
361
|
-
return UnifiedCacheParams(
|
|
362
|
-
block_adapter=BlockAdapter(
|
|
363
|
-
pipe=pipe,
|
|
364
|
-
transformer=pipe.transformer,
|
|
365
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
366
|
-
blocks_name="transformer_blocks",
|
|
367
|
-
dummy_blocks_names=[],
|
|
368
|
-
),
|
|
369
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
elif pipe_cls_name.startswith("CogVideoX"):
|
|
373
|
-
from diffusers import CogVideoXTransformer3DModel
|
|
374
|
-
|
|
375
|
-
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
376
|
-
return UnifiedCacheParams(
|
|
377
|
-
block_adapter=BlockAdapter(
|
|
378
|
-
pipe=pipe,
|
|
379
|
-
transformer=pipe.transformer,
|
|
380
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
381
|
-
blocks_name="transformer_blocks",
|
|
382
|
-
dummy_blocks_names=[],
|
|
383
|
-
),
|
|
384
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
elif pipe_cls_name.startswith("Wan"):
|
|
388
|
-
from diffusers import (
|
|
389
|
-
WanTransformer3DModel,
|
|
390
|
-
WanVACETransformer3DModel,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
assert isinstance(
|
|
394
|
-
pipe.transformer,
|
|
395
|
-
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
396
|
-
)
|
|
397
|
-
if getattr(pipe, "transformer_2", None):
|
|
398
|
-
# Wan 2.2, cache for low-noise transformer
|
|
399
|
-
assert isinstance(
|
|
400
|
-
pipe.transformer_2,
|
|
401
|
-
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
402
|
-
)
|
|
403
|
-
return UnifiedCacheParams(
|
|
404
|
-
block_adapter=BlockAdapter(
|
|
405
|
-
pipe=pipe,
|
|
406
|
-
transformer=pipe.transformer_2,
|
|
407
|
-
blocks=pipe.transformer_2.blocks,
|
|
408
|
-
blocks_name="blocks",
|
|
409
|
-
dummy_blocks_names=[],
|
|
410
|
-
),
|
|
411
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
412
|
-
)
|
|
413
|
-
else:
|
|
414
|
-
# Wan 2.1
|
|
415
|
-
return UnifiedCacheParams(
|
|
416
|
-
block_adapter=BlockAdapter(
|
|
417
|
-
pipe=pipe,
|
|
418
|
-
transformer=pipe.transformer,
|
|
419
|
-
blocks=pipe.transformer.blocks,
|
|
420
|
-
blocks_name="blocks",
|
|
421
|
-
dummy_blocks_names=[],
|
|
422
|
-
),
|
|
423
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
427
|
-
from diffusers import HunyuanVideoTransformer3DModel
|
|
428
|
-
|
|
429
|
-
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
430
|
-
return UnifiedCacheParams(
|
|
431
|
-
block_adapter=BlockAdapter(
|
|
432
|
-
pipe=pipe,
|
|
433
|
-
blocks=(
|
|
434
|
-
pipe.transformer.transformer_blocks
|
|
435
|
-
+ pipe.transformer.single_transformer_blocks
|
|
436
|
-
),
|
|
437
|
-
blocks_name="transformer_blocks",
|
|
438
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
439
|
-
),
|
|
440
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
441
|
-
)
|
|
442
|
-
|
|
443
|
-
elif pipe_cls_name.startswith("QwenImage"):
|
|
444
|
-
from diffusers import QwenImageTransformer2DModel
|
|
445
|
-
|
|
446
|
-
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
447
|
-
return UnifiedCacheParams(
|
|
448
|
-
block_adapter=BlockAdapter(
|
|
449
|
-
pipe=pipe,
|
|
450
|
-
transformer=pipe.transformer,
|
|
451
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
452
|
-
blocks_name="transformer_blocks",
|
|
453
|
-
dummy_blocks_names=[],
|
|
454
|
-
),
|
|
455
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
elif pipe_cls_name.startswith("LTXVideo"):
|
|
459
|
-
from diffusers import LTXVideoTransformer3DModel
|
|
460
|
-
|
|
461
|
-
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
462
|
-
return UnifiedCacheParams(
|
|
463
|
-
block_adapter=BlockAdapter(
|
|
464
|
-
pipe=pipe,
|
|
465
|
-
transformer=pipe.transformer,
|
|
466
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
467
|
-
blocks_name="transformer_blocks",
|
|
468
|
-
dummy_blocks_names=[],
|
|
469
|
-
),
|
|
470
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
elif pipe_cls_name.startswith("Allegro"):
|
|
474
|
-
from diffusers import AllegroTransformer3DModel
|
|
475
|
-
|
|
476
|
-
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
477
|
-
return UnifiedCacheParams(
|
|
478
|
-
block_adapter=BlockAdapter(
|
|
479
|
-
pipe=pipe,
|
|
480
|
-
transformer=pipe.transformer,
|
|
481
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
482
|
-
blocks_name="transformer_blocks",
|
|
483
|
-
dummy_blocks_names=[],
|
|
484
|
-
),
|
|
485
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
elif pipe_cls_name.startswith("CogView3Plus"):
|
|
489
|
-
from diffusers import CogView3PlusTransformer2DModel
|
|
490
|
-
|
|
491
|
-
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
492
|
-
return UnifiedCacheParams(
|
|
493
|
-
block_adapter=BlockAdapter(
|
|
494
|
-
pipe=pipe,
|
|
495
|
-
transformer=pipe.transformer,
|
|
496
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
497
|
-
blocks_name="transformer_blocks",
|
|
498
|
-
dummy_blocks_names=[],
|
|
499
|
-
),
|
|
500
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
501
|
-
)
|
|
502
|
-
|
|
503
|
-
elif pipe_cls_name.startswith("CogView4"):
|
|
504
|
-
from diffusers import CogView4Transformer2DModel
|
|
505
|
-
|
|
506
|
-
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
507
|
-
return UnifiedCacheParams(
|
|
508
|
-
block_adapter=BlockAdapter(
|
|
509
|
-
pipe=pipe,
|
|
510
|
-
transformer=pipe.transformer,
|
|
511
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
512
|
-
blocks_name="transformer_blocks",
|
|
513
|
-
dummy_blocks_names=[],
|
|
514
|
-
),
|
|
515
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
elif pipe_cls_name.startswith("Cosmos"):
|
|
519
|
-
from diffusers import CosmosTransformer3DModel
|
|
520
|
-
|
|
521
|
-
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
522
|
-
return UnifiedCacheParams(
|
|
523
|
-
block_adapter=BlockAdapter(
|
|
524
|
-
pipe=pipe,
|
|
525
|
-
transformer=pipe.transformer,
|
|
526
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
527
|
-
blocks_name="transformer_blocks",
|
|
528
|
-
dummy_blocks_names=[],
|
|
529
|
-
),
|
|
530
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
531
|
-
)
|
|
532
|
-
|
|
533
|
-
elif pipe_cls_name.startswith("EasyAnimate"):
|
|
534
|
-
from diffusers import EasyAnimateTransformer3DModel
|
|
535
|
-
|
|
536
|
-
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
537
|
-
return UnifiedCacheParams(
|
|
538
|
-
block_adapter=BlockAdapter(
|
|
539
|
-
pipe=pipe,
|
|
540
|
-
transformer=pipe.transformer,
|
|
541
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
542
|
-
blocks_name="transformer_blocks",
|
|
543
|
-
dummy_blocks_names=[],
|
|
544
|
-
),
|
|
545
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
elif pipe_cls_name.startswith("SkyReelsV2"):
|
|
549
|
-
from diffusers import SkyReelsV2Transformer3DModel
|
|
550
|
-
|
|
551
|
-
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
552
|
-
return UnifiedCacheParams(
|
|
553
|
-
block_adapter=BlockAdapter(
|
|
554
|
-
pipe=pipe,
|
|
555
|
-
transformer=pipe.transformer,
|
|
556
|
-
blocks=pipe.transformer.blocks,
|
|
557
|
-
blocks_name="blocks",
|
|
558
|
-
dummy_blocks_names=[],
|
|
559
|
-
),
|
|
560
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
561
|
-
)
|
|
562
|
-
elif pipe_cls_name.startswith("SD3"):
|
|
563
|
-
from diffusers import SD3Transformer2DModel
|
|
564
|
-
|
|
565
|
-
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
566
|
-
return UnifiedCacheParams(
|
|
567
|
-
block_adapter=BlockAdapter(
|
|
568
|
-
pipe=pipe,
|
|
569
|
-
transformer=pipe.transformer,
|
|
570
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
571
|
-
blocks_name="transformer_blocks",
|
|
572
|
-
dummy_blocks_names=[],
|
|
573
|
-
),
|
|
574
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
575
|
-
)
|
|
576
|
-
|
|
577
|
-
elif pipe_cls_name.startswith("ConsisID"):
|
|
578
|
-
from diffusers import ConsisIDTransformer3DModel
|
|
579
|
-
|
|
580
|
-
assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
|
|
581
|
-
return UnifiedCacheParams(
|
|
582
|
-
block_adapter=BlockAdapter(
|
|
583
|
-
pipe=pipe,
|
|
584
|
-
transformer=pipe.transformer,
|
|
585
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
586
|
-
blocks_name="transformer_blocks",
|
|
587
|
-
dummy_blocks_names=[],
|
|
588
|
-
),
|
|
589
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
elif pipe_cls_name.startswith("DiT"):
|
|
593
|
-
from diffusers import DiTTransformer2DModel
|
|
594
|
-
|
|
595
|
-
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
596
|
-
return UnifiedCacheParams(
|
|
597
|
-
block_adapter=BlockAdapter(
|
|
598
|
-
pipe=pipe,
|
|
599
|
-
transformer=pipe.transformer,
|
|
600
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
601
|
-
blocks_name="transformer_blocks",
|
|
602
|
-
dummy_blocks_names=[],
|
|
603
|
-
),
|
|
604
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
elif pipe_cls_name.startswith("Amused"):
|
|
608
|
-
from diffusers import UVit2DModel
|
|
609
|
-
|
|
610
|
-
assert isinstance(pipe.transformer, UVit2DModel)
|
|
611
|
-
return UnifiedCacheParams(
|
|
612
|
-
block_adapter=BlockAdapter(
|
|
613
|
-
pipe=pipe,
|
|
614
|
-
transformer=pipe.transformer,
|
|
615
|
-
blocks=pipe.transformer.transformer_layers,
|
|
616
|
-
blocks_name="transformer_layers",
|
|
617
|
-
dummy_blocks_names=[],
|
|
618
|
-
),
|
|
619
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
620
|
-
)
|
|
621
|
-
|
|
622
|
-
elif pipe_cls_name.startswith("Bria"):
|
|
623
|
-
from diffusers import BriaTransformer2DModel
|
|
624
|
-
|
|
625
|
-
assert isinstance(pipe.transformer, BriaTransformer2DModel)
|
|
626
|
-
return UnifiedCacheParams(
|
|
627
|
-
block_adapter=BlockAdapter(
|
|
628
|
-
pipe=pipe,
|
|
629
|
-
transformer=pipe.transformer,
|
|
630
|
-
blocks=(
|
|
631
|
-
pipe.transformer.transformer_blocks
|
|
632
|
-
+ pipe.transformer.single_transformer_blocks
|
|
633
|
-
),
|
|
634
|
-
blocks_name="transformer_blocks",
|
|
635
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
636
|
-
),
|
|
637
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
638
|
-
)
|
|
639
|
-
|
|
640
|
-
elif pipe_cls_name.startswith("HunyuanDiT"):
|
|
641
|
-
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
642
|
-
|
|
643
|
-
assert isinstance(
|
|
644
|
-
pipe.transformer,
|
|
645
|
-
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
646
|
-
)
|
|
647
|
-
return UnifiedCacheParams(
|
|
648
|
-
block_adapter=BlockAdapter(
|
|
649
|
-
pipe=pipe,
|
|
650
|
-
transformer=pipe.transformer,
|
|
651
|
-
blocks=pipe.transformer.blocks,
|
|
652
|
-
blocks_name="blocks",
|
|
653
|
-
dummy_blocks_names=[],
|
|
654
|
-
),
|
|
655
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
656
|
-
)
|
|
657
|
-
|
|
658
|
-
elif pipe_cls_name.startswith("HunyuanDiTPAG"):
|
|
659
|
-
from diffusers import HunyuanDiT2DModel
|
|
660
|
-
|
|
661
|
-
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
662
|
-
return UnifiedCacheParams(
|
|
663
|
-
block_adapter=BlockAdapter(
|
|
664
|
-
pipe=pipe,
|
|
665
|
-
transformer=pipe.transformer,
|
|
666
|
-
blocks=pipe.transformer.blocks,
|
|
667
|
-
blocks_name="blocks",
|
|
668
|
-
dummy_blocks_names=[],
|
|
669
|
-
),
|
|
670
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
671
|
-
)
|
|
672
|
-
|
|
673
|
-
elif pipe_cls_name.startswith("Lumina"):
|
|
674
|
-
from diffusers import LuminaNextDiT2DModel
|
|
675
|
-
|
|
676
|
-
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
677
|
-
return UnifiedCacheParams(
|
|
678
|
-
block_adapter=BlockAdapter(
|
|
679
|
-
pipe=pipe,
|
|
680
|
-
transformer=pipe.transformer,
|
|
681
|
-
blocks=pipe.transformer.layers,
|
|
682
|
-
blocks_name="layers",
|
|
683
|
-
dummy_blocks_names=[],
|
|
684
|
-
),
|
|
685
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
elif pipe_cls_name.startswith("Lumina2"):
|
|
689
|
-
from diffusers import Lumina2Transformer2DModel
|
|
690
|
-
|
|
691
|
-
assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
|
|
692
|
-
return UnifiedCacheParams(
|
|
693
|
-
block_adapter=BlockAdapter(
|
|
694
|
-
pipe=pipe,
|
|
695
|
-
transformer=pipe.transformer,
|
|
696
|
-
blocks=pipe.transformer.layers,
|
|
697
|
-
blocks_name="layers",
|
|
698
|
-
dummy_blocks_names=[],
|
|
699
|
-
),
|
|
700
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
701
|
-
)
|
|
702
|
-
|
|
703
|
-
elif pipe_cls_name.startswith("OmniGen"):
|
|
704
|
-
from diffusers import OmniGenTransformer2DModel
|
|
705
|
-
|
|
706
|
-
assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
|
|
707
|
-
return UnifiedCacheParams(
|
|
708
|
-
block_adapter=BlockAdapter(
|
|
709
|
-
pipe=pipe,
|
|
710
|
-
transformer=pipe.transformer,
|
|
711
|
-
blocks=pipe.transformer.layers,
|
|
712
|
-
blocks_name="layers",
|
|
713
|
-
dummy_blocks_names=[],
|
|
714
|
-
),
|
|
715
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
716
|
-
)
|
|
717
|
-
|
|
718
|
-
elif pipe_cls_name.startswith("PixArt"):
|
|
719
|
-
from diffusers import PixArtTransformer2DModel
|
|
720
|
-
|
|
721
|
-
assert isinstance(pipe.transformer, PixArtTransformer2DModel)
|
|
722
|
-
return UnifiedCacheParams(
|
|
723
|
-
block_adapter=BlockAdapter(
|
|
724
|
-
pipe=pipe,
|
|
725
|
-
transformer=pipe.transformer,
|
|
726
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
727
|
-
blocks_name="transformer_blocks",
|
|
728
|
-
dummy_blocks_names=[],
|
|
729
|
-
),
|
|
730
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
731
|
-
)
|
|
732
|
-
|
|
733
|
-
elif pipe_cls_name.startswith("Sana"):
|
|
734
|
-
from diffusers import SanaTransformer2DModel
|
|
735
|
-
|
|
736
|
-
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
737
|
-
return UnifiedCacheParams(
|
|
738
|
-
block_adapter=BlockAdapter(
|
|
739
|
-
pipe=pipe,
|
|
740
|
-
transformer=pipe.transformer,
|
|
741
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
742
|
-
blocks_name="transformer_blocks",
|
|
743
|
-
dummy_blocks_names=[],
|
|
744
|
-
),
|
|
745
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
746
|
-
)
|
|
747
|
-
|
|
748
|
-
elif pipe_cls_name.startswith("ShapE"):
|
|
749
|
-
from diffusers import PriorTransformer
|
|
750
|
-
|
|
751
|
-
assert isinstance(pipe.prior, PriorTransformer)
|
|
752
|
-
return UnifiedCacheParams(
|
|
753
|
-
block_adapter=BlockAdapter(
|
|
754
|
-
pipe=pipe,
|
|
755
|
-
transformer=pipe.prior,
|
|
756
|
-
blocks=pipe.prior.transformer_blocks,
|
|
757
|
-
blocks_name="transformer_blocks",
|
|
758
|
-
dummy_blocks_names=[],
|
|
759
|
-
),
|
|
760
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
761
|
-
)
|
|
762
|
-
|
|
763
|
-
elif pipe_cls_name.startswith("StableAudio"):
|
|
764
|
-
from diffusers import StableAudioDiTModel
|
|
765
|
-
|
|
766
|
-
assert isinstance(pipe.transformer, StableAudioDiTModel)
|
|
767
|
-
return UnifiedCacheParams(
|
|
768
|
-
block_adapter=BlockAdapter(
|
|
769
|
-
pipe=pipe,
|
|
770
|
-
transformer=pipe.transformer,
|
|
771
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
772
|
-
blocks_name="transformer_blocks",
|
|
773
|
-
dummy_blocks_names=[],
|
|
774
|
-
),
|
|
775
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
776
|
-
)
|
|
777
|
-
|
|
778
|
-
elif pipe_cls_name.startswith("VisualCloze"):
|
|
779
|
-
from diffusers import FluxTransformer2DModel
|
|
780
|
-
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
781
|
-
|
|
782
|
-
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
783
|
-
return UnifiedCacheParams(
|
|
784
|
-
block_adapter=BlockAdapter(
|
|
785
|
-
pipe=pipe,
|
|
786
|
-
transformer=pipe.transformer,
|
|
787
|
-
blocks=(
|
|
788
|
-
pipe.transformer.transformer_blocks
|
|
789
|
-
+ pipe.transformer.single_transformer_blocks
|
|
790
|
-
),
|
|
791
|
-
blocks_name="transformer_blocks",
|
|
792
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
793
|
-
patch_functor=FluxPatchFunctor(),
|
|
794
|
-
),
|
|
795
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
796
|
-
)
|
|
797
|
-
|
|
798
|
-
elif pipe_cls_name.startswith("AuraFlow"):
|
|
799
|
-
from diffusers import AuraFlowTransformer2DModel
|
|
800
|
-
|
|
801
|
-
assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
|
|
802
|
-
return UnifiedCacheParams(
|
|
803
|
-
block_adapter=BlockAdapter(
|
|
804
|
-
pipe=pipe,
|
|
805
|
-
transformer=pipe.transformer,
|
|
806
|
-
# Only support caching single_transformer_blocks for AuraFlow now.
|
|
807
|
-
# TODO: Support AuraFlowPatchFunctor.
|
|
808
|
-
blocks=pipe.transformer.single_transformer_blocks,
|
|
809
|
-
blocks_name="single_transformer_blocks",
|
|
810
|
-
dummy_blocks_names=[],
|
|
811
|
-
),
|
|
812
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
elif pipe_cls_name.startswith("Chroma"):
|
|
816
|
-
from diffusers import ChromaTransformer2DModel
|
|
817
|
-
from cache_dit.cache_factory.patch_functors import (
|
|
818
|
-
ChromaPatchFunctor,
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
822
|
-
return UnifiedCacheParams(
|
|
823
|
-
block_adapter=BlockAdapter(
|
|
824
|
-
pipe=pipe,
|
|
825
|
-
transformer=pipe.transformer,
|
|
826
|
-
blocks=(
|
|
827
|
-
pipe.transformer.transformer_blocks
|
|
828
|
-
+ pipe.transformer.single_transformer_blocks
|
|
829
|
-
),
|
|
830
|
-
blocks_name="transformer_blocks",
|
|
831
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
832
|
-
patch_functor=ChromaPatchFunctor(),
|
|
833
|
-
),
|
|
834
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
835
|
-
)
|
|
836
|
-
|
|
837
|
-
elif pipe_cls_name.startswith("HiDream"):
|
|
838
|
-
from diffusers import HiDreamImageTransformer2DModel
|
|
839
|
-
|
|
840
|
-
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
841
|
-
return UnifiedCacheParams(
|
|
842
|
-
block_adapter=BlockAdapter(
|
|
843
|
-
pipe=pipe,
|
|
844
|
-
transformer=pipe.transformer,
|
|
845
|
-
# Only support caching single_stream_blocks for HiDream now.
|
|
846
|
-
# TODO: Support HiDreamPatchFunctor.
|
|
847
|
-
blocks=pipe.transformer.single_stream_blocks,
|
|
848
|
-
blocks_name="single_stream_blocks",
|
|
849
|
-
dummy_blocks_names=[],
|
|
850
|
-
),
|
|
851
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
852
|
-
)
|
|
853
|
-
|
|
854
|
-
else:
|
|
855
|
-
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
856
|
-
|
|
857
27
|
@classmethod
|
|
858
28
|
def apply(
|
|
859
29
|
cls,
|
|
860
30
|
pipe: DiffusionPipeline = None,
|
|
861
31
|
block_adapter: BlockAdapter = None,
|
|
862
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
32
|
+
# forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
863
33
|
**cache_context_kwargs,
|
|
864
34
|
) -> DiffusionPipeline:
|
|
865
35
|
assert (
|
|
@@ -867,15 +37,14 @@ class UnifiedCacheAdapter:
|
|
|
867
37
|
), "pipe or block_adapter can not both None!"
|
|
868
38
|
|
|
869
39
|
if pipe is not None:
|
|
870
|
-
if
|
|
40
|
+
if BlockAdapterRegistry.is_supported(pipe):
|
|
871
41
|
logger.info(
|
|
872
42
|
f"{pipe.__class__.__name__} is officially supported by cache-dit. "
|
|
873
43
|
"Use it's pre-defined BlockAdapter directly!"
|
|
874
44
|
)
|
|
875
|
-
|
|
45
|
+
block_adapter = BlockAdapterRegistry.get_adapter(pipe)
|
|
876
46
|
return cls.cachify(
|
|
877
|
-
|
|
878
|
-
forward_pattern=params.forward_pattern,
|
|
47
|
+
block_adapter,
|
|
879
48
|
**cache_context_kwargs,
|
|
880
49
|
)
|
|
881
50
|
else:
|
|
@@ -889,7 +58,6 @@ class UnifiedCacheAdapter:
|
|
|
889
58
|
)
|
|
890
59
|
return cls.cachify(
|
|
891
60
|
block_adapter,
|
|
892
|
-
forward_pattern=forward_pattern,
|
|
893
61
|
**cache_context_kwargs,
|
|
894
62
|
)
|
|
895
63
|
|
|
@@ -897,31 +65,27 @@ class UnifiedCacheAdapter:
|
|
|
897
65
|
def cachify(
|
|
898
66
|
cls,
|
|
899
67
|
block_adapter: BlockAdapter,
|
|
900
|
-
*,
|
|
901
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
902
68
|
**cache_context_kwargs,
|
|
903
69
|
) -> DiffusionPipeline:
|
|
904
70
|
|
|
905
71
|
if block_adapter.auto:
|
|
906
72
|
block_adapter = BlockAdapter.auto_block_adapter(
|
|
907
73
|
block_adapter,
|
|
908
|
-
forward_pattern,
|
|
909
74
|
)
|
|
910
75
|
|
|
911
76
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
912
|
-
|
|
77
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
78
|
+
# 0. Apply cache on pipeline: wrap cache context
|
|
913
79
|
cls.create_context(
|
|
914
|
-
block_adapter
|
|
80
|
+
block_adapter,
|
|
915
81
|
**cache_context_kwargs,
|
|
916
82
|
)
|
|
917
|
-
# Apply cache on transformer: mock cached transformer blocks
|
|
83
|
+
# 1. Apply cache on transformer: mock cached transformer blocks
|
|
918
84
|
cls.mock_blocks(
|
|
919
85
|
block_adapter,
|
|
920
|
-
forward_pattern=forward_pattern,
|
|
921
86
|
)
|
|
922
87
|
cls.patch_params(
|
|
923
88
|
block_adapter,
|
|
924
|
-
forward_pattern=forward_pattern,
|
|
925
89
|
**cache_context_kwargs,
|
|
926
90
|
)
|
|
927
91
|
return block_adapter.pipe
|
|
@@ -930,41 +94,36 @@ class UnifiedCacheAdapter:
|
|
|
930
94
|
def patch_params(
|
|
931
95
|
cls,
|
|
932
96
|
block_adapter: BlockAdapter,
|
|
933
|
-
forward_pattern: ForwardPattern = None,
|
|
934
97
|
**cache_context_kwargs,
|
|
935
98
|
):
|
|
936
|
-
block_adapter.transformer._forward_pattern =
|
|
99
|
+
block_adapter.transformer._forward_pattern = (
|
|
100
|
+
block_adapter.forward_pattern
|
|
101
|
+
)
|
|
102
|
+
block_adapter.transformer._has_separate_cfg = (
|
|
103
|
+
block_adapter.has_separate_cfg
|
|
104
|
+
)
|
|
937
105
|
block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
|
|
938
106
|
block_adapter.pipe.__class__._cache_context_kwargs = (
|
|
939
107
|
cache_context_kwargs
|
|
940
108
|
)
|
|
941
|
-
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
) -> bool:
|
|
947
|
-
cls_name = pipe_or_transformer.__class__.__name__
|
|
948
|
-
if cls_name.startswith("QwenImage"):
|
|
949
|
-
return True
|
|
950
|
-
elif cls_name.startswith("Wan"):
|
|
951
|
-
return True
|
|
952
|
-
elif cls_name.startswith("CogView4"):
|
|
953
|
-
return True
|
|
954
|
-
elif cls_name.startswith("Cosmos"):
|
|
955
|
-
return True
|
|
956
|
-
elif cls_name.startswith("SkyReelsV2"):
|
|
957
|
-
return True
|
|
958
|
-
elif cls_name.startswith("Chroma"):
|
|
959
|
-
return True
|
|
960
|
-
return False
|
|
109
|
+
for blocks, forward_pattern in zip(
|
|
110
|
+
block_adapter.blocks, block_adapter.forward_pattern
|
|
111
|
+
):
|
|
112
|
+
blocks._forward_pattern = forward_pattern
|
|
113
|
+
blocks._cache_context_kwargs = cache_context_kwargs
|
|
961
114
|
|
|
962
115
|
@classmethod
|
|
963
116
|
def check_context_kwargs(cls, pipe, **cache_context_kwargs):
|
|
964
117
|
# Check cache_context_kwargs
|
|
965
118
|
if not cache_context_kwargs["do_separate_cfg"]:
|
|
966
119
|
# Check cfg for some specific case if users don't set it as True
|
|
967
|
-
cache_context_kwargs["do_separate_cfg"] =
|
|
120
|
+
cache_context_kwargs["do_separate_cfg"] = (
|
|
121
|
+
BlockAdapterRegistry.has_separate_cfg(pipe)
|
|
122
|
+
)
|
|
123
|
+
logger.info(
|
|
124
|
+
f"Use default 'do_separate_cfg': {cache_context_kwargs['do_separate_cfg']}, "
|
|
125
|
+
f"Pipeline: {pipe.__class__.__name__}."
|
|
126
|
+
)
|
|
968
127
|
|
|
969
128
|
if cache_type := cache_context_kwargs.pop("cache_type", None):
|
|
970
129
|
assert (
|
|
@@ -976,65 +135,87 @@ class UnifiedCacheAdapter:
|
|
|
976
135
|
@classmethod
|
|
977
136
|
def create_context(
|
|
978
137
|
cls,
|
|
979
|
-
|
|
138
|
+
block_adapter: BlockAdapter,
|
|
980
139
|
**cache_context_kwargs,
|
|
981
140
|
) -> DiffusionPipeline:
|
|
982
|
-
if getattr(pipe, "_is_cached", False):
|
|
983
|
-
return pipe
|
|
141
|
+
if getattr(block_adapter.pipe, "_is_cached", False):
|
|
142
|
+
return block_adapter.pipe
|
|
984
143
|
|
|
985
144
|
# Check cache_context_kwargs
|
|
986
145
|
cache_context_kwargs = cls.check_context_kwargs(
|
|
987
|
-
pipe,
|
|
146
|
+
block_adapter.pipe,
|
|
988
147
|
**cache_context_kwargs,
|
|
989
148
|
)
|
|
990
149
|
# Apply cache on pipeline: wrap cache context
|
|
991
|
-
cache_kwargs, _ =
|
|
150
|
+
cache_kwargs, _ = CachedContext.collect_cache_kwargs(
|
|
992
151
|
default_attrs={},
|
|
993
152
|
**cache_context_kwargs,
|
|
994
153
|
)
|
|
995
|
-
original_call = pipe.__class__.__call__
|
|
154
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
996
155
|
|
|
997
156
|
@functools.wraps(original_call)
|
|
998
157
|
def new_call(self, *args, **kwargs):
|
|
999
|
-
with
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
158
|
+
with ExitStack() as stack:
|
|
159
|
+
# cache context will reset for each pipe inference
|
|
160
|
+
for blocks_name in block_adapter.blocks_name:
|
|
161
|
+
stack.enter_context(
|
|
162
|
+
CachedContext.cache_context(
|
|
163
|
+
CachedContext.reset_cache_context(
|
|
164
|
+
blocks_name,
|
|
165
|
+
**cache_kwargs,
|
|
166
|
+
),
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
outputs = original_call(self, *args, **kwargs)
|
|
170
|
+
cls.patch_stats(block_adapter)
|
|
171
|
+
return outputs
|
|
172
|
+
|
|
173
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
174
|
+
block_adapter.pipe.__class__._is_cached = True
|
|
175
|
+
return block_adapter.pipe
|
|
1005
176
|
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
|
|
177
|
+
@classmethod
|
|
178
|
+
def patch_stats(cls, block_adapter: BlockAdapter):
|
|
179
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
180
|
+
patch_cached_stats,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
patch_cached_stats(block_adapter.transformer)
|
|
184
|
+
for blocks, blocks_name in zip(
|
|
185
|
+
block_adapter.blocks, block_adapter.blocks_name
|
|
186
|
+
):
|
|
187
|
+
patch_cached_stats(blocks, blocks_name)
|
|
1009
188
|
|
|
1010
189
|
@classmethod
|
|
1011
190
|
def mock_blocks(
|
|
1012
191
|
cls,
|
|
1013
192
|
block_adapter: BlockAdapter,
|
|
1014
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
1015
193
|
) -> torch.nn.Module:
|
|
1016
194
|
|
|
1017
195
|
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
1018
196
|
return block_adapter.transformer
|
|
1019
197
|
|
|
1020
198
|
# Check block forward pattern matching
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
forward_pattern
|
|
1024
|
-
)
|
|
1025
|
-
|
|
1026
|
-
|
|
1027
|
-
|
|
199
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
200
|
+
for forward_pattern, blocks in zip(
|
|
201
|
+
block_adapter.forward_pattern, block_adapter.blocks
|
|
202
|
+
):
|
|
203
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
204
|
+
blocks,
|
|
205
|
+
forward_pattern=forward_pattern,
|
|
206
|
+
check_num_outputs=block_adapter.check_num_outputs,
|
|
207
|
+
), (
|
|
208
|
+
"No block forward pattern matched, "
|
|
209
|
+
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
210
|
+
)
|
|
1028
211
|
|
|
1029
212
|
# Apply cache on transformer: mock cached transformer blocks
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
)
|
|
1037
|
-
]
|
|
213
|
+
# TODO: Use blocks_name to spearate cached context for different
|
|
214
|
+
# blocks list. For example, single_transformer_blocks and
|
|
215
|
+
# transformer_blocks should have different cached context and
|
|
216
|
+
# forward pattern.
|
|
217
|
+
cached_blocks = cls.collect_cached_blocks(
|
|
218
|
+
block_adapter=block_adapter,
|
|
1038
219
|
)
|
|
1039
220
|
dummy_blocks = torch.nn.ModuleList()
|
|
1040
221
|
|
|
@@ -1045,13 +226,14 @@ class UnifiedCacheAdapter:
|
|
|
1045
226
|
@functools.wraps(original_forward)
|
|
1046
227
|
def new_forward(self, *args, **kwargs):
|
|
1047
228
|
with ExitStack() as stack:
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
229
|
+
for blocks_name in block_adapter.blocks_name:
|
|
230
|
+
stack.enter_context(
|
|
231
|
+
unittest.mock.patch.object(
|
|
232
|
+
self,
|
|
233
|
+
blocks_name,
|
|
234
|
+
cached_blocks[blocks_name],
|
|
235
|
+
)
|
|
1053
236
|
)
|
|
1054
|
-
)
|
|
1055
237
|
for dummy_name in block_adapter.dummy_blocks_names:
|
|
1056
238
|
stack.enter_context(
|
|
1057
239
|
unittest.mock.patch.object(
|
|
@@ -1068,3 +250,30 @@ class UnifiedCacheAdapter:
|
|
|
1068
250
|
block_adapter.transformer._is_cached = True
|
|
1069
251
|
|
|
1070
252
|
return block_adapter.transformer
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def collect_cached_blocks(
|
|
256
|
+
cls,
|
|
257
|
+
block_adapter: BlockAdapter,
|
|
258
|
+
) -> Dict[str, torch.nn.ModuleList]:
|
|
259
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
260
|
+
|
|
261
|
+
cached_blocks_bind_context = {}
|
|
262
|
+
|
|
263
|
+
for i in range(len(block_adapter.blocks)):
|
|
264
|
+
cached_blocks_bind_context[block_adapter.blocks_name[i]] = (
|
|
265
|
+
torch.nn.ModuleList(
|
|
266
|
+
[
|
|
267
|
+
CachedBlocks(
|
|
268
|
+
block_adapter.blocks[i],
|
|
269
|
+
block_adapter.blocks_name[i],
|
|
270
|
+
block_adapter.blocks_name[i], # context name
|
|
271
|
+
transformer=block_adapter.transformer,
|
|
272
|
+
forward_pattern=block_adapter.forward_pattern[i],
|
|
273
|
+
check_num_outputs=block_adapter.check_num_outputs,
|
|
274
|
+
)
|
|
275
|
+
]
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return cached_blocks_bind_context
|