cache-dit 0.3.1__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (34) hide show
  1. cache_dit/__init__.py +1 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +3 -6
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +21 -64
  5. cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
  6. cache_dit/cache_factory/cache_adapters/cache_adapter.py +82 -21
  7. cache_dit/cache_factory/cache_blocks/__init__.py +4 -0
  8. cache_dit/cache_factory/cache_blocks/offload_utils.py +115 -0
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +3 -0
  10. cache_dit/cache_factory/cache_contexts/__init__.py +10 -8
  11. cache_dit/cache_factory/cache_contexts/cache_context.py +186 -117
  12. cache_dit/cache_factory/cache_contexts/cache_manager.py +63 -131
  13. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +132 -0
  14. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/foca.py +1 -1
  15. cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/taylorseer.py +7 -2
  16. cache_dit/cache_factory/cache_interface.py +128 -111
  17. cache_dit/cache_factory/params_modifier.py +87 -0
  18. cache_dit/metrics/__init__.py +3 -1
  19. cache_dit/utils.py +12 -21
  20. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/METADATA +200 -434
  21. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/RECORD +27 -31
  22. cache_dit/cache_factory/cache_adapters/v2/__init__.py +0 -3
  23. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +0 -524
  24. cache_dit/cache_factory/cache_contexts/taylorseer.py +0 -102
  25. cache_dit/cache_factory/cache_contexts/v2/__init__.py +0 -13
  26. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +0 -288
  27. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +0 -799
  28. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +0 -81
  29. /cache_dit/cache_factory/cache_blocks/{utils.py → pattern_utils.py} +0 -0
  30. /cache_dit/cache_factory/cache_contexts/{v2/calibrators → calibrators}/base.py +0 -0
  31. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/WHEEL +0 -0
  32. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/entry_points.txt +0 -0
  33. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/licenses/LICENSE +0 -0
  34. {cache_dit-0.3.1.dist-info → cache_dit-0.3.3.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -18,6 +18,7 @@ from cache_dit.cache_factory import BlockAdapter
18
18
  from cache_dit.cache_factory import ParamsModifier
19
19
  from cache_dit.cache_factory import ForwardPattern
20
20
  from cache_dit.cache_factory import PatchFunctor
21
+ from cache_dit.cache_factory import BasicCacheConfig
21
22
  from cache_dit.cache_factory import CalibratorConfig
22
23
  from cache_dit.cache_factory import TaylorSeerCalibratorConfig
23
24
  from cache_dit.cache_factory import FoCaCalibratorConfig
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.1'
32
- __version_tuple__ = version_tuple = (0, 3, 1)
31
+ __version__ = version = '0.3.3'
32
+ __version_tuple__ = version_tuple = (0, 3, 3)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -3,25 +3,22 @@ from cache_dit.cache_factory.cache_types import cache_type
3
3
  from cache_dit.cache_factory.cache_types import block_range
4
4
 
5
5
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
6
-
6
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
7
7
  from cache_dit.cache_factory.patch_functors import PatchFunctor
8
8
 
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
- from cache_dit.cache_factory.block_adapters import ParamsModifier
11
10
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
12
11
 
13
12
  from cache_dit.cache_factory.cache_contexts import CachedContext
13
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
14
14
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
15
- from cache_dit.cache_factory.cache_contexts import CachedContextV2
16
- from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
17
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig # no V1
15
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
18
16
  from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
19
17
  from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
20
18
 
21
19
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
22
20
 
23
21
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
24
- from cache_dit.cache_factory.cache_adapters import CachedAdapterV2
25
22
 
26
23
  from cache_dit.cache_factory.cache_interface import enable_cache
27
24
  from cache_dit.cache_factory.cache_interface import disable_cache
@@ -7,73 +7,15 @@ from collections.abc import Iterable
7
7
  from typing import Any, Tuple, List, Optional, Union
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
- from cache_dit.cache_factory.forward_pattern import ForwardPattern
11
10
  from cache_dit.cache_factory.patch_functors import PatchFunctor
12
- from cache_dit.cache_factory.cache_contexts import CalibratorConfig
11
+ from cache_dit.cache_factory.forward_pattern import ForwardPattern
12
+ from cache_dit.cache_factory.params_modifier import ParamsModifier
13
13
 
14
14
  from cache_dit.logger import init_logger
15
15
 
16
16
  logger = init_logger(__name__)
17
17
 
18
18
 
19
- class ParamsModifier:
20
- def __init__(
21
- self,
22
- # Cache context kwargs
23
- Fn_compute_blocks: Optional[int] = None,
24
- Bn_compute_blocks: Optional[int] = None,
25
- max_warmup_steps: Optional[int] = None,
26
- max_cached_steps: Optional[int] = None,
27
- max_continuous_cached_steps: Optional[int] = None,
28
- residual_diff_threshold: Optional[float] = None,
29
- # Cache CFG or not
30
- enable_separate_cfg: Optional[bool] = None,
31
- cfg_compute_first: Optional[bool] = None,
32
- cfg_diff_compute_separate: Optional[bool] = None,
33
- # Hybird TaylorSeer
34
- enable_taylorseer: Optional[bool] = None,
35
- enable_encoder_taylorseer: Optional[bool] = None,
36
- taylorseer_cache_type: Optional[str] = None,
37
- taylorseer_order: Optional[int] = None,
38
- # New param only for v2 API
39
- calibrator_config: Optional[CalibratorConfig] = None,
40
- **other_cache_context_kwargs,
41
- ):
42
- self._context_kwargs = other_cache_context_kwargs.copy()
43
- self._maybe_update_param("Fn_compute_blocks", Fn_compute_blocks)
44
- self._maybe_update_param("Bn_compute_blocks", Bn_compute_blocks)
45
- self._maybe_update_param("max_warmup_steps", max_warmup_steps)
46
- self._maybe_update_param("max_cached_steps", max_cached_steps)
47
- self._maybe_update_param(
48
- "max_continuous_cached_steps", max_continuous_cached_steps
49
- )
50
- self._maybe_update_param(
51
- "residual_diff_threshold", residual_diff_threshold
52
- )
53
- self._maybe_update_param("enable_separate_cfg", enable_separate_cfg)
54
- self._maybe_update_param("cfg_compute_first", cfg_compute_first)
55
- self._maybe_update_param(
56
- "cfg_diff_compute_separate", cfg_diff_compute_separate
57
- )
58
- # V1 only supports the Taylorseer calibrator. We have decided to
59
- # keep this code for API compatibility reasons.
60
- if calibrator_config is None:
61
- self._maybe_update_param("enable_taylorseer", enable_taylorseer)
62
- self._maybe_update_param(
63
- "enable_encoder_taylorseer", enable_encoder_taylorseer
64
- )
65
- self._maybe_update_param(
66
- "taylorseer_cache_type", taylorseer_cache_type
67
- )
68
- self._maybe_update_param("taylorseer_order", taylorseer_order)
69
- else:
70
- self._maybe_update_param("calibrator_config", calibrator_config)
71
-
72
- def _maybe_update_param(self, key: str, value: Any):
73
- if value is not None:
74
- self._context_kwargs[key] = value
75
-
76
-
77
19
  @dataclasses.dataclass
78
20
  class BlockAdapter:
79
21
 
@@ -123,10 +65,12 @@ class BlockAdapter:
123
65
  ] = None
124
66
 
125
67
  # modify cache context params for specific blocks.
126
- params_modifiers: Union[
127
- ParamsModifier,
128
- List[ParamsModifier],
129
- List[List[ParamsModifier]],
68
+ params_modifiers: Optional[
69
+ Union[
70
+ ParamsModifier,
71
+ List[ParamsModifier],
72
+ List[List[ParamsModifier]],
73
+ ]
130
74
  ] = None
131
75
 
132
76
  check_forward_pattern: bool = True
@@ -169,6 +113,19 @@ class BlockAdapter:
169
113
  if any((self.pipe is not None, self.transformer is not None)):
170
114
  self.maybe_fill_attrs()
171
115
  self.maybe_patchify()
116
+ self.maybe_skip_checks()
117
+
118
+ def maybe_skip_checks(self):
119
+ if getattr(self.transformer, "_hf_hook", None) is not None:
120
+ logger.warning("_hf_hook is not None, force skip pattern check!")
121
+ self.check_forward_pattern = False
122
+ self.check_num_outputs = False
123
+ elif getattr(self.transformer, "_diffusers_hook", None) is not None:
124
+ logger.warning(
125
+ "_diffusers_hook is not None, force skip pattern check!"
126
+ )
127
+ self.check_forward_pattern = False
128
+ self.check_num_outputs = False
172
129
 
173
130
  def maybe_fill_attrs(self):
174
131
  # NOTE: This func should be call before normalize.
@@ -1,2 +1 @@
1
1
  from cache_dit.cache_factory.cache_adapters.cache_adapter import CachedAdapter
2
- from cache_dit.cache_factory.cache_adapters.v2 import CachedAdapterV2
@@ -1,10 +1,8 @@
1
1
  import torch
2
-
3
2
  import unittest
4
3
  import functools
5
-
6
4
  from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any, Union, Callable
5
+ from typing import Dict, List, Tuple, Any, Union, Callable, Optional
8
6
 
9
7
  from diffusers import DiffusionPipeline
10
8
 
@@ -13,8 +11,10 @@ from cache_dit.cache_factory.block_adapters import BlockAdapter
13
11
  from cache_dit.cache_factory.block_adapters import ParamsModifier
14
12
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
13
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
14
+ from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
15
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
16
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
- from cache_dit.cache_factory.cache_blocks.utils import (
17
+ from cache_dit.cache_factory.cache_blocks import (
18
18
  patch_cached_stats,
19
19
  remove_cached_stats,
20
20
  )
@@ -55,6 +55,12 @@ class CachedAdapter:
55
55
  block_adapter = BlockAdapterRegistry.get_adapter(
56
56
  pipe_or_adapter
57
57
  )
58
+ if params_modifiers := cache_context_kwargs.pop(
59
+ "params_modifiers",
60
+ None,
61
+ ):
62
+ block_adapter.params_modifiers = params_modifiers
63
+
58
64
  return cls.cachify(
59
65
  block_adapter,
60
66
  **cache_context_kwargs,
@@ -69,6 +75,12 @@ class CachedAdapter:
69
75
  logger.info(
70
76
  "Adapting Cache Acceleration using custom BlockAdapter!"
71
77
  )
78
+ if pipe_or_adapter.params_modifiers is None:
79
+ if params_modifiers := cache_context_kwargs.pop(
80
+ "params_modifiers", None
81
+ ):
82
+ pipe_or_adapter.params_modifiers = params_modifiers
83
+
72
84
  return cls.cachify(
73
85
  pipe_or_adapter,
74
86
  **cache_context_kwargs,
@@ -114,33 +126,36 @@ class CachedAdapter:
114
126
  **cache_context_kwargs,
115
127
  ):
116
128
  # Check cache_context_kwargs
117
- if cache_context_kwargs["enable_separate_cfg"] is None:
129
+ cache_config: BasicCacheConfig = cache_context_kwargs[
130
+ "cache_config"
131
+ ] # ref
132
+ assert cache_config is not None, "cache_config can not be None."
133
+ if cache_config.enable_separate_cfg is None:
118
134
  # Check cfg for some specific case if users don't set it as True
119
135
  if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
- cache_context_kwargs["enable_separate_cfg"] = True
136
+ cache_config.enable_separate_cfg = True
121
137
  logger.info(
122
138
  f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
139
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
140
  )
125
141
  else:
126
- cache_context_kwargs["enable_separate_cfg"] = (
142
+ cache_config.enable_separate_cfg = (
127
143
  BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
144
  )
129
145
  logger.info(
130
146
  f"Use default 'enable_separate_cfg' from block adapter "
131
- f"register: {cache_context_kwargs['enable_separate_cfg']}, "
147
+ f"register: {cache_config.enable_separate_cfg}, "
132
148
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
149
  )
134
150
  else:
135
151
  logger.info(
136
152
  f"Use custom 'enable_separate_cfg' from cache context "
137
- f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
153
+ f"kwargs: {cache_config.enable_separate_cfg}. "
138
154
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
155
  )
140
156
 
141
- if (
142
- cache_type := cache_context_kwargs.pop("cache_type", None)
143
- ) is not None:
157
+ cache_type = cache_context_kwargs.pop("cache_type", None)
158
+ if cache_type is not None:
144
159
  assert (
145
160
  cache_type == CacheType.DBCache
146
161
  ), "Custom cache setting only support for DBCache now!"
@@ -176,7 +191,7 @@ class CachedAdapter:
176
191
  block_adapter.pipe._cache_manager = cache_manager # instance level
177
192
 
178
193
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
- block_adapter, cache_manager, **cache_context_kwargs
194
+ block_adapter, **cache_context_kwargs
180
195
  )
181
196
 
182
197
  original_call = block_adapter.pipe.__class__.__call__
@@ -212,7 +227,6 @@ class CachedAdapter:
212
227
  def modify_context_params(
213
228
  cls,
214
229
  block_adapter: BlockAdapter,
215
- cache_manager: CachedContextManager,
216
230
  **cache_context_kwargs,
217
231
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
232
 
@@ -230,6 +244,8 @@ class CachedAdapter:
230
244
  contexts_kwargs[i]["name"] = flatten_contexts[i]
231
245
 
232
246
  if block_adapter.params_modifiers is None:
247
+ for i in range(len(contexts_kwargs)):
248
+ cls._config_messages(**contexts_kwargs[i])
233
249
  return flatten_contexts, contexts_kwargs
234
250
 
235
251
  flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
@@ -242,12 +258,26 @@ class CachedAdapter:
242
258
  contexts_kwargs[i].update(
243
259
  flatten_modifiers[i]._context_kwargs,
244
260
  )
245
- contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
- default_attrs={}, **contexts_kwargs[i]
247
- )
261
+ cls._config_messages(**contexts_kwargs[i])
248
262
 
249
263
  return flatten_contexts, contexts_kwargs
250
264
 
265
+ @classmethod
266
+ def _config_messages(cls, **contexts_kwargs):
267
+ cache_config: BasicCacheConfig = contexts_kwargs.get(
268
+ "cache_config", None
269
+ )
270
+ calibrator_config: CalibratorConfig = contexts_kwargs.get(
271
+ "calibrator_config", None
272
+ )
273
+ if cache_config is not None:
274
+ message = f"Collected Cache Config: {cache_config.strify()}"
275
+ if calibrator_config is not None:
276
+ message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
277
+ else:
278
+ message += ", Calibrator Config: None"
279
+ logger.info(message)
280
+
251
281
  @classmethod
252
282
  def mock_blocks(
253
283
  cls,
@@ -298,7 +328,19 @@ class CachedAdapter:
298
328
 
299
329
  assert isinstance(dummy_blocks_names, list)
300
330
 
301
- @functools.wraps(original_forward)
331
+ from accelerate import hooks
332
+
333
+ _hf_hook: Optional[hooks.ModelHook] = None
334
+
335
+ if getattr(transformer, "_hf_hook", None) is not None:
336
+ _hf_hook = transformer._hf_hook # hooks from accelerate.hooks
337
+
338
+ # TODO: remove group offload hooks the re-apply after cache applied.
339
+ # hooks = _diffusers_hook.hooks.copy(); _diffusers_hook.hooks.clear()
340
+ # re-apply hooks to transformer after cache applied.
341
+ # from diffusers.hooks.hooks import HookFunctionReference, HookRegistry
342
+ # from diffusers.hooks.group_offloading import apply_group_offloading
343
+
302
344
  def new_forward(self, *args, **kwargs):
303
345
  with ExitStack() as stack:
304
346
  for name, context_name in zip(
@@ -316,9 +358,27 @@ class CachedAdapter:
316
358
  self, dummy_name, dummy_blocks
317
359
  )
318
360
  )
319
- return original_forward(*args, **kwargs)
361
+ outputs = original_forward(*args, **kwargs)
362
+ return outputs
363
+
364
+ def new_forward_with_hf_hook(self, *args, **kwargs):
365
+ # Compatible with model cpu offload
366
+ if _hf_hook is not None and hasattr(_hf_hook, "pre_forward"):
367
+ args, kwargs = _hf_hook.pre_forward(self, *args, **kwargs)
368
+
369
+ outputs = new_forward(self, *args, **kwargs)
370
+
371
+ if _hf_hook is not None and hasattr(_hf_hook, "post_forward"):
372
+ outputs = _hf_hook.post_forward(self, outputs)
373
+
374
+ return outputs
375
+
376
+ # NOTE: Still can't fully compatible with group offloading
377
+ transformer.forward = functools.update_wrapper(
378
+ functools.partial(new_forward_with_hf_hook, transformer),
379
+ new_forward_with_hf_hook,
380
+ )
320
381
 
321
- transformer.forward = new_forward.__get__(transformer)
322
382
  transformer._original_forward = original_forward
323
383
  transformer._is_cached = True
324
384
 
@@ -335,7 +395,8 @@ class CachedAdapter:
335
395
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
336
396
  assert hasattr(block_adapter.pipe, "_cache_manager")
337
397
  assert isinstance(
338
- block_adapter.pipe._cache_manager, CachedContextManager
398
+ block_adapter.pipe._cache_manager,
399
+ CachedContextManager,
339
400
  )
340
401
 
341
402
  for i in range(len(block_adapter.transformer)):
@@ -12,6 +12,10 @@ from cache_dit.cache_factory.cache_blocks.pattern_0_1_2 import (
12
12
  from cache_dit.cache_factory.cache_blocks.pattern_3_4_5 import (
13
13
  CachedBlocks_Pattern_3_4_5,
14
14
  )
15
+ from cache_dit.cache_factory.cache_blocks.pattern_utils import (
16
+ patch_cached_stats,
17
+ remove_cached_stats,
18
+ )
15
19
 
16
20
  from cache_dit.logger import init_logger
17
21
 
@@ -0,0 +1,115 @@
1
+ import torch
2
+ import asyncio
3
+ import logging
4
+ from contextlib import contextmanager
5
+ from typing import Generator, Optional, List
6
+ from diffusers.hooks.group_offloading import _is_group_offload_enabled
7
+ from cache_dit.logger import init_logger
8
+
9
+ logger = init_logger(__name__)
10
+
11
+
12
+ @torch.compiler.disable
13
+ @contextmanager
14
+ def maybe_onload(
15
+ block: torch.nn.Module,
16
+ reference_tensor: torch.Tensor,
17
+ pending_tasks: List[asyncio.Task] = [],
18
+ ) -> Generator:
19
+
20
+ if not _is_group_offload_enabled(block):
21
+ yield block
22
+ return
23
+
24
+ original_devices: Optional[List[torch.device]] = None
25
+ if hasattr(block, "parameters"):
26
+ params = list(block.parameters())
27
+ if params:
28
+ original_devices = [param.data.device for param in params]
29
+
30
+ target_device: torch.device = reference_tensor.device
31
+ move_task: Optional[asyncio.Task] = None
32
+ need_restore: bool = False
33
+
34
+ try:
35
+ if original_devices is not None:
36
+ unique_devices = list(set(original_devices))
37
+ if len(unique_devices) > 1 or unique_devices[0] != target_device:
38
+ if logger.isEnabledFor(logging.DEBUG):
39
+ logger.debug(
40
+ f"Onloading from {unique_devices} to {target_device}"
41
+ )
42
+
43
+ has_meta_params = any(
44
+ dev.type == "meta" for dev in original_devices
45
+ )
46
+ if has_meta_params: # compatible with sequential cpu offload
47
+ block = block.to_empty(device=target_device)
48
+ else:
49
+ block = block.to(target_device, non_blocking=False)
50
+ need_restore = True
51
+ yield block
52
+ finally:
53
+ if need_restore and original_devices:
54
+
55
+ async def restore_device():
56
+ for param, original_device in zip(
57
+ block.parameters(), original_devices
58
+ ):
59
+ param.data = await asyncio.to_thread(
60
+ lambda p, d: p.to(d, non_blocking=True),
61
+ param.data, # type: torch.Tensor
62
+ original_device, # type: torch.device
63
+ ) # type: ignore[assignment]
64
+
65
+ loop = get_event_loop()
66
+ move_task = loop.create_task(restore_device())
67
+ if move_task:
68
+ pending_tasks.append(move_task)
69
+
70
+
71
+ def get_event_loop() -> asyncio.AbstractEventLoop:
72
+ try:
73
+ loop = asyncio.get_running_loop()
74
+ except RuntimeError:
75
+ try:
76
+ loop = asyncio.get_event_loop()
77
+ except RuntimeError:
78
+ loop = asyncio.new_event_loop()
79
+ asyncio.set_event_loop(loop)
80
+
81
+ if not loop.is_running():
82
+
83
+ def run_loop() -> None:
84
+ asyncio.set_event_loop(loop)
85
+ loop.run_forever()
86
+
87
+ import threading
88
+
89
+ if not any(t.name == "_my_loop" for t in threading.enumerate()):
90
+ threading.Thread(
91
+ target=run_loop, name="_my_loop", daemon=True
92
+ ).start()
93
+
94
+ return loop
95
+
96
+
97
+ @torch.compiler.disable
98
+ def maybe_offload(
99
+ pending_tasks: List[asyncio.Task],
100
+ ) -> None:
101
+ if not pending_tasks:
102
+ return
103
+
104
+ loop = get_event_loop()
105
+
106
+ async def gather_tasks():
107
+ return await asyncio.gather(*pending_tasks)
108
+
109
+ future = asyncio.run_coroutine_threadsafe(gather_tasks(), loop)
110
+ try:
111
+ future.result(timeout=30.0)
112
+ except Exception as e:
113
+ logger.error(f"May Offload Error: {e}")
114
+
115
+ pending_tasks.clear()
@@ -1,7 +1,9 @@
1
1
  import inspect
2
+ import asyncio
2
3
  import torch
3
4
  import torch.distributed as dist
4
5
 
6
+ from typing import List
5
7
  from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
6
8
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
7
9
  CachedContextManager,
@@ -45,6 +47,7 @@ class CachedBlocks_Pattern_Base(torch.nn.Module):
45
47
  self.cache_prefix = cache_prefix
46
48
  self.cache_context = cache_context
47
49
  self.cache_manager = cache_manager
50
+ self.pending_tasks: List[asyncio.Task] = []
48
51
 
49
52
  self._check_forward_pattern()
50
53
  logger.info(
@@ -1,12 +1,14 @@
1
- # namespace alias: for _CachedContext and many others' cache context funcs.
2
- from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
3
- from cache_dit.cache_factory.cache_contexts.cache_manager import (
4
- CachedContextManager,
5
- )
6
- from cache_dit.cache_factory.cache_contexts.v2 import (
7
- CachedContextV2,
8
- CachedContextManagerV2,
1
+ from cache_dit.cache_factory.cache_contexts.calibrators import (
2
+ Calibrator,
3
+ CalibratorBase,
9
4
  CalibratorConfig,
10
5
  TaylorSeerCalibratorConfig,
11
6
  FoCaCalibratorConfig,
12
7
  )
8
+ from cache_dit.cache_factory.cache_contexts.cache_context import (
9
+ CachedContext,
10
+ BasicCacheConfig,
11
+ )
12
+ from cache_dit.cache_factory.cache_contexts.cache_manager import (
13
+ CachedContextManager,
14
+ )