cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/offload_utils.py +115 -0
  11. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  12. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  13. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  14. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  15. cache_dit/caching/cache_contexts/__init__.py +28 -0
  16. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
  18. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  21. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  22. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  23. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  24. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  25. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  26. cache_dit/caching/cache_interface.py +358 -0
  27. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  28. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  29. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  30. cache_dit/caching/patch_functors/__init__.py +15 -0
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  36. cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
  37. cache_dit/caching/utils.py +68 -0
  38. cache_dit/metrics/__init__.py +11 -0
  39. cache_dit/metrics/metrics.py +3 -0
  40. cache_dit/parallelism/__init__.py +3 -0
  41. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  57. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  58. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  59. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  60. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  61. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  62. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  68. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  69. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  70. cache_dit/parallelism/parallel_backend.py +26 -0
  71. cache_dit/parallelism/parallel_config.py +88 -0
  72. cache_dit/parallelism/parallel_interface.py +77 -0
  73. cache_dit/quantize/__init__.py +7 -0
  74. cache_dit/quantize/backends/__init__.py +1 -0
  75. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  76. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  77. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
  78. cache_dit/quantize/quantize_backend.py +0 -0
  79. cache_dit/quantize/quantize_config.py +0 -0
  80. cache_dit/quantize/quantize_interface.py +3 -16
  81. cache_dit/summary.py +593 -0
  82. cache_dit/utils.py +46 -290
  83. cache_dit-1.0.14.dist-info/METADATA +301 -0
  84. cache_dit-1.0.14.dist-info/RECORD +102 -0
  85. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  86. cache_dit/cache_factory/__init__.py +0 -28
  87. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  88. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  89. cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
  90. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  91. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
  92. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
  93. cache_dit/cache_factory/cache_blocks/utils.py +0 -41
  94. cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
  95. cache_dit/cache_factory/cache_interface.py +0 -217
  96. cache_dit/cache_factory/patch_functors/__init__.py +0 -12
  97. cache_dit/cache_factory/utils.py +0 -57
  98. cache_dit-0.3.2.dist-info/METADATA +0 -753
  99. cache_dit-0.3.2.dist-info/RECORD +0 -56
  100. cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
  101. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  102. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  103. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  104. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  105. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  106. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  107. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  108. {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,86 @@
1
+ import torch
2
+
3
+ from typing import Any
4
+ from cache_dit.caching import CachedContext
5
+ from cache_dit.caching import CachedContextManager
6
+ from cache_dit.caching import PrunedContextManager
7
+
8
+
9
+ def apply_stats(
10
+ module: torch.nn.Module | Any,
11
+ cache_context: CachedContext | str = None,
12
+ context_manager: CachedContextManager | PrunedContextManager = None,
13
+ ):
14
+ # Patch the cached stats to the module, the cached stats
15
+ # will be reset for each calling of pipe.__call__(**kwargs).
16
+ if module is None or context_manager is None:
17
+ return
18
+
19
+ if cache_context is not None:
20
+ context_manager.set_context(cache_context)
21
+
22
+ # Cache stats for Dual Block Cache
23
+ module._cached_steps = context_manager.get_cached_steps()
24
+ module._residual_diffs = context_manager.get_residual_diffs()
25
+ module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
26
+ module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
27
+ # Pruned stats for Dynamic Block Prune
28
+ if not isinstance(context_manager, PrunedContextManager):
29
+ return
30
+ module._pruned_steps = context_manager.get_pruned_steps()
31
+ module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
32
+ module._pruned_blocks = context_manager.get_pruned_blocks()
33
+ module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
34
+ module._actual_blocks = context_manager.get_actual_blocks()
35
+ module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
36
+ # Caculate pruned ratio
37
+ if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
38
+ module._pruned_ratio = sum(module._pruned_blocks) / sum(
39
+ module._actual_blocks
40
+ )
41
+ else:
42
+ module._pruned_ratio = None
43
+ if (
44
+ len(module._cfg_pruned_blocks) > 0
45
+ and sum(module._cfg_actual_blocks) > 0
46
+ ):
47
+ module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
48
+ module._cfg_actual_blocks
49
+ )
50
+ else:
51
+ module._cfg_pruned_ratio = None
52
+
53
+
54
+ def remove_stats(
55
+ module: torch.nn.Module | Any,
56
+ ):
57
+ if module is None:
58
+ return
59
+
60
+ # Dual Block Cache
61
+ if hasattr(module, "_cached_steps"):
62
+ del module._cached_steps
63
+ if hasattr(module, "_residual_diffs"):
64
+ del module._residual_diffs
65
+ if hasattr(module, "_cfg_cached_steps"):
66
+ del module._cfg_cached_steps
67
+ if hasattr(module, "_cfg_residual_diffs"):
68
+ del module._cfg_residual_diffs
69
+
70
+ # Dynamic Block Prune
71
+ if hasattr(module, "_pruned_steps"):
72
+ del module._pruned_steps
73
+ if hasattr(module, "_cfg_pruned_steps"):
74
+ del module._cfg_pruned_steps
75
+ if hasattr(module, "_pruned_blocks"):
76
+ del module._pruned_blocks
77
+ if hasattr(module, "_cfg_pruned_blocks"):
78
+ del module._cfg_pruned_blocks
79
+ if hasattr(module, "_actual_blocks"):
80
+ del module._actual_blocks
81
+ if hasattr(module, "_cfg_actual_blocks"):
82
+ del module._cfg_actual_blocks
83
+ if hasattr(module, "_pruned_ratio"):
84
+ del module._pruned_ratio
85
+ if hasattr(module, "_cfg_pruned_ratio"):
86
+ del module._cfg_pruned_ratio
@@ -0,0 +1,28 @@
1
+ from cache_dit.caching.cache_contexts.calibrators import (
2
+ Calibrator,
3
+ CalibratorBase,
4
+ CalibratorConfig,
5
+ TaylorSeerCalibratorConfig,
6
+ FoCaCalibratorConfig,
7
+ )
8
+ from cache_dit.caching.cache_contexts.cache_config import (
9
+ BasicCacheConfig,
10
+ DBCacheConfig,
11
+ )
12
+ from cache_dit.caching.cache_contexts.cache_context import (
13
+ CachedContext,
14
+ )
15
+ from cache_dit.caching.cache_contexts.cache_manager import (
16
+ CachedContextManager,
17
+ ContextNotExistError,
18
+ )
19
+ from cache_dit.caching.cache_contexts.prune_config import DBPruneConfig
20
+ from cache_dit.caching.cache_contexts.prune_context import (
21
+ PrunedContext,
22
+ )
23
+ from cache_dit.caching.cache_contexts.prune_manager import (
24
+ PrunedContextManager,
25
+ )
26
+ from cache_dit.caching.cache_contexts.context_manager import (
27
+ ContextManager,
28
+ )
@@ -0,0 +1,120 @@
1
+ import torch
2
+ import dataclasses
3
+ from typing import Optional, Union
4
+ from cache_dit.caching.cache_types import CacheType
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class BasicCacheConfig:
12
+ # Default: Dual Block Cache with Flexible FnBn configuration.
13
+ cache_type: CacheType = CacheType.DBCache # DBCache, DBPrune, NONE
14
+
15
+ # Fn_compute_blocks: (`int`, *required*, defaults to 8):
16
+ # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
17
+ # at time step t, enabling the calculation of a more stable L1 diff and delivering more
18
+ # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
19
+ # for more details of DBCache.
20
+ Fn_compute_blocks: int = 8
21
+ # Bn_compute_blocks: (`int`, *required*, defaults to 0):
22
+ # Further fuses approximate information in the **last n** Transformer blocks to enhance
23
+ # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
24
+ # that use residual cache.
25
+ Bn_compute_blocks: int = 0
26
+ # residual_diff_threshold (`float`, *required*, defaults to 0.08):
27
+ # the value of residual diff threshold, a higher value leads to faster performance at the
28
+ # cost of lower precision.
29
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.08
30
+ # max_warmup_steps (`int`, *required*, defaults to 8):
31
+ # DBCache does not apply the caching strategy when the number of running steps is less than
32
+ # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
33
+ max_warmup_steps: int = 8 # DON'T Cache in warmup steps
34
+ # warmup_interval (`int`, *required*, defaults to 1):
35
+ # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
36
+ # in warmup steps will be computed, others will use dynamic cache.
37
+ warmup_interval: int = 1 # skip interval in warmup steps
38
+ # max_cached_steps (`int`, *required*, defaults to -1):
39
+ # DBCache disables the caching strategy when the previous cached steps exceed this value to
40
+ # prevent precision degradation.
41
+ max_cached_steps: int = -1 # for both CFG and non-CFG
42
+ # max_continuous_cached_steps (`int`, *required*, defaults to -1):
43
+ # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
44
+ # prevent precision degradation.
45
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
46
+ # enable_separate_cfg (`bool`, *required*, defaults to None):
47
+ # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
48
+ # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
49
+ # CogVideoX, HunyuanVideo, Mochi, etc.
50
+ enable_separate_cfg: Optional[bool] = None
51
+ # cfg_compute_first (`bool`, *required*, defaults to False):
52
+ # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
53
+ # 1, 3, 5, ... -> CFG step.
54
+ cfg_compute_first: bool = False
55
+ # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
56
+ # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
57
+ # use the computed diff from current non-CFG transformer step for current CFG step.
58
+ cfg_diff_compute_separate: bool = True
59
+ # num_inference_steps (`int`, *optional*, defaults to None):
60
+ # num_inference_steps for DiffusionPipeline, used to adjust some internal settings
61
+ # for better caching performance. For example, we will refresh the cache once the
62
+ # executed steps exceed num_inference_steps if num_inference_steps is provided.
63
+ num_inference_steps: Optional[int] = None
64
+
65
+ def update(self, **kwargs) -> "BasicCacheConfig":
66
+ for key, value in kwargs.items():
67
+ if hasattr(self, key):
68
+ if value is not None:
69
+ setattr(self, key, value)
70
+ return self
71
+
72
+ def empty(self, **kwargs) -> "BasicCacheConfig":
73
+ # Set all fields to None
74
+ for field in dataclasses.fields(self):
75
+ if hasattr(self, field.name):
76
+ setattr(self, field.name, None)
77
+ if kwargs:
78
+ self.update(**kwargs)
79
+ return self
80
+
81
+ def reset(self, **kwargs) -> "BasicCacheConfig":
82
+ return self.empty(**kwargs)
83
+
84
+ def as_dict(self) -> dict:
85
+ return dataclasses.asdict(self)
86
+
87
+ def strify(self) -> str:
88
+ return (
89
+ f"{self.cache_type}_"
90
+ f"F{self.Fn_compute_blocks}"
91
+ f"B{self.Bn_compute_blocks}_"
92
+ f"W{self.max_warmup_steps}"
93
+ f"I{self.warmup_interval}"
94
+ f"M{max(0, self.max_cached_steps)}"
95
+ f"MC{max(0, self.max_continuous_cached_steps)}_"
96
+ f"R{self.residual_diff_threshold}"
97
+ )
98
+
99
+
100
+ @dataclasses.dataclass
101
+ class ExtraCacheConfig:
102
+ # Some other not very important settings for Dual Block Cache.
103
+ # NOTE: These flags maybe deprecated in the future and users
104
+ # should never use these extra configurations in their cases.
105
+
106
+ # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
107
+ # The hidden states diff threshold for DBCache if use hidden_states as
108
+ # cache (not residual).
109
+ l1_hidden_states_diff_threshold: float = None
110
+ # important_condition_threshold (`float`, *optional*, defaults to 0.0):
111
+ # Only select the most important tokens while calculating the l1 diff.
112
+ important_condition_threshold: float = 0.0
113
+ # downsample_factor (`int`, *optional*, defaults to 1):
114
+ # Downsample factor for Fn buffer, in order the save GPU memory.
115
+ downsample_factor: int = 1
116
+
117
+
118
+ @dataclasses.dataclass
119
+ class DBCacheConfig(BasicCacheConfig):
120
+ pass # Just an alias for BasicCacheConfig
@@ -1,3 +1,5 @@
1
+ # The cache context codebase is adapted from FBCache. Over time its codebase
2
+ # diverged a lot, and context API is no longer compatible with FBCache.
1
3
  import logging
2
4
  import dataclasses
3
5
  from collections import defaultdict
@@ -5,7 +7,12 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
5
7
 
6
8
  import torch
7
9
 
8
- from cache_dit.cache_factory.cache_contexts.calibrators import (
10
+ from cache_dit.caching.cache_contexts.cache_config import (
11
+ BasicCacheConfig,
12
+ ExtraCacheConfig,
13
+ DBCacheConfig,
14
+ )
15
+ from cache_dit.caching.cache_contexts.calibrators import (
9
16
  Calibrator,
10
17
  CalibratorBase,
11
18
  CalibratorConfig,
@@ -15,96 +22,16 @@ from cache_dit.logger import init_logger
15
22
  logger = init_logger(__name__)
16
23
 
17
24
 
18
- @dataclasses.dataclass
19
- class BasicCacheConfig:
20
- # Dual Block Cache with Flexible FnBn configuration.
21
-
22
- # Fn_compute_blocks: (`int`, *required*, defaults to 8):
23
- # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
24
- # at time step t, enabling the calculation of a more stable L1 diff and delivering more
25
- # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
26
- # for more details of DBCache.
27
- Fn_compute_blocks: int = 8
28
- # Bn_compute_blocks: (`int`, *required*, defaults to 0):
29
- # Further fuses approximate information in the **last n** Transformer blocks to enhance
30
- # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
31
- # that use residual cache.
32
- Bn_compute_blocks: int = 0
33
- # residual_diff_threshold (`float`, *required*, defaults to 0.08):
34
- # the value of residual diff threshold, a higher value leads to faster performance at the
35
- # cost of lower precision.
36
- residual_diff_threshold: Union[torch.Tensor, float] = 0.08
37
- # max_warmup_steps (`int`, *required*, defaults to 8):
38
- # DBCache does not apply the caching strategy when the number of running steps is less than
39
- # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
40
- max_warmup_steps: int = 8 # DON'T Cache in warmup steps
41
- # max_cached_steps (`int`, *required*, defaults to -1):
42
- # DBCache disables the caching strategy when the previous cached steps exceed this value to
43
- # prevent precision degradation.
44
- max_cached_steps: int = -1 # for both CFG and non-CFG
45
- # max_continuous_cached_steps (`int`, *required*, defaults to -1):
46
- # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
47
- # prevent precision degradation.
48
- max_continuous_cached_steps: int = -1 # the max continuous cached steps
49
- # enable_separate_cfg (`bool`, *required*, defaults to None):
50
- # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
51
- # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
52
- # CogVideoX, HunyuanVideo, Mochi, etc.
53
- enable_separate_cfg: Optional[bool] = None
54
- # cfg_compute_first (`bool`, *required*, defaults to False):
55
- # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
56
- # 1, 3, 5, ... -> CFG step.
57
- cfg_compute_first: bool = False
58
- # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
59
- # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
60
- # use the computed diff from current non-CFG transformer step for current CFG step.
61
- cfg_diff_compute_separate: bool = True
62
-
63
- def update(self, **kwargs) -> "BasicCacheConfig":
64
- for key, value in kwargs.items():
65
- if hasattr(self, key):
66
- setattr(self, key, value)
67
- return self
68
-
69
- def strify(self) -> str:
70
- return (
71
- f"DBCACHE_F{self.Fn_compute_blocks}"
72
- f"B{self.Bn_compute_blocks}_"
73
- f"W{self.max_warmup_steps}"
74
- f"M{max(0, self.max_cached_steps)}"
75
- f"MC{max(0, self.max_continuous_cached_steps)}_"
76
- f"R{self.residual_diff_threshold}"
77
- )
78
-
79
-
80
- @dataclasses.dataclass
81
- class ExtraCacheConfig:
82
- # Some other not very important settings for Dual Block Cache.
83
- # NOTE: These flags maybe deprecated in the future and users
84
- # should never use these extra configurations in their cases.
85
-
86
- # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
87
- # The hidden states diff threshold for DBCache if use hidden_states as
88
- # cache (not residual).
89
- l1_hidden_states_diff_threshold: float = None
90
- # important_condition_threshold (`float`, *optional*, defaults to 0.0):
91
- # Only select the most important tokens while calculating the l1 diff.
92
- important_condition_threshold: float = 0.0
93
- # downsample_factor (`int`, *optional*, defaults to 1):
94
- # Downsample factor for Fn buffer, in order the save GPU memory.
95
- downsample_factor: int = 1
96
- # num_inference_steps (`int`, *optional*, defaults to -1):
97
- # num_inference_steps for DiffusionPipeline, for future use.
98
- num_inference_steps: int = -1
99
-
100
-
101
25
  @dataclasses.dataclass
102
26
  class CachedContext:
103
27
  name: str = "default"
104
28
  # Buffer for storing the residuals and other tensors
105
29
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
106
30
  # Basic Dual Block Cache Config
107
- cache_config: BasicCacheConfig = dataclasses.field(
31
+ cache_config: Union[
32
+ BasicCacheConfig,
33
+ DBCacheConfig,
34
+ ] = dataclasses.field(
108
35
  default_factory=BasicCacheConfig,
109
36
  )
110
37
  # NOTE: Users should never use these extra configurations.
@@ -126,14 +53,14 @@ class CachedContext:
126
53
  # be double of executed_steps.
127
54
  transformer_executed_steps: int = 0
128
55
 
129
- # CFG & non-CFG cached steps
56
+ # CFG & non-CFG cached/pruned steps
130
57
  cached_steps: List[int] = dataclasses.field(default_factory=list)
131
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
58
+ residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
132
59
  default_factory=lambda: defaultdict(float),
133
60
  )
134
61
  continuous_cached_steps: int = 0
135
62
  cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
136
- cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
63
+ cfg_residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
137
64
  default_factory=lambda: defaultdict(float),
138
65
  )
139
66
  cfg_continuous_cached_steps: int = 0
@@ -281,7 +208,9 @@ class CachedContext:
281
208
  def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
282
209
  return self.cfg_calibrator, self.cfg_encoder_calibrator
283
210
 
284
- def add_residual_diff(self, diff):
211
+ def add_residual_diff(self, diff: float | torch.Tensor):
212
+ if isinstance(diff, torch.Tensor):
213
+ diff = diff.item()
285
214
  # step: executed_steps - 1, not transformer_steps - 1
286
215
  step = str(self.get_current_step())
287
216
  # Only add the diff if it is not already recorded for this step
@@ -346,5 +275,15 @@ class CachedContext:
346
275
  # CFG steps: 1, 3, 5, 7, ...
347
276
  return self.get_current_transformer_step() % 2 != 0
348
277
 
278
+ @property
279
+ def warmup_steps(self) -> List[int]:
280
+ return list(
281
+ range(
282
+ 0,
283
+ self.cache_config.max_warmup_steps,
284
+ self.cache_config.warmup_interval,
285
+ )
286
+ )
287
+
349
288
  def is_in_warmup(self):
350
- return self.get_current_step() < self.cache_config.max_warmup_steps
289
+ return self.get_current_step() in self.warmup_steps
@@ -5,8 +5,9 @@ from typing import Dict, Optional, Tuple, Union, List
5
5
  import torch
6
6
  import torch.distributed as dist
7
7
 
8
- from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
9
- from cache_dit.cache_factory.cache_contexts.cache_context import (
8
+ from cache_dit.caching.cache_contexts.calibrators import CalibratorBase
9
+ from cache_dit.caching.cache_contexts.cache_context import (
10
+ BasicCacheConfig,
10
11
  CachedContext,
11
12
  )
12
13
  from cache_dit.logger import init_logger
@@ -14,29 +15,156 @@ from cache_dit.logger import init_logger
14
15
  logger = init_logger(__name__)
15
16
 
16
17
 
18
+ class ContextNotExistError(Exception):
19
+ pass
20
+
21
+
17
22
  class CachedContextManager:
18
23
  # Each Pipeline should have it's own context manager instance.
19
24
 
20
- def __init__(self, name: str = None):
25
+ def __init__(self, name: str = None, persistent_context: bool = False):
21
26
  self.name = name
22
27
  self._current_context: CachedContext = None
23
28
  self._cached_context_manager: Dict[str, CachedContext] = {}
29
+ # Whether to create new context automatically when setting
30
+ # a non-exist context name. Persistent context is useful when
31
+ # the pipeline class is not provided and users want to use
32
+ # cache-dit in a transformer-only way.
33
+ self._persistent_context = persistent_context
34
+ self._current_step_refreshed: bool = False
35
+
36
+ @property
37
+ def persistent_context(self) -> bool:
38
+ return self._persistent_context
39
+
40
+ @property
41
+ def current_context(self) -> CachedContext:
42
+ return self._current_context
43
+
44
+ @property
45
+ @torch.compiler.disable
46
+ def current_step_refreshed(self) -> bool:
47
+ return self._current_step_refreshed
24
48
 
49
+ @torch.compiler.disable
50
+ def is_pre_refreshed(self) -> bool:
51
+ _context = self._current_context
52
+ if _context is None:
53
+ return False
54
+
55
+ num_inference_steps = _context.cache_config.num_inference_steps
56
+ if num_inference_steps is not None:
57
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
58
+ return current_step == num_inference_steps - 1
59
+ return False
60
+
61
+ @torch.compiler.disable
25
62
  def new_context(self, *args, **kwargs) -> CachedContext:
63
+ if self._persistent_context:
64
+ cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
65
+ assert (
66
+ cache_config is not None
67
+ and cache_config.num_inference_steps is not None
68
+ ), (
69
+ "When persistent_context is True, num_inference_steps "
70
+ "must be set in cache_config for proper cache refreshing."
71
+ f"\nkwargs: {kwargs}"
72
+ )
26
73
  _context = CachedContext(*args, **kwargs)
74
+ # NOTE: Patch args and kwargs for implicit refresh.
75
+ _context._init_args = args # maybe empty tuple: ()
76
+ _context._init_kwargs = kwargs # maybe empty dict: {}
27
77
  self._cached_context_manager[_context.name] = _context
28
78
  return _context
29
79
 
30
- def set_context(self, cached_context: CachedContext | str):
80
+ @torch.compiler.disable
81
+ def maybe_refresh(
82
+ self,
83
+ cached_context: Optional[CachedContext | str] = None,
84
+ ) -> bool:
85
+ if cached_context is None:
86
+ _context = self._current_context
87
+ assert _context is not None, "Current context is not set!"
88
+
89
+ if isinstance(cached_context, CachedContext):
90
+ _context = cached_context
91
+ else:
92
+ if cached_context not in self._cached_context_manager:
93
+ raise ContextNotExistError("Context not exist!")
94
+ _context = self._cached_context_manager[cached_context]
95
+
96
+ if self._persistent_context:
97
+ assert _context.cache_config.num_inference_steps is not None, (
98
+ "When persistent_context is True, num_inference_steps must be set "
99
+ "in cache_config for proper cache refreshing."
100
+ )
101
+
102
+ num_inference_steps = _context.cache_config.num_inference_steps
103
+ if num_inference_steps is not None:
104
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
105
+ # Another round of inference, need to refresh cache context.
106
+ if current_step >= num_inference_steps:
107
+ if logger.isEnabledFor(logging.DEBUG):
108
+ logger.debug(
109
+ f"Refreshing cache context '{_context.name}' "
110
+ f"as current step: {current_step} >= "
111
+ f"num_inference_steps: {num_inference_steps}."
112
+ )
113
+ return True
114
+ return False
115
+
116
+ @torch.compiler.disable
117
+ def set_context(
118
+ self,
119
+ cached_context: CachedContext | str,
120
+ *args,
121
+ **kwargs,
122
+ ) -> CachedContext:
31
123
  if isinstance(cached_context, CachedContext):
32
124
  self._current_context = cached_context
33
125
  else:
34
- self._current_context = self._cached_context_manager[cached_context]
126
+ if cached_context not in self._cached_context_manager:
127
+ if not self._persistent_context:
128
+ raise ContextNotExistError(
129
+ "Context not exist and persistent_context is False. Please "
130
+ "create new context first or set persistent_context=True."
131
+ )
132
+ else:
133
+ # Create new context if not exist
134
+ if any((bool(args), bool(kwargs))):
135
+ kwargs["name"] = cached_context
136
+ self._current_context = self.new_context(
137
+ *args, **kwargs
138
+ )
139
+ else:
140
+ raise ValueError(
141
+ "To create new context, please provide args and kwargs."
142
+ )
143
+ else:
144
+ self._current_context = self._cached_context_manager[
145
+ cached_context
146
+ ]
147
+
148
+ if self.maybe_refresh(self._current_context):
149
+ if not any((bool(args), bool(kwargs))):
150
+ assert hasattr(self._current_context, "_init_args")
151
+ assert hasattr(self._current_context, "_init_kwargs")
152
+ args = self._current_context._init_args
153
+ kwargs = self._current_context._init_kwargs
154
+
155
+ self._current_context = self.reset_context(
156
+ self._current_context, *args, **kwargs
157
+ )
158
+ self._current_step_refreshed = True
159
+ else:
160
+ self._current_step_refreshed = False
161
+
162
+ return self._current_context
35
163
 
36
164
  def get_context(self, name: str = None) -> CachedContext:
37
165
  if name is not None:
38
166
  if name not in self._cached_context_manager:
39
- raise ValueError("Context not exist!")
167
+ raise ContextNotExistError("Context not exist!")
40
168
  return self._cached_context_manager[name]
41
169
  return self._current_context
42
170
 
@@ -475,7 +603,7 @@ class CachedContextManager:
475
603
 
476
604
  if calibrator is not None:
477
605
  # Use calibrator to update the buffer
478
- calibrator.update(buffer)
606
+ calibrator.update(buffer, name=prefix)
479
607
  else:
480
608
  if logger.isEnabledFor(logging.DEBUG):
481
609
  logger.debug(
@@ -506,7 +634,7 @@ class CachedContextManager:
506
634
  calibrator, _ = self.get_calibrator()
507
635
 
508
636
  if calibrator is not None:
509
- return calibrator.approximate()
637
+ return calibrator.approximate(name=prefix)
510
638
  else:
511
639
  if logger.isEnabledFor(logging.DEBUG):
512
640
  logger.debug(
@@ -544,7 +672,7 @@ class CachedContextManager:
544
672
 
545
673
  if encoder_calibrator is not None:
546
674
  # Use CalibratorBase to update the buffer
547
- encoder_calibrator.update(buffer)
675
+ encoder_calibrator.update(buffer, name=prefix)
548
676
  else:
549
677
  if logger.isEnabledFor(logging.DEBUG):
550
678
  logger.debug(
@@ -575,7 +703,7 @@ class CachedContextManager:
575
703
 
576
704
  if encoder_calibrator is not None:
577
705
  # Use calibrator to approximate the value
578
- return encoder_calibrator.approximate()
706
+ return encoder_calibrator.approximate(name=prefix)
579
707
  else:
580
708
  if logger.isEnabledFor(logging.DEBUG):
581
709
  logger.debug(
@@ -1,10 +1,10 @@
1
- from cache_dit.cache_factory.cache_contexts.calibrators.base import (
1
+ from cache_dit.caching.cache_contexts.calibrators.base import (
2
2
  CalibratorBase,
3
3
  )
4
- from cache_dit.cache_factory.cache_contexts.calibrators.taylorseer import (
4
+ from cache_dit.caching.cache_contexts.calibrators.taylorseer import (
5
5
  TaylorSeerCalibrator,
6
6
  )
7
- from cache_dit.cache_factory.cache_contexts.calibrators.foca import (
7
+ from cache_dit.caching.cache_contexts.calibrators.foca import (
8
8
  FoCaCalibrator,
9
9
  )
10
10
 
@@ -45,6 +45,28 @@ class CalibratorConfig:
45
45
  def to_kwargs(self) -> Dict:
46
46
  return self.calibrator_kwargs.copy()
47
47
 
48
+ def as_dict(self) -> dict:
49
+ return dataclasses.asdict(self)
50
+
51
+ def update(self, **kwargs) -> "CalibratorConfig":
52
+ for key, value in kwargs.items():
53
+ if hasattr(self, key):
54
+ if value is not None:
55
+ setattr(self, key, value)
56
+ return self
57
+
58
+ def empty(self, **kwargs) -> "CalibratorConfig":
59
+ # Set all fields to None
60
+ for field in dataclasses.fields(self):
61
+ if hasattr(self, field.name):
62
+ setattr(self, field.name, None)
63
+ if kwargs:
64
+ self.update(**kwargs)
65
+ return self
66
+
67
+ def reset(self, **kwargs) -> "CalibratorConfig":
68
+ return self.empty(**kwargs)
69
+
48
70
 
49
71
  @dataclasses.dataclass
50
72
  class TaylorSeerCalibratorConfig(CalibratorConfig):
@@ -1,4 +1,4 @@
1
- from cache_dit.cache_factory.cache_contexts.calibrators.base import (
1
+ from cache_dit.caching.cache_contexts.calibrators.base import (
2
2
  CalibratorBase,
3
3
  )
4
4