cache-dit 0.2.24__py3-none-any.whl → 0.2.26__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +5 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +2 -0
  4. cache_dit/cache_factory/cache_adapters.py +375 -26
  5. cache_dit/cache_factory/cache_blocks/__init__.py +20 -0
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +270 -0
  8. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +17 -18
  9. cache_dit/cache_factory/cache_blocks/utils.py +19 -0
  10. cache_dit/cache_factory/cache_context.py +32 -25
  11. cache_dit/cache_factory/cache_interface.py +8 -3
  12. cache_dit/cache_factory/forward_pattern.py +45 -24
  13. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  14. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  15. cache_dit/cache_factory/patch_functors/functor_chroma.py +273 -0
  16. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +45 -31
  17. cache_dit/compile/utils.py +1 -1
  18. cache_dit/quantize/__init__.py +1 -0
  19. cache_dit/quantize/quantize_ao.py +196 -0
  20. cache_dit/quantize/quantize_interface.py +46 -0
  21. cache_dit/utils.py +49 -17
  22. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/METADATA +43 -18
  23. cache_dit-0.2.26.dist-info/RECORD +42 -0
  24. cache_dit-0.2.24.dist-info/RECORD +0 -32
  25. /cache_dit/{cache_factory/patch/__init__.py → quantize/quantize_svdq.py} +0 -0
  26. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/WHEEL +0 -0
  27. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-0.2.24.dist-info → cache_dit-0.2.26.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -11,7 +11,10 @@ from cache_dit.cache_factory import block_range
11
11
  from cache_dit.cache_factory import CacheType
12
12
  from cache_dit.cache_factory import BlockAdapter
13
13
  from cache_dit.cache_factory import ForwardPattern
14
+ from cache_dit.cache_factory import PatchFunctor
15
+ from cache_dit.cache_factory import supported_pipelines
14
16
  from cache_dit.compile import set_compile_configs
17
+ from cache_dit.quantize import quantize
15
18
  from cache_dit.utils import summary
16
19
  from cache_dit.utils import strify
17
20
  from cache_dit.logger import init_logger
@@ -23,3 +26,5 @@ Forward_Pattern_0 = ForwardPattern.Pattern_0
23
26
  Forward_Pattern_1 = ForwardPattern.Pattern_1
24
27
  Forward_Pattern_2 = ForwardPattern.Pattern_2
25
28
  Forward_Pattern_3 = ForwardPattern.Pattern_3
29
+ Forward_Pattern_4 = ForwardPattern.Pattern_4
30
+ Forward_Pattern_5 = ForwardPattern.Pattern_5
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.24'
32
- __version_tuple__ = version_tuple = (0, 2, 24)
31
+ __version__ = version = '0.2.26'
32
+ __version_tuple__ = version_tuple = (0, 2, 26)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -5,4 +5,6 @@ from cache_dit.cache_factory.cache_types import block_range
5
5
  from cache_dit.cache_factory.cache_adapters import BlockAdapter
6
6
  from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
7
7
  from cache_dit.cache_factory.cache_interface import enable_cache
8
+ from cache_dit.cache_factory.cache_interface import supported_pipelines
9
+ from cache_dit.cache_factory.patch_functors import PatchFunctor
8
10
  from cache_dit.cache_factory.utils import load_options
@@ -8,15 +8,14 @@ import dataclasses
8
8
  from typing import Any, Tuple, List, Optional
9
9
  from contextlib import ExitStack
10
10
  from diffusers import DiffusionPipeline
11
- from cache_dit.cache_factory.patch.flux import (
12
- maybe_patch_flux_transformer,
13
- )
14
11
  from cache_dit.cache_factory import CacheType
12
+ from cache_dit.cache_factory import cache_context
15
13
  from cache_dit.cache_factory import ForwardPattern
14
+ from cache_dit.cache_factory.patch_functors import PatchFunctor
16
15
  from cache_dit.cache_factory.cache_blocks import (
17
- cache_context,
18
- DBCachedTransformerBlocks,
16
+ DBCachedBlocks,
19
17
  )
18
+
20
19
  from cache_dit.logger import init_logger
21
20
 
22
21
  logger = init_logger(__name__)
@@ -24,12 +23,14 @@ logger = init_logger(__name__)
24
23
 
25
24
  @dataclasses.dataclass
26
25
  class BlockAdapter:
27
- pipe: DiffusionPipeline = None
26
+ pipe: DiffusionPipeline | Any = None
28
27
  transformer: torch.nn.Module = None
29
28
  blocks: torch.nn.ModuleList = None
30
29
  # transformer_blocks, blocks, etc.
31
30
  blocks_name: str = None
32
- dummy_blocks_names: list[str] = dataclasses.field(default_factory=list)
31
+ dummy_blocks_names: List[str] = dataclasses.field(default_factory=list)
32
+ # patch functor: Flux, etc.
33
+ patch_functor: Optional[PatchFunctor] = None
33
34
  # flags to control auto block adapter
34
35
  auto: bool = False
35
36
  allow_prefixes: List[str] = dataclasses.field(
@@ -38,6 +39,8 @@ class BlockAdapter:
38
39
  "single_transformer",
39
40
  "blocks",
40
41
  "layers",
42
+ "single_stream_blocks",
43
+ "double_stream_blocks",
41
44
  ]
42
45
  )
43
46
  check_prefixes: bool = True
@@ -50,17 +53,19 @@ class BlockAdapter:
50
53
  )
51
54
 
52
55
  def __post_init__(self):
53
- self.maybe_apply_patch()
56
+ assert any((self.pipe is not None, self.transformer is not None))
57
+ self.patchify()
54
58
 
55
- def maybe_apply_patch(self):
59
+ def patchify(self, *args, **kwargs):
56
60
  # Process some specificial cases, specific for transformers
57
61
  # that has different forward patterns between single_transformer_blocks
58
62
  # and transformer_blocks , such as Flux (diffusers < 0.35.0).
59
- if self.transformer.__class__.__name__.startswith("Flux"):
60
- self.transformer = maybe_patch_flux_transformer(
61
- self.transformer,
62
- blocks=self.blocks,
63
- )
63
+ if self.patch_functor is not None:
64
+ if self.transformer is not None:
65
+ self.patch_functor.apply(self.transformer, *args, **kwargs)
66
+ else:
67
+ assert hasattr(self.pipe, "transformer")
68
+ self.patch_functor.apply(self.pipe.transformer, *args, **kwargs)
64
69
 
65
70
  @staticmethod
66
71
  def auto_block_adapter(
@@ -99,7 +104,9 @@ class BlockAdapter:
99
104
  @staticmethod
100
105
  def check_block_adapter(adapter: "BlockAdapter") -> bool:
101
106
  if (
102
- isinstance(adapter.pipe, DiffusionPipeline)
107
+ # NOTE: pipe may not need to be DiffusionPipeline?
108
+ # isinstance(adapter.pipe, DiffusionPipeline)
109
+ adapter.pipe is not None
103
110
  and adapter.transformer is not None
104
111
  and adapter.blocks is not None
105
112
  and adapter.blocks_name is not None
@@ -287,11 +294,34 @@ class UnifiedCacheAdapter:
287
294
  "EasyAnimate",
288
295
  "SkyReelsV2",
289
296
  "SD3",
297
+ "ConsisID",
298
+ "DiT",
299
+ "Amused",
300
+ "Bria",
301
+ "HunyuanDiT",
302
+ "HunyuanDiTPAG",
303
+ "Lumina",
304
+ "Lumina2",
305
+ "OmniGen",
306
+ "PixArt",
307
+ "Sana",
308
+ "ShapE",
309
+ "StableAudio",
310
+ "VisualCloze",
311
+ "AuraFlow",
312
+ "Chroma",
313
+ "HiDream",
290
314
  ]
291
315
 
292
316
  def __call__(self, *args, **kwargs):
293
317
  return self.apply(*args, **kwargs)
294
318
 
319
+ @classmethod
320
+ def supported_pipelines(cls) -> Tuple[int, List[str]]:
321
+ return len(cls._supported_pipelines), [
322
+ p + "*" for p in cls._supported_pipelines
323
+ ]
324
+
295
325
  @classmethod
296
326
  def is_supported(cls, pipe: DiffusionPipeline) -> bool:
297
327
  pipe_cls_name: str = pipe.__class__.__name__
@@ -303,8 +333,10 @@ class UnifiedCacheAdapter:
303
333
  @classmethod
304
334
  def get_params(cls, pipe: DiffusionPipeline) -> UnifiedCacheParams:
305
335
  pipe_cls_name: str = pipe.__class__.__name__
336
+
306
337
  if pipe_cls_name.startswith("Flux"):
307
338
  from diffusers import FluxTransformer2DModel
339
+ from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
308
340
 
309
341
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
310
342
  return UnifiedCacheParams(
@@ -317,9 +349,11 @@ class UnifiedCacheAdapter:
317
349
  ),
318
350
  blocks_name="transformer_blocks",
319
351
  dummy_blocks_names=["single_transformer_blocks"],
352
+ patch_functor=FluxPatchFunctor(),
320
353
  ),
321
354
  forward_pattern=ForwardPattern.Pattern_1,
322
355
  )
356
+
323
357
  elif pipe_cls_name.startswith("Mochi"):
324
358
  from diffusers import MochiTransformer3DModel
325
359
 
@@ -334,6 +368,7 @@ class UnifiedCacheAdapter:
334
368
  ),
335
369
  forward_pattern=ForwardPattern.Pattern_0,
336
370
  )
371
+
337
372
  elif pipe_cls_name.startswith("CogVideoX"):
338
373
  from diffusers import CogVideoXTransformer3DModel
339
374
 
@@ -348,6 +383,7 @@ class UnifiedCacheAdapter:
348
383
  ),
349
384
  forward_pattern=ForwardPattern.Pattern_0,
350
385
  )
386
+
351
387
  elif pipe_cls_name.startswith("Wan"):
352
388
  from diffusers import (
353
389
  WanTransformer3DModel,
@@ -358,16 +394,35 @@ class UnifiedCacheAdapter:
358
394
  pipe.transformer,
359
395
  (WanTransformer3DModel, WanVACETransformer3DModel),
360
396
  )
361
- return UnifiedCacheParams(
362
- block_adapter=BlockAdapter(
363
- pipe=pipe,
364
- transformer=pipe.transformer,
365
- blocks=pipe.transformer.blocks,
366
- blocks_name="blocks",
367
- dummy_blocks_names=[],
368
- ),
369
- forward_pattern=ForwardPattern.Pattern_2,
370
- )
397
+ if getattr(pipe, "transformer_2", None):
398
+ # Wan 2.2, cache for low-noise transformer
399
+ assert isinstance(
400
+ pipe.transformer_2,
401
+ (WanTransformer3DModel, WanVACETransformer3DModel),
402
+ )
403
+ return UnifiedCacheParams(
404
+ block_adapter=BlockAdapter(
405
+ pipe=pipe,
406
+ transformer=pipe.transformer_2,
407
+ blocks=pipe.transformer_2.blocks,
408
+ blocks_name="blocks",
409
+ dummy_blocks_names=[],
410
+ ),
411
+ forward_pattern=ForwardPattern.Pattern_2,
412
+ )
413
+ else:
414
+ # Wan 2.1
415
+ return UnifiedCacheParams(
416
+ block_adapter=BlockAdapter(
417
+ pipe=pipe,
418
+ transformer=pipe.transformer,
419
+ blocks=pipe.transformer.blocks,
420
+ blocks_name="blocks",
421
+ dummy_blocks_names=[],
422
+ ),
423
+ forward_pattern=ForwardPattern.Pattern_2,
424
+ )
425
+
371
426
  elif pipe_cls_name.startswith("HunyuanVideo"):
372
427
  from diffusers import HunyuanVideoTransformer3DModel
373
428
 
@@ -384,6 +439,7 @@ class UnifiedCacheAdapter:
384
439
  ),
385
440
  forward_pattern=ForwardPattern.Pattern_0,
386
441
  )
442
+
387
443
  elif pipe_cls_name.startswith("QwenImage"):
388
444
  from diffusers import QwenImageTransformer2DModel
389
445
 
@@ -398,6 +454,7 @@ class UnifiedCacheAdapter:
398
454
  ),
399
455
  forward_pattern=ForwardPattern.Pattern_1,
400
456
  )
457
+
401
458
  elif pipe_cls_name.startswith("LTXVideo"):
402
459
  from diffusers import LTXVideoTransformer3DModel
403
460
 
@@ -412,6 +469,7 @@ class UnifiedCacheAdapter:
412
469
  ),
413
470
  forward_pattern=ForwardPattern.Pattern_2,
414
471
  )
472
+
415
473
  elif pipe_cls_name.startswith("Allegro"):
416
474
  from diffusers import AllegroTransformer3DModel
417
475
 
@@ -426,6 +484,7 @@ class UnifiedCacheAdapter:
426
484
  ),
427
485
  forward_pattern=ForwardPattern.Pattern_2,
428
486
  )
487
+
429
488
  elif pipe_cls_name.startswith("CogView3Plus"):
430
489
  from diffusers import CogView3PlusTransformer2DModel
431
490
 
@@ -440,6 +499,7 @@ class UnifiedCacheAdapter:
440
499
  ),
441
500
  forward_pattern=ForwardPattern.Pattern_0,
442
501
  )
502
+
443
503
  elif pipe_cls_name.startswith("CogView4"):
444
504
  from diffusers import CogView4Transformer2DModel
445
505
 
@@ -454,6 +514,7 @@ class UnifiedCacheAdapter:
454
514
  ),
455
515
  forward_pattern=ForwardPattern.Pattern_0,
456
516
  )
517
+
457
518
  elif pipe_cls_name.startswith("Cosmos"):
458
519
  from diffusers import CosmosTransformer3DModel
459
520
 
@@ -468,6 +529,7 @@ class UnifiedCacheAdapter:
468
529
  ),
469
530
  forward_pattern=ForwardPattern.Pattern_2,
470
531
  )
532
+
471
533
  elif pipe_cls_name.startswith("EasyAnimate"):
472
534
  from diffusers import EasyAnimateTransformer3DModel
473
535
 
@@ -482,6 +544,7 @@ class UnifiedCacheAdapter:
482
544
  ),
483
545
  forward_pattern=ForwardPattern.Pattern_0,
484
546
  )
547
+
485
548
  elif pipe_cls_name.startswith("SkyReelsV2"):
486
549
  from diffusers import SkyReelsV2Transformer3DModel
487
550
 
@@ -510,6 +573,284 @@ class UnifiedCacheAdapter:
510
573
  ),
511
574
  forward_pattern=ForwardPattern.Pattern_1,
512
575
  )
576
+
577
+ elif pipe_cls_name.startswith("ConsisID"):
578
+ from diffusers import ConsisIDTransformer3DModel
579
+
580
+ assert isinstance(pipe.transformer, ConsisIDTransformer3DModel)
581
+ return UnifiedCacheParams(
582
+ block_adapter=BlockAdapter(
583
+ pipe=pipe,
584
+ transformer=pipe.transformer,
585
+ blocks=pipe.transformer.transformer_blocks,
586
+ blocks_name="transformer_blocks",
587
+ dummy_blocks_names=[],
588
+ ),
589
+ forward_pattern=ForwardPattern.Pattern_0,
590
+ )
591
+
592
+ elif pipe_cls_name.startswith("DiT"):
593
+ from diffusers import DiTTransformer2DModel
594
+
595
+ assert isinstance(pipe.transformer, DiTTransformer2DModel)
596
+ return UnifiedCacheParams(
597
+ block_adapter=BlockAdapter(
598
+ pipe=pipe,
599
+ transformer=pipe.transformer,
600
+ blocks=pipe.transformer.transformer_blocks,
601
+ blocks_name="transformer_blocks",
602
+ dummy_blocks_names=[],
603
+ ),
604
+ forward_pattern=ForwardPattern.Pattern_3,
605
+ )
606
+
607
+ elif pipe_cls_name.startswith("Amused"):
608
+ from diffusers import UVit2DModel
609
+
610
+ assert isinstance(pipe.transformer, UVit2DModel)
611
+ return UnifiedCacheParams(
612
+ block_adapter=BlockAdapter(
613
+ pipe=pipe,
614
+ transformer=pipe.transformer,
615
+ blocks=pipe.transformer.transformer_layers,
616
+ blocks_name="transformer_layers",
617
+ dummy_blocks_names=[],
618
+ ),
619
+ forward_pattern=ForwardPattern.Pattern_3,
620
+ )
621
+
622
+ elif pipe_cls_name.startswith("Bria"):
623
+ from diffusers import BriaTransformer2DModel
624
+
625
+ assert isinstance(pipe.transformer, BriaTransformer2DModel)
626
+ return UnifiedCacheParams(
627
+ block_adapter=BlockAdapter(
628
+ pipe=pipe,
629
+ transformer=pipe.transformer,
630
+ blocks=(
631
+ pipe.transformer.transformer_blocks
632
+ + pipe.transformer.single_transformer_blocks
633
+ ),
634
+ blocks_name="transformer_blocks",
635
+ dummy_blocks_names=["single_transformer_blocks"],
636
+ ),
637
+ forward_pattern=ForwardPattern.Pattern_0,
638
+ )
639
+
640
+ elif pipe_cls_name.startswith("HunyuanDiT"):
641
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
642
+
643
+ assert isinstance(
644
+ pipe.transformer,
645
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
646
+ )
647
+ return UnifiedCacheParams(
648
+ block_adapter=BlockAdapter(
649
+ pipe=pipe,
650
+ transformer=pipe.transformer,
651
+ blocks=pipe.transformer.blocks,
652
+ blocks_name="blocks",
653
+ dummy_blocks_names=[],
654
+ ),
655
+ forward_pattern=ForwardPattern.Pattern_3,
656
+ )
657
+
658
+ elif pipe_cls_name.startswith("HunyuanDiTPAG"):
659
+ from diffusers import HunyuanDiT2DModel
660
+
661
+ assert isinstance(pipe.transformer, HunyuanDiT2DModel)
662
+ return UnifiedCacheParams(
663
+ block_adapter=BlockAdapter(
664
+ pipe=pipe,
665
+ transformer=pipe.transformer,
666
+ blocks=pipe.transformer.blocks,
667
+ blocks_name="blocks",
668
+ dummy_blocks_names=[],
669
+ ),
670
+ forward_pattern=ForwardPattern.Pattern_3,
671
+ )
672
+
673
+ elif pipe_cls_name.startswith("Lumina"):
674
+ from diffusers import LuminaNextDiT2DModel
675
+
676
+ assert isinstance(pipe.transformer, LuminaNextDiT2DModel)
677
+ return UnifiedCacheParams(
678
+ block_adapter=BlockAdapter(
679
+ pipe=pipe,
680
+ transformer=pipe.transformer,
681
+ blocks=pipe.transformer.layers,
682
+ blocks_name="layers",
683
+ dummy_blocks_names=[],
684
+ ),
685
+ forward_pattern=ForwardPattern.Pattern_3,
686
+ )
687
+
688
+ elif pipe_cls_name.startswith("Lumina2"):
689
+ from diffusers import Lumina2Transformer2DModel
690
+
691
+ assert isinstance(pipe.transformer, Lumina2Transformer2DModel)
692
+ return UnifiedCacheParams(
693
+ block_adapter=BlockAdapter(
694
+ pipe=pipe,
695
+ transformer=pipe.transformer,
696
+ blocks=pipe.transformer.layers,
697
+ blocks_name="layers",
698
+ dummy_blocks_names=[],
699
+ ),
700
+ forward_pattern=ForwardPattern.Pattern_3,
701
+ )
702
+
703
+ elif pipe_cls_name.startswith("OmniGen"):
704
+ from diffusers import OmniGenTransformer2DModel
705
+
706
+ assert isinstance(pipe.transformer, OmniGenTransformer2DModel)
707
+ return UnifiedCacheParams(
708
+ block_adapter=BlockAdapter(
709
+ pipe=pipe,
710
+ transformer=pipe.transformer,
711
+ blocks=pipe.transformer.layers,
712
+ blocks_name="layers",
713
+ dummy_blocks_names=[],
714
+ ),
715
+ forward_pattern=ForwardPattern.Pattern_3,
716
+ )
717
+
718
+ elif pipe_cls_name.startswith("PixArt"):
719
+ from diffusers import PixArtTransformer2DModel
720
+
721
+ assert isinstance(pipe.transformer, PixArtTransformer2DModel)
722
+ return UnifiedCacheParams(
723
+ block_adapter=BlockAdapter(
724
+ pipe=pipe,
725
+ transformer=pipe.transformer,
726
+ blocks=pipe.transformer.transformer_blocks,
727
+ blocks_name="transformer_blocks",
728
+ dummy_blocks_names=[],
729
+ ),
730
+ forward_pattern=ForwardPattern.Pattern_3,
731
+ )
732
+
733
+ elif pipe_cls_name.startswith("Sana"):
734
+ from diffusers import SanaTransformer2DModel
735
+
736
+ assert isinstance(pipe.transformer, SanaTransformer2DModel)
737
+ return UnifiedCacheParams(
738
+ block_adapter=BlockAdapter(
739
+ pipe=pipe,
740
+ transformer=pipe.transformer,
741
+ blocks=pipe.transformer.transformer_blocks,
742
+ blocks_name="transformer_blocks",
743
+ dummy_blocks_names=[],
744
+ ),
745
+ forward_pattern=ForwardPattern.Pattern_3,
746
+ )
747
+
748
+ elif pipe_cls_name.startswith("ShapE"):
749
+ from diffusers import PriorTransformer
750
+
751
+ assert isinstance(pipe.prior, PriorTransformer)
752
+ return UnifiedCacheParams(
753
+ block_adapter=BlockAdapter(
754
+ pipe=pipe,
755
+ transformer=pipe.prior,
756
+ blocks=pipe.prior.transformer_blocks,
757
+ blocks_name="transformer_blocks",
758
+ dummy_blocks_names=[],
759
+ ),
760
+ forward_pattern=ForwardPattern.Pattern_3,
761
+ )
762
+
763
+ elif pipe_cls_name.startswith("StableAudio"):
764
+ from diffusers import StableAudioDiTModel
765
+
766
+ assert isinstance(pipe.transformer, StableAudioDiTModel)
767
+ return UnifiedCacheParams(
768
+ block_adapter=BlockAdapter(
769
+ pipe=pipe,
770
+ transformer=pipe.transformer,
771
+ blocks=pipe.transformer.transformer_blocks,
772
+ blocks_name="transformer_blocks",
773
+ dummy_blocks_names=[],
774
+ ),
775
+ forward_pattern=ForwardPattern.Pattern_3,
776
+ )
777
+
778
+ elif pipe_cls_name.startswith("VisualCloze"):
779
+ from diffusers import FluxTransformer2DModel
780
+ from cache_dit.cache_factory.patch_functors import FluxPatchFunctor
781
+
782
+ assert isinstance(pipe.transformer, FluxTransformer2DModel)
783
+ return UnifiedCacheParams(
784
+ block_adapter=BlockAdapter(
785
+ pipe=pipe,
786
+ transformer=pipe.transformer,
787
+ blocks=(
788
+ pipe.transformer.transformer_blocks
789
+ + pipe.transformer.single_transformer_blocks
790
+ ),
791
+ blocks_name="transformer_blocks",
792
+ dummy_blocks_names=["single_transformer_blocks"],
793
+ patch_functor=FluxPatchFunctor(),
794
+ ),
795
+ forward_pattern=ForwardPattern.Pattern_1,
796
+ )
797
+
798
+ elif pipe_cls_name.startswith("AuraFlow"):
799
+ from diffusers import AuraFlowTransformer2DModel
800
+
801
+ assert isinstance(pipe.transformer, AuraFlowTransformer2DModel)
802
+ return UnifiedCacheParams(
803
+ block_adapter=BlockAdapter(
804
+ pipe=pipe,
805
+ transformer=pipe.transformer,
806
+ # Only support caching single_transformer_blocks for AuraFlow now.
807
+ # TODO: Support AuraFlowPatchFunctor.
808
+ blocks=pipe.transformer.single_transformer_blocks,
809
+ blocks_name="single_transformer_blocks",
810
+ dummy_blocks_names=[],
811
+ ),
812
+ forward_pattern=ForwardPattern.Pattern_3,
813
+ )
814
+
815
+ elif pipe_cls_name.startswith("Chroma"):
816
+ from diffusers import ChromaTransformer2DModel
817
+ from cache_dit.cache_factory.patch_functors import (
818
+ ChromaPatchFunctor,
819
+ )
820
+
821
+ assert isinstance(pipe.transformer, ChromaTransformer2DModel)
822
+ return UnifiedCacheParams(
823
+ block_adapter=BlockAdapter(
824
+ pipe=pipe,
825
+ transformer=pipe.transformer,
826
+ blocks=(
827
+ pipe.transformer.transformer_blocks
828
+ + pipe.transformer.single_transformer_blocks
829
+ ),
830
+ blocks_name="transformer_blocks",
831
+ dummy_blocks_names=["single_transformer_blocks"],
832
+ patch_functor=ChromaPatchFunctor(),
833
+ ),
834
+ forward_pattern=ForwardPattern.Pattern_1,
835
+ )
836
+
837
+ elif pipe_cls_name.startswith("HiDream"):
838
+ from diffusers import HiDreamImageTransformer2DModel
839
+
840
+ assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
841
+ return UnifiedCacheParams(
842
+ block_adapter=BlockAdapter(
843
+ pipe=pipe,
844
+ transformer=pipe.transformer,
845
+ # Only support caching single_stream_blocks for HiDream now.
846
+ # TODO: Support HiDreamPatchFunctor.
847
+ blocks=pipe.transformer.single_stream_blocks,
848
+ blocks_name="single_stream_blocks",
849
+ dummy_blocks_names=[],
850
+ ),
851
+ forward_pattern=ForwardPattern.Pattern_3,
852
+ )
853
+
513
854
  else:
514
855
  raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
515
856
 
@@ -608,6 +949,14 @@ class UnifiedCacheAdapter:
608
949
  return True
609
950
  elif cls_name.startswith("Wan"):
610
951
  return True
952
+ elif cls_name.startswith("CogView4"):
953
+ return True
954
+ elif cls_name.startswith("Cosmos"):
955
+ return True
956
+ elif cls_name.startswith("SkyReelsV2"):
957
+ return True
958
+ elif cls_name.startswith("Chroma"):
959
+ return True
611
960
  return False
612
961
 
613
962
  @classmethod
@@ -680,7 +1029,7 @@ class UnifiedCacheAdapter:
680
1029
  # Apply cache on transformer: mock cached transformer blocks
681
1030
  cached_blocks = torch.nn.ModuleList(
682
1031
  [
683
- DBCachedTransformerBlocks(
1032
+ DBCachedBlocks(
684
1033
  block_adapter.blocks,
685
1034
  transformer=block_adapter.transformer,
686
1035
  forward_pattern=forward_pattern,
@@ -0,0 +1,20 @@
1
+ from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
2
+ DBCachedBlocks_Pattern_0_1_2,
3
+ )
4
+ from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
5
+ DBCachedBlocks_Pattern_3_4_5,
6
+ )
7
+
8
+
9
+ class DBCachedBlocks:
10
+ def __new__(cls, *args, **kwargs):
11
+ forward_pattern = kwargs.get("forward_pattern", None)
12
+ assert forward_pattern is not None, "forward_pattern can't be None."
13
+ if forward_pattern in DBCachedBlocks_Pattern_0_1_2._supported_patterns:
14
+ return DBCachedBlocks_Pattern_0_1_2(*args, **kwargs)
15
+ elif (
16
+ forward_pattern in DBCachedBlocks_Pattern_3_4_5._supported_patterns
17
+ ):
18
+ return DBCachedBlocks_Pattern_3_4_5(*args, **kwargs)
19
+ else:
20
+ raise ValueError(f"Pattern {forward_pattern} is not supported now!")
@@ -0,0 +1,16 @@
1
+ from cache_dit.cache_factory import ForwardPattern
2
+ from cache_dit.cache_factory.cache_blocks.pattern_base import (
3
+ DBCachedBlocks_Pattern_Base,
4
+ )
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ class DBCachedBlocks_Pattern_0_1_2(DBCachedBlocks_Pattern_Base):
11
+ _supported_patterns = [
12
+ ForwardPattern.Pattern_0,
13
+ ForwardPattern.Pattern_1,
14
+ ForwardPattern.Pattern_2,
15
+ ]
16
+ ...