cache-dit 0.1.1.dev2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +0 -0
- cache_dit/_version.py +21 -0
- cache_dit/cache_factory/__init__.py +166 -0
- cache_dit/cache_factory/dual_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +1361 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +88 -0
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +45 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +979 -0
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +727 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +53 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +100 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +89 -0
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +98 -0
- cache_dit/cache_factory/taylorseer.py +76 -0
- cache_dit/cache_factory/utils.py +0 -0
- cache_dit/logger.py +97 -0
- cache_dit/primitives.py +152 -0
- cache_dit-0.1.1.dev2.dist-info/METADATA +31 -0
- cache_dit-0.1.1.dev2.dist-info/RECORD +30 -0
- cache_dit-0.1.1.dev2.dist-info/WHEEL +5 -0
- cache_dit-0.1.1.dev2.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/__init__.py
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
|
|
5
|
+
from diffusers import DiffusionPipeline
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_db_cache_on_transformer(transformer, *args, **kwargs):
|
|
9
|
+
transformer_cls_name: str = transformer.__class__.__name__
|
|
10
|
+
if transformer_cls_name.startswith("Flux"):
|
|
11
|
+
adapter_name = "flux"
|
|
12
|
+
elif transformer_cls_name.startswith("Mochi"):
|
|
13
|
+
adapter_name = "mochi"
|
|
14
|
+
elif transformer_cls_name.startswith("CogVideoX"):
|
|
15
|
+
adapter_name = "cogvideox"
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Unknown transformer class name: {transformer_cls_name}"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
22
|
+
apply_db_cache_on_transformer_fn = getattr(
|
|
23
|
+
adapter_module, "apply_db_cache_on_transformer"
|
|
24
|
+
)
|
|
25
|
+
return apply_db_cache_on_transformer_fn(transformer, *args, **kwargs)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def apply_db_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
29
|
+
assert isinstance(pipe, DiffusionPipeline)
|
|
30
|
+
|
|
31
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
32
|
+
if pipe_cls_name.startswith("Flux"):
|
|
33
|
+
adapter_name = "flux"
|
|
34
|
+
elif pipe_cls_name.startswith("Mochi"):
|
|
35
|
+
adapter_name = "mochi"
|
|
36
|
+
elif pipe_cls_name.startswith("CogVideoX"):
|
|
37
|
+
adapter_name = "cogvideox"
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
40
|
+
|
|
41
|
+
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
42
|
+
apply_db_cache_on_pipe_fn = getattr(
|
|
43
|
+
adapter_module, "apply_db_cache_on_pipe"
|
|
44
|
+
)
|
|
45
|
+
return apply_db_cache_on_pipe_fn(pipe, *args, **kwargs)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/cogvideox.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.dual_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_db_cache_on_transformer(
|
|
13
|
+
transformer: CogVideoXTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.DBCachedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
)
|
|
24
|
+
]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
original_forward = transformer.forward
|
|
28
|
+
|
|
29
|
+
@functools.wraps(transformer.__class__.forward)
|
|
30
|
+
def new_forward(
|
|
31
|
+
self,
|
|
32
|
+
*args,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
with unittest.mock.patch.object(
|
|
36
|
+
self,
|
|
37
|
+
"transformer_blocks",
|
|
38
|
+
cached_transformer_blocks,
|
|
39
|
+
):
|
|
40
|
+
return original_forward(
|
|
41
|
+
*args,
|
|
42
|
+
**kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
46
|
+
|
|
47
|
+
transformer._is_cached = True
|
|
48
|
+
|
|
49
|
+
return transformer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def apply_cache_on_pipe(
|
|
53
|
+
pipe: DiffusionPipeline,
|
|
54
|
+
*,
|
|
55
|
+
shallow_patch: bool = False,
|
|
56
|
+
residual_diff_threshold=0.04,
|
|
57
|
+
downsample_factor=1,
|
|
58
|
+
warmup_steps=0,
|
|
59
|
+
max_cached_steps=-1,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
62
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
63
|
+
default_attrs={
|
|
64
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
65
|
+
"downsample_factor": downsample_factor,
|
|
66
|
+
"warmup_steps": warmup_steps,
|
|
67
|
+
"max_cached_steps": max_cached_steps,
|
|
68
|
+
},
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
if not getattr(pipe, "_is_cached", False):
|
|
72
|
+
original_call = pipe.__class__.__call__
|
|
73
|
+
|
|
74
|
+
@functools.wraps(original_call)
|
|
75
|
+
def new_call(self, *args, **kwargs):
|
|
76
|
+
with cache_context.cache_context(
|
|
77
|
+
cache_context.create_cache_context(
|
|
78
|
+
**cache_kwargs,
|
|
79
|
+
)
|
|
80
|
+
):
|
|
81
|
+
return original_call(self, *args, **kwargs)
|
|
82
|
+
|
|
83
|
+
pipe.__class__.__call__ = new_call
|
|
84
|
+
pipe.__class__._is_cached = True
|
|
85
|
+
|
|
86
|
+
if not shallow_patch:
|
|
87
|
+
apply_db_cache_on_transformer(pipe.transformer, **kwargs)
|
|
88
|
+
|
|
89
|
+
return pipe
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/flux.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.dual_block_cache import cache_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_db_cache_on_transformer(
|
|
13
|
+
transformer: FluxTransformer2DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_cached", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
cache_context.DBCachedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer.single_transformer_blocks,
|
|
23
|
+
transformer=transformer,
|
|
24
|
+
return_hidden_states_first=False,
|
|
25
|
+
)
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
29
|
+
|
|
30
|
+
original_forward = transformer.forward
|
|
31
|
+
|
|
32
|
+
@functools.wraps(original_forward)
|
|
33
|
+
def new_forward(
|
|
34
|
+
self,
|
|
35
|
+
*args,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
with (
|
|
39
|
+
unittest.mock.patch.object(
|
|
40
|
+
self,
|
|
41
|
+
"transformer_blocks",
|
|
42
|
+
cached_transformer_blocks,
|
|
43
|
+
),
|
|
44
|
+
unittest.mock.patch.object(
|
|
45
|
+
self,
|
|
46
|
+
"single_transformer_blocks",
|
|
47
|
+
dummy_single_transformer_blocks,
|
|
48
|
+
),
|
|
49
|
+
):
|
|
50
|
+
return original_forward(
|
|
51
|
+
*args,
|
|
52
|
+
**kwargs,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
56
|
+
|
|
57
|
+
transformer._is_cached = True
|
|
58
|
+
|
|
59
|
+
return transformer
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def apply_db_cache_on_pipe(
|
|
63
|
+
pipe: DiffusionPipeline,
|
|
64
|
+
*,
|
|
65
|
+
shallow_patch: bool = False,
|
|
66
|
+
residual_diff_threshold=0.05,
|
|
67
|
+
downsample_factor=1,
|
|
68
|
+
warmup_steps=0,
|
|
69
|
+
max_cached_steps=-1,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
73
|
+
default_attrs={
|
|
74
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
75
|
+
"downsample_factor": downsample_factor,
|
|
76
|
+
"warmup_steps": warmup_steps,
|
|
77
|
+
"max_cached_steps": max_cached_steps,
|
|
78
|
+
},
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if not getattr(pipe, "_is_cached", False):
|
|
83
|
+
original_call = pipe.__class__.__call__
|
|
84
|
+
|
|
85
|
+
@functools.wraps(original_call)
|
|
86
|
+
def new_call(self, *args, **kwargs):
|
|
87
|
+
with cache_context.cache_context(
|
|
88
|
+
cache_context.create_cache_context(
|
|
89
|
+
**cache_kwargs,
|
|
90
|
+
)
|
|
91
|
+
):
|
|
92
|
+
return original_call(self, *args, **kwargs)
|
|
93
|
+
|
|
94
|
+
pipe.__class__.__call__ = new_call
|
|
95
|
+
pipe.__class__._is_cached = True
|
|
96
|
+
|
|
97
|
+
if not shallow_patch:
|
|
98
|
+
apply_db_cache_on_transformer(pipe.transformer, **kwargs)
|
|
99
|
+
|
|
100
|
+
return pipe
|
|
@@ -0,0 +1,88 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/mochi.py
|
|
2
|
+
import functools
|
|
3
|
+
import unittest
|
|
4
|
+
|
|
5
|
+
import torch
|
|
6
|
+
from diffusers import DiffusionPipeline, MochiTransformer3DModel
|
|
7
|
+
|
|
8
|
+
from cache_dit.cache_factory.dual_block_cache import cache_context
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
def apply_db_cache_on_transformer(
|
|
12
|
+
transformer: MochiTransformer3DModel,
|
|
13
|
+
):
|
|
14
|
+
if getattr(transformer, "_is_cached", False):
|
|
15
|
+
return transformer
|
|
16
|
+
|
|
17
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
18
|
+
[
|
|
19
|
+
cache_context.DBCachedTransformerBlocks(
|
|
20
|
+
transformer.transformer_blocks,
|
|
21
|
+
transformer=transformer,
|
|
22
|
+
)
|
|
23
|
+
]
|
|
24
|
+
)
|
|
25
|
+
|
|
26
|
+
original_forward = transformer.forward
|
|
27
|
+
|
|
28
|
+
@functools.wraps(transformer.__class__.forward)
|
|
29
|
+
def new_forward(
|
|
30
|
+
self,
|
|
31
|
+
*args,
|
|
32
|
+
**kwargs,
|
|
33
|
+
):
|
|
34
|
+
with unittest.mock.patch.object(
|
|
35
|
+
self,
|
|
36
|
+
"transformer_blocks",
|
|
37
|
+
cached_transformer_blocks,
|
|
38
|
+
):
|
|
39
|
+
return original_forward(
|
|
40
|
+
*args,
|
|
41
|
+
**kwargs,
|
|
42
|
+
)
|
|
43
|
+
|
|
44
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
45
|
+
|
|
46
|
+
transformer._is_cached = True
|
|
47
|
+
|
|
48
|
+
return transformer
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def apply_db_cache_on_pipe(
|
|
52
|
+
pipe: DiffusionPipeline,
|
|
53
|
+
*,
|
|
54
|
+
shallow_patch: bool = False,
|
|
55
|
+
residual_diff_threshold=0.06,
|
|
56
|
+
downsample_factor=1,
|
|
57
|
+
warmup_steps=0,
|
|
58
|
+
max_cached_steps=-1,
|
|
59
|
+
**kwargs,
|
|
60
|
+
):
|
|
61
|
+
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
62
|
+
default_attrs={
|
|
63
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
64
|
+
"downsample_factor": downsample_factor,
|
|
65
|
+
"warmup_steps": warmup_steps,
|
|
66
|
+
"max_cached_steps": max_cached_steps,
|
|
67
|
+
},
|
|
68
|
+
**kwargs,
|
|
69
|
+
)
|
|
70
|
+
if not getattr(pipe, "_is_cached", False):
|
|
71
|
+
original_call = pipe.__class__.__call__
|
|
72
|
+
|
|
73
|
+
@functools.wraps(original_call)
|
|
74
|
+
def new_call(self, *args, **kwargs):
|
|
75
|
+
with cache_context.cache_context(
|
|
76
|
+
cache_context.create_cache_context(
|
|
77
|
+
**cache_kwargs,
|
|
78
|
+
)
|
|
79
|
+
):
|
|
80
|
+
return original_call(self, *args, **kwargs)
|
|
81
|
+
|
|
82
|
+
pipe.__class__.__call__ = new_call
|
|
83
|
+
pipe.__class__._is_cached = True
|
|
84
|
+
|
|
85
|
+
if not shallow_patch:
|
|
86
|
+
apply_db_cache_on_transformer(pipe.transformer, **kwargs)
|
|
87
|
+
|
|
88
|
+
return pipe
|
|
File without changes
|
|
@@ -0,0 +1,45 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/__init__.py
|
|
2
|
+
|
|
3
|
+
import importlib
|
|
4
|
+
|
|
5
|
+
from diffusers import DiffusionPipeline
|
|
6
|
+
|
|
7
|
+
|
|
8
|
+
def apply_db_prune_on_transformer(transformer, *args, **kwargs):
|
|
9
|
+
transformer_cls_name: str = transformer.__class__.__name__
|
|
10
|
+
if transformer_cls_name.startswith("Flux"):
|
|
11
|
+
adapter_name = "flux"
|
|
12
|
+
elif transformer_cls_name.startswith("Mochi"):
|
|
13
|
+
adapter_name = "mochi"
|
|
14
|
+
elif transformer_cls_name.startswith("CogVideoX"):
|
|
15
|
+
adapter_name = "cogvideox"
|
|
16
|
+
else:
|
|
17
|
+
raise ValueError(
|
|
18
|
+
f"Unknown transformer class name: {transformer_cls_name}"
|
|
19
|
+
)
|
|
20
|
+
|
|
21
|
+
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
22
|
+
apply_db_cache_on_transformer_fn = getattr(
|
|
23
|
+
adapter_module, "apply_db_prune_on_transformer"
|
|
24
|
+
)
|
|
25
|
+
return apply_db_cache_on_transformer_fn(transformer, *args, **kwargs)
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
def apply_db_prune_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
29
|
+
assert isinstance(pipe, DiffusionPipeline)
|
|
30
|
+
|
|
31
|
+
pipe_cls_name: str = pipe.__class__.__name__
|
|
32
|
+
if pipe_cls_name.startswith("Flux"):
|
|
33
|
+
adapter_name = "flux"
|
|
34
|
+
elif pipe_cls_name.startswith("Mochi"):
|
|
35
|
+
adapter_name = "mochi"
|
|
36
|
+
elif pipe_cls_name.startswith("CogVideoX"):
|
|
37
|
+
adapter_name = "cogvideox"
|
|
38
|
+
else:
|
|
39
|
+
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
40
|
+
|
|
41
|
+
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
42
|
+
apply_db_cache_on_pipe_fn = getattr(
|
|
43
|
+
adapter_module, "apply_db_prune_on_pipe"
|
|
44
|
+
)
|
|
45
|
+
return apply_db_cache_on_pipe_fn(pipe, *args, **kwargs)
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/cogvideox.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import CogVideoXTransformer3DModel, DiffusionPipeline
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.dynamic_block_prune import prune_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_db_prune_on_transformer(
|
|
13
|
+
transformer: CogVideoXTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_pruned", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
prune_context.DBPrunedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
)
|
|
24
|
+
]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
original_forward = transformer.forward
|
|
28
|
+
|
|
29
|
+
@functools.wraps(transformer.__class__.forward)
|
|
30
|
+
def new_forward(
|
|
31
|
+
self,
|
|
32
|
+
*args,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
with unittest.mock.patch.object(
|
|
36
|
+
self,
|
|
37
|
+
"transformer_blocks",
|
|
38
|
+
cached_transformer_blocks,
|
|
39
|
+
):
|
|
40
|
+
return original_forward(
|
|
41
|
+
*args,
|
|
42
|
+
**kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
46
|
+
|
|
47
|
+
transformer._is_pruned = True
|
|
48
|
+
|
|
49
|
+
return transformer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def apply_db_prune_on_pipe(
|
|
53
|
+
pipe: DiffusionPipeline,
|
|
54
|
+
*,
|
|
55
|
+
shallow_patch: bool = False,
|
|
56
|
+
residual_diff_threshold=0.04,
|
|
57
|
+
downsample_factor=1,
|
|
58
|
+
warmup_steps=0,
|
|
59
|
+
max_pruned_steps=-1,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
62
|
+
cache_kwargs, kwargs = prune_context.collect_prune_kwargs(
|
|
63
|
+
default_attrs={
|
|
64
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
65
|
+
"downsample_factor": downsample_factor,
|
|
66
|
+
"warmup_steps": warmup_steps,
|
|
67
|
+
"max_pruned_steps": max_pruned_steps,
|
|
68
|
+
},
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
if not getattr(pipe, "_is_pruned", False):
|
|
72
|
+
original_call = pipe.__class__.__call__
|
|
73
|
+
|
|
74
|
+
@functools.wraps(original_call)
|
|
75
|
+
def new_call(self, *args, **kwargs):
|
|
76
|
+
with prune_context.prune_context(
|
|
77
|
+
prune_context.create_prune_context(
|
|
78
|
+
**cache_kwargs,
|
|
79
|
+
)
|
|
80
|
+
):
|
|
81
|
+
return original_call(self, *args, **kwargs)
|
|
82
|
+
|
|
83
|
+
pipe.__class__.__call__ = new_call
|
|
84
|
+
pipe.__class__._is_pruned = True
|
|
85
|
+
|
|
86
|
+
if not shallow_patch:
|
|
87
|
+
apply_db_prune_on_transformer(pipe.transformer, **kwargs)
|
|
88
|
+
|
|
89
|
+
return pipe
|
|
@@ -0,0 +1,100 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/flux.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.dynamic_block_prune import prune_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_db_prune_on_transformer(
|
|
13
|
+
transformer: FluxTransformer2DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_pruned", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
prune_context.DBPrunedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer.single_transformer_blocks,
|
|
23
|
+
transformer=transformer,
|
|
24
|
+
return_hidden_states_first=False,
|
|
25
|
+
)
|
|
26
|
+
]
|
|
27
|
+
)
|
|
28
|
+
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
29
|
+
|
|
30
|
+
original_forward = transformer.forward
|
|
31
|
+
|
|
32
|
+
@functools.wraps(original_forward)
|
|
33
|
+
def new_forward(
|
|
34
|
+
self,
|
|
35
|
+
*args,
|
|
36
|
+
**kwargs,
|
|
37
|
+
):
|
|
38
|
+
with (
|
|
39
|
+
unittest.mock.patch.object(
|
|
40
|
+
self,
|
|
41
|
+
"transformer_blocks",
|
|
42
|
+
cached_transformer_blocks,
|
|
43
|
+
),
|
|
44
|
+
unittest.mock.patch.object(
|
|
45
|
+
self,
|
|
46
|
+
"single_transformer_blocks",
|
|
47
|
+
dummy_single_transformer_blocks,
|
|
48
|
+
),
|
|
49
|
+
):
|
|
50
|
+
return original_forward(
|
|
51
|
+
*args,
|
|
52
|
+
**kwargs,
|
|
53
|
+
)
|
|
54
|
+
|
|
55
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
56
|
+
|
|
57
|
+
transformer._is_pruned = True
|
|
58
|
+
|
|
59
|
+
return transformer
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
def apply_db_prune_on_pipe(
|
|
63
|
+
pipe: DiffusionPipeline,
|
|
64
|
+
*,
|
|
65
|
+
shallow_patch: bool = False,
|
|
66
|
+
residual_diff_threshold=0.05,
|
|
67
|
+
downsample_factor=1,
|
|
68
|
+
warmup_steps=0,
|
|
69
|
+
max_pruned_steps=-1,
|
|
70
|
+
**kwargs,
|
|
71
|
+
):
|
|
72
|
+
cache_kwargs, kwargs = prune_context.collect_prune_kwargs(
|
|
73
|
+
default_attrs={
|
|
74
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
75
|
+
"downsample_factor": downsample_factor,
|
|
76
|
+
"warmup_steps": warmup_steps,
|
|
77
|
+
"max_pruned_steps": max_pruned_steps,
|
|
78
|
+
},
|
|
79
|
+
**kwargs,
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
if not getattr(pipe, "_is_pruned", False):
|
|
83
|
+
original_call = pipe.__class__.__call__
|
|
84
|
+
|
|
85
|
+
@functools.wraps(original_call)
|
|
86
|
+
def new_call(self, *args, **kwargs):
|
|
87
|
+
with prune_context.prune_context(
|
|
88
|
+
prune_context.create_prune_context(
|
|
89
|
+
**cache_kwargs,
|
|
90
|
+
)
|
|
91
|
+
):
|
|
92
|
+
return original_call(self, *args, **kwargs)
|
|
93
|
+
|
|
94
|
+
pipe.__class__.__call__ = new_call
|
|
95
|
+
pipe.__class__._is_pruned = True
|
|
96
|
+
|
|
97
|
+
if not shallow_patch:
|
|
98
|
+
apply_db_prune_on_transformer(pipe.transformer, **kwargs)
|
|
99
|
+
|
|
100
|
+
return pipe
|
|
@@ -0,0 +1,89 @@
|
|
|
1
|
+
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/mochi.py
|
|
2
|
+
|
|
3
|
+
import functools
|
|
4
|
+
import unittest
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from diffusers import DiffusionPipeline, MochiTransformer3DModel
|
|
8
|
+
|
|
9
|
+
from cache_dit.cache_factory.dynamic_block_prune import prune_context
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
def apply_db_prune_on_transformer(
|
|
13
|
+
transformer: MochiTransformer3DModel,
|
|
14
|
+
):
|
|
15
|
+
if getattr(transformer, "_is_pruned", False):
|
|
16
|
+
return transformer
|
|
17
|
+
|
|
18
|
+
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
+
[
|
|
20
|
+
prune_context.DBPrunedTransformerBlocks(
|
|
21
|
+
transformer.transformer_blocks,
|
|
22
|
+
transformer=transformer,
|
|
23
|
+
)
|
|
24
|
+
]
|
|
25
|
+
)
|
|
26
|
+
|
|
27
|
+
original_forward = transformer.forward
|
|
28
|
+
|
|
29
|
+
@functools.wraps(transformer.__class__.forward)
|
|
30
|
+
def new_forward(
|
|
31
|
+
self,
|
|
32
|
+
*args,
|
|
33
|
+
**kwargs,
|
|
34
|
+
):
|
|
35
|
+
with unittest.mock.patch.object(
|
|
36
|
+
self,
|
|
37
|
+
"transformer_blocks",
|
|
38
|
+
cached_transformer_blocks,
|
|
39
|
+
):
|
|
40
|
+
return original_forward(
|
|
41
|
+
*args,
|
|
42
|
+
**kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
transformer.forward = new_forward.__get__(transformer)
|
|
46
|
+
|
|
47
|
+
transformer._is_pruned = True
|
|
48
|
+
|
|
49
|
+
return transformer
|
|
50
|
+
|
|
51
|
+
|
|
52
|
+
def apply_db_prune_on_pipe(
|
|
53
|
+
pipe: DiffusionPipeline,
|
|
54
|
+
*,
|
|
55
|
+
shallow_patch: bool = False,
|
|
56
|
+
residual_diff_threshold=0.06,
|
|
57
|
+
downsample_factor=1,
|
|
58
|
+
warmup_steps=0,
|
|
59
|
+
max_pruned_steps=-1,
|
|
60
|
+
**kwargs,
|
|
61
|
+
):
|
|
62
|
+
cache_kwargs, kwargs = prune_context.collect_prune_kwargs(
|
|
63
|
+
default_attrs={
|
|
64
|
+
"residual_diff_threshold": residual_diff_threshold,
|
|
65
|
+
"downsample_factor": downsample_factor,
|
|
66
|
+
"warmup_steps": warmup_steps,
|
|
67
|
+
"max_pruned_steps": max_pruned_steps,
|
|
68
|
+
},
|
|
69
|
+
**kwargs,
|
|
70
|
+
)
|
|
71
|
+
if not getattr(pipe, "_is_pruned", False):
|
|
72
|
+
original_call = pipe.__class__.__call__
|
|
73
|
+
|
|
74
|
+
@functools.wraps(original_call)
|
|
75
|
+
def new_call(self, *args, **kwargs):
|
|
76
|
+
with prune_context.prune_context(
|
|
77
|
+
prune_context.create_prune_context(
|
|
78
|
+
**cache_kwargs,
|
|
79
|
+
)
|
|
80
|
+
):
|
|
81
|
+
return original_call(self, *args, **kwargs)
|
|
82
|
+
|
|
83
|
+
pipe.__class__.__call__ = new_call
|
|
84
|
+
pipe.__class__._is_pruned = True
|
|
85
|
+
|
|
86
|
+
if not shallow_patch:
|
|
87
|
+
apply_db_prune_on_transformer(pipe.transformer, **kwargs)
|
|
88
|
+
|
|
89
|
+
return pipe
|