cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (104) hide show
  1. cache_dit/__init__.py +37 -19
  2. cache_dit/_version.py +2 -2
  3. cache_dit/caching/__init__.py +36 -0
  4. cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
  5. cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
  6. cache_dit/caching/block_adapters/block_registers.py +118 -0
  7. cache_dit/caching/cache_adapters/__init__.py +1 -0
  8. cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
  9. cache_dit/caching/cache_blocks/__init__.py +226 -0
  10. cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
  11. cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
  12. cache_dit/caching/cache_blocks/pattern_base.py +748 -0
  13. cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
  14. cache_dit/caching/cache_contexts/__init__.py +28 -0
  15. cache_dit/caching/cache_contexts/cache_config.py +120 -0
  16. cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
  17. cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
  18. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
  19. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
  20. cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
  21. cache_dit/caching/cache_contexts/context_manager.py +36 -0
  22. cache_dit/caching/cache_contexts/prune_config.py +63 -0
  23. cache_dit/caching/cache_contexts/prune_context.py +155 -0
  24. cache_dit/caching/cache_contexts/prune_manager.py +167 -0
  25. cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
  26. cache_dit/{cache_factory → caching}/cache_types.py +19 -2
  27. cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
  28. cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
  29. cache_dit/caching/patch_functors/__init__.py +15 -0
  30. cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
  31. cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
  32. cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
  33. cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
  34. cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
  35. cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
  36. cache_dit/{cache_factory → caching}/utils.py +19 -8
  37. cache_dit/metrics/__init__.py +11 -0
  38. cache_dit/parallelism/__init__.py +3 -0
  39. cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
  40. cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
  41. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
  42. cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
  43. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
  44. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
  45. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
  46. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
  47. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
  48. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
  49. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
  50. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
  51. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
  52. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
  53. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
  54. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
  55. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
  56. cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
  57. cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
  58. cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
  59. cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
  60. cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
  61. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
  62. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
  63. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
  64. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
  65. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
  66. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
  67. cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
  68. cache_dit/parallelism/parallel_backend.py +26 -0
  69. cache_dit/parallelism/parallel_config.py +88 -0
  70. cache_dit/parallelism/parallel_interface.py +77 -0
  71. cache_dit/quantize/__init__.py +7 -0
  72. cache_dit/quantize/backends/__init__.py +1 -0
  73. cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
  74. cache_dit/quantize/backends/torchao/__init__.py +1 -0
  75. cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
  76. cache_dit/quantize/quantize_backend.py +0 -0
  77. cache_dit/quantize/quantize_config.py +0 -0
  78. cache_dit/quantize/quantize_interface.py +3 -16
  79. cache_dit/summary.py +593 -0
  80. cache_dit/utils.py +46 -290
  81. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
  82. cache_dit-1.0.14.dist-info/RECORD +102 -0
  83. cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
  84. cache_dit/cache_factory/__init__.py +0 -28
  85. cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
  86. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  87. cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
  88. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
  89. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
  90. cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
  91. cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
  92. cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
  93. cache_dit/cache_factory/patch_functors/__init__.py +0 -15
  94. cache_dit-1.0.3.dist-info/RECORD +0 -58
  95. cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
  96. /cache_dit/{cache_factory → caching}/.gitignore +0 -0
  97. /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
  98. /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
  99. /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
  100. /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
  101. /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
  102. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
  103. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
  104. {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,86 @@
1
+ import torch
2
+
3
+ from typing import Any
4
+ from cache_dit.caching import CachedContext
5
+ from cache_dit.caching import CachedContextManager
6
+ from cache_dit.caching import PrunedContextManager
7
+
8
+
9
+ def apply_stats(
10
+ module: torch.nn.Module | Any,
11
+ cache_context: CachedContext | str = None,
12
+ context_manager: CachedContextManager | PrunedContextManager = None,
13
+ ):
14
+ # Patch the cached stats to the module, the cached stats
15
+ # will be reset for each calling of pipe.__call__(**kwargs).
16
+ if module is None or context_manager is None:
17
+ return
18
+
19
+ if cache_context is not None:
20
+ context_manager.set_context(cache_context)
21
+
22
+ # Cache stats for Dual Block Cache
23
+ module._cached_steps = context_manager.get_cached_steps()
24
+ module._residual_diffs = context_manager.get_residual_diffs()
25
+ module._cfg_cached_steps = context_manager.get_cfg_cached_steps()
26
+ module._cfg_residual_diffs = context_manager.get_cfg_residual_diffs()
27
+ # Pruned stats for Dynamic Block Prune
28
+ if not isinstance(context_manager, PrunedContextManager):
29
+ return
30
+ module._pruned_steps = context_manager.get_pruned_steps()
31
+ module._cfg_pruned_steps = context_manager.get_cfg_pruned_steps()
32
+ module._pruned_blocks = context_manager.get_pruned_blocks()
33
+ module._cfg_pruned_blocks = context_manager.get_cfg_pruned_blocks()
34
+ module._actual_blocks = context_manager.get_actual_blocks()
35
+ module._cfg_actual_blocks = context_manager.get_cfg_actual_blocks()
36
+ # Caculate pruned ratio
37
+ if len(module._pruned_blocks) > 0 and sum(module._actual_blocks) > 0:
38
+ module._pruned_ratio = sum(module._pruned_blocks) / sum(
39
+ module._actual_blocks
40
+ )
41
+ else:
42
+ module._pruned_ratio = None
43
+ if (
44
+ len(module._cfg_pruned_blocks) > 0
45
+ and sum(module._cfg_actual_blocks) > 0
46
+ ):
47
+ module._cfg_pruned_ratio = sum(module._cfg_pruned_blocks) / sum(
48
+ module._cfg_actual_blocks
49
+ )
50
+ else:
51
+ module._cfg_pruned_ratio = None
52
+
53
+
54
+ def remove_stats(
55
+ module: torch.nn.Module | Any,
56
+ ):
57
+ if module is None:
58
+ return
59
+
60
+ # Dual Block Cache
61
+ if hasattr(module, "_cached_steps"):
62
+ del module._cached_steps
63
+ if hasattr(module, "_residual_diffs"):
64
+ del module._residual_diffs
65
+ if hasattr(module, "_cfg_cached_steps"):
66
+ del module._cfg_cached_steps
67
+ if hasattr(module, "_cfg_residual_diffs"):
68
+ del module._cfg_residual_diffs
69
+
70
+ # Dynamic Block Prune
71
+ if hasattr(module, "_pruned_steps"):
72
+ del module._pruned_steps
73
+ if hasattr(module, "_cfg_pruned_steps"):
74
+ del module._cfg_pruned_steps
75
+ if hasattr(module, "_pruned_blocks"):
76
+ del module._pruned_blocks
77
+ if hasattr(module, "_cfg_pruned_blocks"):
78
+ del module._cfg_pruned_blocks
79
+ if hasattr(module, "_actual_blocks"):
80
+ del module._actual_blocks
81
+ if hasattr(module, "_cfg_actual_blocks"):
82
+ del module._cfg_actual_blocks
83
+ if hasattr(module, "_pruned_ratio"):
84
+ del module._pruned_ratio
85
+ if hasattr(module, "_cfg_pruned_ratio"):
86
+ del module._cfg_pruned_ratio
@@ -0,0 +1,28 @@
1
+ from cache_dit.caching.cache_contexts.calibrators import (
2
+ Calibrator,
3
+ CalibratorBase,
4
+ CalibratorConfig,
5
+ TaylorSeerCalibratorConfig,
6
+ FoCaCalibratorConfig,
7
+ )
8
+ from cache_dit.caching.cache_contexts.cache_config import (
9
+ BasicCacheConfig,
10
+ DBCacheConfig,
11
+ )
12
+ from cache_dit.caching.cache_contexts.cache_context import (
13
+ CachedContext,
14
+ )
15
+ from cache_dit.caching.cache_contexts.cache_manager import (
16
+ CachedContextManager,
17
+ ContextNotExistError,
18
+ )
19
+ from cache_dit.caching.cache_contexts.prune_config import DBPruneConfig
20
+ from cache_dit.caching.cache_contexts.prune_context import (
21
+ PrunedContext,
22
+ )
23
+ from cache_dit.caching.cache_contexts.prune_manager import (
24
+ PrunedContextManager,
25
+ )
26
+ from cache_dit.caching.cache_contexts.context_manager import (
27
+ ContextManager,
28
+ )
@@ -0,0 +1,120 @@
1
+ import torch
2
+ import dataclasses
3
+ from typing import Optional, Union
4
+ from cache_dit.caching.cache_types import CacheType
5
+ from cache_dit.logger import init_logger
6
+
7
+ logger = init_logger(__name__)
8
+
9
+
10
+ @dataclasses.dataclass
11
+ class BasicCacheConfig:
12
+ # Default: Dual Block Cache with Flexible FnBn configuration.
13
+ cache_type: CacheType = CacheType.DBCache # DBCache, DBPrune, NONE
14
+
15
+ # Fn_compute_blocks: (`int`, *required*, defaults to 8):
16
+ # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
17
+ # at time step t, enabling the calculation of a more stable L1 diff and delivering more
18
+ # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
19
+ # for more details of DBCache.
20
+ Fn_compute_blocks: int = 8
21
+ # Bn_compute_blocks: (`int`, *required*, defaults to 0):
22
+ # Further fuses approximate information in the **last n** Transformer blocks to enhance
23
+ # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
24
+ # that use residual cache.
25
+ Bn_compute_blocks: int = 0
26
+ # residual_diff_threshold (`float`, *required*, defaults to 0.08):
27
+ # the value of residual diff threshold, a higher value leads to faster performance at the
28
+ # cost of lower precision.
29
+ residual_diff_threshold: Union[torch.Tensor, float] = 0.08
30
+ # max_warmup_steps (`int`, *required*, defaults to 8):
31
+ # DBCache does not apply the caching strategy when the number of running steps is less than
32
+ # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
33
+ max_warmup_steps: int = 8 # DON'T Cache in warmup steps
34
+ # warmup_interval (`int`, *required*, defaults to 1):
35
+ # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
36
+ # in warmup steps will be computed, others will use dynamic cache.
37
+ warmup_interval: int = 1 # skip interval in warmup steps
38
+ # max_cached_steps (`int`, *required*, defaults to -1):
39
+ # DBCache disables the caching strategy when the previous cached steps exceed this value to
40
+ # prevent precision degradation.
41
+ max_cached_steps: int = -1 # for both CFG and non-CFG
42
+ # max_continuous_cached_steps (`int`, *required*, defaults to -1):
43
+ # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
44
+ # prevent precision degradation.
45
+ max_continuous_cached_steps: int = -1 # the max continuous cached steps
46
+ # enable_separate_cfg (`bool`, *required*, defaults to None):
47
+ # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
48
+ # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
49
+ # CogVideoX, HunyuanVideo, Mochi, etc.
50
+ enable_separate_cfg: Optional[bool] = None
51
+ # cfg_compute_first (`bool`, *required*, defaults to False):
52
+ # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
53
+ # 1, 3, 5, ... -> CFG step.
54
+ cfg_compute_first: bool = False
55
+ # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
56
+ # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
57
+ # use the computed diff from current non-CFG transformer step for current CFG step.
58
+ cfg_diff_compute_separate: bool = True
59
+ # num_inference_steps (`int`, *optional*, defaults to None):
60
+ # num_inference_steps for DiffusionPipeline, used to adjust some internal settings
61
+ # for better caching performance. For example, we will refresh the cache once the
62
+ # executed steps exceed num_inference_steps if num_inference_steps is provided.
63
+ num_inference_steps: Optional[int] = None
64
+
65
+ def update(self, **kwargs) -> "BasicCacheConfig":
66
+ for key, value in kwargs.items():
67
+ if hasattr(self, key):
68
+ if value is not None:
69
+ setattr(self, key, value)
70
+ return self
71
+
72
+ def empty(self, **kwargs) -> "BasicCacheConfig":
73
+ # Set all fields to None
74
+ for field in dataclasses.fields(self):
75
+ if hasattr(self, field.name):
76
+ setattr(self, field.name, None)
77
+ if kwargs:
78
+ self.update(**kwargs)
79
+ return self
80
+
81
+ def reset(self, **kwargs) -> "BasicCacheConfig":
82
+ return self.empty(**kwargs)
83
+
84
+ def as_dict(self) -> dict:
85
+ return dataclasses.asdict(self)
86
+
87
+ def strify(self) -> str:
88
+ return (
89
+ f"{self.cache_type}_"
90
+ f"F{self.Fn_compute_blocks}"
91
+ f"B{self.Bn_compute_blocks}_"
92
+ f"W{self.max_warmup_steps}"
93
+ f"I{self.warmup_interval}"
94
+ f"M{max(0, self.max_cached_steps)}"
95
+ f"MC{max(0, self.max_continuous_cached_steps)}_"
96
+ f"R{self.residual_diff_threshold}"
97
+ )
98
+
99
+
100
+ @dataclasses.dataclass
101
+ class ExtraCacheConfig:
102
+ # Some other not very important settings for Dual Block Cache.
103
+ # NOTE: These flags maybe deprecated in the future and users
104
+ # should never use these extra configurations in their cases.
105
+
106
+ # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
107
+ # The hidden states diff threshold for DBCache if use hidden_states as
108
+ # cache (not residual).
109
+ l1_hidden_states_diff_threshold: float = None
110
+ # important_condition_threshold (`float`, *optional*, defaults to 0.0):
111
+ # Only select the most important tokens while calculating the l1 diff.
112
+ important_condition_threshold: float = 0.0
113
+ # downsample_factor (`int`, *optional*, defaults to 1):
114
+ # Downsample factor for Fn buffer, in order the save GPU memory.
115
+ downsample_factor: int = 1
116
+
117
+
118
+ @dataclasses.dataclass
119
+ class DBCacheConfig(BasicCacheConfig):
120
+ pass # Just an alias for BasicCacheConfig
@@ -1,3 +1,5 @@
1
+ # The cache context codebase is adapted from FBCache. Over time its codebase
2
+ # diverged a lot, and context API is no longer compatible with FBCache.
1
3
  import logging
2
4
  import dataclasses
3
5
  from collections import defaultdict
@@ -5,7 +7,12 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
5
7
 
6
8
  import torch
7
9
 
8
- from cache_dit.cache_factory.cache_contexts.calibrators import (
10
+ from cache_dit.caching.cache_contexts.cache_config import (
11
+ BasicCacheConfig,
12
+ ExtraCacheConfig,
13
+ DBCacheConfig,
14
+ )
15
+ from cache_dit.caching.cache_contexts.calibrators import (
9
16
  Calibrator,
10
17
  CalibratorBase,
11
18
  CalibratorConfig,
@@ -15,101 +22,16 @@ from cache_dit.logger import init_logger
15
22
  logger = init_logger(__name__)
16
23
 
17
24
 
18
- @dataclasses.dataclass
19
- class BasicCacheConfig:
20
- # Dual Block Cache with Flexible FnBn configuration.
21
-
22
- # Fn_compute_blocks: (`int`, *required*, defaults to 8):
23
- # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
24
- # at time step t, enabling the calculation of a more stable L1 diff and delivering more
25
- # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
26
- # for more details of DBCache.
27
- Fn_compute_blocks: int = 8
28
- # Bn_compute_blocks: (`int`, *required*, defaults to 0):
29
- # Further fuses approximate information in the **last n** Transformer blocks to enhance
30
- # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
31
- # that use residual cache.
32
- Bn_compute_blocks: int = 0
33
- # residual_diff_threshold (`float`, *required*, defaults to 0.08):
34
- # the value of residual diff threshold, a higher value leads to faster performance at the
35
- # cost of lower precision.
36
- residual_diff_threshold: Union[torch.Tensor, float] = 0.08
37
- # max_warmup_steps (`int`, *required*, defaults to 8):
38
- # DBCache does not apply the caching strategy when the number of running steps is less than
39
- # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
40
- max_warmup_steps: int = 8 # DON'T Cache in warmup steps
41
- # warmup_interval (`int`, *required*, defaults to 1):
42
- # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
43
- # in warmup steps will be computed, others will use dynamic cache.
44
- warmup_interval: int = 1 # skip interval in warmup steps
45
- # max_cached_steps (`int`, *required*, defaults to -1):
46
- # DBCache disables the caching strategy when the previous cached steps exceed this value to
47
- # prevent precision degradation.
48
- max_cached_steps: int = -1 # for both CFG and non-CFG
49
- # max_continuous_cached_steps (`int`, *required*, defaults to -1):
50
- # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
51
- # prevent precision degradation.
52
- max_continuous_cached_steps: int = -1 # the max continuous cached steps
53
- # enable_separate_cfg (`bool`, *required*, defaults to None):
54
- # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
55
- # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
56
- # CogVideoX, HunyuanVideo, Mochi, etc.
57
- enable_separate_cfg: Optional[bool] = None
58
- # cfg_compute_first (`bool`, *required*, defaults to False):
59
- # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
60
- # 1, 3, 5, ... -> CFG step.
61
- cfg_compute_first: bool = False
62
- # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
63
- # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
64
- # use the computed diff from current non-CFG transformer step for current CFG step.
65
- cfg_diff_compute_separate: bool = True
66
-
67
- def update(self, **kwargs) -> "BasicCacheConfig":
68
- for key, value in kwargs.items():
69
- if hasattr(self, key):
70
- setattr(self, key, value)
71
- return self
72
-
73
- def strify(self) -> str:
74
- return (
75
- f"DBCACHE_F{self.Fn_compute_blocks}"
76
- f"B{self.Bn_compute_blocks}_"
77
- f"W{self.max_warmup_steps}"
78
- f"I{self.warmup_interval}"
79
- f"M{max(0, self.max_cached_steps)}"
80
- f"MC{max(0, self.max_continuous_cached_steps)}_"
81
- f"R{self.residual_diff_threshold}"
82
- )
83
-
84
-
85
- @dataclasses.dataclass
86
- class ExtraCacheConfig:
87
- # Some other not very important settings for Dual Block Cache.
88
- # NOTE: These flags maybe deprecated in the future and users
89
- # should never use these extra configurations in their cases.
90
-
91
- # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
92
- # The hidden states diff threshold for DBCache if use hidden_states as
93
- # cache (not residual).
94
- l1_hidden_states_diff_threshold: float = None
95
- # important_condition_threshold (`float`, *optional*, defaults to 0.0):
96
- # Only select the most important tokens while calculating the l1 diff.
97
- important_condition_threshold: float = 0.0
98
- # downsample_factor (`int`, *optional*, defaults to 1):
99
- # Downsample factor for Fn buffer, in order the save GPU memory.
100
- downsample_factor: int = 1
101
- # num_inference_steps (`int`, *optional*, defaults to -1):
102
- # num_inference_steps for DiffusionPipeline, for future use.
103
- num_inference_steps: int = -1
104
-
105
-
106
25
  @dataclasses.dataclass
107
26
  class CachedContext:
108
27
  name: str = "default"
109
28
  # Buffer for storing the residuals and other tensors
110
29
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
111
30
  # Basic Dual Block Cache Config
112
- cache_config: BasicCacheConfig = dataclasses.field(
31
+ cache_config: Union[
32
+ BasicCacheConfig,
33
+ DBCacheConfig,
34
+ ] = dataclasses.field(
113
35
  default_factory=BasicCacheConfig,
114
36
  )
115
37
  # NOTE: Users should never use these extra configurations.
@@ -131,14 +53,14 @@ class CachedContext:
131
53
  # be double of executed_steps.
132
54
  transformer_executed_steps: int = 0
133
55
 
134
- # CFG & non-CFG cached steps
56
+ # CFG & non-CFG cached/pruned steps
135
57
  cached_steps: List[int] = dataclasses.field(default_factory=list)
136
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
58
+ residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
137
59
  default_factory=lambda: defaultdict(float),
138
60
  )
139
61
  continuous_cached_steps: int = 0
140
62
  cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
141
- cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
63
+ cfg_residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
142
64
  default_factory=lambda: defaultdict(float),
143
65
  )
144
66
  cfg_continuous_cached_steps: int = 0
@@ -286,7 +208,9 @@ class CachedContext:
286
208
  def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
287
209
  return self.cfg_calibrator, self.cfg_encoder_calibrator
288
210
 
289
- def add_residual_diff(self, diff):
211
+ def add_residual_diff(self, diff: float | torch.Tensor):
212
+ if isinstance(diff, torch.Tensor):
213
+ diff = diff.item()
290
214
  # step: executed_steps - 1, not transformer_steps - 1
291
215
  step = str(self.get_current_step())
292
216
  # Only add the diff if it is not already recorded for this step
@@ -5,8 +5,9 @@ from typing import Dict, Optional, Tuple, Union, List
5
5
  import torch
6
6
  import torch.distributed as dist
7
7
 
8
- from cache_dit.cache_factory.cache_contexts.calibrators import CalibratorBase
9
- from cache_dit.cache_factory.cache_contexts.cache_context import (
8
+ from cache_dit.caching.cache_contexts.calibrators import CalibratorBase
9
+ from cache_dit.caching.cache_contexts.cache_context import (
10
+ BasicCacheConfig,
10
11
  CachedContext,
11
12
  )
12
13
  from cache_dit.logger import init_logger
@@ -14,36 +15,156 @@ from cache_dit.logger import init_logger
14
15
  logger = init_logger(__name__)
15
16
 
16
17
 
17
- class CacheNotExistError(Exception):
18
+ class ContextNotExistError(Exception):
18
19
  pass
19
20
 
20
21
 
21
22
  class CachedContextManager:
22
23
  # Each Pipeline should have it's own context manager instance.
23
24
 
24
- def __init__(self, name: str = None):
25
+ def __init__(self, name: str = None, persistent_context: bool = False):
25
26
  self.name = name
26
27
  self._current_context: CachedContext = None
27
28
  self._cached_context_manager: Dict[str, CachedContext] = {}
29
+ # Whether to create new context automatically when setting
30
+ # a non-exist context name. Persistent context is useful when
31
+ # the pipeline class is not provided and users want to use
32
+ # cache-dit in a transformer-only way.
33
+ self._persistent_context = persistent_context
34
+ self._current_step_refreshed: bool = False
35
+
36
+ @property
37
+ def persistent_context(self) -> bool:
38
+ return self._persistent_context
39
+
40
+ @property
41
+ def current_context(self) -> CachedContext:
42
+ return self._current_context
43
+
44
+ @property
45
+ @torch.compiler.disable
46
+ def current_step_refreshed(self) -> bool:
47
+ return self._current_step_refreshed
28
48
 
49
+ @torch.compiler.disable
50
+ def is_pre_refreshed(self) -> bool:
51
+ _context = self._current_context
52
+ if _context is None:
53
+ return False
54
+
55
+ num_inference_steps = _context.cache_config.num_inference_steps
56
+ if num_inference_steps is not None:
57
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
58
+ return current_step == num_inference_steps - 1
59
+ return False
60
+
61
+ @torch.compiler.disable
29
62
  def new_context(self, *args, **kwargs) -> CachedContext:
63
+ if self._persistent_context:
64
+ cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
65
+ assert (
66
+ cache_config is not None
67
+ and cache_config.num_inference_steps is not None
68
+ ), (
69
+ "When persistent_context is True, num_inference_steps "
70
+ "must be set in cache_config for proper cache refreshing."
71
+ f"\nkwargs: {kwargs}"
72
+ )
30
73
  _context = CachedContext(*args, **kwargs)
74
+ # NOTE: Patch args and kwargs for implicit refresh.
75
+ _context._init_args = args # maybe empty tuple: ()
76
+ _context._init_kwargs = kwargs # maybe empty dict: {}
31
77
  self._cached_context_manager[_context.name] = _context
32
78
  return _context
33
79
 
34
- def set_context(self, cached_context: CachedContext | str) -> CachedContext:
80
+ @torch.compiler.disable
81
+ def maybe_refresh(
82
+ self,
83
+ cached_context: Optional[CachedContext | str] = None,
84
+ ) -> bool:
85
+ if cached_context is None:
86
+ _context = self._current_context
87
+ assert _context is not None, "Current context is not set!"
88
+
89
+ if isinstance(cached_context, CachedContext):
90
+ _context = cached_context
91
+ else:
92
+ if cached_context not in self._cached_context_manager:
93
+ raise ContextNotExistError("Context not exist!")
94
+ _context = self._cached_context_manager[cached_context]
95
+
96
+ if self._persistent_context:
97
+ assert _context.cache_config.num_inference_steps is not None, (
98
+ "When persistent_context is True, num_inference_steps must be set "
99
+ "in cache_config for proper cache refreshing."
100
+ )
101
+
102
+ num_inference_steps = _context.cache_config.num_inference_steps
103
+ if num_inference_steps is not None:
104
+ current_step = _context.get_current_step() # e.g, 0~49,50~99,...
105
+ # Another round of inference, need to refresh cache context.
106
+ if current_step >= num_inference_steps:
107
+ if logger.isEnabledFor(logging.DEBUG):
108
+ logger.debug(
109
+ f"Refreshing cache context '{_context.name}' "
110
+ f"as current step: {current_step} >= "
111
+ f"num_inference_steps: {num_inference_steps}."
112
+ )
113
+ return True
114
+ return False
115
+
116
+ @torch.compiler.disable
117
+ def set_context(
118
+ self,
119
+ cached_context: CachedContext | str,
120
+ *args,
121
+ **kwargs,
122
+ ) -> CachedContext:
35
123
  if isinstance(cached_context, CachedContext):
36
124
  self._current_context = cached_context
37
125
  else:
38
126
  if cached_context not in self._cached_context_manager:
39
- raise CacheNotExistError("Context not exist!")
40
- self._current_context = self._cached_context_manager[cached_context]
127
+ if not self._persistent_context:
128
+ raise ContextNotExistError(
129
+ "Context not exist and persistent_context is False. Please "
130
+ "create new context first or set persistent_context=True."
131
+ )
132
+ else:
133
+ # Create new context if not exist
134
+ if any((bool(args), bool(kwargs))):
135
+ kwargs["name"] = cached_context
136
+ self._current_context = self.new_context(
137
+ *args, **kwargs
138
+ )
139
+ else:
140
+ raise ValueError(
141
+ "To create new context, please provide args and kwargs."
142
+ )
143
+ else:
144
+ self._current_context = self._cached_context_manager[
145
+ cached_context
146
+ ]
147
+
148
+ if self.maybe_refresh(self._current_context):
149
+ if not any((bool(args), bool(kwargs))):
150
+ assert hasattr(self._current_context, "_init_args")
151
+ assert hasattr(self._current_context, "_init_kwargs")
152
+ args = self._current_context._init_args
153
+ kwargs = self._current_context._init_kwargs
154
+
155
+ self._current_context = self.reset_context(
156
+ self._current_context, *args, **kwargs
157
+ )
158
+ self._current_step_refreshed = True
159
+ else:
160
+ self._current_step_refreshed = False
161
+
41
162
  return self._current_context
42
163
 
43
164
  def get_context(self, name: str = None) -> CachedContext:
44
165
  if name is not None:
45
166
  if name not in self._cached_context_manager:
46
- raise CacheNotExistError("Context not exist!")
167
+ raise ContextNotExistError("Context not exist!")
47
168
  return self._cached_context_manager[name]
48
169
  return self._current_context
49
170
 
@@ -482,7 +603,7 @@ class CachedContextManager:
482
603
 
483
604
  if calibrator is not None:
484
605
  # Use calibrator to update the buffer
485
- calibrator.update(buffer)
606
+ calibrator.update(buffer, name=prefix)
486
607
  else:
487
608
  if logger.isEnabledFor(logging.DEBUG):
488
609
  logger.debug(
@@ -513,7 +634,7 @@ class CachedContextManager:
513
634
  calibrator, _ = self.get_calibrator()
514
635
 
515
636
  if calibrator is not None:
516
- return calibrator.approximate()
637
+ return calibrator.approximate(name=prefix)
517
638
  else:
518
639
  if logger.isEnabledFor(logging.DEBUG):
519
640
  logger.debug(
@@ -551,7 +672,7 @@ class CachedContextManager:
551
672
 
552
673
  if encoder_calibrator is not None:
553
674
  # Use CalibratorBase to update the buffer
554
- encoder_calibrator.update(buffer)
675
+ encoder_calibrator.update(buffer, name=prefix)
555
676
  else:
556
677
  if logger.isEnabledFor(logging.DEBUG):
557
678
  logger.debug(
@@ -582,7 +703,7 @@ class CachedContextManager:
582
703
 
583
704
  if encoder_calibrator is not None:
584
705
  # Use calibrator to approximate the value
585
- return encoder_calibrator.approximate()
706
+ return encoder_calibrator.approximate(name=prefix)
586
707
  else:
587
708
  if logger.isEnabledFor(logging.DEBUG):
588
709
  logger.debug(
@@ -1,10 +1,10 @@
1
- from cache_dit.cache_factory.cache_contexts.calibrators.base import (
1
+ from cache_dit.caching.cache_contexts.calibrators.base import (
2
2
  CalibratorBase,
3
3
  )
4
- from cache_dit.cache_factory.cache_contexts.calibrators.taylorseer import (
4
+ from cache_dit.caching.cache_contexts.calibrators.taylorseer import (
5
5
  TaylorSeerCalibrator,
6
6
  )
7
- from cache_dit.cache_factory.cache_contexts.calibrators.foca import (
7
+ from cache_dit.caching.cache_contexts.calibrators.foca import (
8
8
  FoCaCalibrator,
9
9
  )
10
10
 
@@ -45,6 +45,28 @@ class CalibratorConfig:
45
45
  def to_kwargs(self) -> Dict:
46
46
  return self.calibrator_kwargs.copy()
47
47
 
48
+ def as_dict(self) -> dict:
49
+ return dataclasses.asdict(self)
50
+
51
+ def update(self, **kwargs) -> "CalibratorConfig":
52
+ for key, value in kwargs.items():
53
+ if hasattr(self, key):
54
+ if value is not None:
55
+ setattr(self, key, value)
56
+ return self
57
+
58
+ def empty(self, **kwargs) -> "CalibratorConfig":
59
+ # Set all fields to None
60
+ for field in dataclasses.fields(self):
61
+ if hasattr(self, field.name):
62
+ setattr(self, field.name, None)
63
+ if kwargs:
64
+ self.update(**kwargs)
65
+ return self
66
+
67
+ def reset(self, **kwargs) -> "CalibratorConfig":
68
+ return self.empty(**kwargs)
69
+
48
70
 
49
71
  @dataclasses.dataclass
50
72
  class TaylorSeerCalibratorConfig(CalibratorConfig):
@@ -1,4 +1,4 @@
1
- from cache_dit.cache_factory.cache_contexts.calibrators.base import (
1
+ from cache_dit.caching.cache_contexts.calibrators.base import (
2
2
  CalibratorBase,
3
3
  )
4
4