cache-dit 0.2.26__py3-none-any.whl → 0.2.27__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +7 -6
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +15 -4
- cache_dit/cache_factory/block_adapters/__init__.py +538 -0
- cache_dit/cache_factory/block_adapters/block_adapters.py +333 -0
- cache_dit/cache_factory/block_adapters/block_registers.py +77 -0
- cache_dit/cache_factory/cache_adapters.py +120 -911
- cache_dit/cache_factory/cache_blocks/__init__.py +7 -9
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +2 -2
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +46 -41
- cache_dit/cache_factory/cache_blocks/pattern_base.py +98 -79
- cache_dit/cache_factory/cache_blocks/utils.py +13 -9
- cache_dit/cache_factory/cache_contexts/__init__.py +2 -0
- cache_dit/cache_factory/{cache_context.py → cache_contexts/cache_context.py} +89 -55
- cache_dit/cache_factory/cache_contexts/cache_manager.py +0 -0
- cache_dit/cache_factory/cache_interface.py +21 -18
- cache_dit/cache_factory/patch_functors/functor_chroma.py +3 -0
- cache_dit/cache_factory/patch_functors/functor_flux.py +4 -0
- cache_dit/quantize/quantize_ao.py +1 -0
- cache_dit/utils.py +19 -16
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/METADATA +42 -12
- cache_dit-0.2.27.dist-info/RECORD +47 -0
- cache_dit-0.2.26.dist-info/RECORD +0 -42
- /cache_dit/cache_factory/{taylorseer.py → cache_contexts/taylorseer.py} +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.26.dist-info → cache_dit-0.2.27.dist-info}/top_level.txt +0 -0
|
@@ -7,14 +7,15 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
|
7
7
|
import torch
|
|
8
8
|
import torch.distributed as dist
|
|
9
9
|
|
|
10
|
-
from cache_dit.cache_factory.taylorseer import TaylorSeer
|
|
10
|
+
from cache_dit.cache_factory.cache_contexts.taylorseer import TaylorSeer
|
|
11
11
|
from cache_dit.logger import init_logger
|
|
12
12
|
|
|
13
13
|
logger = init_logger(__name__)
|
|
14
14
|
|
|
15
15
|
|
|
16
16
|
@dataclasses.dataclass
|
|
17
|
-
class
|
|
17
|
+
class _CachedContext: # Internal CachedContext Impl class
|
|
18
|
+
name: str = "default"
|
|
18
19
|
# Dual Block Cache
|
|
19
20
|
# Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
|
|
20
21
|
Fn_compute_blocks: int = 1
|
|
@@ -99,6 +100,8 @@ class DBCacheContext:
|
|
|
99
100
|
|
|
100
101
|
@torch.compiler.disable
|
|
101
102
|
def __post_init__(self):
|
|
103
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
104
|
+
logger.info(f"Created _CacheContext: {self.name}")
|
|
102
105
|
# Some checks for settings
|
|
103
106
|
if self.do_separate_cfg:
|
|
104
107
|
assert self.enable_alter_cache is False, (
|
|
@@ -329,26 +332,60 @@ class DBCacheContext:
|
|
|
329
332
|
|
|
330
333
|
|
|
331
334
|
# TODO: Support context manager for different cache_context
|
|
335
|
+
_current_cache_context: _CachedContext = None
|
|
336
|
+
|
|
337
|
+
_cache_context_manager: Dict[str, _CachedContext] = {}
|
|
332
338
|
|
|
333
339
|
|
|
334
340
|
def create_cache_context(*args, **kwargs):
|
|
335
|
-
|
|
341
|
+
global _cache_context_manager
|
|
342
|
+
_context = _CachedContext(*args, **kwargs)
|
|
343
|
+
_cache_context_manager[_context.name] = _context
|
|
344
|
+
return _context
|
|
336
345
|
|
|
337
346
|
|
|
338
|
-
def
|
|
347
|
+
def get_cache_context():
|
|
339
348
|
return _current_cache_context
|
|
340
349
|
|
|
341
350
|
|
|
342
|
-
def
|
|
343
|
-
global _current_cache_context
|
|
344
|
-
|
|
351
|
+
def set_cache_context(cache_context: _CachedContext | str):
|
|
352
|
+
global _current_cache_context, _cache_context_manager
|
|
353
|
+
if isinstance(cache_context, _CachedContext):
|
|
354
|
+
_current_cache_context = cache_context
|
|
355
|
+
else:
|
|
356
|
+
_current_cache_context = _cache_context_manager[cache_context]
|
|
357
|
+
|
|
358
|
+
|
|
359
|
+
def reset_cache_context(cache_context: _CachedContext | str, *args, **kwargs):
|
|
360
|
+
global _cache_context_manager
|
|
361
|
+
if isinstance(cache_context, _CachedContext):
|
|
362
|
+
old_context_name = cache_context.name
|
|
363
|
+
if cache_context.name in _cache_context_manager:
|
|
364
|
+
del _cache_context_manager[cache_context.name]
|
|
365
|
+
# force use old_context name
|
|
366
|
+
kwargs["name"] = old_context_name
|
|
367
|
+
_context = _CachedContext(*args, **kwargs)
|
|
368
|
+
_cache_context_manager[_context.name] = _context
|
|
369
|
+
else:
|
|
370
|
+
old_context_name = cache_context
|
|
371
|
+
if cache_context in _cache_context_manager:
|
|
372
|
+
del _cache_context_manager[cache_context]
|
|
373
|
+
# force use old_context name
|
|
374
|
+
kwargs["name"] = old_context_name
|
|
375
|
+
_context = _CachedContext(*args, **kwargs)
|
|
376
|
+
_cache_context_manager[_context.name] = _context
|
|
377
|
+
|
|
378
|
+
return _context
|
|
345
379
|
|
|
346
380
|
|
|
347
381
|
@contextlib.contextmanager
|
|
348
|
-
def cache_context(cache_context):
|
|
349
|
-
global _current_cache_context
|
|
382
|
+
def cache_context(cache_context: _CachedContext | str):
|
|
383
|
+
global _current_cache_context, _cache_context_manager
|
|
350
384
|
old_cache_context = _current_cache_context
|
|
351
|
-
|
|
385
|
+
if isinstance(cache_context, _CachedContext):
|
|
386
|
+
_current_cache_context = cache_context
|
|
387
|
+
else:
|
|
388
|
+
_current_cache_context = _cache_context_manager[cache_context]
|
|
352
389
|
try:
|
|
353
390
|
yield
|
|
354
391
|
finally:
|
|
@@ -357,49 +394,49 @@ def cache_context(cache_context):
|
|
|
357
394
|
|
|
358
395
|
@torch.compiler.disable
|
|
359
396
|
def get_residual_diff_threshold():
|
|
360
|
-
cache_context =
|
|
397
|
+
cache_context = get_cache_context()
|
|
361
398
|
assert cache_context is not None, "cache_context must be set before"
|
|
362
399
|
return cache_context.get_residual_diff_threshold()
|
|
363
400
|
|
|
364
401
|
|
|
365
402
|
@torch.compiler.disable
|
|
366
403
|
def get_buffer(name):
|
|
367
|
-
cache_context =
|
|
404
|
+
cache_context = get_cache_context()
|
|
368
405
|
assert cache_context is not None, "cache_context must be set before"
|
|
369
406
|
return cache_context.get_buffer(name)
|
|
370
407
|
|
|
371
408
|
|
|
372
409
|
@torch.compiler.disable
|
|
373
410
|
def set_buffer(name, buffer):
|
|
374
|
-
cache_context =
|
|
411
|
+
cache_context = get_cache_context()
|
|
375
412
|
assert cache_context is not None, "cache_context must be set before"
|
|
376
413
|
cache_context.set_buffer(name, buffer)
|
|
377
414
|
|
|
378
415
|
|
|
379
416
|
@torch.compiler.disable
|
|
380
417
|
def remove_buffer(name):
|
|
381
|
-
cache_context =
|
|
418
|
+
cache_context = get_cache_context()
|
|
382
419
|
assert cache_context is not None, "cache_context must be set before"
|
|
383
420
|
cache_context.remove_buffer(name)
|
|
384
421
|
|
|
385
422
|
|
|
386
423
|
@torch.compiler.disable
|
|
387
424
|
def mark_step_begin():
|
|
388
|
-
cache_context =
|
|
425
|
+
cache_context = get_cache_context()
|
|
389
426
|
assert cache_context is not None, "cache_context must be set before"
|
|
390
427
|
cache_context.mark_step_begin()
|
|
391
428
|
|
|
392
429
|
|
|
393
430
|
@torch.compiler.disable
|
|
394
431
|
def get_current_step():
|
|
395
|
-
cache_context =
|
|
432
|
+
cache_context = get_cache_context()
|
|
396
433
|
assert cache_context is not None, "cache_context must be set before"
|
|
397
434
|
return cache_context.get_current_step()
|
|
398
435
|
|
|
399
436
|
|
|
400
437
|
@torch.compiler.disable
|
|
401
438
|
def get_current_step_residual_diff():
|
|
402
|
-
cache_context =
|
|
439
|
+
cache_context = get_cache_context()
|
|
403
440
|
assert cache_context is not None, "cache_context must be set before"
|
|
404
441
|
step = str(get_current_step())
|
|
405
442
|
residual_diffs = get_residual_diffs()
|
|
@@ -410,7 +447,7 @@ def get_current_step_residual_diff():
|
|
|
410
447
|
|
|
411
448
|
@torch.compiler.disable
|
|
412
449
|
def get_current_step_cfg_residual_diff():
|
|
413
|
-
cache_context =
|
|
450
|
+
cache_context = get_cache_context()
|
|
414
451
|
assert cache_context is not None, "cache_context must be set before"
|
|
415
452
|
step = str(get_current_step())
|
|
416
453
|
cfg_residual_diffs = get_cfg_residual_diffs()
|
|
@@ -421,110 +458,110 @@ def get_current_step_cfg_residual_diff():
|
|
|
421
458
|
|
|
422
459
|
@torch.compiler.disable
|
|
423
460
|
def get_current_transformer_step():
|
|
424
|
-
cache_context =
|
|
461
|
+
cache_context = get_cache_context()
|
|
425
462
|
assert cache_context is not None, "cache_context must be set before"
|
|
426
463
|
return cache_context.get_current_transformer_step()
|
|
427
464
|
|
|
428
465
|
|
|
429
466
|
@torch.compiler.disable
|
|
430
467
|
def get_cached_steps():
|
|
431
|
-
cache_context =
|
|
468
|
+
cache_context = get_cache_context()
|
|
432
469
|
assert cache_context is not None, "cache_context must be set before"
|
|
433
470
|
return cache_context.get_cached_steps()
|
|
434
471
|
|
|
435
472
|
|
|
436
473
|
@torch.compiler.disable
|
|
437
474
|
def get_cfg_cached_steps():
|
|
438
|
-
cache_context =
|
|
475
|
+
cache_context = get_cache_context()
|
|
439
476
|
assert cache_context is not None, "cache_context must be set before"
|
|
440
477
|
return cache_context.get_cfg_cached_steps()
|
|
441
478
|
|
|
442
479
|
|
|
443
480
|
@torch.compiler.disable
|
|
444
481
|
def get_max_cached_steps():
|
|
445
|
-
cache_context =
|
|
482
|
+
cache_context = get_cache_context()
|
|
446
483
|
assert cache_context is not None, "cache_context must be set before"
|
|
447
484
|
return cache_context.max_cached_steps
|
|
448
485
|
|
|
449
486
|
|
|
450
487
|
@torch.compiler.disable
|
|
451
488
|
def get_max_continuous_cached_steps():
|
|
452
|
-
cache_context =
|
|
489
|
+
cache_context = get_cache_context()
|
|
453
490
|
assert cache_context is not None, "cache_context must be set before"
|
|
454
491
|
return cache_context.max_continuous_cached_steps
|
|
455
492
|
|
|
456
493
|
|
|
457
494
|
@torch.compiler.disable
|
|
458
495
|
def get_continuous_cached_steps():
|
|
459
|
-
cache_context =
|
|
496
|
+
cache_context = get_cache_context()
|
|
460
497
|
assert cache_context is not None, "cache_context must be set before"
|
|
461
498
|
return cache_context.continuous_cached_steps
|
|
462
499
|
|
|
463
500
|
|
|
464
501
|
@torch.compiler.disable
|
|
465
502
|
def get_cfg_continuous_cached_steps():
|
|
466
|
-
cache_context =
|
|
503
|
+
cache_context = get_cache_context()
|
|
467
504
|
assert cache_context is not None, "cache_context must be set before"
|
|
468
505
|
return cache_context.cfg_continuous_cached_steps
|
|
469
506
|
|
|
470
507
|
|
|
471
508
|
@torch.compiler.disable
|
|
472
509
|
def add_cached_step():
|
|
473
|
-
cache_context =
|
|
510
|
+
cache_context = get_cache_context()
|
|
474
511
|
assert cache_context is not None, "cache_context must be set before"
|
|
475
512
|
cache_context.add_cached_step()
|
|
476
513
|
|
|
477
514
|
|
|
478
515
|
@torch.compiler.disable
|
|
479
516
|
def add_residual_diff(diff):
|
|
480
|
-
cache_context =
|
|
517
|
+
cache_context = get_cache_context()
|
|
481
518
|
assert cache_context is not None, "cache_context must be set before"
|
|
482
519
|
cache_context.add_residual_diff(diff)
|
|
483
520
|
|
|
484
521
|
|
|
485
522
|
@torch.compiler.disable
|
|
486
523
|
def get_residual_diffs():
|
|
487
|
-
cache_context =
|
|
524
|
+
cache_context = get_cache_context()
|
|
488
525
|
assert cache_context is not None, "cache_context must be set before"
|
|
489
526
|
return cache_context.get_residual_diffs()
|
|
490
527
|
|
|
491
528
|
|
|
492
529
|
@torch.compiler.disable
|
|
493
530
|
def get_cfg_residual_diffs():
|
|
494
|
-
cache_context =
|
|
531
|
+
cache_context = get_cache_context()
|
|
495
532
|
assert cache_context is not None, "cache_context must be set before"
|
|
496
533
|
return cache_context.get_cfg_residual_diffs()
|
|
497
534
|
|
|
498
535
|
|
|
499
536
|
@torch.compiler.disable
|
|
500
537
|
def is_taylorseer_enabled():
|
|
501
|
-
cache_context =
|
|
538
|
+
cache_context = get_cache_context()
|
|
502
539
|
assert cache_context is not None, "cache_context must be set before"
|
|
503
540
|
return cache_context.enable_taylorseer
|
|
504
541
|
|
|
505
542
|
|
|
506
543
|
@torch.compiler.disable
|
|
507
544
|
def is_encoder_taylorseer_enabled():
|
|
508
|
-
cache_context =
|
|
545
|
+
cache_context = get_cache_context()
|
|
509
546
|
assert cache_context is not None, "cache_context must be set before"
|
|
510
547
|
return cache_context.enable_encoder_taylorseer
|
|
511
548
|
|
|
512
549
|
|
|
513
550
|
def get_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
|
|
514
|
-
cache_context =
|
|
551
|
+
cache_context = get_cache_context()
|
|
515
552
|
assert cache_context is not None, "cache_context must be set before"
|
|
516
553
|
return cache_context.get_taylorseers()
|
|
517
554
|
|
|
518
555
|
|
|
519
556
|
def get_cfg_taylorseers() -> Tuple[TaylorSeer, TaylorSeer]:
|
|
520
|
-
cache_context =
|
|
557
|
+
cache_context = get_cache_context()
|
|
521
558
|
assert cache_context is not None, "cache_context must be set before"
|
|
522
559
|
return cache_context.get_cfg_taylorseers()
|
|
523
560
|
|
|
524
561
|
|
|
525
562
|
@torch.compiler.disable
|
|
526
563
|
def is_taylorseer_cache_residual():
|
|
527
|
-
cache_context =
|
|
564
|
+
cache_context = get_cache_context()
|
|
528
565
|
assert cache_context is not None, "cache_context must be set before"
|
|
529
566
|
return cache_context.taylorseer_cache_type == "residual"
|
|
530
567
|
|
|
@@ -547,28 +584,28 @@ def is_encoder_cache_residual():
|
|
|
547
584
|
|
|
548
585
|
@torch.compiler.disable
|
|
549
586
|
def is_alter_cache_enabled():
|
|
550
|
-
cache_context =
|
|
587
|
+
cache_context = get_cache_context()
|
|
551
588
|
assert cache_context is not None, "cache_context must be set before"
|
|
552
589
|
return cache_context.enable_alter_cache
|
|
553
590
|
|
|
554
591
|
|
|
555
592
|
@torch.compiler.disable
|
|
556
593
|
def is_alter_cache():
|
|
557
|
-
cache_context =
|
|
594
|
+
cache_context = get_cache_context()
|
|
558
595
|
assert cache_context is not None, "cache_context must be set before"
|
|
559
596
|
return cache_context.is_alter_cache
|
|
560
597
|
|
|
561
598
|
|
|
562
599
|
@torch.compiler.disable
|
|
563
600
|
def is_in_warmup():
|
|
564
|
-
cache_context =
|
|
601
|
+
cache_context = get_cache_context()
|
|
565
602
|
assert cache_context is not None, "cache_context must be set before"
|
|
566
603
|
return cache_context.is_in_warmup()
|
|
567
604
|
|
|
568
605
|
|
|
569
606
|
@torch.compiler.disable
|
|
570
607
|
def is_l1_diff_enabled():
|
|
571
|
-
cache_context =
|
|
608
|
+
cache_context = get_cache_context()
|
|
572
609
|
assert cache_context is not None, "cache_context must be set before"
|
|
573
610
|
return (
|
|
574
611
|
cache_context.l1_hidden_states_diff_threshold is not None
|
|
@@ -578,21 +615,21 @@ def is_l1_diff_enabled():
|
|
|
578
615
|
|
|
579
616
|
@torch.compiler.disable
|
|
580
617
|
def get_important_condition_threshold():
|
|
581
|
-
cache_context =
|
|
618
|
+
cache_context = get_cache_context()
|
|
582
619
|
assert cache_context is not None, "cache_context must be set before"
|
|
583
620
|
return cache_context.important_condition_threshold
|
|
584
621
|
|
|
585
622
|
|
|
586
623
|
@torch.compiler.disable
|
|
587
624
|
def non_compute_blocks_diff_threshold():
|
|
588
|
-
cache_context =
|
|
625
|
+
cache_context = get_cache_context()
|
|
589
626
|
assert cache_context is not None, "cache_context must be set before"
|
|
590
627
|
return cache_context.non_compute_blocks_diff_threshold
|
|
591
628
|
|
|
592
629
|
|
|
593
630
|
@torch.compiler.disable
|
|
594
631
|
def Fn_compute_blocks():
|
|
595
|
-
cache_context =
|
|
632
|
+
cache_context = get_cache_context()
|
|
596
633
|
assert cache_context is not None, "cache_context must be set before"
|
|
597
634
|
assert (
|
|
598
635
|
cache_context.Fn_compute_blocks >= 1
|
|
@@ -612,7 +649,7 @@ def Fn_compute_blocks():
|
|
|
612
649
|
|
|
613
650
|
@torch.compiler.disable
|
|
614
651
|
def Fn_compute_blocks_ids():
|
|
615
|
-
cache_context =
|
|
652
|
+
cache_context = get_cache_context()
|
|
616
653
|
assert cache_context is not None, "cache_context must be set before"
|
|
617
654
|
assert (
|
|
618
655
|
len(cache_context.Fn_compute_blocks_ids)
|
|
@@ -627,7 +664,7 @@ def Fn_compute_blocks_ids():
|
|
|
627
664
|
|
|
628
665
|
@torch.compiler.disable
|
|
629
666
|
def Bn_compute_blocks():
|
|
630
|
-
cache_context =
|
|
667
|
+
cache_context = get_cache_context()
|
|
631
668
|
assert cache_context is not None, "cache_context must be set before"
|
|
632
669
|
assert (
|
|
633
670
|
cache_context.Bn_compute_blocks >= 0
|
|
@@ -647,7 +684,7 @@ def Bn_compute_blocks():
|
|
|
647
684
|
|
|
648
685
|
@torch.compiler.disable
|
|
649
686
|
def Bn_compute_blocks_ids():
|
|
650
|
-
cache_context =
|
|
687
|
+
cache_context = get_cache_context()
|
|
651
688
|
assert cache_context is not None, "cache_context must be set before"
|
|
652
689
|
assert (
|
|
653
690
|
len(cache_context.Bn_compute_blocks_ids)
|
|
@@ -662,44 +699,41 @@ def Bn_compute_blocks_ids():
|
|
|
662
699
|
|
|
663
700
|
@torch.compiler.disable
|
|
664
701
|
def do_separate_cfg():
|
|
665
|
-
cache_context =
|
|
702
|
+
cache_context = get_cache_context()
|
|
666
703
|
assert cache_context is not None, "cache_context must be set before"
|
|
667
704
|
return cache_context.do_separate_cfg
|
|
668
705
|
|
|
669
706
|
|
|
670
707
|
@torch.compiler.disable
|
|
671
708
|
def is_separate_cfg_step():
|
|
672
|
-
cache_context =
|
|
709
|
+
cache_context = get_cache_context()
|
|
673
710
|
assert cache_context is not None, "cache_context must be set before"
|
|
674
711
|
return cache_context.is_separate_cfg_step()
|
|
675
712
|
|
|
676
713
|
|
|
677
714
|
@torch.compiler.disable
|
|
678
715
|
def cfg_diff_compute_separate():
|
|
679
|
-
cache_context =
|
|
716
|
+
cache_context = get_cache_context()
|
|
680
717
|
assert cache_context is not None, "cache_context must be set before"
|
|
681
718
|
return cache_context.cfg_diff_compute_separate
|
|
682
719
|
|
|
683
720
|
|
|
684
|
-
_current_cache_context: DBCacheContext = None
|
|
685
|
-
|
|
686
|
-
|
|
687
721
|
def collect_cache_kwargs(default_attrs: dict, **kwargs):
|
|
688
722
|
# NOTE: This API will split kwargs into cache_kwargs and other_kwargs
|
|
689
723
|
# default_attrs: specific settings for different pipelines
|
|
690
|
-
cache_attrs = dataclasses.fields(
|
|
724
|
+
cache_attrs = dataclasses.fields(_CachedContext)
|
|
691
725
|
cache_attrs = [
|
|
692
726
|
attr
|
|
693
727
|
for attr in cache_attrs
|
|
694
728
|
if hasattr(
|
|
695
|
-
|
|
729
|
+
_CachedContext,
|
|
696
730
|
attr.name,
|
|
697
731
|
)
|
|
698
732
|
]
|
|
699
733
|
cache_kwargs = {
|
|
700
734
|
attr.name: kwargs.pop(
|
|
701
735
|
attr.name,
|
|
702
|
-
getattr(
|
|
736
|
+
getattr(_CachedContext, attr.name),
|
|
703
737
|
)
|
|
704
738
|
for attr in cache_attrs
|
|
705
739
|
}
|
|
@@ -1057,7 +1091,7 @@ def apply_hidden_states_residual(
|
|
|
1057
1091
|
|
|
1058
1092
|
@torch.compiler.disable
|
|
1059
1093
|
def get_downsample_factor():
|
|
1060
|
-
cache_context =
|
|
1094
|
+
cache_context = get_cache_context()
|
|
1061
1095
|
assert cache_context is not None, "cache_context must be set before"
|
|
1062
1096
|
return cache_context.downsample_factor
|
|
1063
1097
|
|
|
@@ -1104,7 +1138,7 @@ def get_can_use_cache(
|
|
|
1104
1138
|
"can not use cache."
|
|
1105
1139
|
)
|
|
1106
1140
|
# reset continuous cached steps stats
|
|
1107
|
-
cache_context =
|
|
1141
|
+
cache_context = get_cache_context()
|
|
1108
1142
|
if not is_separate_cfg_step():
|
|
1109
1143
|
cache_context.continuous_cached_steps = 0
|
|
1110
1144
|
else:
|
|
File without changes
|
|
@@ -1,23 +1,18 @@
|
|
|
1
1
|
from typing import Any, Tuple, List
|
|
2
2
|
from diffusers import DiffusionPipeline
|
|
3
|
-
from cache_dit.cache_factory.forward_pattern import ForwardPattern
|
|
4
3
|
from cache_dit.cache_factory.cache_types import CacheType
|
|
5
|
-
from cache_dit.cache_factory.
|
|
6
|
-
from cache_dit.cache_factory.
|
|
4
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
5
|
+
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
6
|
+
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
7
7
|
|
|
8
8
|
from cache_dit.logger import init_logger
|
|
9
9
|
|
|
10
10
|
logger = init_logger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
def supported_pipelines() -> Tuple[int, List[str]]:
|
|
14
|
-
return UnifiedCacheAdapter.supported_pipelines()
|
|
15
|
-
|
|
16
|
-
|
|
17
13
|
def enable_cache(
|
|
18
|
-
#
|
|
14
|
+
# DiffusionPipeline or BlockAdapter
|
|
19
15
|
pipe_or_adapter: DiffusionPipeline | BlockAdapter | Any,
|
|
20
|
-
forward_pattern: ForwardPattern = ForwardPattern.Pattern_0,
|
|
21
16
|
# Cache context kwargs
|
|
22
17
|
Fn_compute_blocks: int = 8,
|
|
23
18
|
Bn_compute_blocks: int = 0,
|
|
@@ -34,7 +29,7 @@ def enable_cache(
|
|
|
34
29
|
enable_encoder_taylorseer: bool = False,
|
|
35
30
|
taylorseer_cache_type: str = "residual",
|
|
36
31
|
taylorseer_order: int = 2,
|
|
37
|
-
**
|
|
32
|
+
**other_cache_context_kwargs,
|
|
38
33
|
) -> DiffusionPipeline | Any:
|
|
39
34
|
r"""
|
|
40
35
|
Unified Cache API for almost Any Diffusion Transformers (with Transformer Blocks
|
|
@@ -48,9 +43,6 @@ def enable_cache(
|
|
|
48
43
|
The standard Diffusion Pipeline or custom BlockAdapter (from cache-dit or user-defined).
|
|
49
44
|
For example: cache_dit.enable_cache(FluxPipeline(...)). Please check https://github.com/vipshop/cache-dit/blob/main/docs/BlockAdapter.md
|
|
50
45
|
for the usgae of BlockAdapter.
|
|
51
|
-
forward_pattern (`ForwardPattern`, *required*, defaults to `ForwardPattern.Pattern_0`):
|
|
52
|
-
The forward pattern of Transformer block, please check https://github.com/vipshop/cache-dit/tree/main?tab=readme-ov-file#forward-pattern-matching
|
|
53
|
-
for more details.
|
|
54
46
|
Fn_compute_blocks (`int`, *required*, defaults to 8):
|
|
55
47
|
Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
56
48
|
at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
@@ -111,7 +103,7 @@ def enable_cache(
|
|
|
111
103
|
"""
|
|
112
104
|
|
|
113
105
|
# Collect cache context kwargs
|
|
114
|
-
cache_context_kwargs =
|
|
106
|
+
cache_context_kwargs = other_cache_context_kwargs.copy()
|
|
115
107
|
cache_context_kwargs["cache_type"] = CacheType.DBCache
|
|
116
108
|
cache_context_kwargs["Fn_compute_blocks"] = Fn_compute_blocks
|
|
117
109
|
cache_context_kwargs["Bn_compute_blocks"] = Bn_compute_blocks
|
|
@@ -141,21 +133,32 @@ def enable_cache(
|
|
|
141
133
|
}
|
|
142
134
|
|
|
143
135
|
if isinstance(pipe_or_adapter, BlockAdapter):
|
|
144
|
-
return
|
|
136
|
+
return CachedAdapter.apply(
|
|
145
137
|
pipe=None,
|
|
146
138
|
block_adapter=pipe_or_adapter,
|
|
147
|
-
forward_pattern=forward_pattern,
|
|
148
139
|
**cache_context_kwargs,
|
|
149
140
|
)
|
|
150
141
|
elif isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
151
|
-
return
|
|
142
|
+
return CachedAdapter.apply(
|
|
152
143
|
pipe=pipe_or_adapter,
|
|
153
144
|
block_adapter=None,
|
|
154
|
-
forward_pattern=forward_pattern,
|
|
155
145
|
**cache_context_kwargs,
|
|
156
146
|
)
|
|
157
147
|
else:
|
|
158
148
|
raise ValueError(
|
|
149
|
+
f"type: {type(pipe_or_adapter)} is not valid, "
|
|
159
150
|
"Please pass DiffusionPipeline or BlockAdapter"
|
|
160
151
|
"for the 1's position param: pipe_or_adapter"
|
|
161
152
|
)
|
|
153
|
+
|
|
154
|
+
|
|
155
|
+
def supported_pipelines(
|
|
156
|
+
**kwargs,
|
|
157
|
+
) -> Tuple[int, List[str]]:
|
|
158
|
+
return BlockAdapterRegistry.supported_pipelines(**kwargs)
|
|
159
|
+
|
|
160
|
+
|
|
161
|
+
def get_adapter(
|
|
162
|
+
pipe: DiffusionPipeline | str | Any,
|
|
163
|
+
) -> BlockAdapter:
|
|
164
|
+
return BlockAdapterRegistry.get_adapter(pipe)
|
|
@@ -30,6 +30,9 @@ class ChromaPatchFunctor(PatchFunctor):
|
|
|
30
30
|
blocks: torch.nn.ModuleList = None,
|
|
31
31
|
**kwargs,
|
|
32
32
|
) -> ChromaTransformer2DModel:
|
|
33
|
+
if getattr(transformer, "_is_patched", False):
|
|
34
|
+
return transformer
|
|
35
|
+
|
|
33
36
|
if blocks is None:
|
|
34
37
|
blocks = transformer.single_transformer_blocks
|
|
35
38
|
|
|
@@ -30,6 +30,10 @@ class FluxPatchFunctor(PatchFunctor):
|
|
|
30
30
|
blocks: torch.nn.ModuleList = None,
|
|
31
31
|
**kwargs,
|
|
32
32
|
) -> FluxTransformer2DModel:
|
|
33
|
+
|
|
34
|
+
if getattr(transformer, "_is_patched", False):
|
|
35
|
+
return transformer
|
|
36
|
+
|
|
33
37
|
if blocks is None:
|
|
34
38
|
blocks = transformer.single_transformer_blocks
|
|
35
39
|
|
|
@@ -179,6 +179,7 @@ def quantize_ao(
|
|
|
179
179
|
force_empty_cache()
|
|
180
180
|
|
|
181
181
|
logger.info(
|
|
182
|
+
f"Quantized Method: {quant_type:>5}\n"
|
|
182
183
|
f"Quantized Linear Layers: {num_quant_linear:>5}\n"
|
|
183
184
|
f"Skipped Linear Layers: {num_skip_linear:>5}\n"
|
|
184
185
|
f"Total Linear Layers: {num_linear_layers:>5}\n"
|
cache_dit/utils.py
CHANGED
|
@@ -30,27 +30,32 @@ class CacheStats:
|
|
|
30
30
|
|
|
31
31
|
|
|
32
32
|
def summary(
|
|
33
|
-
|
|
33
|
+
pipe_or_module: DiffusionPipeline | torch.nn.Module | Any,
|
|
34
34
|
details: bool = False,
|
|
35
35
|
logging: bool = True,
|
|
36
36
|
) -> CacheStats:
|
|
37
37
|
cache_stats = CacheStats()
|
|
38
|
-
|
|
39
|
-
if not isinstance(
|
|
40
|
-
assert hasattr(
|
|
41
|
-
|
|
38
|
+
|
|
39
|
+
if not isinstance(pipe_or_module, torch.nn.Module):
|
|
40
|
+
assert hasattr(pipe_or_module, "transformer")
|
|
41
|
+
module = pipe_or_module.transformer
|
|
42
|
+
cls_name = module.__class__.__name__
|
|
42
43
|
else:
|
|
43
|
-
|
|
44
|
+
module = pipe_or_module
|
|
45
|
+
|
|
46
|
+
cls_name = module.__class__.__name__
|
|
47
|
+
if isinstance(module, torch.nn.ModuleList):
|
|
48
|
+
cls_name = module[0].__class__.__name__
|
|
44
49
|
|
|
45
|
-
if hasattr(
|
|
46
|
-
cache_options =
|
|
50
|
+
if hasattr(module, "_cache_context_kwargs"):
|
|
51
|
+
cache_options = module._cache_context_kwargs
|
|
47
52
|
cache_stats.cache_options = cache_options
|
|
48
53
|
if logging:
|
|
49
54
|
print(f"\n🤗Cache Options: {cls_name}\n\n{cache_options}")
|
|
50
55
|
|
|
51
|
-
if hasattr(
|
|
52
|
-
cached_steps: list[int] =
|
|
53
|
-
residual_diffs: dict[str, float] = dict(
|
|
56
|
+
if hasattr(module, "_cached_steps"):
|
|
57
|
+
cached_steps: list[int] = module._cached_steps
|
|
58
|
+
residual_diffs: dict[str, float] = dict(module._residual_diffs)
|
|
54
59
|
cache_stats.cached_steps = cached_steps
|
|
55
60
|
cache_stats.residual_diffs = residual_diffs
|
|
56
61
|
|
|
@@ -91,11 +96,9 @@ def summary(
|
|
|
91
96
|
compact=True,
|
|
92
97
|
)
|
|
93
98
|
|
|
94
|
-
if hasattr(
|
|
95
|
-
cfg_cached_steps: list[int] =
|
|
96
|
-
cfg_residual_diffs: dict[str, float] = dict(
|
|
97
|
-
transformer._cfg_residual_diffs
|
|
98
|
-
)
|
|
99
|
+
if hasattr(module, "_cfg_cached_steps"):
|
|
100
|
+
cfg_cached_steps: list[int] = module._cfg_cached_steps
|
|
101
|
+
cfg_residual_diffs: dict[str, float] = dict(module._cfg_residual_diffs)
|
|
99
102
|
cache_stats.cfg_cached_steps = cfg_cached_steps
|
|
100
103
|
cache_stats.cfg_residual_diffs = cfg_residual_diffs
|
|
101
104
|
|