cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -15,6 +15,8 @@ def apply_db_cache_on_transformer(transformer, *args, **kwargs):
|
|
|
15
15
|
adapter_name = "wan"
|
|
16
16
|
elif transformer_cls_name.startswith("HunyuanVideo"):
|
|
17
17
|
adapter_name = "hunyuan_video"
|
|
18
|
+
elif transformer_cls_name.startswith("QwenImage"):
|
|
19
|
+
adapter_name = "qwen_image"
|
|
18
20
|
else:
|
|
19
21
|
raise ValueError(
|
|
20
22
|
f"Unknown transformer class name: {transformer_cls_name}"
|
|
@@ -41,6 +43,8 @@ def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
|
41
43
|
adapter_name = "wan"
|
|
42
44
|
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
43
45
|
adapter_name = "hunyuan_video"
|
|
46
|
+
elif pipe_cls_name.startswith("QwenImage"):
|
|
47
|
+
adapter_name = "qwen_image"
|
|
44
48
|
else:
|
|
45
49
|
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
46
50
|
|
|
@@ -4,7 +4,10 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.dual_block_cache import
|
|
7
|
+
from cache_dit.cache_factory.dual_block_cache import (
|
|
8
|
+
cache_context,
|
|
9
|
+
DBCachedTransformerBlocks,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def apply_db_cache_on_transformer(
|
|
@@ -15,7 +18,7 @@ def apply_db_cache_on_transformer(
|
|
|
15
18
|
|
|
16
19
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
17
20
|
[
|
|
18
|
-
|
|
21
|
+
DBCachedTransformerBlocks(
|
|
19
22
|
transformer.transformer_blocks,
|
|
20
23
|
transformer=transformer,
|
|
21
24
|
)
|
|
@@ -4,7 +4,15 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.
|
|
7
|
+
from cache_dit.cache_factory.patch.flux import maybe_patch_flux_transformer
|
|
8
|
+
from cache_dit.cache_factory.dual_block_cache import (
|
|
9
|
+
cache_context,
|
|
10
|
+
DBCachedTransformerBlocks,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
from cache_dit.logger import init_logger
|
|
14
|
+
|
|
15
|
+
logger = init_logger(__name__)
|
|
8
16
|
|
|
9
17
|
|
|
10
18
|
def apply_db_cache_on_transformer(
|
|
@@ -13,11 +21,13 @@ def apply_db_cache_on_transformer(
|
|
|
13
21
|
if getattr(transformer, "_is_cached", False):
|
|
14
22
|
return transformer
|
|
15
23
|
|
|
24
|
+
transformer = maybe_patch_flux_transformer(transformer)
|
|
25
|
+
|
|
16
26
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
17
27
|
[
|
|
18
|
-
|
|
19
|
-
transformer.transformer_blocks
|
|
20
|
-
transformer.single_transformer_blocks,
|
|
28
|
+
DBCachedTransformerBlocks(
|
|
29
|
+
transformer.transformer_blocks
|
|
30
|
+
+ transformer.single_transformer_blocks,
|
|
21
31
|
transformer=transformer,
|
|
22
32
|
return_hidden_states_first=False,
|
|
23
33
|
)
|
|
@@ -11,7 +11,10 @@ from diffusers.utils import (
|
|
|
11
11
|
USE_PEFT_BACKEND,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
|
-
from cache_dit.cache_factory.dual_block_cache import
|
|
14
|
+
from cache_dit.cache_factory.dual_block_cache import (
|
|
15
|
+
cache_context,
|
|
16
|
+
DBCachedTransformerBlocks,
|
|
17
|
+
)
|
|
15
18
|
from cache_dit.logger import init_logger
|
|
16
19
|
|
|
17
20
|
try:
|
|
@@ -44,7 +47,7 @@ def apply_db_cache_on_transformer(
|
|
|
44
47
|
|
|
45
48
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
46
49
|
[
|
|
47
|
-
|
|
50
|
+
DBCachedTransformerBlocks(
|
|
48
51
|
transformer.transformer_blocks
|
|
49
52
|
+ transformer.single_transformer_blocks,
|
|
50
53
|
transformer=transformer,
|
|
@@ -4,7 +4,10 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import DiffusionPipeline, MochiTransformer3DModel
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.dual_block_cache import
|
|
7
|
+
from cache_dit.cache_factory.dual_block_cache import (
|
|
8
|
+
cache_context,
|
|
9
|
+
DBCachedTransformerBlocks,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def apply_db_cache_on_transformer(
|
|
@@ -15,7 +18,7 @@ def apply_db_cache_on_transformer(
|
|
|
15
18
|
|
|
16
19
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
17
20
|
[
|
|
18
|
-
|
|
21
|
+
DBCachedTransformerBlocks(
|
|
19
22
|
transformer.transformer_blocks,
|
|
20
23
|
transformer=transformer,
|
|
21
24
|
)
|
|
@@ -1,25 +1,27 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/mochi.py
|
|
2
|
-
|
|
3
1
|
import functools
|
|
4
2
|
import unittest
|
|
5
3
|
|
|
6
4
|
import torch
|
|
7
|
-
from diffusers import
|
|
5
|
+
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
|
|
8
6
|
|
|
9
|
-
from cache_dit.cache_factory.
|
|
7
|
+
from cache_dit.cache_factory.dual_block_cache import (
|
|
8
|
+
cache_context,
|
|
9
|
+
DBCachedTransformerBlocks,
|
|
10
|
+
)
|
|
10
11
|
|
|
11
12
|
|
|
12
|
-
def
|
|
13
|
-
transformer:
|
|
13
|
+
def apply_db_cache_on_transformer(
|
|
14
|
+
transformer: QwenImageTransformer2DModel,
|
|
14
15
|
):
|
|
15
16
|
if getattr(transformer, "_is_cached", False):
|
|
16
17
|
return transformer
|
|
17
18
|
|
|
18
|
-
|
|
19
|
+
transformer_blocks = torch.nn.ModuleList(
|
|
19
20
|
[
|
|
20
|
-
|
|
21
|
+
DBCachedTransformerBlocks(
|
|
21
22
|
transformer.transformer_blocks,
|
|
22
23
|
transformer=transformer,
|
|
24
|
+
return_hidden_states_first=False,
|
|
23
25
|
)
|
|
24
26
|
]
|
|
25
27
|
)
|
|
@@ -35,7 +37,7 @@ def apply_cache_on_transformer(
|
|
|
35
37
|
with unittest.mock.patch.object(
|
|
36
38
|
self,
|
|
37
39
|
"transformer_blocks",
|
|
38
|
-
|
|
40
|
+
transformer_blocks,
|
|
39
41
|
):
|
|
40
42
|
return original_forward(
|
|
41
43
|
*args,
|
|
@@ -49,8 +51,8 @@ def apply_cache_on_transformer(
|
|
|
49
51
|
return transformer
|
|
50
52
|
|
|
51
53
|
|
|
52
|
-
def
|
|
53
|
-
pipe:
|
|
54
|
+
def apply_db_cache_on_pipe(
|
|
55
|
+
pipe: QwenImagePipeline,
|
|
54
56
|
*,
|
|
55
57
|
shallow_patch: bool = False,
|
|
56
58
|
residual_diff_threshold=0.06,
|
|
@@ -84,6 +86,6 @@ def apply_cache_on_pipe(
|
|
|
84
86
|
pipe.__class__._is_cached = True
|
|
85
87
|
|
|
86
88
|
if not shallow_patch:
|
|
87
|
-
|
|
89
|
+
apply_db_cache_on_transformer(pipe.transformer)
|
|
88
90
|
|
|
89
91
|
return pipe
|
|
@@ -2,9 +2,12 @@ import functools
|
|
|
2
2
|
import unittest
|
|
3
3
|
|
|
4
4
|
import torch
|
|
5
|
-
from diffusers import
|
|
5
|
+
from diffusers import WanPipeline, WanTransformer3DModel
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.dual_block_cache import
|
|
7
|
+
from cache_dit.cache_factory.dual_block_cache import (
|
|
8
|
+
cache_context,
|
|
9
|
+
DBCachedTransformerBlocks,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def apply_db_cache_on_transformer(
|
|
@@ -15,7 +18,7 @@ def apply_db_cache_on_transformer(
|
|
|
15
18
|
|
|
16
19
|
blocks = torch.nn.ModuleList(
|
|
17
20
|
[
|
|
18
|
-
|
|
21
|
+
DBCachedTransformerBlocks(
|
|
19
22
|
transformer.blocks,
|
|
20
23
|
transformer=transformer,
|
|
21
24
|
return_hidden_states_only=True,
|
|
@@ -49,7 +52,7 @@ def apply_db_cache_on_transformer(
|
|
|
49
52
|
|
|
50
53
|
|
|
51
54
|
def apply_db_cache_on_pipe(
|
|
52
|
-
pipe:
|
|
55
|
+
pipe: WanPipeline,
|
|
53
56
|
*,
|
|
54
57
|
shallow_patch: bool = False,
|
|
55
58
|
residual_diff_threshold=0.03,
|
|
@@ -15,6 +15,8 @@ def apply_db_prune_on_transformer(transformer, *args, **kwargs):
|
|
|
15
15
|
adapter_name = "wan"
|
|
16
16
|
elif transformer_cls_name.startswith("HunyuanVideo"):
|
|
17
17
|
adapter_name = "hunyuan_video"
|
|
18
|
+
elif transformer_cls_name.startswith("QwenImage"):
|
|
19
|
+
adapter_name = "qwen_image"
|
|
18
20
|
else:
|
|
19
21
|
raise ValueError(
|
|
20
22
|
f"Unknown transformer class name: {transformer_cls_name}"
|
|
@@ -41,6 +43,8 @@ def apply_db_prune_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
|
41
43
|
adapter_name = "wan"
|
|
42
44
|
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
43
45
|
adapter_name = "hunyuan_video"
|
|
46
|
+
elif pipe_cls_name.startswith("QwenImage"):
|
|
47
|
+
adapter_name = "qwen_image"
|
|
44
48
|
else:
|
|
45
49
|
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
46
50
|
|
|
@@ -4,7 +4,10 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.dynamic_block_prune import
|
|
7
|
+
from cache_dit.cache_factory.dynamic_block_prune import (
|
|
8
|
+
prune_context,
|
|
9
|
+
DBPrunedTransformerBlocks,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def apply_db_prune_on_transformer(
|
|
@@ -15,7 +18,7 @@ def apply_db_prune_on_transformer(
|
|
|
15
18
|
|
|
16
19
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
17
20
|
[
|
|
18
|
-
|
|
21
|
+
DBPrunedTransformerBlocks(
|
|
19
22
|
transformer.transformer_blocks,
|
|
20
23
|
transformer=transformer,
|
|
21
24
|
)
|
|
@@ -4,7 +4,11 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.
|
|
7
|
+
from cache_dit.cache_factory.patch.flux import maybe_patch_flux_transformer
|
|
8
|
+
from cache_dit.cache_factory.dynamic_block_prune import (
|
|
9
|
+
prune_context,
|
|
10
|
+
DBPrunedTransformerBlocks,
|
|
11
|
+
)
|
|
8
12
|
|
|
9
13
|
|
|
10
14
|
def apply_db_prune_on_transformer(
|
|
@@ -13,11 +17,13 @@ def apply_db_prune_on_transformer(
|
|
|
13
17
|
if getattr(transformer, "_is_pruned", False):
|
|
14
18
|
return transformer
|
|
15
19
|
|
|
20
|
+
transformer = maybe_patch_flux_transformer(transformer)
|
|
21
|
+
|
|
16
22
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
17
23
|
[
|
|
18
|
-
|
|
19
|
-
transformer.transformer_blocks
|
|
20
|
-
transformer.single_transformer_blocks,
|
|
24
|
+
DBPrunedTransformerBlocks(
|
|
25
|
+
transformer.transformer_blocks
|
|
26
|
+
+ transformer.single_transformer_blocks,
|
|
21
27
|
transformer=transformer,
|
|
22
28
|
return_hidden_states_first=False,
|
|
23
29
|
)
|
|
@@ -11,7 +11,10 @@ from diffusers.utils import (
|
|
|
11
11
|
USE_PEFT_BACKEND,
|
|
12
12
|
)
|
|
13
13
|
|
|
14
|
-
from cache_dit.cache_factory.dynamic_block_prune import
|
|
14
|
+
from cache_dit.cache_factory.dynamic_block_prune import (
|
|
15
|
+
prune_context,
|
|
16
|
+
DBPrunedTransformerBlocks,
|
|
17
|
+
)
|
|
15
18
|
from cache_dit.logger import init_logger
|
|
16
19
|
|
|
17
20
|
try:
|
|
@@ -44,7 +47,7 @@ def apply_db_prune_on_transformer(
|
|
|
44
47
|
|
|
45
48
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
46
49
|
[
|
|
47
|
-
|
|
50
|
+
DBPrunedTransformerBlocks(
|
|
48
51
|
transformer.transformer_blocks
|
|
49
52
|
+ transformer.single_transformer_blocks,
|
|
50
53
|
transformer=transformer,
|
|
@@ -4,7 +4,10 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import DiffusionPipeline, MochiTransformer3DModel
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.dynamic_block_prune import
|
|
7
|
+
from cache_dit.cache_factory.dynamic_block_prune import (
|
|
8
|
+
prune_context,
|
|
9
|
+
DBPrunedTransformerBlocks,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def apply_db_prune_on_transformer(
|
|
@@ -15,7 +18,7 @@ def apply_db_prune_on_transformer(
|
|
|
15
18
|
|
|
16
19
|
cached_transformer_blocks = torch.nn.ModuleList(
|
|
17
20
|
[
|
|
18
|
-
|
|
21
|
+
DBPrunedTransformerBlocks(
|
|
19
22
|
transformer.transformer_blocks,
|
|
20
23
|
transformer=transformer,
|
|
21
24
|
)
|
|
@@ -1,25 +1,30 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/cogvideox.py
|
|
2
|
-
|
|
3
1
|
import functools
|
|
4
2
|
import unittest
|
|
5
3
|
|
|
6
4
|
import torch
|
|
7
|
-
from diffusers import
|
|
5
|
+
from diffusers import QwenImagePipeline, QwenImageTransformer2DModel
|
|
6
|
+
|
|
7
|
+
from cache_dit.cache_factory.dynamic_block_prune import (
|
|
8
|
+
prune_context,
|
|
9
|
+
DBPrunedTransformerBlocks,
|
|
10
|
+
)
|
|
11
|
+
from cache_dit.logger import init_logger
|
|
8
12
|
|
|
9
|
-
|
|
13
|
+
logger = init_logger(__name__)
|
|
10
14
|
|
|
11
15
|
|
|
12
|
-
def
|
|
13
|
-
transformer:
|
|
16
|
+
def apply_db_prune_on_transformer(
|
|
17
|
+
transformer: QwenImageTransformer2DModel,
|
|
14
18
|
):
|
|
15
|
-
if getattr(transformer, "
|
|
19
|
+
if getattr(transformer, "_is_pruned", False):
|
|
16
20
|
return transformer
|
|
17
21
|
|
|
18
|
-
|
|
22
|
+
transformer_blocks = torch.nn.ModuleList(
|
|
19
23
|
[
|
|
20
|
-
|
|
24
|
+
DBPrunedTransformerBlocks(
|
|
21
25
|
transformer.transformer_blocks,
|
|
22
26
|
transformer=transformer,
|
|
27
|
+
return_hidden_states_first=False,
|
|
23
28
|
)
|
|
24
29
|
]
|
|
25
30
|
)
|
|
@@ -35,7 +40,7 @@ def apply_cache_on_transformer(
|
|
|
35
40
|
with unittest.mock.patch.object(
|
|
36
41
|
self,
|
|
37
42
|
"transformer_blocks",
|
|
38
|
-
|
|
43
|
+
transformer_blocks,
|
|
39
44
|
):
|
|
40
45
|
return original_forward(
|
|
41
46
|
*args,
|
|
@@ -44,46 +49,46 @@ def apply_cache_on_transformer(
|
|
|
44
49
|
|
|
45
50
|
transformer.forward = new_forward.__get__(transformer)
|
|
46
51
|
|
|
47
|
-
transformer.
|
|
52
|
+
transformer._is_pruned = True
|
|
48
53
|
|
|
49
54
|
return transformer
|
|
50
55
|
|
|
51
56
|
|
|
52
|
-
def
|
|
53
|
-
pipe:
|
|
57
|
+
def apply_db_prune_on_pipe(
|
|
58
|
+
pipe: QwenImagePipeline,
|
|
54
59
|
*,
|
|
55
60
|
shallow_patch: bool = False,
|
|
56
|
-
residual_diff_threshold=0.
|
|
61
|
+
residual_diff_threshold=0.06,
|
|
57
62
|
downsample_factor=1,
|
|
58
63
|
warmup_steps=0,
|
|
59
|
-
|
|
64
|
+
max_pruned_steps=-1,
|
|
60
65
|
**kwargs,
|
|
61
66
|
):
|
|
62
|
-
|
|
67
|
+
prune_kwargs, kwargs = prune_context.collect_prune_kwargs(
|
|
63
68
|
default_attrs={
|
|
64
69
|
"residual_diff_threshold": residual_diff_threshold,
|
|
65
70
|
"downsample_factor": downsample_factor,
|
|
66
71
|
"warmup_steps": warmup_steps,
|
|
67
|
-
"
|
|
72
|
+
"max_pruned_steps": max_pruned_steps,
|
|
68
73
|
},
|
|
69
74
|
**kwargs,
|
|
70
75
|
)
|
|
71
|
-
if not getattr(pipe, "
|
|
76
|
+
if not getattr(pipe, "_is_pruned", False):
|
|
72
77
|
original_call = pipe.__class__.__call__
|
|
73
78
|
|
|
74
79
|
@functools.wraps(original_call)
|
|
75
80
|
def new_call(self, *args, **kwargs):
|
|
76
|
-
with
|
|
77
|
-
|
|
78
|
-
**
|
|
81
|
+
with prune_context.prune_context(
|
|
82
|
+
prune_context.create_prune_context(
|
|
83
|
+
**prune_kwargs,
|
|
79
84
|
)
|
|
80
85
|
):
|
|
81
86
|
return original_call(self, *args, **kwargs)
|
|
82
87
|
|
|
83
88
|
pipe.__class__.__call__ = new_call
|
|
84
|
-
pipe.__class__.
|
|
89
|
+
pipe.__class__._is_pruned = True
|
|
85
90
|
|
|
86
91
|
if not shallow_patch:
|
|
87
|
-
|
|
92
|
+
apply_db_prune_on_transformer(pipe.transformer)
|
|
88
93
|
|
|
89
94
|
return pipe
|
|
@@ -4,7 +4,10 @@ import unittest
|
|
|
4
4
|
import torch
|
|
5
5
|
from diffusers import DiffusionPipeline, WanTransformer3DModel
|
|
6
6
|
|
|
7
|
-
from cache_dit.cache_factory.dynamic_block_prune import
|
|
7
|
+
from cache_dit.cache_factory.dynamic_block_prune import (
|
|
8
|
+
prune_context,
|
|
9
|
+
DBPrunedTransformerBlocks,
|
|
10
|
+
)
|
|
8
11
|
|
|
9
12
|
|
|
10
13
|
def apply_db_prune_on_transformer(
|
|
@@ -15,7 +18,7 @@ def apply_db_prune_on_transformer(
|
|
|
15
18
|
|
|
16
19
|
blocks = torch.nn.ModuleList(
|
|
17
20
|
[
|
|
18
|
-
|
|
21
|
+
DBPrunedTransformerBlocks(
|
|
19
22
|
transformer.blocks,
|
|
20
23
|
transformer=transformer,
|
|
21
24
|
return_hidden_states_only=True,
|