cache-dit 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +3 -0
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +8 -1
- cache_dit/cache_factory/block_adapters/__init__.py +4 -1
- cache_dit/cache_factory/cache_adapters/cache_adapter.py +126 -80
- cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
- cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
- cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
- cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
- cache_dit/cache_factory/cache_contexts/cache_config.py +118 -0
- cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
- cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
- cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +22 -0
- cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
- cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
- cache_dit/cache_factory/cache_contexts/prune_config.py +63 -0
- cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
- cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
- cache_dit/cache_factory/cache_interface.py +20 -14
- cache_dit/cache_factory/cache_types.py +19 -2
- cache_dit/cache_factory/params_modifier.py +7 -7
- cache_dit/cache_factory/utils.py +18 -7
- cache_dit/quantize/quantize_ao.py +58 -17
- cache_dit/utils.py +191 -54
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/METADATA +11 -10
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/RECORD +32 -27
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/WHEEL +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/entry_points.txt +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,155 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import logging
|
|
3
|
+
import dataclasses
|
|
4
|
+
from typing import List
|
|
5
|
+
from cache_dit.cache_factory.cache_types import CacheType
|
|
6
|
+
from cache_dit.cache_factory.cache_contexts.prune_config import (
|
|
7
|
+
DBPruneConfig,
|
|
8
|
+
)
|
|
9
|
+
from cache_dit.cache_factory.cache_contexts.cache_context import (
|
|
10
|
+
CachedContext,
|
|
11
|
+
)
|
|
12
|
+
from cache_dit.logger import init_logger
|
|
13
|
+
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
@dataclasses.dataclass
|
|
19
|
+
class PrunedContext(CachedContext):
|
|
20
|
+
# Overwrite the cache_config type for PrunedContext
|
|
21
|
+
cache_config: DBPruneConfig = dataclasses.field(
|
|
22
|
+
default_factory=DBPruneConfig,
|
|
23
|
+
)
|
|
24
|
+
# Specially for Dynamic Block Prune
|
|
25
|
+
pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
26
|
+
actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
27
|
+
cfg_pruned_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
28
|
+
cfg_actual_blocks: List[int] = dataclasses.field(default_factory=list)
|
|
29
|
+
|
|
30
|
+
def __post_init__(self):
|
|
31
|
+
super().__post_init__()
|
|
32
|
+
if not isinstance(self.cache_config, DBPruneConfig):
|
|
33
|
+
raise ValueError(
|
|
34
|
+
"PrunedContext only supports DBPruneConfig as cache_config."
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if self.cache_config.cache_type == CacheType.DBPrune:
|
|
38
|
+
if (
|
|
39
|
+
self.calibrator_config is not None
|
|
40
|
+
and self.cache_config.force_reduce_calibrator_vram
|
|
41
|
+
):
|
|
42
|
+
# May reduce VRAM usage for Dynamic Block Prune
|
|
43
|
+
self.extra_cache_config.downsample_factor = max(
|
|
44
|
+
4, self.extra_cache_config.downsample_factor
|
|
45
|
+
)
|
|
46
|
+
|
|
47
|
+
def get_residual_diff_threshold(self):
|
|
48
|
+
# Overwite this func for Dynamic Block Prune
|
|
49
|
+
residual_diff_threshold = self.cache_config.residual_diff_threshold
|
|
50
|
+
if isinstance(residual_diff_threshold, torch.Tensor):
|
|
51
|
+
residual_diff_threshold = residual_diff_threshold.item()
|
|
52
|
+
if self.cache_config.enable_dynamic_prune_threshold:
|
|
53
|
+
# Compute the dynamic prune threshold based on the mean of the
|
|
54
|
+
# residual diffs of the previous computed or pruned blocks.
|
|
55
|
+
step = str(self.get_current_step())
|
|
56
|
+
if int(step) >= 0 and str(step) in self.residual_diffs:
|
|
57
|
+
assert isinstance(self.residual_diffs[step], list)
|
|
58
|
+
# Use all the recorded diffs for this step
|
|
59
|
+
# NOTE: Should we only use the last 5 diffs?
|
|
60
|
+
diffs = self.residual_diffs[step][:5]
|
|
61
|
+
diffs = [d for d in diffs if d > 0.0]
|
|
62
|
+
if diffs:
|
|
63
|
+
mean_diff = sum(diffs) / len(diffs)
|
|
64
|
+
relaxed_diff = (
|
|
65
|
+
mean_diff
|
|
66
|
+
* self.cache_config.dynamic_prune_threshold_relax_ratio
|
|
67
|
+
)
|
|
68
|
+
if self.cache_config.max_dynamic_prune_threshold is None:
|
|
69
|
+
max_dynamic_prune_threshold = (
|
|
70
|
+
2 * residual_diff_threshold
|
|
71
|
+
)
|
|
72
|
+
else:
|
|
73
|
+
max_dynamic_prune_threshold = (
|
|
74
|
+
self.cache_config.max_dynamic_prune_threshold
|
|
75
|
+
)
|
|
76
|
+
if relaxed_diff < max_dynamic_prune_threshold:
|
|
77
|
+
# If the mean diff is less than twice the threshold,
|
|
78
|
+
# we can use it as the dynamic prune threshold.
|
|
79
|
+
residual_diff_threshold = (
|
|
80
|
+
relaxed_diff
|
|
81
|
+
if relaxed_diff > residual_diff_threshold
|
|
82
|
+
else residual_diff_threshold
|
|
83
|
+
)
|
|
84
|
+
if logger.isEnabledFor(logging.DEBUG):
|
|
85
|
+
logger.debug(
|
|
86
|
+
f"Dynamic prune threshold for step {step}: "
|
|
87
|
+
f"{residual_diff_threshold:.6f}"
|
|
88
|
+
)
|
|
89
|
+
return residual_diff_threshold
|
|
90
|
+
|
|
91
|
+
def mark_step_begin(self):
|
|
92
|
+
# Overwite this func for Dynamic Block Prune
|
|
93
|
+
super().mark_step_begin()
|
|
94
|
+
# Reset pruned_blocks and actual_blocks at the beginning
|
|
95
|
+
# of each transformer step.
|
|
96
|
+
if self.get_current_transformer_step() == 0:
|
|
97
|
+
self.pruned_blocks.clear()
|
|
98
|
+
self.actual_blocks.clear()
|
|
99
|
+
|
|
100
|
+
def add_residual_diff(self, diff: float | torch.Tensor):
|
|
101
|
+
# Overwite this func for Dynamic Block Prune
|
|
102
|
+
if isinstance(diff, torch.Tensor):
|
|
103
|
+
diff = diff.item()
|
|
104
|
+
# step: executed_steps - 1, not transformer_steps - 1
|
|
105
|
+
step = str(self.get_current_step())
|
|
106
|
+
# For Dynamic Block Prune, we will record all the diffs for this step
|
|
107
|
+
# Only add the diff if it is not already recorded for this step
|
|
108
|
+
if not self.is_separate_cfg_step():
|
|
109
|
+
if step not in self.residual_diffs:
|
|
110
|
+
self.residual_diffs[step] = []
|
|
111
|
+
self.residual_diffs[step].append(diff)
|
|
112
|
+
else:
|
|
113
|
+
if step not in self.cfg_residual_diffs:
|
|
114
|
+
self.cfg_residual_diffs[step] = []
|
|
115
|
+
self.cfg_residual_diffs[step].append(diff)
|
|
116
|
+
|
|
117
|
+
def add_pruned_step(self):
|
|
118
|
+
curr_cached_step = self.get_current_step()
|
|
119
|
+
# Avoid adding the same step multiple times
|
|
120
|
+
if not self.is_separate_cfg_step():
|
|
121
|
+
if curr_cached_step not in self.cached_steps:
|
|
122
|
+
self.add_cached_step()
|
|
123
|
+
else:
|
|
124
|
+
if curr_cached_step not in self.cfg_cached_steps:
|
|
125
|
+
self.add_cached_step()
|
|
126
|
+
|
|
127
|
+
def add_pruned_block(self, num_blocks):
|
|
128
|
+
if not self.is_separate_cfg_step():
|
|
129
|
+
self.pruned_blocks.append(num_blocks)
|
|
130
|
+
else:
|
|
131
|
+
self.cfg_pruned_blocks.append(num_blocks)
|
|
132
|
+
|
|
133
|
+
def add_actual_block(self, num_blocks):
|
|
134
|
+
if not self.is_separate_cfg_step():
|
|
135
|
+
self.actual_blocks.append(num_blocks)
|
|
136
|
+
else:
|
|
137
|
+
self.cfg_actual_blocks.append(num_blocks)
|
|
138
|
+
|
|
139
|
+
def get_pruned_blocks(self):
|
|
140
|
+
return self.pruned_blocks.copy()
|
|
141
|
+
|
|
142
|
+
def get_cfg_pruned_blocks(self):
|
|
143
|
+
return self.cfg_pruned_blocks.copy()
|
|
144
|
+
|
|
145
|
+
def get_actual_blocks(self):
|
|
146
|
+
return self.actual_blocks.copy()
|
|
147
|
+
|
|
148
|
+
def get_cfg_actual_blocks(self):
|
|
149
|
+
return self.cfg_actual_blocks.copy()
|
|
150
|
+
|
|
151
|
+
def get_pruned_steps(self):
|
|
152
|
+
return self.get_cached_steps()
|
|
153
|
+
|
|
154
|
+
def get_cfg_pruned_steps(self):
|
|
155
|
+
return self.get_cfg_cached_steps()
|
|
@@ -0,0 +1,154 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import functools
|
|
3
|
+
from typing import Dict, List, Tuple, Union
|
|
4
|
+
|
|
5
|
+
from cache_dit.cache_factory.cache_contexts.cache_manager import (
|
|
6
|
+
CachedContextManager,
|
|
7
|
+
)
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts.prune_context import (
|
|
9
|
+
PrunedContext,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
12
|
+
|
|
13
|
+
logger = init_logger(__name__)
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
class PrunedContextManager(CachedContextManager):
|
|
17
|
+
# Reuse CachedContextManager for Dynamic Block Prune
|
|
18
|
+
|
|
19
|
+
def __init__(self, name: str = None):
|
|
20
|
+
super().__init__(name)
|
|
21
|
+
# Overwrite for Dynamic Block Prune
|
|
22
|
+
self._current_context: PrunedContext = None
|
|
23
|
+
self._cached_context_manager: Dict[str, PrunedContext] = {}
|
|
24
|
+
|
|
25
|
+
# Overwrite for Dynamic Block Prune
|
|
26
|
+
def new_context(self, *args, **kwargs) -> PrunedContext:
|
|
27
|
+
_context = PrunedContext(*args, **kwargs)
|
|
28
|
+
self._cached_context_manager[_context.name] = _context
|
|
29
|
+
return _context
|
|
30
|
+
|
|
31
|
+
def set_context(
|
|
32
|
+
self,
|
|
33
|
+
cached_context: PrunedContext | str,
|
|
34
|
+
) -> PrunedContext:
|
|
35
|
+
return super().set_context(cached_context)
|
|
36
|
+
|
|
37
|
+
def get_context(self, name: str = None) -> PrunedContext:
|
|
38
|
+
return super().get_context(name)
|
|
39
|
+
|
|
40
|
+
def reset_context(
|
|
41
|
+
self,
|
|
42
|
+
cached_context: PrunedContext | str,
|
|
43
|
+
*args,
|
|
44
|
+
**kwargs,
|
|
45
|
+
) -> PrunedContext:
|
|
46
|
+
return super().reset_context(cached_context, *args, **kwargs)
|
|
47
|
+
|
|
48
|
+
# Specially for Dynamic Block Prune
|
|
49
|
+
@torch.compiler.disable
|
|
50
|
+
def add_pruned_step(self):
|
|
51
|
+
cached_context = self.get_context()
|
|
52
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
53
|
+
cached_context.add_pruned_step()
|
|
54
|
+
|
|
55
|
+
@torch.compiler.disable
|
|
56
|
+
def add_pruned_block(self, num_blocks):
|
|
57
|
+
cached_context = self.get_context()
|
|
58
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
59
|
+
cached_context.add_pruned_block(num_blocks)
|
|
60
|
+
|
|
61
|
+
@torch.compiler.disable
|
|
62
|
+
def add_actual_block(self, num_blocks):
|
|
63
|
+
cached_context = self.get_context()
|
|
64
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
65
|
+
cached_context.add_actual_block(num_blocks)
|
|
66
|
+
|
|
67
|
+
@torch.compiler.disable
|
|
68
|
+
def get_pruned_steps(self) -> List[int]:
|
|
69
|
+
cached_context = self.get_context()
|
|
70
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
71
|
+
return cached_context.get_pruned_steps()
|
|
72
|
+
|
|
73
|
+
@torch.compiler.disable
|
|
74
|
+
def get_cfg_pruned_steps(self) -> List[int]:
|
|
75
|
+
cached_context = self.get_context()
|
|
76
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
77
|
+
return cached_context.get_cfg_pruned_steps()
|
|
78
|
+
|
|
79
|
+
@torch.compiler.disable
|
|
80
|
+
def get_pruned_blocks(self) -> List[int]:
|
|
81
|
+
cached_context = self.get_context()
|
|
82
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
83
|
+
return cached_context.get_pruned_blocks()
|
|
84
|
+
|
|
85
|
+
@torch.compiler.disable
|
|
86
|
+
def get_actual_blocks(self) -> List[int]:
|
|
87
|
+
cached_context = self.get_context()
|
|
88
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
89
|
+
return cached_context.get_actual_blocks()
|
|
90
|
+
|
|
91
|
+
@torch.compiler.disable
|
|
92
|
+
def get_cfg_pruned_blocks(self) -> List[int]:
|
|
93
|
+
cached_context = self.get_context()
|
|
94
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
95
|
+
return cached_context.get_cfg_pruned_blocks()
|
|
96
|
+
|
|
97
|
+
@torch.compiler.disable
|
|
98
|
+
def get_cfg_actual_blocks(self) -> List[int]:
|
|
99
|
+
cached_context = self.get_context()
|
|
100
|
+
assert cached_context is not None, "cached_context must be set before"
|
|
101
|
+
return cached_context.get_cfg_actual_blocks()
|
|
102
|
+
|
|
103
|
+
@torch.compiler.disable
|
|
104
|
+
@functools.lru_cache(maxsize=8)
|
|
105
|
+
def get_non_prune_blocks_ids(self, num_blocks: int) -> List[int]:
|
|
106
|
+
assert num_blocks is not None, "num_blocks must be provided"
|
|
107
|
+
assert num_blocks > 0, "num_blocks must be greater than 0"
|
|
108
|
+
# Get the non-prune block ids for current context
|
|
109
|
+
# Never prune the first `Fn` and last `Bn` blocks.
|
|
110
|
+
Fn_compute_blocks_ids = list(
|
|
111
|
+
range(
|
|
112
|
+
self.Fn_compute_blocks()
|
|
113
|
+
if self.Fn_compute_blocks() < num_blocks
|
|
114
|
+
else num_blocks
|
|
115
|
+
)
|
|
116
|
+
)
|
|
117
|
+
|
|
118
|
+
Bn_compute_blocks_ids = list(
|
|
119
|
+
range(
|
|
120
|
+
num_blocks
|
|
121
|
+
- (
|
|
122
|
+
self.Bn_compute_blocks()
|
|
123
|
+
if self.Bn_compute_blocks() < num_blocks
|
|
124
|
+
else num_blocks
|
|
125
|
+
),
|
|
126
|
+
num_blocks,
|
|
127
|
+
)
|
|
128
|
+
)
|
|
129
|
+
context = self.get_context()
|
|
130
|
+
assert context is not None, "cached_context must be set before"
|
|
131
|
+
|
|
132
|
+
non_prune_blocks_ids = list(
|
|
133
|
+
set(
|
|
134
|
+
Fn_compute_blocks_ids
|
|
135
|
+
+ Bn_compute_blocks_ids
|
|
136
|
+
+ context.cache_config.non_prune_block_ids
|
|
137
|
+
)
|
|
138
|
+
)
|
|
139
|
+
non_prune_blocks_ids = [
|
|
140
|
+
d for d in non_prune_blocks_ids if d < num_blocks
|
|
141
|
+
]
|
|
142
|
+
return sorted(non_prune_blocks_ids)
|
|
143
|
+
|
|
144
|
+
@torch.compiler.disable
|
|
145
|
+
def can_prune(self, *args, **kwargs) -> bool:
|
|
146
|
+
# Directly reuse can_cache for Dynamic Block Prune
|
|
147
|
+
return self.can_cache(*args, **kwargs)
|
|
148
|
+
|
|
149
|
+
@torch.compiler.disable
|
|
150
|
+
def apply_prune(
|
|
151
|
+
self, *args, **kwargs
|
|
152
|
+
) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
|
|
153
|
+
# Directly reuse apply_cache for Dynamic Block Prune
|
|
154
|
+
return self.apply_cache(*args, **kwargs)
|
|
@@ -5,6 +5,8 @@ from cache_dit.cache_factory.block_adapters import BlockAdapter
|
|
|
5
5
|
from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
|
|
6
6
|
from cache_dit.cache_factory.cache_adapters import CachedAdapter
|
|
7
7
|
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
8
|
+
from cache_dit.cache_factory.cache_contexts import DBCacheConfig
|
|
9
|
+
from cache_dit.cache_factory.cache_contexts import DBPruneConfig
|
|
8
10
|
from cache_dit.cache_factory.cache_contexts import CalibratorConfig
|
|
9
11
|
from cache_dit.cache_factory.params_modifier import ParamsModifier
|
|
10
12
|
|
|
@@ -19,8 +21,12 @@ def enable_cache(
|
|
|
19
21
|
DiffusionPipeline,
|
|
20
22
|
BlockAdapter,
|
|
21
23
|
],
|
|
22
|
-
#
|
|
23
|
-
cache_config:
|
|
24
|
+
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
25
|
+
cache_config: Union[
|
|
26
|
+
BasicCacheConfig,
|
|
27
|
+
DBCacheConfig,
|
|
28
|
+
DBPruneConfig,
|
|
29
|
+
] = DBCacheConfig(),
|
|
24
30
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
25
31
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
26
32
|
# Modify cache context params for specific blocks.
|
|
@@ -136,14 +142,14 @@ def enable_cache(
|
|
|
136
142
|
>>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
|
|
137
143
|
"""
|
|
138
144
|
# Collect cache context kwargs
|
|
139
|
-
|
|
140
|
-
if (cache_type :=
|
|
145
|
+
context_kwargs = {}
|
|
146
|
+
if (cache_type := context_kwargs.get("cache_type", None)) is not None:
|
|
141
147
|
if cache_type == CacheType.NONE:
|
|
142
148
|
return pipe_or_adapter
|
|
143
149
|
|
|
144
150
|
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
145
151
|
# for backward compatibility but will be removed in the future.
|
|
146
|
-
|
|
152
|
+
deprecated_kwargs = {
|
|
147
153
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
148
154
|
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
149
155
|
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
@@ -159,23 +165,23 @@ def enable_cache(
|
|
|
159
165
|
),
|
|
160
166
|
}
|
|
161
167
|
|
|
162
|
-
|
|
163
|
-
k: v for k, v in
|
|
168
|
+
deprecated_kwargs = {
|
|
169
|
+
k: v for k, v in deprecated_kwargs.items() if v is not None
|
|
164
170
|
}
|
|
165
171
|
|
|
166
|
-
if
|
|
172
|
+
if deprecated_kwargs:
|
|
167
173
|
logger.warning(
|
|
168
174
|
"Manually settup DBCache context without BasicCacheConfig is "
|
|
169
175
|
"deprecated and will be removed in the future, please use "
|
|
170
176
|
"`cache_config` parameter instead!"
|
|
171
177
|
)
|
|
172
178
|
if cache_config is not None:
|
|
173
|
-
cache_config.update(**
|
|
179
|
+
cache_config.update(**deprecated_kwargs)
|
|
174
180
|
else:
|
|
175
|
-
cache_config = BasicCacheConfig(**
|
|
181
|
+
cache_config = BasicCacheConfig(**deprecated_kwargs)
|
|
176
182
|
|
|
177
183
|
if cache_config is not None:
|
|
178
|
-
|
|
184
|
+
context_kwargs["cache_config"] = cache_config
|
|
179
185
|
|
|
180
186
|
# WARNING: Deprecated taylorseer params. These parameters are now retained
|
|
181
187
|
# for backward compatibility but will be removed in the future.
|
|
@@ -202,15 +208,15 @@ def enable_cache(
|
|
|
202
208
|
)
|
|
203
209
|
|
|
204
210
|
if calibrator_config is not None:
|
|
205
|
-
|
|
211
|
+
context_kwargs["calibrator_config"] = calibrator_config
|
|
206
212
|
|
|
207
213
|
if params_modifiers is not None:
|
|
208
|
-
|
|
214
|
+
context_kwargs["params_modifiers"] = params_modifiers
|
|
209
215
|
|
|
210
216
|
if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
|
|
211
217
|
return CachedAdapter.apply(
|
|
212
218
|
pipe_or_adapter,
|
|
213
|
-
**
|
|
219
|
+
**context_kwargs,
|
|
214
220
|
)
|
|
215
221
|
else:
|
|
216
222
|
raise ValueError(
|
|
@@ -6,7 +6,8 @@ logger = init_logger(__name__)
|
|
|
6
6
|
|
|
7
7
|
class CacheType(Enum):
|
|
8
8
|
NONE = "NONE"
|
|
9
|
-
DBCache = "Dual_Block_Cache"
|
|
9
|
+
DBCache = "DBCache" # "Dual_Block_Cache"
|
|
10
|
+
DBPrune = "DBPrune" # "Dynamic_Block_Prune"
|
|
10
11
|
|
|
11
12
|
@staticmethod
|
|
12
13
|
def type(type_hint: "CacheType | str") -> "CacheType":
|
|
@@ -14,6 +15,9 @@ class CacheType(Enum):
|
|
|
14
15
|
return type_hint
|
|
15
16
|
return cache_type(type_hint)
|
|
16
17
|
|
|
18
|
+
def __str__(self) -> str:
|
|
19
|
+
return self.value
|
|
20
|
+
|
|
17
21
|
|
|
18
22
|
def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
19
23
|
if type_hint is None:
|
|
@@ -21,7 +25,6 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
21
25
|
|
|
22
26
|
if isinstance(type_hint, CacheType):
|
|
23
27
|
return type_hint
|
|
24
|
-
|
|
25
28
|
elif type_hint.upper() in (
|
|
26
29
|
"DUAL_BLOCK_CACHE",
|
|
27
30
|
"DB_CACHE",
|
|
@@ -29,6 +32,20 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
|
|
|
29
32
|
"DB",
|
|
30
33
|
):
|
|
31
34
|
return CacheType.DBCache
|
|
35
|
+
elif type_hint.upper() in (
|
|
36
|
+
"DYNAMIC_BLOCK_PRUNE",
|
|
37
|
+
"DB_PRUNE",
|
|
38
|
+
"DBPRUNE",
|
|
39
|
+
"DBP",
|
|
40
|
+
):
|
|
41
|
+
return CacheType.DBPrune
|
|
42
|
+
elif type_hint.upper() in (
|
|
43
|
+
"NONE",
|
|
44
|
+
"NO_CACHE",
|
|
45
|
+
"NOCACHE",
|
|
46
|
+
"NC",
|
|
47
|
+
):
|
|
48
|
+
return CacheType.NONE
|
|
32
49
|
return CacheType.NONE
|
|
33
50
|
|
|
34
51
|
|
|
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
|
|
|
11
11
|
class ParamsModifier:
|
|
12
12
|
def __init__(
|
|
13
13
|
self,
|
|
14
|
-
#
|
|
14
|
+
# BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
|
|
15
15
|
cache_config: BasicCacheConfig = None,
|
|
16
16
|
# Calibrator config: TaylorSeerCalibratorConfig, etc.
|
|
17
17
|
calibrator_config: Optional[CalibratorConfig] = None,
|
|
@@ -22,7 +22,7 @@ class ParamsModifier:
|
|
|
22
22
|
|
|
23
23
|
# WARNING: Deprecated cache config params. These parameters are now retained
|
|
24
24
|
# for backward compatibility but will be removed in the future.
|
|
25
|
-
|
|
25
|
+
deprecated_kwargs = {
|
|
26
26
|
"Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
|
|
27
27
|
"Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
|
|
28
28
|
"max_warmup_steps": kwargs.get("max_warmup_steps", None),
|
|
@@ -40,20 +40,20 @@ class ParamsModifier:
|
|
|
40
40
|
),
|
|
41
41
|
}
|
|
42
42
|
|
|
43
|
-
|
|
44
|
-
k: v for k, v in
|
|
43
|
+
deprecated_kwargs = {
|
|
44
|
+
k: v for k, v in deprecated_kwargs.items() if v is not None
|
|
45
45
|
}
|
|
46
46
|
|
|
47
|
-
if
|
|
47
|
+
if deprecated_kwargs:
|
|
48
48
|
logger.warning(
|
|
49
49
|
"Manually settup DBCache context without BasicCacheConfig is "
|
|
50
50
|
"deprecated and will be removed in the future, please use "
|
|
51
51
|
"`cache_config` parameter instead!"
|
|
52
52
|
)
|
|
53
53
|
if cache_config is not None:
|
|
54
|
-
cache_config.update(**
|
|
54
|
+
cache_config.update(**deprecated_kwargs)
|
|
55
55
|
else:
|
|
56
|
-
cache_config = BasicCacheConfig(**
|
|
56
|
+
cache_config = BasicCacheConfig(**deprecated_kwargs)
|
|
57
57
|
|
|
58
58
|
if cache_config is not None:
|
|
59
59
|
self._context_kwargs["cache_config"] = cache_config
|
cache_dit/cache_factory/utils.py
CHANGED
|
@@ -7,10 +7,6 @@ def load_cache_options_from_yaml(yaml_file_path):
|
|
|
7
7
|
kwargs: dict = yaml.safe_load(f)
|
|
8
8
|
|
|
9
9
|
required_keys = [
|
|
10
|
-
"max_warmup_steps",
|
|
11
|
-
"max_cached_steps",
|
|
12
|
-
"Fn_compute_blocks",
|
|
13
|
-
"Bn_compute_blocks",
|
|
14
10
|
"residual_diff_threshold",
|
|
15
11
|
]
|
|
16
12
|
for key in required_keys:
|
|
@@ -38,10 +34,25 @@ def load_cache_options_from_yaml(yaml_file_path):
|
|
|
38
34
|
)
|
|
39
35
|
)
|
|
40
36
|
|
|
41
|
-
|
|
37
|
+
if "cache_type" not in kwargs:
|
|
38
|
+
from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
|
|
42
39
|
|
|
43
|
-
|
|
44
|
-
|
|
40
|
+
cache_context_kwargs["cache_config"] = BasicCacheConfig()
|
|
41
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
42
|
+
else:
|
|
43
|
+
cache_type = kwargs.pop("cache_type")
|
|
44
|
+
if cache_type == "DBCache":
|
|
45
|
+
from cache_dit.cache_factory.cache_contexts import DBCacheConfig
|
|
46
|
+
|
|
47
|
+
cache_context_kwargs["cache_config"] = DBCacheConfig()
|
|
48
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
49
|
+
elif cache_type == "DBPrune":
|
|
50
|
+
from cache_dit.cache_factory.cache_contexts import DBPruneConfig
|
|
51
|
+
|
|
52
|
+
cache_context_kwargs["cache_config"] = DBPruneConfig()
|
|
53
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
45
56
|
|
|
46
57
|
return cache_context_kwargs
|
|
47
58
|
|
|
@@ -83,11 +83,18 @@ def quantize_ao(
|
|
|
83
83
|
def _quantization_fn():
|
|
84
84
|
try:
|
|
85
85
|
if quant_type == "fp8_w8a8_dq":
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
86
|
+
try:
|
|
87
|
+
from torchao.quantization import (
|
|
88
|
+
float8_dynamic_activation_float8_weight,
|
|
89
|
+
PerTensor,
|
|
90
|
+
PerRow,
|
|
91
|
+
)
|
|
92
|
+
except ImportError:
|
|
93
|
+
from torchao.quantization import (
|
|
94
|
+
Float8DynamicActivationFloat8WeightConfig as float8_dynamic_activation_float8_weight,
|
|
95
|
+
PerTensor,
|
|
96
|
+
PerRow,
|
|
97
|
+
)
|
|
91
98
|
|
|
92
99
|
if per_row: # Ensure bfloat16
|
|
93
100
|
module.to(torch.bfloat16)
|
|
@@ -109,7 +116,12 @@ def quantize_ao(
|
|
|
109
116
|
)
|
|
110
117
|
|
|
111
118
|
elif quant_type == "fp8_w8a16_wo":
|
|
112
|
-
|
|
119
|
+
try:
|
|
120
|
+
from torchao.quantization import float8_weight_only
|
|
121
|
+
except ImportError:
|
|
122
|
+
from torchao.quantization import (
|
|
123
|
+
Float8WeightOnlyConfig as float8_weight_only,
|
|
124
|
+
)
|
|
113
125
|
|
|
114
126
|
quantization_fn = float8_weight_only(
|
|
115
127
|
weight_dtype=kwargs.get(
|
|
@@ -119,14 +131,25 @@ def quantize_ao(
|
|
|
119
131
|
)
|
|
120
132
|
|
|
121
133
|
elif quant_type == "int8_w8a8_dq":
|
|
122
|
-
|
|
123
|
-
|
|
124
|
-
|
|
134
|
+
try:
|
|
135
|
+
from torchao.quantization import (
|
|
136
|
+
int8_dynamic_activation_int8_weight,
|
|
137
|
+
)
|
|
138
|
+
except ImportError:
|
|
139
|
+
from torchao.quantization import (
|
|
140
|
+
Int8DynamicActivationInt8WeightConfig as int8_dynamic_activation_int8_weight,
|
|
141
|
+
)
|
|
125
142
|
|
|
126
143
|
quantization_fn = int8_dynamic_activation_int8_weight()
|
|
127
144
|
|
|
128
145
|
elif quant_type == "int8_w8a16_wo":
|
|
129
|
-
|
|
146
|
+
|
|
147
|
+
try:
|
|
148
|
+
from torchao.quantization import int8_weight_only
|
|
149
|
+
except ImportError:
|
|
150
|
+
from torchao.quantization import (
|
|
151
|
+
Int8WeightOnlyConfig as int8_weight_only,
|
|
152
|
+
)
|
|
130
153
|
|
|
131
154
|
quantization_fn = int8_weight_only(
|
|
132
155
|
# group_size is None -> per_channel, else per group
|
|
@@ -134,23 +157,41 @@ def quantize_ao(
|
|
|
134
157
|
)
|
|
135
158
|
|
|
136
159
|
elif quant_type == "int4_w4a8_dq":
|
|
137
|
-
|
|
138
|
-
|
|
139
|
-
|
|
160
|
+
|
|
161
|
+
try:
|
|
162
|
+
from torchao.quantization import (
|
|
163
|
+
int8_dynamic_activation_int4_weight,
|
|
164
|
+
)
|
|
165
|
+
except ImportError:
|
|
166
|
+
from torchao.quantization import (
|
|
167
|
+
Int8DynamicActivationInt4WeightConfig as int8_dynamic_activation_int4_weight,
|
|
168
|
+
)
|
|
140
169
|
|
|
141
170
|
quantization_fn = int8_dynamic_activation_int4_weight(
|
|
142
171
|
group_size=kwargs.get("group_size", 32),
|
|
143
172
|
)
|
|
144
173
|
|
|
145
174
|
elif quant_type == "int4_w4a4_dq":
|
|
146
|
-
|
|
147
|
-
|
|
148
|
-
|
|
175
|
+
|
|
176
|
+
try:
|
|
177
|
+
from torchao.quantization import (
|
|
178
|
+
int4_dynamic_activation_int4_weight,
|
|
179
|
+
)
|
|
180
|
+
except ImportError:
|
|
181
|
+
from torchao.quantization import (
|
|
182
|
+
Int4DynamicActivationInt4WeightConfig as int4_dynamic_activation_int4_weight,
|
|
183
|
+
)
|
|
149
184
|
|
|
150
185
|
quantization_fn = int4_dynamic_activation_int4_weight()
|
|
151
186
|
|
|
152
187
|
elif quant_type == "int4_w4a16_wo":
|
|
153
|
-
|
|
188
|
+
|
|
189
|
+
try:
|
|
190
|
+
from torchao.quantization import int4_weight_only
|
|
191
|
+
except ImportError:
|
|
192
|
+
from torchao.quantization import (
|
|
193
|
+
Int4WeightOnlyConfig as int4_weight_only,
|
|
194
|
+
)
|
|
154
195
|
|
|
155
196
|
quantization_fn = int4_weight_only(
|
|
156
197
|
group_size=kwargs.get("group_size", 32),
|