cache-dit 1.0.3__py3-none-any.whl → 1.0.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (32) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +8 -1
  4. cache_dit/cache_factory/block_adapters/__init__.py +4 -1
  5. cache_dit/cache_factory/cache_adapters/cache_adapter.py +126 -80
  6. cache_dit/cache_factory/cache_blocks/__init__.py +167 -17
  7. cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +10 -0
  8. cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +256 -24
  9. cache_dit/cache_factory/cache_blocks/pattern_base.py +273 -38
  10. cache_dit/cache_factory/cache_blocks/pattern_utils.py +55 -10
  11. cache_dit/cache_factory/cache_contexts/__init__.py +15 -2
  12. cache_dit/cache_factory/cache_contexts/cache_config.py +118 -0
  13. cache_dit/cache_factory/cache_contexts/cache_context.py +15 -93
  14. cache_dit/cache_factory/cache_contexts/cache_manager.py +7 -7
  15. cache_dit/cache_factory/cache_contexts/calibrators/__init__.py +22 -0
  16. cache_dit/cache_factory/cache_contexts/calibrators/taylorseer.py +78 -8
  17. cache_dit/cache_factory/cache_contexts/context_manager.py +29 -0
  18. cache_dit/cache_factory/cache_contexts/prune_config.py +63 -0
  19. cache_dit/cache_factory/cache_contexts/prune_context.py +155 -0
  20. cache_dit/cache_factory/cache_contexts/prune_manager.py +154 -0
  21. cache_dit/cache_factory/cache_interface.py +20 -14
  22. cache_dit/cache_factory/cache_types.py +19 -2
  23. cache_dit/cache_factory/params_modifier.py +7 -7
  24. cache_dit/cache_factory/utils.py +18 -7
  25. cache_dit/quantize/quantize_ao.py +58 -17
  26. cache_dit/utils.py +191 -54
  27. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/METADATA +11 -10
  28. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/RECORD +32 -27
  29. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/WHEEL +0 -0
  30. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/entry_points.txt +0 -0
  31. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/licenses/LICENSE +0 -0
  32. {cache_dit-1.0.3.dist-info → cache_dit-1.0.5.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -19,6 +19,8 @@ from cache_dit.cache_factory import ParamsModifier
19
19
  from cache_dit.cache_factory import ForwardPattern
20
20
  from cache_dit.cache_factory import PatchFunctor
21
21
  from cache_dit.cache_factory import BasicCacheConfig
22
+ from cache_dit.cache_factory import DBCacheConfig
23
+ from cache_dit.cache_factory import DBPruneConfig
22
24
  from cache_dit.cache_factory import CalibratorConfig
23
25
  from cache_dit.cache_factory import TaylorSeerCalibratorConfig
24
26
  from cache_dit.cache_factory import FoCaCalibratorConfig
@@ -30,6 +32,7 @@ from cache_dit.quantize import quantize
30
32
 
31
33
  NONE = CacheType.NONE
32
34
  DBCache = CacheType.DBCache
35
+ DBPrune = CacheType.DBPrune
33
36
 
34
37
  Pattern_0 = ForwardPattern.Pattern_0
35
38
  Pattern_1 = ForwardPattern.Pattern_1
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '1.0.3'
32
- __version_tuple__ = version_tuple = (1, 0, 3)
31
+ __version__ = version = '1.0.5'
32
+ __version_tuple__ = version_tuple = (1, 0, 5)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -9,14 +9,21 @@ from cache_dit.cache_factory.patch_functors import PatchFunctor
9
9
  from cache_dit.cache_factory.block_adapters import BlockAdapter
10
10
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
11
11
 
12
- from cache_dit.cache_factory.cache_contexts import CachedContext
13
12
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
13
+ from cache_dit.cache_factory.cache_contexts import DBCacheConfig
14
+ from cache_dit.cache_factory.cache_contexts import CachedContext
14
15
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
16
+ from cache_dit.cache_factory.cache_contexts import DBPruneConfig
17
+ from cache_dit.cache_factory.cache_contexts import PrunedContext
18
+ from cache_dit.cache_factory.cache_contexts import PrunedContextManager
19
+ from cache_dit.cache_factory.cache_contexts import ContextManager
15
20
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
21
  from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
17
22
  from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
18
23
 
19
24
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
25
+ from cache_dit.cache_factory.cache_blocks import PrunedBlocks
26
+ from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
20
27
 
21
28
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
22
29
 
@@ -12,7 +12,10 @@ def flux_adapter(pipe, **kwargs) -> BlockAdapter:
12
12
  from cache_dit.utils import is_diffusers_at_least_0_3_5
13
13
 
14
14
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
15
- if is_diffusers_at_least_0_3_5():
15
+ transformer_cls_name: str = pipe.transformer.__class__.__name__
16
+ if is_diffusers_at_least_0_3_5() and not transformer_cls_name.startswith(
17
+ "Nunchaku"
18
+ ):
16
19
  return BlockAdapter(
17
20
  pipe=pipe,
18
21
  transformer=pipe.transformer,
@@ -1,3 +1,4 @@
1
+ import copy
1
2
  import torch
2
3
  import unittest
3
4
  import functools
@@ -10,10 +11,10 @@ from cache_dit.cache_factory.cache_types import CacheType
10
11
  from cache_dit.cache_factory.block_adapters import BlockAdapter
11
12
  from cache_dit.cache_factory.block_adapters import ParamsModifier
12
13
  from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
13
- from cache_dit.cache_factory.cache_contexts import CachedContextManager
14
+ from cache_dit.cache_factory.cache_contexts import ContextManager
14
15
  from cache_dit.cache_factory.cache_contexts import BasicCacheConfig
15
16
  from cache_dit.cache_factory.cache_contexts import CalibratorConfig
16
- from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
+ from cache_dit.cache_factory.cache_blocks import UnifiedBlocks
17
18
  from cache_dit.logger import init_logger
18
19
 
19
20
  logger = init_logger(__name__)
@@ -32,7 +33,7 @@ class CachedAdapter:
32
33
  DiffusionPipeline,
33
34
  BlockAdapter,
34
35
  ],
35
- **cache_context_kwargs,
36
+ **context_kwargs,
36
37
  ) -> Union[
37
38
  DiffusionPipeline,
38
39
  BlockAdapter,
@@ -51,7 +52,7 @@ class CachedAdapter:
51
52
  block_adapter = BlockAdapterRegistry.get_adapter(
52
53
  pipe_or_adapter
53
54
  )
54
- if params_modifiers := cache_context_kwargs.pop(
55
+ if params_modifiers := context_kwargs.pop(
55
56
  "params_modifiers",
56
57
  None,
57
58
  ):
@@ -59,7 +60,7 @@ class CachedAdapter:
59
60
 
60
61
  return cls.cachify(
61
62
  block_adapter,
62
- **cache_context_kwargs,
63
+ **context_kwargs,
63
64
  ).pipe
64
65
  else:
65
66
  raise ValueError(
@@ -72,21 +73,21 @@ class CachedAdapter:
72
73
  "Adapting Cache Acceleration using custom BlockAdapter!"
73
74
  )
74
75
  if pipe_or_adapter.params_modifiers is None:
75
- if params_modifiers := cache_context_kwargs.pop(
76
+ if params_modifiers := context_kwargs.pop(
76
77
  "params_modifiers", None
77
78
  ):
78
79
  pipe_or_adapter.params_modifiers = params_modifiers
79
80
 
80
81
  return cls.cachify(
81
82
  pipe_or_adapter,
82
- **cache_context_kwargs,
83
+ **context_kwargs,
83
84
  )
84
85
 
85
86
  @classmethod
86
87
  def cachify(
87
88
  cls,
88
89
  block_adapter: BlockAdapter,
89
- **cache_context_kwargs,
90
+ **context_kwargs,
90
91
  ) -> BlockAdapter:
91
92
 
92
93
  if block_adapter.auto:
@@ -103,14 +104,15 @@ class CachedAdapter:
103
104
 
104
105
  # 1. Apply cache on pipeline: wrap cache context, must
105
106
  # call create_context before mock_blocks.
106
- cls.create_context(
107
+ _, contexts_kwargs = cls.create_context(
107
108
  block_adapter,
108
- **cache_context_kwargs,
109
+ **context_kwargs,
109
110
  )
110
111
 
111
112
  # 2. Apply cache on transformer: mock cached blocks
112
113
  cls.mock_blocks(
113
114
  block_adapter,
115
+ contexts_kwargs,
114
116
  )
115
117
 
116
118
  return block_adapter
@@ -119,12 +121,10 @@ class CachedAdapter:
119
121
  def check_context_kwargs(
120
122
  cls,
121
123
  block_adapter: BlockAdapter,
122
- **cache_context_kwargs,
124
+ **context_kwargs,
123
125
  ):
124
- # Check cache_context_kwargs
125
- cache_config: BasicCacheConfig = cache_context_kwargs[
126
- "cache_config"
127
- ] # ref
126
+ # Check context_kwargs
127
+ cache_config: BasicCacheConfig = context_kwargs["cache_config"] # ref
128
128
  assert cache_config is not None, "cache_config can not be None."
129
129
  if cache_config.enable_separate_cfg is None:
130
130
  # Check cfg for some specific case if users don't set it as True
@@ -150,19 +150,23 @@ class CachedAdapter:
150
150
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
151
151
  )
152
152
 
153
- cache_type = cache_context_kwargs.pop("cache_type", None)
153
+ cache_type = context_kwargs.pop("cache_type", None)
154
154
  if cache_type is not None:
155
- assert (
156
- cache_type == CacheType.DBCache
157
- ), "Custom cache setting only support for DBCache now!"
155
+ assert isinstance(
156
+ cache_type, CacheType
157
+ ), f"cache_type must be CacheType Enum, but got {type(cache_type)}."
158
+ assert cache_type == cache_config.cache_type, (
159
+ f"cache_type from context_kwargs ({cache_type}) must be the same "
160
+ f"as that from cache_config ({cache_config.cache_type})."
161
+ )
158
162
 
159
- return cache_context_kwargs
163
+ return context_kwargs
160
164
 
161
165
  @classmethod
162
166
  def create_context(
163
167
  cls,
164
168
  block_adapter: BlockAdapter,
165
- **cache_context_kwargs,
169
+ **context_kwargs,
166
170
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
167
171
 
168
172
  BlockAdapter.assert_normalized(block_adapter)
@@ -170,9 +174,9 @@ class CachedAdapter:
170
174
  if BlockAdapter.is_cached(block_adapter.pipe):
171
175
  return block_adapter.pipe
172
176
 
173
- # Check cache_context_kwargs
174
- cache_context_kwargs = cls.check_context_kwargs(
175
- block_adapter, **cache_context_kwargs
177
+ # Check context_kwargs
178
+ context_kwargs = cls.check_context_kwargs(
179
+ block_adapter, **context_kwargs
176
180
  )
177
181
  # Apply cache on pipeline: wrap cache context
178
182
  pipe_cls_name = block_adapter.pipe.__class__.__name__
@@ -181,15 +185,19 @@ class CachedAdapter:
181
185
  # Different transformers (Wan2.2, etc) should shared the same
182
186
  # cache manager but with different cache context (according
183
187
  # to their unique instance id).
184
- cache_manager = CachedContextManager(
188
+ cache_config: BasicCacheConfig = context_kwargs.get(
189
+ "cache_config", None
190
+ )
191
+ assert cache_config is not None, "cache_config can not be None."
192
+ context_manager = ContextManager(
185
193
  name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
194
+ cache_type=cache_config.cache_type,
186
195
  )
187
- block_adapter.pipe._cache_manager = cache_manager # instance level
196
+ block_adapter.pipe._context_manager = context_manager # instance level
188
197
 
189
198
  flatten_contexts, contexts_kwargs = cls.modify_context_params(
190
- block_adapter, **cache_context_kwargs
199
+ block_adapter, **context_kwargs
191
200
  )
192
-
193
201
  original_call = block_adapter.pipe.__class__.__call__
194
202
 
195
203
  @functools.wraps(original_call)
@@ -200,8 +208,8 @@ class CachedAdapter:
200
208
  flatten_contexts, contexts_kwargs
201
209
  ):
202
210
  stack.enter_context(
203
- cache_manager.enter_context(
204
- cache_manager.reset_context(
211
+ context_manager.enter_context(
212
+ context_manager.reset_context(
205
213
  context_name,
206
214
  **context_kwargs,
207
215
  ),
@@ -223,14 +231,14 @@ class CachedAdapter:
223
231
  def modify_context_params(
224
232
  cls,
225
233
  block_adapter: BlockAdapter,
226
- **cache_context_kwargs,
234
+ **context_kwargs,
227
235
  ) -> Tuple[List[str], List[Dict[str, Any]]]:
228
236
 
229
237
  flatten_contexts = BlockAdapter.flatten(
230
238
  block_adapter.unique_blocks_name
231
239
  )
232
240
  contexts_kwargs = [
233
- cache_context_kwargs.copy()
241
+ copy.deepcopy(context_kwargs) # must deep copy
234
242
  for _ in range(
235
243
  len(flatten_contexts),
236
244
  )
@@ -251,9 +259,41 @@ class CachedAdapter:
251
259
  for i in range(
252
260
  min(len(contexts_kwargs), len(flatten_modifiers)),
253
261
  ):
254
- contexts_kwargs[i].update(
255
- flatten_modifiers[i]._context_kwargs,
256
- )
262
+ if "cache_config" in flatten_modifiers[i]._context_kwargs:
263
+ modifier_cache_config = flatten_modifiers[
264
+ i
265
+ ]._context_kwargs.get("cache_config", None)
266
+ modifier_calibrator_config = flatten_modifiers[
267
+ i
268
+ ]._context_kwargs.get("calibrator_config", None)
269
+ if modifier_cache_config is not None:
270
+ assert isinstance(
271
+ modifier_cache_config, BasicCacheConfig
272
+ ), (
273
+ f"cache_config must be BasicCacheConfig, but got "
274
+ f"{type(modifier_cache_config)}."
275
+ )
276
+ contexts_kwargs[i]["cache_config"].update(
277
+ **modifier_cache_config.as_dict()
278
+ )
279
+ if modifier_calibrator_config is not None:
280
+ assert isinstance(
281
+ modifier_calibrator_config, CalibratorConfig
282
+ ), (
283
+ f"calibrator_config must be CalibratorConfig, but got "
284
+ f"{type(modifier_calibrator_config)}."
285
+ )
286
+ if (
287
+ contexts_kwargs[i].get("calibrator_config", None)
288
+ is None
289
+ ):
290
+ contexts_kwargs[i][
291
+ "calibrator_config"
292
+ ] = modifier_calibrator_config
293
+ else:
294
+ contexts_kwargs[i]["calibrator_config"].update(
295
+ **modifier_calibrator_config.as_dict()
296
+ )
257
297
  cls._config_messages(**contexts_kwargs[i])
258
298
 
259
299
  return flatten_contexts, contexts_kwargs
@@ -267,7 +307,7 @@ class CachedAdapter:
267
307
  "calibrator_config", None
268
308
  )
269
309
  if cache_config is not None:
270
- message = f"Collected Cache Config: {cache_config.strify()}"
310
+ message = f"Collected Context Config: {cache_config.strify()}"
271
311
  if calibrator_config is not None:
272
312
  message += f", Calibrator Config: {calibrator_config.strify(details=True)}"
273
313
  else:
@@ -278,6 +318,7 @@ class CachedAdapter:
278
318
  def mock_blocks(
279
319
  cls,
280
320
  block_adapter: BlockAdapter,
321
+ contexts_kwargs: List[Dict],
281
322
  ) -> List[torch.nn.Module]:
282
323
 
283
324
  BlockAdapter.assert_normalized(block_adapter)
@@ -287,20 +328,23 @@ class CachedAdapter:
287
328
 
288
329
  # Apply cache on transformer: mock cached transformer blocks
289
330
  for (
290
- cached_blocks,
331
+ unified_blocks,
291
332
  transformer,
292
333
  blocks_name,
293
334
  unique_blocks_name,
294
335
  dummy_blocks_names,
295
336
  ) in zip(
296
- cls.collect_cached_blocks(block_adapter),
337
+ cls.collect_unified_blocks(
338
+ block_adapter,
339
+ contexts_kwargs,
340
+ ),
297
341
  block_adapter.transformer,
298
342
  block_adapter.blocks_name,
299
343
  block_adapter.unique_blocks_name,
300
344
  block_adapter.dummy_blocks_names,
301
345
  ):
302
346
  cls.mock_transformer(
303
- cached_blocks,
347
+ unified_blocks,
304
348
  transformer,
305
349
  blocks_name,
306
350
  unique_blocks_name,
@@ -312,7 +356,7 @@ class CachedAdapter:
312
356
  @classmethod
313
357
  def mock_transformer(
314
358
  cls,
315
- cached_blocks: Dict[str, torch.nn.ModuleList],
359
+ unified_blocks: Dict[str, torch.nn.ModuleList],
316
360
  transformer: torch.nn.Module,
317
361
  blocks_name: List[str],
318
362
  unique_blocks_name: List[str],
@@ -352,7 +396,7 @@ class CachedAdapter:
352
396
  ):
353
397
  stack.enter_context(
354
398
  unittest.mock.patch.object(
355
- self, name, cached_blocks[context_name]
399
+ self, name, unified_blocks[context_name]
356
400
  )
357
401
  )
358
402
  for dummy_name in dummy_blocks_names:
@@ -388,46 +432,51 @@ class CachedAdapter:
388
432
  return transformer
389
433
 
390
434
  @classmethod
391
- def collect_cached_blocks(
435
+ def collect_unified_blocks(
392
436
  cls,
393
437
  block_adapter: BlockAdapter,
438
+ contexts_kwargs: List[Dict],
394
439
  ) -> List[Dict[str, torch.nn.ModuleList]]:
395
440
 
396
441
  BlockAdapter.assert_normalized(block_adapter)
397
442
 
398
443
  total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
399
- assert hasattr(block_adapter.pipe, "_cache_manager")
444
+ assert hasattr(block_adapter.pipe, "_context_manager")
400
445
  assert isinstance(
401
- block_adapter.pipe._cache_manager,
402
- CachedContextManager,
446
+ block_adapter.pipe._context_manager,
447
+ ContextManager._supported_managers,
403
448
  )
404
449
 
405
450
  for i in range(len(block_adapter.transformer)):
406
451
 
407
- cached_blocks_bind_context = {}
452
+ unified_blocks_bind_context = {}
408
453
  for j in range(len(block_adapter.blocks[i])):
409
- cached_blocks_bind_context[
454
+ cache_config: BasicCacheConfig = contexts_kwargs[
455
+ i * len(block_adapter.blocks[i]) + j
456
+ ]["cache_config"]
457
+ unified_blocks_bind_context[
410
458
  block_adapter.unique_blocks_name[i][j]
411
459
  ] = torch.nn.ModuleList(
412
460
  [
413
- CachedBlocks(
461
+ UnifiedBlocks(
414
462
  # 0. Transformer blocks configuration
415
463
  block_adapter.blocks[i][j],
416
464
  transformer=block_adapter.transformer[i],
417
465
  forward_pattern=block_adapter.forward_pattern[i][j],
418
466
  check_forward_pattern=block_adapter.check_forward_pattern,
419
467
  check_num_outputs=block_adapter.check_num_outputs,
420
- # 1. Cache context configuration
468
+ # 1. Cache/Prune context configuration
421
469
  cache_prefix=block_adapter.blocks_name[i][j],
422
470
  cache_context=block_adapter.unique_blocks_name[i][
423
471
  j
424
472
  ],
425
- cache_manager=block_adapter.pipe._cache_manager,
473
+ context_manager=block_adapter.pipe._context_manager,
474
+ cache_type=cache_config.cache_type,
426
475
  )
427
476
  ]
428
477
  )
429
478
 
430
- total_cached_blocks.append(cached_blocks_bind_context)
479
+ total_cached_blocks.append(unified_blocks_bind_context)
431
480
 
432
481
  return total_cached_blocks
433
482
 
@@ -437,7 +486,7 @@ class CachedAdapter:
437
486
  block_adapter: BlockAdapter,
438
487
  contexts_kwargs: List[Dict],
439
488
  ):
440
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
489
+ block_adapter.pipe._context_kwargs = contexts_kwargs[0]
441
490
 
442
491
  params_shift = 0
443
492
  for i in range(len(block_adapter.transformer)):
@@ -448,16 +497,14 @@ class CachedAdapter:
448
497
  block_adapter.transformer[i]._has_separate_cfg = (
449
498
  block_adapter.has_separate_cfg
450
499
  )
451
- block_adapter.transformer[i]._cache_context_kwargs = (
452
- contexts_kwargs[params_shift]
453
- )
500
+ block_adapter.transformer[i]._context_kwargs = contexts_kwargs[
501
+ params_shift
502
+ ]
454
503
 
455
504
  blocks = block_adapter.blocks[i]
456
505
  for j in range(len(blocks)):
457
506
  blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
458
- blocks[j]._cache_context_kwargs = contexts_kwargs[
459
- params_shift + j
460
- ]
507
+ blocks[j]._context_kwargs = contexts_kwargs[params_shift + j]
461
508
 
462
509
  params_shift += len(blocks)
463
510
 
@@ -467,25 +514,25 @@ class CachedAdapter:
467
514
  block_adapter: BlockAdapter,
468
515
  ):
469
516
  from cache_dit.cache_factory.cache_blocks import (
470
- patch_cached_stats,
517
+ apply_stats,
471
518
  )
472
519
 
473
- cache_manager = block_adapter.pipe._cache_manager
520
+ context_manager = block_adapter.pipe._context_manager
474
521
 
475
522
  for i in range(len(block_adapter.transformer)):
476
- patch_cached_stats(
523
+ apply_stats(
477
524
  block_adapter.transformer[i],
478
525
  cache_context=block_adapter.unique_blocks_name[i][-1],
479
- cache_manager=cache_manager,
526
+ context_manager=context_manager,
480
527
  )
481
528
  for blocks, unique_name in zip(
482
529
  block_adapter.blocks[i],
483
530
  block_adapter.unique_blocks_name[i],
484
531
  ):
485
- patch_cached_stats(
532
+ apply_stats(
486
533
  blocks,
487
534
  cache_context=unique_name,
488
- cache_manager=cache_manager,
535
+ context_manager=context_manager,
489
536
  )
490
537
 
491
538
  @classmethod
@@ -513,11 +560,13 @@ class CachedAdapter:
513
560
  original_call = pipe.__class__._original_call
514
561
  pipe.__class__.__call__ = original_call
515
562
  del pipe.__class__._original_call
516
- if hasattr(pipe, "_cache_manager"):
517
- cache_manager = pipe._cache_manager
518
- if isinstance(cache_manager, CachedContextManager):
519
- cache_manager.clear_contexts()
520
- del pipe._cache_manager
563
+ if hasattr(pipe, "_context_manager"):
564
+ context_manager = pipe._context_manager
565
+ if isinstance(
566
+ context_manager, ContextManager._supported_managers
567
+ ):
568
+ context_manager.clear_contexts()
569
+ del pipe._context_manager
521
570
  if hasattr(pipe, "_is_cached"):
522
571
  del pipe.__class__._is_cached
523
572
 
@@ -532,22 +581,22 @@ class CachedAdapter:
532
581
  def _release_blocks_params(blocks):
533
582
  if hasattr(blocks, "_forward_pattern"):
534
583
  del blocks._forward_pattern
535
- if hasattr(blocks, "_cache_context_kwargs"):
536
- del blocks._cache_context_kwargs
584
+ if hasattr(blocks, "_context_kwargs"):
585
+ del blocks._context_kwargs
537
586
 
538
587
  def _release_transformer_params(transformer):
539
588
  if hasattr(transformer, "_forward_pattern"):
540
589
  del transformer._forward_pattern
541
590
  if hasattr(transformer, "_has_separate_cfg"):
542
591
  del transformer._has_separate_cfg
543
- if hasattr(transformer, "_cache_context_kwargs"):
544
- del transformer._cache_context_kwargs
592
+ if hasattr(transformer, "_context_kwargs"):
593
+ del transformer._context_kwargs
545
594
  for blocks in BlockAdapter.find_blocks(transformer):
546
595
  _release_blocks_params(blocks)
547
596
 
548
597
  def _release_pipeline_params(pipe):
549
- if hasattr(pipe, "_cache_context_kwargs"):
550
- del pipe._cache_context_kwargs
598
+ if hasattr(pipe, "_context_kwargs"):
599
+ del pipe._context_kwargs
551
600
 
552
601
  cls.release_hooks(
553
602
  pipe_or_adapter,
@@ -558,14 +607,11 @@ class CachedAdapter:
558
607
 
559
608
  # release stats hooks
560
609
  from cache_dit.cache_factory.cache_blocks import (
561
- remove_cached_stats,
610
+ remove_stats,
562
611
  )
563
612
 
564
613
  cls.release_hooks(
565
- pipe_or_adapter,
566
- remove_cached_stats,
567
- remove_cached_stats,
568
- remove_cached_stats,
614
+ pipe_or_adapter, remove_stats, remove_stats, remove_stats
569
615
  )
570
616
 
571
617
  @classmethod