cache-dit 0.2.26__py3-none-any.whl → 0.2.28__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +8 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +17 -4
- cache_dit/cache_factory/block_adapters/__init__.py +555 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +538 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +262 -938
- cache_dit/cache_factory/cache_blocks/__init__.py +60 -11
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +45 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +106 -80
- cache_dit/cache_factory/cache_blocks/utils.py +16 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +5 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +327 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +833 -0
- cache_dit/cache_factory/cache_interface.py +31 -31
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +26 -26
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/METADATA +59 -23
- cache_dit-0.2.28.dist-info/RECORD +47 -0
- cache_dit/cache_factory/cache_context.py +0 -1155
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.28.dist-info}/top_level.txt +0 -0
|
@@ -1,865 +1,36 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
import inspect
|
|
4
3
|
import unittest
|
|
5
4
|
import functools
|
|
6
|
-
import dataclasses
|
|
7
5
|
|
|
8
|
-
from typing import Any, Tuple, List, Optional
|
|
9
6
|
from contextlib import ExitStack
|
|
7
|
+
from typing import Dict, List, Tuple, Any
|
|
8
|
+
|
|
10
9
|
from diffusers import DiffusionPipeline
|
|
10
|
+
|
|
11
11
|
from cache_dit.cache_factory import CacheType
|
|
12
|
-
from cache_dit.cache_factory import
|
|
13
|
-
from cache_dit.cache_factory import
|
|
14
|
-
from cache_dit.cache_factory
|
|
15
|
-
from cache_dit.cache_factory
|
|
16
|
-
|
|
17
|
-
)
|
|
12
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import ParamsModifier
|
|
14
|
+
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
|
+
from cache_dit.cache_factory import CachedContextManager
|
|
16
|
+
from cache_dit.cache_factory import CachedBlocks
|
|
18
17
|
|
|
19
18
|
from cache_dit.logger import init_logger
|
|
20
19
|
|
|
21
20
|
logger = init_logger(__name__)
|
|
22
21
|
|
|
23
22
|
|
|
24
|
-
|
|
25
|
-
class
|
|
26
|
-
pipe: DiffusionPipeline | Any = None
|
|
27
|
-
transformer: torch.nn.Module = None
|
|
28
|
-
blocks: torch.nn.ModuleList = None
|
|
29
|
-
# transformer_blocks, blocks, etc.
|
|
30
|
-
blocks_name: str = None
|
|
31
|
-
dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
|
|
32
|
-
# patch functor: Flux, etc.
|
|
33
|
-
patch_functor: Optional[PatchFunctor] = None
|
|
34
|
-
# flags to control auto block adapter
|
|
35
|
-
auto: bool = False
|
|
36
|
-
allow_prefixes: List[str] = dataclasses.field(
|
|
37
|
-
default_factory=lambda: [
|
|
38
|
-
"transformer",
|
|
39
|
-
"single_transformer",
|
|
40
|
-
"blocks",
|
|
41
|
-
"layers",
|
|
42
|
-
"single_stream_blocks",
|
|
43
|
-
"double_stream_blocks",
|
|
44
|
-
]
|
|
45
|
-
)
|
|
46
|
-
check_prefixes: bool = True
|
|
47
|
-
allow_suffixes: List[str] = dataclasses.field(
|
|
48
|
-
default_factory=lambda: ["TransformerBlock"]
|
|
49
|
-
)
|
|
50
|
-
check_suffixes: bool = False
|
|
51
|
-
blocks_policy: str = dataclasses.field(
|
|
52
|
-
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
def __post_init__(self):
|
|
56
|
-
assert any((self.pipe is not None, self.transformer is not None))
|
|
57
|
-
self.patchify()
|
|
58
|
-
|
|
59
|
-
def patchify(self, *args, **kwargs):
|
|
60
|
-
# Process some specificial cases, specific for transformers
|
|
61
|
-
# that has different forward patterns between single_transformer_blocks
|
|
62
|
-
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
63
|
-
if self.patch_functor is not None:
|
|
64
|
-
if self.transformer is not None:
|
|
65
|
-
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
66
|
-
else:
|
|
67
|
-
assert hasattr(self.pipe, "transformer")
|
|
68
|
-
self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
|
|
69
|
-
|
|
70
|
-
@staticmethod
|
|
71
|
-
def auto_block_adapter(
|
|
72
|
-
adapter: "BlockAdapter",
|
|
73
|
-
forward_pattern: Optional[ForwardPattern] = None,
|
|
74
|
-
) -> "BlockAdapter":
|
|
75
|
-
assert adapter.auto, (
|
|
76
|
-
"Please manually set `auto` to True, or, manually "
|
|
77
|
-
"set all the transformer blocks configuration."
|
|
78
|
-
)
|
|
79
|
-
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
80
|
-
pipe = adapter.pipe
|
|
81
|
-
|
|
82
|
-
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
83
|
-
|
|
84
|
-
transformer = pipe.transformer
|
|
85
|
-
|
|
86
|
-
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
87
|
-
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
88
|
-
transformer=transformer,
|
|
89
|
-
allow_prefixes=adapter.allow_prefixes,
|
|
90
|
-
allow_suffixes=adapter.allow_suffixes,
|
|
91
|
-
check_prefixes=adapter.check_prefixes,
|
|
92
|
-
check_suffixes=adapter.check_suffixes,
|
|
93
|
-
blocks_policy=adapter.blocks_policy,
|
|
94
|
-
forward_pattern=forward_pattern,
|
|
95
|
-
)
|
|
96
|
-
|
|
97
|
-
return BlockAdapter(
|
|
98
|
-
pipe=pipe,
|
|
99
|
-
transformer=transformer,
|
|
100
|
-
blocks=blocks,
|
|
101
|
-
blocks_name=blocks_name,
|
|
102
|
-
)
|
|
103
|
-
|
|
104
|
-
@staticmethod
|
|
105
|
-
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
106
|
-
if (
|
|
107
|
-
# NOTE: pipe may not need to be DiffusionPipeline?
|
|
108
|
-
# isinstance(adapter.pipe, DiffusionPipeline)
|
|
109
|
-
adapter.pipe is not None
|
|
110
|
-
and adapter.transformer is not None
|
|
111
|
-
and adapter.blocks is not None
|
|
112
|
-
and adapter.blocks_name is not None
|
|
113
|
-
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
114
|
-
):
|
|
115
|
-
return True
|
|
116
|
-
|
|
117
|
-
logger.warning("Check block adapter failed!")
|
|
118
|
-
return False
|
|
119
|
-
|
|
120
|
-
@staticmethod
|
|
121
|
-
def find_blocks(
|
|
122
|
-
transformer: torch.nn.Module,
|
|
123
|
-
allow_prefixes: List[str] = [
|
|
124
|
-
"transformer",
|
|
125
|
-
"single_transformer",
|
|
126
|
-
"blocks",
|
|
127
|
-
"layers",
|
|
128
|
-
],
|
|
129
|
-
allow_suffixes: List[str] = [
|
|
130
|
-
"TransformerBlock",
|
|
131
|
-
],
|
|
132
|
-
check_prefixes: bool = True,
|
|
133
|
-
check_suffixes: bool = False,
|
|
134
|
-
**kwargs,
|
|
135
|
-
) -> Tuple[torch.nn.ModuleList, str]:
|
|
136
|
-
# Check prefixes
|
|
137
|
-
if check_prefixes:
|
|
138
|
-
blocks_names = []
|
|
139
|
-
for attr_name in dir(transformer):
|
|
140
|
-
for prefix in allow_prefixes:
|
|
141
|
-
if attr_name.startswith(prefix):
|
|
142
|
-
blocks_names.append(attr_name)
|
|
143
|
-
else:
|
|
144
|
-
blocks_names = dir(transformer)
|
|
145
|
-
|
|
146
|
-
# Check ModuleList
|
|
147
|
-
valid_names = []
|
|
148
|
-
valid_count = []
|
|
149
|
-
forward_pattern = kwargs.get("forward_pattern", None)
|
|
150
|
-
for blocks_name in blocks_names:
|
|
151
|
-
if blocks := getattr(transformer, blocks_name, None):
|
|
152
|
-
if isinstance(blocks, torch.nn.ModuleList):
|
|
153
|
-
block = blocks[0]
|
|
154
|
-
block_cls_name = block.__class__.__name__
|
|
155
|
-
# Check suffixes
|
|
156
|
-
if isinstance(block, torch.nn.Module) and (
|
|
157
|
-
any(
|
|
158
|
-
(
|
|
159
|
-
block_cls_name.endswith(allow_suffix)
|
|
160
|
-
for allow_suffix in allow_suffixes
|
|
161
|
-
)
|
|
162
|
-
)
|
|
163
|
-
or (not check_suffixes)
|
|
164
|
-
):
|
|
165
|
-
# May check forward pattern
|
|
166
|
-
if forward_pattern is not None:
|
|
167
|
-
if BlockAdapter.match_blocks_pattern(
|
|
168
|
-
blocks,
|
|
169
|
-
forward_pattern,
|
|
170
|
-
logging=False,
|
|
171
|
-
):
|
|
172
|
-
valid_names.append(blocks_name)
|
|
173
|
-
valid_count.append(len(blocks))
|
|
174
|
-
else:
|
|
175
|
-
valid_names.append(blocks_name)
|
|
176
|
-
valid_count.append(len(blocks))
|
|
177
|
-
|
|
178
|
-
if not valid_names:
|
|
179
|
-
raise ValueError(
|
|
180
|
-
"Auto selected transformer blocks failed, please set it manually."
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
final_name = valid_names[0]
|
|
184
|
-
final_count = valid_count[0]
|
|
185
|
-
block_policy = kwargs.get("blocks_policy", "max")
|
|
186
|
-
|
|
187
|
-
for blocks_name, count in zip(valid_names, valid_count):
|
|
188
|
-
blocks = getattr(transformer, blocks_name)
|
|
189
|
-
logger.info(
|
|
190
|
-
f"Auto selected transformer blocks: {blocks_name}, "
|
|
191
|
-
f"class: {blocks[0].__class__.__name__}, "
|
|
192
|
-
f"num blocks: {count}"
|
|
193
|
-
)
|
|
194
|
-
if block_policy == "max":
|
|
195
|
-
if final_count < count:
|
|
196
|
-
final_count = count
|
|
197
|
-
final_name = blocks_name
|
|
198
|
-
else:
|
|
199
|
-
if final_count > count:
|
|
200
|
-
final_count = count
|
|
201
|
-
final_name = blocks_name
|
|
202
|
-
|
|
203
|
-
final_blocks = getattr(transformer, final_name)
|
|
204
|
-
|
|
205
|
-
logger.info(
|
|
206
|
-
f"Final selected transformer blocks: {final_name}, "
|
|
207
|
-
f"class: {final_blocks[0].__class__.__name__}, "
|
|
208
|
-
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
209
|
-
)
|
|
210
|
-
|
|
211
|
-
return final_blocks, final_name
|
|
212
|
-
|
|
213
|
-
@staticmethod
|
|
214
|
-
def match_block_pattern(
|
|
215
|
-
block: torch.nn.Module,
|
|
216
|
-
forward_pattern: ForwardPattern,
|
|
217
|
-
) -> bool:
|
|
218
|
-
assert (
|
|
219
|
-
forward_pattern.Supported
|
|
220
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
221
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
222
|
-
|
|
223
|
-
forward_parameters = set(
|
|
224
|
-
inspect.signature(block.forward).parameters.keys()
|
|
225
|
-
)
|
|
226
|
-
num_outputs = str(
|
|
227
|
-
inspect.signature(block.forward).return_annotation
|
|
228
|
-
).count("torch.Tensor")
|
|
229
|
-
|
|
230
|
-
in_matched = True
|
|
231
|
-
out_matched = True
|
|
232
|
-
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
233
|
-
# output pattern not match
|
|
234
|
-
out_matched = False
|
|
235
|
-
|
|
236
|
-
for required_param in forward_pattern.In:
|
|
237
|
-
if required_param not in forward_parameters:
|
|
238
|
-
in_matched = False
|
|
239
|
-
|
|
240
|
-
return in_matched and out_matched
|
|
241
|
-
|
|
242
|
-
@staticmethod
|
|
243
|
-
def match_blocks_pattern(
|
|
244
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
245
|
-
forward_pattern: ForwardPattern,
|
|
246
|
-
logging: bool = True,
|
|
247
|
-
) -> bool:
|
|
248
|
-
assert (
|
|
249
|
-
forward_pattern.Supported
|
|
250
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
251
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
252
|
-
|
|
253
|
-
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
254
|
-
|
|
255
|
-
pattern_matched_states = []
|
|
256
|
-
for block in transformer_blocks:
|
|
257
|
-
pattern_matched_states.append(
|
|
258
|
-
BlockAdapter.match_block_pattern(
|
|
259
|
-
block,
|
|
260
|
-
forward_pattern,
|
|
261
|
-
)
|
|
262
|
-
)
|
|
263
|
-
|
|
264
|
-
pattern_matched = all(pattern_matched_states) # all block match
|
|
265
|
-
if pattern_matched and logging:
|
|
266
|
-
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
267
|
-
logger.info(
|
|
268
|
-
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
269
|
-
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
270
|
-
)
|
|
271
|
-
|
|
272
|
-
return pattern_matched
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
@dataclasses.dataclass
|
|
276
|
-
class UnifiedCacheParams:
|
|
277
|
-
block_adapter: BlockAdapter = None
|
|
278
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
class UnifiedCacheAdapter:
|
|
282
|
-
_supported_pipelines = [
|
|
283
|
-
"Flux",
|
|
284
|
-
"Mochi",
|
|
285
|
-
"CogVideoX",
|
|
286
|
-
"Wan",
|
|
287
|
-
"HunyuanVideo",
|
|
288
|
-
"QwenImage",
|
|
289
|
-
"LTXVideo",
|
|
290
|
-
"Allegro",
|
|
291
|
-
"CogView3Plus",
|
|
292
|
-
"CogView4",
|
|
293
|
-
"Cosmos",
|
|
294
|
-
"EasyAnimate",
|
|
295
|
-
"SkyReelsV2",
|
|
296
|
-
"SD3",
|
|
297
|
-
"ConsisID",
|
|
298
|
-
"DiT",
|
|
299
|
-
"Amused",
|
|
300
|
-
"Bria",
|
|
301
|
-
"HunyuanDiT",
|
|
302
|
-
"HunyuanDiTPAG",
|
|
303
|
-
"Lumina",
|
|
304
|
-
"Lumina2",
|
|
305
|
-
"OmniGen",
|
|
306
|
-
"PixArt",
|
|
307
|
-
"Sana",
|
|
308
|
-
"ShapE",
|
|
309
|
-
"StableAudio",
|
|
310
|
-
"VisualCloze",
|
|
311
|
-
"AuraFlow",
|
|
312
|
-
"Chroma",
|
|
313
|
-
"HiDream",
|
|
314
|
-
]
|
|
23
|
+
# Unified Cached Adapter
|
|
24
|
+
class CachedAdapter:
|
|
315
25
|
|
|
316
26
|
def __call__(self, *args, **kwargs):
|
|
317
27
|
return self.apply(*args, **kwargs)
|
|
318
28
|
|
|
319
|
-
@classmethod
|
|
320
|
-
def supported_pipelines(cls) -> Tuple[int, List[str]]:
|
|
321
|
-
return len(cls._supported_pipelines), [
|
|
322
|
-
p + "*" for p in cls._supported_pipelines
|
|
323
|
-
]
|
|
324
|
-
|
|
325
|
-
@classmethod
|
|
326
|
-
def is_supported(cls, pipe: DiffusionPipeline) -> bool:
|
|
327
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
328
|
-
for prefix in cls._supported_pipelines:
|
|
329
|
-
if pipe_cls_name.startswith(prefix):
|
|
330
|
-
return True
|
|
331
|
-
return False
|
|
332
|
-
|
|
333
|
-
@classmethod
|
|
334
|
-
def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
|
|
335
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
336
|
-
|
|
337
|
-
if pipe_cls_name.startswith("Flux"):
|
|
338
|
-
from diffusers import FluxTransformer2DModel
|
|
339
|
-
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
340
|
-
|
|
341
|
-
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
342
|
-
return UnifiedCacheParams(
|
|
343
|
-
block_adapter=BlockAdapter(
|
|
344
|
-
pipe=pipe,
|
|
345
|
-
transformer=pipe.transformer,
|
|
346
|
-
blocks=(
|
|
347
|
-
pipe.transformer.transformer_blocks
|
|
348
|
-
+ pipe.transformer.single_transformer_blocks
|
|
349
|
-
),
|
|
350
|
-
blocks_name="transformer_blocks",
|
|
351
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
352
|
-
patch_functor=FluxPatchFunctor(),
|
|
353
|
-
),
|
|
354
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
elif pipe_cls_name.startswith("Mochi"):
|
|
358
|
-
from diffusers import MochiTransformer3DModel
|
|
359
|
-
|
|
360
|
-
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
361
|
-
return UnifiedCacheParams(
|
|
362
|
-
block_adapter=BlockAdapter(
|
|
363
|
-
pipe=pipe,
|
|
364
|
-
transformer=pipe.transformer,
|
|
365
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
366
|
-
blocks_name="transformer_blocks",
|
|
367
|
-
dummy_blocks_names=[],
|
|
368
|
-
),
|
|
369
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
370
|
-
)
|
|
371
|
-
|
|
372
|
-
elif pipe_cls_name.startswith("CogVideoX"):
|
|
373
|
-
from diffusers import CogVideoXTransformer3DModel
|
|
374
|
-
|
|
375
|
-
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
376
|
-
return UnifiedCacheParams(
|
|
377
|
-
block_adapter=BlockAdapter(
|
|
378
|
-
pipe=pipe,
|
|
379
|
-
transformer=pipe.transformer,
|
|
380
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
381
|
-
blocks_name="transformer_blocks",
|
|
382
|
-
dummy_blocks_names=[],
|
|
383
|
-
),
|
|
384
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
385
|
-
)
|
|
386
|
-
|
|
387
|
-
elif pipe_cls_name.startswith("Wan"):
|
|
388
|
-
from diffusers import (
|
|
389
|
-
WanTransformer3DModel,
|
|
390
|
-
WanVACETransformer3DModel,
|
|
391
|
-
)
|
|
392
|
-
|
|
393
|
-
assert isinstance(
|
|
394
|
-
pipe.transformer,
|
|
395
|
-
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
396
|
-
)
|
|
397
|
-
if getattr(pipe, "transformer_2", None):
|
|
398
|
-
# Wan 2.2, cache for low-noise transformer
|
|
399
|
-
assert isinstance(
|
|
400
|
-
pipe.transformer_2,
|
|
401
|
-
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
402
|
-
)
|
|
403
|
-
return UnifiedCacheParams(
|
|
404
|
-
block_adapter=BlockAdapter(
|
|
405
|
-
pipe=pipe,
|
|
406
|
-
transformer=pipe.transformer_2,
|
|
407
|
-
blocks=pipe.transformer_2.blocks,
|
|
408
|
-
blocks_name="blocks",
|
|
409
|
-
dummy_blocks_names=[],
|
|
410
|
-
),
|
|
411
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
412
|
-
)
|
|
413
|
-
else:
|
|
414
|
-
# Wan 2.1
|
|
415
|
-
return UnifiedCacheParams(
|
|
416
|
-
block_adapter=BlockAdapter(
|
|
417
|
-
pipe=pipe,
|
|
418
|
-
transformer=pipe.transformer,
|
|
419
|
-
blocks=pipe.transformer.blocks,
|
|
420
|
-
blocks_name="blocks",
|
|
421
|
-
dummy_blocks_names=[],
|
|
422
|
-
),
|
|
423
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
424
|
-
)
|
|
425
|
-
|
|
426
|
-
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
427
|
-
from diffusers import HunyuanVideoTransformer3DModel
|
|
428
|
-
|
|
429
|
-
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
430
|
-
return UnifiedCacheParams(
|
|
431
|
-
block_adapter=BlockAdapter(
|
|
432
|
-
pipe=pipe,
|
|
433
|
-
blocks=(
|
|
434
|
-
pipe.transformer.transformer_blocks
|
|
435
|
-
+ pipe.transformer.single_transformer_blocks
|
|
436
|
-
),
|
|
437
|
-
blocks_name="transformer_blocks",
|
|
438
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
439
|
-
),
|
|
440
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
441
|
-
)
|
|
442
|
-
|
|
443
|
-
elif pipe_cls_name.startswith("QwenImage"):
|
|
444
|
-
from diffusers import QwenImageTransformer2DModel
|
|
445
|
-
|
|
446
|
-
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
447
|
-
return UnifiedCacheParams(
|
|
448
|
-
block_adapter=BlockAdapter(
|
|
449
|
-
pipe=pipe,
|
|
450
|
-
transformer=pipe.transformer,
|
|
451
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
452
|
-
blocks_name="transformer_blocks",
|
|
453
|
-
dummy_blocks_names=[],
|
|
454
|
-
),
|
|
455
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
456
|
-
)
|
|
457
|
-
|
|
458
|
-
elif pipe_cls_name.startswith("LTXVideo"):
|
|
459
|
-
from diffusers import LTXVideoTransformer3DModel
|
|
460
|
-
|
|
461
|
-
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
462
|
-
return UnifiedCacheParams(
|
|
463
|
-
block_adapter=BlockAdapter(
|
|
464
|
-
pipe=pipe,
|
|
465
|
-
transformer=pipe.transformer,
|
|
466
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
467
|
-
blocks_name="transformer_blocks",
|
|
468
|
-
dummy_blocks_names=[],
|
|
469
|
-
),
|
|
470
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
471
|
-
)
|
|
472
|
-
|
|
473
|
-
elif pipe_cls_name.startswith("Allegro"):
|
|
474
|
-
from diffusers import AllegroTransformer3DModel
|
|
475
|
-
|
|
476
|
-
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
477
|
-
return UnifiedCacheParams(
|
|
478
|
-
block_adapter=BlockAdapter(
|
|
479
|
-
pipe=pipe,
|
|
480
|
-
transformer=pipe.transformer,
|
|
481
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
482
|
-
blocks_name="transformer_blocks",
|
|
483
|
-
dummy_blocks_names=[],
|
|
484
|
-
),
|
|
485
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
486
|
-
)
|
|
487
|
-
|
|
488
|
-
elif pipe_cls_name.startswith("CogView3Plus"):
|
|
489
|
-
from diffusers import CogView3PlusTransformer2DModel
|
|
490
|
-
|
|
491
|
-
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
492
|
-
return UnifiedCacheParams(
|
|
493
|
-
block_adapter=BlockAdapter(
|
|
494
|
-
pipe=pipe,
|
|
495
|
-
transformer=pipe.transformer,
|
|
496
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
497
|
-
blocks_name="transformer_blocks",
|
|
498
|
-
dummy_blocks_names=[],
|
|
499
|
-
),
|
|
500
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
501
|
-
)
|
|
502
|
-
|
|
503
|
-
elif pipe_cls_name.startswith("CogView4"):
|
|
504
|
-
from diffusers import CogView4Transformer2DModel
|
|
505
|
-
|
|
506
|
-
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
507
|
-
return UnifiedCacheParams(
|
|
508
|
-
block_adapter=BlockAdapter(
|
|
509
|
-
pipe=pipe,
|
|
510
|
-
transformer=pipe.transformer,
|
|
511
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
512
|
-
blocks_name="transformer_blocks",
|
|
513
|
-
dummy_blocks_names=[],
|
|
514
|
-
),
|
|
515
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
516
|
-
)
|
|
517
|
-
|
|
518
|
-
elif pipe_cls_name.startswith("Cosmos"):
|
|
519
|
-
from diffusers import CosmosTransformer3DModel
|
|
520
|
-
|
|
521
|
-
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
522
|
-
return UnifiedCacheParams(
|
|
523
|
-
block_adapter=BlockAdapter(
|
|
524
|
-
pipe=pipe,
|
|
525
|
-
transformer=pipe.transformer,
|
|
526
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
527
|
-
blocks_name="transformer_blocks",
|
|
528
|
-
dummy_blocks_names=[],
|
|
529
|
-
),
|
|
530
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
531
|
-
)
|
|
532
|
-
|
|
533
|
-
elif pipe_cls_name.startswith("EasyAnimate"):
|
|
534
|
-
from diffusers import EasyAnimateTransformer3DModel
|
|
535
|
-
|
|
536
|
-
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
537
|
-
return UnifiedCacheParams(
|
|
538
|
-
block_adapter=BlockAdapter(
|
|
539
|
-
pipe=pipe,
|
|
540
|
-
transformer=pipe.transformer,
|
|
541
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
542
|
-
blocks_name="transformer_blocks",
|
|
543
|
-
dummy_blocks_names=[],
|
|
544
|
-
),
|
|
545
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
546
|
-
)
|
|
547
|
-
|
|
548
|
-
elif pipe_cls_name.startswith("SkyReelsV2"):
|
|
549
|
-
from diffusers import SkyReelsV2Transformer3DModel
|
|
550
|
-
|
|
551
|
-
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
552
|
-
return UnifiedCacheParams(
|
|
553
|
-
block_adapter=BlockAdapter(
|
|
554
|
-
pipe=pipe,
|
|
555
|
-
transformer=pipe.transformer,
|
|
556
|
-
blocks=pipe.transformer.blocks,
|
|
557
|
-
blocks_name="blocks",
|
|
558
|
-
dummy_blocks_names=[],
|
|
559
|
-
),
|
|
560
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
561
|
-
)
|
|
562
|
-
elif pipe_cls_name.startswith("SD3"):
|
|
563
|
-
from diffusers import SD3Transformer2DModel
|
|
564
|
-
|
|
565
|
-
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
566
|
-
return UnifiedCacheParams(
|
|
567
|
-
block_adapter=BlockAdapter(
|
|
568
|
-
pipe=pipe,
|
|
569
|
-
transformer=pipe.transformer,
|
|
570
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
571
|
-
blocks_name="transformer_blocks",
|
|
572
|
-
dummy_blocks_names=[],
|
|
573
|
-
),
|
|
574
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
575
|
-
)
|
|
576
|
-
|
|
577
|
-
elif pipe_cls_name.startswith("ConsisID"):
|
|
578
|
-
from diffusers import ConsisIDTransformer3DModel
|
|
579
|
-
|
|
580
|
-
assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
|
|
581
|
-
return UnifiedCacheParams(
|
|
582
|
-
block_adapter=BlockAdapter(
|
|
583
|
-
pipe=pipe,
|
|
584
|
-
transformer=pipe.transformer,
|
|
585
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
586
|
-
blocks_name="transformer_blocks",
|
|
587
|
-
dummy_blocks_names=[],
|
|
588
|
-
),
|
|
589
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
590
|
-
)
|
|
591
|
-
|
|
592
|
-
elif pipe_cls_name.startswith("DiT"):
|
|
593
|
-
from diffusers import DiTTransformer2DModel
|
|
594
|
-
|
|
595
|
-
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
596
|
-
return UnifiedCacheParams(
|
|
597
|
-
block_adapter=BlockAdapter(
|
|
598
|
-
pipe=pipe,
|
|
599
|
-
transformer=pipe.transformer,
|
|
600
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
601
|
-
blocks_name="transformer_blocks",
|
|
602
|
-
dummy_blocks_names=[],
|
|
603
|
-
),
|
|
604
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
605
|
-
)
|
|
606
|
-
|
|
607
|
-
elif pipe_cls_name.startswith("Amused"):
|
|
608
|
-
from diffusers import UVit2DModel
|
|
609
|
-
|
|
610
|
-
assert isinstance(pipe.transformer, UVit2DModel)
|
|
611
|
-
return UnifiedCacheParams(
|
|
612
|
-
block_adapter=BlockAdapter(
|
|
613
|
-
pipe=pipe,
|
|
614
|
-
transformer=pipe.transformer,
|
|
615
|
-
blocks=pipe.transformer.transformer_layers,
|
|
616
|
-
blocks_name="transformer_layers",
|
|
617
|
-
dummy_blocks_names=[],
|
|
618
|
-
),
|
|
619
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
620
|
-
)
|
|
621
|
-
|
|
622
|
-
elif pipe_cls_name.startswith("Bria"):
|
|
623
|
-
from diffusers import BriaTransformer2DModel
|
|
624
|
-
|
|
625
|
-
assert isinstance(pipe.transformer, BriaTransformer2DModel)
|
|
626
|
-
return UnifiedCacheParams(
|
|
627
|
-
block_adapter=BlockAdapter(
|
|
628
|
-
pipe=pipe,
|
|
629
|
-
transformer=pipe.transformer,
|
|
630
|
-
blocks=(
|
|
631
|
-
pipe.transformer.transformer_blocks
|
|
632
|
-
+ pipe.transformer.single_transformer_blocks
|
|
633
|
-
),
|
|
634
|
-
blocks_name="transformer_blocks",
|
|
635
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
636
|
-
),
|
|
637
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
638
|
-
)
|
|
639
|
-
|
|
640
|
-
elif pipe_cls_name.startswith("HunyuanDiT"):
|
|
641
|
-
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
642
|
-
|
|
643
|
-
assert isinstance(
|
|
644
|
-
pipe.transformer,
|
|
645
|
-
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
646
|
-
)
|
|
647
|
-
return UnifiedCacheParams(
|
|
648
|
-
block_adapter=BlockAdapter(
|
|
649
|
-
pipe=pipe,
|
|
650
|
-
transformer=pipe.transformer,
|
|
651
|
-
blocks=pipe.transformer.blocks,
|
|
652
|
-
blocks_name="blocks",
|
|
653
|
-
dummy_blocks_names=[],
|
|
654
|
-
),
|
|
655
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
656
|
-
)
|
|
657
|
-
|
|
658
|
-
elif pipe_cls_name.startswith("HunyuanDiTPAG"):
|
|
659
|
-
from diffusers import HunyuanDiT2DModel
|
|
660
|
-
|
|
661
|
-
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
662
|
-
return UnifiedCacheParams(
|
|
663
|
-
block_adapter=BlockAdapter(
|
|
664
|
-
pipe=pipe,
|
|
665
|
-
transformer=pipe.transformer,
|
|
666
|
-
blocks=pipe.transformer.blocks,
|
|
667
|
-
blocks_name="blocks",
|
|
668
|
-
dummy_blocks_names=[],
|
|
669
|
-
),
|
|
670
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
671
|
-
)
|
|
672
|
-
|
|
673
|
-
elif pipe_cls_name.startswith("Lumina"):
|
|
674
|
-
from diffusers import LuminaNextDiT2DModel
|
|
675
|
-
|
|
676
|
-
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
677
|
-
return UnifiedCacheParams(
|
|
678
|
-
block_adapter=BlockAdapter(
|
|
679
|
-
pipe=pipe,
|
|
680
|
-
transformer=pipe.transformer,
|
|
681
|
-
blocks=pipe.transformer.layers,
|
|
682
|
-
blocks_name="layers",
|
|
683
|
-
dummy_blocks_names=[],
|
|
684
|
-
),
|
|
685
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
686
|
-
)
|
|
687
|
-
|
|
688
|
-
elif pipe_cls_name.startswith("Lumina2"):
|
|
689
|
-
from diffusers import Lumina2Transformer2DModel
|
|
690
|
-
|
|
691
|
-
assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
|
|
692
|
-
return UnifiedCacheParams(
|
|
693
|
-
block_adapter=BlockAdapter(
|
|
694
|
-
pipe=pipe,
|
|
695
|
-
transformer=pipe.transformer,
|
|
696
|
-
blocks=pipe.transformer.layers,
|
|
697
|
-
blocks_name="layers",
|
|
698
|
-
dummy_blocks_names=[],
|
|
699
|
-
),
|
|
700
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
701
|
-
)
|
|
702
|
-
|
|
703
|
-
elif pipe_cls_name.startswith("OmniGen"):
|
|
704
|
-
from diffusers import OmniGenTransformer2DModel
|
|
705
|
-
|
|
706
|
-
assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
|
|
707
|
-
return UnifiedCacheParams(
|
|
708
|
-
block_adapter=BlockAdapter(
|
|
709
|
-
pipe=pipe,
|
|
710
|
-
transformer=pipe.transformer,
|
|
711
|
-
blocks=pipe.transformer.layers,
|
|
712
|
-
blocks_name="layers",
|
|
713
|
-
dummy_blocks_names=[],
|
|
714
|
-
),
|
|
715
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
716
|
-
)
|
|
717
|
-
|
|
718
|
-
elif pipe_cls_name.startswith("PixArt"):
|
|
719
|
-
from diffusers import PixArtTransformer2DModel
|
|
720
|
-
|
|
721
|
-
assert isinstance(pipe.transformer, PixArtTransformer2DModel)
|
|
722
|
-
return UnifiedCacheParams(
|
|
723
|
-
block_adapter=BlockAdapter(
|
|
724
|
-
pipe=pipe,
|
|
725
|
-
transformer=pipe.transformer,
|
|
726
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
727
|
-
blocks_name="transformer_blocks",
|
|
728
|
-
dummy_blocks_names=[],
|
|
729
|
-
),
|
|
730
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
731
|
-
)
|
|
732
|
-
|
|
733
|
-
elif pipe_cls_name.startswith("Sana"):
|
|
734
|
-
from diffusers import SanaTransformer2DModel
|
|
735
|
-
|
|
736
|
-
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
737
|
-
return UnifiedCacheParams(
|
|
738
|
-
block_adapter=BlockAdapter(
|
|
739
|
-
pipe=pipe,
|
|
740
|
-
transformer=pipe.transformer,
|
|
741
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
742
|
-
blocks_name="transformer_blocks",
|
|
743
|
-
dummy_blocks_names=[],
|
|
744
|
-
),
|
|
745
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
746
|
-
)
|
|
747
|
-
|
|
748
|
-
elif pipe_cls_name.startswith("ShapE"):
|
|
749
|
-
from diffusers import PriorTransformer
|
|
750
|
-
|
|
751
|
-
assert isinstance(pipe.prior, PriorTransformer)
|
|
752
|
-
return UnifiedCacheParams(
|
|
753
|
-
block_adapter=BlockAdapter(
|
|
754
|
-
pipe=pipe,
|
|
755
|
-
transformer=pipe.prior,
|
|
756
|
-
blocks=pipe.prior.transformer_blocks,
|
|
757
|
-
blocks_name="transformer_blocks",
|
|
758
|
-
dummy_blocks_names=[],
|
|
759
|
-
),
|
|
760
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
761
|
-
)
|
|
762
|
-
|
|
763
|
-
elif pipe_cls_name.startswith("StableAudio"):
|
|
764
|
-
from diffusers import StableAudioDiTModel
|
|
765
|
-
|
|
766
|
-
assert isinstance(pipe.transformer, StableAudioDiTModel)
|
|
767
|
-
return UnifiedCacheParams(
|
|
768
|
-
block_adapter=BlockAdapter(
|
|
769
|
-
pipe=pipe,
|
|
770
|
-
transformer=pipe.transformer,
|
|
771
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
772
|
-
blocks_name="transformer_blocks",
|
|
773
|
-
dummy_blocks_names=[],
|
|
774
|
-
),
|
|
775
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
776
|
-
)
|
|
777
|
-
|
|
778
|
-
elif pipe_cls_name.startswith("VisualCloze"):
|
|
779
|
-
from diffusers import FluxTransformer2DModel
|
|
780
|
-
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
781
|
-
|
|
782
|
-
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
783
|
-
return UnifiedCacheParams(
|
|
784
|
-
block_adapter=BlockAdapter(
|
|
785
|
-
pipe=pipe,
|
|
786
|
-
transformer=pipe.transformer,
|
|
787
|
-
blocks=(
|
|
788
|
-
pipe.transformer.transformer_blocks
|
|
789
|
-
+ pipe.transformer.single_transformer_blocks
|
|
790
|
-
),
|
|
791
|
-
blocks_name="transformer_blocks",
|
|
792
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
793
|
-
patch_functor=FluxPatchFunctor(),
|
|
794
|
-
),
|
|
795
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
796
|
-
)
|
|
797
|
-
|
|
798
|
-
elif pipe_cls_name.startswith("AuraFlow"):
|
|
799
|
-
from diffusers import AuraFlowTransformer2DModel
|
|
800
|
-
|
|
801
|
-
assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
|
|
802
|
-
return UnifiedCacheParams(
|
|
803
|
-
block_adapter=BlockAdapter(
|
|
804
|
-
pipe=pipe,
|
|
805
|
-
transformer=pipe.transformer,
|
|
806
|
-
# Only support caching single_transformer_blocks for AuraFlow now.
|
|
807
|
-
# TODO: Support AuraFlowPatchFunctor.
|
|
808
|
-
blocks=pipe.transformer.single_transformer_blocks,
|
|
809
|
-
blocks_name="single_transformer_blocks",
|
|
810
|
-
dummy_blocks_names=[],
|
|
811
|
-
),
|
|
812
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
813
|
-
)
|
|
814
|
-
|
|
815
|
-
elif pipe_cls_name.startswith("Chroma"):
|
|
816
|
-
from diffusers import ChromaTransformer2DModel
|
|
817
|
-
from cache_dit.cache_factory.patch_functors import (
|
|
818
|
-
ChromaPatchFunctor,
|
|
819
|
-
)
|
|
820
|
-
|
|
821
|
-
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
822
|
-
return UnifiedCacheParams(
|
|
823
|
-
block_adapter=BlockAdapter(
|
|
824
|
-
pipe=pipe,
|
|
825
|
-
transformer=pipe.transformer,
|
|
826
|
-
blocks=(
|
|
827
|
-
pipe.transformer.transformer_blocks
|
|
828
|
-
+ pipe.transformer.single_transformer_blocks
|
|
829
|
-
),
|
|
830
|
-
blocks_name="transformer_blocks",
|
|
831
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
832
|
-
patch_functor=ChromaPatchFunctor(),
|
|
833
|
-
),
|
|
834
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
835
|
-
)
|
|
836
|
-
|
|
837
|
-
elif pipe_cls_name.startswith("HiDream"):
|
|
838
|
-
from diffusers import HiDreamImageTransformer2DModel
|
|
839
|
-
|
|
840
|
-
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
841
|
-
return UnifiedCacheParams(
|
|
842
|
-
block_adapter=BlockAdapter(
|
|
843
|
-
pipe=pipe,
|
|
844
|
-
transformer=pipe.transformer,
|
|
845
|
-
# Only support caching single_stream_blocks for HiDream now.
|
|
846
|
-
# TODO: Support HiDreamPatchFunctor.
|
|
847
|
-
blocks=pipe.transformer.single_stream_blocks,
|
|
848
|
-
blocks_name="single_stream_blocks",
|
|
849
|
-
dummy_blocks_names=[],
|
|
850
|
-
),
|
|
851
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
852
|
-
)
|
|
853
|
-
|
|
854
|
-
else:
|
|
855
|
-
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
856
|
-
|
|
857
29
|
@classmethod
|
|
858
30
|
def apply(
|
|
859
31
|
cls,
|
|
860
32
|
pipe: DiffusionPipeline = None,
|
|
861
33
|
block_adapter: BlockAdapter = None,
|
|
862
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
863
34
|
**cache_context_kwargs,
|
|
864
35
|
) -> DiffusionPipeline:
|
|
865
36
|
assert (
|
|
@@ -867,15 +38,14 @@ class UnifiedCacheAdapter:
|
|
|
867
38
|
), "pipe or block_adapter can not both None!"
|
|
868
39
|
|
|
869
40
|
if pipe is not None:
|
|
870
|
-
if
|
|
41
|
+
if BlockAdapterRegistry.is_supported(pipe):
|
|
871
42
|
logger.info(
|
|
872
43
|
f"{pipe.__class__.__name__} is officially supported by cache-dit. "
|
|
873
44
|
"Use it's pre-defined BlockAdapter directly!"
|
|
874
45
|
)
|
|
875
|
-
|
|
46
|
+
block_adapter = BlockAdapterRegistry.get_adapter(pipe)
|
|
876
47
|
return cls.cachify(
|
|
877
|
-
|
|
878
|
-
forward_pattern=params.forward_pattern,
|
|
48
|
+
block_adapter,
|
|
879
49
|
**cache_context_kwargs,
|
|
880
50
|
)
|
|
881
51
|
else:
|
|
@@ -889,7 +59,6 @@ class UnifiedCacheAdapter:
|
|
|
889
59
|
)
|
|
890
60
|
return cls.cachify(
|
|
891
61
|
block_adapter,
|
|
892
|
-
forward_pattern=forward_pattern,
|
|
893
62
|
**cache_context_kwargs,
|
|
894
63
|
)
|
|
895
64
|
|
|
@@ -897,74 +66,78 @@ class UnifiedCacheAdapter:
|
|
|
897
66
|
def cachify(
|
|
898
67
|
cls,
|
|
899
68
|
block_adapter: BlockAdapter,
|
|
900
|
-
*,
|
|
901
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
902
69
|
**cache_context_kwargs,
|
|
903
70
|
) -> DiffusionPipeline:
|
|
904
71
|
|
|
905
72
|
if block_adapter.auto:
|
|
906
73
|
block_adapter = BlockAdapter.auto_block_adapter(
|
|
907
74
|
block_adapter,
|
|
908
|
-
forward_pattern,
|
|
909
75
|
)
|
|
910
76
|
|
|
911
77
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
912
|
-
|
|
78
|
+
|
|
79
|
+
# 0. Must normalize block_adapter before apply cache
|
|
80
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
81
|
+
if BlockAdapter.is_cached(block_adapter):
|
|
82
|
+
return block_adapter.pipe
|
|
83
|
+
|
|
84
|
+
# 1. Apply cache on pipeline: wrap cache context, must
|
|
85
|
+
# call create_context before mock_blocks.
|
|
913
86
|
cls.create_context(
|
|
914
|
-
block_adapter
|
|
87
|
+
block_adapter,
|
|
915
88
|
**cache_context_kwargs,
|
|
916
89
|
)
|
|
917
|
-
|
|
90
|
+
|
|
91
|
+
# 2. Apply cache on transformer: mock cached blocks
|
|
918
92
|
cls.mock_blocks(
|
|
919
93
|
block_adapter,
|
|
920
|
-
forward_pattern=forward_pattern,
|
|
921
|
-
)
|
|
922
|
-
cls.patch_params(
|
|
923
|
-
block_adapter,
|
|
924
|
-
forward_pattern=forward_pattern,
|
|
925
|
-
**cache_context_kwargs,
|
|
926
94
|
)
|
|
95
|
+
|
|
927
96
|
return block_adapter.pipe
|
|
928
97
|
|
|
929
98
|
@classmethod
|
|
930
99
|
def patch_params(
|
|
931
100
|
cls,
|
|
932
101
|
block_adapter: BlockAdapter,
|
|
933
|
-
|
|
934
|
-
**cache_context_kwargs,
|
|
102
|
+
contexts_kwargs: List[Dict],
|
|
935
103
|
):
|
|
936
|
-
block_adapter.
|
|
937
|
-
block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
|
|
938
|
-
block_adapter.pipe.__class__._cache_context_kwargs = (
|
|
939
|
-
cache_context_kwargs
|
|
940
|
-
)
|
|
104
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
941
105
|
|
|
942
|
-
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
948
|
-
|
|
949
|
-
|
|
950
|
-
|
|
951
|
-
|
|
952
|
-
|
|
953
|
-
|
|
954
|
-
|
|
955
|
-
|
|
956
|
-
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
106
|
+
params_shift = 0
|
|
107
|
+
for i in range(len(block_adapter.transformer)):
|
|
108
|
+
|
|
109
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
110
|
+
block_adapter.forward_pattern
|
|
111
|
+
)
|
|
112
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
113
|
+
block_adapter.has_separate_cfg
|
|
114
|
+
)
|
|
115
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
116
|
+
contexts_kwargs[params_shift]
|
|
117
|
+
)
|
|
118
|
+
|
|
119
|
+
blocks = block_adapter.blocks[i]
|
|
120
|
+
for j in range(len(blocks)):
|
|
121
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
122
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
123
|
+
params_shift + j
|
|
124
|
+
]
|
|
125
|
+
|
|
126
|
+
params_shift += len(blocks)
|
|
961
127
|
|
|
962
128
|
@classmethod
|
|
963
129
|
def check_context_kwargs(cls, pipe, **cache_context_kwargs):
|
|
964
130
|
# Check cache_context_kwargs
|
|
965
|
-
if not cache_context_kwargs["
|
|
131
|
+
if not cache_context_kwargs["enable_spearate_cfg"]:
|
|
966
132
|
# Check cfg for some specific case if users don't set it as True
|
|
967
|
-
cache_context_kwargs["
|
|
133
|
+
cache_context_kwargs["enable_spearate_cfg"] = (
|
|
134
|
+
BlockAdapterRegistry.has_separate_cfg(pipe)
|
|
135
|
+
)
|
|
136
|
+
logger.info(
|
|
137
|
+
f"Use default 'enable_spearate_cfg': "
|
|
138
|
+
f"{cache_context_kwargs['enable_spearate_cfg']}, "
|
|
139
|
+
f"Pipeline: {pipe.__class__.__name__}."
|
|
140
|
+
)
|
|
968
141
|
|
|
969
142
|
if cache_type := cache_context_kwargs.pop("cache_type", None):
|
|
970
143
|
assert (
|
|
@@ -976,95 +149,246 @@ class UnifiedCacheAdapter:
|
|
|
976
149
|
@classmethod
|
|
977
150
|
def create_context(
|
|
978
151
|
cls,
|
|
979
|
-
|
|
152
|
+
block_adapter: BlockAdapter,
|
|
980
153
|
**cache_context_kwargs,
|
|
981
154
|
) -> DiffusionPipeline:
|
|
982
|
-
|
|
983
|
-
|
|
155
|
+
|
|
156
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
157
|
+
|
|
158
|
+
if BlockAdapter.is_cached(block_adapter.pipe):
|
|
159
|
+
return block_adapter.pipe
|
|
984
160
|
|
|
985
161
|
# Check cache_context_kwargs
|
|
986
162
|
cache_context_kwargs = cls.check_context_kwargs(
|
|
987
|
-
pipe,
|
|
163
|
+
block_adapter.pipe,
|
|
988
164
|
**cache_context_kwargs,
|
|
989
165
|
)
|
|
990
166
|
# Apply cache on pipeline: wrap cache context
|
|
991
|
-
|
|
992
|
-
|
|
993
|
-
|
|
167
|
+
pipe_cls_name = block_adapter.pipe.__class__.__name__
|
|
168
|
+
|
|
169
|
+
# Each Pipeline should have it's own context manager instance.
|
|
170
|
+
# Different transformers (Wan2.2, etc) should shared the same
|
|
171
|
+
# cache manager but with different cache context (according
|
|
172
|
+
# to their unique instance id).
|
|
173
|
+
cache_manager = CachedContextManager(
|
|
174
|
+
name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
|
|
994
175
|
)
|
|
995
|
-
|
|
176
|
+
block_adapter.pipe._cache_manager = cache_manager # instance level
|
|
177
|
+
|
|
178
|
+
flatten_contexts, contexts_kwargs = cls.modify_context_params(
|
|
179
|
+
block_adapter, cache_manager, **cache_context_kwargs
|
|
180
|
+
)
|
|
181
|
+
|
|
182
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
996
183
|
|
|
997
184
|
@functools.wraps(original_call)
|
|
998
185
|
def new_call(self, *args, **kwargs):
|
|
999
|
-
with
|
|
1000
|
-
|
|
1001
|
-
|
|
1002
|
-
|
|
1003
|
-
|
|
1004
|
-
|
|
186
|
+
with ExitStack() as stack:
|
|
187
|
+
# cache context will be reset for each pipe inference
|
|
188
|
+
for context_name, context_kwargs in zip(
|
|
189
|
+
flatten_contexts, contexts_kwargs
|
|
190
|
+
):
|
|
191
|
+
stack.enter_context(
|
|
192
|
+
cache_manager.enter_context(
|
|
193
|
+
cache_manager.reset_context(
|
|
194
|
+
context_name,
|
|
195
|
+
**context_kwargs,
|
|
196
|
+
),
|
|
197
|
+
)
|
|
198
|
+
)
|
|
199
|
+
outputs = original_call(self, *args, **kwargs)
|
|
200
|
+
cls.patch_stats(block_adapter)
|
|
201
|
+
return outputs
|
|
202
|
+
|
|
203
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
204
|
+
block_adapter.pipe.__class__._original_call = original_call
|
|
205
|
+
block_adapter.pipe.__class__._is_cached = True
|
|
1005
206
|
|
|
1006
|
-
|
|
1007
|
-
|
|
1008
|
-
return pipe
|
|
207
|
+
cls.patch_params(block_adapter, contexts_kwargs)
|
|
208
|
+
|
|
209
|
+
return block_adapter.pipe
|
|
1009
210
|
|
|
1010
211
|
@classmethod
|
|
1011
|
-
def
|
|
212
|
+
def modify_context_params(
|
|
1012
213
|
cls,
|
|
1013
214
|
block_adapter: BlockAdapter,
|
|
1014
|
-
|
|
1015
|
-
|
|
215
|
+
cache_manager: CachedContextManager,
|
|
216
|
+
**cache_context_kwargs,
|
|
217
|
+
) -> Tuple[List[str], List[Dict[str, Any]]]:
|
|
1016
218
|
|
|
1017
|
-
|
|
1018
|
-
|
|
219
|
+
flatten_contexts = BlockAdapter.flatten(
|
|
220
|
+
block_adapter.unique_blocks_name
|
|
221
|
+
)
|
|
222
|
+
contexts_kwargs = [
|
|
223
|
+
cache_context_kwargs.copy()
|
|
224
|
+
for _ in range(
|
|
225
|
+
len(flatten_contexts),
|
|
226
|
+
)
|
|
227
|
+
]
|
|
1019
228
|
|
|
1020
|
-
|
|
1021
|
-
|
|
1022
|
-
|
|
1023
|
-
|
|
1024
|
-
|
|
1025
|
-
|
|
1026
|
-
|
|
229
|
+
for i in range(len(contexts_kwargs)):
|
|
230
|
+
contexts_kwargs[i]["name"] = flatten_contexts[i]
|
|
231
|
+
|
|
232
|
+
if block_adapter.params_modifiers is None:
|
|
233
|
+
return flatten_contexts, contexts_kwargs
|
|
234
|
+
|
|
235
|
+
flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
|
|
236
|
+
block_adapter.params_modifiers,
|
|
1027
237
|
)
|
|
1028
238
|
|
|
1029
|
-
|
|
1030
|
-
|
|
1031
|
-
|
|
1032
|
-
|
|
1033
|
-
|
|
1034
|
-
|
|
1035
|
-
|
|
1036
|
-
|
|
1037
|
-
|
|
239
|
+
for i in range(
|
|
240
|
+
min(len(contexts_kwargs), len(flatten_modifiers)),
|
|
241
|
+
):
|
|
242
|
+
contexts_kwargs[i].update(
|
|
243
|
+
flatten_modifiers[i]._context_kwargs,
|
|
244
|
+
)
|
|
245
|
+
contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
|
|
246
|
+
default_attrs={}, **contexts_kwargs[i]
|
|
247
|
+
)
|
|
248
|
+
|
|
249
|
+
return flatten_contexts, contexts_kwargs
|
|
250
|
+
|
|
251
|
+
@classmethod
|
|
252
|
+
def patch_stats(
|
|
253
|
+
cls,
|
|
254
|
+
block_adapter: BlockAdapter,
|
|
255
|
+
):
|
|
256
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
257
|
+
patch_cached_stats,
|
|
1038
258
|
)
|
|
259
|
+
|
|
260
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
261
|
+
|
|
262
|
+
for i in range(len(block_adapter.transformer)):
|
|
263
|
+
patch_cached_stats(
|
|
264
|
+
block_adapter.transformer[i],
|
|
265
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
266
|
+
cache_manager=cache_manager,
|
|
267
|
+
)
|
|
268
|
+
for blocks, unique_name in zip(
|
|
269
|
+
block_adapter.blocks[i],
|
|
270
|
+
block_adapter.unique_blocks_name[i],
|
|
271
|
+
):
|
|
272
|
+
patch_cached_stats(
|
|
273
|
+
blocks,
|
|
274
|
+
cache_context=unique_name,
|
|
275
|
+
cache_manager=cache_manager,
|
|
276
|
+
)
|
|
277
|
+
|
|
278
|
+
@classmethod
|
|
279
|
+
def mock_blocks(
|
|
280
|
+
cls,
|
|
281
|
+
block_adapter: BlockAdapter,
|
|
282
|
+
) -> List[torch.nn.Module]:
|
|
283
|
+
|
|
284
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
285
|
+
|
|
286
|
+
if BlockAdapter.is_cached(block_adapter.transformer):
|
|
287
|
+
return block_adapter.transformer
|
|
288
|
+
|
|
289
|
+
# Apply cache on transformer: mock cached transformer blocks
|
|
290
|
+
for (
|
|
291
|
+
cached_blocks,
|
|
292
|
+
transformer,
|
|
293
|
+
blocks_name,
|
|
294
|
+
unique_blocks_name,
|
|
295
|
+
dummy_blocks_names,
|
|
296
|
+
) in zip(
|
|
297
|
+
cls.collect_cached_blocks(block_adapter),
|
|
298
|
+
block_adapter.transformer,
|
|
299
|
+
block_adapter.blocks_name,
|
|
300
|
+
block_adapter.unique_blocks_name,
|
|
301
|
+
block_adapter.dummy_blocks_names,
|
|
302
|
+
):
|
|
303
|
+
cls.mock_transformer(
|
|
304
|
+
cached_blocks,
|
|
305
|
+
transformer,
|
|
306
|
+
blocks_name,
|
|
307
|
+
unique_blocks_name,
|
|
308
|
+
dummy_blocks_names,
|
|
309
|
+
)
|
|
310
|
+
|
|
311
|
+
return block_adapter.transformer
|
|
312
|
+
|
|
313
|
+
@classmethod
|
|
314
|
+
def mock_transformer(
|
|
315
|
+
cls,
|
|
316
|
+
cached_blocks: Dict[str, torch.nn.ModuleList],
|
|
317
|
+
transformer: torch.nn.Module,
|
|
318
|
+
blocks_name: List[str],
|
|
319
|
+
unique_blocks_name: List[str],
|
|
320
|
+
dummy_blocks_names: List[str],
|
|
321
|
+
) -> torch.nn.Module:
|
|
1039
322
|
dummy_blocks = torch.nn.ModuleList()
|
|
1040
323
|
|
|
1041
|
-
original_forward =
|
|
324
|
+
original_forward = transformer.forward
|
|
1042
325
|
|
|
1043
|
-
assert isinstance(
|
|
326
|
+
assert isinstance(dummy_blocks_names, list)
|
|
1044
327
|
|
|
1045
328
|
@functools.wraps(original_forward)
|
|
1046
329
|
def new_forward(self, *args, **kwargs):
|
|
1047
330
|
with ExitStack() as stack:
|
|
1048
|
-
|
|
1049
|
-
|
|
1050
|
-
|
|
1051
|
-
|
|
1052
|
-
|
|
331
|
+
for name, context_name in zip(
|
|
332
|
+
blocks_name,
|
|
333
|
+
unique_blocks_name,
|
|
334
|
+
):
|
|
335
|
+
stack.enter_context(
|
|
336
|
+
unittest.mock.patch.object(
|
|
337
|
+
self, name, cached_blocks[context_name]
|
|
338
|
+
)
|
|
1053
339
|
)
|
|
1054
|
-
|
|
1055
|
-
for dummy_name in block_adapter.dummy_blocks_names:
|
|
340
|
+
for dummy_name in dummy_blocks_names:
|
|
1056
341
|
stack.enter_context(
|
|
1057
342
|
unittest.mock.patch.object(
|
|
1058
|
-
self,
|
|
1059
|
-
dummy_name,
|
|
1060
|
-
dummy_blocks,
|
|
343
|
+
self, dummy_name, dummy_blocks
|
|
1061
344
|
)
|
|
1062
345
|
)
|
|
1063
346
|
return original_forward(*args, **kwargs)
|
|
1064
347
|
|
|
1065
|
-
|
|
1066
|
-
|
|
348
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
349
|
+
transformer._original_forward = original_forward
|
|
350
|
+
transformer._is_cached = True
|
|
351
|
+
|
|
352
|
+
return transformer
|
|
353
|
+
|
|
354
|
+
@classmethod
|
|
355
|
+
def collect_cached_blocks(
|
|
356
|
+
cls,
|
|
357
|
+
block_adapter: BlockAdapter,
|
|
358
|
+
) -> List[Dict[str, torch.nn.ModuleList]]:
|
|
359
|
+
|
|
360
|
+
BlockAdapter.assert_normalized(block_adapter)
|
|
361
|
+
|
|
362
|
+
total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
|
|
363
|
+
assert hasattr(block_adapter.pipe, "_cache_manager")
|
|
364
|
+
assert isinstance(
|
|
365
|
+
block_adapter.pipe._cache_manager, CachedContextManager
|
|
1067
366
|
)
|
|
1068
|
-
block_adapter.transformer._is_cached = True
|
|
1069
367
|
|
|
1070
|
-
|
|
368
|
+
for i in range(len(block_adapter.transformer)):
|
|
369
|
+
|
|
370
|
+
cached_blocks_bind_context = {}
|
|
371
|
+
for j in range(len(block_adapter.blocks[i])):
|
|
372
|
+
cached_blocks_bind_context[
|
|
373
|
+
block_adapter.unique_blocks_name[i][j]
|
|
374
|
+
] = torch.nn.ModuleList(
|
|
375
|
+
[
|
|
376
|
+
CachedBlocks(
|
|
377
|
+
# 0. Transformer blocks configuration
|
|
378
|
+
block_adapter.blocks[i][j],
|
|
379
|
+
transformer=block_adapter.transformer[i],
|
|
380
|
+
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
381
|
+
check_num_outputs=block_adapter.check_num_outputs,
|
|
382
|
+
# 1. Cache context configuration
|
|
383
|
+
cache_prefix=block_adapter.blocks_name[i][j],
|
|
384
|
+
cache_context=block_adapter.unique_blocks_name[i][
|
|
385
|
+
j
|
|
386
|
+
],
|
|
387
|
+
cache_manager=block_adapter.pipe._cache_manager,
|
|
388
|
+
)
|
|
389
|
+
]
|
|
390
|
+
)
|
|
391
|
+
|
|
392
|
+
total_cached_blocks.append(cached_blocks_bind_context)
|
|
393
|
+
|
|
394
|
+
return total_cached_blocks
|