cache-dit 0.2.8__py3-none-any.whl → 0.2.10__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.8'
21
- __version_tuple__ = version_tuple = (0, 2, 8)
20
+ __version__ = version = '0.2.10'
21
+ __version_tuple__ = version_tuple = (0, 2, 10)
@@ -1,168 +1,3 @@
1
- from enum import Enum
2
-
3
- from diffusers import DiffusionPipeline
4
-
5
- from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
- apply_db_cache_on_pipe,
7
- )
8
- from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
- apply_fb_cache_on_pipe,
10
- )
11
- from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
- apply_db_prune_on_pipe,
13
- )
14
- from cache_dit.logger import init_logger
15
-
16
-
17
- logger = init_logger(__name__)
18
-
19
-
20
- class CacheType(Enum):
21
- NONE = "NONE"
22
- FBCache = "First_Block_Cache"
23
- DBCache = "Dual_Block_Cache"
24
- DBPrune = "Dynamic_Block_Prune"
25
-
26
- @staticmethod
27
- def type(cache_type: "CacheType | str") -> "CacheType":
28
- if isinstance(cache_type, CacheType):
29
- return cache_type
30
- return CacheType.cache_type(cache_type)
31
-
32
- @staticmethod
33
- def cache_type(cache_type: "CacheType | str") -> "CacheType":
34
- if cache_type is None:
35
- return CacheType.NONE
36
-
37
- if isinstance(cache_type, CacheType):
38
- return cache_type
39
- if cache_type.lower() in (
40
- "first_block_cache",
41
- "fb_cache",
42
- "fbcache",
43
- "fb",
44
- ):
45
- return CacheType.FBCache
46
- elif cache_type.lower() in (
47
- "dual_block_cache",
48
- "db_cache",
49
- "dbcache",
50
- "db",
51
- ):
52
- return CacheType.DBCache
53
- elif cache_type.lower() in (
54
- "dynamic_block_prune",
55
- "db_prune",
56
- "dbprune",
57
- "dbp",
58
- ):
59
- return CacheType.DBPrune
60
- elif cache_type.lower() in (
61
- "none_cache",
62
- "nonecache",
63
- "no_cache",
64
- "nocache",
65
- "none",
66
- "no",
67
- ):
68
- return CacheType.NONE
69
- else:
70
- raise ValueError(f"Unknown cache type: {cache_type}")
71
-
72
- @staticmethod
73
- def range(start: int, end: int, step: int = 1) -> list[int]:
74
- if start > end or end <= 0 or step <= 1:
75
- return []
76
- # Always compute 0 and end - 1 blocks for DB Cache
77
- return list(
78
- sorted(set([0] + list(range(start, end, step)) + [end - 1]))
79
- )
80
-
81
- @staticmethod
82
- def default_options(cache_type: "CacheType | str") -> dict:
83
- _no_options = {
84
- "cache_type": CacheType.NONE,
85
- }
86
-
87
- _fb_options = {
88
- "cache_type": CacheType.FBCache,
89
- "residual_diff_threshold": 0.08,
90
- "warmup_steps": 8,
91
- "max_cached_steps": 8,
92
- }
93
-
94
- _Fn_compute_blocks = 8
95
- _Bn_compute_blocks = 8
96
-
97
- _db_options = {
98
- "cache_type": CacheType.DBCache,
99
- "residual_diff_threshold": 0.12,
100
- "warmup_steps": 8,
101
- "max_cached_steps": -1, # -1 means no limit
102
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
103
- "Fn_compute_blocks": _Fn_compute_blocks,
104
- "Bn_compute_blocks": _Bn_compute_blocks,
105
- "max_Fn_compute_blocks": 16,
106
- "max_Bn_compute_blocks": 16,
107
- "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
108
- "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
109
- }
110
-
111
- _dbp_options = {
112
- "cache_type": CacheType.DBPrune,
113
- "residual_diff_threshold": 0.08,
114
- "Fn_compute_blocks": _Fn_compute_blocks,
115
- "Bn_compute_blocks": _Bn_compute_blocks,
116
- "warmup_steps": 8,
117
- "max_pruned_steps": -1, # -1 means no limit
118
- }
119
-
120
- if cache_type == CacheType.FBCache:
121
- return _fb_options
122
- elif cache_type == CacheType.DBCache:
123
- return _db_options
124
- elif cache_type == CacheType.DBPrune:
125
- return _dbp_options
126
- elif cache_type == CacheType.NONE:
127
- return _no_options
128
- else:
129
- raise ValueError(f"Unknown cache type: {cache_type}")
130
-
131
-
132
- def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
133
- assert isinstance(pipe, DiffusionPipeline)
134
-
135
- if hasattr(pipe, "_is_cached") and pipe._is_cached:
136
- return pipe
137
-
138
- if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
139
- return pipe
140
-
141
- cache_type = kwargs.pop("cache_type", None)
142
- if cache_type is None:
143
- logger.warning(
144
- "No cache type specified, we will use DBCache by default. "
145
- "Please specify the cache_type explicitly if you want to "
146
- "use a different cache type."
147
- )
148
- # Force to use DBCache with default cache options
149
- return apply_db_cache_on_pipe(
150
- pipe,
151
- **CacheType.default_options(CacheType.DBCache),
152
- )
153
-
154
- cache_type = CacheType.type(cache_type)
155
-
156
- if cache_type == CacheType.FBCache:
157
- return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
158
- elif cache_type == CacheType.DBCache:
159
- return apply_db_cache_on_pipe(pipe, *args, **kwargs)
160
- elif cache_type == CacheType.DBPrune:
161
- return apply_db_prune_on_pipe(pipe, *args, **kwargs)
162
- elif cache_type == CacheType.NONE:
163
- logger.warning(
164
- f"Cache type is {cache_type}, no caching will be applied."
165
- )
166
- return pipe
167
- else:
168
- raise ValueError(f"Unknown cache type: {cache_type}")
1
+ from cache_dit.cache_factory.adapters import CacheType
2
+ from cache_dit.cache_factory.adapters import apply_cache_on_pipe
3
+ from cache_dit.cache_factory.utils import load_cache_options_from_yaml
@@ -0,0 +1,169 @@
1
+ from enum import Enum
2
+
3
+ from diffusers import DiffusionPipeline
4
+
5
+ from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
+ apply_db_cache_on_pipe,
7
+ )
8
+ from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
+ apply_fb_cache_on_pipe,
10
+ )
11
+ from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
+ apply_db_prune_on_pipe,
13
+ )
14
+
15
+ from cache_dit.logger import init_logger
16
+
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ class CacheType(Enum):
22
+ NONE = "NONE"
23
+ FBCache = "First_Block_Cache"
24
+ DBCache = "Dual_Block_Cache"
25
+ DBPrune = "Dynamic_Block_Prune"
26
+
27
+ @staticmethod
28
+ def type(cache_type: "CacheType | str") -> "CacheType":
29
+ if isinstance(cache_type, CacheType):
30
+ return cache_type
31
+ return CacheType.cache_type(cache_type)
32
+
33
+ @staticmethod
34
+ def cache_type(cache_type: "CacheType | str") -> "CacheType":
35
+ if cache_type is None:
36
+ return CacheType.NONE
37
+
38
+ if isinstance(cache_type, CacheType):
39
+ return cache_type
40
+ if cache_type.lower() in (
41
+ "first_block_cache",
42
+ "fb_cache",
43
+ "fbcache",
44
+ "fb",
45
+ ):
46
+ return CacheType.FBCache
47
+ elif cache_type.lower() in (
48
+ "dual_block_cache",
49
+ "db_cache",
50
+ "dbcache",
51
+ "db",
52
+ ):
53
+ return CacheType.DBCache
54
+ elif cache_type.lower() in (
55
+ "dynamic_block_prune",
56
+ "db_prune",
57
+ "dbprune",
58
+ "dbp",
59
+ ):
60
+ return CacheType.DBPrune
61
+ elif cache_type.lower() in (
62
+ "none_cache",
63
+ "nonecache",
64
+ "no_cache",
65
+ "nocache",
66
+ "none",
67
+ "no",
68
+ ):
69
+ return CacheType.NONE
70
+ else:
71
+ raise ValueError(f"Unknown cache type: {cache_type}")
72
+
73
+ @staticmethod
74
+ def range(start: int, end: int, step: int = 1) -> list[int]:
75
+ if start > end or end <= 0 or step <= 1:
76
+ return []
77
+ # Always compute 0 and end - 1 blocks for DB Cache
78
+ return list(
79
+ sorted(set([0] + list(range(start, end, step)) + [end - 1]))
80
+ )
81
+
82
+ @staticmethod
83
+ def default_options(cache_type: "CacheType | str") -> dict:
84
+ _no_options = {
85
+ "cache_type": CacheType.NONE,
86
+ }
87
+
88
+ _fb_options = {
89
+ "cache_type": CacheType.FBCache,
90
+ "residual_diff_threshold": 0.08,
91
+ "warmup_steps": 8,
92
+ "max_cached_steps": 8,
93
+ }
94
+
95
+ _Fn_compute_blocks = 8
96
+ _Bn_compute_blocks = 8
97
+
98
+ _db_options = {
99
+ "cache_type": CacheType.DBCache,
100
+ "residual_diff_threshold": 0.12,
101
+ "warmup_steps": 8,
102
+ "max_cached_steps": -1, # -1 means no limit
103
+ # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
104
+ "Fn_compute_blocks": _Fn_compute_blocks,
105
+ "Bn_compute_blocks": _Bn_compute_blocks,
106
+ "max_Fn_compute_blocks": 16,
107
+ "max_Bn_compute_blocks": 16,
108
+ "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
109
+ "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
110
+ }
111
+
112
+ _dbp_options = {
113
+ "cache_type": CacheType.DBPrune,
114
+ "residual_diff_threshold": 0.08,
115
+ "Fn_compute_blocks": _Fn_compute_blocks,
116
+ "Bn_compute_blocks": _Bn_compute_blocks,
117
+ "warmup_steps": 8,
118
+ "max_pruned_steps": -1, # -1 means no limit
119
+ }
120
+
121
+ if cache_type == CacheType.FBCache:
122
+ return _fb_options
123
+ elif cache_type == CacheType.DBCache:
124
+ return _db_options
125
+ elif cache_type == CacheType.DBPrune:
126
+ return _dbp_options
127
+ elif cache_type == CacheType.NONE:
128
+ return _no_options
129
+ else:
130
+ raise ValueError(f"Unknown cache type: {cache_type}")
131
+
132
+
133
+ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
134
+ assert isinstance(pipe, DiffusionPipeline)
135
+
136
+ if hasattr(pipe, "_is_cached") and pipe._is_cached:
137
+ return pipe
138
+
139
+ if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
140
+ return pipe
141
+
142
+ cache_type = kwargs.pop("cache_type", None)
143
+ if cache_type is None:
144
+ logger.warning(
145
+ "No cache type specified, we will use DBCache by default. "
146
+ "Please specify the cache_type explicitly if you want to "
147
+ "use a different cache type."
148
+ )
149
+ # Force to use DBCache with default cache options
150
+ return apply_db_cache_on_pipe(
151
+ pipe,
152
+ **CacheType.default_options(CacheType.DBCache),
153
+ )
154
+
155
+ cache_type = CacheType.type(cache_type)
156
+
157
+ if cache_type == CacheType.FBCache:
158
+ return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
159
+ elif cache_type == CacheType.DBCache:
160
+ return apply_db_cache_on_pipe(pipe, *args, **kwargs)
161
+ elif cache_type == CacheType.DBPrune:
162
+ return apply_db_prune_on_pipe(pipe, *args, **kwargs)
163
+ elif cache_type == CacheType.NONE:
164
+ logger.warning(
165
+ f"Cache type is {cache_type}, no caching will be applied."
166
+ )
167
+ return pipe
168
+ else:
169
+ raise ValueError(f"Unknown cache type: {cache_type}")
@@ -0,0 +1,53 @@
1
+ import yaml
2
+ from cache_dit.cache_factory.adapters import CacheType
3
+
4
+
5
+ def load_cache_options_from_yaml(yaml_file_path):
6
+ try:
7
+ with open(yaml_file_path, "r") as f:
8
+ config = yaml.safe_load(f)
9
+
10
+ required_keys = [
11
+ "cache_type",
12
+ "warmup_steps",
13
+ "max_cached_steps",
14
+ "Fn_compute_blocks",
15
+ "Bn_compute_blocks",
16
+ "residual_diff_threshold",
17
+ ]
18
+ for key in required_keys:
19
+ if key not in config:
20
+ raise ValueError(
21
+ f"Configuration file missing required item: {key}"
22
+ )
23
+
24
+ # Convert cache_type to CacheType enum
25
+ if isinstance(config["cache_type"], str):
26
+ try:
27
+ config["cache_type"] = CacheType[config["cache_type"]]
28
+ except KeyError:
29
+ valid_types = [ct.name for ct in CacheType]
30
+ raise ValueError(
31
+ f"Invalid cache_type value: {config['cache_type']}, "
32
+ f"valid values are: {valid_types}"
33
+ )
34
+ elif not isinstance(config["cache_type"], CacheType):
35
+ raise ValueError(
36
+ f"cache_type must be a string or CacheType enum, "
37
+ f"got: {type(config['cache_type'])}"
38
+ )
39
+
40
+ # Handle default value for taylorseer_kwargs
41
+ if "taylorseer_kwargs" not in config and config.get(
42
+ "enable_taylorseer", False
43
+ ):
44
+ config["taylorseer_kwargs"] = {"n_derivatives": 2}
45
+
46
+ return config
47
+
48
+ except FileNotFoundError:
49
+ raise FileNotFoundError(
50
+ f"Configuration file not found: {yaml_file_path}"
51
+ )
52
+ except yaml.YAMLError as e:
53
+ raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
@@ -1,6 +1,7 @@
1
1
  import os
2
2
 
3
3
  import torch
4
+ import torch.distributed as dist
4
5
  from cache_dit.logger import init_logger, logging_rank_0
5
6
 
6
7
  logger = init_logger(__name__)
@@ -50,12 +51,13 @@ def set_custom_compile_configs(
50
51
  )
51
52
  return
52
53
 
53
- # Enable compute comm overlap
54
- torch._inductor.config.reorder_for_compute_comm_overlap = True
55
- # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
56
- torch._inductor.config.intra_node_bw = (
57
- 64 if "L20" in torch.cuda.get_device_name() else 300
58
- )
54
+ if dist.is_initialized():
55
+ # Enable compute comm overlap
56
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
57
+ # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
58
+ torch._inductor.config.intra_node_bw = (
59
+ 64 if "L20" in torch.cuda.get_device_name() else 300
60
+ )
59
61
 
60
62
  # Below are default settings for torch.compile, you can change
61
63
  # them to your needs and test the performance
@@ -348,8 +348,9 @@ def get_args():
348
348
  "all",
349
349
  ]
350
350
  parser.add_argument(
351
- "metric",
351
+ "metrics",
352
352
  type=str,
353
+ nargs="+",
353
354
  default="psnr",
354
355
  choices=METRICS_CHOICES,
355
356
  help=f"Metric choices: {METRICS_CHOICES}",
@@ -389,6 +390,47 @@ def get_args():
389
390
  default=False,
390
391
  help="Show metrics progress verbose",
391
392
  )
393
+
394
+ # Image 1 vs N pattern
395
+ parser.add_argument(
396
+ "--img-source-dir",
397
+ "-d",
398
+ type=str,
399
+ default=None,
400
+ help="Path to dir that contains dirs of images",
401
+ )
402
+ parser.add_argument(
403
+ "--ref-img-dir",
404
+ "-r",
405
+ type=str,
406
+ default=None,
407
+ help="Path to ref dir that contains ground truth images",
408
+ )
409
+
410
+ # Video 1 vs N pattern
411
+ parser.add_argument(
412
+ "--video-source-dir",
413
+ "-vd",
414
+ type=str,
415
+ default=None,
416
+ help="Path to dir that contains many videos",
417
+ )
418
+ parser.add_argument(
419
+ "--ref-video",
420
+ "-rv",
421
+ type=str,
422
+ default=None,
423
+ help="Path to ground truth video",
424
+ )
425
+
426
+ # FID batch size
427
+ parser.add_argument(
428
+ "--fid-batch-size",
429
+ "-b",
430
+ type=int,
431
+ default=1,
432
+ help="Batch size for FID compute",
433
+ )
392
434
  return parser.parse_args()
393
435
 
394
436
 
@@ -401,67 +443,195 @@ def entrypoint():
401
443
  set_metrics_verbose(True)
402
444
  DISABLE_VERBOSE = not get_metrics_verbose()
403
445
 
404
- if args.img_true is not None and args.img_test is not None:
405
- if any(
406
- (
407
- not os.path.exists(args.img_true),
408
- not os.path.exists(args.img_test),
409
- )
410
- ):
446
+ if "all" in args.metrics or "fid" in args.metrics:
447
+ FID = FrechetInceptionDistance(
448
+ disable_tqdm=DISABLE_VERBOSE,
449
+ batch_size=args.fid_batch_size,
450
+ )
451
+
452
+ # run one metric
453
+ def _run_metric(
454
+ mertric: str,
455
+ img_true: str = None,
456
+ img_test: str = None,
457
+ video_true: str = None,
458
+ video_test: str = None,
459
+ ) -> None:
460
+ nonlocal FID
461
+ mertric = mertric.lower()
462
+ if img_true is not None and img_test is not None:
463
+ if any(
464
+ (
465
+ not os.path.exists(img_true),
466
+ not os.path.exists(img_test),
467
+ )
468
+ ):
469
+ return
470
+ # img_true and img_test can be files or dirs
471
+ img_true_info = os.path.basename(img_true)
472
+ img_test_info = os.path.basename(img_test)
473
+ if mertric == "psnr" or mertric == "all":
474
+ img_psnr, n = compute_psnr(img_true, img_test)
475
+ logger.info(
476
+ f"{img_true_info} vs {img_test_info}, "
477
+ f"Num: {n}, PSNR: {img_psnr}"
478
+ )
479
+ if mertric == "ssim" or mertric == "all":
480
+ img_ssim, n = compute_ssim(img_true, img_test)
481
+ logger.info(
482
+ f"{img_true_info} vs {img_test_info}, "
483
+ f"Num: {n}, SSIM: {img_ssim}"
484
+ )
485
+ if mertric == "mse" or mertric == "all":
486
+ img_mse, n = compute_mse(img_true, img_test)
487
+ logger.info(
488
+ f"{img_true_info} vs {img_test_info}, "
489
+ f"Num: {n}, MSE: {img_mse}"
490
+ )
491
+ if mertric == "fid" or mertric == "all":
492
+ img_fid, n = FID.compute_fid(img_true, img_test)
493
+ logger.info(
494
+ f"{img_true_info} vs {img_test_info}, "
495
+ f"Num: {n}, FID: {img_fid}"
496
+ )
497
+ if video_true is not None and video_test is not None:
498
+ if any(
499
+ (
500
+ not os.path.exists(video_true),
501
+ not os.path.exists(video_test),
502
+ )
503
+ ):
504
+ return
505
+ # video_true and video_test can be files or dirs
506
+ video_true_info = os.path.basename(video_true)
507
+ video_test_info = os.path.basename(video_test)
508
+ if mertric == "psnr" or mertric == "all":
509
+ video_psnr, n = compute_video_psnr(video_true, video_test)
510
+ logger.info(
511
+ f"{video_true_info} vs {video_test_info}, "
512
+ f"Frames: {n}, PSNR: {video_psnr}"
513
+ )
514
+ if mertric == "ssim" or mertric == "all":
515
+ video_ssim, n = compute_video_ssim(video_true, video_test)
516
+ logger.info(
517
+ f"{video_true_info} vs {video_test_info}, "
518
+ f"Frames: {n}, SSIM: {video_ssim}"
519
+ )
520
+ if mertric == "mse" or mertric == "all":
521
+ video_mse, n = compute_video_mse(video_true, video_test)
522
+ logger.info(
523
+ f"{video_true_info} vs {video_test_info}, "
524
+ f"Frames: {n}, MSE: {video_mse}"
525
+ )
526
+ if mertric == "fid" or mertric == "all":
527
+ video_fid, n = FID.compute_video_fid(video_true, video_test)
528
+ logger.info(
529
+ f"{video_true_info} vs {video_test_info}, "
530
+ f"Frames: {n}, FID: {video_fid}"
531
+ )
532
+
533
+ # run selected metrics
534
+ if not DISABLE_VERBOSE:
535
+ logger.info(f"Selected metrics: {args.metrics}")
536
+
537
+ def _is_image_1vsN_pattern() -> bool:
538
+ return args.img_source_dir is not None and args.ref_img_dir is not None
539
+
540
+ def _is_video_1vsN_pattern() -> bool:
541
+ return args.video_source_dir is not None and args.ref_video is not None
542
+
543
+ assert not all((_is_image_1vsN_pattern(), _is_video_1vsN_pattern()))
544
+
545
+ if _is_image_1vsN_pattern():
546
+ # Glob Image dirs
547
+ if not os.path.exists(args.img_source_dir):
548
+ logger.error(f"{args.img_source_dir} not exist!")
411
549
  return
412
- # img_true and img_test can be files or dirs
413
- if args.metric == "psnr" or args.metric == "all":
414
- img_psnr, n = compute_psnr(args.img_true, args.img_test)
415
- logger.info(
416
- f"{args.img_true} vs {args.img_test}, Num: {n}, PSNR: {img_psnr}"
417
- )
418
- if args.metric == "ssim" or args.metric == "all":
419
- img_ssim, n = compute_ssim(args.img_true, args.img_test)
420
- logger.info(
421
- f"{args.img_true} vs {args.img_test}, Num: {n}, SSIM: {img_ssim}"
422
- )
423
- if args.metric == "mse" or args.metric == "all":
424
- img_mse, n = compute_mse(args.img_true, args.img_test)
425
- logger.info(
426
- f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
427
- )
428
- if args.metric == "fid" or args.metric == "all":
429
- FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
430
- img_fid, n = FID.compute_fid(args.img_true, args.img_test)
431
- logger.info(
432
- f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
433
- )
434
- if args.video_true is not None and args.video_test is not None:
435
- if any(
436
- (
437
- not os.path.exists(args.video_true),
438
- not os.path.exists(args.video_test),
439
- )
440
- ):
550
+ if not os.path.exists(args.ref_img_dir):
551
+ logger.error(f"{args.ref_img_dir} not exist!")
441
552
  return
442
- # video_true and video_test can be files or dirs
443
- if args.metric == "psnr" or args.metric == "all":
444
- video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
445
- logger.info(
446
- f"{args.video_true} vs {args.video_test}, Frames: {n}, PSNR: {video_psnr}"
447
- )
448
- if args.metric == "ssim" or args.metric == "all":
449
- video_ssim, n = compute_video_ssim(args.video_true, args.video_test)
553
+
554
+ directories = []
555
+ for item in os.listdir(args.img_source_dir):
556
+ item_path = os.path.join(args.img_source_dir, item)
557
+ if os.path.isdir(item_path):
558
+ if os.path.basename(item_path) == os.path.basename(
559
+ args.ref_img_dir
560
+ ):
561
+ continue
562
+ directories.append(item_path)
563
+
564
+ if len(directories) == 0:
565
+ return
566
+
567
+ directories = sorted(directories)
568
+ if not DISABLE_VERBOSE:
450
569
  logger.info(
451
- f"{args.video_true} vs {args.video_test}, Frames: {n}, SSIM: {video_ssim}"
570
+ f"Compare {args.ref_img_dir} vs {directories}, "
571
+ f"Num compares: {len(directories)}"
452
572
  )
453
- if args.metric == "mse" or args.metric == "all":
454
- video_mse, n = compute_video_mse(args.video_true, args.video_test)
573
+
574
+ for metric in args.metrics:
575
+ for img_test_dir in directories:
576
+ _run_metric(
577
+ mertric=metric,
578
+ img_true=args.ref_img_dir,
579
+ img_test=img_test_dir,
580
+ )
581
+
582
+ elif _is_video_1vsN_pattern():
583
+ # Glob videos
584
+ if not os.path.exists(args.video_source_dir):
585
+ logger.error(f"{args.video_source_dir} not exist!")
586
+ return
587
+ if not os.path.exists(args.ref_video):
588
+ logger.error(f"{args.ref_video} not exist!")
589
+ return
590
+
591
+ video_source_dir: pathlib.Path = pathlib.Path(args.video_source_dir)
592
+ video_source_files = sorted(
593
+ [
594
+ file
595
+ for ext in _VIDEO_EXTENSIONS
596
+ for file in video_source_dir.rglob("*.{}".format(ext))
597
+ ]
598
+ )
599
+ video_source_files = [file.as_posix() for file in video_source_files]
600
+
601
+ video_source_selected = []
602
+ for video_source_file in video_source_files:
603
+ if os.path.basename(video_source_file) == os.path.basename(
604
+ args.ref_video
605
+ ):
606
+ continue
607
+ video_source_selected.append(video_source_file)
608
+
609
+ if len(video_source_selected) == 0:
610
+ return
611
+
612
+ video_source_selected = sorted(video_source_selected)
613
+ if not DISABLE_VERBOSE:
455
614
  logger.info(
456
- f"{args.video_true} vs {args.video_test}, Frames: {n}, MSE: {video_mse}"
457
- )
458
- if args.metric == "fid" or args.metric == "all":
459
- FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
460
- video_fid, n = FID.compute_video_fid(
461
- args.video_true, args.video_test
615
+ f"Compare {args.ref_video} vs {video_source_selected}, "
616
+ f"Num compares: {len(video_source_selected)}"
462
617
  )
463
- logger.info(
464
- f"{args.video_true} vs {args.video_test}, Frames: {n}, FID: {video_fid}"
618
+
619
+ for metric in args.metrics:
620
+ for video_test in video_source_selected:
621
+ _run_metric(
622
+ mertric=metric,
623
+ video_true=args.ref_video,
624
+ video_test=video_test,
625
+ )
626
+
627
+ else:
628
+ for metric in args.metrics:
629
+ _run_metric(
630
+ mertric=metric,
631
+ img_true=args.img_true,
632
+ img_test=args.img_test,
633
+ video_true=args.video_true,
634
+ video_test=args.video_test,
465
635
  )
466
636
 
467
637
 
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.8
3
+ Version: 0.2.10
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -10,6 +10,7 @@ Requires-Python: >=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: packaging
13
+ Requires-Dist: pyyaml
13
14
  Requires-Dist: torch>=2.5.1
14
15
  Requires-Dist: transformers>=4.51.3
15
16
  Requires-Dist: diffusers>=0.33.1
@@ -61,6 +62,11 @@ Dynamic: requires-python
61
62
  </p>
62
63
  </div>
63
64
 
65
+ ## 🔥News🔥
66
+
67
+ - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of [huggingface/flux-fast](https://github.com/huggingface/flux-fast) that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20 while still maintaining **high precision**.
68
+
69
+
64
70
  ## 🤗 Introduction
65
71
 
66
72
  <div align="center">
@@ -1,10 +1,11 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=zkhRarrvPoGA1yWjS9_zVM80dWqpDesNn9DiHcF4JWM,511
2
+ cache_dit/_version.py,sha256=fMOyoyXAggjNTgl2YJ-8HW1bnjjDPiNACUsDoNufScI,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
- cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
5
+ cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
6
+ cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
6
7
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
7
- cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
8
9
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
10
  cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=itVEb6gT2eZuncAHUmP51ZS0r6v6cGtRvnPjyeXqKH8,71156
10
11
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
@@ -30,17 +31,17 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sh
30
31
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
31
32
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
32
33
  cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
33
- cache_dit/compile/utils.py,sha256=KU60xc474Anbj7Y_FLRFmNxEjVYLLXkhbtCLXO7o_Tc,3699
34
+ cache_dit/compile/utils.py,sha256=OTvkwcezSrApZ2M1IMkYtkEmFbkfpTknhHMgoBApd6U,3786
34
35
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
36
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
37
  cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
37
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
38
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
39
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
40
- cache_dit/metrics/metrics.py,sha256=tzAtG_-fM1xPIBfRVFIBupvOWYzIO3xDq29Vy5rOBWc,14730
41
- cache_dit-0.2.8.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
42
- cache_dit-0.2.8.dist-info/METADATA,sha256=8E51DpSKDGqk3_cG9buahoXN-7fub6M8VCiPb_Idg64,27608
43
- cache_dit-0.2.8.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
44
- cache_dit-0.2.8.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
45
- cache_dit-0.2.8.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
46
- cache_dit-0.2.8.dist-info/RECORD,,
41
+ cache_dit/metrics/metrics.py,sha256=O9a8qV6deQDWEoez7UZ_aqDLQ9rJXAJUMHGnJM7RUMs,19927
42
+ cache_dit-0.2.10.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.10.dist-info/METADATA,sha256=f5PF-lhexcLdB2HmBWETBEyg011eZOc99tlVI1lozYA,28002
44
+ cache_dit-0.2.10.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.10.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.10.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.10.dist-info/RECORD,,