cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +9 -4
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +16 -3
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +121 -563
- cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
- cache_dit/cache_factory/cache_blocks/utils.py +23 -0
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
- cache_dit/cache_factory/cache_interface.py +24 -16
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
- cache_dit/quantize/quantize_ao.py +19 -4
- cache_dit/quantize/quantize_interface.py +2 -2
- cache_dit/utils.py +19 -15
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.25.dist-info/RECORD +0 -36
- /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -1,524 +1,35 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
|
|
3
|
-
import inspect
|
|
4
3
|
import unittest
|
|
5
4
|
import functools
|
|
6
|
-
import dataclasses
|
|
7
5
|
|
|
8
|
-
from typing import
|
|
6
|
+
from typing import Dict
|
|
9
7
|
from contextlib import ExitStack
|
|
10
8
|
from diffusers import DiffusionPipeline
|
|
11
|
-
from cache_dit.cache_factory.patch.flux import (
|
|
12
|
-
maybe_patch_flux_transformer,
|
|
13
|
-
)
|
|
14
9
|
from cache_dit.cache_factory import CacheType
|
|
10
|
+
from cache_dit.cache_factory import CachedContext
|
|
15
11
|
from cache_dit.cache_factory import ForwardPattern
|
|
16
|
-
from cache_dit.cache_factory
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
12
|
+
from cache_dit.cache_factory import BlockAdapter
|
|
13
|
+
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
14
|
+
from cache_dit.cache_factory import CachedBlocks
|
|
15
|
+
|
|
20
16
|
from cache_dit.logger import init_logger
|
|
21
17
|
|
|
22
18
|
logger = init_logger(__name__)
|
|
23
19
|
|
|
24
20
|
|
|
25
|
-
|
|
26
|
-
class
|
|
27
|
-
pipe: DiffusionPipeline = None
|
|
28
|
-
transformer: torch.nn.Module = None
|
|
29
|
-
blocks: torch.nn.ModuleList = None
|
|
30
|
-
# transformer_blocks, blocks, etc.
|
|
31
|
-
blocks_name: str = None
|
|
32
|
-
dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
|
|
33
|
-
# flags to control auto block adapter
|
|
34
|
-
auto: bool = False
|
|
35
|
-
allow_prefixes: List[str] = dataclasses.field(
|
|
36
|
-
default_factory=lambda: [
|
|
37
|
-
"transformer",
|
|
38
|
-
"single_transformer",
|
|
39
|
-
"blocks",
|
|
40
|
-
"layers",
|
|
41
|
-
]
|
|
42
|
-
)
|
|
43
|
-
check_prefixes: bool = True
|
|
44
|
-
allow_suffixes: List[str] = dataclasses.field(
|
|
45
|
-
default_factory=lambda: ["TransformerBlock"]
|
|
46
|
-
)
|
|
47
|
-
check_suffixes: bool = False
|
|
48
|
-
blocks_policy: str = dataclasses.field(
|
|
49
|
-
default="max", metadata={"allowed_values": ["max", "min"]}
|
|
50
|
-
)
|
|
51
|
-
|
|
52
|
-
def __post_init__(self):
|
|
53
|
-
self.maybe_apply_patch()
|
|
54
|
-
|
|
55
|
-
def maybe_apply_patch(self):
|
|
56
|
-
# Process some specificial cases, specific for transformers
|
|
57
|
-
# that has different forward patterns between single_transformer_blocks
|
|
58
|
-
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
59
|
-
if self.transformer.__class__.__name__.startswith("Flux"):
|
|
60
|
-
self.transformer = maybe_patch_flux_transformer(
|
|
61
|
-
self.transformer,
|
|
62
|
-
blocks=self.blocks,
|
|
63
|
-
)
|
|
64
|
-
|
|
65
|
-
@staticmethod
|
|
66
|
-
def auto_block_adapter(
|
|
67
|
-
adapter: "BlockAdapter",
|
|
68
|
-
forward_pattern: Optional[ForwardPattern] = None,
|
|
69
|
-
) -> "BlockAdapter":
|
|
70
|
-
assert adapter.auto, (
|
|
71
|
-
"Please manually set `auto` to True, or, manually "
|
|
72
|
-
"set all the transformer blocks configuration."
|
|
73
|
-
)
|
|
74
|
-
assert adapter.pipe is not None, "adapter.pipe can not be None."
|
|
75
|
-
pipe = adapter.pipe
|
|
76
|
-
|
|
77
|
-
assert hasattr(pipe, "transformer"), "pipe.transformer can not be None."
|
|
78
|
-
|
|
79
|
-
transformer = pipe.transformer
|
|
80
|
-
|
|
81
|
-
# "transformer_blocks", "blocks", "single_transformer_blocks", "layers"
|
|
82
|
-
blocks, blocks_name = BlockAdapter.find_blocks(
|
|
83
|
-
transformer=transformer,
|
|
84
|
-
allow_prefixes=adapter.allow_prefixes,
|
|
85
|
-
allow_suffixes=adapter.allow_suffixes,
|
|
86
|
-
check_prefixes=adapter.check_prefixes,
|
|
87
|
-
check_suffixes=adapter.check_suffixes,
|
|
88
|
-
blocks_policy=adapter.blocks_policy,
|
|
89
|
-
forward_pattern=forward_pattern,
|
|
90
|
-
)
|
|
91
|
-
|
|
92
|
-
return BlockAdapter(
|
|
93
|
-
pipe=pipe,
|
|
94
|
-
transformer=transformer,
|
|
95
|
-
blocks=blocks,
|
|
96
|
-
blocks_name=blocks_name,
|
|
97
|
-
)
|
|
98
|
-
|
|
99
|
-
@staticmethod
|
|
100
|
-
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
101
|
-
if (
|
|
102
|
-
isinstance(adapter.pipe, DiffusionPipeline)
|
|
103
|
-
and adapter.transformer is not None
|
|
104
|
-
and adapter.blocks is not None
|
|
105
|
-
and adapter.blocks_name is not None
|
|
106
|
-
and isinstance(adapter.blocks, torch.nn.ModuleList)
|
|
107
|
-
):
|
|
108
|
-
return True
|
|
109
|
-
|
|
110
|
-
logger.warning("Check block adapter failed!")
|
|
111
|
-
return False
|
|
112
|
-
|
|
113
|
-
@staticmethod
|
|
114
|
-
def find_blocks(
|
|
115
|
-
transformer: torch.nn.Module,
|
|
116
|
-
allow_prefixes: List[str] = [
|
|
117
|
-
"transformer",
|
|
118
|
-
"single_transformer",
|
|
119
|
-
"blocks",
|
|
120
|
-
"layers",
|
|
121
|
-
],
|
|
122
|
-
allow_suffixes: List[str] = [
|
|
123
|
-
"TransformerBlock",
|
|
124
|
-
],
|
|
125
|
-
check_prefixes: bool = True,
|
|
126
|
-
check_suffixes: bool = False,
|
|
127
|
-
**kwargs,
|
|
128
|
-
) -> Tuple[torch.nn.ModuleList, str]:
|
|
129
|
-
# Check prefixes
|
|
130
|
-
if check_prefixes:
|
|
131
|
-
blocks_names = []
|
|
132
|
-
for attr_name in dir(transformer):
|
|
133
|
-
for prefix in allow_prefixes:
|
|
134
|
-
if attr_name.startswith(prefix):
|
|
135
|
-
blocks_names.append(attr_name)
|
|
136
|
-
else:
|
|
137
|
-
blocks_names = dir(transformer)
|
|
138
|
-
|
|
139
|
-
# Check ModuleList
|
|
140
|
-
valid_names = []
|
|
141
|
-
valid_count = []
|
|
142
|
-
forward_pattern = kwargs.get("forward_pattern", None)
|
|
143
|
-
for blocks_name in blocks_names:
|
|
144
|
-
if blocks := getattr(transformer, blocks_name, None):
|
|
145
|
-
if isinstance(blocks, torch.nn.ModuleList):
|
|
146
|
-
block = blocks[0]
|
|
147
|
-
block_cls_name = block.__class__.__name__
|
|
148
|
-
# Check suffixes
|
|
149
|
-
if isinstance(block, torch.nn.Module) and (
|
|
150
|
-
any(
|
|
151
|
-
(
|
|
152
|
-
block_cls_name.endswith(allow_suffix)
|
|
153
|
-
for allow_suffix in allow_suffixes
|
|
154
|
-
)
|
|
155
|
-
)
|
|
156
|
-
or (not check_suffixes)
|
|
157
|
-
):
|
|
158
|
-
# May check forward pattern
|
|
159
|
-
if forward_pattern is not None:
|
|
160
|
-
if BlockAdapter.match_blocks_pattern(
|
|
161
|
-
blocks,
|
|
162
|
-
forward_pattern,
|
|
163
|
-
logging=False,
|
|
164
|
-
):
|
|
165
|
-
valid_names.append(blocks_name)
|
|
166
|
-
valid_count.append(len(blocks))
|
|
167
|
-
else:
|
|
168
|
-
valid_names.append(blocks_name)
|
|
169
|
-
valid_count.append(len(blocks))
|
|
170
|
-
|
|
171
|
-
if not valid_names:
|
|
172
|
-
raise ValueError(
|
|
173
|
-
"Auto selected transformer blocks failed, please set it manually."
|
|
174
|
-
)
|
|
175
|
-
|
|
176
|
-
final_name = valid_names[0]
|
|
177
|
-
final_count = valid_count[0]
|
|
178
|
-
block_policy = kwargs.get("blocks_policy", "max")
|
|
179
|
-
|
|
180
|
-
for blocks_name, count in zip(valid_names, valid_count):
|
|
181
|
-
blocks = getattr(transformer, blocks_name)
|
|
182
|
-
logger.info(
|
|
183
|
-
f"Auto selected transformer blocks: {blocks_name}, "
|
|
184
|
-
f"class: {blocks[0].__class__.__name__}, "
|
|
185
|
-
f"num blocks: {count}"
|
|
186
|
-
)
|
|
187
|
-
if block_policy == "max":
|
|
188
|
-
if final_count < count:
|
|
189
|
-
final_count = count
|
|
190
|
-
final_name = blocks_name
|
|
191
|
-
else:
|
|
192
|
-
if final_count > count:
|
|
193
|
-
final_count = count
|
|
194
|
-
final_name = blocks_name
|
|
195
|
-
|
|
196
|
-
final_blocks = getattr(transformer, final_name)
|
|
197
|
-
|
|
198
|
-
logger.info(
|
|
199
|
-
f"Final selected transformer blocks: {final_name}, "
|
|
200
|
-
f"class: {final_blocks[0].__class__.__name__}, "
|
|
201
|
-
f"num blocks: {final_count}, block_policy: {block_policy}."
|
|
202
|
-
)
|
|
203
|
-
|
|
204
|
-
return final_blocks, final_name
|
|
205
|
-
|
|
206
|
-
@staticmethod
|
|
207
|
-
def match_block_pattern(
|
|
208
|
-
block: torch.nn.Module,
|
|
209
|
-
forward_pattern: ForwardPattern,
|
|
210
|
-
) -> bool:
|
|
211
|
-
assert (
|
|
212
|
-
forward_pattern.Supported
|
|
213
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
214
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
215
|
-
|
|
216
|
-
forward_parameters = set(
|
|
217
|
-
inspect.signature(block.forward).parameters.keys()
|
|
218
|
-
)
|
|
219
|
-
num_outputs = str(
|
|
220
|
-
inspect.signature(block.forward).return_annotation
|
|
221
|
-
).count("torch.Tensor")
|
|
222
|
-
|
|
223
|
-
in_matched = True
|
|
224
|
-
out_matched = True
|
|
225
|
-
if num_outputs > 0 and len(forward_pattern.Out) != num_outputs:
|
|
226
|
-
# output pattern not match
|
|
227
|
-
out_matched = False
|
|
228
|
-
|
|
229
|
-
for required_param in forward_pattern.In:
|
|
230
|
-
if required_param not in forward_parameters:
|
|
231
|
-
in_matched = False
|
|
232
|
-
|
|
233
|
-
return in_matched and out_matched
|
|
234
|
-
|
|
235
|
-
@staticmethod
|
|
236
|
-
def match_blocks_pattern(
|
|
237
|
-
transformer_blocks: torch.nn.ModuleList,
|
|
238
|
-
forward_pattern: ForwardPattern,
|
|
239
|
-
logging: bool = True,
|
|
240
|
-
) -> bool:
|
|
241
|
-
assert (
|
|
242
|
-
forward_pattern.Supported
|
|
243
|
-
and forward_pattern in ForwardPattern.supported_patterns()
|
|
244
|
-
), f"Pattern {forward_pattern} is not support now!"
|
|
245
|
-
|
|
246
|
-
assert isinstance(transformer_blocks, torch.nn.ModuleList)
|
|
247
|
-
|
|
248
|
-
pattern_matched_states = []
|
|
249
|
-
for block in transformer_blocks:
|
|
250
|
-
pattern_matched_states.append(
|
|
251
|
-
BlockAdapter.match_block_pattern(
|
|
252
|
-
block,
|
|
253
|
-
forward_pattern,
|
|
254
|
-
)
|
|
255
|
-
)
|
|
256
|
-
|
|
257
|
-
pattern_matched = all(pattern_matched_states) # all block match
|
|
258
|
-
if pattern_matched and logging:
|
|
259
|
-
block_cls_name = transformer_blocks[0].__class__.__name__
|
|
260
|
-
logger.info(
|
|
261
|
-
f"Match Block Forward Pattern: {block_cls_name}, {forward_pattern}"
|
|
262
|
-
f"\nIN:{forward_pattern.In}, OUT:{forward_pattern.Out})"
|
|
263
|
-
)
|
|
264
|
-
|
|
265
|
-
return pattern_matched
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
@dataclasses.dataclass
|
|
269
|
-
class UnifiedCacheParams:
|
|
270
|
-
block_adapter: BlockAdapter = None
|
|
271
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
class UnifiedCacheAdapter:
|
|
275
|
-
_supported_pipelines = [
|
|
276
|
-
"Flux",
|
|
277
|
-
"Mochi",
|
|
278
|
-
"CogVideoX",
|
|
279
|
-
"Wan",
|
|
280
|
-
"HunyuanVideo",
|
|
281
|
-
"QwenImage",
|
|
282
|
-
"LTXVideo",
|
|
283
|
-
"Allegro",
|
|
284
|
-
"CogView3Plus",
|
|
285
|
-
"CogView4",
|
|
286
|
-
"Cosmos",
|
|
287
|
-
"EasyAnimate",
|
|
288
|
-
"SkyReelsV2",
|
|
289
|
-
"SD3",
|
|
290
|
-
]
|
|
21
|
+
# Unified Cached Adapter
|
|
22
|
+
class CachedAdapter:
|
|
291
23
|
|
|
292
24
|
def __call__(self, *args, **kwargs):
|
|
293
25
|
return self.apply(*args, **kwargs)
|
|
294
26
|
|
|
295
|
-
@classmethod
|
|
296
|
-
def is_supported(cls, pipe: DiffusionPipeline) -> bool:
|
|
297
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
298
|
-
for prefix in cls._supported_pipelines:
|
|
299
|
-
if pipe_cls_name.startswith(prefix):
|
|
300
|
-
return True
|
|
301
|
-
return False
|
|
302
|
-
|
|
303
|
-
@classmethod
|
|
304
|
-
def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
|
|
305
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
306
|
-
if pipe_cls_name.startswith("Flux"):
|
|
307
|
-
from diffusers import FluxTransformer2DModel
|
|
308
|
-
|
|
309
|
-
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
310
|
-
return UnifiedCacheParams(
|
|
311
|
-
block_adapter=BlockAdapter(
|
|
312
|
-
pipe=pipe,
|
|
313
|
-
transformer=pipe.transformer,
|
|
314
|
-
blocks=(
|
|
315
|
-
pipe.transformer.transformer_blocks
|
|
316
|
-
+ pipe.transformer.single_transformer_blocks
|
|
317
|
-
),
|
|
318
|
-
blocks_name="transformer_blocks",
|
|
319
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
320
|
-
),
|
|
321
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
322
|
-
)
|
|
323
|
-
elif pipe_cls_name.startswith("Mochi"):
|
|
324
|
-
from diffusers import MochiTransformer3DModel
|
|
325
|
-
|
|
326
|
-
assert isinstance(pipe.transformer, MochiTransformer3DModel)
|
|
327
|
-
return UnifiedCacheParams(
|
|
328
|
-
block_adapter=BlockAdapter(
|
|
329
|
-
pipe=pipe,
|
|
330
|
-
transformer=pipe.transformer,
|
|
331
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
332
|
-
blocks_name="transformer_blocks",
|
|
333
|
-
dummy_blocks_names=[],
|
|
334
|
-
),
|
|
335
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
336
|
-
)
|
|
337
|
-
elif pipe_cls_name.startswith("CogVideoX"):
|
|
338
|
-
from diffusers import CogVideoXTransformer3DModel
|
|
339
|
-
|
|
340
|
-
assert isinstance(pipe.transformer, CogVideoXTransformer3DModel)
|
|
341
|
-
return UnifiedCacheParams(
|
|
342
|
-
block_adapter=BlockAdapter(
|
|
343
|
-
pipe=pipe,
|
|
344
|
-
transformer=pipe.transformer,
|
|
345
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
346
|
-
blocks_name="transformer_blocks",
|
|
347
|
-
dummy_blocks_names=[],
|
|
348
|
-
),
|
|
349
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
350
|
-
)
|
|
351
|
-
elif pipe_cls_name.startswith("Wan"):
|
|
352
|
-
from diffusers import (
|
|
353
|
-
WanTransformer3DModel,
|
|
354
|
-
WanVACETransformer3DModel,
|
|
355
|
-
)
|
|
356
|
-
|
|
357
|
-
assert isinstance(
|
|
358
|
-
pipe.transformer,
|
|
359
|
-
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
360
|
-
)
|
|
361
|
-
return UnifiedCacheParams(
|
|
362
|
-
block_adapter=BlockAdapter(
|
|
363
|
-
pipe=pipe,
|
|
364
|
-
transformer=pipe.transformer,
|
|
365
|
-
blocks=pipe.transformer.blocks,
|
|
366
|
-
blocks_name="blocks",
|
|
367
|
-
dummy_blocks_names=[],
|
|
368
|
-
),
|
|
369
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
370
|
-
)
|
|
371
|
-
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
372
|
-
from diffusers import HunyuanVideoTransformer3DModel
|
|
373
|
-
|
|
374
|
-
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
375
|
-
return UnifiedCacheParams(
|
|
376
|
-
block_adapter=BlockAdapter(
|
|
377
|
-
pipe=pipe,
|
|
378
|
-
blocks=(
|
|
379
|
-
pipe.transformer.transformer_blocks
|
|
380
|
-
+ pipe.transformer.single_transformer_blocks
|
|
381
|
-
),
|
|
382
|
-
blocks_name="transformer_blocks",
|
|
383
|
-
dummy_blocks_names=["single_transformer_blocks"],
|
|
384
|
-
),
|
|
385
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
386
|
-
)
|
|
387
|
-
elif pipe_cls_name.startswith("QwenImage"):
|
|
388
|
-
from diffusers import QwenImageTransformer2DModel
|
|
389
|
-
|
|
390
|
-
assert isinstance(pipe.transformer, QwenImageTransformer2DModel)
|
|
391
|
-
return UnifiedCacheParams(
|
|
392
|
-
block_adapter=BlockAdapter(
|
|
393
|
-
pipe=pipe,
|
|
394
|
-
transformer=pipe.transformer,
|
|
395
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
396
|
-
blocks_name="transformer_blocks",
|
|
397
|
-
dummy_blocks_names=[],
|
|
398
|
-
),
|
|
399
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
400
|
-
)
|
|
401
|
-
elif pipe_cls_name.startswith("LTXVideo"):
|
|
402
|
-
from diffusers import LTXVideoTransformer3DModel
|
|
403
|
-
|
|
404
|
-
assert isinstance(pipe.transformer, LTXVideoTransformer3DModel)
|
|
405
|
-
return UnifiedCacheParams(
|
|
406
|
-
block_adapter=BlockAdapter(
|
|
407
|
-
pipe=pipe,
|
|
408
|
-
transformer=pipe.transformer,
|
|
409
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
410
|
-
blocks_name="transformer_blocks",
|
|
411
|
-
dummy_blocks_names=[],
|
|
412
|
-
),
|
|
413
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
414
|
-
)
|
|
415
|
-
elif pipe_cls_name.startswith("Allegro"):
|
|
416
|
-
from diffusers import AllegroTransformer3DModel
|
|
417
|
-
|
|
418
|
-
assert isinstance(pipe.transformer, AllegroTransformer3DModel)
|
|
419
|
-
return UnifiedCacheParams(
|
|
420
|
-
block_adapter=BlockAdapter(
|
|
421
|
-
pipe=pipe,
|
|
422
|
-
transformer=pipe.transformer,
|
|
423
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
424
|
-
blocks_name="transformer_blocks",
|
|
425
|
-
dummy_blocks_names=[],
|
|
426
|
-
),
|
|
427
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
428
|
-
)
|
|
429
|
-
elif pipe_cls_name.startswith("CogView3Plus"):
|
|
430
|
-
from diffusers import CogView3PlusTransformer2DModel
|
|
431
|
-
|
|
432
|
-
assert isinstance(pipe.transformer, CogView3PlusTransformer2DModel)
|
|
433
|
-
return UnifiedCacheParams(
|
|
434
|
-
block_adapter=BlockAdapter(
|
|
435
|
-
pipe=pipe,
|
|
436
|
-
transformer=pipe.transformer,
|
|
437
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
438
|
-
blocks_name="transformer_blocks",
|
|
439
|
-
dummy_blocks_names=[],
|
|
440
|
-
),
|
|
441
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
442
|
-
)
|
|
443
|
-
elif pipe_cls_name.startswith("CogView4"):
|
|
444
|
-
from diffusers import CogView4Transformer2DModel
|
|
445
|
-
|
|
446
|
-
assert isinstance(pipe.transformer, CogView4Transformer2DModel)
|
|
447
|
-
return UnifiedCacheParams(
|
|
448
|
-
block_adapter=BlockAdapter(
|
|
449
|
-
pipe=pipe,
|
|
450
|
-
transformer=pipe.transformer,
|
|
451
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
452
|
-
blocks_name="transformer_blocks",
|
|
453
|
-
dummy_blocks_names=[],
|
|
454
|
-
),
|
|
455
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
456
|
-
)
|
|
457
|
-
elif pipe_cls_name.startswith("Cosmos"):
|
|
458
|
-
from diffusers import CosmosTransformer3DModel
|
|
459
|
-
|
|
460
|
-
assert isinstance(pipe.transformer, CosmosTransformer3DModel)
|
|
461
|
-
return UnifiedCacheParams(
|
|
462
|
-
block_adapter=BlockAdapter(
|
|
463
|
-
pipe=pipe,
|
|
464
|
-
transformer=pipe.transformer,
|
|
465
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
466
|
-
blocks_name="transformer_blocks",
|
|
467
|
-
dummy_blocks_names=[],
|
|
468
|
-
),
|
|
469
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
470
|
-
)
|
|
471
|
-
elif pipe_cls_name.startswith("EasyAnimate"):
|
|
472
|
-
from diffusers import EasyAnimateTransformer3DModel
|
|
473
|
-
|
|
474
|
-
assert isinstance(pipe.transformer, EasyAnimateTransformer3DModel)
|
|
475
|
-
return UnifiedCacheParams(
|
|
476
|
-
block_adapter=BlockAdapter(
|
|
477
|
-
pipe=pipe,
|
|
478
|
-
transformer=pipe.transformer,
|
|
479
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
480
|
-
blocks_name="transformer_blocks",
|
|
481
|
-
dummy_blocks_names=[],
|
|
482
|
-
),
|
|
483
|
-
forward_pattern=ForwardPattern.Pattern_0,
|
|
484
|
-
)
|
|
485
|
-
elif pipe_cls_name.startswith("SkyReelsV2"):
|
|
486
|
-
from diffusers import SkyReelsV2Transformer3DModel
|
|
487
|
-
|
|
488
|
-
assert isinstance(pipe.transformer, SkyReelsV2Transformer3DModel)
|
|
489
|
-
return UnifiedCacheParams(
|
|
490
|
-
block_adapter=BlockAdapter(
|
|
491
|
-
pipe=pipe,
|
|
492
|
-
transformer=pipe.transformer,
|
|
493
|
-
blocks=pipe.transformer.blocks,
|
|
494
|
-
blocks_name="blocks",
|
|
495
|
-
dummy_blocks_names=[],
|
|
496
|
-
),
|
|
497
|
-
forward_pattern=ForwardPattern.Pattern_2,
|
|
498
|
-
)
|
|
499
|
-
elif pipe_cls_name.startswith("SD3"):
|
|
500
|
-
from diffusers import SD3Transformer2DModel
|
|
501
|
-
|
|
502
|
-
assert isinstance(pipe.transformer, SD3Transformer2DModel)
|
|
503
|
-
return UnifiedCacheParams(
|
|
504
|
-
block_adapter=BlockAdapter(
|
|
505
|
-
pipe=pipe,
|
|
506
|
-
transformer=pipe.transformer,
|
|
507
|
-
blocks=pipe.transformer.transformer_blocks,
|
|
508
|
-
blocks_name="transformer_blocks",
|
|
509
|
-
dummy_blocks_names=[],
|
|
510
|
-
),
|
|
511
|
-
forward_pattern=ForwardPattern.Pattern_1,
|
|
512
|
-
)
|
|
513
|
-
else:
|
|
514
|
-
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
515
|
-
|
|
516
27
|
@classmethod
|
|
517
28
|
def apply(
|
|
518
29
|
cls,
|
|
519
30
|
pipe: DiffusionPipeline = None,
|
|
520
31
|
block_adapter: BlockAdapter = None,
|
|
521
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
32
|
+
# forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
522
33
|
**cache_context_kwargs,
|
|
523
34
|
) -> DiffusionPipeline:
|
|
524
35
|
assert (
|
|
@@ -526,15 +37,14 @@ class UnifiedCacheAdapter:
|
|
|
526
37
|
), "pipe or block_adapter can not both None!"
|
|
527
38
|
|
|
528
39
|
if pipe is not None:
|
|
529
|
-
if
|
|
40
|
+
if BlockAdapterRegistry.is_supported(pipe):
|
|
530
41
|
logger.info(
|
|
531
42
|
f"{pipe.__class__.__name__} is officially supported by cache-dit. "
|
|
532
43
|
"Use it's pre-defined BlockAdapter directly!"
|
|
533
44
|
)
|
|
534
|
-
|
|
45
|
+
block_adapter = BlockAdapterRegistry.get_adapter(pipe)
|
|
535
46
|
return cls.cachify(
|
|
536
|
-
|
|
537
|
-
forward_pattern=params.forward_pattern,
|
|
47
|
+
block_adapter,
|
|
538
48
|
**cache_context_kwargs,
|
|
539
49
|
)
|
|
540
50
|
else:
|
|
@@ -548,7 +58,6 @@ class UnifiedCacheAdapter:
|
|
|
548
58
|
)
|
|
549
59
|
return cls.cachify(
|
|
550
60
|
block_adapter,
|
|
551
|
-
forward_pattern=forward_pattern,
|
|
552
61
|
**cache_context_kwargs,
|
|
553
62
|
)
|
|
554
63
|
|
|
@@ -556,31 +65,27 @@ class UnifiedCacheAdapter:
|
|
|
556
65
|
def cachify(
|
|
557
66
|
cls,
|
|
558
67
|
block_adapter: BlockAdapter,
|
|
559
|
-
*,
|
|
560
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
561
68
|
**cache_context_kwargs,
|
|
562
69
|
) -> DiffusionPipeline:
|
|
563
70
|
|
|
564
71
|
if block_adapter.auto:
|
|
565
72
|
block_adapter = BlockAdapter.auto_block_adapter(
|
|
566
73
|
block_adapter,
|
|
567
|
-
forward_pattern,
|
|
568
74
|
)
|
|
569
75
|
|
|
570
76
|
if BlockAdapter.check_block_adapter(block_adapter):
|
|
571
|
-
|
|
77
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
78
|
+
# 0. Apply cache on pipeline: wrap cache context
|
|
572
79
|
cls.create_context(
|
|
573
|
-
block_adapter
|
|
80
|
+
block_adapter,
|
|
574
81
|
**cache_context_kwargs,
|
|
575
82
|
)
|
|
576
|
-
# Apply cache on transformer: mock cached transformer blocks
|
|
83
|
+
# 1. Apply cache on transformer: mock cached transformer blocks
|
|
577
84
|
cls.mock_blocks(
|
|
578
85
|
block_adapter,
|
|
579
|
-
forward_pattern=forward_pattern,
|
|
580
86
|
)
|
|
581
87
|
cls.patch_params(
|
|
582
88
|
block_adapter,
|
|
583
|
-
forward_pattern=forward_pattern,
|
|
584
89
|
**cache_context_kwargs,
|
|
585
90
|
)
|
|
586
91
|
return block_adapter.pipe
|
|
@@ -589,33 +94,36 @@ class UnifiedCacheAdapter:
|
|
|
589
94
|
def patch_params(
|
|
590
95
|
cls,
|
|
591
96
|
block_adapter: BlockAdapter,
|
|
592
|
-
forward_pattern: ForwardPattern = None,
|
|
593
97
|
**cache_context_kwargs,
|
|
594
98
|
):
|
|
595
|
-
block_adapter.transformer._forward_pattern =
|
|
99
|
+
block_adapter.transformer._forward_pattern = (
|
|
100
|
+
block_adapter.forward_pattern
|
|
101
|
+
)
|
|
102
|
+
block_adapter.transformer._has_separate_cfg = (
|
|
103
|
+
block_adapter.has_separate_cfg
|
|
104
|
+
)
|
|
596
105
|
block_adapter.transformer._cache_context_kwargs = cache_context_kwargs
|
|
597
106
|
block_adapter.pipe.__class__._cache_context_kwargs = (
|
|
598
107
|
cache_context_kwargs
|
|
599
108
|
)
|
|
600
|
-
|
|
601
|
-
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
) -> bool:
|
|
606
|
-
cls_name = pipe_or_transformer.__class__.__name__
|
|
607
|
-
if cls_name.startswith("QwenImage"):
|
|
608
|
-
return True
|
|
609
|
-
elif cls_name.startswith("Wan"):
|
|
610
|
-
return True
|
|
611
|
-
return False
|
|
109
|
+
for blocks, forward_pattern in zip(
|
|
110
|
+
block_adapter.blocks, block_adapter.forward_pattern
|
|
111
|
+
):
|
|
112
|
+
blocks._forward_pattern = forward_pattern
|
|
113
|
+
blocks._cache_context_kwargs = cache_context_kwargs
|
|
612
114
|
|
|
613
115
|
@classmethod
|
|
614
116
|
def check_context_kwargs(cls, pipe, **cache_context_kwargs):
|
|
615
117
|
# Check cache_context_kwargs
|
|
616
118
|
if not cache_context_kwargs["do_separate_cfg"]:
|
|
617
119
|
# Check cfg for some specific case if users don't set it as True
|
|
618
|
-
cache_context_kwargs["do_separate_cfg"] =
|
|
120
|
+
cache_context_kwargs["do_separate_cfg"] = (
|
|
121
|
+
BlockAdapterRegistry.has_separate_cfg(pipe)
|
|
122
|
+
)
|
|
123
|
+
logger.info(
|
|
124
|
+
f"Use default 'do_separate_cfg': {cache_context_kwargs['do_separate_cfg']}, "
|
|
125
|
+
f"Pipeline: {pipe.__class__.__name__}."
|
|
126
|
+
)
|
|
619
127
|
|
|
620
128
|
if cache_type := cache_context_kwargs.pop("cache_type", None):
|
|
621
129
|
assert (
|
|
@@ -627,65 +135,87 @@ class UnifiedCacheAdapter:
|
|
|
627
135
|
@classmethod
|
|
628
136
|
def create_context(
|
|
629
137
|
cls,
|
|
630
|
-
|
|
138
|
+
block_adapter: BlockAdapter,
|
|
631
139
|
**cache_context_kwargs,
|
|
632
140
|
) -> DiffusionPipeline:
|
|
633
|
-
if getattr(pipe, "_is_cached", False):
|
|
634
|
-
return pipe
|
|
141
|
+
if getattr(block_adapter.pipe, "_is_cached", False):
|
|
142
|
+
return block_adapter.pipe
|
|
635
143
|
|
|
636
144
|
# Check cache_context_kwargs
|
|
637
145
|
cache_context_kwargs = cls.check_context_kwargs(
|
|
638
|
-
pipe,
|
|
146
|
+
block_adapter.pipe,
|
|
639
147
|
**cache_context_kwargs,
|
|
640
148
|
)
|
|
641
149
|
# Apply cache on pipeline: wrap cache context
|
|
642
|
-
cache_kwargs, _ =
|
|
150
|
+
cache_kwargs, _ = CachedContext.collect_cache_kwargs(
|
|
643
151
|
default_attrs={},
|
|
644
152
|
**cache_context_kwargs,
|
|
645
153
|
)
|
|
646
|
-
original_call = pipe.__class__.__call__
|
|
154
|
+
original_call = block_adapter.pipe.__class__.__call__
|
|
647
155
|
|
|
648
156
|
@functools.wraps(original_call)
|
|
649
157
|
def new_call(self, *args, **kwargs):
|
|
650
|
-
with
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
158
|
+
with ExitStack() as stack:
|
|
159
|
+
# cache context will reset for each pipe inference
|
|
160
|
+
for blocks_name in block_adapter.blocks_name:
|
|
161
|
+
stack.enter_context(
|
|
162
|
+
CachedContext.cache_context(
|
|
163
|
+
CachedContext.reset_cache_context(
|
|
164
|
+
blocks_name,
|
|
165
|
+
**cache_kwargs,
|
|
166
|
+
),
|
|
167
|
+
)
|
|
168
|
+
)
|
|
169
|
+
outputs = original_call(self, *args, **kwargs)
|
|
170
|
+
cls.patch_stats(block_adapter)
|
|
171
|
+
return outputs
|
|
656
172
|
|
|
657
|
-
pipe.__class__.__call__ = new_call
|
|
658
|
-
pipe.__class__._is_cached = True
|
|
659
|
-
return pipe
|
|
173
|
+
block_adapter.pipe.__class__.__call__ = new_call
|
|
174
|
+
block_adapter.pipe.__class__._is_cached = True
|
|
175
|
+
return block_adapter.pipe
|
|
176
|
+
|
|
177
|
+
@classmethod
|
|
178
|
+
def patch_stats(cls, block_adapter: BlockAdapter):
|
|
179
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
180
|
+
patch_cached_stats,
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
patch_cached_stats(block_adapter.transformer)
|
|
184
|
+
for blocks, blocks_name in zip(
|
|
185
|
+
block_adapter.blocks, block_adapter.blocks_name
|
|
186
|
+
):
|
|
187
|
+
patch_cached_stats(blocks, blocks_name)
|
|
660
188
|
|
|
661
189
|
@classmethod
|
|
662
190
|
def mock_blocks(
|
|
663
191
|
cls,
|
|
664
192
|
block_adapter: BlockAdapter,
|
|
665
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
666
193
|
) -> torch.nn.Module:
|
|
667
194
|
|
|
668
195
|
if getattr(block_adapter.transformer, "_is_cached", False):
|
|
669
196
|
return block_adapter.transformer
|
|
670
197
|
|
|
671
198
|
# Check block forward pattern matching
|
|
672
|
-
|
|
673
|
-
|
|
674
|
-
forward_pattern
|
|
675
|
-
)
|
|
676
|
-
|
|
677
|
-
|
|
678
|
-
|
|
199
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
200
|
+
for forward_pattern, blocks in zip(
|
|
201
|
+
block_adapter.forward_pattern, block_adapter.blocks
|
|
202
|
+
):
|
|
203
|
+
assert BlockAdapter.match_blocks_pattern(
|
|
204
|
+
blocks,
|
|
205
|
+
forward_pattern=forward_pattern,
|
|
206
|
+
check_num_outputs=block_adapter.check_num_outputs,
|
|
207
|
+
), (
|
|
208
|
+
"No block forward pattern matched, "
|
|
209
|
+
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
210
|
+
)
|
|
679
211
|
|
|
680
212
|
# Apply cache on transformer: mock cached transformer blocks
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
)
|
|
688
|
-
]
|
|
213
|
+
# TODO: Use blocks_name to spearate cached context for different
|
|
214
|
+
# blocks list. For example, single_transformer_blocks and
|
|
215
|
+
# transformer_blocks should have different cached context and
|
|
216
|
+
# forward pattern.
|
|
217
|
+
cached_blocks = cls.collect_cached_blocks(
|
|
218
|
+
block_adapter=block_adapter,
|
|
689
219
|
)
|
|
690
220
|
dummy_blocks = torch.nn.ModuleList()
|
|
691
221
|
|
|
@@ -696,13 +226,14 @@ class UnifiedCacheAdapter:
|
|
|
696
226
|
@functools.wraps(original_forward)
|
|
697
227
|
def new_forward(self, *args, **kwargs):
|
|
698
228
|
with ExitStack() as stack:
|
|
699
|
-
|
|
700
|
-
|
|
701
|
-
|
|
702
|
-
|
|
703
|
-
|
|
229
|
+
for blocks_name in block_adapter.blocks_name:
|
|
230
|
+
stack.enter_context(
|
|
231
|
+
unittest.mock.patch.object(
|
|
232
|
+
self,
|
|
233
|
+
blocks_name,
|
|
234
|
+
cached_blocks[blocks_name],
|
|
235
|
+
)
|
|
704
236
|
)
|
|
705
|
-
)
|
|
706
237
|
for dummy_name in block_adapter.dummy_blocks_names:
|
|
707
238
|
stack.enter_context(
|
|
708
239
|
unittest.mock.patch.object(
|
|
@@ -719,3 +250,30 @@ class UnifiedCacheAdapter:
|
|
|
719
250
|
block_adapter.transformer._is_cached = True
|
|
720
251
|
|
|
721
252
|
return block_adapter.transformer
|
|
253
|
+
|
|
254
|
+
@classmethod
|
|
255
|
+
def collect_cached_blocks(
|
|
256
|
+
cls,
|
|
257
|
+
block_adapter: BlockAdapter,
|
|
258
|
+
) -> Dict[str, torch.nn.ModuleList]:
|
|
259
|
+
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
260
|
+
|
|
261
|
+
cached_blocks_bind_context = {}
|
|
262
|
+
|
|
263
|
+
for i in range(len(block_adapter.blocks)):
|
|
264
|
+
cached_blocks_bind_context[block_adapter.blocks_name[i]] = (
|
|
265
|
+
torch.nn.ModuleList(
|
|
266
|
+
[
|
|
267
|
+
CachedBlocks(
|
|
268
|
+
block_adapter.blocks[i],
|
|
269
|
+
block_adapter.blocks_name[i],
|
|
270
|
+
block_adapter.blocks_name[i], # context name
|
|
271
|
+
transformer=block_adapter.transformer,
|
|
272
|
+
forward_pattern=block_adapter.forward_pattern[i],
|
|
273
|
+
check_num_outputs=block_adapter.check_num_outputs,
|
|
274
|
+
)
|
|
275
|
+
]
|
|
276
|
+
)
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
return cached_blocks_bind_context
|