cache-dit 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/block_adapters/__init__.py +4 -1
  5. cache_dit/cache_factory/cache_adapters/cache_adapter.py +126 -80
  6. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  7. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  8. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
  10. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  11. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  12. cache_dit/cache_factory/cache_contexts/cache_config.py +118 -0
  13. cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
  14. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  15. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +22 -0
  16. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  17. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  18. cache_dit/cache_factory/cache_contexts/prune_config.py +63 -0
  19. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  20. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  21. cache_dit/cache_factory/cache_interface.py +20 -14
  22. cache_dit/cache_factory/cache_types.py +19 -2
  23. cache_dit/cache_factory/params_modifier.py +7 -7
  24. cache_dit/cache_factory/utils.py +18 -7
  25. cache_dit/quantize/quantize_ao.py +58 -17
  26. cache_dit/utils.py +191 -54
  27. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/METADATA +11 -10
  28. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/RECORD +32 -27
  29. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/WHEEL +0 -0
  30. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,155 @@
1
+ import torch
2
+ import logging
3
+ import dataclasses
4
+ from typing import List
5
+ from cache_dit.cache_factory.cache_types import CacheType
6
+ from cache_dit.cache_factory.cache_contexts.prune_config import (
7
+ DBPruneConfig,
8
+ )
9
+ from cache_dit.cache_factory.cache_contexts.cache_context import (
10
+ CachedContext,
11
+ )
12
+ from cache_dit.logger import init_logger
13
+
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ @dataclasses.dataclass
19
+ class PrunedContext(CachedContext):
20
+ # Overwrite the cache_config type for PrunedContext
21
+ cache_config: DBPruneConfig = dataclasses.field(
22
+ default_factory=DBPruneConfig,
23
+ )
24
+ # Specially for Dynamic Block Prune
25
+ pruned_blocks: List[int] = dataclasses.field(default_factory=list)
26
+ actual_blocks: List[int] = dataclasses.field(default_factory=list)
27
+ cfg_pruned_blocks: List[int] = dataclasses.field(default_factory=list)
28
+ cfg_actual_blocks: List[int] = dataclasses.field(default_factory=list)
29
+
30
+ def __post_init__(self):
31
+ super().__post_init__()
32
+ if not isinstance(self.cache_config, DBPruneConfig):
33
+ raise ValueError(
34
+ "PrunedContext only supports DBPruneConfig as cache_config."
35
+ )
36
+
37
+ if self.cache_config.cache_type == CacheType.DBPrune:
38
+ if (
39
+ self.calibrator_config is not None
40
+ and self.cache_config.force_reduce_calibrator_vram
41
+ ):
42
+ # May reduce VRAM usage for Dynamic Block Prune
43
+ self.extra_cache_config.downsample_factor = max(
44
+ 4, self.extra_cache_config.downsample_factor
45
+ )
46
+
47
+ def get_residual_diff_threshold(self):
48
+ # Overwite this func for Dynamic Block Prune
49
+ residual_diff_threshold = self.cache_config.residual_diff_threshold
50
+ if isinstance(residual_diff_threshold, torch.Tensor):
51
+ residual_diff_threshold = residual_diff_threshold.item()
52
+ if self.cache_config.enable_dynamic_prune_threshold:
53
+ # Compute the dynamic prune threshold based on the mean of the
54
+ # residual diffs of the previous computed or pruned blocks.
55
+ step = str(self.get_current_step())
56
+ if int(step) >= 0 and str(step) in self.residual_diffs:
57
+ assert isinstance(self.residual_diffs[step], list)
58
+ # Use all the recorded diffs for this step
59
+ # NOTE: Should we only use the last 5 diffs?
60
+ diffs = self.residual_diffs[step][:5]
61
+ diffs = [d for d in diffs if d > 0.0]
62
+ if diffs:
63
+ mean_diff = sum(diffs) / len(diffs)
64
+ relaxed_diff = (
65
+ mean_diff
66
+ * self.cache_config.dynamic_prune_threshold_relax_ratio
67
+ )
68
+ if self.cache_config.max_dynamic_prune_threshold is None:
69
+ max_dynamic_prune_threshold = (
70
+ 2 * residual_diff_threshold
71
+ )
72
+ else:
73
+ max_dynamic_prune_threshold = (
74
+ self.cache_config.max_dynamic_prune_threshold
75
+ )
76
+ if relaxed_diff < max_dynamic_prune_threshold:
77
+ # If the mean diff is less than twice the threshold,
78
+ # we can use it as the dynamic prune threshold.
79
+ residual_diff_threshold = (
80
+ relaxed_diff
81
+ if relaxed_diff > residual_diff_threshold
82
+ else residual_diff_threshold
83
+ )
84
+ if logger.isEnabledFor(logging.DEBUG):
85
+ logger.debug(
86
+ f"Dynamic prune threshold for step {step}: "
87
+ f"{residual_diff_threshold:.6f}"
88
+ )
89
+ return residual_diff_threshold
90
+
91
+ def mark_step_begin(self):
92
+ # Overwite this func for Dynamic Block Prune
93
+ super().mark_step_begin()
94
+ # Reset pruned_blocks and actual_blocks at the beginning
95
+ # of each transformer step.
96
+ if self.get_current_transformer_step() == 0:
97
+ self.pruned_blocks.clear()
98
+ self.actual_blocks.clear()
99
+
100
+ def add_residual_diff(self, diff: float | torch.Tensor):
101
+ # Overwite this func for Dynamic Block Prune
102
+ if isinstance(diff, torch.Tensor):
103
+ diff = diff.item()
104
+ # step: executed_steps - 1, not transformer_steps - 1
105
+ step = str(self.get_current_step())
106
+ # For Dynamic Block Prune, we will record all the diffs for this step
107
+ # Only add the diff if it is not already recorded for this step
108
+ if not self.is_separate_cfg_step():
109
+ if step not in self.residual_diffs:
110
+ self.residual_diffs[step] = []
111
+ self.residual_diffs[step].append(diff)
112
+ else:
113
+ if step not in self.cfg_residual_diffs:
114
+ self.cfg_residual_diffs[step] = []
115
+ self.cfg_residual_diffs[step].append(diff)
116
+
117
+ def add_pruned_step(self):
118
+ curr_cached_step = self.get_current_step()
119
+ # Avoid adding the same step multiple times
120
+ if not self.is_separate_cfg_step():
121
+ if curr_cached_step not in self.cached_steps:
122
+ self.add_cached_step()
123
+ else:
124
+ if curr_cached_step not in self.cfg_cached_steps:
125
+ self.add_cached_step()
126
+
127
+ def add_pruned_block(self, num_blocks):
128
+ if not self.is_separate_cfg_step():
129
+ self.pruned_blocks.append(num_blocks)
130
+ else:
131
+ self.cfg_pruned_blocks.append(num_blocks)
132
+
133
+ def add_actual_block(self, num_blocks):
134
+ if not self.is_separate_cfg_step():
135
+ self.actual_blocks.append(num_blocks)
136
+ else:
137
+ self.cfg_actual_blocks.append(num_blocks)
138
+
139
+ def get_pruned_blocks(self):
140
+ return self.pruned_blocks.copy()
141
+
142
+ def get_cfg_pruned_blocks(self):
143
+ return self.cfg_pruned_blocks.copy()
144
+
145
+ def get_actual_blocks(self):
146
+ return self.actual_blocks.copy()
147
+
148
+ def get_cfg_actual_blocks(self):
149
+ return self.cfg_actual_blocks.copy()
150
+
151
+ def get_pruned_steps(self):
152
+ return self.get_cached_steps()
153
+
154
+ def get_cfg_pruned_steps(self):
155
+ return self.get_cfg_cached_steps()
@@ -0,0 +1,154 @@
1
+ import torch
2
+ import functools
3
+ from typing import Dict, List, Tuple, Union
4
+
5
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
6
+ CachedContextManager,
7
+ )
8
+ from cache_dit.cache_factory.cache_contexts.prune_context import (
9
+ PrunedContext,
10
+ )
11
+ from cache_dit.logger import init_logger
12
+
13
+ logger = init_logger(__name__)
14
+
15
+
16
+ class PrunedContextManager(CachedContextManager):
17
+ # Reuse CachedContextManager for Dynamic Block Prune
18
+
19
+ def __init__(self, name: str = None):
20
+ super().__init__(name)
21
+ # Overwrite for Dynamic Block Prune
22
+ self._current_context: PrunedContext = None
23
+ self._cached_context_manager: Dict[str, PrunedContext] = {}
24
+
25
+ # Overwrite for Dynamic Block Prune
26
+ def new_context(self, *args, **kwargs) -> PrunedContext:
27
+ _context = PrunedContext(*args, **kwargs)
28
+ self._cached_context_manager[_context.name] = _context
29
+ return _context
30
+
31
+ def set_context(
32
+ self,
33
+ cached_context: PrunedContext | str,
34
+ ) -> PrunedContext:
35
+ return super().set_context(cached_context)
36
+
37
+ def get_context(self, name: str = None) -> PrunedContext:
38
+ return super().get_context(name)
39
+
40
+ def reset_context(
41
+ self,
42
+ cached_context: PrunedContext | str,
43
+ *args,
44
+ **kwargs,
45
+ ) -> PrunedContext:
46
+ return super().reset_context(cached_context, *args, **kwargs)
47
+
48
+ # Specially for Dynamic Block Prune
49
+ @torch.compiler.disable
50
+ def add_pruned_step(self):
51
+ cached_context = self.get_context()
52
+ assert cached_context is not None, "cached_context must be set before"
53
+ cached_context.add_pruned_step()
54
+
55
+ @torch.compiler.disable
56
+ def add_pruned_block(self, num_blocks):
57
+ cached_context = self.get_context()
58
+ assert cached_context is not None, "cached_context must be set before"
59
+ cached_context.add_pruned_block(num_blocks)
60
+
61
+ @torch.compiler.disable
62
+ def add_actual_block(self, num_blocks):
63
+ cached_context = self.get_context()
64
+ assert cached_context is not None, "cached_context must be set before"
65
+ cached_context.add_actual_block(num_blocks)
66
+
67
+ @torch.compiler.disable
68
+ def get_pruned_steps(self) -> List[int]:
69
+ cached_context = self.get_context()
70
+ assert cached_context is not None, "cached_context must be set before"
71
+ return cached_context.get_pruned_steps()
72
+
73
+ @torch.compiler.disable
74
+ def get_cfg_pruned_steps(self) -> List[int]:
75
+ cached_context = self.get_context()
76
+ assert cached_context is not None, "cached_context must be set before"
77
+ return cached_context.get_cfg_pruned_steps()
78
+
79
+ @torch.compiler.disable
80
+ def get_pruned_blocks(self) -> List[int]:
81
+ cached_context = self.get_context()
82
+ assert cached_context is not None, "cached_context must be set before"
83
+ return cached_context.get_pruned_blocks()
84
+
85
+ @torch.compiler.disable
86
+ def get_actual_blocks(self) -> List[int]:
87
+ cached_context = self.get_context()
88
+ assert cached_context is not None, "cached_context must be set before"
89
+ return cached_context.get_actual_blocks()
90
+
91
+ @torch.compiler.disable
92
+ def get_cfg_pruned_blocks(self) -> List[int]:
93
+ cached_context = self.get_context()
94
+ assert cached_context is not None, "cached_context must be set before"
95
+ return cached_context.get_cfg_pruned_blocks()
96
+
97
+ @torch.compiler.disable
98
+ def get_cfg_actual_blocks(self) -> List[int]:
99
+ cached_context = self.get_context()
100
+ assert cached_context is not None, "cached_context must be set before"
101
+ return cached_context.get_cfg_actual_blocks()
102
+
103
+ @torch.compiler.disable
104
+ @functools.lru_cache(maxsize=8)
105
+ def get_non_prune_blocks_ids(self, num_blocks: int) -> List[int]:
106
+ assert num_blocks is not None, "num_blocks must be provided"
107
+ assert num_blocks > 0, "num_blocks must be greater than 0"
108
+ # Get the non-prune block ids for current context
109
+ # Never prune the first `Fn` and last `Bn` blocks.
110
+ Fn_compute_blocks_ids = list(
111
+ range(
112
+ self.Fn_compute_blocks()
113
+ if self.Fn_compute_blocks() < num_blocks
114
+ else num_blocks
115
+ )
116
+ )
117
+
118
+ Bn_compute_blocks_ids = list(
119
+ range(
120
+ num_blocks
121
+ - (
122
+ self.Bn_compute_blocks()
123
+ if self.Bn_compute_blocks() < num_blocks
124
+ else num_blocks
125
+ ),
126
+ num_blocks,
127
+ )
128
+ )
129
+ context = self.get_context()
130
+ assert context is not None, "cached_context must be set before"
131
+
132
+ non_prune_blocks_ids = list(
133
+ set(
134
+ Fn_compute_blocks_ids
135
+ + Bn_compute_blocks_ids
136
+ + context.cache_config.non_prune_block_ids
137
+ )
138
+ )
139
+ non_prune_blocks_ids = [
140
+ d for d in non_prune_blocks_ids if d < num_blocks
141
+ ]
142
+ return sorted(non_prune_blocks_ids)
143
+
144
+ @torch.compiler.disable
145
+ def can_prune(self, *args, **kwargs) -> bool:
146
+ # Directly reuse can_cache for Dynamic Block Prune
147
+ return self.can_cache(*args, **kwargs)
148
+
149
+ @torch.compiler.disable
150
+ def apply_prune(
151
+ self, *args, **kwargs
152
+ ) -> Tuple[torch.Tensor, Union[torch.Tensor, None]]:
153
+ # Directly reuse apply_cache for Dynamic Block Prune
154
+ return self.apply_cache(*args, **kwargs)
@@ -5,6 +5,8 @@ from cache_dit.cache_factory.block_adapters import BlockAdapter
5
5
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
6
6
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
7
7
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
8
+ from cache_dit.cache_factory.cache_contexts import DBCacheConfig
9
+ from cache_dit.cache_factory.cache_contexts import DBPruneConfig
8
10
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
9
11
  from cache_dit.cache_factory.params_modifier import ParamsModifier
10
12
 
@@ -19,8 +21,12 @@ def enable_cache(
19
21
  DiffusionPipeline,
20
22
  BlockAdapter,
21
23
  ],
22
- # Basic DBCache config: BasicCacheConfig
23
- cache_config: BasicCacheConfig = BasicCacheConfig(),
24
+ # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
25
+ cache_config: Union[
26
+ BasicCacheConfig,
27
+ DBCacheConfig,
28
+ DBPruneConfig,
29
+ ] = DBCacheConfig(),
24
30
  # Calibrator config: TaylorSeerCalibratorConfig, etc.
25
31
  calibrator_config: Optional[CalibratorConfig] = None,
26
32
  # Modify cache context params for specific blocks.
@@ -136,14 +142,14 @@ def enable_cache(
136
142
  >>> cache_dit.disable_cache(pipe) # Disable cache and run original pipe.
137
143
  """
138
144
  # Collect cache context kwargs
139
- cache_context_kwargs = {}
140
- if (cache_type := cache_context_kwargs.pop("cache_type", None)) is not None:
145
+ context_kwargs = {}
146
+ if (cache_type := context_kwargs.get("cache_type", None)) is not None:
141
147
  if cache_type == CacheType.NONE:
142
148
  return pipe_or_adapter
143
149
 
144
150
  # WARNING: Deprecated cache config params. These parameters are now retained
145
151
  # for backward compatibility but will be removed in the future.
146
- deprecated_cache_kwargs = {
152
+ deprecated_kwargs = {
147
153
  "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
148
154
  "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
149
155
  "max_warmup_steps": kwargs.get("max_warmup_steps", None),
@@ -159,23 +165,23 @@ def enable_cache(
159
165
  ),
160
166
  }
161
167
 
162
- deprecated_cache_kwargs = {
163
- k: v for k, v in deprecated_cache_kwargs.items() if v is not None
168
+ deprecated_kwargs = {
169
+ k: v for k, v in deprecated_kwargs.items() if v is not None
164
170
  }
165
171
 
166
- if deprecated_cache_kwargs:
172
+ if deprecated_kwargs:
167
173
  logger.warning(
168
174
  "Manually settup DBCache context without BasicCacheConfig is "
169
175
  "deprecated and will be removed in the future, please use "
170
176
  "`cache_config` parameter instead!"
171
177
  )
172
178
  if cache_config is not None:
173
- cache_config.update(**deprecated_cache_kwargs)
179
+ cache_config.update(**deprecated_kwargs)
174
180
  else:
175
- cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
181
+ cache_config = BasicCacheConfig(**deprecated_kwargs)
176
182
 
177
183
  if cache_config is not None:
178
- cache_context_kwargs["cache_config"] = cache_config
184
+ context_kwargs["cache_config"] = cache_config
179
185
 
180
186
  # WARNING: Deprecated taylorseer params. These parameters are now retained
181
187
  # for backward compatibility but will be removed in the future.
@@ -202,15 +208,15 @@ def enable_cache(
202
208
  )
203
209
 
204
210
  if calibrator_config is not None:
205
- cache_context_kwargs["calibrator_config"] = calibrator_config
211
+ context_kwargs["calibrator_config"] = calibrator_config
206
212
 
207
213
  if params_modifiers is not None:
208
- cache_context_kwargs["params_modifiers"] = params_modifiers
214
+ context_kwargs["params_modifiers"] = params_modifiers
209
215
 
210
216
  if isinstance(pipe_or_adapter, (DiffusionPipeline, BlockAdapter)):
211
217
  return CachedAdapter.apply(
212
218
  pipe_or_adapter,
213
- **cache_context_kwargs,
219
+ **context_kwargs,
214
220
  )
215
221
  else:
216
222
  raise ValueError(
@@ -6,7 +6,8 @@ logger = init_logger(__name__)
6
6
 
7
7
  class CacheType(Enum):
8
8
  NONE = "NONE"
9
- DBCache = "Dual_Block_Cache"
9
+ DBCache = "DBCache" # "Dual_Block_Cache"
10
+ DBPrune = "DBPrune" # "Dynamic_Block_Prune"
10
11
 
11
12
  @staticmethod
12
13
  def type(type_hint: "CacheType | str") -> "CacheType":
@@ -14,6 +15,9 @@ class CacheType(Enum):
14
15
  return type_hint
15
16
  return cache_type(type_hint)
16
17
 
18
+ def __str__(self) -> str:
19
+ return self.value
20
+
17
21
 
18
22
  def cache_type(type_hint: "CacheType | str") -> "CacheType":
19
23
  if type_hint is None:
@@ -21,7 +25,6 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
21
25
 
22
26
  if isinstance(type_hint, CacheType):
23
27
  return type_hint
24
-
25
28
  elif type_hint.upper() in (
26
29
  "DUAL_BLOCK_CACHE",
27
30
  "DB_CACHE",
@@ -29,6 +32,20 @@ def cache_type(type_hint: "CacheType | str") -> "CacheType":
29
32
  "DB",
30
33
  ):
31
34
  return CacheType.DBCache
35
+ elif type_hint.upper() in (
36
+ "DYNAMIC_BLOCK_PRUNE",
37
+ "DB_PRUNE",
38
+ "DBPRUNE",
39
+ "DBP",
40
+ ):
41
+ return CacheType.DBPrune
42
+ elif type_hint.upper() in (
43
+ "NONE",
44
+ "NO_CACHE",
45
+ "NOCACHE",
46
+ "NC",
47
+ ):
48
+ return CacheType.NONE
32
49
  return CacheType.NONE
33
50
 
34
51
 
@@ -11,7 +11,7 @@ logger = init_logger(__name__)
11
11
  class ParamsModifier:
12
12
  def __init__(
13
13
  self,
14
- # Basic DBCache config: BasicCacheConfig
14
+ # BasicCacheConfig, DBCacheConfig, DBPruneConfig, etc.
15
15
  cache_config: BasicCacheConfig = None,
16
16
  # Calibrator config: TaylorSeerCalibratorConfig, etc.
17
17
  calibrator_config: Optional[CalibratorConfig] = None,
@@ -22,7 +22,7 @@ class ParamsModifier:
22
22
 
23
23
  # WARNING: Deprecated cache config params. These parameters are now retained
24
24
  # for backward compatibility but will be removed in the future.
25
- deprecated_cache_kwargs = {
25
+ deprecated_kwargs = {
26
26
  "Fn_compute_blocks": kwargs.get("Fn_compute_blocks", None),
27
27
  "Bn_compute_blocks": kwargs.get("Bn_compute_blocks", None),
28
28
  "max_warmup_steps": kwargs.get("max_warmup_steps", None),
@@ -40,20 +40,20 @@ class ParamsModifier:
40
40
  ),
41
41
  }
42
42
 
43
- deprecated_cache_kwargs = {
44
- k: v for k, v in deprecated_cache_kwargs.items() if v is not None
43
+ deprecated_kwargs = {
44
+ k: v for k, v in deprecated_kwargs.items() if v is not None
45
45
  }
46
46
 
47
- if deprecated_cache_kwargs:
47
+ if deprecated_kwargs:
48
48
  logger.warning(
49
49
  "Manually settup DBCache context without BasicCacheConfig is "
50
50
  "deprecated and will be removed in the future, please use "
51
51
  "`cache_config` parameter instead!"
52
52
  )
53
53
  if cache_config is not None:
54
- cache_config.update(**deprecated_cache_kwargs)
54
+ cache_config.update(**deprecated_kwargs)
55
55
  else:
56
- cache_config = BasicCacheConfig(**deprecated_cache_kwargs)
56
+ cache_config = BasicCacheConfig(**deprecated_kwargs)
57
57
 
58
58
  if cache_config is not None:
59
59
  self._context_kwargs["cache_config"] = cache_config
@@ -7,10 +7,6 @@ def load_cache_options_from_yaml(yaml_file_path):
7
7
  kwargs: dict = yaml.safe_load(f)
8
8
 
9
9
  required_keys = [
10
- "max_warmup_steps",
11
- "max_cached_steps",
12
- "Fn_compute_blocks",
13
- "Bn_compute_blocks",
14
10
  "residual_diff_threshold",
15
11
  ]
16
12
  for key in required_keys:
@@ -38,10 +34,25 @@ def load_cache_options_from_yaml(yaml_file_path):
38
34
  )
39
35
  )
40
36
 
41
- from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
37
+ if "cache_type" not in kwargs:
38
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
42
39
 
43
- cache_context_kwargs["cache_config"] = BasicCacheConfig()
44
- cache_context_kwargs["cache_config"].update(**kwargs)
40
+ cache_context_kwargs["cache_config"] = BasicCacheConfig()
41
+ cache_context_kwargs["cache_config"].update(**kwargs)
42
+ else:
43
+ cache_type = kwargs.pop("cache_type")
44
+ if cache_type == "DBCache":
45
+ from cache_dit.cache_factory.cache_contexts import DBCacheConfig
46
+
47
+ cache_context_kwargs["cache_config"] = DBCacheConfig()
48
+ cache_context_kwargs["cache_config"].update(**kwargs)
49
+ elif cache_type == "DBPrune":
50
+ from cache_dit.cache_factory.cache_contexts import DBPruneConfig
51
+
52
+ cache_context_kwargs["cache_config"] = DBPruneConfig()
53
+ cache_context_kwargs["cache_config"].update(**kwargs)
54
+ else:
55
+ raise ValueError(f"Unsupported cache_type: {cache_type}.")
45
56
 
46
57
  return cache_context_kwargs
47
58
 
@@ -83,11 +83,18 @@ def quantize_ao(
83
83
  def _quantization_fn():
84
84
  try:
85
85
  if quant_type == "fp8_w8a8_dq":
86
- from torchao.quantization import (
87
- float8_dynamic_activation_float8_weight,
88
- PerTensor,
89
- PerRow,
90
- )
86
+ try:
87
+ from torchao.quantization import (
88
+ float8_dynamic_activation_float8_weight,
89
+ PerTensor,
90
+ PerRow,
91
+ )
92
+ except ImportError:
93
+ from torchao.quantization import (
94
+ Float8DynamicActivationFloat8WeightConfig as float8_dynamic_activation_float8_weight,
95
+ PerTensor,
96
+ PerRow,
97
+ )
91
98
 
92
99
  if per_row: # Ensure bfloat16
93
100
  module.to(torch.bfloat16)
@@ -109,7 +116,12 @@ def quantize_ao(
109
116
  )
110
117
 
111
118
  elif quant_type == "fp8_w8a16_wo":
112
- from torchao.quantization import float8_weight_only
119
+ try:
120
+ from torchao.quantization import float8_weight_only
121
+ except ImportError:
122
+ from torchao.quantization import (
123
+ Float8WeightOnlyConfig as float8_weight_only,
124
+ )
113
125
 
114
126
  quantization_fn = float8_weight_only(
115
127
  weight_dtype=kwargs.get(
@@ -119,14 +131,25 @@ def quantize_ao(
119
131
  )
120
132
 
121
133
  elif quant_type == "int8_w8a8_dq":
122
- from torchao.quantization import (
123
- int8_dynamic_activation_int8_weight,
124
- )
134
+ try:
135
+ from torchao.quantization import (
136
+ int8_dynamic_activation_int8_weight,
137
+ )
138
+ except ImportError:
139
+ from torchao.quantization import (
140
+ Int8DynamicActivationInt8WeightConfig as int8_dynamic_activation_int8_weight,
141
+ )
125
142
 
126
143
  quantization_fn = int8_dynamic_activation_int8_weight()
127
144
 
128
145
  elif quant_type == "int8_w8a16_wo":
129
- from torchao.quantization import int8_weight_only
146
+
147
+ try:
148
+ from torchao.quantization import int8_weight_only
149
+ except ImportError:
150
+ from torchao.quantization import (
151
+ Int8WeightOnlyConfig as int8_weight_only,
152
+ )
130
153
 
131
154
  quantization_fn = int8_weight_only(
132
155
  # group_size is None -> per_channel, else per group
@@ -134,23 +157,41 @@ def quantize_ao(
134
157
  )
135
158
 
136
159
  elif quant_type == "int4_w4a8_dq":
137
- from torchao.quantization import (
138
- int8_dynamic_activation_int4_weight,
139
- )
160
+
161
+ try:
162
+ from torchao.quantization import (
163
+ int8_dynamic_activation_int4_weight,
164
+ )
165
+ except ImportError:
166
+ from torchao.quantization import (
167
+ Int8DynamicActivationInt4WeightConfig as int8_dynamic_activation_int4_weight,
168
+ )
140
169
 
141
170
  quantization_fn = int8_dynamic_activation_int4_weight(
142
171
  group_size=kwargs.get("group_size", 32),
143
172
  )
144
173
 
145
174
  elif quant_type == "int4_w4a4_dq":
146
- from torchao.quantization import (
147
- int4_dynamic_activation_int4_weight,
148
- )
175
+
176
+ try:
177
+ from torchao.quantization import (
178
+ int4_dynamic_activation_int4_weight,
179
+ )
180
+ except ImportError:
181
+ from torchao.quantization import (
182
+ Int4DynamicActivationInt4WeightConfig as int4_dynamic_activation_int4_weight,
183
+ )
149
184
 
150
185
  quantization_fn = int4_dynamic_activation_int4_weight()
151
186
 
152
187
  elif quant_type == "int4_w4a16_wo":
153
- from torchao.quantization import int4_weight_only
188
+
189
+ try:
190
+ from torchao.quantization import int4_weight_only
191
+ except ImportError:
192
+ from torchao.quantization import (
193
+ Int4WeightOnlyConfig as int4_weight_only,
194
+ )
154
195
 
155
196
  quantization_fn = int4_weight_only(
156
197
  group_size=kwargs.get("group_size", 32),