cache-dit 0.2.15__py3-none-any.whl → 0.2.17__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of cache-dit might be problematic. Click here for more details.
- cache_dit/__init__.py +12 -0
- cache_dit/_version.py +16 -3
- cache_dit/cache_factory/.gitignore +2 -0
- cache_dit/cache_factory/__init__.py +52 -2
- cache_dit/cache_factory/cache_adapters.py +654 -0
- cache_dit/cache_factory/cache_blocks.py +487 -0
- cache_dit/cache_factory/{dual_block_cache/cache_context.py → cache_context.py} +11 -862
- cache_dit/cache_factory/patch/flux.py +249 -0
- cache_dit/cache_factory/utils.py +1 -1
- cache_dit/compile/__init__.py +1 -1
- cache_dit/compile/utils.py +1 -1
- {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/METADATA +87 -204
- cache_dit-0.2.17.dist-info/RECORD +30 -0
- cache_dit/cache_factory/adapters.py +0 -169
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/__init__.py +0 -55
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/cogvideox.py +0 -87
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/flux.py +0 -98
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/hunyuan_video.py +0 -294
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/mochi.py +0 -87
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/qwen_image.py +0 -88
- cache_dit/cache_factory/dual_block_cache/diffusers_adapters/wan.py +0 -97
- cache_dit/cache_factory/dynamic_block_prune/__init__.py +0 -0
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/__init__.py +0 -51
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/cogvideox.py +0 -87
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/flux.py +0 -98
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/hunyuan_video.py +0 -294
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/mochi.py +0 -87
- cache_dit/cache_factory/dynamic_block_prune/diffusers_adapters/wan.py +0 -97
- cache_dit/cache_factory/dynamic_block_prune/prune_context.py +0 -1005
- cache_dit/cache_factory/first_block_cache/__init__.py +0 -0
- cache_dit/cache_factory/first_block_cache/cache_context.py +0 -719
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/__init__.py +0 -57
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/cogvideox.py +0 -89
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/flux.py +0 -100
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/hunyuan_video.py +0 -295
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/mochi.py +0 -89
- cache_dit/cache_factory/first_block_cache/diffusers_adapters/wan.py +0 -98
- cache_dit-0.2.15.dist-info/RECORD +0 -50
- /cache_dit/cache_factory/{dual_block_cache → patch}/__init__.py +0 -0
- {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/WHEEL +0 -0
- {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/licenses/LICENSE +0 -0
- {cache_dit-0.2.15.dist-info → cache_dit-0.2.17.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,249 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import numpy as np
|
|
5
|
+
from typing import Tuple, Optional, Dict, Any, Union
|
|
6
|
+
from diffusers import FluxTransformer2DModel
|
|
7
|
+
from diffusers.models.transformers.transformer_flux import (
|
|
8
|
+
FluxSingleTransformerBlock,
|
|
9
|
+
Transformer2DModelOutput,
|
|
10
|
+
)
|
|
11
|
+
from diffusers.utils import (
|
|
12
|
+
USE_PEFT_BACKEND,
|
|
13
|
+
scale_lora_layers,
|
|
14
|
+
unscale_lora_layers,
|
|
15
|
+
)
|
|
16
|
+
|
|
17
|
+
|
|
18
|
+
from cache_dit.logger import init_logger
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L380
|
|
24
|
+
def __patch_single_forward__(
|
|
25
|
+
self: FluxSingleTransformerBlock,
|
|
26
|
+
hidden_states: torch.Tensor,
|
|
27
|
+
encoder_hidden_states: torch.Tensor,
|
|
28
|
+
temb: torch.Tensor,
|
|
29
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
30
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
31
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
32
|
+
text_seq_len = encoder_hidden_states.shape[1]
|
|
33
|
+
hidden_states = torch.cat([encoder_hidden_states, hidden_states], dim=1)
|
|
34
|
+
|
|
35
|
+
residual = hidden_states
|
|
36
|
+
norm_hidden_states, gate = self.norm(hidden_states, emb=temb)
|
|
37
|
+
mlp_hidden_states = self.act_mlp(self.proj_mlp(norm_hidden_states))
|
|
38
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
39
|
+
attn_output = self.attn(
|
|
40
|
+
hidden_states=norm_hidden_states,
|
|
41
|
+
image_rotary_emb=image_rotary_emb,
|
|
42
|
+
**joint_attention_kwargs,
|
|
43
|
+
)
|
|
44
|
+
|
|
45
|
+
hidden_states = torch.cat([attn_output, mlp_hidden_states], dim=2)
|
|
46
|
+
gate = gate.unsqueeze(1)
|
|
47
|
+
hidden_states = gate * self.proj_out(hidden_states)
|
|
48
|
+
hidden_states = residual + hidden_states
|
|
49
|
+
if hidden_states.dtype == torch.float16:
|
|
50
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
51
|
+
|
|
52
|
+
encoder_hidden_states, hidden_states = (
|
|
53
|
+
hidden_states[:, :text_seq_len],
|
|
54
|
+
hidden_states[:, text_seq_len:],
|
|
55
|
+
)
|
|
56
|
+
return encoder_hidden_states, hidden_states
|
|
57
|
+
|
|
58
|
+
|
|
59
|
+
# copy from: https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/transformers/transformer_flux.py#L631
|
|
60
|
+
def __patch_transformer_forward__(
|
|
61
|
+
self: FluxTransformer2DModel,
|
|
62
|
+
hidden_states: torch.Tensor,
|
|
63
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
64
|
+
pooled_projections: torch.Tensor = None,
|
|
65
|
+
timestep: torch.LongTensor = None,
|
|
66
|
+
img_ids: torch.Tensor = None,
|
|
67
|
+
txt_ids: torch.Tensor = None,
|
|
68
|
+
guidance: torch.Tensor = None,
|
|
69
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
70
|
+
controlnet_block_samples=None,
|
|
71
|
+
controlnet_single_block_samples=None,
|
|
72
|
+
return_dict: bool = True,
|
|
73
|
+
controlnet_blocks_repeat: bool = False,
|
|
74
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
75
|
+
if joint_attention_kwargs is not None:
|
|
76
|
+
joint_attention_kwargs = joint_attention_kwargs.copy()
|
|
77
|
+
lora_scale = joint_attention_kwargs.pop("scale", 1.0)
|
|
78
|
+
else:
|
|
79
|
+
lora_scale = 1.0
|
|
80
|
+
|
|
81
|
+
if USE_PEFT_BACKEND:
|
|
82
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
83
|
+
scale_lora_layers(self, lora_scale)
|
|
84
|
+
else:
|
|
85
|
+
if (
|
|
86
|
+
joint_attention_kwargs is not None
|
|
87
|
+
and joint_attention_kwargs.get("scale", None) is not None
|
|
88
|
+
):
|
|
89
|
+
logger.warning(
|
|
90
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
|
91
|
+
)
|
|
92
|
+
|
|
93
|
+
hidden_states = self.x_embedder(hidden_states)
|
|
94
|
+
|
|
95
|
+
timestep = timestep.to(hidden_states.dtype) * 1000
|
|
96
|
+
if guidance is not None:
|
|
97
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
98
|
+
|
|
99
|
+
temb = (
|
|
100
|
+
self.time_text_embed(timestep, pooled_projections)
|
|
101
|
+
if guidance is None
|
|
102
|
+
else self.time_text_embed(timestep, guidance, pooled_projections)
|
|
103
|
+
)
|
|
104
|
+
encoder_hidden_states = self.context_embedder(encoder_hidden_states)
|
|
105
|
+
|
|
106
|
+
if txt_ids.ndim == 3:
|
|
107
|
+
logger.warning(
|
|
108
|
+
"Passing `txt_ids` 3d torch.Tensor is deprecated."
|
|
109
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
110
|
+
)
|
|
111
|
+
txt_ids = txt_ids[0]
|
|
112
|
+
if img_ids.ndim == 3:
|
|
113
|
+
logger.warning(
|
|
114
|
+
"Passing `img_ids` 3d torch.Tensor is deprecated."
|
|
115
|
+
"Please remove the batch dimension and pass it as a 2d torch Tensor"
|
|
116
|
+
)
|
|
117
|
+
img_ids = img_ids[0]
|
|
118
|
+
|
|
119
|
+
ids = torch.cat((txt_ids, img_ids), dim=0)
|
|
120
|
+
image_rotary_emb = self.pos_embed(ids)
|
|
121
|
+
|
|
122
|
+
if (
|
|
123
|
+
joint_attention_kwargs is not None
|
|
124
|
+
and "ip_adapter_image_embeds" in joint_attention_kwargs
|
|
125
|
+
):
|
|
126
|
+
ip_adapter_image_embeds = joint_attention_kwargs.pop(
|
|
127
|
+
"ip_adapter_image_embeds"
|
|
128
|
+
)
|
|
129
|
+
ip_hidden_states = self.encoder_hid_proj(ip_adapter_image_embeds)
|
|
130
|
+
joint_attention_kwargs.update({"ip_hidden_states": ip_hidden_states})
|
|
131
|
+
|
|
132
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
133
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
134
|
+
encoder_hidden_states, hidden_states = (
|
|
135
|
+
self._gradient_checkpointing_func(
|
|
136
|
+
block,
|
|
137
|
+
hidden_states,
|
|
138
|
+
encoder_hidden_states,
|
|
139
|
+
temb,
|
|
140
|
+
image_rotary_emb,
|
|
141
|
+
joint_attention_kwargs,
|
|
142
|
+
)
|
|
143
|
+
)
|
|
144
|
+
|
|
145
|
+
else:
|
|
146
|
+
encoder_hidden_states, hidden_states = block(
|
|
147
|
+
hidden_states=hidden_states,
|
|
148
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
149
|
+
temb=temb,
|
|
150
|
+
image_rotary_emb=image_rotary_emb,
|
|
151
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
152
|
+
)
|
|
153
|
+
|
|
154
|
+
# controlnet residual
|
|
155
|
+
if controlnet_block_samples is not None:
|
|
156
|
+
interval_control = len(self.transformer_blocks) / len(
|
|
157
|
+
controlnet_block_samples
|
|
158
|
+
)
|
|
159
|
+
interval_control = int(np.ceil(interval_control))
|
|
160
|
+
# For Xlabs ControlNet.
|
|
161
|
+
if controlnet_blocks_repeat:
|
|
162
|
+
hidden_states = (
|
|
163
|
+
hidden_states
|
|
164
|
+
+ controlnet_block_samples[
|
|
165
|
+
index_block % len(controlnet_block_samples)
|
|
166
|
+
]
|
|
167
|
+
)
|
|
168
|
+
else:
|
|
169
|
+
hidden_states = (
|
|
170
|
+
hidden_states
|
|
171
|
+
+ controlnet_block_samples[index_block // interval_control]
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
for index_block, block in enumerate(self.single_transformer_blocks):
|
|
175
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
176
|
+
encoder_hidden_states, hidden_states = (
|
|
177
|
+
self._gradient_checkpointing_func(
|
|
178
|
+
block,
|
|
179
|
+
hidden_states,
|
|
180
|
+
encoder_hidden_states,
|
|
181
|
+
temb,
|
|
182
|
+
image_rotary_emb,
|
|
183
|
+
joint_attention_kwargs,
|
|
184
|
+
)
|
|
185
|
+
)
|
|
186
|
+
|
|
187
|
+
else:
|
|
188
|
+
encoder_hidden_states, hidden_states = block(
|
|
189
|
+
hidden_states=hidden_states,
|
|
190
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
191
|
+
temb=temb,
|
|
192
|
+
image_rotary_emb=image_rotary_emb,
|
|
193
|
+
joint_attention_kwargs=joint_attention_kwargs,
|
|
194
|
+
)
|
|
195
|
+
|
|
196
|
+
# controlnet residual
|
|
197
|
+
if controlnet_single_block_samples is not None:
|
|
198
|
+
interval_control = len(self.single_transformer_blocks) / len(
|
|
199
|
+
controlnet_single_block_samples
|
|
200
|
+
)
|
|
201
|
+
interval_control = int(np.ceil(interval_control))
|
|
202
|
+
hidden_states = (
|
|
203
|
+
hidden_states
|
|
204
|
+
+ controlnet_single_block_samples[
|
|
205
|
+
index_block // interval_control
|
|
206
|
+
]
|
|
207
|
+
)
|
|
208
|
+
|
|
209
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
210
|
+
output = self.proj_out(hidden_states)
|
|
211
|
+
|
|
212
|
+
if USE_PEFT_BACKEND:
|
|
213
|
+
# remove `lora_scale` from each PEFT layer
|
|
214
|
+
unscale_lora_layers(self, lora_scale)
|
|
215
|
+
|
|
216
|
+
if not return_dict:
|
|
217
|
+
return (output,)
|
|
218
|
+
|
|
219
|
+
return Transformer2DModelOutput(sample=output)
|
|
220
|
+
|
|
221
|
+
|
|
222
|
+
def maybe_patch_flux_transformer(
|
|
223
|
+
transformer: FluxTransformer2DModel,
|
|
224
|
+
blocks: torch.nn.ModuleList = None,
|
|
225
|
+
) -> FluxTransformer2DModel:
|
|
226
|
+
if blocks is None:
|
|
227
|
+
blocks = transformer.single_transformer_blocks
|
|
228
|
+
|
|
229
|
+
is_patched = False
|
|
230
|
+
for block in blocks:
|
|
231
|
+
if isinstance(block, FluxSingleTransformerBlock):
|
|
232
|
+
forward_parameters = inspect.signature(
|
|
233
|
+
blocks.forward
|
|
234
|
+
).parameters.keys()
|
|
235
|
+
if "encoder_hidden_states" not in forward_parameters:
|
|
236
|
+
block.forward = __patch_single_forward__.__get__(block)
|
|
237
|
+
is_patched = True
|
|
238
|
+
|
|
239
|
+
if is_patched:
|
|
240
|
+
logger.warning("Patched Flux for cache-dit.")
|
|
241
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
242
|
+
"Please call apply_cache_on_pipe before Parallelize, "
|
|
243
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
244
|
+
"parallized forward and cause a downgrade of performance."
|
|
245
|
+
)
|
|
246
|
+
transformer.forward = __patch_transformer_forward__.__get__(transformer)
|
|
247
|
+
transformer._is_patched = True
|
|
248
|
+
|
|
249
|
+
return transformer
|
cache_dit/cache_factory/utils.py
CHANGED
cache_dit/compile/__init__.py
CHANGED
|
@@ -1 +1 @@
|
|
|
1
|
-
from cache_dit.compile.utils import
|
|
1
|
+
from cache_dit.compile.utils import set_compile_configs
|
cache_dit/compile/utils.py
CHANGED
|
@@ -23,7 +23,7 @@ def epilogue_prologue_fusion_enabled(**kwargs) -> bool:
|
|
|
23
23
|
return CACHE_DIT_EPILOGUE_PROLOGUE_FUSION or mode
|
|
24
24
|
|
|
25
25
|
|
|
26
|
-
def
|
|
26
|
+
def set_compile_configs(
|
|
27
27
|
cuda_graphs: bool = False,
|
|
28
28
|
force_disable_compile_caches: bool = False,
|
|
29
29
|
use_fast_math: bool = False,
|