cache-dit 0.2.25__py3-none-any.whl → 0.2.27__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +9 -4
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +16 -3
  4. cache_dit/cache_factory/block_adapters/__init__.py +538 -0
  5. cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
  6. cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
  7. cache_dit/cache_factory/cache_adapters.py +121 -563
  8. cache_dit/cache_factory/cache_blocks/__init__.py +18 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +16 -0
  10. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +275 -0
  11. cache_dit/cache_factory/{cache_blocks.py → cache_blocks/pattern_base.py} +100 -82
  12. cache_dit/cache_factory/cache_blocks/utils.py +23 -0
  13. cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
  14. cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +94 -56
  15. cache_dit/cache_factory/cache_interface.py +24 -16
  16. cache_dit/cache_factory/forward_pattern.py +45 -24
  17. cache_dit/cache_factory/patch_functors/__init__.py +5 -0
  18. cache_dit/cache_factory/patch_functors/functor_base.py +18 -0
  19. cache_dit/cache_factory/patch_functors/functor_chroma.py +276 -0
  20. cache_dit/cache_factory/{patch/flux.py → patch_functors/functor_flux.py} +49 -31
  21. cache_dit/quantize/quantize_ao.py +19 -4
  22. cache_dit/quantize/quantize_interface.py +2 -2
  23. cache_dit/utils.py +19 -15
  24. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/METADATA +76 -19
  25. cache_dit-0.2.27.dist-info/RECORD +47 -0
  26. cache_dit-0.2.25.dist-info/RECORD +0 -36
  27. /cache_dit/cache_factory/{patch/__init__.py → cache_contexts/cache_manager.py} +0 -0
  28. /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
  29. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
  30. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-0.2.25.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
@@ -7,14 +7,15 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
7
7
  import torch
8
8
  import torch.distributed as dist
9
9
 
10
- from cache_dit.cache_factory.taylorseer import TaylorSeer
10
+ from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
11
11
  from cache_dit.logger import init_logger
12
12
 
13
13
  logger = init_logger(__name__)
14
14
 
15
15
 
16
16
  @dataclasses.dataclass
17
- class DBCacheContext:
17
+ class _CachedContext: # Internal CachedContext Impl class
18
+ name: str = "default"
18
19
  # Dual Block Cache
19
20
  # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
20
21
  Fn_compute_blocks: int = 1
@@ -99,6 +100,8 @@ class DBCacheContext:
99
100
 
100
101
  @torch.compiler.disable
101
102
  def __post_init__(self):
103
+ if logger.isEnabledFor(logging.DEBUG):
104
+ logger.info(f"Created _CacheContext: {self.name}")
102
105
  # Some checks for settings
103
106
  if self.do_separate_cfg:
104
107
  assert self.enable_alter_cache is False, (
@@ -329,26 +332,60 @@ class DBCacheContext:
329
332
 
330
333
 
331
334
  # TODO: Support context manager for different cache_context
335
+ _current_cache_context: _CachedContext = None
336
+
337
+ _cache_context_manager: Dict[str, _CachedContext] = {}
332
338
 
333
339
 
334
340
  def create_cache_context(*args, **kwargs):
335
- return DBCacheContext(*args, **kwargs)
341
+ global _cache_context_manager
342
+ _context = _CachedContext(*args, **kwargs)
343
+ _cache_context_manager[_context.name] = _context
344
+ return _context
336
345
 
337
346
 
338
- def get_current_cache_context():
347
+ def get_cache_context():
339
348
  return _current_cache_context
340
349
 
341
350
 
342
- def set_current_cache_context(cache_context=None):
343
- global _current_cache_context
344
- _current_cache_context = cache_context
351
+ def set_cache_context(cache_context: _CachedContext | str):
352
+ global _current_cache_context, _cache_context_manager
353
+ if isinstance(cache_context, _CachedContext):
354
+ _current_cache_context = cache_context
355
+ else:
356
+ _current_cache_context = _cache_context_manager[cache_context]
357
+
358
+
359
+ def reset_cache_context(cache_context: _CachedContext | str, *args, **kwargs):
360
+ global _cache_context_manager
361
+ if isinstance(cache_context, _CachedContext):
362
+ old_context_name = cache_context.name
363
+ if cache_context.name in _cache_context_manager:
364
+ del _cache_context_manager[cache_context.name]
365
+ # force use old_context name
366
+ kwargs["name"] = old_context_name
367
+ _context = _CachedContext(*args, **kwargs)
368
+ _cache_context_manager[_context.name] = _context
369
+ else:
370
+ old_context_name = cache_context
371
+ if cache_context in _cache_context_manager:
372
+ del _cache_context_manager[cache_context]
373
+ # force use old_context name
374
+ kwargs["name"] = old_context_name
375
+ _context = _CachedContext(*args, **kwargs)
376
+ _cache_context_manager[_context.name] = _context
377
+
378
+ return _context
345
379
 
346
380
 
347
381
  @contextlib.contextmanager
348
- def cache_context(cache_context):
349
- global _current_cache_context
382
+ def cache_context(cache_context: _CachedContext | str):
383
+ global _current_cache_context, _cache_context_manager
350
384
  old_cache_context = _current_cache_context
351
- _current_cache_context = cache_context
385
+ if isinstance(cache_context, _CachedContext):
386
+ _current_cache_context = cache_context
387
+ else:
388
+ _current_cache_context = _cache_context_manager[cache_context]
352
389
  try:
353
390
  yield
354
391
  finally:
@@ -357,49 +394,49 @@ def cache_context(cache_context):
357
394
 
358
395
  @torch.compiler.disable
359
396
  def get_residual_diff_threshold():
360
- cache_context = get_current_cache_context()
397
+ cache_context = get_cache_context()
361
398
  assert cache_context is not None, "cache_context must be set before"
362
399
  return cache_context.get_residual_diff_threshold()
363
400
 
364
401
 
365
402
  @torch.compiler.disable
366
403
  def get_buffer(name):
367
- cache_context = get_current_cache_context()
404
+ cache_context = get_cache_context()
368
405
  assert cache_context is not None, "cache_context must be set before"
369
406
  return cache_context.get_buffer(name)
370
407
 
371
408
 
372
409
  @torch.compiler.disable
373
410
  def set_buffer(name, buffer):
374
- cache_context = get_current_cache_context()
411
+ cache_context = get_cache_context()
375
412
  assert cache_context is not None, "cache_context must be set before"
376
413
  cache_context.set_buffer(name, buffer)
377
414
 
378
415
 
379
416
  @torch.compiler.disable
380
417
  def remove_buffer(name):
381
- cache_context = get_current_cache_context()
418
+ cache_context = get_cache_context()
382
419
  assert cache_context is not None, "cache_context must be set before"
383
420
  cache_context.remove_buffer(name)
384
421
 
385
422
 
386
423
  @torch.compiler.disable
387
424
  def mark_step_begin():
388
- cache_context = get_current_cache_context()
425
+ cache_context = get_cache_context()
389
426
  assert cache_context is not None, "cache_context must be set before"
390
427
  cache_context.mark_step_begin()
391
428
 
392
429
 
393
430
  @torch.compiler.disable
394
431
  def get_current_step():
395
- cache_context = get_current_cache_context()
432
+ cache_context = get_cache_context()
396
433
  assert cache_context is not None, "cache_context must be set before"
397
434
  return cache_context.get_current_step()
398
435
 
399
436
 
400
437
  @torch.compiler.disable
401
438
  def get_current_step_residual_diff():
402
- cache_context = get_current_cache_context()
439
+ cache_context = get_cache_context()
403
440
  assert cache_context is not None, "cache_context must be set before"
404
441
  step = str(get_current_step())
405
442
  residual_diffs = get_residual_diffs()
@@ -410,7 +447,7 @@ def get_current_step_residual_diff():
410
447
 
411
448
  @torch.compiler.disable
412
449
  def get_current_step_cfg_residual_diff():
413
- cache_context = get_current_cache_context()
450
+ cache_context = get_cache_context()
414
451
  assert cache_context is not None, "cache_context must be set before"
415
452
  step = str(get_current_step())
416
453
  cfg_residual_diffs = get_cfg_residual_diffs()
@@ -421,110 +458,110 @@ def get_current_step_cfg_residual_diff():
421
458
 
422
459
  @torch.compiler.disable
423
460
  def get_current_transformer_step():
424
- cache_context = get_current_cache_context()
461
+ cache_context = get_cache_context()
425
462
  assert cache_context is not None, "cache_context must be set before"
426
463
  return cache_context.get_current_transformer_step()
427
464
 
428
465
 
429
466
  @torch.compiler.disable
430
467
  def get_cached_steps():
431
- cache_context = get_current_cache_context()
468
+ cache_context = get_cache_context()
432
469
  assert cache_context is not None, "cache_context must be set before"
433
470
  return cache_context.get_cached_steps()
434
471
 
435
472
 
436
473
  @torch.compiler.disable
437
474
  def get_cfg_cached_steps():
438
- cache_context = get_current_cache_context()
475
+ cache_context = get_cache_context()
439
476
  assert cache_context is not None, "cache_context must be set before"
440
477
  return cache_context.get_cfg_cached_steps()
441
478
 
442
479
 
443
480
  @torch.compiler.disable
444
481
  def get_max_cached_steps():
445
- cache_context = get_current_cache_context()
482
+ cache_context = get_cache_context()
446
483
  assert cache_context is not None, "cache_context must be set before"
447
484
  return cache_context.max_cached_steps
448
485
 
449
486
 
450
487
  @torch.compiler.disable
451
488
  def get_max_continuous_cached_steps():
452
- cache_context = get_current_cache_context()
489
+ cache_context = get_cache_context()
453
490
  assert cache_context is not None, "cache_context must be set before"
454
491
  return cache_context.max_continuous_cached_steps
455
492
 
456
493
 
457
494
  @torch.compiler.disable
458
495
  def get_continuous_cached_steps():
459
- cache_context = get_current_cache_context()
496
+ cache_context = get_cache_context()
460
497
  assert cache_context is not None, "cache_context must be set before"
461
498
  return cache_context.continuous_cached_steps
462
499
 
463
500
 
464
501
  @torch.compiler.disable
465
502
  def get_cfg_continuous_cached_steps():
466
- cache_context = get_current_cache_context()
503
+ cache_context = get_cache_context()
467
504
  assert cache_context is not None, "cache_context must be set before"
468
505
  return cache_context.cfg_continuous_cached_steps
469
506
 
470
507
 
471
508
  @torch.compiler.disable
472
509
  def add_cached_step():
473
- cache_context = get_current_cache_context()
510
+ cache_context = get_cache_context()
474
511
  assert cache_context is not None, "cache_context must be set before"
475
512
  cache_context.add_cached_step()
476
513
 
477
514
 
478
515
  @torch.compiler.disable
479
516
  def add_residual_diff(diff):
480
- cache_context = get_current_cache_context()
517
+ cache_context = get_cache_context()
481
518
  assert cache_context is not None, "cache_context must be set before"
482
519
  cache_context.add_residual_diff(diff)
483
520
 
484
521
 
485
522
  @torch.compiler.disable
486
523
  def get_residual_diffs():
487
- cache_context = get_current_cache_context()
524
+ cache_context = get_cache_context()
488
525
  assert cache_context is not None, "cache_context must be set before"
489
526
  return cache_context.get_residual_diffs()
490
527
 
491
528
 
492
529
  @torch.compiler.disable
493
530
  def get_cfg_residual_diffs():
494
- cache_context = get_current_cache_context()
531
+ cache_context = get_cache_context()
495
532
  assert cache_context is not None, "cache_context must be set before"
496
533
  return cache_context.get_cfg_residual_diffs()
497
534
 
498
535
 
499
536
  @torch.compiler.disable
500
537
  def is_taylorseer_enabled():
501
- cache_context = get_current_cache_context()
538
+ cache_context = get_cache_context()
502
539
  assert cache_context is not None, "cache_context must be set before"
503
540
  return cache_context.enable_taylorseer
504
541
 
505
542
 
506
543
  @torch.compiler.disable
507
544
  def is_encoder_taylorseer_enabled():
508
- cache_context = get_current_cache_context()
545
+ cache_context = get_cache_context()
509
546
  assert cache_context is not None, "cache_context must be set before"
510
547
  return cache_context.enable_encoder_taylorseer
511
548
 
512
549
 
513
550
  def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
514
- cache_context = get_current_cache_context()
551
+ cache_context = get_cache_context()
515
552
  assert cache_context is not None, "cache_context must be set before"
516
553
  return cache_context.get_taylorseers()
517
554
 
518
555
 
519
556
  def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
520
- cache_context = get_current_cache_context()
557
+ cache_context = get_cache_context()
521
558
  assert cache_context is not None, "cache_context must be set before"
522
559
  return cache_context.get_cfg_taylorseers()
523
560
 
524
561
 
525
562
  @torch.compiler.disable
526
563
  def is_taylorseer_cache_residual():
527
- cache_context = get_current_cache_context()
564
+ cache_context = get_cache_context()
528
565
  assert cache_context is not None, "cache_context must be set before"
529
566
  return cache_context.taylorseer_cache_type == "residual"
530
567
 
@@ -547,28 +584,28 @@ def is_encoder_cache_residual():
547
584
 
548
585
  @torch.compiler.disable
549
586
  def is_alter_cache_enabled():
550
- cache_context = get_current_cache_context()
587
+ cache_context = get_cache_context()
551
588
  assert cache_context is not None, "cache_context must be set before"
552
589
  return cache_context.enable_alter_cache
553
590
 
554
591
 
555
592
  @torch.compiler.disable
556
593
  def is_alter_cache():
557
- cache_context = get_current_cache_context()
594
+ cache_context = get_cache_context()
558
595
  assert cache_context is not None, "cache_context must be set before"
559
596
  return cache_context.is_alter_cache
560
597
 
561
598
 
562
599
  @torch.compiler.disable
563
600
  def is_in_warmup():
564
- cache_context = get_current_cache_context()
601
+ cache_context = get_cache_context()
565
602
  assert cache_context is not None, "cache_context must be set before"
566
603
  return cache_context.is_in_warmup()
567
604
 
568
605
 
569
606
  @torch.compiler.disable
570
607
  def is_l1_diff_enabled():
571
- cache_context = get_current_cache_context()
608
+ cache_context = get_cache_context()
572
609
  assert cache_context is not None, "cache_context must be set before"
573
610
  return (
574
611
  cache_context.l1_hidden_states_diff_threshold is not None
@@ -578,21 +615,21 @@ def is_l1_diff_enabled():
578
615
 
579
616
  @torch.compiler.disable
580
617
  def get_important_condition_threshold():
581
- cache_context = get_current_cache_context()
618
+ cache_context = get_cache_context()
582
619
  assert cache_context is not None, "cache_context must be set before"
583
620
  return cache_context.important_condition_threshold
584
621
 
585
622
 
586
623
  @torch.compiler.disable
587
624
  def non_compute_blocks_diff_threshold():
588
- cache_context = get_current_cache_context()
625
+ cache_context = get_cache_context()
589
626
  assert cache_context is not None, "cache_context must be set before"
590
627
  return cache_context.non_compute_blocks_diff_threshold
591
628
 
592
629
 
593
630
  @torch.compiler.disable
594
631
  def Fn_compute_blocks():
595
- cache_context = get_current_cache_context()
632
+ cache_context = get_cache_context()
596
633
  assert cache_context is not None, "cache_context must be set before"
597
634
  assert (
598
635
  cache_context.Fn_compute_blocks >= 1
@@ -612,7 +649,7 @@ def Fn_compute_blocks():
612
649
 
613
650
  @torch.compiler.disable
614
651
  def Fn_compute_blocks_ids():
615
- cache_context = get_current_cache_context()
652
+ cache_context = get_cache_context()
616
653
  assert cache_context is not None, "cache_context must be set before"
617
654
  assert (
618
655
  len(cache_context.Fn_compute_blocks_ids)
@@ -627,7 +664,7 @@ def Fn_compute_blocks_ids():
627
664
 
628
665
  @torch.compiler.disable
629
666
  def Bn_compute_blocks():
630
- cache_context = get_current_cache_context()
667
+ cache_context = get_cache_context()
631
668
  assert cache_context is not None, "cache_context must be set before"
632
669
  assert (
633
670
  cache_context.Bn_compute_blocks >= 0
@@ -647,7 +684,7 @@ def Bn_compute_blocks():
647
684
 
648
685
  @torch.compiler.disable
649
686
  def Bn_compute_blocks_ids():
650
- cache_context = get_current_cache_context()
687
+ cache_context = get_cache_context()
651
688
  assert cache_context is not None, "cache_context must be set before"
652
689
  assert (
653
690
  len(cache_context.Bn_compute_blocks_ids)
@@ -662,44 +699,41 @@ def Bn_compute_blocks_ids():
662
699
 
663
700
  @torch.compiler.disable
664
701
  def do_separate_cfg():
665
- cache_context = get_current_cache_context()
702
+ cache_context = get_cache_context()
666
703
  assert cache_context is not None, "cache_context must be set before"
667
704
  return cache_context.do_separate_cfg
668
705
 
669
706
 
670
707
  @torch.compiler.disable
671
708
  def is_separate_cfg_step():
672
- cache_context = get_current_cache_context()
709
+ cache_context = get_cache_context()
673
710
  assert cache_context is not None, "cache_context must be set before"
674
711
  return cache_context.is_separate_cfg_step()
675
712
 
676
713
 
677
714
  @torch.compiler.disable
678
715
  def cfg_diff_compute_separate():
679
- cache_context = get_current_cache_context()
716
+ cache_context = get_cache_context()
680
717
  assert cache_context is not None, "cache_context must be set before"
681
718
  return cache_context.cfg_diff_compute_separate
682
719
 
683
720
 
684
- _current_cache_context: DBCacheContext = None
685
-
686
-
687
721
  def collect_cache_kwargs(default_attrs: dict, **kwargs):
688
722
  # NOTE: This API will split kwargs into cache_kwargs and other_kwargs
689
723
  # default_attrs: specific settings for different pipelines
690
- cache_attrs = dataclasses.fields(DBCacheContext)
724
+ cache_attrs = dataclasses.fields(_CachedContext)
691
725
  cache_attrs = [
692
726
  attr
693
727
  for attr in cache_attrs
694
728
  if hasattr(
695
- DBCacheContext,
729
+ _CachedContext,
696
730
  attr.name,
697
731
  )
698
732
  ]
699
733
  cache_kwargs = {
700
734
  attr.name: kwargs.pop(
701
735
  attr.name,
702
- getattr(DBCacheContext, attr.name),
736
+ getattr(_CachedContext, attr.name),
703
737
  )
704
738
  for attr in cache_attrs
705
739
  }
@@ -941,7 +975,11 @@ def get_Bn_buffer(prefix: str = "Bn"):
941
975
 
942
976
 
943
977
  @torch.compiler.disable
944
- def set_Bn_encoder_buffer(buffer: torch.Tensor, prefix: str = "Bn"):
978
+ def set_Bn_encoder_buffer(buffer: torch.Tensor | None, prefix: str = "Bn"):
979
+ # DON'T set None Buffer
980
+ if buffer is None:
981
+ return
982
+
945
983
  # This buffer is use for encoder hidden states approximation.
946
984
  if is_encoder_taylorseer_enabled():
947
985
  # taylorseer, encoder_taylorseer
@@ -1053,7 +1091,7 @@ def apply_hidden_states_residual(
1053
1091
 
1054
1092
  @torch.compiler.disable
1055
1093
  def get_downsample_factor():
1056
- cache_context = get_current_cache_context()
1094
+ cache_context = get_cache_context()
1057
1095
  assert cache_context is not None, "cache_context must be set before"
1058
1096
  return cache_context.downsample_factor
1059
1097
 
@@ -1100,7 +1138,7 @@ def get_can_use_cache(
1100
1138
  "can not use cache."
1101
1139
  )
1102
1140
  # reset continuous cached steps stats
1103
- cache_context = get_current_cache_context()
1141
+ cache_context = get_cache_context()
1104
1142
  if not is_separate_cfg_step():
1105
1143
  cache_context.continuous_cached_steps = 0
1106
1144
  else:
@@ -1,8 +1,9 @@
1
+ from typing import Any, Tuple, List
1
2
  from diffusers import DiffusionPipeline
2
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
3
3
  from cache_dit.cache_factory.cache_types import CacheType
4
- from cache_dit.cache_factory.cache_adapters import BlockAdapter
5
- from cache_dit.cache_factory.cache_adapters import UnifiedCacheAdapter
4
+ from cache_dit.cache_factory.block_adapters import BlockAdapter
5
+ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
6
+ from cache_dit.cache_factory.cache_adapters import CachedAdapter
6
7
 
7
8
  from cache_dit.logger import init_logger
8
9
 
@@ -10,9 +11,8 @@ logger = init_logger(__name__)
10
11
 
11
12
 
12
13
  def enable_cache(
13
- # BlockAdapter & forward pattern
14
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
15
- forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
14
+ # DiffusionPipeline or BlockAdapter
15
+ pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
16
16
  # Cache context kwargs
17
17
  Fn_compute_blocks: int = 8,
18
18
  Bn_compute_blocks: int = 0,
@@ -29,8 +29,8 @@ def enable_cache(
29
29
  enable_encoder_taylorseer: bool = False,
30
30
  taylorseer_cache_type: str = "residual",
31
31
  taylorseer_order: int = 2,
32
- **other_cache_kwargs,
33
- ) -> DiffusionPipeline:
32
+ **other_cache_context_kwargs,
33
+ ) -> DiffusionPipeline | Any:
34
34
  r"""
35
35
  Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
36
36
  that match the specific Input and Output patterns).
@@ -43,9 +43,6 @@ def enable_cache(
43
43
  The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
44
44
  For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
45
45
  for the usgae of BlockAdapter.
46
- forward_pattern (`ForwardPattern`, *required*, defaults to `ForwardPattern.Pattern_0`):
47
- The forward pattern of Transformer block, please check https://github.com/vipshop/cache-dit/tree/main?tab=readme-ov-file#forward-pattern-matching
48
- for more details.
49
46
  Fn_compute_blocks (`int`, *required*, defaults to 8):
50
47
  Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
51
48
  at time step t, enabling the calculation of a more stable L1 diff and delivering more
@@ -106,7 +103,7 @@ def enable_cache(
106
103
  """
107
104
 
108
105
  # Collect cache context kwargs
109
- cache_context_kwargs = other_cache_kwargs.copy()
106
+ cache_context_kwargs = other_cache_context_kwargs.copy()
110
107
  cache_context_kwargs["cache_type"] = CacheType.DBCache
111
108
  cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
112
109
  cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
@@ -136,21 +133,32 @@ def enable_cache(
136
133
  }
137
134
 
138
135
  if isinstance(pipe_or_adapter, BlockAdapter):
139
- return UnifiedCacheAdapter.apply(
136
+ return CachedAdapter.apply(
140
137
  pipe=None,
141
138
  block_adapter=pipe_or_adapter,
142
- forward_pattern=forward_pattern,
143
139
  **cache_context_kwargs,
144
140
  )
145
141
  elif isinstance(pipe_or_adapter, DiffusionPipeline):
146
- return UnifiedCacheAdapter.apply(
142
+ return CachedAdapter.apply(
147
143
  pipe=pipe_or_adapter,
148
144
  block_adapter=None,
149
- forward_pattern=forward_pattern,
150
145
  **cache_context_kwargs,
151
146
  )
152
147
  else:
153
148
  raise ValueError(
149
+ f"type: {type(pipe_or_adapter)} is not valid, "
154
150
  "Please pass DiffusionPipeline or BlockAdapter"
155
151
  "for the 1's position param: pipe_or_adapter"
156
152
  )
153
+
154
+
155
+ def supported_pipelines(
156
+ **kwargs,
157
+ ) -> Tuple[int, List[str]]:
158
+ return BlockAdapterRegistry.supported_pipelines(**kwargs)
159
+
160
+
161
+ def get_adapter(
162
+ pipe: DiffusionPipeline | str | Any,
163
+ ) -> BlockAdapter:
164
+ return BlockAdapterRegistry.get_adapter(pipe)
@@ -19,39 +19,57 @@ class ForwardPattern(Enum):
19
19
  self.Supported = Supported
20
20
 
21
21
  Pattern_0 = (
22
- True,
23
- False,
24
- False,
25
- ("hidden_states", "encoder_hidden_states"),
26
- ("hidden_states", "encoder_hidden_states"),
27
- True,
22
+ True, # Return_H_First
23
+ False, # Return_H_Only
24
+ False, # Forward_H_only
25
+ ("hidden_states", "encoder_hidden_states"), # In
26
+ ("hidden_states", "encoder_hidden_states"), # Out
27
+ True, # Supported
28
28
  )
29
29
 
30
30
  Pattern_1 = (
31
- False,
32
- False,
33
- False,
34
- ("hidden_states", "encoder_hidden_states"),
35
- ("encoder_hidden_states", "hidden_states"),
36
- True,
31
+ False, # Return_H_First
32
+ False, # Return_H_Only
33
+ False, # Forward_H_only
34
+ ("hidden_states", "encoder_hidden_states"), # In
35
+ ("encoder_hidden_states", "hidden_states"), # Out
36
+ True, # Supported
37
37
  )
38
38
 
39
39
  Pattern_2 = (
40
- False,
41
- True,
42
- False,
43
- ("hidden_states", "encoder_hidden_states"),
44
- ("hidden_states",),
45
- True,
40
+ False, # Return_H_First
41
+ True, # Return_H_Only
42
+ False, # Forward_H_only
43
+ ("hidden_states", "encoder_hidden_states"), # In
44
+ ("hidden_states",), # Out
45
+ True, # Supported
46
46
  )
47
47
 
48
48
  Pattern_3 = (
49
- False,
50
- True,
51
- False,
52
- ("hidden_states",),
53
- ("hidden_states",),
54
- False,
49
+ False, # Return_H_First
50
+ True, # Return_H_Only
51
+ True, # Forward_H_only
52
+ ("hidden_states",), # In
53
+ ("hidden_states",), # Out
54
+ True, # Supported
55
+ )
56
+
57
+ Pattern_4 = (
58
+ True, # Return_H_First
59
+ False, # Return_H_Only
60
+ True, # Forward_H_only
61
+ ("hidden_states",), # In
62
+ ("hidden_states", "encoder_hidden_states"), # Out
63
+ True, # Supported
64
+ )
65
+
66
+ Pattern_5 = (
67
+ False, # Return_H_First
68
+ False, # Return_H_Only
69
+ True, # Forward_H_only
70
+ ("hidden_states",), # In
71
+ ("encoder_hidden_states", "hidden_states"), # Out
72
+ True, # Supported
55
73
  )
56
74
 
57
75
  @staticmethod
@@ -60,4 +78,7 @@ class ForwardPattern(Enum):
60
78
  ForwardPattern.Pattern_0,
61
79
  ForwardPattern.Pattern_1,
62
80
  ForwardPattern.Pattern_2,
81
+ ForwardPattern.Pattern_3,
82
+ ForwardPattern.Pattern_4,
83
+ ForwardPattern.Pattern_5,
63
84
  ]
@@ -0,0 +1,5 @@
1
+ from cache_dit.cache_factory.patch_functors.functor_base import PatchFunctor
2
+ from cache_dit.cache_factory.patch_functors.functor_flux import FluxPatchFunctor
3
+ from cache_dit.cache_factory.patch_functors.functor_chroma import (
4
+ ChromaPatchFunctor,
5
+ )
@@ -0,0 +1,18 @@
1
+ import torch
2
+ from abc import abstractmethod
3
+
4
+ from cache_dit.logger import init_logger
5
+
6
+ logger = init_logger(__name__)
7
+
8
+
9
+ class PatchFunctor:
10
+
11
+ @abstractmethod
12
+ def apply(
13
+ self,
14
+ transformer: torch.nn.Module,
15
+ *args,
16
+ **kwargs,
17
+ ) -> torch.nn.Module:
18
+ raise NotImplementedError("apply method is not implemented.")