cache-dit 0.3.2__py3-none-any.whl → 1.0.14__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cache_dit/__init__.py +37 -19
- cache_dit/_version.py +2 -2
- cache_dit/caching/__init__.py +36 -0
- cache_dit/{cache_factory → caching}/block_adapters/__init__.py +149 -18
- cache_dit/{cache_factory → caching}/block_adapters/block_adapters.py +91 -7
- cache_dit/caching/block_adapters/block_registers.py +118 -0
- cache_dit/caching/cache_adapters/__init__.py +1 -0
- cache_dit/{cache_factory → caching}/cache_adapters/cache_adapter.py +262 -123
- cache_dit/caching/cache_blocks/__init__.py +226 -0
- cache_dit/caching/cache_blocks/offload_utils.py +115 -0
- cache_dit/caching/cache_blocks/pattern_0_1_2.py +26 -0
- cache_dit/caching/cache_blocks/pattern_3_4_5.py +543 -0
- cache_dit/caching/cache_blocks/pattern_base.py +748 -0
- cache_dit/caching/cache_blocks/pattern_utils.py +86 -0
- cache_dit/caching/cache_contexts/__init__.py +28 -0
- cache_dit/caching/cache_contexts/cache_config.py +120 -0
- cache_dit/{cache_factory → caching}/cache_contexts/cache_context.py +29 -90
- cache_dit/{cache_factory → caching}/cache_contexts/cache_manager.py +138 -10
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/__init__.py +25 -3
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/foca.py +1 -1
- cache_dit/{cache_factory → caching}/cache_contexts/calibrators/taylorseer.py +81 -9
- cache_dit/caching/cache_contexts/context_manager.py +36 -0
- cache_dit/caching/cache_contexts/prune_config.py +63 -0
- cache_dit/caching/cache_contexts/prune_context.py +155 -0
- cache_dit/caching/cache_contexts/prune_manager.py +167 -0
- cache_dit/caching/cache_interface.py +358 -0
- cache_dit/{cache_factory → caching}/cache_types.py +19 -2
- cache_dit/{cache_factory → caching}/forward_pattern.py +14 -14
- cache_dit/{cache_factory → caching}/params_modifier.py +10 -10
- cache_dit/caching/patch_functors/__init__.py +15 -0
- cache_dit/{cache_factory → caching}/patch_functors/functor_chroma.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_dit.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_flux.py +1 -1
- cache_dit/{cache_factory → caching}/patch_functors/functor_hidream.py +2 -4
- cache_dit/{cache_factory → caching}/patch_functors/functor_hunyuan_dit.py +1 -1
- cache_dit/caching/patch_functors/functor_qwen_image_controlnet.py +263 -0
- cache_dit/caching/utils.py +68 -0
- cache_dit/metrics/__init__.py +11 -0
- cache_dit/metrics/metrics.py +3 -0
- cache_dit/parallelism/__init__.py +3 -0
- cache_dit/parallelism/backends/native_diffusers/__init__.py +6 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/__init__.py +164 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/__init__.py +4 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/attention/_attention_dispatch.py +304 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_chroma.py +95 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogvideox.py +202 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cogview.py +299 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_cosisid.py +123 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_dit.py +94 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_flux.py +88 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_hunyuan.py +729 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_ltxvideo.py +264 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_nunchaku.py +407 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_pixart.py +285 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_qwen_image.py +104 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_registers.py +84 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_plan_wan.py +101 -0
- cache_dit/parallelism/backends/native_diffusers/context_parallelism/cp_planners.py +117 -0
- cache_dit/parallelism/backends/native_diffusers/parallel_difffusers.py +49 -0
- cache_dit/parallelism/backends/native_diffusers/utils.py +11 -0
- cache_dit/parallelism/backends/native_pytorch/__init__.py +6 -0
- cache_dit/parallelism/backends/native_pytorch/parallel_torch.py +62 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/__init__.py +48 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_flux.py +171 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_kandinsky5.py +79 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_qwen_image.py +78 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_registers.py +65 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_plan_wan.py +153 -0
- cache_dit/parallelism/backends/native_pytorch/tensor_parallelism/tp_planners.py +14 -0
- cache_dit/parallelism/parallel_backend.py +26 -0
- cache_dit/parallelism/parallel_config.py +88 -0
- cache_dit/parallelism/parallel_interface.py +77 -0
- cache_dit/quantize/__init__.py +7 -0
- cache_dit/quantize/backends/__init__.py +1 -0
- cache_dit/quantize/backends/bitsandbytes/__init__.py +0 -0
- cache_dit/quantize/backends/torchao/__init__.py +1 -0
- cache_dit/quantize/{quantize_ao.py → backends/torchao/quantize_ao.py} +44 -30
- cache_dit/quantize/quantize_backend.py +0 -0
- cache_dit/quantize/quantize_config.py +0 -0
- cache_dit/quantize/quantize_interface.py +3 -16
- cache_dit/summary.py +593 -0
- cache_dit/utils.py +46 -290
- cache_dit-1.0.14.dist-info/METADATA +301 -0
- cache_dit-1.0.14.dist-info/RECORD +102 -0
- cache_dit-1.0.14.dist-info/licenses/LICENSE +203 -0
- cache_dit/cache_factory/__init__.py +0 -28
- cache_dit/cache_factory/block_adapters/block_registers.py +0 -90
- cache_dit/cache_factory/cache_adapters/__init__.py +0 -1
- cache_dit/cache_factory/cache_blocks/__init__.py +0 -72
- cache_dit/cache_factory/cache_blocks/pattern_0_1_2.py +0 -16
- cache_dit/cache_factory/cache_blocks/pattern_3_4_5.py +0 -238
- cache_dit/cache_factory/cache_blocks/pattern_base.py +0 -404
- cache_dit/cache_factory/cache_blocks/utils.py +0 -41
- cache_dit/cache_factory/cache_contexts/__init__.py +0 -14
- cache_dit/cache_factory/cache_interface.py +0 -217
- cache_dit/cache_factory/patch_functors/__init__.py +0 -12
- cache_dit/cache_factory/utils.py +0 -57
- cache_dit-0.3.2.dist-info/METADATA +0 -753
- cache_dit-0.3.2.dist-info/RECORD +0 -56
- cache_dit-0.3.2.dist-info/licenses/LICENSE +0 -53
- /cache_dit/{cache_factory → caching}/.gitignore +0 -0
- /cache_dit/{cache_factory → caching}/cache_contexts/calibrators/base.py +0 -0
- /cache_dit/{cache_factory → caching}/patch_functors/functor_base.py +0 -0
- /cache_dit/{custom_ops → kernels}/__init__.py +0 -0
- /cache_dit/{custom_ops → kernels}/triton_taylorseer.py +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/WHEEL +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/entry_points.txt +0 -0
- {cache_dit-0.3.2.dist-info → cache_dit-1.0.14.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,263 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import numpy as np
|
|
3
|
+
from typing import Tuple, Optional, Dict, Any, Union, List
|
|
4
|
+
from diffusers import QwenImageTransformer2DModel
|
|
5
|
+
from diffusers.models.transformers.transformer_qwenimage import (
|
|
6
|
+
QwenImageTransformerBlock,
|
|
7
|
+
Transformer2DModelOutput,
|
|
8
|
+
)
|
|
9
|
+
from diffusers.utils import (
|
|
10
|
+
USE_PEFT_BACKEND,
|
|
11
|
+
scale_lora_layers,
|
|
12
|
+
unscale_lora_layers,
|
|
13
|
+
)
|
|
14
|
+
from cache_dit.caching.patch_functors.functor_base import (
|
|
15
|
+
PatchFunctor,
|
|
16
|
+
)
|
|
17
|
+
from cache_dit.logger import init_logger
|
|
18
|
+
|
|
19
|
+
logger = init_logger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class QwenImageControlNetPatchFunctor(PatchFunctor):
|
|
23
|
+
|
|
24
|
+
def apply(
|
|
25
|
+
self,
|
|
26
|
+
transformer: QwenImageTransformer2DModel,
|
|
27
|
+
**kwargs,
|
|
28
|
+
) -> QwenImageTransformer2DModel:
|
|
29
|
+
if hasattr(transformer, "_is_patched"):
|
|
30
|
+
return transformer
|
|
31
|
+
|
|
32
|
+
is_patched = False
|
|
33
|
+
|
|
34
|
+
_index_block = 0
|
|
35
|
+
_num_blocks = len(transformer.transformer_blocks)
|
|
36
|
+
for block in transformer.transformer_blocks:
|
|
37
|
+
assert isinstance(block, QwenImageTransformerBlock)
|
|
38
|
+
block._index_block = _index_block
|
|
39
|
+
block._num_blocks = _num_blocks
|
|
40
|
+
block.forward = __patch_block_forward__.__get__(block)
|
|
41
|
+
_index_block += 1
|
|
42
|
+
|
|
43
|
+
is_patched = True
|
|
44
|
+
cls_name = transformer.__class__.__name__
|
|
45
|
+
|
|
46
|
+
if is_patched:
|
|
47
|
+
logger.warning(f"Patched {cls_name} for cache-dit.")
|
|
48
|
+
assert not getattr(transformer, "_is_parallelized", False), (
|
|
49
|
+
"Please call `cache_dit.enable_cache` before Parallelize, "
|
|
50
|
+
"the __patch_transformer_forward__ will overwrite the "
|
|
51
|
+
"parallized forward and cause a downgrade of performance."
|
|
52
|
+
)
|
|
53
|
+
transformer.forward = __patch_transformer_forward__.__get__(
|
|
54
|
+
transformer
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
transformer._is_patched = is_patched # True or False
|
|
58
|
+
|
|
59
|
+
logger.info(
|
|
60
|
+
f"Applied {self.__class__.__name__} for {cls_name}, "
|
|
61
|
+
f"Patch: {is_patched}."
|
|
62
|
+
)
|
|
63
|
+
|
|
64
|
+
return transformer
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def __patch_block_forward__(
|
|
68
|
+
self: QwenImageTransformerBlock,
|
|
69
|
+
hidden_states: torch.Tensor,
|
|
70
|
+
encoder_hidden_states: torch.Tensor,
|
|
71
|
+
encoder_hidden_states_mask: torch.Tensor,
|
|
72
|
+
temb: torch.Tensor,
|
|
73
|
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
|
74
|
+
joint_attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
75
|
+
controlnet_block_samples: Optional[List[torch.Tensor]] = None,
|
|
76
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
77
|
+
# Get modulation parameters for both streams
|
|
78
|
+
img_mod_params = self.img_mod(temb) # [B, 6*dim]
|
|
79
|
+
txt_mod_params = self.txt_mod(temb) # [B, 6*dim]
|
|
80
|
+
|
|
81
|
+
# Split modulation parameters for norm1 and norm2
|
|
82
|
+
img_mod1, img_mod2 = img_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
|
83
|
+
txt_mod1, txt_mod2 = txt_mod_params.chunk(2, dim=-1) # Each [B, 3*dim]
|
|
84
|
+
|
|
85
|
+
# Process image stream - norm1 + modulation
|
|
86
|
+
img_normed = self.img_norm1(hidden_states)
|
|
87
|
+
img_modulated, img_gate1 = self._modulate(img_normed, img_mod1)
|
|
88
|
+
|
|
89
|
+
# Process text stream - norm1 + modulation
|
|
90
|
+
txt_normed = self.txt_norm1(encoder_hidden_states)
|
|
91
|
+
txt_modulated, txt_gate1 = self._modulate(txt_normed, txt_mod1)
|
|
92
|
+
|
|
93
|
+
# Use QwenAttnProcessor2_0 for joint attention computation
|
|
94
|
+
# This directly implements the DoubleStreamLayerMegatron logic:
|
|
95
|
+
# 1. Computes QKV for both streams
|
|
96
|
+
# 2. Applies QK normalization and RoPE
|
|
97
|
+
# 3. Concatenates and runs joint attention
|
|
98
|
+
# 4. Splits results back to separate streams
|
|
99
|
+
joint_attention_kwargs = joint_attention_kwargs or {}
|
|
100
|
+
attn_output = self.attn(
|
|
101
|
+
hidden_states=img_modulated, # Image stream (will be processed as "sample")
|
|
102
|
+
encoder_hidden_states=txt_modulated, # Text stream (will be processed as "context")
|
|
103
|
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
|
104
|
+
image_rotary_emb=image_rotary_emb,
|
|
105
|
+
**joint_attention_kwargs,
|
|
106
|
+
)
|
|
107
|
+
|
|
108
|
+
# QwenAttnProcessor2_0 returns (img_output, txt_output) when encoder_hidden_states is provided
|
|
109
|
+
img_attn_output, txt_attn_output = attn_output
|
|
110
|
+
|
|
111
|
+
# Apply attention gates and add residual (like in Megatron)
|
|
112
|
+
hidden_states = hidden_states + img_gate1 * img_attn_output
|
|
113
|
+
encoder_hidden_states = encoder_hidden_states + txt_gate1 * txt_attn_output
|
|
114
|
+
|
|
115
|
+
# Process image stream - norm2 + MLP
|
|
116
|
+
img_normed2 = self.img_norm2(hidden_states)
|
|
117
|
+
img_modulated2, img_gate2 = self._modulate(img_normed2, img_mod2)
|
|
118
|
+
img_mlp_output = self.img_mlp(img_modulated2)
|
|
119
|
+
hidden_states = hidden_states + img_gate2 * img_mlp_output
|
|
120
|
+
|
|
121
|
+
# Process text stream - norm2 + MLP
|
|
122
|
+
txt_normed2 = self.txt_norm2(encoder_hidden_states)
|
|
123
|
+
txt_modulated2, txt_gate2 = self._modulate(txt_normed2, txt_mod2)
|
|
124
|
+
txt_mlp_output = self.txt_mlp(txt_modulated2)
|
|
125
|
+
encoder_hidden_states = encoder_hidden_states + txt_gate2 * txt_mlp_output
|
|
126
|
+
|
|
127
|
+
# Clip to prevent overflow for fp16
|
|
128
|
+
if encoder_hidden_states.dtype == torch.float16:
|
|
129
|
+
encoder_hidden_states = encoder_hidden_states.clip(-65504, 65504)
|
|
130
|
+
if hidden_states.dtype == torch.float16:
|
|
131
|
+
hidden_states = hidden_states.clip(-65504, 65504)
|
|
132
|
+
|
|
133
|
+
if controlnet_block_samples is not None:
|
|
134
|
+
# Add ControlNet conditioning
|
|
135
|
+
num_blocks = self._num_blocks
|
|
136
|
+
index_block = self._index_block
|
|
137
|
+
interval_control = num_blocks / len(controlnet_block_samples)
|
|
138
|
+
interval_control = int(np.ceil(interval_control))
|
|
139
|
+
hidden_states = (
|
|
140
|
+
hidden_states
|
|
141
|
+
+ controlnet_block_samples[index_block // interval_control]
|
|
142
|
+
)
|
|
143
|
+
|
|
144
|
+
return encoder_hidden_states, hidden_states
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
def __patch_transformer_forward__(
|
|
148
|
+
self: QwenImageTransformer2DModel,
|
|
149
|
+
hidden_states: torch.Tensor,
|
|
150
|
+
encoder_hidden_states: torch.Tensor = None,
|
|
151
|
+
encoder_hidden_states_mask: torch.Tensor = None,
|
|
152
|
+
timestep: torch.LongTensor = None,
|
|
153
|
+
img_shapes: Optional[List[Tuple[int, int, int]]] = None,
|
|
154
|
+
txt_seq_lens: Optional[List[int]] = None,
|
|
155
|
+
guidance: torch.Tensor = None, # TODO: this should probably be removed
|
|
156
|
+
attention_kwargs: Optional[Dict[str, Any]] = None,
|
|
157
|
+
controlnet_block_samples=None,
|
|
158
|
+
return_dict: bool = True,
|
|
159
|
+
) -> Union[torch.Tensor, Transformer2DModelOutput]:
|
|
160
|
+
"""
|
|
161
|
+
The [`QwenTransformer2DModel`] forward method.
|
|
162
|
+
|
|
163
|
+
Args:
|
|
164
|
+
hidden_states (`torch.Tensor` of shape `(batch_size, image_sequence_length, in_channels)`):
|
|
165
|
+
Input `hidden_states`.
|
|
166
|
+
encoder_hidden_states (`torch.Tensor` of shape `(batch_size, text_sequence_length, joint_attention_dim)`):
|
|
167
|
+
Conditional embeddings (embeddings computed from the input conditions such as prompts) to use.
|
|
168
|
+
encoder_hidden_states_mask (`torch.Tensor` of shape `(batch_size, text_sequence_length)`):
|
|
169
|
+
Mask of the input conditions.
|
|
170
|
+
timestep ( `torch.LongTensor`):
|
|
171
|
+
Used to indicate denoising step.
|
|
172
|
+
attention_kwargs (`dict`, *optional*):
|
|
173
|
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
|
174
|
+
`self.processor` in
|
|
175
|
+
[diffusers.models.attention_processor](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention_processor.py).
|
|
176
|
+
return_dict (`bool`, *optional*, defaults to `True`):
|
|
177
|
+
Whether or not to return a [`~models.transformer_2d.Transformer2DModelOutput`] instead of a plain
|
|
178
|
+
tuple.
|
|
179
|
+
|
|
180
|
+
Returns:
|
|
181
|
+
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
|
|
182
|
+
`tuple` where the first element is the sample tensor.
|
|
183
|
+
"""
|
|
184
|
+
if attention_kwargs is not None:
|
|
185
|
+
attention_kwargs = attention_kwargs.copy()
|
|
186
|
+
lora_scale = attention_kwargs.pop("scale", 1.0)
|
|
187
|
+
else:
|
|
188
|
+
lora_scale = 1.0
|
|
189
|
+
|
|
190
|
+
if USE_PEFT_BACKEND:
|
|
191
|
+
# weight the lora layers by setting `lora_scale` for each PEFT layer
|
|
192
|
+
scale_lora_layers(self, lora_scale)
|
|
193
|
+
else:
|
|
194
|
+
if (
|
|
195
|
+
attention_kwargs is not None
|
|
196
|
+
and attention_kwargs.get("scale", None) is not None
|
|
197
|
+
):
|
|
198
|
+
logger.warning(
|
|
199
|
+
"Passing `scale` via `joint_attention_kwargs` when not using the PEFT backend is ineffective."
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
hidden_states = self.img_in(hidden_states)
|
|
203
|
+
|
|
204
|
+
timestep = timestep.to(hidden_states.dtype)
|
|
205
|
+
encoder_hidden_states = self.txt_norm(encoder_hidden_states)
|
|
206
|
+
encoder_hidden_states = self.txt_in(encoder_hidden_states)
|
|
207
|
+
|
|
208
|
+
if guidance is not None:
|
|
209
|
+
guidance = guidance.to(hidden_states.dtype) * 1000
|
|
210
|
+
|
|
211
|
+
temb = (
|
|
212
|
+
self.time_text_embed(timestep, hidden_states)
|
|
213
|
+
if guidance is None
|
|
214
|
+
else self.time_text_embed(timestep, guidance, hidden_states)
|
|
215
|
+
)
|
|
216
|
+
|
|
217
|
+
image_rotary_emb = self.pos_embed(
|
|
218
|
+
img_shapes, txt_seq_lens, device=hidden_states.device
|
|
219
|
+
)
|
|
220
|
+
|
|
221
|
+
for index_block, block in enumerate(self.transformer_blocks):
|
|
222
|
+
if torch.is_grad_enabled() and self.gradient_checkpointing:
|
|
223
|
+
encoder_hidden_states, hidden_states = (
|
|
224
|
+
self._gradient_checkpointing_func(
|
|
225
|
+
block,
|
|
226
|
+
hidden_states,
|
|
227
|
+
encoder_hidden_states,
|
|
228
|
+
encoder_hidden_states_mask,
|
|
229
|
+
temb,
|
|
230
|
+
image_rotary_emb,
|
|
231
|
+
controlnet_block_samples,
|
|
232
|
+
)
|
|
233
|
+
)
|
|
234
|
+
|
|
235
|
+
else:
|
|
236
|
+
encoder_hidden_states, hidden_states = block(
|
|
237
|
+
hidden_states=hidden_states,
|
|
238
|
+
encoder_hidden_states=encoder_hidden_states,
|
|
239
|
+
encoder_hidden_states_mask=encoder_hidden_states_mask,
|
|
240
|
+
temb=temb,
|
|
241
|
+
image_rotary_emb=image_rotary_emb,
|
|
242
|
+
controlnet_block_samples=controlnet_block_samples,
|
|
243
|
+
joint_attention_kwargs=attention_kwargs,
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
# # controlnet residual
|
|
247
|
+
# if controlnet_block_samples is not None:
|
|
248
|
+
# interval_control = len(self.transformer_blocks) / len(controlnet_block_samples)
|
|
249
|
+
# interval_control = int(np.ceil(interval_control))
|
|
250
|
+
# hidden_states = hidden_states + controlnet_block_samples[index_block // interval_control]
|
|
251
|
+
|
|
252
|
+
# Use only the image part (hidden_states) from the dual-stream blocks
|
|
253
|
+
hidden_states = self.norm_out(hidden_states, temb)
|
|
254
|
+
output = self.proj_out(hidden_states)
|
|
255
|
+
|
|
256
|
+
if USE_PEFT_BACKEND:
|
|
257
|
+
# remove `lora_scale` from each PEFT layer
|
|
258
|
+
unscale_lora_layers(self, lora_scale)
|
|
259
|
+
|
|
260
|
+
if not return_dict:
|
|
261
|
+
return (output,)
|
|
262
|
+
|
|
263
|
+
return Transformer2DModelOutput(sample=output)
|
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
import yaml
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
def load_cache_options_from_yaml(yaml_file_path):
|
|
5
|
+
try:
|
|
6
|
+
with open(yaml_file_path, "r") as f:
|
|
7
|
+
kwargs: dict = yaml.safe_load(f)
|
|
8
|
+
|
|
9
|
+
required_keys = [
|
|
10
|
+
"residual_diff_threshold",
|
|
11
|
+
]
|
|
12
|
+
for key in required_keys:
|
|
13
|
+
if key not in kwargs:
|
|
14
|
+
raise ValueError(
|
|
15
|
+
f"Configuration file missing required item: {key}"
|
|
16
|
+
)
|
|
17
|
+
|
|
18
|
+
cache_context_kwargs = {}
|
|
19
|
+
if kwargs.get("enable_taylorseer", False):
|
|
20
|
+
from cache_dit.caching.cache_contexts.calibrators import (
|
|
21
|
+
TaylorSeerCalibratorConfig,
|
|
22
|
+
)
|
|
23
|
+
|
|
24
|
+
cache_context_kwargs["calibrator_config"] = (
|
|
25
|
+
TaylorSeerCalibratorConfig(
|
|
26
|
+
enable_calibrator=kwargs.pop("enable_taylorseer"),
|
|
27
|
+
enable_encoder_calibrator=kwargs.pop(
|
|
28
|
+
"enable_encoder_taylorseer", False
|
|
29
|
+
),
|
|
30
|
+
calibrator_cache_type=kwargs.pop(
|
|
31
|
+
"taylorseer_cache_type", "residual"
|
|
32
|
+
),
|
|
33
|
+
taylorseer_order=kwargs.pop("taylorseer_order", 1),
|
|
34
|
+
)
|
|
35
|
+
)
|
|
36
|
+
|
|
37
|
+
if "cache_type" not in kwargs:
|
|
38
|
+
from cache_dit.caching.cache_contexts import BasicCacheConfig
|
|
39
|
+
|
|
40
|
+
cache_context_kwargs["cache_config"] = BasicCacheConfig()
|
|
41
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
42
|
+
else:
|
|
43
|
+
cache_type = kwargs.pop("cache_type")
|
|
44
|
+
if cache_type == "DBCache":
|
|
45
|
+
from cache_dit.caching.cache_contexts import DBCacheConfig
|
|
46
|
+
|
|
47
|
+
cache_context_kwargs["cache_config"] = DBCacheConfig()
|
|
48
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
49
|
+
elif cache_type == "DBPrune":
|
|
50
|
+
from cache_dit.caching.cache_contexts import DBPruneConfig
|
|
51
|
+
|
|
52
|
+
cache_context_kwargs["cache_config"] = DBPruneConfig()
|
|
53
|
+
cache_context_kwargs["cache_config"].update(**kwargs)
|
|
54
|
+
else:
|
|
55
|
+
raise ValueError(f"Unsupported cache_type: {cache_type}.")
|
|
56
|
+
|
|
57
|
+
return cache_context_kwargs
|
|
58
|
+
|
|
59
|
+
except FileNotFoundError:
|
|
60
|
+
raise FileNotFoundError(
|
|
61
|
+
f"Configuration file not found: {yaml_file_path}"
|
|
62
|
+
)
|
|
63
|
+
except yaml.YAMLError as e:
|
|
64
|
+
raise yaml.YAMLError(f"YAML file parsing error: {str(e)}")
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
def load_options(path: str):
|
|
68
|
+
return load_cache_options_from_yaml(path)
|
cache_dit/metrics/__init__.py
CHANGED
|
@@ -1,3 +1,14 @@
|
|
|
1
|
+
try:
|
|
2
|
+
import ImageReward
|
|
3
|
+
import lpips
|
|
4
|
+
import skimage
|
|
5
|
+
import scipy
|
|
6
|
+
except ImportError:
|
|
7
|
+
raise ImportError(
|
|
8
|
+
"Metrics functionality requires the 'metrics' extra dependencies. "
|
|
9
|
+
"Install with:\npip install cache-dit[metrics]"
|
|
10
|
+
)
|
|
11
|
+
|
|
1
12
|
from cache_dit.metrics.metrics import compute_psnr
|
|
2
13
|
from cache_dit.metrics.metrics import compute_ssim
|
|
3
14
|
from cache_dit.metrics.metrics import compute_mse
|
cache_dit/metrics/metrics.py
CHANGED
|
@@ -646,6 +646,7 @@ def entrypoint():
|
|
|
646
646
|
not os.path.exists(img_test),
|
|
647
647
|
)
|
|
648
648
|
):
|
|
649
|
+
logger.error(f"Not exist: {img_true} or {img_test}, skip.")
|
|
649
650
|
return
|
|
650
651
|
# img_true and img_test can be files or dirs
|
|
651
652
|
img_true_info = os.path.basename(img_true)
|
|
@@ -684,6 +685,7 @@ def entrypoint():
|
|
|
684
685
|
not os.path.exists(img_test), # dir
|
|
685
686
|
)
|
|
686
687
|
):
|
|
688
|
+
logger.error(f"Not exist: {prompt_true} or {img_test}, skip.")
|
|
687
689
|
return
|
|
688
690
|
|
|
689
691
|
# img_true and img_test can be files or dirs
|
|
@@ -714,6 +716,7 @@ def entrypoint():
|
|
|
714
716
|
not os.path.exists(video_test),
|
|
715
717
|
)
|
|
716
718
|
):
|
|
719
|
+
logger.error(f"Not exist: {video_true} or {video_test}, skip.")
|
|
717
720
|
return
|
|
718
721
|
|
|
719
722
|
# video_true and video_test can be files or dirs
|
|
@@ -0,0 +1,164 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from typing import Optional
|
|
3
|
+
|
|
4
|
+
from diffusers.models.modeling_utils import ModelMixin
|
|
5
|
+
from cache_dit.parallelism.parallel_backend import ParallelismBackend
|
|
6
|
+
from cache_dit.parallelism.parallel_config import ParallelismConfig
|
|
7
|
+
from cache_dit.logger import init_logger
|
|
8
|
+
from ..utils import (
|
|
9
|
+
native_diffusers_parallelism_available,
|
|
10
|
+
ContextParallelConfig,
|
|
11
|
+
)
|
|
12
|
+
from .attention import maybe_resigter_native_attention_backend
|
|
13
|
+
from .cp_planners import *
|
|
14
|
+
|
|
15
|
+
try:
|
|
16
|
+
maybe_resigter_native_attention_backend()
|
|
17
|
+
except ImportError as e:
|
|
18
|
+
raise ImportError(e)
|
|
19
|
+
|
|
20
|
+
logger = init_logger(__name__)
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def maybe_enable_context_parallelism(
|
|
24
|
+
transformer: torch.nn.Module,
|
|
25
|
+
parallelism_config: Optional[ParallelismConfig],
|
|
26
|
+
) -> torch.nn.Module:
|
|
27
|
+
assert isinstance(transformer, ModelMixin), (
|
|
28
|
+
"transformer must be an instance of diffusers' ModelMixin, "
|
|
29
|
+
f"but got {type(transformer)}"
|
|
30
|
+
)
|
|
31
|
+
if parallelism_config is None:
|
|
32
|
+
return transformer
|
|
33
|
+
|
|
34
|
+
assert isinstance(parallelism_config, ParallelismConfig), (
|
|
35
|
+
"parallelism_config must be an instance of ParallelismConfig"
|
|
36
|
+
f" but got {type(parallelism_config)}"
|
|
37
|
+
)
|
|
38
|
+
|
|
39
|
+
if (
|
|
40
|
+
parallelism_config.backend == ParallelismBackend.NATIVE_DIFFUSER
|
|
41
|
+
and native_diffusers_parallelism_available()
|
|
42
|
+
):
|
|
43
|
+
cp_config = None
|
|
44
|
+
if (
|
|
45
|
+
parallelism_config.ulysses_size is not None
|
|
46
|
+
or parallelism_config.ring_size is not None
|
|
47
|
+
):
|
|
48
|
+
cp_config = ContextParallelConfig(
|
|
49
|
+
ulysses_degree=parallelism_config.ulysses_size,
|
|
50
|
+
ring_degree=parallelism_config.ring_size,
|
|
51
|
+
)
|
|
52
|
+
if cp_config is not None:
|
|
53
|
+
attention_backend = parallelism_config.parallel_kwargs.get(
|
|
54
|
+
"attention_backend", None
|
|
55
|
+
)
|
|
56
|
+
if hasattr(transformer, "enable_parallelism"):
|
|
57
|
+
if hasattr(transformer, "set_attention_backend"):
|
|
58
|
+
# native, _native_cudnn, flash, etc.
|
|
59
|
+
if attention_backend is None:
|
|
60
|
+
# Now only _native_cudnn is supported for parallelism
|
|
61
|
+
# issue: https://github.com/huggingface/diffusers/pull/12443
|
|
62
|
+
transformer.set_attention_backend("_native_cudnn")
|
|
63
|
+
logger.warning(
|
|
64
|
+
"attention_backend is None, set default attention backend "
|
|
65
|
+
"to _native_cudnn for parallelism because of the issue: "
|
|
66
|
+
"https://github.com/huggingface/diffusers/pull/12443"
|
|
67
|
+
)
|
|
68
|
+
else:
|
|
69
|
+
transformer.set_attention_backend(attention_backend)
|
|
70
|
+
logger.info(
|
|
71
|
+
"Found attention_backend from config, set attention "
|
|
72
|
+
f"backend to: {attention_backend}"
|
|
73
|
+
)
|
|
74
|
+
# Prefer custom cp_plan if provided
|
|
75
|
+
cp_plan = parallelism_config.parallel_kwargs.get(
|
|
76
|
+
"cp_plan", None
|
|
77
|
+
)
|
|
78
|
+
if cp_plan is not None:
|
|
79
|
+
logger.info(
|
|
80
|
+
f"Using custom context parallelism plan: {cp_plan}"
|
|
81
|
+
)
|
|
82
|
+
else:
|
|
83
|
+
# Try get context parallelism plan from register if not provided
|
|
84
|
+
extra_parallel_kwargs = {}
|
|
85
|
+
if parallelism_config.parallel_kwargs is not None:
|
|
86
|
+
extra_parallel_kwargs = (
|
|
87
|
+
parallelism_config.parallel_kwargs
|
|
88
|
+
)
|
|
89
|
+
cp_plan = ContextParallelismPlannerRegister.get_planner(
|
|
90
|
+
transformer
|
|
91
|
+
)().apply(transformer=transformer, **extra_parallel_kwargs)
|
|
92
|
+
|
|
93
|
+
transformer.enable_parallelism(
|
|
94
|
+
config=cp_config, cp_plan=cp_plan
|
|
95
|
+
)
|
|
96
|
+
_maybe_patch_native_parallel_config(transformer)
|
|
97
|
+
else:
|
|
98
|
+
raise ValueError(
|
|
99
|
+
f"{transformer.__class__.__name__} does not support context parallelism."
|
|
100
|
+
)
|
|
101
|
+
|
|
102
|
+
return transformer
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
def _maybe_patch_native_parallel_config(
|
|
106
|
+
transformer: torch.nn.Module,
|
|
107
|
+
) -> torch.nn.Module:
|
|
108
|
+
|
|
109
|
+
cls_name = transformer.__class__.__name__
|
|
110
|
+
if not cls_name.startswith("Nunchaku"):
|
|
111
|
+
return transformer
|
|
112
|
+
|
|
113
|
+
from diffusers import FluxTransformer2DModel, QwenImageTransformer2DModel
|
|
114
|
+
|
|
115
|
+
try:
|
|
116
|
+
from nunchaku.models.transformers.transformer_flux_v2 import (
|
|
117
|
+
NunchakuFluxTransformer2DModelV2,
|
|
118
|
+
NunchakuFluxAttention,
|
|
119
|
+
NunchakuFluxFA2Processor,
|
|
120
|
+
)
|
|
121
|
+
from nunchaku.models.transformers.transformer_qwenimage import (
|
|
122
|
+
NunchakuQwenAttention,
|
|
123
|
+
NunchakuQwenImageNaiveFA2Processor,
|
|
124
|
+
NunchakuQwenImageTransformer2DModel,
|
|
125
|
+
)
|
|
126
|
+
except ImportError:
|
|
127
|
+
raise ImportError(
|
|
128
|
+
"NunchakuFluxTransformer2DModelV2 or NunchakuQwenImageTransformer2DModel "
|
|
129
|
+
"requires the 'nunchaku' package. Please install nunchaku before using "
|
|
130
|
+
"the context parallelism for nunchaku 4-bits models."
|
|
131
|
+
)
|
|
132
|
+
assert isinstance(
|
|
133
|
+
transformer,
|
|
134
|
+
(
|
|
135
|
+
NunchakuFluxTransformer2DModelV2,
|
|
136
|
+
FluxTransformer2DModel,
|
|
137
|
+
),
|
|
138
|
+
) or isinstance(
|
|
139
|
+
transformer,
|
|
140
|
+
(
|
|
141
|
+
NunchakuQwenImageTransformer2DModel,
|
|
142
|
+
QwenImageTransformer2DModel,
|
|
143
|
+
),
|
|
144
|
+
), (
|
|
145
|
+
"transformer must be an instance of NunchakuFluxTransformer2DModelV2 "
|
|
146
|
+
f"or NunchakuQwenImageTransformer2DModel, but got {type(transformer)}"
|
|
147
|
+
)
|
|
148
|
+
config = transformer._parallel_config
|
|
149
|
+
|
|
150
|
+
attention_classes = (
|
|
151
|
+
NunchakuFluxAttention,
|
|
152
|
+
NunchakuFluxFA2Processor,
|
|
153
|
+
NunchakuQwenAttention,
|
|
154
|
+
NunchakuQwenImageNaiveFA2Processor,
|
|
155
|
+
)
|
|
156
|
+
for module in transformer.modules():
|
|
157
|
+
if not isinstance(module, attention_classes):
|
|
158
|
+
continue
|
|
159
|
+
processor = getattr(module, "processor", None)
|
|
160
|
+
if processor is None or not hasattr(processor, "_parallel_config"):
|
|
161
|
+
continue
|
|
162
|
+
processor._parallel_config = config
|
|
163
|
+
|
|
164
|
+
return transformer
|