cache-dit 1.0.2__py3-none-any.whl → 1.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +90 -76
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +271 -36
- cache_dit/cache_factory/cache_blocks/pattern_base.py +286 -45
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +102 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +26 -89
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +69 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +23 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +38 -27
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/METADATA +14 -7
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/RECORD +29 -24
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.2.dist-info → cache_dit-1.0.4.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,102 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import dataclasses
|
|
3
|
+
from typing import Optional, Union
|
|
4
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
5
|
+
from cache_dit.logger import init_logger
|
|
6
|
+
|
|
7
|
+
logger = init_logger(__name__)
|
|
8
|
+
|
|
9
|
+
|
|
10
|
+
@dataclasses.dataclass
|
|
11
|
+
class BasicCacheConfig:
|
|
12
|
+
# Default: Dual Block Cache with Flexible FnBn configuration.
|
|
13
|
+
cache_type: CacheType = CacheType.DBCache # DBCache, DBPrune, NONE
|
|
14
|
+
|
|
15
|
+
# Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
16
|
+
# Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
17
|
+
# at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
18
|
+
# accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
19
|
+
# for more details of DBCache.
|
|
20
|
+
Fn_compute_blocks: int = 8
|
|
21
|
+
# Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
22
|
+
# Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
23
|
+
# prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
24
|
+
# that use residual cache.
|
|
25
|
+
Bn_compute_blocks: int = 0
|
|
26
|
+
# residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
27
|
+
# the value of residual diff threshold, a higher value leads to faster performance at the
|
|
28
|
+
# cost of lower precision.
|
|
29
|
+
residual_diff_threshold: Union[torch.Tensor, float] = 0.08
|
|
30
|
+
# max_warmup_steps (`int`, *required*, defaults to 8):
|
|
31
|
+
# DBCache does not apply the caching strategy when the number of running steps is less than
|
|
32
|
+
# or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
33
|
+
max_warmup_steps: int = 8 # DON'T Cache in warmup steps
|
|
34
|
+
# warmup_interval (`int`, *required*, defaults to 1):
|
|
35
|
+
# Skip interval in warmup steps, e.g., when warmup_interval is 2, only 0, 2, 4, ... steps
|
|
36
|
+
# in warmup steps will be computed, others will use dynamic cache.
|
|
37
|
+
warmup_interval: int = 1 # skip interval in warmup steps
|
|
38
|
+
# max_cached_steps (`int`, *required*, defaults to -1):
|
|
39
|
+
# DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
40
|
+
# prevent precision degradation.
|
|
41
|
+
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
42
|
+
# max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
43
|
+
# DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
44
|
+
# prevent precision degradation.
|
|
45
|
+
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
46
|
+
# enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
47
|
+
# Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
48
|
+
# and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
49
|
+
# CogVideoX, HunyuanVideo, Mochi, etc.
|
|
50
|
+
enable_separate_cfg: Optional[bool] = None
|
|
51
|
+
# cfg_compute_first (`bool`, *required*, defaults to False):
|
|
52
|
+
# Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
53
|
+
# 1, 3, 5, ... -> CFG step.
|
|
54
|
+
cfg_compute_first: bool = False
|
|
55
|
+
# cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
56
|
+
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
57
|
+
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
58
|
+
cfg_diff_compute_separate: bool = True
|
|
59
|
+
|
|
60
|
+
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
61
|
+
for key, value in kwargs.items():
|
|
62
|
+
if hasattr(self, key):
|
|
63
|
+
setattr(self, key, value)
|
|
64
|
+
return self
|
|
65
|
+
|
|
66
|
+
def strify(self) -> str:
|
|
67
|
+
return (
|
|
68
|
+
f"{self.cache_type}_"
|
|
69
|
+
f"F{self.Fn_compute_blocks}"
|
|
70
|
+
f"B{self.Bn_compute_blocks}_"
|
|
71
|
+
f"W{self.max_warmup_steps}"
|
|
72
|
+
f"I{self.warmup_interval}"
|
|
73
|
+
f"M{max(0, self.max_cached_steps)}"
|
|
74
|
+
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
75
|
+
f"R{self.residual_diff_threshold}"
|
|
76
|
+
)
|
|
77
|
+
|
|
78
|
+
|
|
79
|
+
@dataclasses.dataclass
|
|
80
|
+
class ExtraCacheConfig:
|
|
81
|
+
# Some other not very important settings for Dual Block Cache.
|
|
82
|
+
# NOTE: These flags maybe deprecated in the future and users
|
|
83
|
+
# should never use these extra configurations in their cases.
|
|
84
|
+
|
|
85
|
+
# l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
|
|
86
|
+
# The hidden states diff threshold for DBCache if use hidden_states as
|
|
87
|
+
# cache (not residual).
|
|
88
|
+
l1_hidden_states_diff_threshold: float = None
|
|
89
|
+
# important_condition_threshold (`float`, *optional*, defaults to 0.0):
|
|
90
|
+
# Only select the most important tokens while calculating the l1 diff.
|
|
91
|
+
important_condition_threshold: float = 0.0
|
|
92
|
+
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
93
|
+
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
94
|
+
downsample_factor: int = 1
|
|
95
|
+
# num_inference_steps (`int`, *optional*, defaults to -1):
|
|
96
|
+
# num_inference_steps for DiffusionPipeline, for future use.
|
|
97
|
+
num_inference_steps: int = -1
|
|
98
|
+
|
|
99
|
+
|
|
100
|
+
@dataclasses.dataclass
|
|
101
|
+
class DBCacheConfig(BasicCacheConfig):
|
|
102
|
+
pass # Just an alias for BasicCacheConfig
|
|
@@ -5,6 +5,11 @@ from typing import Any, DefaultDict, Dict, List, Optional, Union, Tuple
|
|
|
5
5
|
|
|
6
6
|
import torch
|
|
7
7
|
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.cache_config import (
|
|
9
|
+
BasicCacheConfig,
|
|
10
|
+
ExtraCacheConfig,
|
|
11
|
+
DBCacheConfig,
|
|
12
|
+
)
|
|
8
13
|
from cache_dit.cache_factory.cache_contexts.calibrators import (
|
|
9
14
|
Calibrator,
|
|
10
15
|
CalibratorBase,
|
|
@@ -15,96 +20,16 @@ from cache_dit.logger import init_logger
|
|
|
15
20
|
logger = init_logger(__name__)
|
|
16
21
|
|
|
17
22
|
|
|
18
|
-
@dataclasses.dataclass
|
|
19
|
-
class BasicCacheConfig:
|
|
20
|
-
# Dual Block Cache with Flexible FnBn configuration.
|
|
21
|
-
|
|
22
|
-
# Fn_compute_blocks: (`int`, *required*, defaults to 8):
|
|
23
|
-
# Specifies that `DBCache` uses the **first n** Transformer blocks to fit the information
|
|
24
|
-
# at time step t, enabling the calculation of a more stable L1 diff and delivering more
|
|
25
|
-
# accurate information to subsequent blocks. Please check https://github.com/vipshop/cache-dit/blob/main/docs/DBCache.md
|
|
26
|
-
# for more details of DBCache.
|
|
27
|
-
Fn_compute_blocks: int = 8
|
|
28
|
-
# Bn_compute_blocks: (`int`, *required*, defaults to 0):
|
|
29
|
-
# Further fuses approximate information in the **last n** Transformer blocks to enhance
|
|
30
|
-
# prediction accuracy. These blocks act as an auto-scaler for approximate hidden states
|
|
31
|
-
# that use residual cache.
|
|
32
|
-
Bn_compute_blocks: int = 0
|
|
33
|
-
# residual_diff_threshold (`float`, *required*, defaults to 0.08):
|
|
34
|
-
# the value of residual diff threshold, a higher value leads to faster performance at the
|
|
35
|
-
# cost of lower precision.
|
|
36
|
-
residual_diff_threshold: Union[torch.Tensor, float] = 0.08
|
|
37
|
-
# max_warmup_steps (`int`, *required*, defaults to 8):
|
|
38
|
-
# DBCache does not apply the caching strategy when the number of running steps is less than
|
|
39
|
-
# or equal to this value, ensuring the model sufficiently learns basic features during warmup.
|
|
40
|
-
max_warmup_steps: int = 8 # DON'T Cache in warmup steps
|
|
41
|
-
# max_cached_steps (`int`, *required*, defaults to -1):
|
|
42
|
-
# DBCache disables the caching strategy when the previous cached steps exceed this value to
|
|
43
|
-
# prevent precision degradation.
|
|
44
|
-
max_cached_steps: int = -1 # for both CFG and non-CFG
|
|
45
|
-
# max_continuous_cached_steps (`int`, *required*, defaults to -1):
|
|
46
|
-
# DBCache disables the caching strategy when the previous continous cached steps exceed this value to
|
|
47
|
-
# prevent precision degradation.
|
|
48
|
-
max_continuous_cached_steps: int = -1 # the max continuous cached steps
|
|
49
|
-
# enable_separate_cfg (`bool`, *required*, defaults to None):
|
|
50
|
-
# Whether to do separate cfg or not, such as Wan 2.1, Qwen-Image. For model that fused CFG
|
|
51
|
-
# and non-CFG into single forward step, should set enable_separate_cfg as False, for example:
|
|
52
|
-
# CogVideoX, HunyuanVideo, Mochi, etc.
|
|
53
|
-
enable_separate_cfg: Optional[bool] = None
|
|
54
|
-
# cfg_compute_first (`bool`, *required*, defaults to False):
|
|
55
|
-
# Compute cfg forward first or not, default False, namely, 0, 2, 4, ..., -> non-CFG step;
|
|
56
|
-
# 1, 3, 5, ... -> CFG step.
|
|
57
|
-
cfg_compute_first: bool = False
|
|
58
|
-
# cfg_diff_compute_separate (`bool`, *required*, defaults to True):
|
|
59
|
-
# Compute separate diff values for CFG and non-CFG step, default True. If False, we will
|
|
60
|
-
# use the computed diff from current non-CFG transformer step for current CFG step.
|
|
61
|
-
cfg_diff_compute_separate: bool = True
|
|
62
|
-
|
|
63
|
-
def update(self, **kwargs) -> "BasicCacheConfig":
|
|
64
|
-
for key, value in kwargs.items():
|
|
65
|
-
if hasattr(self, key):
|
|
66
|
-
setattr(self, key, value)
|
|
67
|
-
return self
|
|
68
|
-
|
|
69
|
-
def strify(self) -> str:
|
|
70
|
-
return (
|
|
71
|
-
f"DBCACHE_F{self.Fn_compute_blocks}"
|
|
72
|
-
f"B{self.Bn_compute_blocks}_"
|
|
73
|
-
f"W{self.max_warmup_steps}"
|
|
74
|
-
f"M{max(0, self.max_cached_steps)}"
|
|
75
|
-
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
76
|
-
f"R{self.residual_diff_threshold}"
|
|
77
|
-
)
|
|
78
|
-
|
|
79
|
-
|
|
80
|
-
@dataclasses.dataclass
|
|
81
|
-
class ExtraCacheConfig:
|
|
82
|
-
# Some other not very important settings for Dual Block Cache.
|
|
83
|
-
# NOTE: These flags maybe deprecated in the future and users
|
|
84
|
-
# should never use these extra configurations in their cases.
|
|
85
|
-
|
|
86
|
-
# l1_hidden_states_diff_threshold (`float`, *optional*, defaults to None):
|
|
87
|
-
# The hidden states diff threshold for DBCache if use hidden_states as
|
|
88
|
-
# cache (not residual).
|
|
89
|
-
l1_hidden_states_diff_threshold: float = None
|
|
90
|
-
# important_condition_threshold (`float`, *optional*, defaults to 0.0):
|
|
91
|
-
# Only select the most important tokens while calculating the l1 diff.
|
|
92
|
-
important_condition_threshold: float = 0.0
|
|
93
|
-
# downsample_factor (`int`, *optional*, defaults to 1):
|
|
94
|
-
# Downsample factor for Fn buffer, in order the save GPU memory.
|
|
95
|
-
downsample_factor: int = 1
|
|
96
|
-
# num_inference_steps (`int`, *optional*, defaults to -1):
|
|
97
|
-
# num_inference_steps for DiffusionPipeline, for future use.
|
|
98
|
-
num_inference_steps: int = -1
|
|
99
|
-
|
|
100
|
-
|
|
101
23
|
@dataclasses.dataclass
|
|
102
24
|
class CachedContext:
|
|
103
25
|
name: str = "default"
|
|
104
26
|
# Buffer for storing the residuals and other tensors
|
|
105
27
|
buffers: Dict[str, Any] = dataclasses.field(default_factory=dict)
|
|
106
28
|
# Basic Dual Block Cache Config
|
|
107
|
-
cache_config:
|
|
29
|
+
cache_config: Union[
|
|
30
|
+
BasicCacheConfig,
|
|
31
|
+
DBCacheConfig,
|
|
32
|
+
] = dataclasses.field(
|
|
108
33
|
default_factory=BasicCacheConfig,
|
|
109
34
|
)
|
|
110
35
|
# NOTE: Users should never use these extra configurations.
|
|
@@ -126,14 +51,14 @@ class CachedContext:
|
|
|
126
51
|
# be double of executed_steps.
|
|
127
52
|
transformer_executed_steps: int = 0
|
|
128
53
|
|
|
129
|
-
# CFG & non-CFG cached steps
|
|
54
|
+
# CFG & non-CFG cached/pruned steps
|
|
130
55
|
cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
131
|
-
residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
56
|
+
residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
|
|
132
57
|
default_factory=lambda: defaultdict(float),
|
|
133
58
|
)
|
|
134
59
|
continuous_cached_steps: int = 0
|
|
135
60
|
cfg_cached_steps: List[int] = dataclasses.field(default_factory=list)
|
|
136
|
-
cfg_residual_diffs: DefaultDict[str, float] = dataclasses.field(
|
|
61
|
+
cfg_residual_diffs: DefaultDict[str, float | list] = dataclasses.field(
|
|
137
62
|
default_factory=lambda: defaultdict(float),
|
|
138
63
|
)
|
|
139
64
|
cfg_continuous_cached_steps: int = 0
|
|
@@ -281,7 +206,9 @@ class CachedContext:
|
|
|
281
206
|
def get_cfg_calibrators(self) -> Tuple[CalibratorBase, CalibratorBase]:
|
|
282
207
|
return self.cfg_calibrator, self.cfg_encoder_calibrator
|
|
283
208
|
|
|
284
|
-
def add_residual_diff(self, diff):
|
|
209
|
+
def add_residual_diff(self, diff: float | torch.Tensor):
|
|
210
|
+
if isinstance(diff, torch.Tensor):
|
|
211
|
+
diff = diff.item()
|
|
285
212
|
# step: executed_steps - 1, not transformer_steps - 1
|
|
286
213
|
step = str(self.get_current_step())
|
|
287
214
|
# Only add the diff if it is not already recorded for this step
|
|
@@ -346,5 +273,15 @@ class CachedContext:
|
|
|
346
273
|
# CFG steps: 1, 3, 5, 7, ...
|
|
347
274
|
return self.get_current_transformer_step() % 2 != 0
|
|
348
275
|
|
|
276
|
+
@property
|
|
277
|
+
def warmup_steps(self) -> List[int]:
|
|
278
|
+
return list(
|
|
279
|
+
range(
|
|
280
|
+
0,
|
|
281
|
+
self.cache_config.max_warmup_steps,
|
|
282
|
+
self.cache_config.warmup_interval,
|
|
283
|
+
)
|
|
284
|
+
)
|
|
285
|
+
|
|
349
286
|
def is_in_warmup(self):
|
|
350
|
-
return self.get_current_step()
|
|
287
|
+
return self.get_current_step() in self.warmup_steps
|
|
@@ -14,7 +14,7 @@ from cache_dit.logger import init_logger
|
|
|
14
14
|
logger = init_logger(__name__)
|
|
15
15
|
|
|
16
16
|
|
|
17
|
-
class
|
|
17
|
+
class ContextNotExistError(Exception):
|
|
18
18
|
pass
|
|
19
19
|
|
|
20
20
|
|
|
@@ -36,14 +36,14 @@ class CachedContextManager:
|
|
|
36
36
|
self._current_context = cached_context
|
|
37
37
|
else:
|
|
38
38
|
if cached_context not in self._cached_context_manager:
|
|
39
|
-
raise
|
|
39
|
+
raise ContextNotExistError("Context not exist!")
|
|
40
40
|
self._current_context = self._cached_context_manager[cached_context]
|
|
41
41
|
return self._current_context
|
|
42
42
|
|
|
43
43
|
def get_context(self, name: str = None) -> CachedContext:
|
|
44
44
|
if name is not None:
|
|
45
45
|
if name not in self._cached_context_manager:
|
|
46
|
-
raise
|
|
46
|
+
raise ContextNotExistError("Context not exist!")
|
|
47
47
|
return self._cached_context_manager[name]
|
|
48
48
|
return self._current_context
|
|
49
49
|
|
|
@@ -482,7 +482,7 @@ class CachedContextManager:
|
|
|
482
482
|
|
|
483
483
|
if calibrator is not None:
|
|
484
484
|
# Use calibrator to update the buffer
|
|
485
|
-
calibrator.update(buffer)
|
|
485
|
+
calibrator.update(buffer, name=prefix)
|
|
486
486
|
else:
|
|
487
487
|
if logger.isEnabledFor(logging.DEBUG):
|
|
488
488
|
logger.debug(
|
|
@@ -513,7 +513,7 @@ class CachedContextManager:
|
|
|
513
513
|
calibrator, _ = self.get_calibrator()
|
|
514
514
|
|
|
515
515
|
if calibrator is not None:
|
|
516
|
-
return calibrator.approximate()
|
|
516
|
+
return calibrator.approximate(name=prefix)
|
|
517
517
|
else:
|
|
518
518
|
if logger.isEnabledFor(logging.DEBUG):
|
|
519
519
|
logger.debug(
|
|
@@ -551,7 +551,7 @@ class CachedContextManager:
|
|
|
551
551
|
|
|
552
552
|
if encoder_calibrator is not None:
|
|
553
553
|
# Use CalibratorBase to update the buffer
|
|
554
|
-
encoder_calibrator.update(buffer)
|
|
554
|
+
encoder_calibrator.update(buffer, name=prefix)
|
|
555
555
|
else:
|
|
556
556
|
if logger.isEnabledFor(logging.DEBUG):
|
|
557
557
|
logger.debug(
|
|
@@ -582,7 +582,7 @@ class CachedContextManager:
|
|
|
582
582
|
|
|
583
583
|
if encoder_calibrator is not None:
|
|
584
584
|
# Use calibrator to approximate the value
|
|
585
|
-
return encoder_calibrator.approximate()
|
|
585
|
+
return encoder_calibrator.approximate(name=prefix)
|
|
586
586
|
else:
|
|
587
587
|
if logger.isEnabledFor(logging.DEBUG):
|
|
588
588
|
logger.debug(
|
|
@@ -10,13 +10,12 @@ from cache_dit.logger import init_logger
|
|
|
10
10
|
logger = init_logger(__name__)
|
|
11
11
|
|
|
12
12
|
|
|
13
|
-
class
|
|
13
|
+
class TaylorSeerState:
|
|
14
14
|
def __init__(
|
|
15
15
|
self,
|
|
16
16
|
n_derivatives=1,
|
|
17
17
|
max_warmup_steps=1,
|
|
18
18
|
skip_interval_steps=1,
|
|
19
|
-
**kwargs,
|
|
20
19
|
):
|
|
21
20
|
self.n_derivatives = n_derivatives
|
|
22
21
|
self.order = n_derivatives + 1
|
|
@@ -28,9 +27,8 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
28
27
|
"dY_prev": [None] * self.order,
|
|
29
28
|
"dY_current": [None] * self.order,
|
|
30
29
|
}
|
|
31
|
-
self.reset_cache()
|
|
32
30
|
|
|
33
|
-
def
|
|
31
|
+
def reset(self):
|
|
34
32
|
self.state: Dict[str, List[torch.Tensor]] = {
|
|
35
33
|
"dY_prev": [None] * self.order,
|
|
36
34
|
"dY_current": [None] * self.order,
|
|
@@ -38,6 +36,9 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
38
36
|
self.current_step = -1
|
|
39
37
|
self.last_non_approximated_step = -1
|
|
40
38
|
|
|
39
|
+
def mark_step_begin(self): # NEED
|
|
40
|
+
self.current_step += 1
|
|
41
|
+
|
|
41
42
|
def should_compute(self, step=None):
|
|
42
43
|
step = self.current_step if step is None else step
|
|
43
44
|
if (
|
|
@@ -56,7 +57,7 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
56
57
|
window = self.current_step - self.last_non_approximated_step
|
|
57
58
|
if self.state["dY_prev"][0] is not None:
|
|
58
59
|
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
59
|
-
self.
|
|
60
|
+
self.reset()
|
|
60
61
|
|
|
61
62
|
for i in range(self.n_derivatives):
|
|
62
63
|
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
@@ -77,9 +78,6 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
77
78
|
break
|
|
78
79
|
return output
|
|
79
80
|
|
|
80
|
-
def mark_step_begin(self): # NEED
|
|
81
|
-
self.current_step += 1
|
|
82
|
-
|
|
83
81
|
def update(self, Y: torch.Tensor): # NEED
|
|
84
82
|
# Directly call this method will ingnore the warmup
|
|
85
83
|
# policy and force full computation.
|
|
@@ -106,5 +104,77 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
106
104
|
else:
|
|
107
105
|
return self.approximate()
|
|
108
106
|
|
|
107
|
+
|
|
108
|
+
class TaylorSeerCalibrator(CalibratorBase):
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
n_derivatives=1,
|
|
112
|
+
max_warmup_steps=1,
|
|
113
|
+
skip_interval_steps=1,
|
|
114
|
+
**kwargs,
|
|
115
|
+
):
|
|
116
|
+
self.n_derivatives = n_derivatives
|
|
117
|
+
self.max_warmup_steps = max_warmup_steps
|
|
118
|
+
self.skip_interval_steps = skip_interval_steps
|
|
119
|
+
self.states: Dict[str, TaylorSeerState] = {}
|
|
120
|
+
self.reset_cache()
|
|
121
|
+
|
|
122
|
+
def reset_cache(self): # NEED
|
|
123
|
+
if self.states:
|
|
124
|
+
for state in self.states.values():
|
|
125
|
+
state.reset()
|
|
126
|
+
|
|
127
|
+
def maybe_init_state(
|
|
128
|
+
self,
|
|
129
|
+
name: str = "default",
|
|
130
|
+
):
|
|
131
|
+
if name not in self.states:
|
|
132
|
+
self.states[name] = TaylorSeerState(
|
|
133
|
+
n_derivatives=self.n_derivatives,
|
|
134
|
+
max_warmup_steps=self.max_warmup_steps,
|
|
135
|
+
skip_interval_steps=self.skip_interval_steps,
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
def mark_step_begin(self, *args, **kwargs):
|
|
139
|
+
if self.states:
|
|
140
|
+
for state in self.states.values():
|
|
141
|
+
state.mark_step_begin()
|
|
142
|
+
|
|
143
|
+
def derivative(
|
|
144
|
+
self,
|
|
145
|
+
Y: torch.Tensor,
|
|
146
|
+
name: str = "default",
|
|
147
|
+
) -> List[torch.Tensor]:
|
|
148
|
+
self.maybe_init_state(name)
|
|
149
|
+
state = self.states[name]
|
|
150
|
+
state.derivative(Y)
|
|
151
|
+
return state.state["dY_current"]
|
|
152
|
+
|
|
153
|
+
def approximate(
|
|
154
|
+
self,
|
|
155
|
+
name: str = "default",
|
|
156
|
+
) -> torch.Tensor: # NEED
|
|
157
|
+
assert name in self.states, f"State '{name}' not found."
|
|
158
|
+
state = self.states[name]
|
|
159
|
+
return state.approximate()
|
|
160
|
+
|
|
161
|
+
def update(
|
|
162
|
+
self,
|
|
163
|
+
Y: torch.Tensor,
|
|
164
|
+
name: str = "default",
|
|
165
|
+
): # NEED
|
|
166
|
+
self.maybe_init_state(name)
|
|
167
|
+
state = self.states[name]
|
|
168
|
+
state.update(Y)
|
|
169
|
+
|
|
170
|
+
def step(
|
|
171
|
+
self,
|
|
172
|
+
Y: torch.Tensor,
|
|
173
|
+
name: str = "default",
|
|
174
|
+
):
|
|
175
|
+
self.maybe_init_state(name)
|
|
176
|
+
state = self.states[name]
|
|
177
|
+
return state.step(Y)
|
|
178
|
+
|
|
109
179
|
def __repr__(self):
|
|
110
180
|
return f"TaylorSeerCalibrator_O({self.n_derivatives})"
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
2
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
3
|
+
CachedContextManager,
|
|
4
|
+
)
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.prune_manager import (
|
|
6
|
+
PrunedContextManager,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ContextManager:
|
|
14
|
+
_supported_managers = (
|
|
15
|
+
CachedContextManager,
|
|
16
|
+
PrunedContextManager,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def __new__(
|
|
20
|
+
cls,
|
|
21
|
+
cache_type: CacheType,
|
|
22
|
+
name: str = "default",
|
|
23
|
+
) -> CachedContextManager | PrunedContextManager:
|
|
24
|
+
if cache_type == CacheType.DBCache:
|
|
25
|
+
return CachedContextManager(name)
|
|
26
|
+
elif cache_type == CacheType.DBPrune:
|
|
27
|
+
return PrunedContextManager(name)
|
|
28
|
+
else:
|
|
29
|
+
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
@@ -0,0 +1,69 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import List
|
|
3
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
4
|
+
from cache_dit.cache_factory.cache_contexts.cache_config import (
|
|
5
|
+
BasicCacheConfig,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from cache_dit.logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclasses.dataclass
|
|
14
|
+
class DBPruneConfig(BasicCacheConfig):
|
|
15
|
+
# Dyanamic Block Prune specific configurations
|
|
16
|
+
cache_type: CacheType = CacheType.DBPrune # DBPrune
|
|
17
|
+
|
|
18
|
+
# enable_dynamic_prune_threshold (`bool`, *required*, defaults to False):
|
|
19
|
+
# Whether to enable the dynamic prune threshold or not. If True, we will
|
|
20
|
+
# compute the dynamic prune threshold based on the mean of the residual
|
|
21
|
+
# diffs of the previous computed or pruned blocks.
|
|
22
|
+
# But, also limit mean_diff to be at least 2x the residual_diff_threshold
|
|
23
|
+
# to avoid too aggressive pruning.
|
|
24
|
+
enable_dynamic_prune_threshold: bool = False
|
|
25
|
+
# max_dynamic_prune_threshold (`float`, *optional*, defaults to None):
|
|
26
|
+
# The max dynamic prune threshold, if not None, the dynamic prune threshold
|
|
27
|
+
# will not exceed this value. If None, we will limit it to be at least 2x
|
|
28
|
+
# the residual_diff_threshold to avoid too aggressive pruning.
|
|
29
|
+
max_dynamic_prune_threshold: float = None
|
|
30
|
+
# dynamic_prune_threshold_relax_ratio (`float`, *optional*, defaults to 1.25):
|
|
31
|
+
# The relax ratio for dynamic prune threshold, the dynamic prune threshold
|
|
32
|
+
# will be set as:
|
|
33
|
+
# dynamic_prune_threshold = mean_diff * dynamic_prune_threshold_relax_ratio
|
|
34
|
+
# to avoid too aggressive pruning.
|
|
35
|
+
# The default value is 1.25, which means the dynamic prune threshold will
|
|
36
|
+
# be 1.25 times the mean of the residual diffs of the previous computed
|
|
37
|
+
# or pruned blocks.
|
|
38
|
+
# Users can tune this value to achieve a better trade-off between speedup
|
|
39
|
+
# and precision. A higher value leads to more aggressive pruning
|
|
40
|
+
# and faster speedup, but may also lead to lower precision.
|
|
41
|
+
dynamic_prune_threshold_relax_ratio: float = 1.25
|
|
42
|
+
# non_prune_block_ids (`List[int]`, *optional*, defaults to []):
|
|
43
|
+
# The list of block ids that will not be pruned, even if their residual
|
|
44
|
+
# diffs are below the prune threshold. This can be useful for the first
|
|
45
|
+
# few blocks, which are usually more important for the model performance.
|
|
46
|
+
non_prune_block_ids: List[int] = dataclasses.field(default_factory=list)
|
|
47
|
+
# force_reduce_calibrator_vram (`bool`, *optional*, defaults to True):
|
|
48
|
+
# Whether to force reduce the VRAM usage of the calibrator for Dynamic Block
|
|
49
|
+
# Prune. If True, we will set the downsample_factor of the extra_cache_config
|
|
50
|
+
# to at least 2 to reduce the VRAM usage of the calibrator.
|
|
51
|
+
force_reduce_calibrator_vram: bool = False
|
|
52
|
+
|
|
53
|
+
def update(self, **kwargs) -> "DBPruneConfig":
|
|
54
|
+
for key, value in kwargs.items():
|
|
55
|
+
if hasattr(self, key):
|
|
56
|
+
setattr(self, key, value)
|
|
57
|
+
return self
|
|
58
|
+
|
|
59
|
+
def strify(self) -> str:
|
|
60
|
+
return (
|
|
61
|
+
f"{self.cache_type}_"
|
|
62
|
+
f"F{self.Fn_compute_blocks}"
|
|
63
|
+
f"B{self.Bn_compute_blocks}_"
|
|
64
|
+
f"W{self.max_warmup_steps}"
|
|
65
|
+
f"I{self.warmup_interval}"
|
|
66
|
+
f"M{max(0, self.max_cached_steps)}"
|
|
67
|
+
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
68
|
+
f"R{self.residual_diff_threshold}"
|
|
69
|
+
)
|