cache-dit 0.3.1__py3-none-any.whl → 0.3.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (30) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +8 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +47 -14
  7. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  8. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  9. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  10. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  11. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  12. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  13. cache_dit/cache_factory/cache_interface.py +128 -111
  14. cache_dit/cache_factory/params_modifier.py +87 -0
  15. cache_dit/metrics/__init__.py +3 -1
  16. cache_dit/utils.py +12 -21
  17. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/METADATA +78 -64
  18. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/RECORD +23 -28
  19. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  20. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  21. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  22. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  23. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  24. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  25. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  26. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  27. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/WHEEL +0 -0
  28. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/entry_points.txt +0 -0
  29. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/licenses/LICENSE +0 -0
  30. {cache_dit-0.3.1.dist-info → cache_dit-0.3.2.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -18,6 +18,7 @@ from cache_dit.cache_factory import BlockAdapter
18
18
  from cache_dit.cache_factory import ParamsModifier
19
19
  from cache_dit.cache_factory import ForwardPattern
20
20
  from cache_dit.cache_factory import PatchFunctor
21
+ from cache_dit.cache_factory import BasicCacheConfig
21
22
  from cache_dit.cache_factory import CalibratorConfig
22
23
  from cache_dit.cache_factory import TaylorSeerCalibratorConfig
23
24
  from cache_dit.cache_factory import FoCaCalibratorConfig
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.1'
32
- __version_tuple__ = version_tuple = (0, 3, 1)
31
+ __version__ = version = '0.3.2'
32
+ __version_tuple__ = version_tuple = (0, 3, 2)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -3,25 +3,22 @@ from cache_dit.cache_factory.cache_types import cache_type
3
3
  from cache_dit.cache_factory.cache_types import block_range
4
4
 
5
5
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
6
-
6
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
7
7
  from cache_dit.cache_factory.patch_functors import PatchFunctor
8
8
 
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
- from cache_dit.cache_factory.block_adapters import ParamsModifier
11
10
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
12
11
 
13
12
  from cache_dit.cache_factory.cache_contexts import CachedContext
13
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
14
14
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
15
- from cache_dit.cache_factory.cache_contexts import CachedContextV2
16
- from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
17
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig # no V1
15
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
18
16
  from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
19
17
  from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
20
18
 
21
19
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
22
20
 
23
21
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
24
- from cache_dit.cache_factory.cache_adapters import CachedAdapterV2
25
22
 
26
23
  from cache_dit.cache_factory.cache_interface import enable_cache
27
24
  from cache_dit.cache_factory.cache_interface import disable_cache
@@ -7,73 +7,15 @@ from collections.abc import Iterable
7
7
  from typing import Any, Tuple, List, Optional, Union
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
11
10
  from cache_dit.cache_factory.patch_functors import PatchFunctor
12
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
11
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
12
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
13
13
 
14
14
  from cache_dit.logger import init_logger
15
15
 
16
16
  logger = init_logger(__name__)
17
17
 
18
18
 
19
- class ParamsModifier:
20
- def __init__(
21
- self,
22
- # Cache context kwargs
23
- Fn_compute_blocks: Optional[int] = None,
24
- Bn_compute_blocks: Optional[int] = None,
25
- max_warmup_steps: Optional[int] = None,
26
- max_cached_steps: Optional[int] = None,
27
- max_continuous_cached_steps: Optional[int] = None,
28
- residual_diff_threshold: Optional[float] = None,
29
- # Cache CFG or not
30
- enable_separate_cfg: Optional[bool] = None,
31
- cfg_compute_first: Optional[bool] = None,
32
- cfg_diff_compute_separate: Optional[bool] = None,
33
- # Hybird TaylorSeer
34
- enable_taylorseer: Optional[bool] = None,
35
- enable_encoder_taylorseer: Optional[bool] = None,
36
- taylorseer_cache_type: Optional[str] = None,
37
- taylorseer_order: Optional[int] = None,
38
- # New param only for v2 API
39
- calibrator_config: Optional[CalibratorConfig] = None,
40
- **other_cache_context_kwargs,
41
- ):
42
- self._context_kwargs = other_cache_context_kwargs.copy()
43
- self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
44
- self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
45
- self._maybe_update_param("max_warmup_steps", max_warmup_steps)
46
- self._maybe_update_param("max_cached_steps", max_cached_steps)
47
- self._maybe_update_param(
48
- "max_continuous_cached_steps", max_continuous_cached_steps
49
- )
50
- self._maybe_update_param(
51
- "residual_diff_threshold", residual_diff_threshold
52
- )
53
- self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
54
- self._maybe_update_param("cfg_compute_first", cfg_compute_first)
55
- self._maybe_update_param(
56
- "cfg_diff_compute_separate", cfg_diff_compute_separate
57
- )
58
- # V1 only supports the Taylorseer calibrator. We have decided to
59
- # keep this code for API compatibility reasons.
60
- if calibrator_config is None:
61
- self._maybe_update_param("enable_taylorseer", enable_taylorseer)
62
- self._maybe_update_param(
63
- "enable_encoder_taylorseer", enable_encoder_taylorseer
64
- )
65
- self._maybe_update_param(
66
- "taylorseer_cache_type", taylorseer_cache_type
67
- )
68
- self._maybe_update_param("taylorseer_order", taylorseer_order)
69
- else:
70
- self._maybe_update_param("calibrator_config", calibrator_config)
71
-
72
- def _maybe_update_param(self, key: str, value: Any):
73
- if value is not None:
74
- self._context_kwargs[key] = value
75
-
76
-
77
19
  @dataclasses.dataclass
78
20
  class BlockAdapter:
79
21
 
@@ -123,10 +65,12 @@ class BlockAdapter:
123
65
  ] = None
124
66
 
125
67
  # modify cache context params for specific blocks.
126
- params_modifiers: Union[
127
- ParamsModifier,
128
- List[ParamsModifier],
129
- List[List[ParamsModifier]],
68
+ params_modifiers: Optional[
69
+ Union[
70
+ ParamsModifier,
71
+ List[ParamsModifier],
72
+ List[List[ParamsModifier]],
73
+ ]
130
74
  ] = None
131
75
 
132
76
  check_forward_pattern: bool = True
@@ -1,2 +1 @@
1
1
  from cache_dit.cache_factory.cache_adapters.cache_adapter import CachedAdapter
2
- from cache_dit.cache_factory.cache_adapters.v2 import CachedAdapterV2
@@ -13,6 +13,8 @@ from cache_dit.cache_factory.block_adapters import BlockAdapter
13
13
  from cache_dit.cache_factory.block_adapters import ParamsModifier
14
14
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
15
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
16
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
17
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
18
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
19
  from cache_dit.cache_factory.cache_blocks.utils import (
18
20
  patch_cached_stats,
@@ -55,6 +57,12 @@ class CachedAdapter:
55
57
  block_adapter = BlockAdapterRegistry.get_adapter(
56
58
  pipe_or_adapter
57
59
  )
60
+ if params_modifiers := cache_context_kwargs.pop(
61
+ "params_modifiers",
62
+ None,
63
+ ):
64
+ block_adapter.params_modifiers = params_modifiers
65
+
58
66
  return cls.cachify(
59
67
  block_adapter,
60
68
  **cache_context_kwargs,
@@ -69,6 +77,12 @@ class CachedAdapter:
69
77
  logger.info(
70
78
  "Adapting Cache Acceleration using custom BlockAdapter!"
71
79
  )
80
+ if pipe_or_adapter.params_modifiers is None:
81
+ if params_modifiers := cache_context_kwargs.pop(
82
+ "params_modifiers", None
83
+ ):
84
+ pipe_or_adapter.params_modifiers = params_modifiers
85
+
72
86
  return cls.cachify(
73
87
  pipe_or_adapter,
74
88
  **cache_context_kwargs,
@@ -114,33 +128,36 @@ class CachedAdapter:
114
128
  **cache_context_kwargs,
115
129
  ):
116
130
  # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_separate_cfg"] is None:
131
+ cache_config: BasicCacheConfig = cache_context_kwargs[
132
+ "cache_config"
133
+ ] # ref
134
+ assert cache_config is not None, "cache_config can not be None."
135
+ if cache_config.enable_separate_cfg is None:
118
136
  # Check cfg for some specific case if users don't set it as True
119
137
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_separate_cfg"] = True
138
+ cache_config.enable_separate_cfg = True
121
139
  logger.info(
122
140
  f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
141
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
142
  )
125
143
  else:
126
- cache_context_kwargs["enable_separate_cfg"] = (
144
+ cache_config.enable_separate_cfg = (
127
145
  BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
146
  )
129
147
  logger.info(
130
148
  f"Use default 'enable_separate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_separate_cfg']}, "
149
+ f"register: {cache_config.enable_separate_cfg}, "
132
150
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
151
  )
134
152
  else:
135
153
  logger.info(
136
154
  f"Use custom 'enable_separate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
155
+ f"kwargs: {cache_config.enable_separate_cfg}. "
138
156
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
157
  )
140
158
 
141
- if (
142
- cache_type := cache_context_kwargs.pop("cache_type", None)
143
- ) is not None:
159
+ cache_type = cache_context_kwargs.pop("cache_type", None)
160
+ if cache_type is not None:
144
161
  assert (
145
162
  cache_type == CacheType.DBCache
146
163
  ), "Custom cache setting only support for DBCache now!"
@@ -176,7 +193,7 @@ class CachedAdapter:
176
193
  block_adapter.pipe._cache_manager = cache_manager # instance level
177
194
 
178
195
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
- block_adapter, cache_manager, **cache_context_kwargs
196
+ block_adapter, **cache_context_kwargs
180
197
  )
181
198
 
182
199
  original_call = block_adapter.pipe.__class__.__call__
@@ -212,7 +229,6 @@ class CachedAdapter:
212
229
  def modify_context_params(
213
230
  cls,
214
231
  block_adapter: BlockAdapter,
215
- cache_manager: CachedContextManager,
216
232
  **cache_context_kwargs,
217
233
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
234
 
@@ -230,6 +246,8 @@ class CachedAdapter:
230
246
  contexts_kwargs[i]["name"] = flatten_contexts[i]
231
247
 
232
248
  if block_adapter.params_modifiers is None:
249
+ for i in range(len(contexts_kwargs)):
250
+ cls._config_messages(**contexts_kwargs[i])
233
251
  return flatten_contexts, contexts_kwargs
234
252
 
235
253
  flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
@@ -242,12 +260,26 @@ class CachedAdapter:
242
260
  contexts_kwargs[i].update(
243
261
  flatten_modifiers[i]._context_kwargs,
244
262
  )
245
- contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
- default_attrs={}, **contexts_kwargs[i]
247
- )
263
+ cls._config_messages(**contexts_kwargs[i])
248
264
 
249
265
  return flatten_contexts, contexts_kwargs
250
266
 
267
+ @classmethod
268
+ def _config_messages(cls, **contexts_kwargs):
269
+ cache_config: BasicCacheConfig = contexts_kwargs.get(
270
+ "cache_config", None
271
+ )
272
+ calibrator_config: CalibratorConfig = contexts_kwargs.get(
273
+ "calibrator_config", None
274
+ )
275
+ if cache_config is not None:
276
+ message = f"Collected Cache Config: {cache_config.strify()}"
277
+ if calibrator_config is not None:
278
+ message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
279
+ else:
280
+ message += ", Calibrator Config: None"
281
+ logger.info(message)
282
+
251
283
  @classmethod
252
284
  def mock_blocks(
253
285
  cls,
@@ -335,7 +367,8 @@ class CachedAdapter:
335
367
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
336
368
  assert hasattr(block_adapter.pipe, "_cache_manager")
337
369
  assert isinstance(
338
- block_adapter.pipe._cache_manager, CachedContextManager
370
+ block_adapter.pipe._cache_manager,
371
+ CachedContextManager,
339
372
  )
340
373
 
341
374
  for i in range(len(block_adapter.transformer)):
@@ -1,12 +1,14 @@
1
- # namespace alias: for _CachedContext and many others' cache context funcs.
2
- from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
3
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
4
- CachedContextManager,
5
- )
6
- from cache_dit.cache_factory.cache_contexts.v2 import (
7
- CachedContextV2,
8
- CachedContextManagerV2,
1
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
2
+ Calibrator,
3
+ CalibratorBase,
9
4
  CalibratorConfig,
10
5
  TaylorSeerCalibratorConfig,
11
6
  FoCaCalibratorConfig,
12
7
  )
8
+ from cache_dit.cache_factory.cache_contexts.cache_context import (
9
+ CachedContext,
10
+ BasicCacheConfig,
11
+ )
12
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
+ CachedContextManager,
14
+ )