cache-dit 0.2.9__py3-none-any.whl → 0.2.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -17,5 +17,5 @@ __version__: str
17
17
  __version_tuple__: VERSION_TUPLE
18
18
  version_tuple: VERSION_TUPLE
19
19
 
20
- __version__ = version = '0.2.9'
21
- __version_tuple__ = version_tuple = (0, 2, 9)
20
+ __version__ = version = '0.2.11'
21
+ __version_tuple__ = version_tuple = (0, 2, 11)
@@ -1,168 +1,3 @@
1
- from enum import Enum
2
-
3
- from diffusers import DiffusionPipeline
4
-
5
- from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
- apply_db_cache_on_pipe,
7
- )
8
- from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
- apply_fb_cache_on_pipe,
10
- )
11
- from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
- apply_db_prune_on_pipe,
13
- )
14
- from cache_dit.logger import init_logger
15
-
16
-
17
- logger = init_logger(__name__)
18
-
19
-
20
- class CacheType(Enum):
21
- NONE = "NONE"
22
- FBCache = "First_Block_Cache"
23
- DBCache = "Dual_Block_Cache"
24
- DBPrune = "Dynamic_Block_Prune"
25
-
26
- @staticmethod
27
- def type(cache_type: "CacheType | str") -> "CacheType":
28
- if isinstance(cache_type, CacheType):
29
- return cache_type
30
- return CacheType.cache_type(cache_type)
31
-
32
- @staticmethod
33
- def cache_type(cache_type: "CacheType | str") -> "CacheType":
34
- if cache_type is None:
35
- return CacheType.NONE
36
-
37
- if isinstance(cache_type, CacheType):
38
- return cache_type
39
- if cache_type.lower() in (
40
- "first_block_cache",
41
- "fb_cache",
42
- "fbcache",
43
- "fb",
44
- ):
45
- return CacheType.FBCache
46
- elif cache_type.lower() in (
47
- "dual_block_cache",
48
- "db_cache",
49
- "dbcache",
50
- "db",
51
- ):
52
- return CacheType.DBCache
53
- elif cache_type.lower() in (
54
- "dynamic_block_prune",
55
- "db_prune",
56
- "dbprune",
57
- "dbp",
58
- ):
59
- return CacheType.DBPrune
60
- elif cache_type.lower() in (
61
- "none_cache",
62
- "nonecache",
63
- "no_cache",
64
- "nocache",
65
- "none",
66
- "no",
67
- ):
68
- return CacheType.NONE
69
- else:
70
- raise ValueError(f"Unknown cache type: {cache_type}")
71
-
72
- @staticmethod
73
- def range(start: int, end: int, step: int = 1) -> list[int]:
74
- if start > end or end <= 0 or step <= 1:
75
- return []
76
- # Always compute 0 and end - 1 blocks for DB Cache
77
- return list(
78
- sorted(set([0] + list(range(start, end, step)) + [end - 1]))
79
- )
80
-
81
- @staticmethod
82
- def default_options(cache_type: "CacheType | str") -> dict:
83
- _no_options = {
84
- "cache_type": CacheType.NONE,
85
- }
86
-
87
- _fb_options = {
88
- "cache_type": CacheType.FBCache,
89
- "residual_diff_threshold": 0.08,
90
- "warmup_steps": 8,
91
- "max_cached_steps": 8,
92
- }
93
-
94
- _Fn_compute_blocks = 8
95
- _Bn_compute_blocks = 8
96
-
97
- _db_options = {
98
- "cache_type": CacheType.DBCache,
99
- "residual_diff_threshold": 0.12,
100
- "warmup_steps": 8,
101
- "max_cached_steps": -1, # -1 means no limit
102
- # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
103
- "Fn_compute_blocks": _Fn_compute_blocks,
104
- "Bn_compute_blocks": _Bn_compute_blocks,
105
- "max_Fn_compute_blocks": 16,
106
- "max_Bn_compute_blocks": 16,
107
- "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
108
- "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
109
- }
110
-
111
- _dbp_options = {
112
- "cache_type": CacheType.DBPrune,
113
- "residual_diff_threshold": 0.08,
114
- "Fn_compute_blocks": _Fn_compute_blocks,
115
- "Bn_compute_blocks": _Bn_compute_blocks,
116
- "warmup_steps": 8,
117
- "max_pruned_steps": -1, # -1 means no limit
118
- }
119
-
120
- if cache_type == CacheType.FBCache:
121
- return _fb_options
122
- elif cache_type == CacheType.DBCache:
123
- return _db_options
124
- elif cache_type == CacheType.DBPrune:
125
- return _dbp_options
126
- elif cache_type == CacheType.NONE:
127
- return _no_options
128
- else:
129
- raise ValueError(f"Unknown cache type: {cache_type}")
130
-
131
-
132
- def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
133
- assert isinstance(pipe, DiffusionPipeline)
134
-
135
- if hasattr(pipe, "_is_cached") and pipe._is_cached:
136
- return pipe
137
-
138
- if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
139
- return pipe
140
-
141
- cache_type = kwargs.pop("cache_type", None)
142
- if cache_type is None:
143
- logger.warning(
144
- "No cache type specified, we will use DBCache by default. "
145
- "Please specify the cache_type explicitly if you want to "
146
- "use a different cache type."
147
- )
148
- # Force to use DBCache with default cache options
149
- return apply_db_cache_on_pipe(
150
- pipe,
151
- **CacheType.default_options(CacheType.DBCache),
152
- )
153
-
154
- cache_type = CacheType.type(cache_type)
155
-
156
- if cache_type == CacheType.FBCache:
157
- return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
158
- elif cache_type == CacheType.DBCache:
159
- return apply_db_cache_on_pipe(pipe, *args, **kwargs)
160
- elif cache_type == CacheType.DBPrune:
161
- return apply_db_prune_on_pipe(pipe, *args, **kwargs)
162
- elif cache_type == CacheType.NONE:
163
- logger.warning(
164
- f"Cache type is {cache_type}, no caching will be applied."
165
- )
166
- return pipe
167
- else:
168
- raise ValueError(f"Unknown cache type: {cache_type}")
1
+ from cache_dit.cache_factory.adapters import CacheType
2
+ from cache_dit.cache_factory.adapters import apply_cache_on_pipe
3
+ from cache_dit.cache_factory.utils import load_cache_options_from_yaml
@@ -0,0 +1,169 @@
1
+ from enum import Enum
2
+
3
+ from diffusers import DiffusionPipeline
4
+
5
+ from cache_dit.cache_factory.dual_block_cache.diffusers_adapters import (
6
+ apply_db_cache_on_pipe,
7
+ )
8
+ from cache_dit.cache_factory.first_block_cache.diffusers_adapters import (
9
+ apply_fb_cache_on_pipe,
10
+ )
11
+ from cache_dit.cache_factory.dynamic_block_prune.diffusers_adapters import (
12
+ apply_db_prune_on_pipe,
13
+ )
14
+
15
+ from cache_dit.logger import init_logger
16
+
17
+
18
+ logger = init_logger(__name__)
19
+
20
+
21
+ class CacheType(Enum):
22
+ NONE = "NONE"
23
+ FBCache = "First_Block_Cache"
24
+ DBCache = "Dual_Block_Cache"
25
+ DBPrune = "Dynamic_Block_Prune"
26
+
27
+ @staticmethod
28
+ def type(cache_type: "CacheType | str") -> "CacheType":
29
+ if isinstance(cache_type, CacheType):
30
+ return cache_type
31
+ return CacheType.cache_type(cache_type)
32
+
33
+ @staticmethod
34
+ def cache_type(cache_type: "CacheType | str") -> "CacheType":
35
+ if cache_type is None:
36
+ return CacheType.NONE
37
+
38
+ if isinstance(cache_type, CacheType):
39
+ return cache_type
40
+ if cache_type.lower() in (
41
+ "first_block_cache",
42
+ "fb_cache",
43
+ "fbcache",
44
+ "fb",
45
+ ):
46
+ return CacheType.FBCache
47
+ elif cache_type.lower() in (
48
+ "dual_block_cache",
49
+ "db_cache",
50
+ "dbcache",
51
+ "db",
52
+ ):
53
+ return CacheType.DBCache
54
+ elif cache_type.lower() in (
55
+ "dynamic_block_prune",
56
+ "db_prune",
57
+ "dbprune",
58
+ "dbp",
59
+ ):
60
+ return CacheType.DBPrune
61
+ elif cache_type.lower() in (
62
+ "none_cache",
63
+ "nonecache",
64
+ "no_cache",
65
+ "nocache",
66
+ "none",
67
+ "no",
68
+ ):
69
+ return CacheType.NONE
70
+ else:
71
+ raise ValueError(f"Unknown cache type: {cache_type}")
72
+
73
+ @staticmethod
74
+ def range(start: int, end: int, step: int = 1) -> list[int]:
75
+ if start > end or end <= 0 or step <= 1:
76
+ return []
77
+ # Always compute 0 and end - 1 blocks for DB Cache
78
+ return list(
79
+ sorted(set([0] + list(range(start, end, step)) + [end - 1]))
80
+ )
81
+
82
+ @staticmethod
83
+ def default_options(cache_type: "CacheType | str") -> dict:
84
+ _no_options = {
85
+ "cache_type": CacheType.NONE,
86
+ }
87
+
88
+ _fb_options = {
89
+ "cache_type": CacheType.FBCache,
90
+ "residual_diff_threshold": 0.08,
91
+ "warmup_steps": 8,
92
+ "max_cached_steps": 8,
93
+ }
94
+
95
+ _Fn_compute_blocks = 8
96
+ _Bn_compute_blocks = 8
97
+
98
+ _db_options = {
99
+ "cache_type": CacheType.DBCache,
100
+ "residual_diff_threshold": 0.12,
101
+ "warmup_steps": 8,
102
+ "max_cached_steps": -1, # -1 means no limit
103
+ # Fn=1, Bn=0, means FB Cache, otherwise, Dual Block Cache
104
+ "Fn_compute_blocks": _Fn_compute_blocks,
105
+ "Bn_compute_blocks": _Bn_compute_blocks,
106
+ "max_Fn_compute_blocks": 16,
107
+ "max_Bn_compute_blocks": 16,
108
+ "Fn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
109
+ "Bn_compute_blocks_ids": [], # 0, 1, 2, ..., 7, etc.
110
+ }
111
+
112
+ _dbp_options = {
113
+ "cache_type": CacheType.DBPrune,
114
+ "residual_diff_threshold": 0.08,
115
+ "Fn_compute_blocks": _Fn_compute_blocks,
116
+ "Bn_compute_blocks": _Bn_compute_blocks,
117
+ "warmup_steps": 8,
118
+ "max_pruned_steps": -1, # -1 means no limit
119
+ }
120
+
121
+ if cache_type == CacheType.FBCache:
122
+ return _fb_options
123
+ elif cache_type == CacheType.DBCache:
124
+ return _db_options
125
+ elif cache_type == CacheType.DBPrune:
126
+ return _dbp_options
127
+ elif cache_type == CacheType.NONE:
128
+ return _no_options
129
+ else:
130
+ raise ValueError(f"Unknown cache type: {cache_type}")
131
+
132
+
133
+ def apply_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
134
+ assert isinstance(pipe, DiffusionPipeline)
135
+
136
+ if hasattr(pipe, "_is_cached") and pipe._is_cached:
137
+ return pipe
138
+
139
+ if hasattr(pipe, "_is_pruned") and pipe._is_pruned:
140
+ return pipe
141
+
142
+ cache_type = kwargs.pop("cache_type", None)
143
+ if cache_type is None:
144
+ logger.warning(
145
+ "No cache type specified, we will use DBCache by default. "
146
+ "Please specify the cache_type explicitly if you want to "
147
+ "use a different cache type."
148
+ )
149
+ # Force to use DBCache with default cache options
150
+ return apply_db_cache_on_pipe(
151
+ pipe,
152
+ **CacheType.default_options(CacheType.DBCache),
153
+ )
154
+
155
+ cache_type = CacheType.type(cache_type)
156
+
157
+ if cache_type == CacheType.FBCache:
158
+ return apply_fb_cache_on_pipe(pipe, *args, **kwargs)
159
+ elif cache_type == CacheType.DBCache:
160
+ return apply_db_cache_on_pipe(pipe, *args, **kwargs)
161
+ elif cache_type == CacheType.DBPrune:
162
+ return apply_db_prune_on_pipe(pipe, *args, **kwargs)
163
+ elif cache_type == CacheType.NONE:
164
+ logger.warning(
165
+ f"Cache type is {cache_type}, no caching will be applied."
166
+ )
167
+ return pipe
168
+ else:
169
+ raise ValueError(f"Unknown cache type: {cache_type}")
@@ -0,0 +1,53 @@
1
+ import yaml
2
+ from cache_dit.cache_factory.adapters import CacheType
3
+
4
+
5
+ def load_cache_options_from_yaml(yaml_file_path):
6
+ try:
7
+ with open(yaml_file_path, "r") as f:
8
+ config = yaml.safe_load(f)
9
+
10
+ required_keys = [
11
+ "cache_type",
12
+ "warmup_steps",
13
+ "max_cached_steps",
14
+ "Fn_compute_blocks",
15
+ "Bn_compute_blocks",
16
+ "residual_diff_threshold",
17
+ ]
18
+ for key in required_keys:
19
+ if key not in config:
20
+ raise ValueError(
21
+ f"Configuration file missing required item: {key}"
22
+ )
23
+
24
+ # Convert cache_type to CacheType enum
25
+ if isinstance(config["cache_type"], str):
26
+ try:
27
+ config["cache_type"] = CacheType[config["cache_type"]]
28
+ except KeyError:
29
+ valid_types = [ct.name for ct in CacheType]
30
+ raise ValueError(
31
+ f"Invalid cache_type value: {config['cache_type']}, "
32
+ f"valid values are: {valid_types}"
33
+ )
34
+ elif not isinstance(config["cache_type"], CacheType):
35
+ raise ValueError(
36
+ f"cache_type must be a string or CacheType enum, "
37
+ f"got: {type(config['cache_type'])}"
38
+ )
39
+
40
+ # Handle default value for taylorseer_kwargs
41
+ if "taylorseer_kwargs" not in config and config.get(
42
+ "enable_taylorseer", False
43
+ ):
44
+ config["taylorseer_kwargs"] = {"n_derivatives": 2}
45
+
46
+ return config
47
+
48
+ except FileNotFoundError:
49
+ raise FileNotFoundError(
50
+ f"Configuration file not found: {yaml_file_path}"
51
+ )
52
+ except yaml.YAMLError as e:
53
+ raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
@@ -39,6 +39,14 @@ def set_custom_compile_configs(
39
39
  # https://github.com/pytorch/pytorch/issues/153791
40
40
  torch._inductor.config.autotune_local_cache = False
41
41
 
42
+ if dist.is_initialized():
43
+ # Enable compute comm overlap
44
+ torch._inductor.config.reorder_for_compute_comm_overlap = True
45
+ # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
46
+ torch._inductor.config.intra_node_bw = (
47
+ 64 if "L20" in torch.cuda.get_device_name() else 300
48
+ )
49
+
42
50
  FORCE_DISABLE_CUSTOM_COMPILE_CONFIG = (
43
51
  os.environ.get("CACHE_DIT_FORCE_DISABLE_CUSTOM_COMPILE_CONFIG", "0")
44
52
  == "1"
@@ -51,14 +59,6 @@ def set_custom_compile_configs(
51
59
  )
52
60
  return
53
61
 
54
- if dist.is_initialized():
55
- # Enable compute comm overlap
56
- torch._inductor.config.reorder_for_compute_comm_overlap = True
57
- # L20 64 GB/s, PCIe; A100/A800 NVLink 300 GB/s.
58
- torch._inductor.config.intra_node_bw = (
59
- 64 if "L20" in torch.cuda.get_device_name() else 300
60
- )
61
-
62
62
  # Below are default settings for torch.compile, you can change
63
63
  # them to your needs and test the performance
64
64
  torch._inductor.config.max_fusion_size = 64
@@ -334,22 +334,26 @@ compute_video_mse = partial(
334
334
  )
335
335
 
336
336
 
337
+ METRICS_CHOICES = [
338
+ "psnr",
339
+ "ssim",
340
+ "mse",
341
+ "fid",
342
+ "all",
343
+ ]
344
+
345
+
337
346
  # Entrypoints
338
347
  def get_args():
348
+ global METRICS_CHOICES
339
349
  parser = argparse.ArgumentParser(
340
350
  description="CacheDiT's Metrics CLI",
341
351
  formatter_class=argparse.ArgumentDefaultsHelpFormatter,
342
352
  )
343
- METRICS_CHOICES = [
344
- "psnr",
345
- "ssim",
346
- "mse",
347
- "fid",
348
- "all",
349
- ]
350
353
  parser.add_argument(
351
- "metric",
354
+ "metrics",
352
355
  type=str,
356
+ nargs="+",
353
357
  default="psnr",
354
358
  choices=METRICS_CHOICES,
355
359
  help=f"Metric choices: {METRICS_CHOICES}",
@@ -382,6 +386,49 @@ def get_args():
382
386
  default=None,
383
387
  help="Path to predicted video or Dir to predicted videos",
384
388
  )
389
+
390
+ # Image 1 vs N pattern
391
+ parser.add_argument(
392
+ "--img-source-dir",
393
+ "-d",
394
+ type=str,
395
+ default=None,
396
+ help="Path to dir that contains dirs of images",
397
+ )
398
+ parser.add_argument(
399
+ "--ref-img-dir",
400
+ "-r",
401
+ type=str,
402
+ default=None,
403
+ help="Path to ref dir that contains ground truth images",
404
+ )
405
+
406
+ # Video 1 vs N pattern
407
+ parser.add_argument(
408
+ "--video-source-dir",
409
+ "-vd",
410
+ type=str,
411
+ default=None,
412
+ help="Path to dir that contains many videos",
413
+ )
414
+ parser.add_argument(
415
+ "--ref-video",
416
+ "-rv",
417
+ type=str,
418
+ default=None,
419
+ help="Path to ground truth video",
420
+ )
421
+
422
+ # FID batch size
423
+ parser.add_argument(
424
+ "--fid-batch-size",
425
+ "-b",
426
+ type=int,
427
+ default=1,
428
+ help="Batch size for FID compute",
429
+ )
430
+
431
+ # Verbose
385
432
  parser.add_argument(
386
433
  "--enable-verbose",
387
434
  "-verbose",
@@ -389,10 +436,20 @@ def get_args():
389
436
  default=False,
390
437
  help="Show metrics progress verbose",
391
438
  )
439
+
440
+ # Format output
441
+ parser.add_argument(
442
+ "--sort-output",
443
+ "-sort",
444
+ action="store_true",
445
+ default=False,
446
+ help="Sort the outupt metrics results",
447
+ )
392
448
  return parser.parse_args()
393
449
 
394
450
 
395
451
  def entrypoint():
452
+ global METRICS_CHOICES
396
453
  args = get_args()
397
454
  logger.debug(args)
398
455
 
@@ -401,68 +458,270 @@ def entrypoint():
401
458
  set_metrics_verbose(True)
402
459
  DISABLE_VERBOSE = not get_metrics_verbose()
403
460
 
404
- if args.img_true is not None and args.img_test is not None:
405
- if any(
406
- (
407
- not os.path.exists(args.img_true),
408
- not os.path.exists(args.img_test),
409
- )
410
- ):
461
+ if "all" in args.metrics or "fid" in args.metrics:
462
+ FID = FrechetInceptionDistance(
463
+ disable_tqdm=DISABLE_VERBOSE,
464
+ batch_size=args.fid_batch_size,
465
+ )
466
+
467
+ METRICS_META: dict[str, float] = {}
468
+
469
+ # run one metric
470
+ def _run_metric(
471
+ metric: str,
472
+ img_true: str = None,
473
+ img_test: str = None,
474
+ video_true: str = None,
475
+ video_test: str = None,
476
+ ) -> None:
477
+ nonlocal FID
478
+ nonlocal METRICS_META
479
+ metric = metric.lower()
480
+ if img_true is not None and img_test is not None:
481
+ if any(
482
+ (
483
+ not os.path.exists(img_true),
484
+ not os.path.exists(img_test),
485
+ )
486
+ ):
487
+ return
488
+ # img_true and img_test can be files or dirs
489
+ img_true_info = os.path.basename(img_true)
490
+ img_test_info = os.path.basename(img_test)
491
+
492
+ def _logging_msg(value: float, n: int):
493
+ if value is None or n is None:
494
+ return
495
+ msg = (
496
+ f"{img_true_info} vs {img_test_info}, "
497
+ f"Num: {n}, {metric.upper()}: {value:.5f}"
498
+ )
499
+ METRICS_META[msg] = value
500
+ logger.info(msg)
501
+
502
+ if metric == "psnr" or metric == "all":
503
+ img_psnr, n = compute_psnr(img_true, img_test)
504
+ _logging_msg(img_psnr, n)
505
+ if metric == "ssim" or metric == "all":
506
+ img_ssim, n = compute_ssim(img_true, img_test)
507
+ _logging_msg(img_ssim, n)
508
+ if metric == "mse" or metric == "all":
509
+ img_mse, n = compute_mse(img_true, img_test)
510
+ _logging_msg(img_mse, n)
511
+ if metric == "fid" or metric == "all":
512
+ img_fid, n = FID.compute_fid(img_true, img_test)
513
+ _logging_msg(img_fid, n)
514
+
515
+ if video_true is not None and video_test is not None:
516
+ if any(
517
+ (
518
+ not os.path.exists(video_true),
519
+ not os.path.exists(video_test),
520
+ )
521
+ ):
522
+ return
523
+
524
+ # video_true and video_test can be files or dirs
525
+ video_true_info = os.path.basename(video_true)
526
+ video_test_info = os.path.basename(video_test)
527
+
528
+ def _logging_msg(value: float, n: int):
529
+ if value is None or n is None:
530
+ return
531
+ msg = (
532
+ f"{video_true_info} vs {video_test_info}, "
533
+ f"Frames: {n}, {metric.upper()}: {value:.5f}"
534
+ )
535
+ METRICS_META[msg] = value
536
+ logger.info(msg)
537
+
538
+ if metric == "psnr" or metric == "all":
539
+ video_psnr, n = compute_video_psnr(video_true, video_test)
540
+ _logging_msg(video_psnr, n)
541
+ if metric == "ssim" or metric == "all":
542
+ video_ssim, n = compute_video_ssim(video_true, video_test)
543
+ _logging_msg(video_ssim, n)
544
+ if metric == "mse" or metric == "all":
545
+ video_mse, n = compute_video_mse(video_true, video_test)
546
+ _logging_msg(video_mse, n)
547
+ if metric == "fid" or metric == "all":
548
+ video_fid, n = FID.compute_video_fid(video_true, video_test)
549
+ _logging_msg(video_fid, n)
550
+
551
+ # run selected metrics
552
+ if not DISABLE_VERBOSE:
553
+ logger.info(f"Selected metrics: {args.metrics}")
554
+
555
+ def _is_image_1vsN_pattern() -> bool:
556
+ return args.img_source_dir is not None and args.ref_img_dir is not None
557
+
558
+ def _is_video_1vsN_pattern() -> bool:
559
+ return args.video_source_dir is not None and args.ref_video is not None
560
+
561
+ assert not all((_is_image_1vsN_pattern(), _is_video_1vsN_pattern()))
562
+
563
+ if _is_image_1vsN_pattern():
564
+ # Glob Image dirs
565
+ if not os.path.exists(args.img_source_dir):
566
+ logger.error(f"{args.img_source_dir} not exist!")
411
567
  return
412
- # img_true and img_test can be files or dirs
413
- if args.metric == "psnr" or args.metric == "all":
414
- img_psnr, n = compute_psnr(args.img_true, args.img_test)
415
- logger.info(
416
- f"{args.img_true} vs {args.img_test}, Num: {n}, PSNR: {img_psnr}"
417
- )
418
- if args.metric == "ssim" or args.metric == "all":
419
- img_ssim, n = compute_ssim(args.img_true, args.img_test)
420
- logger.info(
421
- f"{args.img_true} vs {args.img_test}, Num: {n}, SSIM: {img_ssim}"
422
- )
423
- if args.metric == "mse" or args.metric == "all":
424
- img_mse, n = compute_mse(args.img_true, args.img_test)
425
- logger.info(
426
- f"{args.img_true} vs {args.img_test}, Num: {n}, MSE: {img_mse}"
427
- )
428
- if args.metric == "fid" or args.metric == "all":
429
- FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
430
- img_fid, n = FID.compute_fid(args.img_true, args.img_test)
568
+ if not os.path.exists(args.ref_img_dir):
569
+ logger.error(f"{args.ref_img_dir} not exist!")
570
+ return
571
+
572
+ directories = []
573
+ for item in os.listdir(args.img_source_dir):
574
+ item_path = os.path.join(args.img_source_dir, item)
575
+ if os.path.isdir(item_path):
576
+ if os.path.basename(item_path) == os.path.basename(
577
+ args.ref_img_dir
578
+ ):
579
+ continue
580
+ directories.append(item_path)
581
+
582
+ if len(directories) == 0:
583
+ return
584
+
585
+ directories = sorted(directories)
586
+ if not DISABLE_VERBOSE:
431
587
  logger.info(
432
- f"{args.img_true} vs {args.img_test}, Num: {n}, FID: {img_fid}"
588
+ f"Compare {args.ref_img_dir} vs {directories}, "
589
+ f"Num compares: {len(directories)}"
433
590
  )
434
- if args.video_true is not None and args.video_test is not None:
435
- if any(
436
- (
437
- not os.path.exists(args.video_true),
438
- not os.path.exists(args.video_test),
439
- )
440
- ):
591
+
592
+ for metric in args.metrics:
593
+ for img_test_dir in directories:
594
+ _run_metric(
595
+ metric=metric,
596
+ img_true=args.ref_img_dir,
597
+ img_test=img_test_dir,
598
+ )
599
+
600
+ elif _is_video_1vsN_pattern():
601
+ # Glob videos
602
+ if not os.path.exists(args.video_source_dir):
603
+ logger.error(f"{args.video_source_dir} not exist!")
604
+ return
605
+ if not os.path.exists(args.ref_video):
606
+ logger.error(f"{args.ref_video} not exist!")
607
+ return
608
+
609
+ video_source_dir: pathlib.Path = pathlib.Path(args.video_source_dir)
610
+ video_source_files = sorted(
611
+ [
612
+ file
613
+ for ext in _VIDEO_EXTENSIONS
614
+ for file in video_source_dir.rglob("*.{}".format(ext))
615
+ ]
616
+ )
617
+ video_source_files = [file.as_posix() for file in video_source_files]
618
+
619
+ video_source_selected = []
620
+ for video_source_file in video_source_files:
621
+ if os.path.basename(video_source_file) == os.path.basename(
622
+ args.ref_video
623
+ ):
624
+ continue
625
+ video_source_selected.append(video_source_file)
626
+
627
+ if len(video_source_selected) == 0:
441
628
  return
442
- # video_true and video_test can be files or dirs
443
- if args.metric == "psnr" or args.metric == "all":
444
- video_psnr, n = compute_video_psnr(args.video_true, args.video_test)
629
+
630
+ video_source_selected = sorted(video_source_selected)
631
+ if not DISABLE_VERBOSE:
445
632
  logger.info(
446
- f"{args.video_true} vs {args.video_test}, Frames: {n}, PSNR: {video_psnr}"
633
+ f"Compare {args.ref_video} vs {video_source_selected}, "
634
+ f"Num compares: {len(video_source_selected)}"
447
635
  )
448
- if args.metric == "ssim" or args.metric == "all":
449
- video_ssim, n = compute_video_ssim(args.video_true, args.video_test)
450
- logger.info(
451
- f"{args.video_true} vs {args.video_test}, Frames: {n}, SSIM: {video_ssim}"
636
+
637
+ for metric in args.metrics:
638
+ for video_test in video_source_selected:
639
+ _run_metric(
640
+ metric=metric,
641
+ video_true=args.ref_video,
642
+ video_test=video_test,
643
+ )
644
+
645
+ else:
646
+ for metric in args.metrics:
647
+ _run_metric(
648
+ metric=metric,
649
+ img_true=args.img_true,
650
+ img_test=args.img_test,
651
+ video_true=args.video_true,
652
+ video_test=args.video_test,
452
653
  )
453
- if args.metric == "mse" or args.metric == "all":
454
- video_mse, n = compute_video_mse(args.video_true, args.video_test)
455
- logger.info(
456
- f"{args.video_true} vs {args.video_test}, Frames: {n}, MSE: {video_mse}"
654
+
655
+ if args.sort_output:
656
+
657
+ def _parse_value(
658
+ text: str,
659
+ tag: str = "Num",
660
+ ) -> float:
661
+ import re
662
+
663
+ pattern = re.compile(
664
+ rf"{re.escape(tag)}:\s*(\d+\.?\d*)", re.IGNORECASE
457
665
  )
458
- if args.metric == "fid" or args.metric == "all":
459
- FID = FrechetInceptionDistance(disable_tqdm=DISABLE_VERBOSE)
460
- video_fid, n = FID.compute_video_fid(
461
- args.video_true, args.video_test
666
+
667
+ match = pattern.search(text)
668
+
669
+ if not match:
670
+ return None
671
+
672
+ if tag.lower() in METRICS_CHOICES:
673
+ return float(match.group(1))
674
+ return int(match.group(1))
675
+
676
+ def _format_item(
677
+ key: str,
678
+ metric: str,
679
+ value: float,
680
+ max_key_len: int,
681
+ ):
682
+ # U1-Q0-C0-NONE vs U4-Q1-C1-NONE
683
+ header = key.split(",")[0].strip()
684
+ # Num / Frames
685
+ if n := _parse_value(key, "Num"):
686
+ print(
687
+ f"{header:<{max_key_len}} Num: {n} "
688
+ f"{metric.upper()}: {value:<.4f}"
689
+ )
690
+ elif n := _parse_value(key, "Frames"):
691
+ print(
692
+ f"{header:<{max_key_len}} Frames: {n} "
693
+ f"{metric.upper()}: {value:<.4f}"
694
+ )
695
+ else:
696
+ raise ValueError("Num or Frames can not be NoneType.")
697
+
698
+ for metric in args.metrics:
699
+ selected_items = {}
700
+ for key in METRICS_META.keys():
701
+ if metric.upper() in key or metric.lower() in key:
702
+ selected_items[key] = METRICS_META[key]
703
+
704
+ reverse = True if metric.lower() in ["psnr", "ssim"] else False
705
+ sorted_items = sorted(
706
+ selected_items.items(), key=lambda x: x[1], reverse=reverse
462
707
  )
463
- logger.info(
464
- f"{args.video_true} vs {args.video_test}, Frames: {n}, FID: {video_fid}"
708
+ selected_keys = [
709
+ key.split(",")[0].strip() for key in selected_items.keys()
710
+ ]
711
+ max_key_len = max(len(key) for key in selected_keys)
712
+
713
+ format_len = int(max_key_len * 1.5)
714
+ res_len = format_len - len(f"Summary: {metric.upper()}")
715
+ left_len = res_len // 2
716
+ right_len = res_len - left_len
717
+ print("-" * format_len)
718
+ print(
719
+ " " * left_len + f"Summary: {metric.upper()}" + " " * right_len
465
720
  )
721
+ print("-" * format_len)
722
+ for key, value in sorted_items:
723
+ _format_item(key, metric, value, max_key_len)
724
+ print("-" * format_len)
466
725
 
467
726
 
468
727
  if __name__ == "__main__":
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: cache_dit
3
- Version: 0.2.9
3
+ Version: 0.2.11
4
4
  Summary: 🤗 CacheDiT: A Training-free and Easy-to-use Cache Acceleration Toolbox for Diffusion Transformers
5
5
  Author: DefTruth, vipshop.com, etc.
6
6
  Maintainer: DefTruth, vipshop.com, etc
@@ -10,6 +10,7 @@ Requires-Python: >=3.10
10
10
  Description-Content-Type: text/markdown
11
11
  License-File: LICENSE
12
12
  Requires-Dist: packaging
13
+ Requires-Dist: pyyaml
13
14
  Requires-Dist: torch>=2.5.1
14
15
  Requires-Dist: transformers>=4.51.3
15
16
  Requires-Dist: diffusers>=0.33.1
@@ -62,9 +63,8 @@ Dynamic: requires-python
62
63
  </div>
63
64
 
64
65
  ## 🔥News🔥
65
-
66
- - [2025-07-13] An end2end speedup example for FLUX using cache-dit is released! **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)**: A forked version of [huggingface/flux-fast](https://github.com/huggingface/flux-fast) that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20 while still maintaining **high precision**.
67
-
66
+ - [2025-07-18] 🎉First caching mechanism in **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** with **[cache-dit](https://github.com/vipshop/cache-dit)**, please check [PR](https://github.com/huggingface/flux-fast/pull/13).
67
+ - [2025-07-13] **[🤗flux-faster](https://github.com/xlite-dev/flux-faster)** is released! A forked version of **[🤗huggingface/flux-fast](https://github.com/huggingface/flux-fast)** that **makes flux-fast even faster** with **[cache-dit](https://github.com/vipshop/cache-dit)**, **3.3x** speedup on NVIDIA L20.
68
68
 
69
69
  ## 🤗 Introduction
70
70
 
@@ -1,10 +1,11 @@
1
1
  cache_dit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- cache_dit/_version.py,sha256=Iq6CyehddPOWDVsW9Hnb65BEkCEkAnt4bl0MAuqXKLA,511
2
+ cache_dit/_version.py,sha256=Y72g1mojWf0yRnnMW5zEUr6skXUSsqAdPjRJUrxXSYc,513
3
3
  cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
4
4
  cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
5
- cache_dit/cache_factory/__init__.py,sha256=5RNuhWakvvqrOV4vkqrEBA7d-V1LwcNSsjtW14mkqK8,5255
5
+ cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
6
+ cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
6
7
  cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
7
- cache_dit/cache_factory/utils.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
8
+ cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
8
9
  cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
10
  cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=itVEb6gT2eZuncAHUmP51ZS0r6v6cGtRvnPjyeXqKH8,71156
10
11
  cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
@@ -30,17 +31,17 @@ cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sh
30
31
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
31
32
  cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
32
33
  cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
33
- cache_dit/compile/utils.py,sha256=OTvkwcezSrApZ2M1IMkYtkEmFbkfpTknhHMgoBApd6U,3786
34
+ cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
34
35
  cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
35
36
  cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
36
37
  cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
37
38
  cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
38
39
  cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
39
40
  cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
40
- cache_dit/metrics/metrics.py,sha256=tzAtG_-fM1xPIBfRVFIBupvOWYzIO3xDq29Vy5rOBWc,14730
41
- cache_dit-0.2.9.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
42
- cache_dit-0.2.9.dist-info/METADATA,sha256=TdvKAftNWwijdCW8K-8iO7fITEcfllWX3FJdZ-qcRqA,28032
43
- cache_dit-0.2.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
44
- cache_dit-0.2.9.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
45
- cache_dit-0.2.9.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
46
- cache_dit-0.2.9.dist-info/RECORD,,
41
+ cache_dit/metrics/metrics.py,sha256=PAzyhJawos1UeMnHsxcu4edkwCSYMBmjDGRR_--I104,22410
42
+ cache_dit-0.2.11.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
43
+ cache_dit-0.2.11.dist-info/METADATA,sha256=HExb-ldgaYzZRSCy6Tg6syONjd0uDoB8_8AShvbfA-0,28213
44
+ cache_dit-0.2.11.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
45
+ cache_dit-0.2.11.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
46
+ cache_dit-0.2.11.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
47
+ cache_dit-0.2.11.dist-info/RECORD,,