cache-dit 0.3.0__py3-none-any.whl → 0.3.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

Files changed (24) hide show
  1. cache_dit/__init__.py +3 -0
  2. cache_dit/_version.py +2 -2
  3. cache_dit/cache_factory/__init__.py +7 -0
  4. cache_dit/cache_factory/block_adapters/block_adapters.py +16 -6
  5. cache_dit/cache_factory/cache_adapters/__init__.py +2 -0
  6. cache_dit/cache_factory/{cache_adapters.py → cache_adapters/cache_adapter.py} +6 -6
  7. cache_dit/cache_factory/cache_adapters/v2/__init__.py +3 -0
  8. cache_dit/cache_factory/cache_adapters/v2/cache_adapter_v2.py +524 -0
  9. cache_dit/cache_factory/cache_contexts/__init__.py +7 -0
  10. cache_dit/cache_factory/cache_contexts/v2/__init__.py +13 -0
  11. cache_dit/cache_factory/cache_contexts/v2/cache_context_v2.py +288 -0
  12. cache_dit/cache_factory/cache_contexts/v2/cache_manager_v2.py +799 -0
  13. cache_dit/cache_factory/cache_contexts/v2/calibrators/__init__.py +81 -0
  14. cache_dit/cache_factory/cache_contexts/v2/calibrators/base.py +27 -0
  15. cache_dit/cache_factory/cache_contexts/v2/calibrators/foca.py +26 -0
  16. cache_dit/cache_factory/cache_contexts/v2/calibrators/taylorseer.py +105 -0
  17. cache_dit/cache_factory/cache_interface.py +39 -12
  18. cache_dit/utils.py +17 -7
  19. {cache_dit-0.3.0.dist-info → cache_dit-0.3.1.dist-info}/METADATA +38 -29
  20. {cache_dit-0.3.0.dist-info → cache_dit-0.3.1.dist-info}/RECORD +24 -14
  21. {cache_dit-0.3.0.dist-info → cache_dit-0.3.1.dist-info}/WHEEL +0 -0
  22. {cache_dit-0.3.0.dist-info → cache_dit-0.3.1.dist-info}/entry_points.txt +0 -0
  23. {cache_dit-0.3.0.dist-info → cache_dit-0.3.1.dist-info}/licenses/LICENSE +0 -0
  24. {cache_dit-0.3.0.dist-info → cache_dit-0.3.1.dist-info}/top_level.txt +0 -0
cache_dit/__init__.py CHANGED
@@ -18,6 +18,9 @@ from cache_dit.cache_factory import BlockAdapter
18
18
  from cache_dit.cache_factory import ParamsModifier
19
19
  from cache_dit.cache_factory import ForwardPattern
20
20
  from cache_dit.cache_factory import PatchFunctor
21
+ from cache_dit.cache_factory import CalibratorConfig
22
+ from cache_dit.cache_factory import TaylorSeerCalibratorConfig
23
+ from cache_dit.cache_factory import FoCaCalibratorConfig
21
24
  from cache_dit.cache_factory import supported_pipelines
22
25
  from cache_dit.cache_factory import get_adapter
23
26
  from cache_dit.compile import set_compile_configs
cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.3.0'
32
- __version_tuple__ = version_tuple = (0, 3, 0)
31
+ __version__ = version = '0.3.1'
32
+ __version_tuple__ = version_tuple = (0, 3, 1)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -12,9 +12,16 @@ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
12
12
 
13
13
  from cache_dit.cache_factory.cache_contexts import CachedContext
14
14
  from cache_dit.cache_factory.cache_contexts import CachedContextManager
15
+ from cache_dit.cache_factory.cache_contexts import CachedContextV2
16
+ from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
17
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig # no V1
18
+ from cache_dit.cache_factory.cache_contexts import TaylorSeerCalibratorConfig
19
+ from cache_dit.cache_factory.cache_contexts import FoCaCalibratorConfig
20
+
15
21
  from cache_dit.cache_factory.cache_blocks import CachedBlocks
16
22
 
17
23
  from cache_dit.cache_factory.cache_adapters import CachedAdapter
24
+ from cache_dit.cache_factory.cache_adapters import CachedAdapterV2
18
25
 
19
26
  from cache_dit.cache_factory.cache_interface import enable_cache
20
27
  from cache_dit.cache_factory.cache_interface import disable_cache
@@ -9,6 +9,7 @@ from typing import Any, Tuple, List, Optional, Union
9
9
  from diffusers import DiffusionPipeline
10
10
  from cache_dit.cache_factory.forward_pattern import ForwardPattern
11
11
  from cache_dit.cache_factory.patch_functors import PatchFunctor
12
+ from cache_dit.cache_factory.cache_contexts import CalibratorConfig
12
13
 
13
14
  from cache_dit.logger import init_logger
14
15
 
@@ -34,6 +35,8 @@ class ParamsModifier:
34
35
  enable_encoder_taylorseer: Optional[bool] = None,
35
36
  taylorseer_cache_type: Optional[str] = None,
36
37
  taylorseer_order: Optional[int] = None,
38
+ # New param only for v2 API
39
+ calibrator_config: Optional[CalibratorConfig] = None,
37
40
  **other_cache_context_kwargs,
38
41
  ):
39
42
  self._context_kwargs = other_cache_context_kwargs.copy()
@@ -52,12 +55,19 @@ class ParamsModifier:
52
55
  self._maybe_update_param(
53
56
  "cfg_diff_compute_separate", cfg_diff_compute_separate
54
57
  )
55
- self._maybe_update_param("enable_taylorseer", enable_taylorseer)
56
- self._maybe_update_param(
57
- "enable_encoder_taylorseer", enable_encoder_taylorseer
58
- )
59
- self._maybe_update_param("taylorseer_cache_type", taylorseer_cache_type)
60
- self._maybe_update_param("taylorseer_order", taylorseer_order)
58
+ # V1 only supports the Taylorseer calibrator. We have decided to
59
+ # keep this code for API compatibility reasons.
60
+ if calibrator_config is None:
61
+ self._maybe_update_param("enable_taylorseer", enable_taylorseer)
62
+ self._maybe_update_param(
63
+ "enable_encoder_taylorseer", enable_encoder_taylorseer
64
+ )
65
+ self._maybe_update_param(
66
+ "taylorseer_cache_type", taylorseer_cache_type
67
+ )
68
+ self._maybe_update_param("taylorseer_order", taylorseer_order)
69
+ else:
70
+ self._maybe_update_param("calibrator_config", calibrator_config)
61
71
 
62
72
  def _maybe_update_param(self, key: str, value: Any):
63
73
  if value is not None:
@@ -0,0 +1,2 @@
1
+ from cache_dit.cache_factory.cache_adapters.cache_adapter import CachedAdapter
2
+ from cache_dit.cache_factory.cache_adapters.v2 import CachedAdapterV2
@@ -8,12 +8,12 @@ from typing import Dict, List, Tuple, Any, Union, Callable
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
10
 
11
- from cache_dit.cache_factory import CacheType
12
- from cache_dit.cache_factory import BlockAdapter
13
- from cache_dit.cache_factory import ParamsModifier
14
- from cache_dit.cache_factory import BlockAdapterRegistry
15
- from cache_dit.cache_factory import CachedContextManager
16
- from cache_dit.cache_factory import CachedBlocks
11
+ from cache_dit.cache_factory.cache_types import CacheType
12
+ from cache_dit.cache_factory.block_adapters import BlockAdapter
13
+ from cache_dit.cache_factory.block_adapters import ParamsModifier
14
+ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
+ from cache_dit.cache_factory.cache_contexts import CachedContextManager
16
+ from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
17
  from cache_dit.cache_factory.cache_blocks.utils import (
18
18
  patch_cached_stats,
19
19
  remove_cached_stats,
@@ -0,0 +1,3 @@
1
+ from cache_dit.cache_factory.cache_adapters.v2.cache_adapter_v2 import (
2
+ CachedAdapterV2,
3
+ )
@@ -0,0 +1,524 @@
1
+ import torch
2
+
3
+ import unittest
4
+ import functools
5
+
6
+ from contextlib import ExitStack
7
+ from typing import Dict, List, Tuple, Any, Union, Callable
8
+
9
+ from diffusers import DiffusionPipeline
10
+
11
+ from cache_dit.cache_factory.cache_types import CacheType
12
+ from cache_dit.cache_factory.block_adapters import BlockAdapter
13
+ from cache_dit.cache_factory.block_adapters import ParamsModifier
14
+ from cache_dit.cache_factory.block_adapters import BlockAdapterRegistry
15
+ from cache_dit.cache_factory.cache_contexts import CachedContextManagerV2
16
+ from cache_dit.cache_factory.cache_blocks import CachedBlocks
17
+ from cache_dit.cache_factory.cache_blocks.utils import (
18
+ patch_cached_stats,
19
+ remove_cached_stats,
20
+ )
21
+ from cache_dit.logger import init_logger
22
+
23
+ logger = init_logger(__name__)
24
+
25
+
26
+ # Unified Cached Adapter
27
+ class CachedAdapterV2:
28
+
29
+ def __call__(self, *args, **kwargs):
30
+ return self.apply(*args, **kwargs)
31
+
32
+ @classmethod
33
+ def apply(
34
+ cls,
35
+ pipe_or_adapter: Union[
36
+ DiffusionPipeline,
37
+ BlockAdapter,
38
+ ],
39
+ **cache_context_kwargs,
40
+ ) -> Union[
41
+ DiffusionPipeline,
42
+ BlockAdapter,
43
+ ]:
44
+ assert (
45
+ pipe_or_adapter is not None
46
+ ), "pipe or block_adapter can not both None!"
47
+
48
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
49
+ if BlockAdapterRegistry.is_supported(pipe_or_adapter):
50
+ logger.info(
51
+ f"{pipe_or_adapter.__class__.__name__} is officially "
52
+ "supported by cache-dit. Use it's pre-defined BlockAdapter "
53
+ "directly!"
54
+ )
55
+ block_adapter = BlockAdapterRegistry.get_adapter(
56
+ pipe_or_adapter
57
+ )
58
+ return cls.cachify(
59
+ block_adapter,
60
+ **cache_context_kwargs,
61
+ ).pipe
62
+ else:
63
+ raise ValueError(
64
+ f"{pipe_or_adapter.__class__.__name__} is not officially supported "
65
+ "by cache-dit, please set BlockAdapter instead!"
66
+ )
67
+ else:
68
+ assert isinstance(pipe_or_adapter, BlockAdapter)
69
+ logger.info(
70
+ "Adapting Cache Acceleration using custom BlockAdapter!"
71
+ )
72
+ return cls.cachify(
73
+ pipe_or_adapter,
74
+ **cache_context_kwargs,
75
+ )
76
+
77
+ @classmethod
78
+ def cachify(
79
+ cls,
80
+ block_adapter: BlockAdapter,
81
+ **cache_context_kwargs,
82
+ ) -> BlockAdapter:
83
+
84
+ if block_adapter.auto:
85
+ block_adapter = BlockAdapter.auto_block_adapter(
86
+ block_adapter,
87
+ )
88
+
89
+ if BlockAdapter.check_block_adapter(block_adapter):
90
+
91
+ # 0. Must normalize block_adapter before apply cache
92
+ block_adapter = BlockAdapter.normalize(block_adapter)
93
+ if BlockAdapter.is_cached(block_adapter):
94
+ return block_adapter
95
+
96
+ # 1. Apply cache on pipeline: wrap cache context, must
97
+ # call create_context before mock_blocks.
98
+ cls.create_context(
99
+ block_adapter,
100
+ **cache_context_kwargs,
101
+ )
102
+
103
+ # 2. Apply cache on transformer: mock cached blocks
104
+ cls.mock_blocks(
105
+ block_adapter,
106
+ )
107
+
108
+ return block_adapter
109
+
110
+ @classmethod
111
+ def check_context_kwargs(
112
+ cls,
113
+ block_adapter: BlockAdapter,
114
+ **cache_context_kwargs,
115
+ ):
116
+ # Check cache_context_kwargs
117
+ if cache_context_kwargs["enable_separate_cfg"] is None:
118
+ # Check cfg for some specific case if users don't set it as True
119
+ if BlockAdapterRegistry.has_separate_cfg(block_adapter):
120
+ cache_context_kwargs["enable_separate_cfg"] = True
121
+ logger.info(
122
+ f"Use custom 'enable_separate_cfg' from BlockAdapter: True. "
123
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
124
+ )
125
+ else:
126
+ cache_context_kwargs["enable_separate_cfg"] = (
127
+ BlockAdapterRegistry.has_separate_cfg(block_adapter.pipe)
128
+ )
129
+ logger.info(
130
+ f"Use default 'enable_separate_cfg' from block adapter "
131
+ f"register: {cache_context_kwargs['enable_separate_cfg']}, "
132
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
133
+ )
134
+ else:
135
+ logger.info(
136
+ f"Use custom 'enable_separate_cfg' from cache context "
137
+ f"kwargs: {cache_context_kwargs['enable_separate_cfg']}. "
138
+ f"Pipeline: {block_adapter.pipe.__class__.__name__}."
139
+ )
140
+
141
+ if (
142
+ cache_type := cache_context_kwargs.pop("cache_type", None)
143
+ ) is not None:
144
+ assert (
145
+ cache_type == CacheType.DBCache
146
+ ), "Custom cache setting only support for DBCache now!"
147
+
148
+ return cache_context_kwargs
149
+
150
+ @classmethod
151
+ def create_context(
152
+ cls,
153
+ block_adapter: BlockAdapter,
154
+ **cache_context_kwargs,
155
+ ) -> DiffusionPipeline:
156
+
157
+ BlockAdapter.assert_normalized(block_adapter)
158
+
159
+ if BlockAdapter.is_cached(block_adapter.pipe):
160
+ return block_adapter.pipe
161
+
162
+ # Check cache_context_kwargs
163
+ cache_context_kwargs = cls.check_context_kwargs(
164
+ block_adapter, **cache_context_kwargs
165
+ )
166
+ # Apply cache on pipeline: wrap cache context
167
+ pipe_cls_name = block_adapter.pipe.__class__.__name__
168
+
169
+ # Each Pipeline should have it's own context manager instance.
170
+ # Different transformers (Wan2.2, etc) should shared the same
171
+ # cache manager but with different cache context (according
172
+ # to their unique instance id).
173
+ cache_manager = CachedContextManagerV2(
174
+ name=f"{pipe_cls_name}_{hash(id(block_adapter.pipe))}",
175
+ )
176
+ block_adapter.pipe._cache_manager = cache_manager # instance level
177
+
178
+ flatten_contexts, contexts_kwargs = cls.modify_context_params(
179
+ block_adapter, cache_manager, **cache_context_kwargs
180
+ )
181
+
182
+ original_call = block_adapter.pipe.__class__.__call__
183
+
184
+ @functools.wraps(original_call)
185
+ def new_call(self, *args, **kwargs):
186
+ with ExitStack() as stack:
187
+ # cache context will be reset for each pipe inference
188
+ for context_name, context_kwargs in zip(
189
+ flatten_contexts, contexts_kwargs
190
+ ):
191
+ stack.enter_context(
192
+ cache_manager.enter_context(
193
+ cache_manager.reset_context(
194
+ context_name,
195
+ **context_kwargs,
196
+ ),
197
+ )
198
+ )
199
+ outputs = original_call(self, *args, **kwargs)
200
+ cls.apply_stats_hooks(block_adapter)
201
+ return outputs
202
+
203
+ block_adapter.pipe.__class__.__call__ = new_call
204
+ block_adapter.pipe.__class__._original_call = original_call
205
+ block_adapter.pipe.__class__._is_cached = True
206
+
207
+ cls.apply_params_hooks(block_adapter, contexts_kwargs)
208
+
209
+ return block_adapter.pipe
210
+
211
+ @classmethod
212
+ def modify_context_params(
213
+ cls,
214
+ block_adapter: BlockAdapter,
215
+ cache_manager: CachedContextManagerV2,
216
+ **cache_context_kwargs,
217
+ ) -> Tuple[List[str], List[Dict[str, Any]]]:
218
+
219
+ flatten_contexts = BlockAdapter.flatten(
220
+ block_adapter.unique_blocks_name
221
+ )
222
+ contexts_kwargs = [
223
+ cache_context_kwargs.copy()
224
+ for _ in range(
225
+ len(flatten_contexts),
226
+ )
227
+ ]
228
+
229
+ for i in range(len(contexts_kwargs)):
230
+ contexts_kwargs[i]["name"] = flatten_contexts[i]
231
+
232
+ if block_adapter.params_modifiers is None:
233
+ return flatten_contexts, contexts_kwargs
234
+
235
+ flatten_modifiers: List[ParamsModifier] = BlockAdapter.flatten(
236
+ block_adapter.params_modifiers,
237
+ )
238
+
239
+ for i in range(
240
+ min(len(contexts_kwargs), len(flatten_modifiers)),
241
+ ):
242
+ contexts_kwargs[i].update(
243
+ flatten_modifiers[i]._context_kwargs,
244
+ )
245
+ contexts_kwargs[i], _ = cache_manager.collect_cache_kwargs(
246
+ default_attrs={}, **contexts_kwargs[i]
247
+ )
248
+
249
+ return flatten_contexts, contexts_kwargs
250
+
251
+ @classmethod
252
+ def mock_blocks(
253
+ cls,
254
+ block_adapter: BlockAdapter,
255
+ ) -> List[torch.nn.Module]:
256
+
257
+ BlockAdapter.assert_normalized(block_adapter)
258
+
259
+ if BlockAdapter.is_cached(block_adapter.transformer):
260
+ return block_adapter.transformer
261
+
262
+ # Apply cache on transformer: mock cached transformer blocks
263
+ for (
264
+ cached_blocks,
265
+ transformer,
266
+ blocks_name,
267
+ unique_blocks_name,
268
+ dummy_blocks_names,
269
+ ) in zip(
270
+ cls.collect_cached_blocks(block_adapter),
271
+ block_adapter.transformer,
272
+ block_adapter.blocks_name,
273
+ block_adapter.unique_blocks_name,
274
+ block_adapter.dummy_blocks_names,
275
+ ):
276
+ cls.mock_transformer(
277
+ cached_blocks,
278
+ transformer,
279
+ blocks_name,
280
+ unique_blocks_name,
281
+ dummy_blocks_names,
282
+ )
283
+
284
+ return block_adapter.transformer
285
+
286
+ @classmethod
287
+ def mock_transformer(
288
+ cls,
289
+ cached_blocks: Dict[str, torch.nn.ModuleList],
290
+ transformer: torch.nn.Module,
291
+ blocks_name: List[str],
292
+ unique_blocks_name: List[str],
293
+ dummy_blocks_names: List[str],
294
+ ) -> torch.nn.Module:
295
+ dummy_blocks = torch.nn.ModuleList()
296
+
297
+ original_forward = transformer.forward
298
+
299
+ assert isinstance(dummy_blocks_names, list)
300
+
301
+ @functools.wraps(original_forward)
302
+ def new_forward(self, *args, **kwargs):
303
+ with ExitStack() as stack:
304
+ for name, context_name in zip(
305
+ blocks_name,
306
+ unique_blocks_name,
307
+ ):
308
+ stack.enter_context(
309
+ unittest.mock.patch.object(
310
+ self, name, cached_blocks[context_name]
311
+ )
312
+ )
313
+ for dummy_name in dummy_blocks_names:
314
+ stack.enter_context(
315
+ unittest.mock.patch.object(
316
+ self, dummy_name, dummy_blocks
317
+ )
318
+ )
319
+ return original_forward(*args, **kwargs)
320
+
321
+ transformer.forward = new_forward.__get__(transformer)
322
+ transformer._original_forward = original_forward
323
+ transformer._is_cached = True
324
+
325
+ return transformer
326
+
327
+ @classmethod
328
+ def collect_cached_blocks(
329
+ cls,
330
+ block_adapter: BlockAdapter,
331
+ ) -> List[Dict[str, torch.nn.ModuleList]]:
332
+
333
+ BlockAdapter.assert_normalized(block_adapter)
334
+
335
+ total_cached_blocks: List[Dict[str, torch.nn.ModuleList]] = []
336
+ assert hasattr(block_adapter.pipe, "_cache_manager")
337
+ assert isinstance(
338
+ block_adapter.pipe._cache_manager, CachedContextManagerV2
339
+ )
340
+
341
+ for i in range(len(block_adapter.transformer)):
342
+
343
+ cached_blocks_bind_context = {}
344
+ for j in range(len(block_adapter.blocks[i])):
345
+ cached_blocks_bind_context[
346
+ block_adapter.unique_blocks_name[i][j]
347
+ ] = torch.nn.ModuleList(
348
+ [
349
+ CachedBlocks(
350
+ # 0. Transformer blocks configuration
351
+ block_adapter.blocks[i][j],
352
+ transformer=block_adapter.transformer[i],
353
+ forward_pattern=block_adapter.forward_pattern[i][j],
354
+ check_forward_pattern=block_adapter.check_forward_pattern,
355
+ check_num_outputs=block_adapter.check_num_outputs,
356
+ # 1. Cache context configuration
357
+ cache_prefix=block_adapter.blocks_name[i][j],
358
+ cache_context=block_adapter.unique_blocks_name[i][
359
+ j
360
+ ],
361
+ cache_manager=block_adapter.pipe._cache_manager,
362
+ )
363
+ ]
364
+ )
365
+
366
+ total_cached_blocks.append(cached_blocks_bind_context)
367
+
368
+ return total_cached_blocks
369
+
370
+ @classmethod
371
+ def apply_params_hooks(
372
+ cls,
373
+ block_adapter: BlockAdapter,
374
+ contexts_kwargs: List[Dict],
375
+ ):
376
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
377
+
378
+ params_shift = 0
379
+ for i in range(len(block_adapter.transformer)):
380
+
381
+ block_adapter.transformer[i]._forward_pattern = (
382
+ block_adapter.forward_pattern
383
+ )
384
+ block_adapter.transformer[i]._has_separate_cfg = (
385
+ block_adapter.has_separate_cfg
386
+ )
387
+ block_adapter.transformer[i]._cache_context_kwargs = (
388
+ contexts_kwargs[params_shift]
389
+ )
390
+
391
+ blocks = block_adapter.blocks[i]
392
+ for j in range(len(blocks)):
393
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
394
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
395
+ params_shift + j
396
+ ]
397
+
398
+ params_shift += len(blocks)
399
+
400
+ @classmethod
401
+ def apply_stats_hooks(
402
+ cls,
403
+ block_adapter: BlockAdapter,
404
+ ):
405
+ cache_manager = block_adapter.pipe._cache_manager
406
+
407
+ for i in range(len(block_adapter.transformer)):
408
+ patch_cached_stats(
409
+ block_adapter.transformer[i],
410
+ cache_context=block_adapter.unique_blocks_name[i][-1],
411
+ cache_manager=cache_manager,
412
+ )
413
+ for blocks, unique_name in zip(
414
+ block_adapter.blocks[i],
415
+ block_adapter.unique_blocks_name[i],
416
+ ):
417
+ patch_cached_stats(
418
+ blocks,
419
+ cache_context=unique_name,
420
+ cache_manager=cache_manager,
421
+ )
422
+
423
+ @classmethod
424
+ def maybe_release_hooks(
425
+ cls,
426
+ pipe_or_adapter: Union[
427
+ DiffusionPipeline,
428
+ BlockAdapter,
429
+ ],
430
+ ):
431
+ # release model hooks
432
+ def _release_blocks_hooks(blocks):
433
+ return
434
+
435
+ def _release_transformer_hooks(transformer):
436
+ if hasattr(transformer, "_original_forward"):
437
+ original_forward = transformer._original_forward
438
+ transformer.forward = original_forward.__get__(transformer)
439
+ del transformer._original_forward
440
+ if hasattr(transformer, "_is_cached"):
441
+ del transformer._is_cached
442
+
443
+ def _release_pipeline_hooks(pipe):
444
+ if hasattr(pipe, "_original_call"):
445
+ original_call = pipe.__class__._original_call
446
+ pipe.__class__.__call__ = original_call
447
+ del pipe.__class__._original_call
448
+ if hasattr(pipe, "_cache_manager"):
449
+ cache_manager = pipe._cache_manager
450
+ if isinstance(cache_manager, CachedContextManagerV2):
451
+ cache_manager.clear_contexts()
452
+ del pipe._cache_manager
453
+ if hasattr(pipe, "_is_cached"):
454
+ del pipe.__class__._is_cached
455
+
456
+ cls.release_hooks(
457
+ pipe_or_adapter,
458
+ _release_blocks_hooks,
459
+ _release_transformer_hooks,
460
+ _release_pipeline_hooks,
461
+ )
462
+
463
+ # release params hooks
464
+ def _release_blocks_params(blocks):
465
+ if hasattr(blocks, "_forward_pattern"):
466
+ del blocks._forward_pattern
467
+ if hasattr(blocks, "_cache_context_kwargs"):
468
+ del blocks._cache_context_kwargs
469
+
470
+ def _release_transformer_params(transformer):
471
+ if hasattr(transformer, "_forward_pattern"):
472
+ del transformer._forward_pattern
473
+ if hasattr(transformer, "_has_separate_cfg"):
474
+ del transformer._has_separate_cfg
475
+ if hasattr(transformer, "_cache_context_kwargs"):
476
+ del transformer._cache_context_kwargs
477
+ for blocks in BlockAdapter.find_blocks(transformer):
478
+ _release_blocks_params(blocks)
479
+
480
+ def _release_pipeline_params(pipe):
481
+ if hasattr(pipe, "_cache_context_kwargs"):
482
+ del pipe._cache_context_kwargs
483
+
484
+ cls.release_hooks(
485
+ pipe_or_adapter,
486
+ _release_blocks_params,
487
+ _release_transformer_params,
488
+ _release_pipeline_params,
489
+ )
490
+
491
+ # release stats hooks
492
+ cls.release_hooks(
493
+ pipe_or_adapter,
494
+ remove_cached_stats,
495
+ remove_cached_stats,
496
+ remove_cached_stats,
497
+ )
498
+
499
+ @classmethod
500
+ def release_hooks(
501
+ cls,
502
+ pipe_or_adapter: Union[
503
+ DiffusionPipeline,
504
+ BlockAdapter,
505
+ ],
506
+ _release_blocks: Callable,
507
+ _release_transformer: Callable,
508
+ _release_pipeline: Callable,
509
+ ):
510
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
511
+ pipe = pipe_or_adapter
512
+ _release_pipeline(pipe)
513
+ if hasattr(pipe, "transformer"):
514
+ _release_transformer(pipe.transformer)
515
+ if hasattr(pipe, "transformer_2"): # Wan 2.2
516
+ _release_transformer(pipe.transformer_2)
517
+ elif isinstance(pipe_or_adapter, BlockAdapter):
518
+ adapter = pipe_or_adapter
519
+ BlockAdapter.assert_normalized(adapter)
520
+ _release_pipeline(adapter.pipe)
521
+ for transformer in BlockAdapter.flatten(adapter.transformer):
522
+ _release_transformer(transformer)
523
+ for blocks in BlockAdapter.flatten(adapter.blocks):
524
+ _release_blocks(blocks)
@@ -3,3 +3,10 @@ from cache_dit.cache_factory.cache_contexts.cache_context import CachedContext
3
3
  from cache_dit.cache_factory.cache_contexts.cache_manager import (
4
4
  CachedContextManager,
5
5
  )
6
+ from cache_dit.cache_factory.cache_contexts.v2 import (
7
+ CachedContextV2,
8
+ CachedContextManagerV2,
9
+ CalibratorConfig,
10
+ TaylorSeerCalibratorConfig,
11
+ FoCaCalibratorConfig,
12
+ )
@@ -0,0 +1,13 @@
1
+ from cache_dit.cache_factory.cache_contexts.v2.calibrators import (
2
+ Calibrator,
3
+ CalibratorBase,
4
+ CalibratorConfig,
5
+ TaylorSeerCalibratorConfig,
6
+ FoCaCalibratorConfig,
7
+ )
8
+ from cache_dit.cache_factory.cache_contexts.v2.cache_context_v2 import (
9
+ CachedContextV2,
10
+ )
11
+ from cache_dit.cache_factory.cache_contexts.v2.cache_manager_v2 import (
12
+ CachedContextManagerV2,
13
+ )