cache-dit 0.2.29__py3-none-any.whl → 0.2.31__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/block_adapters/__init__.py +95 -61
- cache_dit/cache_factory/block_adapters/block_adapters.py +27 -6
- cache_dit/cache_factory/block_adapters/block_registers.py +10 -7
- cache_dit/cache_factory/cache_adapters.py +177 -66
- cache_dit/cache_factory/cache_blocks/__init__.py +3 -0
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +70 -67
- cache_dit/cache_factory/cache_blocks/pattern_base.py +13 -0
- cache_dit/cache_factory/cache_contexts/cache_manager.py +8 -10
- cache_dit/cache_factory/cache_interface.py +19 -77
- cache_dit/cache_factory/cache_types.py +5 -5
- cache_dit/cache_factory/patch_functors/__init__.py +6 -0
- cache_dit/cache_factory/patch_functors/functor_chroma.py +5 -3
- cache_dit/cache_factory/patch_functors/functor_flux.py +5 -3
- cache_dit/cache_factory/patch_functors/functor_hidream.py +412 -0
- cache_dit/cache_factory/patch_functors/functor_hunyuan_dit.py +213 -0
- cache_dit/utils.py +5 -1
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/METADATA +14 -48
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/RECORD +23 -21
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.29.dist-info → cache_dit-0.2.31.dist-info}/top_level.txt +0 -0
cache_dit/_version.py
CHANGED
|
@@ -28,7 +28,7 @@ version_tuple: VERSION_TUPLE
|
|
|
28
28
|
commit_id: COMMIT_ID
|
|
29
29
|
__commit_id__: COMMIT_ID
|
|
30
30
|
|
|
31
|
-
__version__ = version = '0.2.
|
|
32
|
-
__version_tuple__ = version_tuple = (0, 2,
|
|
31
|
+
__version__ = version = '0.2.31'
|
|
32
|
+
__version_tuple__ = version_tuple = (0, 2, 31)
|
|
33
33
|
|
|
34
34
|
__commit_id__ = commit_id = None
|
|
@@ -123,6 +123,7 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
123
123
|
assert isinstance(pipe.transformer, HunyuanVideoTransformer3DModel)
|
|
124
124
|
return BlockAdapter(
|
|
125
125
|
pipe=pipe,
|
|
126
|
+
transformer=pipe.transformer,
|
|
126
127
|
blocks=[
|
|
127
128
|
pipe.transformer.transformer_blocks,
|
|
128
129
|
pipe.transformer.single_transformer_blocks,
|
|
@@ -131,6 +132,8 @@ def hunyuanvideo_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
131
132
|
ForwardPattern.Pattern_0,
|
|
132
133
|
ForwardPattern.Pattern_0,
|
|
133
134
|
],
|
|
135
|
+
# The type hint in diffusers is wrong
|
|
136
|
+
check_num_outputs=False,
|
|
134
137
|
**kwargs,
|
|
135
138
|
)
|
|
136
139
|
|
|
@@ -327,37 +330,6 @@ def bria_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
327
330
|
)
|
|
328
331
|
|
|
329
332
|
|
|
330
|
-
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
331
|
-
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
332
|
-
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
333
|
-
|
|
334
|
-
assert isinstance(
|
|
335
|
-
pipe.transformer,
|
|
336
|
-
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
337
|
-
)
|
|
338
|
-
return BlockAdapter(
|
|
339
|
-
pipe=pipe,
|
|
340
|
-
transformer=pipe.transformer,
|
|
341
|
-
blocks=pipe.transformer.blocks,
|
|
342
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
343
|
-
**kwargs,
|
|
344
|
-
)
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
348
|
-
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
349
|
-
from diffusers import HunyuanDiT2DModel
|
|
350
|
-
|
|
351
|
-
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
352
|
-
return BlockAdapter(
|
|
353
|
-
pipe=pipe,
|
|
354
|
-
transformer=pipe.transformer,
|
|
355
|
-
blocks=pipe.transformer.blocks,
|
|
356
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
357
|
-
**kwargs,
|
|
358
|
-
)
|
|
359
|
-
|
|
360
|
-
|
|
361
333
|
@BlockAdapterRegistry.register("Lumina")
|
|
362
334
|
def lumina_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
363
335
|
from diffusers import LuminaNextDiT2DModel
|
|
@@ -414,10 +386,12 @@ def pixart_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
414
386
|
)
|
|
415
387
|
|
|
416
388
|
|
|
417
|
-
@BlockAdapterRegistry.register("Sana")
|
|
389
|
+
@BlockAdapterRegistry.register("Sana", supported=False)
|
|
418
390
|
def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
419
391
|
from diffusers import SanaTransformer2DModel
|
|
420
392
|
|
|
393
|
+
# TODO: fix -> got multiple values for argument 'encoder_hidden_states'
|
|
394
|
+
|
|
421
395
|
assert isinstance(pipe.transformer, SanaTransformer2DModel)
|
|
422
396
|
return BlockAdapter(
|
|
423
397
|
pipe=pipe,
|
|
@@ -428,20 +402,6 @@ def sana_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
428
402
|
)
|
|
429
403
|
|
|
430
404
|
|
|
431
|
-
@BlockAdapterRegistry.register("ShapE")
|
|
432
|
-
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
433
|
-
from diffusers import PriorTransformer
|
|
434
|
-
|
|
435
|
-
assert isinstance(pipe.prior, PriorTransformer)
|
|
436
|
-
return BlockAdapter(
|
|
437
|
-
pipe=pipe,
|
|
438
|
-
transformer=pipe.prior,
|
|
439
|
-
blocks=pipe.prior.transformer_blocks,
|
|
440
|
-
forward_pattern=ForwardPattern.Pattern_3,
|
|
441
|
-
**kwargs,
|
|
442
|
-
)
|
|
443
|
-
|
|
444
|
-
|
|
445
405
|
@BlockAdapterRegistry.register("StableAudio")
|
|
446
406
|
def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
447
407
|
from diffusers import StableAudioDiTModel
|
|
@@ -459,21 +419,37 @@ def stabledudio_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
459
419
|
@BlockAdapterRegistry.register("VisualCloze")
|
|
460
420
|
def visualcloze_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
461
421
|
from diffusers import FluxTransformer2DModel
|
|
422
|
+
from cache_dit.utils import is_diffusers_at_least_0_3_5
|
|
462
423
|
|
|
463
424
|
assert isinstance(pipe.transformer, FluxTransformer2DModel)
|
|
464
|
-
|
|
465
|
-
|
|
466
|
-
|
|
467
|
-
|
|
468
|
-
|
|
469
|
-
|
|
470
|
-
|
|
471
|
-
|
|
472
|
-
|
|
473
|
-
|
|
474
|
-
|
|
475
|
-
|
|
476
|
-
|
|
425
|
+
if is_diffusers_at_least_0_3_5():
|
|
426
|
+
return BlockAdapter(
|
|
427
|
+
pipe=pipe,
|
|
428
|
+
transformer=pipe.transformer,
|
|
429
|
+
blocks=[
|
|
430
|
+
pipe.transformer.transformer_blocks,
|
|
431
|
+
pipe.transformer.single_transformer_blocks,
|
|
432
|
+
],
|
|
433
|
+
forward_pattern=[
|
|
434
|
+
ForwardPattern.Pattern_1,
|
|
435
|
+
ForwardPattern.Pattern_1,
|
|
436
|
+
],
|
|
437
|
+
**kwargs,
|
|
438
|
+
)
|
|
439
|
+
else:
|
|
440
|
+
return BlockAdapter(
|
|
441
|
+
pipe=pipe,
|
|
442
|
+
transformer=pipe.transformer,
|
|
443
|
+
blocks=[
|
|
444
|
+
pipe.transformer.transformer_blocks,
|
|
445
|
+
pipe.transformer.single_transformer_blocks,
|
|
446
|
+
],
|
|
447
|
+
forward_pattern=[
|
|
448
|
+
ForwardPattern.Pattern_1,
|
|
449
|
+
ForwardPattern.Pattern_3,
|
|
450
|
+
],
|
|
451
|
+
**kwargs,
|
|
452
|
+
)
|
|
477
453
|
|
|
478
454
|
|
|
479
455
|
@BlockAdapterRegistry.register("AuraFlow")
|
|
@@ -511,9 +487,29 @@ def chroma_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
511
487
|
)
|
|
512
488
|
|
|
513
489
|
|
|
490
|
+
@BlockAdapterRegistry.register("ShapE")
|
|
491
|
+
def shape_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
492
|
+
from diffusers import PriorTransformer
|
|
493
|
+
|
|
494
|
+
assert isinstance(pipe.prior, PriorTransformer)
|
|
495
|
+
return BlockAdapter(
|
|
496
|
+
pipe=pipe,
|
|
497
|
+
transformer=pipe.prior,
|
|
498
|
+
blocks=pipe.prior.transformer_blocks,
|
|
499
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
500
|
+
**kwargs,
|
|
501
|
+
)
|
|
502
|
+
|
|
503
|
+
|
|
514
504
|
@BlockAdapterRegistry.register("HiDream")
|
|
515
505
|
def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
506
|
+
# NOTE: Need to patch Transformer forward to fully support
|
|
507
|
+
# double_stream_blocks and single_stream_blocks, namely, need
|
|
508
|
+
# to remove the logics inside the blocks forward loop:
|
|
509
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L893
|
|
510
|
+
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_hidream_image.py#L927
|
|
516
511
|
from diffusers import HiDreamImageTransformer2DModel
|
|
512
|
+
from cache_dit.cache_factory.patch_functors import HiDreamPatchFunctor
|
|
517
513
|
|
|
518
514
|
assert isinstance(pipe.transformer, HiDreamImageTransformer2DModel)
|
|
519
515
|
return BlockAdapter(
|
|
@@ -524,9 +520,47 @@ def hidream_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
|
524
520
|
pipe.transformer.single_stream_blocks,
|
|
525
521
|
],
|
|
526
522
|
forward_pattern=[
|
|
527
|
-
ForwardPattern.
|
|
523
|
+
ForwardPattern.Pattern_0,
|
|
528
524
|
ForwardPattern.Pattern_3,
|
|
529
525
|
],
|
|
530
|
-
|
|
526
|
+
patch_functor=HiDreamPatchFunctor(),
|
|
527
|
+
# NOTE: The type hint in diffusers is wrong
|
|
528
|
+
check_forward_pattern=True,
|
|
529
|
+
check_num_outputs=True,
|
|
530
|
+
**kwargs,
|
|
531
|
+
)
|
|
532
|
+
|
|
533
|
+
|
|
534
|
+
@BlockAdapterRegistry.register("HunyuanDiT")
|
|
535
|
+
def hunyuandit_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
536
|
+
from diffusers import HunyuanDiT2DModel, HunyuanDiT2DControlNetModel
|
|
537
|
+
from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
|
|
538
|
+
|
|
539
|
+
assert isinstance(
|
|
540
|
+
pipe.transformer,
|
|
541
|
+
(HunyuanDiT2DModel, HunyuanDiT2DControlNetModel),
|
|
542
|
+
)
|
|
543
|
+
return BlockAdapter(
|
|
544
|
+
pipe=pipe,
|
|
545
|
+
transformer=pipe.transformer,
|
|
546
|
+
blocks=pipe.transformer.blocks,
|
|
547
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
548
|
+
patch_functor=HunyuanDiTPatchFunctor(),
|
|
549
|
+
**kwargs,
|
|
550
|
+
)
|
|
551
|
+
|
|
552
|
+
|
|
553
|
+
@BlockAdapterRegistry.register("HunyuanDiTPAG")
|
|
554
|
+
def hunyuanditpag_adapter(pipe, **kwargs) -> BlockAdapter:
|
|
555
|
+
from diffusers import HunyuanDiT2DModel
|
|
556
|
+
from cache_dit.cache_factory.patch_functors import HunyuanDiTPatchFunctor
|
|
557
|
+
|
|
558
|
+
assert isinstance(pipe.transformer, HunyuanDiT2DModel)
|
|
559
|
+
return BlockAdapter(
|
|
560
|
+
pipe=pipe,
|
|
561
|
+
transformer=pipe.transformer,
|
|
562
|
+
blocks=pipe.transformer.blocks,
|
|
563
|
+
forward_pattern=ForwardPattern.Pattern_3,
|
|
564
|
+
patch_functor=HunyuanDiTPatchFunctor(),
|
|
531
565
|
**kwargs,
|
|
532
566
|
)
|
|
@@ -75,7 +75,8 @@ class BlockAdapter:
|
|
|
75
75
|
List[List[ParamsModifier]],
|
|
76
76
|
] = None
|
|
77
77
|
|
|
78
|
-
|
|
78
|
+
check_forward_pattern: bool = True
|
|
79
|
+
check_num_outputs: bool = False
|
|
79
80
|
|
|
80
81
|
# Pipeline Level Flags
|
|
81
82
|
# Patch Functor: Flux, etc.
|
|
@@ -111,9 +112,9 @@ class BlockAdapter:
|
|
|
111
112
|
def __post_init__(self):
|
|
112
113
|
if self.skip_post_init:
|
|
113
114
|
return
|
|
114
|
-
|
|
115
|
-
|
|
116
|
-
|
|
115
|
+
if any((self.pipe is not None, self.transformer is not None)):
|
|
116
|
+
self.maybe_fill_attrs()
|
|
117
|
+
self.maybe_patchify()
|
|
117
118
|
|
|
118
119
|
def maybe_fill_attrs(self):
|
|
119
120
|
# NOTE: This func should be call before normalize.
|
|
@@ -130,7 +131,9 @@ class BlockAdapter:
|
|
|
130
131
|
assert isinstance(blocks, torch.nn.ModuleList)
|
|
131
132
|
blocks_name = None
|
|
132
133
|
for attr_name in attr_names:
|
|
133
|
-
if
|
|
134
|
+
if (
|
|
135
|
+
attr := getattr(transformer, attr_name, None)
|
|
136
|
+
) is not None:
|
|
134
137
|
if isinstance(attr, torch.nn.ModuleList) and id(
|
|
135
138
|
attr
|
|
136
139
|
) == id(blocks):
|
|
@@ -389,11 +392,20 @@ class BlockAdapter:
|
|
|
389
392
|
forward_pattern: ForwardPattern,
|
|
390
393
|
**kwargs,
|
|
391
394
|
) -> bool:
|
|
395
|
+
|
|
396
|
+
if not kwargs.get("check_forward_pattern", True):
|
|
397
|
+
return True
|
|
398
|
+
|
|
392
399
|
assert (
|
|
393
400
|
forward_pattern.Supported
|
|
394
401
|
and forward_pattern in ForwardPattern.supported_patterns()
|
|
395
402
|
), f"Pattern {forward_pattern} is not support now!"
|
|
396
403
|
|
|
404
|
+
# NOTE: Special case for HiDreamBlock
|
|
405
|
+
if hasattr(block, "block"):
|
|
406
|
+
if isinstance(block.block, torch.nn.Module):
|
|
407
|
+
block = block.block
|
|
408
|
+
|
|
397
409
|
forward_parameters = set(
|
|
398
410
|
inspect.signature(block.forward).parameters.keys()
|
|
399
411
|
)
|
|
@@ -423,6 +435,14 @@ class BlockAdapter:
|
|
|
423
435
|
logging: bool = True,
|
|
424
436
|
**kwargs,
|
|
425
437
|
) -> bool:
|
|
438
|
+
|
|
439
|
+
if not kwargs.get("check_forward_pattern", True):
|
|
440
|
+
if logging:
|
|
441
|
+
logger.warning(
|
|
442
|
+
f"Skipped Forward Pattern Check: {forward_pattern}"
|
|
443
|
+
)
|
|
444
|
+
return True
|
|
445
|
+
|
|
426
446
|
assert (
|
|
427
447
|
forward_pattern.Supported
|
|
428
448
|
and forward_pattern in ForwardPattern.supported_patterns()
|
|
@@ -529,6 +549,7 @@ class BlockAdapter:
|
|
|
529
549
|
blocks,
|
|
530
550
|
forward_pattern=forward_pattern,
|
|
531
551
|
check_num_outputs=adapter.check_num_outputs,
|
|
552
|
+
check_forward_pattern=adapter.check_forward_pattern,
|
|
532
553
|
), (
|
|
533
554
|
"No block forward pattern matched, "
|
|
534
555
|
f"supported lists: {ForwardPattern.supported_patterns()}"
|
|
@@ -558,7 +579,7 @@ class BlockAdapter:
|
|
|
558
579
|
assert isinstance(adapter[0], torch.nn.Module)
|
|
559
580
|
return getattr(adapter[0], "_is_cached", False)
|
|
560
581
|
else:
|
|
561
|
-
raise TypeError(f"Can't check this type: {adapter}!")
|
|
582
|
+
raise TypeError(f"Can't check this type: {type(adapter)}!")
|
|
562
583
|
|
|
563
584
|
@classmethod
|
|
564
585
|
def nested_depth(cls, obj: Any):
|
|
@@ -1,4 +1,4 @@
|
|
|
1
|
-
from typing import Any, Tuple, List, Dict
|
|
1
|
+
from typing import Any, Tuple, List, Dict, Callable
|
|
2
2
|
|
|
3
3
|
from diffusers import DiffusionPipeline
|
|
4
4
|
from cache_dit.cache_factory.block_adapters.block_adapters import BlockAdapter
|
|
@@ -9,20 +9,23 @@ logger = init_logger(__name__)
|
|
|
9
9
|
|
|
10
10
|
|
|
11
11
|
class BlockAdapterRegistry:
|
|
12
|
-
_adapters: Dict[str, BlockAdapter] = {}
|
|
13
|
-
_predefined_adapters_has_spearate_cfg: List[str] =
|
|
12
|
+
_adapters: Dict[str, Callable[..., BlockAdapter]] = {}
|
|
13
|
+
_predefined_adapters_has_spearate_cfg: List[str] = [
|
|
14
14
|
"QwenImage",
|
|
15
15
|
"Wan",
|
|
16
16
|
"CogView4",
|
|
17
17
|
"Cosmos",
|
|
18
18
|
"SkyReelsV2",
|
|
19
19
|
"Chroma",
|
|
20
|
-
|
|
20
|
+
]
|
|
21
21
|
|
|
22
22
|
@classmethod
|
|
23
|
-
def register(cls, name):
|
|
24
|
-
def decorator(
|
|
25
|
-
|
|
23
|
+
def register(cls, name: str, supported: bool = True):
|
|
24
|
+
def decorator(
|
|
25
|
+
func: Callable[..., BlockAdapter]
|
|
26
|
+
) -> Callable[..., BlockAdapter]:
|
|
27
|
+
if supported:
|
|
28
|
+
cls._adapters[name] = func
|
|
26
29
|
return func
|
|
27
30
|
|
|
28
31
|
return decorator
|
|
@@ -4,7 +4,7 @@ import unittest
|
|
|
4
4
|
import functools
|
|
5
5
|
|
|
6
6
|
from contextlib import ExitStack
|
|
7
|
-
from typing import Dict, List, Tuple, Any
|
|
7
|
+
from typing import Dict, List, Tuple, Any, Union, Callable
|
|
8
8
|
|
|
9
9
|
from diffusers import DiffusionPipeline
|
|
10
10
|
|
|
@@ -14,7 +14,10 @@ from cache_dit.cache_factory import ParamsModifier
|
|
|
14
14
|
from cache_dit.cache_factory import BlockAdapterRegistry
|
|
15
15
|
from cache_dit.cache_factory import CachedContextManager
|
|
16
16
|
from cache_dit.cache_factory import CachedBlocks
|
|
17
|
-
|
|
17
|
+
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
18
|
+
patch_cached_stats,
|
|
19
|
+
remove_cached_stats,
|
|
20
|
+
)
|
|
18
21
|
from cache_dit.logger import init_logger
|
|
19
22
|
|
|
20
23
|
logger = init_logger(__name__)
|
|
@@ -29,9 +32,15 @@ class CachedAdapter:
|
|
|
29
32
|
@classmethod
|
|
30
33
|
def apply(
|
|
31
34
|
cls,
|
|
32
|
-
pipe_or_adapter:
|
|
35
|
+
pipe_or_adapter: Union[
|
|
36
|
+
DiffusionPipeline,
|
|
37
|
+
BlockAdapter,
|
|
38
|
+
],
|
|
33
39
|
**cache_context_kwargs,
|
|
34
|
-
) ->
|
|
40
|
+
) -> Union[
|
|
41
|
+
DiffusionPipeline,
|
|
42
|
+
BlockAdapter,
|
|
43
|
+
]:
|
|
35
44
|
assert (
|
|
36
45
|
pipe_or_adapter is not None
|
|
37
46
|
), "pipe or block_adapter can not both None!"
|
|
@@ -49,7 +58,7 @@ class CachedAdapter:
|
|
|
49
58
|
return cls.cachify(
|
|
50
59
|
block_adapter,
|
|
51
60
|
**cache_context_kwargs,
|
|
52
|
-
)
|
|
61
|
+
).pipe
|
|
53
62
|
else:
|
|
54
63
|
raise ValueError(
|
|
55
64
|
f"{pipe_or_adapter.__class__.__name__} is not officially supported "
|
|
@@ -82,7 +91,7 @@ class CachedAdapter:
|
|
|
82
91
|
# 0. Must normalize block_adapter before apply cache
|
|
83
92
|
block_adapter = BlockAdapter.normalize(block_adapter)
|
|
84
93
|
if BlockAdapter.is_cached(block_adapter):
|
|
85
|
-
return block_adapter
|
|
94
|
+
return block_adapter
|
|
86
95
|
|
|
87
96
|
# 1. Apply cache on pipeline: wrap cache context, must
|
|
88
97
|
# call create_context before mock_blocks.
|
|
@@ -98,36 +107,6 @@ class CachedAdapter:
|
|
|
98
107
|
|
|
99
108
|
return block_adapter
|
|
100
109
|
|
|
101
|
-
@classmethod
|
|
102
|
-
def patch_params(
|
|
103
|
-
cls,
|
|
104
|
-
block_adapter: BlockAdapter,
|
|
105
|
-
contexts_kwargs: List[Dict],
|
|
106
|
-
):
|
|
107
|
-
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
108
|
-
|
|
109
|
-
params_shift = 0
|
|
110
|
-
for i in range(len(block_adapter.transformer)):
|
|
111
|
-
|
|
112
|
-
block_adapter.transformer[i]._forward_pattern = (
|
|
113
|
-
block_adapter.forward_pattern
|
|
114
|
-
)
|
|
115
|
-
block_adapter.transformer[i]._has_separate_cfg = (
|
|
116
|
-
block_adapter.has_separate_cfg
|
|
117
|
-
)
|
|
118
|
-
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
119
|
-
contexts_kwargs[params_shift]
|
|
120
|
-
)
|
|
121
|
-
|
|
122
|
-
blocks = block_adapter.blocks[i]
|
|
123
|
-
for j in range(len(blocks)):
|
|
124
|
-
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
125
|
-
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
126
|
-
params_shift + j
|
|
127
|
-
]
|
|
128
|
-
|
|
129
|
-
params_shift += len(blocks)
|
|
130
|
-
|
|
131
110
|
@classmethod
|
|
132
111
|
def check_context_kwargs(
|
|
133
112
|
cls,
|
|
@@ -153,7 +132,9 @@ class CachedAdapter:
|
|
|
153
132
|
f"Pipeline: {block_adapter.pipe.__class__.__name__}."
|
|
154
133
|
)
|
|
155
134
|
|
|
156
|
-
if
|
|
135
|
+
if (
|
|
136
|
+
cache_type := cache_context_kwargs.pop("cache_type", None)
|
|
137
|
+
) is not None:
|
|
157
138
|
assert (
|
|
158
139
|
cache_type == CacheType.DBCache
|
|
159
140
|
), "Custom cache setting only support for DBCache now!"
|
|
@@ -210,14 +191,14 @@ class CachedAdapter:
|
|
|
210
191
|
)
|
|
211
192
|
)
|
|
212
193
|
outputs = original_call(self, *args, **kwargs)
|
|
213
|
-
cls.
|
|
194
|
+
cls.apply_stats_hooks(block_adapter)
|
|
214
195
|
return outputs
|
|
215
196
|
|
|
216
197
|
block_adapter.pipe.__class__.__call__ = new_call
|
|
217
198
|
block_adapter.pipe.__class__._original_call = original_call
|
|
218
199
|
block_adapter.pipe.__class__._is_cached = True
|
|
219
200
|
|
|
220
|
-
cls.
|
|
201
|
+
cls.apply_params_hooks(block_adapter, contexts_kwargs)
|
|
221
202
|
|
|
222
203
|
return block_adapter.pipe
|
|
223
204
|
|
|
@@ -261,33 +242,6 @@ class CachedAdapter:
|
|
|
261
242
|
|
|
262
243
|
return flatten_contexts, contexts_kwargs
|
|
263
244
|
|
|
264
|
-
@classmethod
|
|
265
|
-
def patch_stats(
|
|
266
|
-
cls,
|
|
267
|
-
block_adapter: BlockAdapter,
|
|
268
|
-
):
|
|
269
|
-
from cache_dit.cache_factory.cache_blocks.utils import (
|
|
270
|
-
patch_cached_stats,
|
|
271
|
-
)
|
|
272
|
-
|
|
273
|
-
cache_manager = block_adapter.pipe._cache_manager
|
|
274
|
-
|
|
275
|
-
for i in range(len(block_adapter.transformer)):
|
|
276
|
-
patch_cached_stats(
|
|
277
|
-
block_adapter.transformer[i],
|
|
278
|
-
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
279
|
-
cache_manager=cache_manager,
|
|
280
|
-
)
|
|
281
|
-
for blocks, unique_name in zip(
|
|
282
|
-
block_adapter.blocks[i],
|
|
283
|
-
block_adapter.unique_blocks_name[i],
|
|
284
|
-
):
|
|
285
|
-
patch_cached_stats(
|
|
286
|
-
blocks,
|
|
287
|
-
cache_context=unique_name,
|
|
288
|
-
cache_manager=cache_manager,
|
|
289
|
-
)
|
|
290
|
-
|
|
291
245
|
@classmethod
|
|
292
246
|
def mock_blocks(
|
|
293
247
|
cls,
|
|
@@ -391,6 +345,7 @@ class CachedAdapter:
|
|
|
391
345
|
block_adapter.blocks[i][j],
|
|
392
346
|
transformer=block_adapter.transformer[i],
|
|
393
347
|
forward_pattern=block_adapter.forward_pattern[i][j],
|
|
348
|
+
check_forward_pattern=block_adapter.check_forward_pattern,
|
|
394
349
|
check_num_outputs=block_adapter.check_num_outputs,
|
|
395
350
|
# 1. Cache context configuration
|
|
396
351
|
cache_prefix=block_adapter.blocks_name[i][j],
|
|
@@ -405,3 +360,159 @@ class CachedAdapter:
|
|
|
405
360
|
total_cached_blocks.append(cached_blocks_bind_context)
|
|
406
361
|
|
|
407
362
|
return total_cached_blocks
|
|
363
|
+
|
|
364
|
+
@classmethod
|
|
365
|
+
def apply_params_hooks(
|
|
366
|
+
cls,
|
|
367
|
+
block_adapter: BlockAdapter,
|
|
368
|
+
contexts_kwargs: List[Dict],
|
|
369
|
+
):
|
|
370
|
+
block_adapter.pipe._cache_context_kwargs = contexts_kwargs[0]
|
|
371
|
+
|
|
372
|
+
params_shift = 0
|
|
373
|
+
for i in range(len(block_adapter.transformer)):
|
|
374
|
+
|
|
375
|
+
block_adapter.transformer[i]._forward_pattern = (
|
|
376
|
+
block_adapter.forward_pattern
|
|
377
|
+
)
|
|
378
|
+
block_adapter.transformer[i]._has_separate_cfg = (
|
|
379
|
+
block_adapter.has_separate_cfg
|
|
380
|
+
)
|
|
381
|
+
block_adapter.transformer[i]._cache_context_kwargs = (
|
|
382
|
+
contexts_kwargs[params_shift]
|
|
383
|
+
)
|
|
384
|
+
|
|
385
|
+
blocks = block_adapter.blocks[i]
|
|
386
|
+
for j in range(len(blocks)):
|
|
387
|
+
blocks[j]._forward_pattern = block_adapter.forward_pattern[i][j]
|
|
388
|
+
blocks[j]._cache_context_kwargs = contexts_kwargs[
|
|
389
|
+
params_shift + j
|
|
390
|
+
]
|
|
391
|
+
|
|
392
|
+
params_shift += len(blocks)
|
|
393
|
+
|
|
394
|
+
@classmethod
|
|
395
|
+
def apply_stats_hooks(
|
|
396
|
+
cls,
|
|
397
|
+
block_adapter: BlockAdapter,
|
|
398
|
+
):
|
|
399
|
+
cache_manager = block_adapter.pipe._cache_manager
|
|
400
|
+
|
|
401
|
+
for i in range(len(block_adapter.transformer)):
|
|
402
|
+
patch_cached_stats(
|
|
403
|
+
block_adapter.transformer[i],
|
|
404
|
+
cache_context=block_adapter.unique_blocks_name[i][-1],
|
|
405
|
+
cache_manager=cache_manager,
|
|
406
|
+
)
|
|
407
|
+
for blocks, unique_name in zip(
|
|
408
|
+
block_adapter.blocks[i],
|
|
409
|
+
block_adapter.unique_blocks_name[i],
|
|
410
|
+
):
|
|
411
|
+
patch_cached_stats(
|
|
412
|
+
blocks,
|
|
413
|
+
cache_context=unique_name,
|
|
414
|
+
cache_manager=cache_manager,
|
|
415
|
+
)
|
|
416
|
+
|
|
417
|
+
@classmethod
|
|
418
|
+
def maybe_release_hooks(
|
|
419
|
+
cls,
|
|
420
|
+
pipe_or_adapter: Union[
|
|
421
|
+
DiffusionPipeline,
|
|
422
|
+
BlockAdapter,
|
|
423
|
+
],
|
|
424
|
+
):
|
|
425
|
+
# release model hooks
|
|
426
|
+
def _release_blocks_hooks(blocks):
|
|
427
|
+
return
|
|
428
|
+
|
|
429
|
+
def _release_transformer_hooks(transformer):
|
|
430
|
+
if hasattr(transformer, "_original_forward"):
|
|
431
|
+
original_forward = transformer._original_forward
|
|
432
|
+
transformer.forward = original_forward.__get__(transformer)
|
|
433
|
+
del transformer._original_forward
|
|
434
|
+
if hasattr(transformer, "_is_cached"):
|
|
435
|
+
del transformer._is_cached
|
|
436
|
+
|
|
437
|
+
def _release_pipeline_hooks(pipe):
|
|
438
|
+
if hasattr(pipe, "_original_call"):
|
|
439
|
+
original_call = pipe.__class__._original_call
|
|
440
|
+
pipe.__class__.__call__ = original_call
|
|
441
|
+
del pipe.__class__._original_call
|
|
442
|
+
if hasattr(pipe, "_cache_manager"):
|
|
443
|
+
cache_manager = pipe._cache_manager
|
|
444
|
+
if isinstance(cache_manager, CachedContextManager):
|
|
445
|
+
cache_manager.clear_contexts()
|
|
446
|
+
del pipe._cache_manager
|
|
447
|
+
if hasattr(pipe, "_is_cached"):
|
|
448
|
+
del pipe.__class__._is_cached
|
|
449
|
+
|
|
450
|
+
cls.release_hooks(
|
|
451
|
+
pipe_or_adapter,
|
|
452
|
+
_release_blocks_hooks,
|
|
453
|
+
_release_transformer_hooks,
|
|
454
|
+
_release_pipeline_hooks,
|
|
455
|
+
)
|
|
456
|
+
|
|
457
|
+
# release params hooks
|
|
458
|
+
def _release_blocks_params(blocks):
|
|
459
|
+
if hasattr(blocks, "_forward_pattern"):
|
|
460
|
+
del blocks._forward_pattern
|
|
461
|
+
if hasattr(blocks, "_cache_context_kwargs"):
|
|
462
|
+
del blocks._cache_context_kwargs
|
|
463
|
+
|
|
464
|
+
def _release_transformer_params(transformer):
|
|
465
|
+
if hasattr(transformer, "_forward_pattern"):
|
|
466
|
+
del transformer._forward_pattern
|
|
467
|
+
if hasattr(transformer, "_has_separate_cfg"):
|
|
468
|
+
del transformer._has_separate_cfg
|
|
469
|
+
if hasattr(transformer, "_cache_context_kwargs"):
|
|
470
|
+
del transformer._cache_context_kwargs
|
|
471
|
+
for blocks in BlockAdapter.find_blocks(transformer):
|
|
472
|
+
_release_blocks_params(blocks)
|
|
473
|
+
|
|
474
|
+
def _release_pipeline_params(pipe):
|
|
475
|
+
if hasattr(pipe, "_cache_context_kwargs"):
|
|
476
|
+
del pipe._cache_context_kwargs
|
|
477
|
+
|
|
478
|
+
cls.release_hooks(
|
|
479
|
+
pipe_or_adapter,
|
|
480
|
+
_release_blocks_params,
|
|
481
|
+
_release_transformer_params,
|
|
482
|
+
_release_pipeline_params,
|
|
483
|
+
)
|
|
484
|
+
|
|
485
|
+
# release stats hooks
|
|
486
|
+
cls.release_hooks(
|
|
487
|
+
pipe_or_adapter,
|
|
488
|
+
remove_cached_stats,
|
|
489
|
+
remove_cached_stats,
|
|
490
|
+
remove_cached_stats,
|
|
491
|
+
)
|
|
492
|
+
|
|
493
|
+
@classmethod
|
|
494
|
+
def release_hooks(
|
|
495
|
+
cls,
|
|
496
|
+
pipe_or_adapter: Union[
|
|
497
|
+
DiffusionPipeline,
|
|
498
|
+
BlockAdapter,
|
|
499
|
+
],
|
|
500
|
+
_release_blocks: Callable,
|
|
501
|
+
_release_transformer: Callable,
|
|
502
|
+
_release_pipeline: Callable,
|
|
503
|
+
):
|
|
504
|
+
if isinstance(pipe_or_adapter, DiffusionPipeline):
|
|
505
|
+
pipe = pipe_or_adapter
|
|
506
|
+
_release_pipeline(pipe)
|
|
507
|
+
if hasattr(pipe, "transformer"):
|
|
508
|
+
_release_transformer(pipe.transformer)
|
|
509
|
+
if hasattr(pipe, "transformer_2"): # Wan 2.2
|
|
510
|
+
_release_transformer(pipe.transformer_2)
|
|
511
|
+
elif isinstance(pipe_or_adapter, BlockAdapter):
|
|
512
|
+
adapter = pipe_or_adapter
|
|
513
|
+
BlockAdapter.assert_normalized(adapter)
|
|
514
|
+
_release_pipeline(adapter.pipe)
|
|
515
|
+
for transformer in BlockAdapter.flatten(adapter.transformer):
|
|
516
|
+
_release_transformer(transformer)
|
|
517
|
+
for blocks in BlockAdapter.flatten(adapter.blocks):
|
|
518
|
+
_release_blocks(blocks)
|
|
@@ -25,6 +25,7 @@ class CachedBlocks:
|
|
|
25
25
|
transformer_blocks: torch.nn.ModuleList,
|
|
26
26
|
transformer: torch.nn.Module = None,
|
|
27
27
|
forward_pattern: ForwardPattern = None,
|
|
28
|
+
check_forward_pattern: bool = True,
|
|
28
29
|
check_num_outputs: bool = True,
|
|
29
30
|
# 1. Cache context configuration
|
|
30
31
|
# 'transformer_blocks', 'blocks', 'single_transformer_blocks',
|
|
@@ -45,6 +46,7 @@ class CachedBlocks:
|
|
|
45
46
|
transformer_blocks,
|
|
46
47
|
transformer=transformer,
|
|
47
48
|
forward_pattern=forward_pattern,
|
|
49
|
+
check_forward_pattern=check_forward_pattern,
|
|
48
50
|
check_num_outputs=check_num_outputs,
|
|
49
51
|
# 1. Cache context configuration
|
|
50
52
|
cache_prefix=cache_prefix,
|
|
@@ -58,6 +60,7 @@ class CachedBlocks:
|
|
|
58
60
|
transformer_blocks,
|
|
59
61
|
transformer=transformer,
|
|
60
62
|
forward_pattern=forward_pattern,
|
|
63
|
+
check_forward_pattern=check_forward_pattern,
|
|
61
64
|
check_num_outputs=check_num_outputs,
|
|
62
65
|
# 1. Cache context configuration
|
|
63
66
|
cache_prefix=cache_prefix,
|