cache-dit 1.0.3__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +126 -11
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +78 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +214 -114
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +18 -94
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +133 -12
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/{cache_factory → caching}/cache_interface.py +150 -37
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_qwen_image_controlnet.py +1 -1
- cache_dit/{cache_factory → caching}/utils.py +19 -8
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +40 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/METADATA +123 -116
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -76
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -306
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -458
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -15
- cache_dit/cache_factory/patch_functors/__init__.py +0 -15
- cache_dit-1.0.3.dist-info/RECORD +0 -58
- cache_dit-1.0.3.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_blocks/offload_utils.py +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -1,7 +1,9 @@
|
|
|
1
|
+
# The TaylorSeerState codebase is adapted from FBCache. Over time its codebase
|
|
2
|
+
# diverged a lot, and TaylorSeerState API is no longer compatible with FBCache.
|
|
1
3
|
import math
|
|
2
4
|
import torch
|
|
3
5
|
from typing import List, Dict
|
|
4
|
-
from cache_dit.
|
|
6
|
+
from cache_dit.caching.cache_contexts.calibrators.base import (
|
|
5
7
|
CalibratorBase,
|
|
6
8
|
)
|
|
7
9
|
|
|
@@ -10,13 +12,12 @@ from cache_dit.logger import init_logger
|
|
|
10
12
|
logger = init_logger(__name__)
|
|
11
13
|
|
|
12
14
|
|
|
13
|
-
class
|
|
15
|
+
class TaylorSeerState:
|
|
14
16
|
def __init__(
|
|
15
17
|
self,
|
|
16
18
|
n_derivatives=1,
|
|
17
19
|
max_warmup_steps=1,
|
|
18
20
|
skip_interval_steps=1,
|
|
19
|
-
**kwargs,
|
|
20
21
|
):
|
|
21
22
|
self.n_derivatives = n_derivatives
|
|
22
23
|
self.order = n_derivatives + 1
|
|
@@ -28,9 +29,8 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
28
29
|
"dY_prev": [None] * self.order,
|
|
29
30
|
"dY_current": [None] * self.order,
|
|
30
31
|
}
|
|
31
|
-
self.reset_cache()
|
|
32
32
|
|
|
33
|
-
def
|
|
33
|
+
def reset(self):
|
|
34
34
|
self.state: Dict[str, List[torch.Tensor]] = {
|
|
35
35
|
"dY_prev": [None] * self.order,
|
|
36
36
|
"dY_current": [None] * self.order,
|
|
@@ -38,6 +38,9 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
38
38
|
self.current_step = -1
|
|
39
39
|
self.last_non_approximated_step = -1
|
|
40
40
|
|
|
41
|
+
def mark_step_begin(self): # NEED
|
|
42
|
+
self.current_step += 1
|
|
43
|
+
|
|
41
44
|
def should_compute(self, step=None):
|
|
42
45
|
step = self.current_step if step is None else step
|
|
43
46
|
if (
|
|
@@ -56,7 +59,7 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
56
59
|
window = self.current_step - self.last_non_approximated_step
|
|
57
60
|
if self.state["dY_prev"][0] is not None:
|
|
58
61
|
if dY_current[0].shape != self.state["dY_prev"][0].shape:
|
|
59
|
-
self.
|
|
62
|
+
self.reset()
|
|
60
63
|
|
|
61
64
|
for i in range(self.n_derivatives):
|
|
62
65
|
if self.state["dY_prev"][i] is not None and self.current_step > 1:
|
|
@@ -77,9 +80,6 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
77
80
|
break
|
|
78
81
|
return output
|
|
79
82
|
|
|
80
|
-
def mark_step_begin(self): # NEED
|
|
81
|
-
self.current_step += 1
|
|
82
|
-
|
|
83
83
|
def update(self, Y: torch.Tensor): # NEED
|
|
84
84
|
# Directly call this method will ingnore the warmup
|
|
85
85
|
# policy and force full computation.
|
|
@@ -106,5 +106,77 @@ class TaylorSeerCalibrator(CalibratorBase):
|
|
|
106
106
|
else:
|
|
107
107
|
return self.approximate()
|
|
108
108
|
|
|
109
|
+
|
|
110
|
+
class TaylorSeerCalibrator(CalibratorBase):
|
|
111
|
+
def __init__(
|
|
112
|
+
self,
|
|
113
|
+
n_derivatives=1,
|
|
114
|
+
max_warmup_steps=1,
|
|
115
|
+
skip_interval_steps=1,
|
|
116
|
+
**kwargs,
|
|
117
|
+
):
|
|
118
|
+
self.n_derivatives = n_derivatives
|
|
119
|
+
self.max_warmup_steps = max_warmup_steps
|
|
120
|
+
self.skip_interval_steps = skip_interval_steps
|
|
121
|
+
self.states: Dict[str, TaylorSeerState] = {}
|
|
122
|
+
self.reset_cache()
|
|
123
|
+
|
|
124
|
+
def reset_cache(self): # NEED
|
|
125
|
+
if self.states:
|
|
126
|
+
for state in self.states.values():
|
|
127
|
+
state.reset()
|
|
128
|
+
|
|
129
|
+
def maybe_init_state(
|
|
130
|
+
self,
|
|
131
|
+
name: str = "default",
|
|
132
|
+
):
|
|
133
|
+
if name not in self.states:
|
|
134
|
+
self.states[name] = TaylorSeerState(
|
|
135
|
+
n_derivatives=self.n_derivatives,
|
|
136
|
+
max_warmup_steps=self.max_warmup_steps,
|
|
137
|
+
skip_interval_steps=self.skip_interval_steps,
|
|
138
|
+
)
|
|
139
|
+
|
|
140
|
+
def mark_step_begin(self, *args, **kwargs):
|
|
141
|
+
if self.states:
|
|
142
|
+
for state in self.states.values():
|
|
143
|
+
state.mark_step_begin()
|
|
144
|
+
|
|
145
|
+
def derivative(
|
|
146
|
+
self,
|
|
147
|
+
Y: torch.Tensor,
|
|
148
|
+
name: str = "default",
|
|
149
|
+
) -> List[torch.Tensor]:
|
|
150
|
+
self.maybe_init_state(name)
|
|
151
|
+
state = self.states[name]
|
|
152
|
+
state.derivative(Y)
|
|
153
|
+
return state.state["dY_current"]
|
|
154
|
+
|
|
155
|
+
def approximate(
|
|
156
|
+
self,
|
|
157
|
+
name: str = "default",
|
|
158
|
+
) -> torch.Tensor: # NEED
|
|
159
|
+
assert name in self.states, f"State '{name}' not found."
|
|
160
|
+
state = self.states[name]
|
|
161
|
+
return state.approximate()
|
|
162
|
+
|
|
163
|
+
def update(
|
|
164
|
+
self,
|
|
165
|
+
Y: torch.Tensor,
|
|
166
|
+
name: str = "default",
|
|
167
|
+
): # NEED
|
|
168
|
+
self.maybe_init_state(name)
|
|
169
|
+
state = self.states[name]
|
|
170
|
+
state.update(Y)
|
|
171
|
+
|
|
172
|
+
def step(
|
|
173
|
+
self,
|
|
174
|
+
Y: torch.Tensor,
|
|
175
|
+
name: str = "default",
|
|
176
|
+
):
|
|
177
|
+
self.maybe_init_state(name)
|
|
178
|
+
state = self.states[name]
|
|
179
|
+
return state.step(Y)
|
|
180
|
+
|
|
109
181
|
def __repr__(self):
|
|
110
182
|
return f"TaylorSeerCalibrator_O({self.n_derivatives})"
|
|
@@ -0,0 +1,36 @@
|
|
|
1
|
+
from cache_dit.caching.cache_types import CacheType
|
|
2
|
+
from cache_dit.caching.cache_contexts.cache_manager import (
|
|
3
|
+
CachedContextManager,
|
|
4
|
+
)
|
|
5
|
+
from cache_dit.caching.cache_contexts.prune_manager import (
|
|
6
|
+
PrunedContextManager,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class ContextManager:
|
|
14
|
+
_supported_managers = (
|
|
15
|
+
CachedContextManager,
|
|
16
|
+
PrunedContextManager,
|
|
17
|
+
)
|
|
18
|
+
|
|
19
|
+
def __new__(
|
|
20
|
+
cls,
|
|
21
|
+
cache_type: CacheType,
|
|
22
|
+
name: str = "default",
|
|
23
|
+
persistent_context: bool = False,
|
|
24
|
+
) -> CachedContextManager | PrunedContextManager:
|
|
25
|
+
if cache_type == CacheType.DBCache:
|
|
26
|
+
return CachedContextManager(
|
|
27
|
+
name=name,
|
|
28
|
+
persistent_context=persistent_context,
|
|
29
|
+
)
|
|
30
|
+
elif cache_type == CacheType.DBPrune:
|
|
31
|
+
return PrunedContextManager(
|
|
32
|
+
name=name,
|
|
33
|
+
persistent_context=persistent_context,
|
|
34
|
+
)
|
|
35
|
+
else:
|
|
36
|
+
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
@@ -0,0 +1,63 @@
|
|
|
1
|
+
import dataclasses
|
|
2
|
+
from typing import List
|
|
3
|
+
from cache_dit.caching.cache_types import CacheType
|
|
4
|
+
from cache_dit.caching.cache_contexts.cache_config import (
|
|
5
|
+
BasicCacheConfig,
|
|
6
|
+
)
|
|
7
|
+
|
|
8
|
+
from cache_dit.logger import init_logger
|
|
9
|
+
|
|
10
|
+
logger = init_logger(__name__)
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
@dataclasses.dataclass
|
|
14
|
+
class DBPruneConfig(BasicCacheConfig):
|
|
15
|
+
# Dyanamic Block Prune specific configurations
|
|
16
|
+
cache_type: CacheType = CacheType.DBPrune # DBPrune
|
|
17
|
+
|
|
18
|
+
# enable_dynamic_prune_threshold (`bool`, *required*, defaults to False):
|
|
19
|
+
# Whether to enable the dynamic prune threshold or not. If True, we will
|
|
20
|
+
# compute the dynamic prune threshold based on the mean of the residual
|
|
21
|
+
# diffs of the previous computed or pruned blocks.
|
|
22
|
+
# But, also limit mean_diff to be at least 2x the residual_diff_threshold
|
|
23
|
+
# to avoid too aggressive pruning.
|
|
24
|
+
enable_dynamic_prune_threshold: bool = False
|
|
25
|
+
# max_dynamic_prune_threshold (`float`, *optional*, defaults to None):
|
|
26
|
+
# The max dynamic prune threshold, if not None, the dynamic prune threshold
|
|
27
|
+
# will not exceed this value. If None, we will limit it to be at least 2x
|
|
28
|
+
# the residual_diff_threshold to avoid too aggressive pruning.
|
|
29
|
+
max_dynamic_prune_threshold: float = None
|
|
30
|
+
# dynamic_prune_threshold_relax_ratio (`float`, *optional*, defaults to 1.25):
|
|
31
|
+
# The relax ratio for dynamic prune threshold, the dynamic prune threshold
|
|
32
|
+
# will be set as:
|
|
33
|
+
# dynamic_prune_threshold = mean_diff * dynamic_prune_threshold_relax_ratio
|
|
34
|
+
# to avoid too aggressive pruning.
|
|
35
|
+
# The default value is 1.25, which means the dynamic prune threshold will
|
|
36
|
+
# be 1.25 times the mean of the residual diffs of the previous computed
|
|
37
|
+
# or pruned blocks.
|
|
38
|
+
# Users can tune this value to achieve a better trade-off between speedup
|
|
39
|
+
# and precision. A higher value leads to more aggressive pruning
|
|
40
|
+
# and faster speedup, but may also lead to lower precision.
|
|
41
|
+
dynamic_prune_threshold_relax_ratio: float = 1.25
|
|
42
|
+
# non_prune_block_ids (`List[int]`, *optional*, defaults to []):
|
|
43
|
+
# The list of block ids that will not be pruned, even if their residual
|
|
44
|
+
# diffs are below the prune threshold. This can be useful for the first
|
|
45
|
+
# few blocks, which are usually more important for the model performance.
|
|
46
|
+
non_prune_block_ids: List[int] = dataclasses.field(default_factory=list)
|
|
47
|
+
# force_reduce_calibrator_vram (`bool`, *optional*, defaults to True):
|
|
48
|
+
# Whether to force reduce the VRAM usage of the calibrator for Dynamic Block
|
|
49
|
+
# Prune. If True, we will set the downsample_factor of the extra_cache_config
|
|
50
|
+
# to at least 2 to reduce the VRAM usage of the calibrator.
|
|
51
|
+
force_reduce_calibrator_vram: bool = False
|
|
52
|
+
|
|
53
|
+
def strify(self) -> str:
|
|
54
|
+
return (
|
|
55
|
+
f"{self.cache_type}_"
|
|
56
|
+
f"F{self.Fn_compute_blocks}"
|
|
57
|
+
f"B{self.Bn_compute_blocks}_"
|
|
58
|
+
f"W{self.max_warmup_steps}"
|
|
59
|
+
f"I{self.warmup_interval}"
|
|
60
|
+
f"M{max(0, self.max_cached_steps)}"
|
|
61
|
+
f"MC{max(0, self.max_continuous_cached_steps)}_"
|
|
62
|
+
f"R{self.residual_diff_threshold}"
|
|
63
|
+
)
|
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import List
|
|
5
|
+
from cache_dit.caching.cache_types import CacheType
|
|
6
|
+
from cache_dit.caching.cache_contexts.prune_config import (
|
|
7
|
+
DBPruneConfig,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.caching.cache_contexts.cache_context import (
|
|
10
|
+
CachedContext,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclasses.dataclass
|
|
19
|
+
class PrunedContext(CachedContext):
|
|
20
|
+
# Overwrite the cache_config type for PrunedContext
|
|
21
|
+
cache_config: DBPruneConfig = dataclasses.field(
|
|
22
|
+
default_factory=DBPruneConfig,
|
|
23
|
+
)
|
|
24
|
+
# Specially for Dynamic Block Prune
|
|
25
|
+
pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
26
|
+
actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
27
|
+
cfg_pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
28
|
+
cfg_actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
def __post_init__(self):
|
|
31
|
+
super().__post_init__()
|
|
32
|
+
if not isinstance(self.cache_config, DBPruneConfig):
|
|
33
|
+
raise ValueError(
|
|
34
|
+
"PrunedContext only supports DBPruneConfig as cache_config."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if self.cache_config.cache_type == CacheType.DBPrune:
|
|
38
|
+
if (
|
|
39
|
+
self.calibrator_config is not None
|
|
40
|
+
and self.cache_config.force_reduce_calibrator_vram
|
|
41
|
+
):
|
|
42
|
+
# May reduce VRAM usage for Dynamic Block Prune
|
|
43
|
+
self.extra_cache_config.downsample_factor = max(
|
|
44
|
+
4, self.extra_cache_config.downsample_factor
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def get_residual_diff_threshold(self):
|
|
48
|
+
# Overwite this func for Dynamic Block Prune
|
|
49
|
+
residual_diff_threshold = self.cache_config.residual_diff_threshold
|
|
50
|
+
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
51
|
+
residual_diff_threshold = residual_diff_threshold.item()
|
|
52
|
+
if self.cache_config.enable_dynamic_prune_threshold:
|
|
53
|
+
# Compute the dynamic prune threshold based on the mean of the
|
|
54
|
+
# residual diffs of the previous computed or pruned blocks.
|
|
55
|
+
step = str(self.get_current_step())
|
|
56
|
+
if int(step) >= 0 and str(step) in self.residual_diffs:
|
|
57
|
+
assert isinstance(self.residual_diffs[step], list)
|
|
58
|
+
# Use all the recorded diffs for this step
|
|
59
|
+
# NOTE: Should we only use the last 5 diffs?
|
|
60
|
+
diffs = self.residual_diffs[step][:5]
|
|
61
|
+
diffs = [d for d in diffs if d > 0.0]
|
|
62
|
+
if diffs:
|
|
63
|
+
mean_diff = sum(diffs) / len(diffs)
|
|
64
|
+
relaxed_diff = (
|
|
65
|
+
mean_diff
|
|
66
|
+
* self.cache_config.dynamic_prune_threshold_relax_ratio
|
|
67
|
+
)
|
|
68
|
+
if self.cache_config.max_dynamic_prune_threshold is None:
|
|
69
|
+
max_dynamic_prune_threshold = (
|
|
70
|
+
2 * residual_diff_threshold
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
max_dynamic_prune_threshold = (
|
|
74
|
+
self.cache_config.max_dynamic_prune_threshold
|
|
75
|
+
)
|
|
76
|
+
if relaxed_diff < max_dynamic_prune_threshold:
|
|
77
|
+
# If the mean diff is less than twice the threshold,
|
|
78
|
+
# we can use it as the dynamic prune threshold.
|
|
79
|
+
residual_diff_threshold = (
|
|
80
|
+
relaxed_diff
|
|
81
|
+
if relaxed_diff > residual_diff_threshold
|
|
82
|
+
else residual_diff_threshold
|
|
83
|
+
)
|
|
84
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
85
|
+
logger.debug(
|
|
86
|
+
f"Dynamic prune threshold for step {step}: "
|
|
87
|
+
f"{residual_diff_threshold:.6f}"
|
|
88
|
+
)
|
|
89
|
+
return residual_diff_threshold
|
|
90
|
+
|
|
91
|
+
def mark_step_begin(self):
|
|
92
|
+
# Overwite this func for Dynamic Block Prune
|
|
93
|
+
super().mark_step_begin()
|
|
94
|
+
# Reset pruned_blocks and actual_blocks at the beginning
|
|
95
|
+
# of each transformer step.
|
|
96
|
+
if self.get_current_transformer_step() == 0:
|
|
97
|
+
self.pruned_blocks.clear()
|
|
98
|
+
self.actual_blocks.clear()
|
|
99
|
+
|
|
100
|
+
def add_residual_diff(self, diff: float | torch.Tensor):
|
|
101
|
+
# Overwite this func for Dynamic Block Prune
|
|
102
|
+
if isinstance(diff, torch.Tensor):
|
|
103
|
+
diff = diff.item()
|
|
104
|
+
# step: executed_steps - 1, not transformer_steps - 1
|
|
105
|
+
step = str(self.get_current_step())
|
|
106
|
+
# For Dynamic Block Prune, we will record all the diffs for this step
|
|
107
|
+
# Only add the diff if it is not already recorded for this step
|
|
108
|
+
if not self.is_separate_cfg_step():
|
|
109
|
+
if step not in self.residual_diffs:
|
|
110
|
+
self.residual_diffs[step] = []
|
|
111
|
+
self.residual_diffs[step].append(diff)
|
|
112
|
+
else:
|
|
113
|
+
if step not in self.cfg_residual_diffs:
|
|
114
|
+
self.cfg_residual_diffs[step] = []
|
|
115
|
+
self.cfg_residual_diffs[step].append(diff)
|
|
116
|
+
|
|
117
|
+
def add_pruned_step(self):
|
|
118
|
+
curr_cached_step = self.get_current_step()
|
|
119
|
+
# Avoid adding the same step multiple times
|
|
120
|
+
if not self.is_separate_cfg_step():
|
|
121
|
+
if curr_cached_step not in self.cached_steps:
|
|
122
|
+
self.add_cached_step()
|
|
123
|
+
else:
|
|
124
|
+
if curr_cached_step not in self.cfg_cached_steps:
|
|
125
|
+
self.add_cached_step()
|
|
126
|
+
|
|
127
|
+
def add_pruned_block(self, num_blocks):
|
|
128
|
+
if not self.is_separate_cfg_step():
|
|
129
|
+
self.pruned_blocks.append(num_blocks)
|
|
130
|
+
else:
|
|
131
|
+
self.cfg_pruned_blocks.append(num_blocks)
|
|
132
|
+
|
|
133
|
+
def add_actual_block(self, num_blocks):
|
|
134
|
+
if not self.is_separate_cfg_step():
|
|
135
|
+
self.actual_blocks.append(num_blocks)
|
|
136
|
+
else:
|
|
137
|
+
self.cfg_actual_blocks.append(num_blocks)
|
|
138
|
+
|
|
139
|
+
def get_pruned_blocks(self):
|
|
140
|
+
return self.pruned_blocks.copy()
|
|
141
|
+
|
|
142
|
+
def get_cfg_pruned_blocks(self):
|
|
143
|
+
return self.cfg_pruned_blocks.copy()
|
|
144
|
+
|
|
145
|
+
def get_actual_blocks(self):
|
|
146
|
+
return self.actual_blocks.copy()
|
|
147
|
+
|
|
148
|
+
def get_cfg_actual_blocks(self):
|
|
149
|
+
return self.cfg_actual_blocks.copy()
|
|
150
|
+
|
|
151
|
+
def get_pruned_steps(self):
|
|
152
|
+
return self.get_cached_steps()
|
|
153
|
+
|
|
154
|
+
def get_cfg_pruned_steps(self):
|
|
155
|
+
return self.get_cfg_cached_steps()
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Dict, List, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from cache_dit.caching.cache_contexts.cache_manager import (
|
|
6
|
+
BasicCacheConfig,
|
|
7
|
+
CachedContextManager,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.caching.cache_contexts.prune_context import (
|
|
10
|
+
PrunedContext,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
logger = init_logger(__name__)
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class PrunedContextManager(CachedContextManager):
|
|
18
|
+
# Reuse CachedContextManager for Dynamic Block Prune
|
|
19
|
+
|
|
20
|
+
def __init__(self, name: str = None, **kwargs):
|
|
21
|
+
super().__init__(name, **kwargs)
|
|
22
|
+
# Overwrite for Dynamic Block Prune
|
|
23
|
+
self._current_context: PrunedContext = None
|
|
24
|
+
self._cached_context_manager: Dict[str, PrunedContext] = {}
|
|
25
|
+
|
|
26
|
+
# Overwrite for Dynamic Block Prune
|
|
27
|
+
def new_context(self, *args, **kwargs) -> PrunedContext:
|
|
28
|
+
if self._persistent_context:
|
|
29
|
+
cache_config: BasicCacheConfig = kwargs.get("cache_config", None)
|
|
30
|
+
assert (
|
|
31
|
+
cache_config is not None
|
|
32
|
+
and cache_config.num_inference_steps is not None
|
|
33
|
+
), (
|
|
34
|
+
"When persistent_context is True, num_inference_steps "
|
|
35
|
+
"must be set in cache_config for proper cache refreshing."
|
|
36
|
+
)
|
|
37
|
+
_context = PrunedContext(*args, **kwargs)
|
|
38
|
+
# NOTE: Patch args and kwargs for implicit refresh.
|
|
39
|
+
_context._init_args = args # maybe empty tuple: ()
|
|
40
|
+
_context._init_kwargs = kwargs # maybe empty dict: {}
|
|
41
|
+
self._cached_context_manager[_context.name] = _context
|
|
42
|
+
return _context
|
|
43
|
+
|
|
44
|
+
def set_context(
|
|
45
|
+
self,
|
|
46
|
+
cached_context: PrunedContext | str,
|
|
47
|
+
) -> PrunedContext:
|
|
48
|
+
return super().set_context(cached_context)
|
|
49
|
+
|
|
50
|
+
def get_context(self, name: str = None) -> PrunedContext:
|
|
51
|
+
return super().get_context(name)
|
|
52
|
+
|
|
53
|
+
def reset_context(
|
|
54
|
+
self,
|
|
55
|
+
cached_context: PrunedContext | str,
|
|
56
|
+
*args,
|
|
57
|
+
**kwargs,
|
|
58
|
+
) -> PrunedContext:
|
|
59
|
+
return super().reset_context(cached_context, *args, **kwargs)
|
|
60
|
+
|
|
61
|
+
# Specially for Dynamic Block Prune
|
|
62
|
+
@torch.compiler.disable
|
|
63
|
+
def add_pruned_step(self):
|
|
64
|
+
cached_context = self.get_context()
|
|
65
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
66
|
+
cached_context.add_pruned_step()
|
|
67
|
+
|
|
68
|
+
@torch.compiler.disable
|
|
69
|
+
def add_pruned_block(self, num_blocks):
|
|
70
|
+
cached_context = self.get_context()
|
|
71
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
72
|
+
cached_context.add_pruned_block(num_blocks)
|
|
73
|
+
|
|
74
|
+
@torch.compiler.disable
|
|
75
|
+
def add_actual_block(self, num_blocks):
|
|
76
|
+
cached_context = self.get_context()
|
|
77
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
78
|
+
cached_context.add_actual_block(num_blocks)
|
|
79
|
+
|
|
80
|
+
@torch.compiler.disable
|
|
81
|
+
def get_pruned_steps(self) -> List[int]:
|
|
82
|
+
cached_context = self.get_context()
|
|
83
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
84
|
+
return cached_context.get_pruned_steps()
|
|
85
|
+
|
|
86
|
+
@torch.compiler.disable
|
|
87
|
+
def get_cfg_pruned_steps(self) -> List[int]:
|
|
88
|
+
cached_context = self.get_context()
|
|
89
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
90
|
+
return cached_context.get_cfg_pruned_steps()
|
|
91
|
+
|
|
92
|
+
@torch.compiler.disable
|
|
93
|
+
def get_pruned_blocks(self) -> List[int]:
|
|
94
|
+
cached_context = self.get_context()
|
|
95
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
96
|
+
return cached_context.get_pruned_blocks()
|
|
97
|
+
|
|
98
|
+
@torch.compiler.disable
|
|
99
|
+
def get_actual_blocks(self) -> List[int]:
|
|
100
|
+
cached_context = self.get_context()
|
|
101
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
102
|
+
return cached_context.get_actual_blocks()
|
|
103
|
+
|
|
104
|
+
@torch.compiler.disable
|
|
105
|
+
def get_cfg_pruned_blocks(self) -> List[int]:
|
|
106
|
+
cached_context = self.get_context()
|
|
107
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
108
|
+
return cached_context.get_cfg_pruned_blocks()
|
|
109
|
+
|
|
110
|
+
@torch.compiler.disable
|
|
111
|
+
def get_cfg_actual_blocks(self) -> List[int]:
|
|
112
|
+
cached_context = self.get_context()
|
|
113
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
114
|
+
return cached_context.get_cfg_actual_blocks()
|
|
115
|
+
|
|
116
|
+
@torch.compiler.disable
|
|
117
|
+
@functools.lru_cache(maxsize=8)
|
|
118
|
+
def get_non_prune_blocks_ids(self, num_blocks: int) -> List[int]:
|
|
119
|
+
assert num_blocks is not None, "num_blocks must be provided"
|
|
120
|
+
assert num_blocks > 0, "num_blocks must be greater than 0"
|
|
121
|
+
# Get the non-prune block ids for current context
|
|
122
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
123
|
+
Fn_compute_blocks_ids = list(
|
|
124
|
+
range(
|
|
125
|
+
self.Fn_compute_blocks()
|
|
126
|
+
if self.Fn_compute_blocks() < num_blocks
|
|
127
|
+
else num_blocks
|
|
128
|
+
)
|
|
129
|
+
)
|
|
130
|
+
|
|
131
|
+
Bn_compute_blocks_ids = list(
|
|
132
|
+
range(
|
|
133
|
+
num_blocks
|
|
134
|
+
- (
|
|
135
|
+
self.Bn_compute_blocks()
|
|
136
|
+
if self.Bn_compute_blocks() < num_blocks
|
|
137
|
+
else num_blocks
|
|
138
|
+
),
|
|
139
|
+
num_blocks,
|
|
140
|
+
)
|
|
141
|
+
)
|
|
142
|
+
context = self.get_context()
|
|
143
|
+
assert context is not None, "cached_context must be set before"
|
|
144
|
+
|
|
145
|
+
non_prune_blocks_ids = list(
|
|
146
|
+
set(
|
|
147
|
+
Fn_compute_blocks_ids
|
|
148
|
+
+ Bn_compute_blocks_ids
|
|
149
|
+
+ context.cache_config.non_prune_block_ids
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
non_prune_blocks_ids = [
|
|
153
|
+
d for d in non_prune_blocks_ids if d < num_blocks
|
|
154
|
+
]
|
|
155
|
+
return sorted(non_prune_blocks_ids)
|
|
156
|
+
|
|
157
|
+
@torch.compiler.disable
|
|
158
|
+
def can_prune(self, *args, **kwargs) -> bool:
|
|
159
|
+
# Directly reuse can_cache for Dynamic Block Prune
|
|
160
|
+
return self.can_cache(*args, **kwargs)
|
|
161
|
+
|
|
162
|
+
@torch.compiler.disable
|
|
163
|
+
def apply_prune(
|
|
164
|
+
self, *args, **kwargs
|
|
165
|
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
|
166
|
+
# Directly reuse apply_cache for Dynamic Block Prune
|
|
167
|
+
return self.apply_cache(*args, **kwargs)
|