cache-dit 0.2.14__py3-none-any.whl → 0.2.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/_version.py +2 -2
- cache_dit/cache_factory/__init__.py +1 -0
- cache_dit/cache_factory/adapters.py +47 -5
- cache_dit/cache_factory/dual_block_cache/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/cache_blocks.py +487 -0
- cache_dit/cache_factory/dual_block_cache/cache_context.py +10 -860
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +14 -4
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/mochi.py → dual_block_cache/diffusers_adapters/qwen_image.py} +14 -12
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +7 -4
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +4 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +10 -4
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +5 -2
- cache_dit/cache_factory/{first_block_cache/diffusers_adapters/cogvideox.py → dynamic_block_prune/diffusers_adapters/qwen_image.py} +28 -23
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +5 -2
- cache_dit/cache_factory/dynamic_block_prune/prune_blocks.py +276 -0
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +228 -516
- cache_dit/cache_factory/patch/flux.py +241 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/METADATA +22 -80
- cache_dit-0.2.16.dist-info/RECORD +47 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.14.dist-info/RECORD +0 -49
- /cache_dit/cache_factory/{first_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.14.dist-info → cache_dit-0.2.16.dist-info}/top_level.txt +0 -0
|
@@ -1,57 +0,0 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/__init__.py
|
|
2
|
-
|
|
3
|
-
import importlib
|
|
4
|
-
from typing import Callable
|
|
5
|
-
|
|
6
|
-
from diffusers import DiffusionPipeline
|
|
7
|
-
|
|
8
|
-
|
|
9
|
-
def apply_fb_cache_on_transformer(transformer, *args, **kwargs):
|
|
10
|
-
transformer_cls_name: str = transformer.__class__.__name__
|
|
11
|
-
if transformer_cls_name.startswith("Flux"):
|
|
12
|
-
adapter_name = "flux"
|
|
13
|
-
elif transformer_cls_name.startswith("Mochi"):
|
|
14
|
-
adapter_name = "mochi"
|
|
15
|
-
elif transformer_cls_name.startswith("CogVideoX"):
|
|
16
|
-
adapter_name = "cogvideox"
|
|
17
|
-
elif transformer_cls_name.startswith("Wan"):
|
|
18
|
-
adapter_name = "wan"
|
|
19
|
-
elif transformer_cls_name.startswith("HunyuanVideo"):
|
|
20
|
-
adapter_name = "hunyuan_video"
|
|
21
|
-
else:
|
|
22
|
-
raise ValueError(
|
|
23
|
-
f"Unknown transformer class name: {transformer_cls_name}"
|
|
24
|
-
)
|
|
25
|
-
|
|
26
|
-
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
27
|
-
apply_cache_on_transformer_fn = getattr(
|
|
28
|
-
adapter_module, "apply_cache_on_transformer"
|
|
29
|
-
)
|
|
30
|
-
return apply_cache_on_transformer_fn(transformer, *args, **kwargs)
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
def apply_fb_cache_on_pipe(pipe: DiffusionPipeline, *args, **kwargs):
|
|
34
|
-
assert isinstance(pipe, DiffusionPipeline)
|
|
35
|
-
|
|
36
|
-
pipe_cls_name: str = pipe.__class__.__name__
|
|
37
|
-
if pipe_cls_name.startswith("Flux"):
|
|
38
|
-
adapter_name = "flux"
|
|
39
|
-
elif pipe_cls_name.startswith("Mochi"):
|
|
40
|
-
adapter_name = "mochi"
|
|
41
|
-
elif pipe_cls_name.startswith("CogVideoX"):
|
|
42
|
-
adapter_name = "cogvideox"
|
|
43
|
-
elif pipe_cls_name.startswith("Wan"):
|
|
44
|
-
adapter_name = "wan"
|
|
45
|
-
elif pipe_cls_name.startswith("HunyuanVideo"):
|
|
46
|
-
adapter_name = "hunyuan_video"
|
|
47
|
-
else:
|
|
48
|
-
raise ValueError(f"Unknown pipeline class name: {pipe_cls_name}")
|
|
49
|
-
|
|
50
|
-
adapter_module = importlib.import_module(f".{adapter_name}", __package__)
|
|
51
|
-
apply_cache_on_pipe_fn = getattr(adapter_module, "apply_cache_on_pipe")
|
|
52
|
-
return apply_cache_on_pipe_fn(pipe, *args, **kwargs)
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
# re-export functions for compatibility
|
|
56
|
-
apply_cache_on_transformer: Callable = apply_fb_cache_on_transformer
|
|
57
|
-
apply_cache_on_pipe: Callable = apply_fb_cache_on_pipe
|
|
@@ -1,100 +0,0 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/flux.py
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
import unittest
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from diffusers import DiffusionPipeline, FluxTransformer2DModel
|
|
8
|
-
|
|
9
|
-
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def apply_cache_on_transformer(
|
|
13
|
-
transformer: FluxTransformer2DModel,
|
|
14
|
-
):
|
|
15
|
-
if getattr(transformer, "_is_cached", False):
|
|
16
|
-
return transformer
|
|
17
|
-
|
|
18
|
-
cached_transformer_blocks = torch.nn.ModuleList(
|
|
19
|
-
[
|
|
20
|
-
cache_context.CachedTransformerBlocks(
|
|
21
|
-
transformer.transformer_blocks,
|
|
22
|
-
transformer.single_transformer_blocks,
|
|
23
|
-
transformer=transformer,
|
|
24
|
-
return_hidden_states_first=False,
|
|
25
|
-
)
|
|
26
|
-
]
|
|
27
|
-
)
|
|
28
|
-
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
29
|
-
|
|
30
|
-
original_forward = transformer.forward
|
|
31
|
-
|
|
32
|
-
@functools.wraps(original_forward)
|
|
33
|
-
def new_forward(
|
|
34
|
-
self,
|
|
35
|
-
*args,
|
|
36
|
-
**kwargs,
|
|
37
|
-
):
|
|
38
|
-
with (
|
|
39
|
-
unittest.mock.patch.object(
|
|
40
|
-
self,
|
|
41
|
-
"transformer_blocks",
|
|
42
|
-
cached_transformer_blocks,
|
|
43
|
-
),
|
|
44
|
-
unittest.mock.patch.object(
|
|
45
|
-
self,
|
|
46
|
-
"single_transformer_blocks",
|
|
47
|
-
dummy_single_transformer_blocks,
|
|
48
|
-
),
|
|
49
|
-
):
|
|
50
|
-
return original_forward(
|
|
51
|
-
*args,
|
|
52
|
-
**kwargs,
|
|
53
|
-
)
|
|
54
|
-
|
|
55
|
-
transformer.forward = new_forward.__get__(transformer)
|
|
56
|
-
|
|
57
|
-
transformer._is_cached = True
|
|
58
|
-
|
|
59
|
-
return transformer
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
def apply_cache_on_pipe(
|
|
63
|
-
pipe: DiffusionPipeline,
|
|
64
|
-
*,
|
|
65
|
-
shallow_patch: bool = False,
|
|
66
|
-
residual_diff_threshold=0.05,
|
|
67
|
-
downsample_factor=1,
|
|
68
|
-
warmup_steps=0,
|
|
69
|
-
max_cached_steps=-1,
|
|
70
|
-
**kwargs,
|
|
71
|
-
):
|
|
72
|
-
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
73
|
-
default_attrs={
|
|
74
|
-
"residual_diff_threshold": residual_diff_threshold,
|
|
75
|
-
"downsample_factor": downsample_factor,
|
|
76
|
-
"warmup_steps": warmup_steps,
|
|
77
|
-
"max_cached_steps": max_cached_steps,
|
|
78
|
-
},
|
|
79
|
-
**kwargs,
|
|
80
|
-
) # noqa
|
|
81
|
-
|
|
82
|
-
if not getattr(pipe, "_is_cached", False):
|
|
83
|
-
original_call = pipe.__class__.__call__
|
|
84
|
-
|
|
85
|
-
@functools.wraps(original_call)
|
|
86
|
-
def new_call(self, *args, **kwargs):
|
|
87
|
-
with cache_context.cache_context(
|
|
88
|
-
cache_context.create_cache_context(
|
|
89
|
-
**cache_kwargs,
|
|
90
|
-
)
|
|
91
|
-
):
|
|
92
|
-
return original_call(self, *args, **kwargs)
|
|
93
|
-
|
|
94
|
-
pipe.__class__.__call__ = new_call
|
|
95
|
-
pipe.__class__._is_cached = True
|
|
96
|
-
|
|
97
|
-
if not shallow_patch:
|
|
98
|
-
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
99
|
-
|
|
100
|
-
return pipe
|
|
@@ -1,295 +0,0 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/blob/main/src/para_attn/first_block_cache/diffusers_adapters/hunyuan_video.py
|
|
2
|
-
import functools
|
|
3
|
-
import unittest
|
|
4
|
-
from typing import Any, Dict, Optional, Union
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from diffusers import DiffusionPipeline, HunyuanVideoTransformer3DModel
|
|
8
|
-
from diffusers.models.modeling_outputs import Transformer2DModelOutput
|
|
9
|
-
from diffusers.utils import (
|
|
10
|
-
scale_lora_layers,
|
|
11
|
-
unscale_lora_layers,
|
|
12
|
-
USE_PEFT_BACKEND,
|
|
13
|
-
)
|
|
14
|
-
|
|
15
|
-
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
16
|
-
from cache_dit.logger import init_logger
|
|
17
|
-
|
|
18
|
-
try:
|
|
19
|
-
from para_attn.para_attn_interface import SparseKVAttnMode
|
|
20
|
-
|
|
21
|
-
def is_sparse_kv_attn_available():
|
|
22
|
-
return True
|
|
23
|
-
|
|
24
|
-
except ImportError:
|
|
25
|
-
|
|
26
|
-
class SparseKVAttnMode:
|
|
27
|
-
def __enter__(self):
|
|
28
|
-
pass
|
|
29
|
-
|
|
30
|
-
def __exit__(self, exc_type, exc_value, traceback):
|
|
31
|
-
pass
|
|
32
|
-
|
|
33
|
-
def is_sparse_kv_attn_available():
|
|
34
|
-
return False
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
logger = init_logger(__name__) # pylint: disable=invalid-name
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
def apply_cache_on_transformer(
|
|
41
|
-
transformer: HunyuanVideoTransformer3DModel,
|
|
42
|
-
):
|
|
43
|
-
if getattr(transformer, "_is_cached", False):
|
|
44
|
-
return transformer
|
|
45
|
-
|
|
46
|
-
cached_transformer_blocks = torch.nn.ModuleList(
|
|
47
|
-
[
|
|
48
|
-
cache_context.CachedTransformerBlocks(
|
|
49
|
-
transformer.transformer_blocks
|
|
50
|
-
+ transformer.single_transformer_blocks,
|
|
51
|
-
transformer=transformer,
|
|
52
|
-
)
|
|
53
|
-
]
|
|
54
|
-
)
|
|
55
|
-
dummy_single_transformer_blocks = torch.nn.ModuleList()
|
|
56
|
-
|
|
57
|
-
original_forward = transformer.forward
|
|
58
|
-
|
|
59
|
-
@functools.wraps(transformer.__class__.forward)
|
|
60
|
-
def new_forward(
|
|
61
|
-
self,
|
|
62
|
-
hidden_states: torch.Tensor,
|
|
63
|
-
timestep: torch.LongTensor,
|
|
64
|
-
encoder_hidden_states: torch.Tensor,
|
|
65
|
-
encoder_attention_mask: torch.Tensor,
|
|
66
|
-
pooled_projections: torch.Tensor,
|
|
67
|
-
guidance: torch.Tensor = None,
|
|
68
|
-
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
69
|
-
return_dict: bool = True,
|
|
70
|
-
**kwargs,
|
|
71
|
-
) -> Union[torch.Tensor, Dict[str, torch.Tensor]]:
|
|
72
|
-
with (
|
|
73
|
-
unittest.mock.patch.object(
|
|
74
|
-
self,
|
|
75
|
-
"transformer_blocks",
|
|
76
|
-
cached_transformer_blocks,
|
|
77
|
-
),
|
|
78
|
-
unittest.mock.patch.object(
|
|
79
|
-
self,
|
|
80
|
-
"single_transformer_blocks",
|
|
81
|
-
dummy_single_transformer_blocks,
|
|
82
|
-
),
|
|
83
|
-
):
|
|
84
|
-
if getattr(self, "_is_parallelized", False):
|
|
85
|
-
return original_forward(
|
|
86
|
-
hidden_states,
|
|
87
|
-
timestep,
|
|
88
|
-
encoder_hidden_states,
|
|
89
|
-
encoder_attention_mask,
|
|
90
|
-
pooled_projections,
|
|
91
|
-
guidance=guidance,
|
|
92
|
-
attention_kwargs=attention_kwargs,
|
|
93
|
-
return_dict=return_dict,
|
|
94
|
-
**kwargs,
|
|
95
|
-
)
|
|
96
|
-
else:
|
|
97
|
-
if attention_kwargs is not None:
|
|
98
|
-
attention_kwargs = attention_kwargs.copy()
|
|
99
|
-
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
100
|
-
else:
|
|
101
|
-
lora_scale = 1.0
|
|
102
|
-
|
|
103
|
-
if USE_PEFT_BACKEND:
|
|
104
|
-
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
105
|
-
scale_lora_layers(self, lora_scale)
|
|
106
|
-
else:
|
|
107
|
-
if (
|
|
108
|
-
attention_kwargs is not None
|
|
109
|
-
and attention_kwargs.get("scale", None) is not None
|
|
110
|
-
):
|
|
111
|
-
logger.warning(
|
|
112
|
-
"Passing `scale` via `attention_kwargs` when not using the PEFT backend is ineffective."
|
|
113
|
-
)
|
|
114
|
-
|
|
115
|
-
batch_size, num_channels, num_frames, height, width = (
|
|
116
|
-
hidden_states.shape
|
|
117
|
-
)
|
|
118
|
-
p, p_t = self.config.patch_size, self.config.patch_size_t
|
|
119
|
-
post_patch_num_frames = num_frames // p_t
|
|
120
|
-
post_patch_height = height // p
|
|
121
|
-
post_patch_width = width // p
|
|
122
|
-
|
|
123
|
-
# 1. RoPE
|
|
124
|
-
image_rotary_emb = self.rope(hidden_states)
|
|
125
|
-
|
|
126
|
-
# 2. Conditional embeddings
|
|
127
|
-
temb = self.time_text_embed(
|
|
128
|
-
timestep, guidance, pooled_projections
|
|
129
|
-
)
|
|
130
|
-
hidden_states = self.x_embedder(hidden_states)
|
|
131
|
-
encoder_hidden_states = self.context_embedder(
|
|
132
|
-
encoder_hidden_states, timestep, encoder_attention_mask
|
|
133
|
-
)
|
|
134
|
-
|
|
135
|
-
# 3. Attention mask preparation
|
|
136
|
-
latent_sequence_length = hidden_states.shape[1]
|
|
137
|
-
latent_attention_mask = torch.ones(
|
|
138
|
-
batch_size,
|
|
139
|
-
1,
|
|
140
|
-
latent_sequence_length,
|
|
141
|
-
device=hidden_states.device,
|
|
142
|
-
dtype=torch.bool,
|
|
143
|
-
) # [B, 1, N]
|
|
144
|
-
attention_mask = torch.cat(
|
|
145
|
-
[
|
|
146
|
-
latent_attention_mask,
|
|
147
|
-
encoder_attention_mask.unsqueeze(1).to(torch.bool),
|
|
148
|
-
],
|
|
149
|
-
dim=-1,
|
|
150
|
-
) # [B, 1, N + M]
|
|
151
|
-
|
|
152
|
-
with SparseKVAttnMode():
|
|
153
|
-
# 4. Transformer blocks
|
|
154
|
-
hidden_states, encoder_hidden_states = (
|
|
155
|
-
self.call_transformer_blocks(
|
|
156
|
-
hidden_states,
|
|
157
|
-
encoder_hidden_states,
|
|
158
|
-
temb,
|
|
159
|
-
attention_mask,
|
|
160
|
-
image_rotary_emb,
|
|
161
|
-
)
|
|
162
|
-
)
|
|
163
|
-
|
|
164
|
-
# 5. Output projection
|
|
165
|
-
hidden_states = self.norm_out(hidden_states, temb)
|
|
166
|
-
hidden_states = self.proj_out(hidden_states)
|
|
167
|
-
|
|
168
|
-
hidden_states = hidden_states.reshape(
|
|
169
|
-
batch_size,
|
|
170
|
-
post_patch_num_frames,
|
|
171
|
-
post_patch_height,
|
|
172
|
-
post_patch_width,
|
|
173
|
-
-1,
|
|
174
|
-
p_t,
|
|
175
|
-
p,
|
|
176
|
-
p,
|
|
177
|
-
)
|
|
178
|
-
hidden_states = hidden_states.permute(0, 4, 1, 5, 2, 6, 3, 7)
|
|
179
|
-
hidden_states = (
|
|
180
|
-
hidden_states.flatten(6, 7).flatten(4, 5).flatten(2, 3)
|
|
181
|
-
)
|
|
182
|
-
|
|
183
|
-
hidden_states = hidden_states.to(timestep.dtype)
|
|
184
|
-
|
|
185
|
-
if USE_PEFT_BACKEND:
|
|
186
|
-
# remove `lora_scale` from each PEFT layer
|
|
187
|
-
unscale_lora_layers(self, lora_scale)
|
|
188
|
-
|
|
189
|
-
if not return_dict:
|
|
190
|
-
return (hidden_states,)
|
|
191
|
-
|
|
192
|
-
return Transformer2DModelOutput(sample=hidden_states)
|
|
193
|
-
|
|
194
|
-
transformer.forward = new_forward.__get__(transformer)
|
|
195
|
-
|
|
196
|
-
def call_transformer_blocks(
|
|
197
|
-
self, hidden_states, encoder_hidden_states, *args, **kwargs
|
|
198
|
-
):
|
|
199
|
-
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
200
|
-
|
|
201
|
-
def create_custom_forward(module, return_dict=None):
|
|
202
|
-
def custom_forward(*inputs):
|
|
203
|
-
if return_dict is not None:
|
|
204
|
-
return module(*inputs, return_dict=return_dict)
|
|
205
|
-
else:
|
|
206
|
-
return module(*inputs)
|
|
207
|
-
|
|
208
|
-
return custom_forward
|
|
209
|
-
|
|
210
|
-
ckpt_kwargs: Dict[str, Any] = {"use_reentrant": False}
|
|
211
|
-
|
|
212
|
-
for block in self.transformer_blocks:
|
|
213
|
-
hidden_states, encoder_hidden_states = (
|
|
214
|
-
torch.utils.checkpoint.checkpoint(
|
|
215
|
-
create_custom_forward(block),
|
|
216
|
-
hidden_states,
|
|
217
|
-
encoder_hidden_states,
|
|
218
|
-
*args,
|
|
219
|
-
**kwargs,
|
|
220
|
-
**ckpt_kwargs,
|
|
221
|
-
)
|
|
222
|
-
)
|
|
223
|
-
|
|
224
|
-
for block in self.single_transformer_blocks:
|
|
225
|
-
hidden_states, encoder_hidden_states = (
|
|
226
|
-
torch.utils.checkpoint.checkpoint(
|
|
227
|
-
create_custom_forward(block),
|
|
228
|
-
hidden_states,
|
|
229
|
-
encoder_hidden_states,
|
|
230
|
-
*args,
|
|
231
|
-
**kwargs,
|
|
232
|
-
**ckpt_kwargs,
|
|
233
|
-
)
|
|
234
|
-
)
|
|
235
|
-
|
|
236
|
-
else:
|
|
237
|
-
for block in self.transformer_blocks:
|
|
238
|
-
hidden_states, encoder_hidden_states = block(
|
|
239
|
-
hidden_states, encoder_hidden_states, *args, **kwargs
|
|
240
|
-
)
|
|
241
|
-
|
|
242
|
-
for block in self.single_transformer_blocks:
|
|
243
|
-
hidden_states, encoder_hidden_states = block(
|
|
244
|
-
hidden_states, encoder_hidden_states, *args, **kwargs
|
|
245
|
-
)
|
|
246
|
-
|
|
247
|
-
return hidden_states, encoder_hidden_states
|
|
248
|
-
|
|
249
|
-
transformer.call_transformer_blocks = call_transformer_blocks.__get__(
|
|
250
|
-
transformer
|
|
251
|
-
)
|
|
252
|
-
|
|
253
|
-
transformer._is_cached = True
|
|
254
|
-
|
|
255
|
-
return transformer
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
def apply_cache_on_pipe(
|
|
259
|
-
pipe: DiffusionPipeline,
|
|
260
|
-
*,
|
|
261
|
-
shallow_patch: bool = False,
|
|
262
|
-
residual_diff_threshold=0.06,
|
|
263
|
-
downsample_factor=1,
|
|
264
|
-
warmup_steps=0,
|
|
265
|
-
max_cached_steps=-1,
|
|
266
|
-
**kwargs,
|
|
267
|
-
):
|
|
268
|
-
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
269
|
-
default_attrs={
|
|
270
|
-
"residual_diff_threshold": residual_diff_threshold,
|
|
271
|
-
"downsample_factor": downsample_factor,
|
|
272
|
-
"warmup_steps": warmup_steps,
|
|
273
|
-
"max_cached_steps": max_cached_steps,
|
|
274
|
-
},
|
|
275
|
-
**kwargs,
|
|
276
|
-
)
|
|
277
|
-
if not getattr(pipe, "_is_cached", False):
|
|
278
|
-
original_call = pipe.__class__.__call__
|
|
279
|
-
|
|
280
|
-
@functools.wraps(original_call)
|
|
281
|
-
def new_call(self, *args, **kwargs):
|
|
282
|
-
with cache_context.cache_context(
|
|
283
|
-
cache_context.create_cache_context(
|
|
284
|
-
**cache_kwargs,
|
|
285
|
-
)
|
|
286
|
-
):
|
|
287
|
-
return original_call(self, *args, **kwargs)
|
|
288
|
-
|
|
289
|
-
pipe.__class__.__call__ = new_call
|
|
290
|
-
pipe.__class__._is_cached = True
|
|
291
|
-
|
|
292
|
-
if not shallow_patch:
|
|
293
|
-
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
294
|
-
|
|
295
|
-
return pipe
|
|
@@ -1,98 +0,0 @@
|
|
|
1
|
-
# Adapted from: https://github.com/chengzeyi/ParaAttention/tree/main/src/para_attn/first_block_cache/wan.py
|
|
2
|
-
|
|
3
|
-
import functools
|
|
4
|
-
import unittest
|
|
5
|
-
|
|
6
|
-
import torch
|
|
7
|
-
from diffusers import DiffusionPipeline, WanTransformer3DModel
|
|
8
|
-
|
|
9
|
-
from cache_dit.cache_factory.first_block_cache import cache_context
|
|
10
|
-
|
|
11
|
-
|
|
12
|
-
def apply_cache_on_transformer(
|
|
13
|
-
transformer: WanTransformer3DModel,
|
|
14
|
-
):
|
|
15
|
-
if getattr(transformer, "_is_cached", False):
|
|
16
|
-
return transformer
|
|
17
|
-
|
|
18
|
-
blocks = torch.nn.ModuleList(
|
|
19
|
-
[
|
|
20
|
-
cache_context.CachedTransformerBlocks(
|
|
21
|
-
transformer.blocks,
|
|
22
|
-
transformer=transformer,
|
|
23
|
-
return_hidden_states_only=True,
|
|
24
|
-
)
|
|
25
|
-
]
|
|
26
|
-
)
|
|
27
|
-
|
|
28
|
-
original_forward = transformer.forward
|
|
29
|
-
|
|
30
|
-
@functools.wraps(transformer.__class__.forward)
|
|
31
|
-
def new_forward(
|
|
32
|
-
self,
|
|
33
|
-
*args,
|
|
34
|
-
**kwargs,
|
|
35
|
-
):
|
|
36
|
-
with unittest.mock.patch.object(
|
|
37
|
-
self,
|
|
38
|
-
"blocks",
|
|
39
|
-
blocks,
|
|
40
|
-
):
|
|
41
|
-
return original_forward(
|
|
42
|
-
*args,
|
|
43
|
-
**kwargs,
|
|
44
|
-
)
|
|
45
|
-
|
|
46
|
-
transformer.forward = new_forward.__get__(transformer)
|
|
47
|
-
|
|
48
|
-
transformer._is_cached = True
|
|
49
|
-
|
|
50
|
-
return transformer
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
def apply_cache_on_pipe(
|
|
54
|
-
pipe: DiffusionPipeline,
|
|
55
|
-
*,
|
|
56
|
-
shallow_patch: bool = False,
|
|
57
|
-
residual_diff_threshold=0.03,
|
|
58
|
-
downsample_factor=1,
|
|
59
|
-
slg_layers=None,
|
|
60
|
-
slg_start: float = 0.0,
|
|
61
|
-
slg_end: float = 0.1,
|
|
62
|
-
warmup_steps=0,
|
|
63
|
-
max_cached_steps=-1,
|
|
64
|
-
**kwargs,
|
|
65
|
-
):
|
|
66
|
-
cache_kwargs, kwargs = cache_context.collect_cache_kwargs(
|
|
67
|
-
default_attrs={
|
|
68
|
-
"residual_diff_threshold": residual_diff_threshold,
|
|
69
|
-
"downsample_factor": downsample_factor,
|
|
70
|
-
"enable_alter_cache": True,
|
|
71
|
-
"slg_layers": slg_layers,
|
|
72
|
-
"slg_start": slg_start,
|
|
73
|
-
"slg_end": slg_end,
|
|
74
|
-
"num_inference_steps": kwargs.get("num_inference_steps", 50),
|
|
75
|
-
"warmup_steps": warmup_steps,
|
|
76
|
-
"max_cached_steps": max_cached_steps,
|
|
77
|
-
},
|
|
78
|
-
**kwargs,
|
|
79
|
-
)
|
|
80
|
-
if not getattr(pipe, "_is_cached", False):
|
|
81
|
-
original_call = pipe.__class__.__call__
|
|
82
|
-
|
|
83
|
-
@functools.wraps(original_call)
|
|
84
|
-
def new_call(self, *args, **kwargs):
|
|
85
|
-
with cache_context.cache_context(
|
|
86
|
-
cache_context.create_cache_context(
|
|
87
|
-
**cache_kwargs,
|
|
88
|
-
)
|
|
89
|
-
):
|
|
90
|
-
return original_call(self, *args, **kwargs)
|
|
91
|
-
|
|
92
|
-
pipe.__class__.__call__ = new_call
|
|
93
|
-
pipe.__class__._is_cached = True
|
|
94
|
-
|
|
95
|
-
if not shallow_patch:
|
|
96
|
-
apply_cache_on_transformer(pipe.transformer, **kwargs)
|
|
97
|
-
|
|
98
|
-
return pipe
|
|
@@ -1,49 +0,0 @@
|
|
|
1
|
-
cache_dit/__init__.py,sha256=0-B173-fLi3IA8nJXoS71zK0zD33Xplysd9skmLfEOY,171
|
|
2
|
-
cache_dit/_version.py,sha256=ut2sCt69XoYh0A1_KAmfCg1IKkN6zwqhu2eMFWAhMbQ,513
|
|
3
|
-
cache_dit/logger.py,sha256=0zsu42hN-3-rgGC_C29ms1IvVpV4_b4_SwJCKSenxBE,4304
|
|
4
|
-
cache_dit/primitives.py,sha256=A2iG9YLot3gOsZSPp-_gyjqjLgJvWQRx8aitD4JQ23Y,3877
|
|
5
|
-
cache_dit/utils.py,sha256=4cFNh0asch6Zgsixq0bS1ElfwBu_6BG5ZSmaa1khjyg,144
|
|
6
|
-
cache_dit/cache_factory/__init__.py,sha256=iYQwLwB_XLoYl0OB9unZGDbBtrYvZaLkOAmhGRwdW2E,191
|
|
7
|
-
cache_dit/cache_factory/adapters.py,sha256=QMCaXnmqM7NT7sx4bCF1mMLn-QcXX9h1RmgLAypDedg,5256
|
|
8
|
-
cache_dit/cache_factory/taylorseer.py,sha256=LKSNo2ode69EVo9xrxjxAMEjz0yDGiGADeDYnEqddA8,3987
|
|
9
|
-
cache_dit/cache_factory/utils.py,sha256=V-Mb5Jn07geEUUWo4QAfh6pmSzkL-2OGDn0VAXbG6hQ,1799
|
|
10
|
-
cache_dit/cache_factory/dual_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
11
|
-
cache_dit/cache_factory/dual_block_cache/cache_context.py,sha256=sJ9yxQlcrX4qkPln94FrL0WDe2WIn3_UD2-Mk8YtjSw,73301
|
|
12
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py,sha256=krNAICf-aS3JLmSG8vOB9tpLa04uYRcABsC8PMbVUKY,1870
|
|
13
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py,sha256=3xUjvDzor9AkBkDUc0N7kZqM86MIdajuigesnicNzXE,2260
|
|
14
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py,sha256=cIsov6Pf0dRyddqkzTA2CU-jSDotof8LQr-HIoY9T9M,2615
|
|
15
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py,sha256=SO4q39PQuQ5QVHy5Z-ubiKdstzvQPedONN2J5oiGUh0,9955
|
|
16
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py,sha256=8W9m-WeEVE2ytYi9udKEA8Wtb0EnvP3eT2A1Tu-d29k,2252
|
|
17
|
-
cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py,sha256=EREHM5E1wxnL-uRXRAEege4HXraRp1oD_r1Zx4CsiKk,2596
|
|
18
|
-
cache_dit/cache_factory/dynamic_block_prune/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
19
|
-
cache_dit/cache_factory/dynamic_block_prune/prune_context.py,sha256=1qarKAsEFiaaN2_ghko2dqGz_R7BTQSOyGtb_eQq38Y,35716
|
|
20
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py,sha256=hVBTXj9MMGFGVezT3j8MntFRBiphSaUL4YhSOd8JtuY,1870
|
|
21
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py,sha256=KP8NxtHAKzzBOoX0lhvlMgY_5dmP4Z3T5TOfwl4SSyg,2273
|
|
22
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py,sha256=kCB7lL4OIq8TZn-baMIF8D_PVPTFW60omCMVQCb8ebs,2628
|
|
23
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py,sha256=xAkd40BGsfuCKdW3Abrx35VwgZQg4CZFz13P4VY71eY,9968
|
|
24
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py,sha256=zXgoRDDjus3a2WSjtNh4ERtQp20ceb6nzohHMDlo2zY,2265
|
|
25
|
-
cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py,sha256=PA7nuLgfAelnaI8usQx0Kxi8XATzMapyR1WndEdFoZA,2604
|
|
26
|
-
cache_dit/cache_factory/first_block_cache/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
27
|
-
cache_dit/cache_factory/first_block_cache/cache_context.py,sha256=qn4zWJ_eEMIPYzrxXoslunxbzK0WueuNtC54Pp5Q57k,23241
|
|
28
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py,sha256=-FFgA2MoudEo7uDacg4aWgm1KwfLZFsEDTVxatgbq9M,2146
|
|
29
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py,sha256=qO5CWyurtwW30mvOe6cxeQPTSXLDlPJcezm72zEjDq8,2375
|
|
30
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py,sha256=Dcd4OzABCtyQCZNX2KNnUTdVoO1E1ApM7P8gcVYzcK0,2733
|
|
31
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py,sha256=OL7W4ukYlZz0IDmBR1zVV6XT3Mgciglj9Hqzv1wUAkQ,10092
|
|
32
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py,sha256=lQTClo52OwPbNEE4jiBZQhfC7hbtYqnYIABp_vbm_dk,2363
|
|
33
|
-
cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py,sha256=dBNzHBECAuTTA1a7kLdvZL20YzaKTAS3iciVLzKKEWA,2638
|
|
34
|
-
cache_dit/compile/__init__.py,sha256=DfMdPleFFGADXLsr7zXui8BTz_y9futY6rNmNdh9y7k,63
|
|
35
|
-
cache_dit/compile/utils.py,sha256=N4A55_8uIbEd-S4xyJPcrdKceI2MGM9BTIhJE63jyL4,3786
|
|
36
|
-
cache_dit/custom_ops/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
37
|
-
cache_dit/custom_ops/triton_taylorseer.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
38
|
-
cache_dit/metrics/__init__.py,sha256=RaUhl5dieF40RqnizGzR30qoJJ9dyMUEADwgwMaMQrE,575
|
|
39
|
-
cache_dit/metrics/config.py,sha256=ieOgD9ayz722RjVzk24bSIqS2D6o7TZjGk8KeXV-OLQ,551
|
|
40
|
-
cache_dit/metrics/fid.py,sha256=9Ivtazl6mW0Bon2VXa-Ia5Xj2ewxRD3V1Qkd69zYM3Y,17066
|
|
41
|
-
cache_dit/metrics/inception.py,sha256=pBVe2X6ylLPIXTG4-GWDM9DWnCviMJbJ45R3ulhktR0,12759
|
|
42
|
-
cache_dit/metrics/lpips.py,sha256=I2qCNi6qJh5TRsaIsdxO0WoRX1DN7U_H3zS0oCSahYM,1032
|
|
43
|
-
cache_dit/metrics/metrics.py,sha256=8jvM1sF-nDxUuwCRy44QEoo4dYVLCQVh1QyAMs4eaQY,27840
|
|
44
|
-
cache_dit-0.2.14.dist-info/licenses/LICENSE,sha256=Dqb07Ik2dV41s9nIdMUbiRWEfDqo7-dQeRiY7kPO8PE,3769
|
|
45
|
-
cache_dit-0.2.14.dist-info/METADATA,sha256=EyZN75JcVcvTc5bopHXfl6w-nA-ro9Uit2Sjy5DU66A,25198
|
|
46
|
-
cache_dit-0.2.14.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
47
|
-
cache_dit-0.2.14.dist-info/entry_points.txt,sha256=FX2gysXaZx6NeK1iCLMcIdP8Q4_qikkIHtEmi3oWn8o,65
|
|
48
|
-
cache_dit-0.2.14.dist-info/top_level.txt,sha256=ZJDydonLEhujzz0FOkVbO-BqfzO9d_VqRHmZU-3MOZo,10
|
|
49
|
-
cache_dit-0.2.14.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|