cache-dit 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +5 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +2 -0
- cache_dit/cache_factory/cache_adapters.py +375 -26
- cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
- cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
- cache_dit/cache_factory/cache_blocks/utils.py +19 -0
- cache_dit/cache_factory/cache_context.py +32 -25
- cache_dit/cache_factory/cache_interface.py +8 -3
- cache_dit/cache_factory/forward_pattern.py +45 -24
- cache_dit/cache_factory/patch_functors/__init__.py +5 -0
- cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
- cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
- cache_dit/compile/utils.py +1 -1
- cache_dit/quantize/__init__.py +1 -0
- cache_dit/quantize/quantize_ao.py +196 -0
- cache_dit/quantize/quantize_interface.py +46 -0
- cache_dit/utils.py +49 -17
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/METADATA +43 -18
- cache_dit-0.2.26.dist-info/RECORD +42 -0
- cache_dit-0.2.24.dist-info/RECORD +0 -32
- /cache_dit/{cache_factory/patch/__init__.py → quantize/quantize_svdq.py} +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py
CHANGED
|
@@ -11,7 +11,10 @@ from cache_dit.cache_factory import block_range
|
|
|
11
11
|
from cache_dit.cache_factory import CacheType
|
|
12
12
|
from cache_dit.cache_factory import BlockAdapter
|
|
13
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
|
+
from cache_dit.cache_factory import PatchFunctor
|
|
15
|
+
from cache_dit.cache_factory import supported_pipelines
|
|
14
16
|
from cache_dit.compile import set_compile_configs
|
|
17
|
+
from cache_dit.quantize import quantize
|
|
15
18
|
from cache_dit.utils import summary
|
|
16
19
|
from cache_dit.utils import strify
|
|
17
20
|
from cache_dit.logger import init_logger
|
|
@@ -23,3 +26,5 @@ Forward_Pattern_0 = ForwardPattern.Pattern_0
|
|
|
23
26
|
Forward_Pattern_1 = ForwardPattern.Pattern_1
|
|
24
27
|
Forward_Pattern_2 = ForwardPattern.Pattern_2
|
|
25
28
|
Forward_Pattern_3 = ForwardPattern.Pattern_3
|
|
29
|
+
Forward_Pattern_4 = ForwardPattern.Pattern_4
|
|
30
|
+
Forward_Pattern_5 = ForwardPattern.Pattern_5
|
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.26'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 26)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -5,4 +5,6 @@ from cache_dit.cache_factory.cache_types import block_range
|
|
|
5
5
|
from cache_dit.cache_factory.cache_adapters import BlockAdapter
|
|
6
6
|
from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
|
|
7
7
|
from cache_dit.cache_factory.cache_interface import enable_cache
|
|
8
|
+
from cache_dit.cache_factory.cache_interface import supported_pipelines
|
|
9
|
+
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
8
10
|
from cache_dit.cache_factory.utils import load_options
|
|
@@ -8,15 +8,14 @@ import dataclasses
|
|
|
8
8
|
from typing import Any, Tuple, List, Optional
|
|
9
9
|
from contextlib import ExitStack
|
|
10
10
|
from diffusers import DiffusionPipeline
|
|
11
|
-
from cache_dit.cache_factory.patch.flux import (
|
|
12
|
-
maybe_patch_flux_transformer,
|
|
13
|
-
)
|
|
14
11
|
from cache_dit.cache_factory import CacheType
|
|
12
|
+
from cache_dit.cache_factory import cache_context
|
|
15
13
|
from cache_dit.cache_factory import ForwardPattern
|
|
14
|
+
from cache_dit.cache_factory.patch_functors import PatchFunctor
|
|
16
15
|
from cache_dit.cache_factory.cache_blocks import (
|
|
17
|
-
|
|
18
|
-
DBCachedTransformerBlocks,
|
|
16
|
+
DBCachedBlocks,
|
|
19
17
|
)
|
|
18
|
+
|
|
20
19
|
from cache_dit.logger import init_logger
|
|
21
20
|
|
|
22
21
|
logger = init_logger(__name__)
|
|
@@ -24,12 +23,14 @@ logger = init_logger(__name__)
|
|
|
24
23
|
|
|
25
24
|
@dataclasses.dataclass
|
|
26
25
|
class BlockAdapter:
|
|
27
|
-
pipe: DiffusionPipeline = None
|
|
26
|
+
pipe: DiffusionPipeline | Any = None
|
|
28
27
|
transformer: torch.nn.Module = None
|
|
29
28
|
blocks: torch.nn.ModuleList = None
|
|
30
29
|
# transformer_blocks, blocks, etc.
|
|
31
30
|
blocks_name: str = None
|
|
32
|
-
dummy_blocks_names:
|
|
31
|
+
dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
|
|
32
|
+
# patch functor: Flux, etc.
|
|
33
|
+
patch_functor: Optional[PatchFunctor] = None
|
|
33
34
|
# flags to control auto block adapter
|
|
34
35
|
auto: bool = False
|
|
35
36
|
allow_prefixes: List[str] = dataclasses.field(
|
|
@@ -38,6 +39,8 @@ class BlockAdapter:
|
|
|
38
39
|
"single_transformer",
|
|
39
40
|
"blocks",
|
|
40
41
|
"layers",
|
|
42
|
+
"single_stream_blocks",
|
|
43
|
+
"double_stream_blocks",
|
|
41
44
|
]
|
|
42
45
|
)
|
|
43
46
|
check_prefixes: bool = True
|
|
@@ -50,17 +53,19 @@ class BlockAdapter:
|
|
|
50
53
|
)
|
|
51
54
|
|
|
52
55
|
def __post_init__(self):
|
|
53
|
-
self.
|
|
56
|
+
assert any((self.pipe is not None, self.transformer is not None))
|
|
57
|
+
self.patchify()
|
|
54
58
|
|
|
55
|
-
def
|
|
59
|
+
def patchify(self, *args, **kwargs):
|
|
56
60
|
# Process some specificial cases, specific for transformers
|
|
57
61
|
# that has different forward patterns between single_transformer_blocks
|
|
58
62
|
# and transformer_blocks , such as Flux (diffusers < 0.35.0).
|
|
59
|
-
if self.
|
|
60
|
-
self.transformer
|
|
61
|
-
self.transformer,
|
|
62
|
-
|
|
63
|
-
|
|
63
|
+
if self.patch_functor is not None:
|
|
64
|
+
if self.transformer is not None:
|
|
65
|
+
self.patch_functor.apply(self.transformer, *args, **kwargs)
|
|
66
|
+
else:
|
|
67
|
+
assert hasattr(self.pipe, "transformer")
|
|
68
|
+
self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
|
|
64
69
|
|
|
65
70
|
@staticmethod
|
|
66
71
|
def auto_block_adapter(
|
|
@@ -99,7 +104,9 @@ class BlockAdapter:
|
|
|
99
104
|
@staticmethod
|
|
100
105
|
def check_block_adapter(adapter: "BlockAdapter") -> bool:
|
|
101
106
|
if (
|
|
102
|
-
|
|
107
|
+
# NOTE: pipe may not need to be DiffusionPipeline?
|
|
108
|
+
# isinstance(adapter.pipe, DiffusionPipeline)
|
|
109
|
+
adapter.pipe is not None
|
|
103
110
|
and adapter.transformer is not None
|
|
104
111
|
and adapter.blocks is not None
|
|
105
112
|
and adapter.blocks_name is not None
|
|
@@ -287,11 +294,34 @@ class UnifiedCacheAdapter:
|
|
|
287
294
|
"EasyAnimate",
|
|
288
295
|
"SkyReelsV2",
|
|
289
296
|
"SD3",
|
|
297
|
+
"ConsisID",
|
|
298
|
+
"DiT",
|
|
299
|
+
"Amused",
|
|
300
|
+
"Bria",
|
|
301
|
+
"HunyuanDiT",
|
|
302
|
+
"HunyuanDiTPAG",
|
|
303
|
+
"Lumina",
|
|
304
|
+
"Lumina2",
|
|
305
|
+
"OmniGen",
|
|
306
|
+
"PixArt",
|
|
307
|
+
"Sana",
|
|
308
|
+
"ShapE",
|
|
309
|
+
"StableAudio",
|
|
310
|
+
"VisualCloze",
|
|
311
|
+
"AuraFlow",
|
|
312
|
+
"Chroma",
|
|
313
|
+
"HiDream",
|
|
290
314
|
]
|
|
291
315
|
|
|
292
316
|
def __call__(self, *args, **kwargs):
|
|
293
317
|
return self.apply(*args, **kwargs)
|
|
294
318
|
|
|
319
|
+
@classmethod
|
|
320
|
+
def supported_pipelines(cls) -> Tuple[int, List[str]]:
|
|
321
|
+
return len(cls._supported_pipelines), [
|
|
322
|
+
p + "*" for p in cls._supported_pipelines
|
|
323
|
+
]
|
|
324
|
+
|
|
295
325
|
@classmethod
|
|
296
326
|
def is_supported(cls, pipe: DiffusionPipeline) -> bool:
|
|
297
327
|
pipe_cls_name: str = pipe.__class__.__name__
|
|
@@ -303,8 +333,10 @@ class UnifiedCacheAdapter:
|
|
|
303
333
|
@classmethod
|
|
304
334
|
def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
|
|
305
335
|
pipe_cls_name: str = pipe.__class__.__name__
|
|
336
|
+
|
|
306
337
|
if pipe_cls_name.startswith("Flux"):
|
|
307
338
|
from diffusers import FluxTransformer2DModel
|
|
339
|
+
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
308
340
|
|
|
309
341
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
310
342
|
return UnifiedCacheParams(
|
|
@@ -317,9 +349,11 @@ class UnifiedCacheAdapter:
|
|
|
317
349
|
),
|
|
318
350
|
blocks_name="transformer_blocks",
|
|
319
351
|
dummy_blocks_names=["single_transformer_blocks"],
|
|
352
|
+
patch_functor=FluxPatchFunctor(),
|
|
320
353
|
),
|
|
321
354
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
322
355
|
)
|
|
356
|
+
|
|
323
357
|
elif pipe_cls_name.startswith("Mochi"):
|
|
324
358
|
from diffusers import MochiTransformer3DModel
|
|
325
359
|
|
|
@@ -334,6 +368,7 @@ class UnifiedCacheAdapter:
|
|
|
334
368
|
),
|
|
335
369
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
336
370
|
)
|
|
371
|
+
|
|
337
372
|
elif pipe_cls_name.startswith("CogVideoX"):
|
|
338
373
|
from diffusers import CogVideoXTransformer3DModel
|
|
339
374
|
|
|
@@ -348,6 +383,7 @@ class UnifiedCacheAdapter:
|
|
|
348
383
|
),
|
|
349
384
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
350
385
|
)
|
|
386
|
+
|
|
351
387
|
elif pipe_cls_name.startswith("Wan"):
|
|
352
388
|
from diffusers import (
|
|
353
389
|
WanTransformer3DModel,
|
|
@@ -358,16 +394,35 @@ class UnifiedCacheAdapter:
|
|
|
358
394
|
pipe.transformer,
|
|
359
395
|
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
360
396
|
)
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
397
|
+
if getattr(pipe, "transformer_2", None):
|
|
398
|
+
# Wan 2.2, cache for low-noise transformer
|
|
399
|
+
assert isinstance(
|
|
400
|
+
pipe.transformer_2,
|
|
401
|
+
(WanTransformer3DModel, WanVACETransformer3DModel),
|
|
402
|
+
)
|
|
403
|
+
return UnifiedCacheParams(
|
|
404
|
+
block_adapter=BlockAdapter(
|
|
405
|
+
pipe=pipe,
|
|
406
|
+
transformer=pipe.transformer_2,
|
|
407
|
+
blocks=pipe.transformer_2.blocks,
|
|
408
|
+
blocks_name="blocks",
|
|
409
|
+
dummy_blocks_names=[],
|
|
410
|
+
),
|
|
411
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
412
|
+
)
|
|
413
|
+
else:
|
|
414
|
+
# Wan 2.1
|
|
415
|
+
return UnifiedCacheParams(
|
|
416
|
+
block_adapter=BlockAdapter(
|
|
417
|
+
pipe=pipe,
|
|
418
|
+
transformer=pipe.transformer,
|
|
419
|
+
blocks=pipe.transformer.blocks,
|
|
420
|
+
blocks_name="blocks",
|
|
421
|
+
dummy_blocks_names=[],
|
|
422
|
+
),
|
|
423
|
+
forward_pattern=ForwardPattern.Pattern_2,
|
|
424
|
+
)
|
|
425
|
+
|
|
371
426
|
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
372
427
|
from diffusers import HunyuanVideoTransformer3DModel
|
|
373
428
|
|
|
@@ -384,6 +439,7 @@ class UnifiedCacheAdapter:
|
|
|
384
439
|
),
|
|
385
440
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
386
441
|
)
|
|
442
|
+
|
|
387
443
|
elif pipe_cls_name.startswith("QwenImage"):
|
|
388
444
|
from diffusers import QwenImageTransformer2DModel
|
|
389
445
|
|
|
@@ -398,6 +454,7 @@ class UnifiedCacheAdapter:
|
|
|
398
454
|
),
|
|
399
455
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
400
456
|
)
|
|
457
|
+
|
|
401
458
|
elif pipe_cls_name.startswith("LTXVideo"):
|
|
402
459
|
from diffusers import LTXVideoTransformer3DModel
|
|
403
460
|
|
|
@@ -412,6 +469,7 @@ class UnifiedCacheAdapter:
|
|
|
412
469
|
),
|
|
413
470
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
414
471
|
)
|
|
472
|
+
|
|
415
473
|
elif pipe_cls_name.startswith("Allegro"):
|
|
416
474
|
from diffusers import AllegroTransformer3DModel
|
|
417
475
|
|
|
@@ -426,6 +484,7 @@ class UnifiedCacheAdapter:
|
|
|
426
484
|
),
|
|
427
485
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
428
486
|
)
|
|
487
|
+
|
|
429
488
|
elif pipe_cls_name.startswith("CogView3Plus"):
|
|
430
489
|
from diffusers import CogView3PlusTransformer2DModel
|
|
431
490
|
|
|
@@ -440,6 +499,7 @@ class UnifiedCacheAdapter:
|
|
|
440
499
|
),
|
|
441
500
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
442
501
|
)
|
|
502
|
+
|
|
443
503
|
elif pipe_cls_name.startswith("CogView4"):
|
|
444
504
|
from diffusers import CogView4Transformer2DModel
|
|
445
505
|
|
|
@@ -454,6 +514,7 @@ class UnifiedCacheAdapter:
|
|
|
454
514
|
),
|
|
455
515
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
456
516
|
)
|
|
517
|
+
|
|
457
518
|
elif pipe_cls_name.startswith("Cosmos"):
|
|
458
519
|
from diffusers import CosmosTransformer3DModel
|
|
459
520
|
|
|
@@ -468,6 +529,7 @@ class UnifiedCacheAdapter:
|
|
|
468
529
|
),
|
|
469
530
|
forward_pattern=ForwardPattern.Pattern_2,
|
|
470
531
|
)
|
|
532
|
+
|
|
471
533
|
elif pipe_cls_name.startswith("EasyAnimate"):
|
|
472
534
|
from diffusers import EasyAnimateTransformer3DModel
|
|
473
535
|
|
|
@@ -482,6 +544,7 @@ class UnifiedCacheAdapter:
|
|
|
482
544
|
),
|
|
483
545
|
forward_pattern=ForwardPattern.Pattern_0,
|
|
484
546
|
)
|
|
547
|
+
|
|
485
548
|
elif pipe_cls_name.startswith("SkyReelsV2"):
|
|
486
549
|
from diffusers import SkyReelsV2Transformer3DModel
|
|
487
550
|
|
|
@@ -510,6 +573,284 @@ class UnifiedCacheAdapter:
|
|
|
510
573
|
),
|
|
511
574
|
forward_pattern=ForwardPattern.Pattern_1,
|
|
512
575
|
)
|
|
576
|
+
|
|
577
|
+
elif pipe_cls_name.startswith("ConsisID"):
|
|
578
|
+
from diffusers import ConsisIDTransformer3DModel
|
|
579
|
+
|
|
580
|
+
assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
|
|
581
|
+
return UnifiedCacheParams(
|
|
582
|
+
block_adapter=BlockAdapter(
|
|
583
|
+
pipe=pipe,
|
|
584
|
+
transformer=pipe.transformer,
|
|
585
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
586
|
+
blocks_name="transformer_blocks",
|
|
587
|
+
dummy_blocks_names=[],
|
|
588
|
+
),
|
|
589
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
590
|
+
)
|
|
591
|
+
|
|
592
|
+
elif pipe_cls_name.startswith("DiT"):
|
|
593
|
+
from diffusers import DiTTransformer2DModel
|
|
594
|
+
|
|
595
|
+
assert isinstance(pipe.transformer, DiTTransformer2DModel)
|
|
596
|
+
return UnifiedCacheParams(
|
|
597
|
+
block_adapter=BlockAdapter(
|
|
598
|
+
pipe=pipe,
|
|
599
|
+
transformer=pipe.transformer,
|
|
600
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
601
|
+
blocks_name="transformer_blocks",
|
|
602
|
+
dummy_blocks_names=[],
|
|
603
|
+
),
|
|
604
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
605
|
+
)
|
|
606
|
+
|
|
607
|
+
elif pipe_cls_name.startswith("Amused"):
|
|
608
|
+
from diffusers import UVit2DModel
|
|
609
|
+
|
|
610
|
+
assert isinstance(pipe.transformer, UVit2DModel)
|
|
611
|
+
return UnifiedCacheParams(
|
|
612
|
+
block_adapter=BlockAdapter(
|
|
613
|
+
pipe=pipe,
|
|
614
|
+
transformer=pipe.transformer,
|
|
615
|
+
blocks=pipe.transformer.transformer_layers,
|
|
616
|
+
blocks_name="transformer_layers",
|
|
617
|
+
dummy_blocks_names=[],
|
|
618
|
+
),
|
|
619
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
620
|
+
)
|
|
621
|
+
|
|
622
|
+
elif pipe_cls_name.startswith("Bria"):
|
|
623
|
+
from diffusers import BriaTransformer2DModel
|
|
624
|
+
|
|
625
|
+
assert isinstance(pipe.transformer, BriaTransformer2DModel)
|
|
626
|
+
return UnifiedCacheParams(
|
|
627
|
+
block_adapter=BlockAdapter(
|
|
628
|
+
pipe=pipe,
|
|
629
|
+
transformer=pipe.transformer,
|
|
630
|
+
blocks=(
|
|
631
|
+
pipe.transformer.transformer_blocks
|
|
632
|
+
+ pipe.transformer.single_transformer_blocks
|
|
633
|
+
),
|
|
634
|
+
blocks_name="transformer_blocks",
|
|
635
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
636
|
+
),
|
|
637
|
+
forward_pattern=ForwardPattern.Pattern_0,
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
elif pipe_cls_name.startswith("HunyuanDiT"):
|
|
641
|
+
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
642
|
+
|
|
643
|
+
assert isinstance(
|
|
644
|
+
pipe.transformer,
|
|
645
|
+
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
646
|
+
)
|
|
647
|
+
return UnifiedCacheParams(
|
|
648
|
+
block_adapter=BlockAdapter(
|
|
649
|
+
pipe=pipe,
|
|
650
|
+
transformer=pipe.transformer,
|
|
651
|
+
blocks=pipe.transformer.blocks,
|
|
652
|
+
blocks_name="blocks",
|
|
653
|
+
dummy_blocks_names=[],
|
|
654
|
+
),
|
|
655
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
656
|
+
)
|
|
657
|
+
|
|
658
|
+
elif pipe_cls_name.startswith("HunyuanDiTPAG"):
|
|
659
|
+
from diffusers import HunyuanDiT2DModel
|
|
660
|
+
|
|
661
|
+
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
662
|
+
return UnifiedCacheParams(
|
|
663
|
+
block_adapter=BlockAdapter(
|
|
664
|
+
pipe=pipe,
|
|
665
|
+
transformer=pipe.transformer,
|
|
666
|
+
blocks=pipe.transformer.blocks,
|
|
667
|
+
blocks_name="blocks",
|
|
668
|
+
dummy_blocks_names=[],
|
|
669
|
+
),
|
|
670
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
671
|
+
)
|
|
672
|
+
|
|
673
|
+
elif pipe_cls_name.startswith("Lumina"):
|
|
674
|
+
from diffusers import LuminaNextDiT2DModel
|
|
675
|
+
|
|
676
|
+
assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
|
|
677
|
+
return UnifiedCacheParams(
|
|
678
|
+
block_adapter=BlockAdapter(
|
|
679
|
+
pipe=pipe,
|
|
680
|
+
transformer=pipe.transformer,
|
|
681
|
+
blocks=pipe.transformer.layers,
|
|
682
|
+
blocks_name="layers",
|
|
683
|
+
dummy_blocks_names=[],
|
|
684
|
+
),
|
|
685
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
686
|
+
)
|
|
687
|
+
|
|
688
|
+
elif pipe_cls_name.startswith("Lumina2"):
|
|
689
|
+
from diffusers import Lumina2Transformer2DModel
|
|
690
|
+
|
|
691
|
+
assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
|
|
692
|
+
return UnifiedCacheParams(
|
|
693
|
+
block_adapter=BlockAdapter(
|
|
694
|
+
pipe=pipe,
|
|
695
|
+
transformer=pipe.transformer,
|
|
696
|
+
blocks=pipe.transformer.layers,
|
|
697
|
+
blocks_name="layers",
|
|
698
|
+
dummy_blocks_names=[],
|
|
699
|
+
),
|
|
700
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
elif pipe_cls_name.startswith("OmniGen"):
|
|
704
|
+
from diffusers import OmniGenTransformer2DModel
|
|
705
|
+
|
|
706
|
+
assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
|
|
707
|
+
return UnifiedCacheParams(
|
|
708
|
+
block_adapter=BlockAdapter(
|
|
709
|
+
pipe=pipe,
|
|
710
|
+
transformer=pipe.transformer,
|
|
711
|
+
blocks=pipe.transformer.layers,
|
|
712
|
+
blocks_name="layers",
|
|
713
|
+
dummy_blocks_names=[],
|
|
714
|
+
),
|
|
715
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
716
|
+
)
|
|
717
|
+
|
|
718
|
+
elif pipe_cls_name.startswith("PixArt"):
|
|
719
|
+
from diffusers import PixArtTransformer2DModel
|
|
720
|
+
|
|
721
|
+
assert isinstance(pipe.transformer, PixArtTransformer2DModel)
|
|
722
|
+
return UnifiedCacheParams(
|
|
723
|
+
block_adapter=BlockAdapter(
|
|
724
|
+
pipe=pipe,
|
|
725
|
+
transformer=pipe.transformer,
|
|
726
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
727
|
+
blocks_name="transformer_blocks",
|
|
728
|
+
dummy_blocks_names=[],
|
|
729
|
+
),
|
|
730
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
731
|
+
)
|
|
732
|
+
|
|
733
|
+
elif pipe_cls_name.startswith("Sana"):
|
|
734
|
+
from diffusers import SanaTransformer2DModel
|
|
735
|
+
|
|
736
|
+
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
737
|
+
return UnifiedCacheParams(
|
|
738
|
+
block_adapter=BlockAdapter(
|
|
739
|
+
pipe=pipe,
|
|
740
|
+
transformer=pipe.transformer,
|
|
741
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
742
|
+
blocks_name="transformer_blocks",
|
|
743
|
+
dummy_blocks_names=[],
|
|
744
|
+
),
|
|
745
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
746
|
+
)
|
|
747
|
+
|
|
748
|
+
elif pipe_cls_name.startswith("ShapE"):
|
|
749
|
+
from diffusers import PriorTransformer
|
|
750
|
+
|
|
751
|
+
assert isinstance(pipe.prior, PriorTransformer)
|
|
752
|
+
return UnifiedCacheParams(
|
|
753
|
+
block_adapter=BlockAdapter(
|
|
754
|
+
pipe=pipe,
|
|
755
|
+
transformer=pipe.prior,
|
|
756
|
+
blocks=pipe.prior.transformer_blocks,
|
|
757
|
+
blocks_name="transformer_blocks",
|
|
758
|
+
dummy_blocks_names=[],
|
|
759
|
+
),
|
|
760
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
761
|
+
)
|
|
762
|
+
|
|
763
|
+
elif pipe_cls_name.startswith("StableAudio"):
|
|
764
|
+
from diffusers import StableAudioDiTModel
|
|
765
|
+
|
|
766
|
+
assert isinstance(pipe.transformer, StableAudioDiTModel)
|
|
767
|
+
return UnifiedCacheParams(
|
|
768
|
+
block_adapter=BlockAdapter(
|
|
769
|
+
pipe=pipe,
|
|
770
|
+
transformer=pipe.transformer,
|
|
771
|
+
blocks=pipe.transformer.transformer_blocks,
|
|
772
|
+
blocks_name="transformer_blocks",
|
|
773
|
+
dummy_blocks_names=[],
|
|
774
|
+
),
|
|
775
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
776
|
+
)
|
|
777
|
+
|
|
778
|
+
elif pipe_cls_name.startswith("VisualCloze"):
|
|
779
|
+
from diffusers import FluxTransformer2DModel
|
|
780
|
+
from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
|
|
781
|
+
|
|
782
|
+
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
783
|
+
return UnifiedCacheParams(
|
|
784
|
+
block_adapter=BlockAdapter(
|
|
785
|
+
pipe=pipe,
|
|
786
|
+
transformer=pipe.transformer,
|
|
787
|
+
blocks=(
|
|
788
|
+
pipe.transformer.transformer_blocks
|
|
789
|
+
+ pipe.transformer.single_transformer_blocks
|
|
790
|
+
),
|
|
791
|
+
blocks_name="transformer_blocks",
|
|
792
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
793
|
+
patch_functor=FluxPatchFunctor(),
|
|
794
|
+
),
|
|
795
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
796
|
+
)
|
|
797
|
+
|
|
798
|
+
elif pipe_cls_name.startswith("AuraFlow"):
|
|
799
|
+
from diffusers import AuraFlowTransformer2DModel
|
|
800
|
+
|
|
801
|
+
assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
|
|
802
|
+
return UnifiedCacheParams(
|
|
803
|
+
block_adapter=BlockAdapter(
|
|
804
|
+
pipe=pipe,
|
|
805
|
+
transformer=pipe.transformer,
|
|
806
|
+
# Only support caching single_transformer_blocks for AuraFlow now.
|
|
807
|
+
# TODO: Support AuraFlowPatchFunctor.
|
|
808
|
+
blocks=pipe.transformer.single_transformer_blocks,
|
|
809
|
+
blocks_name="single_transformer_blocks",
|
|
810
|
+
dummy_blocks_names=[],
|
|
811
|
+
),
|
|
812
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
813
|
+
)
|
|
814
|
+
|
|
815
|
+
elif pipe_cls_name.startswith("Chroma"):
|
|
816
|
+
from diffusers import ChromaTransformer2DModel
|
|
817
|
+
from cache_dit.cache_factory.patch_functors import (
|
|
818
|
+
ChromaPatchFunctor,
|
|
819
|
+
)
|
|
820
|
+
|
|
821
|
+
assert isinstance(pipe.transformer, ChromaTransformer2DModel)
|
|
822
|
+
return UnifiedCacheParams(
|
|
823
|
+
block_adapter=BlockAdapter(
|
|
824
|
+
pipe=pipe,
|
|
825
|
+
transformer=pipe.transformer,
|
|
826
|
+
blocks=(
|
|
827
|
+
pipe.transformer.transformer_blocks
|
|
828
|
+
+ pipe.transformer.single_transformer_blocks
|
|
829
|
+
),
|
|
830
|
+
blocks_name="transformer_blocks",
|
|
831
|
+
dummy_blocks_names=["single_transformer_blocks"],
|
|
832
|
+
patch_functor=ChromaPatchFunctor(),
|
|
833
|
+
),
|
|
834
|
+
forward_pattern=ForwardPattern.Pattern_1,
|
|
835
|
+
)
|
|
836
|
+
|
|
837
|
+
elif pipe_cls_name.startswith("HiDream"):
|
|
838
|
+
from diffusers import HiDreamImageTransformer2DModel
|
|
839
|
+
|
|
840
|
+
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
841
|
+
return UnifiedCacheParams(
|
|
842
|
+
block_adapter=BlockAdapter(
|
|
843
|
+
pipe=pipe,
|
|
844
|
+
transformer=pipe.transformer,
|
|
845
|
+
# Only support caching single_stream_blocks for HiDream now.
|
|
846
|
+
# TODO: Support HiDreamPatchFunctor.
|
|
847
|
+
blocks=pipe.transformer.single_stream_blocks,
|
|
848
|
+
blocks_name="single_stream_blocks",
|
|
849
|
+
dummy_blocks_names=[],
|
|
850
|
+
),
|
|
851
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
852
|
+
)
|
|
853
|
+
|
|
513
854
|
else:
|
|
514
855
|
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
515
856
|
|
|
@@ -608,6 +949,14 @@ class UnifiedCacheAdapter:
|
|
|
608
949
|
return True
|
|
609
950
|
elif cls_name.startswith("Wan"):
|
|
610
951
|
return True
|
|
952
|
+
elif cls_name.startswith("CogView4"):
|
|
953
|
+
return True
|
|
954
|
+
elif cls_name.startswith("Cosmos"):
|
|
955
|
+
return True
|
|
956
|
+
elif cls_name.startswith("SkyReelsV2"):
|
|
957
|
+
return True
|
|
958
|
+
elif cls_name.startswith("Chroma"):
|
|
959
|
+
return True
|
|
611
960
|
return False
|
|
612
961
|
|
|
613
962
|
@classmethod
|
|
@@ -680,7 +1029,7 @@ class UnifiedCacheAdapter:
|
|
|
680
1029
|
# Apply cache on transformer: mock cached transformer blocks
|
|
681
1030
|
cached_blocks = torch.nn.ModuleList(
|
|
682
1031
|
[
|
|
683
|
-
|
|
1032
|
+
DBCachedBlocks(
|
|
684
1033
|
block_adapter.blocks,
|
|
685
1034
|
transformer=block_adapter.transformer,
|
|
686
1035
|
forward_pattern=forward_pattern,
|
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
|
|
2
|
+
DBCachedBlocks_Pattern_0_1_2,
|
|
3
|
+
)
|
|
4
|
+
from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
|
|
5
|
+
DBCachedBlocks_Pattern_3_4_5,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class DBCachedBlocks:
|
|
10
|
+
def __new__(cls, *args, **kwargs):
|
|
11
|
+
forward_pattern = kwargs.get("forward_pattern", None)
|
|
12
|
+
assert forward_pattern is not None, "forward_pattern can't be None."
|
|
13
|
+
if forward_pattern in DBCachedBlocks_Pattern_0_1_2._supported_patterns:
|
|
14
|
+
return DBCachedBlocks_Pattern_0_1_2(*args, **kwargs)
|
|
15
|
+
elif (
|
|
16
|
+
forward_pattern in DBCachedBlocks_Pattern_3_4_5._supported_patterns
|
|
17
|
+
):
|
|
18
|
+
return DBCachedBlocks_Pattern_3_4_5(*args, **kwargs)
|
|
19
|
+
else:
|
|
20
|
+
raise ValueError(f"Pattern {forward_pattern} is not supported now!")
|
|
@@ -0,0 +1,16 @@
|
|
|
1
|
+
from cache_dit.cache_factory import ForwardPattern
|
|
2
|
+
from cache_dit.cache_factory.cache_blocks.pattern_base import (
|
|
3
|
+
DBCachedBlocks_Pattern_Base,
|
|
4
|
+
)
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
class DBCachedBlocks_Pattern_0_1_2(DBCachedBlocks_Pattern_Base):
|
|
11
|
+
_supported_patterns = [
|
|
12
|
+
ForwardPattern.Pattern_0,
|
|
13
|
+
ForwardPattern.Pattern_1,
|
|
14
|
+
ForwardPattern.Pattern_2,
|
|
15
|
+
]
|
|
16
|
+
...
|