cache-dit 1.0.3__py3-none-any.whl → 1.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (29) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
  5. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  6. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  7. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
  8. cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
  9. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  10. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  11. cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
  12. cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
  13. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  14. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  15. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  16. cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
  17. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  18. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  19. cache_dit/cache_factory/cache_interface.py +20 -14
  20. cache_dit/cache_factory/cache_types.py +19 -2
  21. cache_dit/cache_factory/params_modifier.py +7 -7
  22. cache_dit/cache_factory/utils.py +18 -7
  23. cache_dit/utils.py +191 -54
  24. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/METADATA +9 -9
  25. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
  26. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
  27. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
  28. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
  29. {cache_dit-1.0.3.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,11 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
5
5
 
6
6
  import torch
7
7
 
8
+ from cache_dit.cache_factory.cache_contexts.cache_config import (
9
+ BasicCacheConfig,
10
+ ExtraCacheConfig,
11
+ DBCacheConfig,
12
+ )
8
13
  from cache_dit.cache_factory.cache_contexts.calibrators import (
9
14
  Calibrator,
10
15
  CalibratorBase,
@@ -15,101 +20,16 @@ from cache_dit.logger import init_logger
15
20
  logger = init_logger(__name__)
16
21
 
17
22
 
18
- @dataclasses.dataclass
19
- class BasicCacheConfig:
20
- # Dual Block Cache with Flexible FnBn configuration.
21
-
22
- # Fn_compute_blocks: (`int`, *required*, defaults to 8):
23
- # Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
24
- # at time step t, enabling the calculation of a more stable L1 diff and delivering more
25
- # accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
26
- # for more details of DBCache.
27
- Fn_compute_blocks: int = 8
28
- # Bn_compute_blocks: (`int`, *required*, defaults to 0):
29
- # Further fuses approximate information in the **last n** Transformer blocks to enhance
30
- # prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
31
- # that use residual cache.
32
- Bn_compute_blocks: int = 0
33
- # residual_diff_threshold (`float`, *required*, defaults to 0.08):
34
- # the value of residual diff threshold, a higher value leads to faster performance at the
35
- # cost of lower precision.
36
- residual_diff_threshold: Union[torch.Tensor, float] = 0.08
37
- # max_warmup_steps (`int`, *required*, defaults to 8):
38
- # DBCache does not apply the caching strategy when the number of running steps is less than
39
- # or equal to this value, ensuring the model sufficiently learns basic features during warmup.
40
- max_warmup_steps: int = 8 # DON'T Cache in warmup steps
41
- # warmup_interval (`int`, *required*, defaults to 1):
42
- # Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
43
- # in warmup steps will be computed, others will use dynamic cache.
44
- warmup_interval: int = 1 # skip interval in warmup steps
45
- # max_cached_steps (`int`, *required*, defaults to -1):
46
- # DBCache disables the caching strategy when the previous cached steps exceed this value to
47
- # prevent precision degradation.
48
- max_cached_steps: int = -1 # for both CFG and non-CFG
49
- # max_continuous_cached_steps (`int`, *required*, defaults to -1):
50
- # DBCache disables the caching strategy when the previous continous cached steps exceed this value to
51
- # prevent precision degradation.
52
- max_continuous_cached_steps: int = -1 # the max continuous cached steps
53
- # enable_separate_cfg (`bool`, *required*, defaults to None):
54
- # Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
55
- # and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
56
- # CogVideoX, HunyuanVideo, Mochi, etc.
57
- enable_separate_cfg: Optional[bool] = None
58
- # cfg_compute_first (`bool`, *required*, defaults to False):
59
- # Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
60
- # 1, 3, 5, ... -> CFG step.
61
- cfg_compute_first: bool = False
62
- # cfg_diff_compute_separate (`bool`, *required*, defaults to True):
63
- # Compute separate diff values for CFG and non-CFG step, default True. If False, we will
64
- # use the computed diff from current non-CFG transformer step for current CFG step.
65
- cfg_diff_compute_separate: bool = True
66
-
67
- def update(self, **kwargs) -> "BasicCacheConfig":
68
- for key, value in kwargs.items():
69
- if hasattr(self, key):
70
- setattr(self, key, value)
71
- return self
72
-
73
- def strify(self) -> str:
74
- return (
75
- f"DBCACHE_F{self.Fn_compute_blocks}"
76
- f"B{self.Bn_compute_blocks}_"
77
- f"W{self.max_warmup_steps}"
78
- f"I{self.warmup_interval}"
79
- f"M{max(0, self.max_cached_steps)}"
80
- f"MC{max(0, self.max_continuous_cached_steps)}_"
81
- f"R{self.residual_diff_threshold}"
82
- )
83
-
84
-
85
- @dataclasses.dataclass
86
- class ExtraCacheConfig:
87
- # Some other not very important settings for Dual Block Cache.
88
- # NOTE: These flags maybe deprecated in the future and users
89
- # should never use these extra configurations in their cases.
90
-
91
- # l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
92
- # The hidden states diff threshold for DBCache if use hidden_states as
93
- # cache (not residual).
94
- l1_hidden_states_diff_threshold: float = None
95
- # important_condition_threshold (`float`, *optional*, defaults to 0.0):
96
- # Only select the most important tokens while calculating the l1 diff.
97
- important_condition_threshold: float = 0.0
98
- # downsample_factor (`int`, *optional*, defaults to 1):
99
- # Downsample factor for Fn buffer, in order the save GPU memory.
100
- downsample_factor: int = 1
101
- # num_inference_steps (`int`, *optional*, defaults to -1):
102
- # num_inference_steps for DiffusionPipeline, for future use.
103
- num_inference_steps: int = -1
104
-
105
-
106
23
  @dataclasses.dataclass
107
24
  class CachedContext:
108
25
  name: str = "default"
109
26
  # Buffer for storing the residuals and other tensors
110
27
  buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
111
28
  # Basic Dual Block Cache Config
112
- cache_config: BasicCacheConfig = dataclasses.field(
29
+ cache_config: Union[
30
+ BasicCacheConfig,
31
+ DBCacheConfig,
32
+ ] = dataclasses.field(
113
33
  default_factory=BasicCacheConfig,
114
34
  )
115
35
  # NOTE: Users should never use these extra configurations.
@@ -131,14 +51,14 @@ class CachedContext:
131
51
  # be double of executed_steps.
132
52
  transformer_executed_steps: int = 0
133
53
 
134
- # CFG & non-CFG cached steps
54
+ # CFG & non-CFG cached/pruned steps
135
55
  cached_steps: List[int] = dataclasses.field(default_factory=list)
136
- residual_diffs: DefaultDict[str, float] = dataclasses.field(
56
+ residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
137
57
  default_factory=lambda: defaultdict(float),
138
58
  )
139
59
  continuous_cached_steps: int = 0
140
60
  cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
141
- cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
61
+ cfg_residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
142
62
  default_factory=lambda: defaultdict(float),
143
63
  )
144
64
  cfg_continuous_cached_steps: int = 0
@@ -286,7 +206,9 @@ class CachedContext:
286
206
  def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
287
207
  return self.cfg_calibrator, self.cfg_encoder_calibrator
288
208
 
289
- def add_residual_diff(self, diff):
209
+ def add_residual_diff(self, diff: float | torch.Tensor):
210
+ if isinstance(diff, torch.Tensor):
211
+ diff = diff.item()
290
212
  # step: executed_steps - 1, not transformer_steps - 1
291
213
  step = str(self.get_current_step())
292
214
  # Only add the diff if it is not already recorded for this step
@@ -14,7 +14,7 @@ from cache_dit.logger import init_logger
14
14
  logger = init_logger(__name__)
15
15
 
16
16
 
17
- class CacheNotExistError(Exception):
17
+ class ContextNotExistError(Exception):
18
18
  pass
19
19
 
20
20
 
@@ -36,14 +36,14 @@ class CachedContextManager:
36
36
  self._current_context = cached_context
37
37
  else:
38
38
  if cached_context not in self._cached_context_manager:
39
- raise CacheNotExistError("Context not exist!")
39
+ raise ContextNotExistError("Context not exist!")
40
40
  self._current_context = self._cached_context_manager[cached_context]
41
41
  return self._current_context
42
42
 
43
43
  def get_context(self, name: str = None) -> CachedContext:
44
44
  if name is not None:
45
45
  if name not in self._cached_context_manager:
46
- raise CacheNotExistError("Context not exist!")
46
+ raise ContextNotExistError("Context not exist!")
47
47
  return self._cached_context_manager[name]
48
48
  return self._current_context
49
49
 
@@ -482,7 +482,7 @@ class CachedContextManager:
482
482
 
483
483
  if calibrator is not None:
484
484
  # Use calibrator to update the buffer
485
- calibrator.update(buffer)
485
+ calibrator.update(buffer, name=prefix)
486
486
  else:
487
487
  if logger.isEnabledFor(logging.DEBUG):
488
488
  logger.debug(
@@ -513,7 +513,7 @@ class CachedContextManager:
513
513
  calibrator, _ = self.get_calibrator()
514
514
 
515
515
  if calibrator is not None:
516
- return calibrator.approximate()
516
+ return calibrator.approximate(name=prefix)
517
517
  else:
518
518
  if logger.isEnabledFor(logging.DEBUG):
519
519
  logger.debug(
@@ -551,7 +551,7 @@ class CachedContextManager:
551
551
 
552
552
  if encoder_calibrator is not None:
553
553
  # Use CalibratorBase to update the buffer
554
- encoder_calibrator.update(buffer)
554
+ encoder_calibrator.update(buffer, name=prefix)
555
555
  else:
556
556
  if logger.isEnabledFor(logging.DEBUG):
557
557
  logger.debug(
@@ -582,7 +582,7 @@ class CachedContextManager:
582
582
 
583
583
  if encoder_calibrator is not None:
584
584
  # Use calibrator to approximate the value
585
- return encoder_calibrator.approximate()
585
+ return encoder_calibrator.approximate(name=prefix)
586
586
  else:
587
587
  if logger.isEnabledFor(logging.DEBUG):
588
588
  logger.debug(
@@ -10,13 +10,12 @@ from cache_dit.logger import init_logger
10
10
  logger = init_logger(__name__)
11
11
 
12
12
 
13
- class TaylorSeerCalibrator(CalibratorBase):
13
+ class TaylorSeerState:
14
14
  def __init__(
15
15
  self,
16
16
  n_derivatives=1,
17
17
  max_warmup_steps=1,
18
18
  skip_interval_steps=1,
19
- **kwargs,
20
19
  ):
21
20
  self.n_derivatives = n_derivatives
22
21
  self.order = n_derivatives + 1
@@ -28,9 +27,8 @@ class TaylorSeerCalibrator(CalibratorBase):
28
27
  "dY_prev": [None] * self.order,
29
28
  "dY_current": [None] * self.order,
30
29
  }
31
- self.reset_cache()
32
30
 
33
- def reset_cache(self): # NEED
31
+ def reset(self):
34
32
  self.state: Dict[str, List[torch.Tensor]] = {
35
33
  "dY_prev": [None] * self.order,
36
34
  "dY_current": [None] * self.order,
@@ -38,6 +36,9 @@ class TaylorSeerCalibrator(CalibratorBase):
38
36
  self.current_step = -1
39
37
  self.last_non_approximated_step = -1
40
38
 
39
+ def mark_step_begin(self): # NEED
40
+ self.current_step += 1
41
+
41
42
  def should_compute(self, step=None):
42
43
  step = self.current_step if step is None else step
43
44
  if (
@@ -56,7 +57,7 @@ class TaylorSeerCalibrator(CalibratorBase):
56
57
  window = self.current_step - self.last_non_approximated_step
57
58
  if self.state["dY_prev"][0] is not None:
58
59
  if dY_current[0].shape != self.state["dY_prev"][0].shape:
59
- self.reset_cache()
60
+ self.reset()
60
61
 
61
62
  for i in range(self.n_derivatives):
62
63
  if self.state["dY_prev"][i] is not None and self.current_step > 1:
@@ -77,9 +78,6 @@ class TaylorSeerCalibrator(CalibratorBase):
77
78
  break
78
79
  return output
79
80
 
80
- def mark_step_begin(self): # NEED
81
- self.current_step += 1
82
-
83
81
  def update(self, Y: torch.Tensor): # NEED
84
82
  # Directly call this method will ingnore the warmup
85
83
  # policy and force full computation.
@@ -106,5 +104,77 @@ class TaylorSeerCalibrator(CalibratorBase):
106
104
  else:
107
105
  return self.approximate()
108
106
 
107
+
108
+ class TaylorSeerCalibrator(CalibratorBase):
109
+ def __init__(
110
+ self,
111
+ n_derivatives=1,
112
+ max_warmup_steps=1,
113
+ skip_interval_steps=1,
114
+ **kwargs,
115
+ ):
116
+ self.n_derivatives = n_derivatives
117
+ self.max_warmup_steps = max_warmup_steps
118
+ self.skip_interval_steps = skip_interval_steps
119
+ self.states: Dict[str, TaylorSeerState] = {}
120
+ self.reset_cache()
121
+
122
+ def reset_cache(self): # NEED
123
+ if self.states:
124
+ for state in self.states.values():
125
+ state.reset()
126
+
127
+ def maybe_init_state(
128
+ self,
129
+ name: str = "default",
130
+ ):
131
+ if name not in self.states:
132
+ self.states[name] = TaylorSeerState(
133
+ n_derivatives=self.n_derivatives,
134
+ max_warmup_steps=self.max_warmup_steps,
135
+ skip_interval_steps=self.skip_interval_steps,
136
+ )
137
+
138
+ def mark_step_begin(self, *args, **kwargs):
139
+ if self.states:
140
+ for state in self.states.values():
141
+ state.mark_step_begin()
142
+
143
+ def derivative(
144
+ self,
145
+ Y: torch.Tensor,
146
+ name: str = "default",
147
+ ) -> List[torch.Tensor]:
148
+ self.maybe_init_state(name)
149
+ state = self.states[name]
150
+ state.derivative(Y)
151
+ return state.state["dY_current"]
152
+
153
+ def approximate(
154
+ self,
155
+ name: str = "default",
156
+ ) -> torch.Tensor: # NEED
157
+ assert name in self.states, f"State '{name}' not found."
158
+ state = self.states[name]
159
+ return state.approximate()
160
+
161
+ def update(
162
+ self,
163
+ Y: torch.Tensor,
164
+ name: str = "default",
165
+ ): # NEED
166
+ self.maybe_init_state(name)
167
+ state = self.states[name]
168
+ state.update(Y)
169
+
170
+ def step(
171
+ self,
172
+ Y: torch.Tensor,
173
+ name: str = "default",
174
+ ):
175
+ self.maybe_init_state(name)
176
+ state = self.states[name]
177
+ return state.step(Y)
178
+
109
179
  def __repr__(self):
110
180
  return f"TaylorSeerCalibrator_O({self.n_derivatives})"
@@ -0,0 +1,29 @@
1
+ from cache_dit.cache_factory.cache_types import CacheType
2
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
3
+ CachedContextManager,
4
+ )
5
+ from cache_dit.cache_factory.cache_contexts.prune_manager import (
6
+ PrunedContextManager,
7
+ )
8
+ from cache_dit.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ class ContextManager:
14
+ _supported_managers = (
15
+ CachedContextManager,
16
+ PrunedContextManager,
17
+ )
18
+
19
+ def __new__(
20
+ cls,
21
+ cache_type: CacheType,
22
+ name: str = "default",
23
+ ) -> CachedContextManager | PrunedContextManager:
24
+ if cache_type == CacheType.DBCache:
25
+ return CachedContextManager(name)
26
+ elif cache_type == CacheType.DBPrune:
27
+ return PrunedContextManager(name)
28
+ else:
29
+ raise ValueError(f"Unsupported cache_type: {cache_type}.")
@@ -0,0 +1,69 @@
1
+ import dataclasses
2
+ from typing import List
3
+ from cache_dit.cache_factory.cache_types import CacheType
4
+ from cache_dit.cache_factory.cache_contexts.cache_config import (
5
+ BasicCacheConfig,
6
+ )
7
+
8
+ from cache_dit.logger import init_logger
9
+
10
+ logger = init_logger(__name__)
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class DBPruneConfig(BasicCacheConfig):
15
+ # Dyanamic Block Prune specific configurations
16
+ cache_type: CacheType = CacheType.DBPrune # DBPrune
17
+
18
+ # enable_dynamic_prune_threshold (`bool`, *required*, defaults to False):
19
+ # Whether to enable the dynamic prune threshold or not. If True, we will
20
+ # compute the dynamic prune threshold based on the mean of the residual
21
+ # diffs of the previous computed or pruned blocks.
22
+ # But, also limit mean_diff to be at least 2x the residual_diff_threshold
23
+ # to avoid too aggressive pruning.
24
+ enable_dynamic_prune_threshold: bool = False
25
+ # max_dynamic_prune_threshold (`float`, *optional*, defaults to None):
26
+ # The max dynamic prune threshold, if not None, the dynamic prune threshold
27
+ # will not exceed this value. If None, we will limit it to be at least 2x
28
+ # the residual_diff_threshold to avoid too aggressive pruning.
29
+ max_dynamic_prune_threshold: float = None
30
+ # dynamic_prune_threshold_relax_ratio (`float`, *optional*, defaults to 1.25):
31
+ # The relax ratio for dynamic prune threshold, the dynamic prune threshold
32
+ # will be set as:
33
+ # dynamic_prune_threshold = mean_diff * dynamic_prune_threshold_relax_ratio
34
+ # to avoid too aggressive pruning.
35
+ # The default value is 1.25, which means the dynamic prune threshold will
36
+ # be 1.25 times the mean of the residual diffs of the previous computed
37
+ # or pruned blocks.
38
+ # Users can tune this value to achieve a better trade-off between speedup
39
+ # and precision. A higher value leads to more aggressive pruning
40
+ # and faster speedup, but may also lead to lower precision.
41
+ dynamic_prune_threshold_relax_ratio: float = 1.25
42
+ # non_prune_block_ids (`List[int]`, *optional*, defaults to []):
43
+ # The list of block ids that will not be pruned, even if their residual
44
+ # diffs are below the prune threshold. This can be useful for the first
45
+ # few blocks, which are usually more important for the model performance.
46
+ non_prune_block_ids: List[int] = dataclasses.field(default_factory=list)
47
+ # force_reduce_calibrator_vram (`bool`, *optional*, defaults to True):
48
+ # Whether to force reduce the VRAM usage of the calibrator for Dynamic Block
49
+ # Prune. If True, we will set the downsample_factor of the extra_cache_config
50
+ # to at least 2 to reduce the VRAM usage of the calibrator.
51
+ force_reduce_calibrator_vram: bool = False
52
+
53
+ def update(self, **kwargs) -> "DBPruneConfig":
54
+ for key, value in kwargs.items():
55
+ if hasattr(self, key):
56
+ setattr(self, key, value)
57
+ return self
58
+
59
+ def strify(self) -> str:
60
+ return (
61
+ f"{self.cache_type}_"
62
+ f"F{self.Fn_compute_blocks}"
63
+ f"B{self.Bn_compute_blocks}_"
64
+ f"W{self.max_warmup_steps}"
65
+ f"I{self.warmup_interval}"
66
+ f"M{max(0, self.max_cached_steps)}"
67
+ f"MC{max(0, self.max_continuous_cached_steps)}_"
68
+ f"R{self.residual_diff_threshold}"
69
+ )
@@ -0,0 +1,155 @@
1
+ import torch
2
+ import logging
3
+ import dataclasses
4
+ from typing import List
5
+ from cache_dit.cache_factory.cache_types import CacheType
6
+ from cache_dit.cache_factory.cache_contexts.prune_config import (
7
+ DBPruneConfig,
8
+ )
9
+ from cache_dit.cache_factory.cache_contexts.cache_context import (
10
+ CachedContext,
11
+ )
12
+ from cache_dit.logger import init_logger
13
+
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class PrunedContext(CachedContext):
20
+ # Overwrite the cache_config type for PrunedContext
21
+ cache_config: DBPruneConfig = dataclasses.field(
22
+ default_factory=DBPruneConfig,
23
+ )
24
+ # Specially for Dynamic Block Prune
25
+ pruned_blocks: List[int] = dataclasses.field(default_factory=list)
26
+ actual_blocks: List[int] = dataclasses.field(default_factory=list)
27
+ cfg_pruned_blocks: List[int] = dataclasses.field(default_factory=list)
28
+ cfg_actual_blocks: List[int] = dataclasses.field(default_factory=list)
29
+
30
+ def __post_init__(self):
31
+ super().__post_init__()
32
+ if not isinstance(self.cache_config, DBPruneConfig):
33
+ raise ValueError(
34
+ "PrunedContext only supports DBPruneConfig as cache_config."
35
+ )
36
+
37
+ if self.cache_config.cache_type == CacheType.DBPrune:
38
+ if (
39
+ self.calibrator_config is not None
40
+ and self.cache_config.force_reduce_calibrator_vram
41
+ ):
42
+ # May reduce VRAM usage for Dynamic Block Prune
43
+ self.extra_cache_config.downsample_factor = max(
44
+ 4, self.extra_cache_config.downsample_factor
45
+ )
46
+
47
+ def get_residual_diff_threshold(self):
48
+ # Overwite this func for Dynamic Block Prune
49
+ residual_diff_threshold = self.cache_config.residual_diff_threshold
50
+ if isinstance(residual_diff_threshold, torch.Tensor):
51
+ residual_diff_threshold = residual_diff_threshold.item()
52
+ if self.cache_config.enable_dynamic_prune_threshold:
53
+ # Compute the dynamic prune threshold based on the mean of the
54
+ # residual diffs of the previous computed or pruned blocks.
55
+ step = str(self.get_current_step())
56
+ if int(step) >= 0 and str(step) in self.residual_diffs:
57
+ assert isinstance(self.residual_diffs[step], list)
58
+ # Use all the recorded diffs for this step
59
+ # NOTE: Should we only use the last 5 diffs?
60
+ diffs = self.residual_diffs[step][:5]
61
+ diffs = [d for d in diffs if d > 0.0]
62
+ if diffs:
63
+ mean_diff = sum(diffs) / len(diffs)
64
+ relaxed_diff = (
65
+ mean_diff
66
+ * self.cache_config.dynamic_prune_threshold_relax_ratio
67
+ )
68
+ if self.cache_config.max_dynamic_prune_threshold is None:
69
+ max_dynamic_prune_threshold = (
70
+ 2 * residual_diff_threshold
71
+ )
72
+ else:
73
+ max_dynamic_prune_threshold = (
74
+ self.cache_config.max_dynamic_prune_threshold
75
+ )
76
+ if relaxed_diff < max_dynamic_prune_threshold:
77
+ # If the mean diff is less than twice the threshold,
78
+ # we can use it as the dynamic prune threshold.
79
+ residual_diff_threshold = (
80
+ relaxed_diff
81
+ if relaxed_diff > residual_diff_threshold
82
+ else residual_diff_threshold
83
+ )
84
+ if logger.isEnabledFor(logging.DEBUG):
85
+ logger.debug(
86
+ f"Dynamic prune threshold for step {step}: "
87
+ f"{residual_diff_threshold:.6f}"
88
+ )
89
+ return residual_diff_threshold
90
+
91
+ def mark_step_begin(self):
92
+ # Overwite this func for Dynamic Block Prune
93
+ super().mark_step_begin()
94
+ # Reset pruned_blocks and actual_blocks at the beginning
95
+ # of each transformer step.
96
+ if self.get_current_transformer_step() == 0:
97
+ self.pruned_blocks.clear()
98
+ self.actual_blocks.clear()
99
+
100
+ def add_residual_diff(self, diff: float | torch.Tensor):
101
+ # Overwite this func for Dynamic Block Prune
102
+ if isinstance(diff, torch.Tensor):
103
+ diff = diff.item()
104
+ # step: executed_steps - 1, not transformer_steps - 1
105
+ step = str(self.get_current_step())
106
+ # For Dynamic Block Prune, we will record all the diffs for this step
107
+ # Only add the diff if it is not already recorded for this step
108
+ if not self.is_separate_cfg_step():
109
+ if step not in self.residual_diffs:
110
+ self.residual_diffs[step] = []
111
+ self.residual_diffs[step].append(diff)
112
+ else:
113
+ if step not in self.cfg_residual_diffs:
114
+ self.cfg_residual_diffs[step] = []
115
+ self.cfg_residual_diffs[step].append(diff)
116
+
117
+ def add_pruned_step(self):
118
+ curr_cached_step = self.get_current_step()
119
+ # Avoid adding the same step multiple times
120
+ if not self.is_separate_cfg_step():
121
+ if curr_cached_step not in self.cached_steps:
122
+ self.add_cached_step()
123
+ else:
124
+ if curr_cached_step not in self.cfg_cached_steps:
125
+ self.add_cached_step()
126
+
127
+ def add_pruned_block(self, num_blocks):
128
+ if not self.is_separate_cfg_step():
129
+ self.pruned_blocks.append(num_blocks)
130
+ else:
131
+ self.cfg_pruned_blocks.append(num_blocks)
132
+
133
+ def add_actual_block(self, num_blocks):
134
+ if not self.is_separate_cfg_step():
135
+ self.actual_blocks.append(num_blocks)
136
+ else:
137
+ self.cfg_actual_blocks.append(num_blocks)
138
+
139
+ def get_pruned_blocks(self):
140
+ return self.pruned_blocks.copy()
141
+
142
+ def get_cfg_pruned_blocks(self):
143
+ return self.cfg_pruned_blocks.copy()
144
+
145
+ def get_actual_blocks(self):
146
+ return self.actual_blocks.copy()
147
+
148
+ def get_cfg_actual_blocks(self):
149
+ return self.cfg_actual_blocks.copy()
150
+
151
+ def get_pruned_steps(self):
152
+ return self.get_cached_steps()
153
+
154
+ def get_cfg_pruned_steps(self):
155
+ return self.get_cfg_cached_steps()