cache-dit 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of cache-dit might be problematic. Click here for more details.

cache_dit/_version.py CHANGED
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
28
28
  commit_id: COMMIT_ID
29
29
  __commit_id__: COMMIT_ID
30
30
 
31
- __version__ = version = '0.2.29'
32
- __version_tuple__ = version_tuple = (0, 2, 29)
31
+ __version__ = version = '0.2.31'
32
+ __version_tuple__ = version_tuple = (0, 2, 31)
33
33
 
34
34
  __commit_id__ = commit_id = None
@@ -123,6 +123,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
123
123
  assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
124
124
  return BlockAdapter(
125
125
  pipe=pipe,
126
+ transformer=pipe.transformer,
126
127
  blocks=[
127
128
  pipe.transformer.transformer_blocks,
128
129
  pipe.transformer.single_transformer_blocks,
@@ -131,6 +132,8 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
131
132
  ForwardPattern.Pattern_0,
132
133
  ForwardPattern.Pattern_0,
133
134
  ],
135
+ # The type hint in diffusers is wrong
136
+ check_num_outputs=False,
134
137
  **kwargs,
135
138
  )
136
139
 
@@ -327,37 +330,6 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
327
330
  )
328
331
 
329
332
 
330
- @BlockAdapterRegistry.register("HunyuanDiT")
331
- def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
332
- from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
333
-
334
- assert isinstance(
335
- pipe.transformer,
336
- (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
337
- )
338
- return BlockAdapter(
339
- pipe=pipe,
340
- transformer=pipe.transformer,
341
- blocks=pipe.transformer.blocks,
342
- forward_pattern=ForwardPattern.Pattern_3,
343
- **kwargs,
344
- )
345
-
346
-
347
- @BlockAdapterRegistry.register("HunyuanDiTPAG")
348
- def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
349
- from diffusers import HunyuanDiT2DModel
350
-
351
- assert isinstance(pipe.transformer, HunyuanDiT2DModel)
352
- return BlockAdapter(
353
- pipe=pipe,
354
- transformer=pipe.transformer,
355
- blocks=pipe.transformer.blocks,
356
- forward_pattern=ForwardPattern.Pattern_3,
357
- **kwargs,
358
- )
359
-
360
-
361
333
  @BlockAdapterRegistry.register("Lumina")
362
334
  def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
363
335
  from diffusers import LuminaNextDiT2DModel
@@ -414,10 +386,12 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
414
386
  )
415
387
 
416
388
 
417
- @BlockAdapterRegistry.register("Sana")
389
+ @BlockAdapterRegistry.register("Sana", supported=False)
418
390
  def sana_adapter(pipe, **kwargs) -> BlockAdapter:
419
391
  from diffusers import SanaTransformer2DModel
420
392
 
393
+ # TODO: fix -> got multiple values for argument 'encoder_hidden_states'
394
+
421
395
  assert isinstance(pipe.transformer, SanaTransformer2DModel)
422
396
  return BlockAdapter(
423
397
  pipe=pipe,
@@ -428,20 +402,6 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
428
402
  )
429
403
 
430
404
 
431
- @BlockAdapterRegistry.register("ShapE")
432
- def shape_adapter(pipe, **kwargs) -> BlockAdapter:
433
- from diffusers import PriorTransformer
434
-
435
- assert isinstance(pipe.prior, PriorTransformer)
436
- return BlockAdapter(
437
- pipe=pipe,
438
- transformer=pipe.prior,
439
- blocks=pipe.prior.transformer_blocks,
440
- forward_pattern=ForwardPattern.Pattern_3,
441
- **kwargs,
442
- )
443
-
444
-
445
405
  @BlockAdapterRegistry.register("StableAudio")
446
406
  def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
447
407
  from diffusers import StableAudioDiTModel
@@ -459,21 +419,37 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
459
419
  @BlockAdapterRegistry.register("VisualCloze")
460
420
  def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
461
421
  from diffusers import FluxTransformer2DModel
422
+ from cache_dit.utils import is_diffusers_at_least_0_3_5
462
423
 
463
424
  assert isinstance(pipe.transformer, FluxTransformer2DModel)
464
- return BlockAdapter(
465
- pipe=pipe,
466
- transformer=pipe.transformer,
467
- blocks=[
468
- pipe.transformer.transformer_blocks,
469
- pipe.transformer.single_transformer_blocks,
470
- ],
471
- forward_pattern=[
472
- ForwardPattern.Pattern_1,
473
- ForwardPattern.Pattern_3,
474
- ],
475
- **kwargs,
476
- )
425
+ if is_diffusers_at_least_0_3_5():
426
+ return BlockAdapter(
427
+ pipe=pipe,
428
+ transformer=pipe.transformer,
429
+ blocks=[
430
+ pipe.transformer.transformer_blocks,
431
+ pipe.transformer.single_transformer_blocks,
432
+ ],
433
+ forward_pattern=[
434
+ ForwardPattern.Pattern_1,
435
+ ForwardPattern.Pattern_1,
436
+ ],
437
+ **kwargs,
438
+ )
439
+ else:
440
+ return BlockAdapter(
441
+ pipe=pipe,
442
+ transformer=pipe.transformer,
443
+ blocks=[
444
+ pipe.transformer.transformer_blocks,
445
+ pipe.transformer.single_transformer_blocks,
446
+ ],
447
+ forward_pattern=[
448
+ ForwardPattern.Pattern_1,
449
+ ForwardPattern.Pattern_3,
450
+ ],
451
+ **kwargs,
452
+ )
477
453
 
478
454
 
479
455
  @BlockAdapterRegistry.register("AuraFlow")
@@ -511,9 +487,29 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
511
487
  )
512
488
 
513
489
 
490
+ @BlockAdapterRegistry.register("ShapE")
491
+ def shape_adapter(pipe, **kwargs) -> BlockAdapter:
492
+ from diffusers import PriorTransformer
493
+
494
+ assert isinstance(pipe.prior, PriorTransformer)
495
+ return BlockAdapter(
496
+ pipe=pipe,
497
+ transformer=pipe.prior,
498
+ blocks=pipe.prior.transformer_blocks,
499
+ forward_pattern=ForwardPattern.Pattern_3,
500
+ **kwargs,
501
+ )
502
+
503
+
514
504
  @BlockAdapterRegistry.register("HiDream")
515
505
  def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
506
+ # NOTE: Need to patch Transformer forward to fully support
507
+ # double_stream_blocks and single_stream_blocks, namely, need
508
+ # to remove the logics inside the blocks forward loop:
509
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
510
+ # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
516
511
  from diffusers import HiDreamImageTransformer2DModel
512
+ from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
517
513
 
518
514
  assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
519
515
  return BlockAdapter(
@@ -524,9 +520,47 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
524
520
  pipe.transformer.single_stream_blocks,
525
521
  ],
526
522
  forward_pattern=[
527
- ForwardPattern.Pattern_4,
523
+ ForwardPattern.Pattern_0,
528
524
  ForwardPattern.Pattern_3,
529
525
  ],
530
- check_num_outputs=False,
526
+ patch_functor=HiDreamPatchFunctor(),
527
+ # NOTE: The type hint in diffusers is wrong
528
+ check_forward_pattern=True,
529
+ check_num_outputs=True,
530
+ **kwargs,
531
+ )
532
+
533
+
534
+ @BlockAdapterRegistry.register("HunyuanDiT")
535
+ def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
536
+ from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
537
+ from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
538
+
539
+ assert isinstance(
540
+ pipe.transformer,
541
+ (HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
542
+ )
543
+ return BlockAdapter(
544
+ pipe=pipe,
545
+ transformer=pipe.transformer,
546
+ blocks=pipe.transformer.blocks,
547
+ forward_pattern=ForwardPattern.Pattern_3,
548
+ patch_functor=HunyuanDiTPatchFunctor(),
549
+ **kwargs,
550
+ )
551
+
552
+
553
+ @BlockAdapterRegistry.register("HunyuanDiTPAG")
554
+ def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
555
+ from diffusers import HunyuanDiT2DModel
556
+ from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
557
+
558
+ assert isinstance(pipe.transformer, HunyuanDiT2DModel)
559
+ return BlockAdapter(
560
+ pipe=pipe,
561
+ transformer=pipe.transformer,
562
+ blocks=pipe.transformer.blocks,
563
+ forward_pattern=ForwardPattern.Pattern_3,
564
+ patch_functor=HunyuanDiTPatchFunctor(),
531
565
  **kwargs,
532
566
  )
@@ -75,7 +75,8 @@ class BlockAdapter:
75
75
  List[List[ParamsModifier]],
76
76
  ] = None
77
77
 
78
- check_num_outputs: bool = True
78
+ check_forward_pattern: bool = True
79
+ check_num_outputs: bool = False
79
80
 
80
81
  # Pipeline Level Flags
81
82
  # Patch Functor: Flux, etc.
@@ -111,9 +112,9 @@ class BlockAdapter:
111
112
  def __post_init__(self):
112
113
  if self.skip_post_init:
113
114
  return
114
- assert any((self.pipe is not None, self.transformer is not None))
115
- self.maybe_fill_attrs()
116
- self.maybe_patchify()
115
+ if any((self.pipe is not None, self.transformer is not None)):
116
+ self.maybe_fill_attrs()
117
+ self.maybe_patchify()
117
118
 
118
119
  def maybe_fill_attrs(self):
119
120
  # NOTE: This func should be call before normalize.
@@ -130,7 +131,9 @@ class BlockAdapter:
130
131
  assert isinstance(blocks, torch.nn.ModuleList)
131
132
  blocks_name = None
132
133
  for attr_name in attr_names:
133
- if attr := getattr(transformer, attr_name, None):
134
+ if (
135
+ attr := getattr(transformer, attr_name, None)
136
+ ) is not None:
134
137
  if isinstance(attr, torch.nn.ModuleList) and id(
135
138
  attr
136
139
  ) == id(blocks):
@@ -389,11 +392,20 @@ class BlockAdapter:
389
392
  forward_pattern: ForwardPattern,
390
393
  **kwargs,
391
394
  ) -> bool:
395
+
396
+ if not kwargs.get("check_forward_pattern", True):
397
+ return True
398
+
392
399
  assert (
393
400
  forward_pattern.Supported
394
401
  and forward_pattern in ForwardPattern.supported_patterns()
395
402
  ), f"Pattern {forward_pattern} is not support now!"
396
403
 
404
+ # NOTE: Special case for HiDreamBlock
405
+ if hasattr(block, "block"):
406
+ if isinstance(block.block, torch.nn.Module):
407
+ block = block.block
408
+
397
409
  forward_parameters = set(
398
410
  inspect.signature(block.forward).parameters.keys()
399
411
  )
@@ -423,6 +435,14 @@ class BlockAdapter:
423
435
  logging: bool = True,
424
436
  **kwargs,
425
437
  ) -> bool:
438
+
439
+ if not kwargs.get("check_forward_pattern", True):
440
+ if logging:
441
+ logger.warning(
442
+ f"Skipped Forward Pattern Check: {forward_pattern}"
443
+ )
444
+ return True
445
+
426
446
  assert (
427
447
  forward_pattern.Supported
428
448
  and forward_pattern in ForwardPattern.supported_patterns()
@@ -529,6 +549,7 @@ class BlockAdapter:
529
549
  blocks,
530
550
  forward_pattern=forward_pattern,
531
551
  check_num_outputs=adapter.check_num_outputs,
552
+ check_forward_pattern=adapter.check_forward_pattern,
532
553
  ), (
533
554
  "No block forward pattern matched, "
534
555
  f"supported lists: {ForwardPattern.supported_patterns()}"
@@ -558,7 +579,7 @@ class BlockAdapter:
558
579
  assert isinstance(adapter[0], torch.nn.Module)
559
580
  return getattr(adapter[0], "_is_cached", False)
560
581
  else:
561
- raise TypeError(f"Can't check this type: {adapter}!")
582
+ raise TypeError(f"Can't check this type: {type(adapter)}!")
562
583
 
563
584
  @classmethod
564
585
  def nested_depth(cls, obj: Any):
@@ -1,4 +1,4 @@
1
- from typing import Any, Tuple, List, Dict
1
+ from typing import Any, Tuple, List, Dict, Callable
2
2
 
3
3
  from diffusers import DiffusionPipeline
4
4
  from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
@@ -9,20 +9,23 @@ logger = init_logger(__name__)
9
9
 
10
10
 
11
11
  class BlockAdapterRegistry:
12
- _adapters: Dict[str, BlockAdapter] = {}
13
- _predefined_adapters_has_spearate_cfg: List[str] = {
12
+ _adapters: Dict[str, Callable[..., BlockAdapter]] = {}
13
+ _predefined_adapters_has_spearate_cfg: List[str] = [
14
14
  "QwenImage",
15
15
  "Wan",
16
16
  "CogView4",
17
17
  "Cosmos",
18
18
  "SkyReelsV2",
19
19
  "Chroma",
20
- }
20
+ ]
21
21
 
22
22
  @classmethod
23
- def register(cls, name):
24
- def decorator(func):
25
- cls._adapters[name] = func
23
+ def register(cls, name: str, supported: bool = True):
24
+ def decorator(
25
+ func: Callable[..., BlockAdapter]
26
+ ) -> Callable[..., BlockAdapter]:
27
+ if supported:
28
+ cls._adapters[name] = func
26
29
  return func
27
30
 
28
31
  return decorator
@@ -4,7 +4,7 @@ import unittest
4
4
  import functools
5
5
 
6
6
  from contextlib import ExitStack
7
- from typing import Dict, List, Tuple, Any
7
+ from typing import Dict, List, Tuple, Any, Union, Callable
8
8
 
9
9
  from diffusers import DiffusionPipeline
10
10
 
@@ -14,7 +14,10 @@ from cache_dit.cache_factory import ParamsModifier
14
14
  from cache_dit.cache_factory import BlockAdapterRegistry
15
15
  from cache_dit.cache_factory import CachedContextManager
16
16
  from cache_dit.cache_factory import CachedBlocks
17
-
17
+ from cache_dit.cache_factory.cache_blocks.utils import (
18
+ patch_cached_stats,
19
+ remove_cached_stats,
20
+ )
18
21
  from cache_dit.logger import init_logger
19
22
 
20
23
  logger = init_logger(__name__)
@@ -29,9 +32,15 @@ class CachedAdapter:
29
32
  @classmethod
30
33
  def apply(
31
34
  cls,
32
- pipe_or_adapter: DiffusionPipeline | BlockAdapter,
35
+ pipe_or_adapter: Union[
36
+ DiffusionPipeline,
37
+ BlockAdapter,
38
+ ],
33
39
  **cache_context_kwargs,
34
- ) -> BlockAdapter:
40
+ ) -> Union[
41
+ DiffusionPipeline,
42
+ BlockAdapter,
43
+ ]:
35
44
  assert (
36
45
  pipe_or_adapter is not None
37
46
  ), "pipe or block_adapter can not both None!"
@@ -49,7 +58,7 @@ class CachedAdapter:
49
58
  return cls.cachify(
50
59
  block_adapter,
51
60
  **cache_context_kwargs,
52
- )
61
+ ).pipe
53
62
  else:
54
63
  raise ValueError(
55
64
  f"{pipe_or_adapter.__class__.__name__} is not officially supported "
@@ -82,7 +91,7 @@ class CachedAdapter:
82
91
  # 0. Must normalize block_adapter before apply cache
83
92
  block_adapter = BlockAdapter.normalize(block_adapter)
84
93
  if BlockAdapter.is_cached(block_adapter):
85
- return block_adapter.pipe
94
+ return block_adapter
86
95
 
87
96
  # 1. Apply cache on pipeline: wrap cache context, must
88
97
  # call create_context before mock_blocks.
@@ -98,36 +107,6 @@ class CachedAdapter:
98
107
 
99
108
  return block_adapter
100
109
 
101
- @classmethod
102
- def patch_params(
103
- cls,
104
- block_adapter: BlockAdapter,
105
- contexts_kwargs: List[Dict],
106
- ):
107
- block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
108
-
109
- params_shift = 0
110
- for i in range(len(block_adapter.transformer)):
111
-
112
- block_adapter.transformer[i]._forward_pattern = (
113
- block_adapter.forward_pattern
114
- )
115
- block_adapter.transformer[i]._has_separate_cfg = (
116
- block_adapter.has_separate_cfg
117
- )
118
- block_adapter.transformer[i]._cache_context_kwargs = (
119
- contexts_kwargs[params_shift]
120
- )
121
-
122
- blocks = block_adapter.blocks[i]
123
- for j in range(len(blocks)):
124
- blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
125
- blocks[j]._cache_context_kwargs = contexts_kwargs[
126
- params_shift + j
127
- ]
128
-
129
- params_shift += len(blocks)
130
-
131
110
  @classmethod
132
111
  def check_context_kwargs(
133
112
  cls,
@@ -153,7 +132,9 @@ class CachedAdapter:
153
132
  f"Pipeline: {block_adapter.pipe.__class__.__name__}."
154
133
  )
155
134
 
156
- if cache_type := cache_context_kwargs.pop("cache_type", None):
135
+ if (
136
+ cache_type := cache_context_kwargs.pop("cache_type", None)
137
+ ) is not None:
157
138
  assert (
158
139
  cache_type == CacheType.DBCache
159
140
  ), "Custom cache setting only support for DBCache now!"
@@ -210,14 +191,14 @@ class CachedAdapter:
210
191
  )
211
192
  )
212
193
  outputs = original_call(self, *args, **kwargs)
213
- cls.patch_stats(block_adapter)
194
+ cls.apply_stats_hooks(block_adapter)
214
195
  return outputs
215
196
 
216
197
  block_adapter.pipe.__class__.__call__ = new_call
217
198
  block_adapter.pipe.__class__._original_call = original_call
218
199
  block_adapter.pipe.__class__._is_cached = True
219
200
 
220
- cls.patch_params(block_adapter, contexts_kwargs)
201
+ cls.apply_params_hooks(block_adapter, contexts_kwargs)
221
202
 
222
203
  return block_adapter.pipe
223
204
 
@@ -261,33 +242,6 @@ class CachedAdapter:
261
242
 
262
243
  return flatten_contexts, contexts_kwargs
263
244
 
264
- @classmethod
265
- def patch_stats(
266
- cls,
267
- block_adapter: BlockAdapter,
268
- ):
269
- from cache_dit.cache_factory.cache_blocks.utils import (
270
- patch_cached_stats,
271
- )
272
-
273
- cache_manager = block_adapter.pipe._cache_manager
274
-
275
- for i in range(len(block_adapter.transformer)):
276
- patch_cached_stats(
277
- block_adapter.transformer[i],
278
- cache_context=block_adapter.unique_blocks_name[i][-1],
279
- cache_manager=cache_manager,
280
- )
281
- for blocks, unique_name in zip(
282
- block_adapter.blocks[i],
283
- block_adapter.unique_blocks_name[i],
284
- ):
285
- patch_cached_stats(
286
- blocks,
287
- cache_context=unique_name,
288
- cache_manager=cache_manager,
289
- )
290
-
291
245
  @classmethod
292
246
  def mock_blocks(
293
247
  cls,
@@ -391,6 +345,7 @@ class CachedAdapter:
391
345
  block_adapter.blocks[i][j],
392
346
  transformer=block_adapter.transformer[i],
393
347
  forward_pattern=block_adapter.forward_pattern[i][j],
348
+ check_forward_pattern=block_adapter.check_forward_pattern,
394
349
  check_num_outputs=block_adapter.check_num_outputs,
395
350
  # 1. Cache context configuration
396
351
  cache_prefix=block_adapter.blocks_name[i][j],
@@ -405,3 +360,159 @@ class CachedAdapter:
405
360
  total_cached_blocks.append(cached_blocks_bind_context)
406
361
 
407
362
  return total_cached_blocks
363
+
364
+ @classmethod
365
+ def apply_params_hooks(
366
+ cls,
367
+ block_adapter: BlockAdapter,
368
+ contexts_kwargs: List[Dict],
369
+ ):
370
+ block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
371
+
372
+ params_shift = 0
373
+ for i in range(len(block_adapter.transformer)):
374
+
375
+ block_adapter.transformer[i]._forward_pattern = (
376
+ block_adapter.forward_pattern
377
+ )
378
+ block_adapter.transformer[i]._has_separate_cfg = (
379
+ block_adapter.has_separate_cfg
380
+ )
381
+ block_adapter.transformer[i]._cache_context_kwargs = (
382
+ contexts_kwargs[params_shift]
383
+ )
384
+
385
+ blocks = block_adapter.blocks[i]
386
+ for j in range(len(blocks)):
387
+ blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
388
+ blocks[j]._cache_context_kwargs = contexts_kwargs[
389
+ params_shift + j
390
+ ]
391
+
392
+ params_shift += len(blocks)
393
+
394
+ @classmethod
395
+ def apply_stats_hooks(
396
+ cls,
397
+ block_adapter: BlockAdapter,
398
+ ):
399
+ cache_manager = block_adapter.pipe._cache_manager
400
+
401
+ for i in range(len(block_adapter.transformer)):
402
+ patch_cached_stats(
403
+ block_adapter.transformer[i],
404
+ cache_context=block_adapter.unique_blocks_name[i][-1],
405
+ cache_manager=cache_manager,
406
+ )
407
+ for blocks, unique_name in zip(
408
+ block_adapter.blocks[i],
409
+ block_adapter.unique_blocks_name[i],
410
+ ):
411
+ patch_cached_stats(
412
+ blocks,
413
+ cache_context=unique_name,
414
+ cache_manager=cache_manager,
415
+ )
416
+
417
+ @classmethod
418
+ def maybe_release_hooks(
419
+ cls,
420
+ pipe_or_adapter: Union[
421
+ DiffusionPipeline,
422
+ BlockAdapter,
423
+ ],
424
+ ):
425
+ # release model hooks
426
+ def _release_blocks_hooks(blocks):
427
+ return
428
+
429
+ def _release_transformer_hooks(transformer):
430
+ if hasattr(transformer, "_original_forward"):
431
+ original_forward = transformer._original_forward
432
+ transformer.forward = original_forward.__get__(transformer)
433
+ del transformer._original_forward
434
+ if hasattr(transformer, "_is_cached"):
435
+ del transformer._is_cached
436
+
437
+ def _release_pipeline_hooks(pipe):
438
+ if hasattr(pipe, "_original_call"):
439
+ original_call = pipe.__class__._original_call
440
+ pipe.__class__.__call__ = original_call
441
+ del pipe.__class__._original_call
442
+ if hasattr(pipe, "_cache_manager"):
443
+ cache_manager = pipe._cache_manager
444
+ if isinstance(cache_manager, CachedContextManager):
445
+ cache_manager.clear_contexts()
446
+ del pipe._cache_manager
447
+ if hasattr(pipe, "_is_cached"):
448
+ del pipe.__class__._is_cached
449
+
450
+ cls.release_hooks(
451
+ pipe_or_adapter,
452
+ _release_blocks_hooks,
453
+ _release_transformer_hooks,
454
+ _release_pipeline_hooks,
455
+ )
456
+
457
+ # release params hooks
458
+ def _release_blocks_params(blocks):
459
+ if hasattr(blocks, "_forward_pattern"):
460
+ del blocks._forward_pattern
461
+ if hasattr(blocks, "_cache_context_kwargs"):
462
+ del blocks._cache_context_kwargs
463
+
464
+ def _release_transformer_params(transformer):
465
+ if hasattr(transformer, "_forward_pattern"):
466
+ del transformer._forward_pattern
467
+ if hasattr(transformer, "_has_separate_cfg"):
468
+ del transformer._has_separate_cfg
469
+ if hasattr(transformer, "_cache_context_kwargs"):
470
+ del transformer._cache_context_kwargs
471
+ for blocks in BlockAdapter.find_blocks(transformer):
472
+ _release_blocks_params(blocks)
473
+
474
+ def _release_pipeline_params(pipe):
475
+ if hasattr(pipe, "_cache_context_kwargs"):
476
+ del pipe._cache_context_kwargs
477
+
478
+ cls.release_hooks(
479
+ pipe_or_adapter,
480
+ _release_blocks_params,
481
+ _release_transformer_params,
482
+ _release_pipeline_params,
483
+ )
484
+
485
+ # release stats hooks
486
+ cls.release_hooks(
487
+ pipe_or_adapter,
488
+ remove_cached_stats,
489
+ remove_cached_stats,
490
+ remove_cached_stats,
491
+ )
492
+
493
+ @classmethod
494
+ def release_hooks(
495
+ cls,
496
+ pipe_or_adapter: Union[
497
+ DiffusionPipeline,
498
+ BlockAdapter,
499
+ ],
500
+ _release_blocks: Callable,
501
+ _release_transformer: Callable,
502
+ _release_pipeline: Callable,
503
+ ):
504
+ if isinstance(pipe_or_adapter, DiffusionPipeline):
505
+ pipe = pipe_or_adapter
506
+ _release_pipeline(pipe)
507
+ if hasattr(pipe, "transformer"):
508
+ _release_transformer(pipe.transformer)
509
+ if hasattr(pipe, "transformer_2"): # Wan 2.2
510
+ _release_transformer(pipe.transformer_2)
511
+ elif isinstance(pipe_or_adapter, BlockAdapter):
512
+ adapter = pipe_or_adapter
513
+ BlockAdapter.assert_normalized(adapter)
514
+ _release_pipeline(adapter.pipe)
515
+ for transformer in BlockAdapter.flatten(adapter.transformer):
516
+ _release_transformer(transformer)
517
+ for blocks in BlockAdapter.flatten(adapter.blocks):
518
+ _release_blocks(blocks)
@@ -25,6 +25,7 @@ class CachedBlocks:
25
25
  transformer_blocks: torch.nn.ModuleList,
26
26
  transformer: torch.nn.Module = None,
27
27
  forward_pattern: ForwardPattern = None,
28
+ check_forward_pattern: bool = True,
28
29
  check_num_outputs: bool = True,
29
30
  # 1. Cache context configuration
30
31
  # 'transformer_blocks', 'blocks', 'single_transformer_blocks',
@@ -45,6 +46,7 @@ class CachedBlocks:
45
46
  transformer_blocks,
46
47
  transformer=transformer,
47
48
  forward_pattern=forward_pattern,
49
+ check_forward_pattern=check_forward_pattern,
48
50
  check_num_outputs=check_num_outputs,
49
51
  # 1. Cache context configuration
50
52
  cache_prefix=cache_prefix,
@@ -58,6 +60,7 @@ class CachedBlocks:
58
60
  transformer_blocks,
59
61
  transformer=transformer,
60
62
  forward_pattern=forward_pattern,
63
+ check_forward_pattern=check_forward_pattern,
61
64
  check_num_outputs=check_num_outputs,
62
65
  # 1. Cache context configuration
63
66
  cache_prefix=cache_prefix,