cache-dit 0.3.0__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cache_dit/__init__.py CHANGED
@@ -18,6 +18,10 @@ from cache_dit.cache_factory import BlockAdapter
18
18
  from cache_dit.cache_factory import ParamsModifier
19
19
  from cache_dit.cache_factory import ForwardPattern
20
20
  from cache_dit.cache_factory import PatchFunctor
21
+ from cache_dit.cache_factory import BasicCacheConfig
22
+ from cache_dit.cache_factory import CalibratorConfig
23
+ from cache_dit.cache_factory import TaylorSeerCalibratorConfig
24
+ from cache_dit.cache_factory import FoCaCalibratorConfig
21
25
  from cache_dit.cache_factory import supported_pipelines
22
26
  from cache_dit.cache_factory import get_adapter
23
27
  from cache_dit.compile import set_compile_configs
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0'
32
- __version_tuple__ = version_tuple = (0, 3, 0)
31
+ __version__ = version = '0.3.2'
32
+ __version_tuple__ = version_tuple = (0, 3, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -3,15 +3,19 @@ from cache_dit.cache_factory.cache_types import cache_type
3
3
  from cache_dit.cache_factory.cache_types import block_range
4
4
 
5
5
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
6
-
6
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
7
7
  from cache_dit.cache_factory.patch_functors import PatchFunctor
8
8
 
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
- from cache_dit.cache_factory.block_adapters import ParamsModifier
11
10
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
12
11
 
13
12
  from cache_dit.cache_factory.cache_contexts import CachedContext
13
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
14
14
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
15
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
+ from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
17
+ from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
18
+
15
19
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
16
20
 
17
21
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
@@ -7,63 +7,15 @@ from collections.abc import Iterable
7
7
  from typing import Any, Tuple, List, Optional, Union
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
11
10
  from cache_dit.cache_factory.patch_functors import PatchFunctor
11
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
12
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
12
13
 
13
14
  from cache_dit.logger import init_logger
14
15
 
15
16
  logger = init_logger(__name__)
16
17
 
17
18
 
18
- class ParamsModifier:
19
- def __init__(
20
- self,
21
- # Cache context kwargs
22
- Fn_compute_blocks: Optional[int] = None,
23
- Bn_compute_blocks: Optional[int] = None,
24
- max_warmup_steps: Optional[int] = None,
25
- max_cached_steps: Optional[int] = None,
26
- max_continuous_cached_steps: Optional[int] = None,
27
- residual_diff_threshold: Optional[float] = None,
28
- # Cache CFG or not
29
- enable_separate_cfg: Optional[bool] = None,
30
- cfg_compute_first: Optional[bool] = None,
31
- cfg_diff_compute_separate: Optional[bool] = None,
32
- # Hybird TaylorSeer
33
- enable_taylorseer: Optional[bool] = None,
34
- enable_encoder_taylorseer: Optional[bool] = None,
35
- taylorseer_cache_type: Optional[str] = None,
36
- taylorseer_order: Optional[int] = None,
37
- **other_cache_context_kwargs,
38
- ):
39
- self._context_kwargs = other_cache_context_kwargs.copy()
40
- self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
41
- self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
42
- self._maybe_update_param("max_warmup_steps", max_warmup_steps)
43
- self._maybe_update_param("max_cached_steps", max_cached_steps)
44
- self._maybe_update_param(
45
- "max_continuous_cached_steps", max_continuous_cached_steps
46
- )
47
- self._maybe_update_param(
48
- "residual_diff_threshold", residual_diff_threshold
49
- )
50
- self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
51
- self._maybe_update_param("cfg_compute_first", cfg_compute_first)
52
- self._maybe_update_param(
53
- "cfg_diff_compute_separate", cfg_diff_compute_separate
54
- )
55
- self._maybe_update_param("enable_taylorseer", enable_taylorseer)
56
- self._maybe_update_param(
57
- "enable_encoder_taylorseer", enable_encoder_taylorseer
58
- )
59
- self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
60
- self._maybe_update_param("taylorseer_order", taylorseer_order)
61
-
62
- def _maybe_update_param(self, key: str, value: Any):
63
- if value is not None:
64
- self._context_kwargs[key] = value
65
-
66
-
67
19
  @dataclasses.dataclass
68
20
  class BlockAdapter:
69
21
 
@@ -113,10 +65,12 @@ class BlockAdapter:
113
65
  ] = None
114
66
 
115
67
  # modify cache context params for specific blocks.
116
- params_modifiers: Union[
117
- ParamsModifier,
118
- List[ParamsModifier],
119
- List[List[ParamsModifier]],
68
+ params_modifiers: Optional[
69
+ Union[
70
+ ParamsModifier,
71
+ List[ParamsModifier],
72
+ List[List[ParamsModifier]],
73
+ ]
120
74
  ] = None
121
75
 
122
76
  check_forward_pattern: bool = True
@@ -0,0 +1 @@
1
+ from cache_dit.cache_factory.cache_adapters.cache_adapter import CachedAdapter
@@ -8,12 +8,14 @@ from typing import Dict, List, Tuple, Any, Union, Callable
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
10
 
11
- from cache_dit.cache_factory import CacheType
12
- from cache_dit.cache_factory import BlockAdapter
13
- from cache_dit.cache_factory import ParamsModifier
14
- from cache_dit.cache_factory import BlockAdapterRegistry
15
- from cache_dit.cache_factory import CachedContextManager
16
- from cache_dit.cache_factory import CachedBlocks
11
+ from cache_dit.cache_factory.cache_types import CacheType
12
+ from cache_dit.cache_factory.block_adapters import BlockAdapter
13
+ from cache_dit.cache_factory.block_adapters import ParamsModifier
14
+ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
+ from cache_dit.cache_factory.cache_contexts import CachedContextManager
16
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
17
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
18
+ from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
19
  from cache_dit.cache_factory.cache_blocks.utils import (
18
20
  patch_cached_stats,
19
21
  remove_cached_stats,
@@ -55,6 +57,12 @@ class CachedAdapter:
55
57
  block_adapter = BlockAdapterRegistry.get_adapter(
56
58
  pipe_or_adapter
57
59
  )
60
+ if params_modifiers := cache_context_kwargs.pop(
61
+ "params_modifiers",
62
+ None,
63
+ ):
64
+ block_adapter.params_modifiers = params_modifiers
65
+
58
66
  return cls.cachify(
59
67
  block_adapter,
60
68
  **cache_context_kwargs,
@@ -69,6 +77,12 @@ class CachedAdapter:
69
77
  logger.info(
70
78
  "Adapting Cache Acceleration using custom BlockAdapter!"
71
79
  )
80
+ if pipe_or_adapter.params_modifiers is None:
81
+ if params_modifiers := cache_context_kwargs.pop(
82
+ "params_modifiers", None
83
+ ):
84
+ pipe_or_adapter.params_modifiers = params_modifiers
85
+
72
86
  return cls.cachify(
73
87
  pipe_or_adapter,
74
88
  **cache_context_kwargs,
@@ -114,33 +128,36 @@ class CachedAdapter:
114
128
  **cache_context_kwargs,
115
129
  ):
116
130
  # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_separate_cfg"] is None:
131
+ cache_config: BasicCacheConfig = cache_context_kwargs[
132
+ "cache_config"
133
+ ] # ref
134
+ assert cache_config is not None, "cache_config can not be None."
135
+ if cache_config.enable_separate_cfg is None:
118
136
  # Check cfg for some specific case if users don't set it as True
119
137
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_separate_cfg"] = True
138
+ cache_config.enable_separate_cfg = True
121
139
  logger.info(
122
140
  f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
141
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
142
  )
125
143
  else:
126
- cache_context_kwargs["enable_separate_cfg"] = (
144
+ cache_config.enable_separate_cfg = (
127
145
  BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
146
  )
129
147
  logger.info(
130
148
  f"Use default 'enable_separate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_separate_cfg']}, "
149
+ f"register: {cache_config.enable_separate_cfg}, "
132
150
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
151
  )
134
152
  else:
135
153
  logger.info(
136
154
  f"Use custom 'enable_separate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
155
+ f"kwargs: {cache_config.enable_separate_cfg}. "
138
156
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
157
  )
140
158
 
141
- if (
142
- cache_type := cache_context_kwargs.pop("cache_type", None)
143
- ) is not None:
159
+ cache_type = cache_context_kwargs.pop("cache_type", None)
160
+ if cache_type is not None:
144
161
  assert (
145
162
  cache_type == CacheType.DBCache
146
163
  ), "Custom cache setting only support for DBCache now!"
@@ -176,7 +193,7 @@ class CachedAdapter:
176
193
  block_adapter.pipe._cache_manager = cache_manager # instance level
177
194
 
178
195
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
- block_adapter, cache_manager, **cache_context_kwargs
196
+ block_adapter, **cache_context_kwargs
180
197
  )
181
198
 
182
199
  original_call = block_adapter.pipe.__class__.__call__
@@ -212,7 +229,6 @@ class CachedAdapter:
212
229
  def modify_context_params(
213
230
  cls,
214
231
  block_adapter: BlockAdapter,
215
- cache_manager: CachedContextManager,
216
232
  **cache_context_kwargs,
217
233
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
234
 
@@ -230,6 +246,8 @@ class CachedAdapter:
230
246
  contexts_kwargs[i]["name"] = flatten_contexts[i]
231
247
 
232
248
  if block_adapter.params_modifiers is None:
249
+ for i in range(len(contexts_kwargs)):
250
+ cls._config_messages(**contexts_kwargs[i])
233
251
  return flatten_contexts, contexts_kwargs
234
252
 
235
253
  flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
@@ -242,12 +260,26 @@ class CachedAdapter:
242
260
  contexts_kwargs[i].update(
243
261
  flatten_modifiers[i]._context_kwargs,
244
262
  )
245
- contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
- default_attrs={}, **contexts_kwargs[i]
247
- )
263
+ cls._config_messages(**contexts_kwargs[i])
248
264
 
249
265
  return flatten_contexts, contexts_kwargs
250
266
 
267
+ @classmethod
268
+ def _config_messages(cls, **contexts_kwargs):
269
+ cache_config: BasicCacheConfig = contexts_kwargs.get(
270
+ "cache_config", None
271
+ )
272
+ calibrator_config: CalibratorConfig = contexts_kwargs.get(
273
+ "calibrator_config", None
274
+ )
275
+ if cache_config is not None:
276
+ message = f"Collected Cache Config: {cache_config.strify()}"
277
+ if calibrator_config is not None:
278
+ message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
279
+ else:
280
+ message += ", Calibrator Config: None"
281
+ logger.info(message)
282
+
251
283
  @classmethod
252
284
  def mock_blocks(
253
285
  cls,
@@ -335,7 +367,8 @@ class CachedAdapter:
335
367
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
336
368
  assert hasattr(block_adapter.pipe, "_cache_manager")
337
369
  assert isinstance(
338
- block_adapter.pipe._cache_manager, CachedContextManager
370
+ block_adapter.pipe._cache_manager,
371
+ CachedContextManager,
339
372
  )
340
373
 
341
374
  for i in range(len(block_adapter.transformer)):
@@ -1,5 +1,14 @@
1
- # namespace alias: for _CachedContext and many others' cache context funcs.
2
- from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
1
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
2
+ Calibrator,
3
+ CalibratorBase,
4
+ CalibratorConfig,
5
+ TaylorSeerCalibratorConfig,
6
+ FoCaCalibratorConfig,
7
+ )
8
+ from cache_dit.cache_factory.cache_contexts.cache_context import (
9
+ CachedContext,
10
+ BasicCacheConfig,
11
+ )
3
12
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
4
13
  CachedContextManager,
5
14
  )